From 43d38eeca5102865b860c8d248a73a6ef61d9beb Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 22 Feb 2024 17:46:23 -0500 Subject: [PATCH 001/531] [Attn] Making decode attn kernel be aware of webgpu target (#1817) This PR enables the decode attn kernel to have awareness of the webgpu backend, so that it helps make sure the total number of threads does not exceed the 256 limit of WebGPU. Co-authored-by: Bohan Hou --- python/mlc_chat/nn/kv_cache.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/python/mlc_chat/nn/kv_cache.py b/python/mlc_chat/nn/kv_cache.py index ac5f2d5d4c..e956037411 100644 --- a/python/mlc_chat/nn/kv_cache.py +++ b/python/mlc_chat/nn/kv_cache.py @@ -820,11 +820,16 @@ def _attention_decode( H_kv = num_kv_heads D = head_dim + thread_limit = 512 if str(target.kind) != "webgpu" else 256 + GROUP_SIZE = H_qo // H_kv VEC_SIZE = min(max(8 // qkv_dtype_bytes, D // 32), 4) bdx = D // VEC_SIZE bdy = GROUP_SIZE - threads_per_CTA = max(512, bdx * bdy) + while bdx * bdy > thread_limit and bdy > 1: + bdy //= 2 + gdz = GROUP_SIZE // bdy + threads_per_CTA = max(thread_limit, bdx * bdy) bdz = threads_per_CTA // (bdx * bdy) tile_size_per_bdx = 2 if GROUP_SIZE == 1 else 1 log2e = math.log2(math.exp(1)) @@ -868,7 +873,7 @@ def batch_decode_paged_kv( sm_scale = 1.0 / math.sqrt(float(D)) * log2e for bx in T.thread_binding(B, thread="blockIdx.x"): - for by in T.thread_binding(H_kv, thread="blockIdx.y"): + for fused_by_bz in T.thread_binding(H_kv * gdz, thread="blockIdx.y"): for ty in T.thread_binding(bdy, thread="threadIdx.y"): for tx in T.thread_binding(bdx, thread="threadIdx.x"): for tz in T.thread_binding(bdz, thread="threadIdx.z"): @@ -894,6 +899,8 @@ def batch_decode_paged_kv( st_d = T.alloc_buffer((1,), "float32", scope="local") O_local = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") + by: T.int32 = fused_by_bz % H_kv + bz: T.int32 = fused_by_bz // H_kv batch_idx: T.int32 = bx cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] @@ -914,8 +921,8 @@ def batch_decode_paged_kv( for vec in T.vectorized(VEC_SIZE): Q_local[vec] = T.if_then_else( rotary_mode == 1, - _rope(Q, q_rope_position[batch_idx], head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + ty, tx * VEC_SIZE + vec), qkv_dtype), - Q[bx, by * GROUP_SIZE + ty, tx * VEC_SIZE + vec] + _rope(Q, q_rope_position[batch_idx], head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec), qkv_dtype), + Q[bx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec] ) for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_size_per_bdx * bdy * bdz)): @@ -1025,10 +1032,10 @@ def batch_decode_paged_kv( # store O to global memory for vec in T.vectorized(VEC_SIZE): - output[batch_idx, by * GROUP_SIZE + ty, tx * VEC_SIZE + vec] = O_local[vec] + output[batch_idx, by * GROUP_SIZE + bz * bdy + ty, tx * VEC_SIZE + vec] = O_local[vec] # store lse to global memory - lse[batch_idx, by * GROUP_SIZE + ty] = st_m[0] + T.log2(st_d[0]) + lse[batch_idx, by * GROUP_SIZE + bz * bdy + ty] = st_m[0] + T.log2(st_d[0]) # fmt: on # pylint: enable=line-too-long,invalid-name,too-many-arguments,too-many-branches return batch_decode_paged_kv From e30a457a8369f64cb38de1cc6357db23aacf349b Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 23 Feb 2024 13:21:54 -0500 Subject: [PATCH 002/531] [Serving][Refactor] Logit processor and logit bias support (#1828) This PR refactors the existing logit processing pipeline with a unfiied logit processor class. The logit processor class exposes two functions: - `InplaceUpdateLogits`, which takes in the raw logits produced by the model, and apply logit bias (which is introduced in this PR), presence/frequency/repetition penalties, and token id mask in order when needed. - `ComputeProbsFromLogits`, which takes in the updated logits, and invoke softmax with temperature to compute the probability distribution. The logit processor completely runs on GPU. This being said, all the logit bias / penalty / mask application and the softmax is backed by GPU kernels. This is a highlight difference compared with the logit processing prior to this PR, where the processing happens on CPU, and softmax also happens on CPU when any logit process is needed. With the unified logit processor, we simplified the interface of handling model's output logits in engine actions to make it cleaner. We also simplified the interface of Sampler. Preliminary results show that LogitProcessor brings a bit perf improvement when any processing is needed. --- cpp/serve/config.cc | 22 + cpp/serve/config.h | 1 + cpp/serve/engine.cc | 31 +- cpp/serve/engine_actions/action.h | 18 +- cpp/serve/engine_actions/batch_decode.cc | 27 +- cpp/serve/engine_actions/batch_draft.cc | 26 +- cpp/serve/engine_actions/batch_verify.cc | 31 +- .../engine_actions/new_request_prefill.cc | 41 +- cpp/serve/function_table.cc | 13 +- cpp/serve/function_table.h | 3 + cpp/serve/logit_processor.cc | 404 ++++++++++++++ cpp/serve/logit_processor.h | 94 ++++ cpp/serve/model.cc | 41 +- cpp/serve/model.h | 15 +- cpp/serve/request_state.cc | 14 +- cpp/serve/request_state.h | 5 + cpp/serve/sampler.cc | 512 ++++-------------- cpp/serve/sampler.h | 22 +- .../compiler_pass/attach_to_ir_module.py | 113 ++++ python/mlc_chat/compiler_pass/pipeline.py | 2 + .../mlc_chat/protocol/openai_api_protocol.py | 46 +- python/mlc_chat/serve/config.py | 6 +- tests/python/serve/server/test_server.py | 71 ++- 23 files changed, 1008 insertions(+), 550 deletions(-) create mode 100644 cpp/serve/logit_processor.cc create mode 100644 cpp/serve/logit_processor.h diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 3c4d77d6a6..804ff9fe93 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -52,6 +52,22 @@ GenerationConfig::GenerationConfig(String config_json_str) { n->repetition_penalty = config["repetition_penalty"].get(); CHECK(n->repetition_penalty > 0) << "Repetition penalty must be a positive number!"; } + if (config.count("logit_bias")) { + CHECK(config["logit_bias"].is() || config["logit_bias"].is()); + if (config["logit_bias"].is()) { + picojson::object logit_bias_json = config["logit_bias"].get(); + std::vector> logit_bias; + logit_bias.reserve(logit_bias_json.size()); + for (auto [token_id_str, bias] : logit_bias_json) { + CHECK(bias.is()); + double bias_value = bias.get(); + CHECK_LE(std::fabs(bias_value), 100.0) + << "Logit bias value should be in range [-100, 100]."; + logit_bias.emplace_back(std::stoi(token_id_str), bias_value); + } + n->logit_bias = std::move(logit_bias); + } + } if (config.count("max_tokens")) { if (config["max_tokens"].is()) { n->max_tokens = config["max_tokens"].get(); @@ -115,6 +131,12 @@ String GenerationConfigNode::AsJSONString() const { config["max_tokens"] = picojson::value(static_cast(this->max_tokens)); config["seed"] = picojson::value(static_cast(this->seed)); + picojson::object logit_bias_obj; + for (auto [token_id, bias] : logit_bias) { + logit_bias_obj[std::to_string(token_id)] = picojson::value(static_cast(bias)); + } + config["logit_bias"] = picojson::value(logit_bias_obj); + picojson::array stop_strs_arr; for (String stop_str : this->stop_strs) { stop_strs_arr.push_back(picojson::value(stop_str)); diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 34bbfc9880..c9ebf0c847 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -25,6 +25,7 @@ class GenerationConfigNode : public Object { double frequency_penalty = 0.0; double presence_penalty = 0.0; double repetition_penalty = 1.0; + std::vector> logit_bias; int seed; bool ignore_eos = false; diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 08376712be..28b1e70006 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -19,6 +19,7 @@ #include "engine_actions/action_commons.h" #include "engine_state.h" #include "event_trace_recorder.h" +#include "logit_processor.h" #include "model.h" #include "request.h" #include "request_state.h" @@ -53,10 +54,10 @@ class EngineImpl : public Engine { this->engine_mode_ = EngineMode(engine_mode_json_str); this->request_stream_callback_ = std::move(request_stream_callback); this->trace_recorder_ = trace_recorder; - this->sampler_ = Sampler::Create(/*sampler_kind=*/"cpu", trace_recorder_); this->tokenizer_ = Tokenizer::FromPath(tokenizer_path); this->token_table_ = tokenizer_->TokenTable(); // Step 2. Initialize each model independently. + // Create the logit processor and sampler. this->models_.clear(); for (const auto& model_info : model_infos) { TVMArgValue model_lib = std::get<0>(model_info); @@ -71,26 +72,35 @@ class EngineImpl : public Engine { << this->max_single_sequence_length_; this->models_.push_back(model); } + int max_logit_processor_num_token = kv_cache_config_->max_num_sequence; + if (engine_mode_->enable_speculative) { + max_logit_processor_num_token *= engine_mode_->spec_draft_length; + } + LogitProcessor logit_processor = + this->models_[0]->CreateLogitProcessor(max_logit_processor_num_token, trace_recorder); + Sampler sampler = Sampler::Create(/*sampler_kind=*/"cpu", trace_recorder_); // Step 3. Initialize engine actions that represent state transitions. if (this->engine_mode_->enable_speculative) { // Speculative decoding is only possible for more than one model. ICHECK_GT(this->models_.size(), 1U); this->actions_ = { EngineAction::NewRequestPrefill(this->models_, // - this->sampler_, // + logit_processor, // + sampler, // this->kv_cache_config_, // this->trace_recorder_), - EngineAction::BatchDraft(this->models_, this->sampler_, this->trace_recorder_, + EngineAction::BatchDraft(this->models_, logit_processor, sampler, this->trace_recorder_, this->engine_mode_->spec_draft_length), - EngineAction::BatchVerify(this->models_, this->sampler_, this->kv_cache_config_, + EngineAction::BatchVerify(this->models_, logit_processor, sampler, this->kv_cache_config_, this->trace_recorder_)}; } else { - this->actions_ = { - EngineAction::NewRequestPrefill(this->models_, // - this->sampler_, // - this->kv_cache_config_, // - this->trace_recorder_), - EngineAction::BatchDecode(this->models_, this->sampler_, this->trace_recorder_)}; + this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->kv_cache_config_, // + this->trace_recorder_), + EngineAction::BatchDecode(this->models_, logit_processor, sampler, + this->trace_recorder_)}; } // Step 4. Automatically set the threading backend max concurrency. SetThreadMaxConcurrency(); @@ -196,7 +206,6 @@ class EngineImpl : public Engine { KVCacheConfig kv_cache_config_; EngineMode engine_mode_; int max_single_sequence_length_; - Sampler sampler_; Tokenizer tokenizer_; std::vector token_table_; // Models diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index cd2ef33f99..8e305e26af 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -53,13 +53,14 @@ class EngineAction : public ObjectRef { * \brief Create the action that prefills requests in the `waiting_queue` * of the engine state. * \param models The models to run prefill in. + * \param logit_processor The logit processor. * \param sampler The sampler to sample new tokens. * \param kv_cache_config The KV cache config to help decide prefill is doable. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ - static EngineAction NewRequestPrefill(Array models, Sampler sampler, - KVCacheConfig kv_cache_config, + static EngineAction NewRequestPrefill(Array models, LogitProcessor logit_processor, + Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder); /*! * \brief Create the action that runs one-step decode for requests in the @@ -74,8 +75,8 @@ class EngineAction : public ObjectRef { * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ - static EngineAction BatchDecode(Array models, Sampler sampler, - Optional trace_recorder); + static EngineAction BatchDecode(Array models, LogitProcessor logit_processor, + Sampler sampler, Optional trace_recorder); /*! * \brief Create the action that runs one-step speculative draft proposal for @@ -88,8 +89,9 @@ class EngineAction : public ObjectRef { * \param draft_length The number of draft proposal rounds. * \return The created action object. */ - static EngineAction BatchDraft(Array models, Sampler sampler, - Optional trace_recorder, int draft_length = 4); + static EngineAction BatchDraft(Array models, LogitProcessor logit_processor, + Sampler sampler, Optional trace_recorder, + int draft_length = 4); /*! * \brief Create the action that runs one-step speculative verification for requests in the @@ -102,8 +104,8 @@ class EngineAction : public ObjectRef { * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ - static EngineAction BatchVerify(Array models, Sampler sampler, - KVCacheConfig kv_cache_config, + static EngineAction BatchVerify(Array models, LogitProcessor logit_processor, + Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineAction, ObjectRef, EngineActionObj); diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index 410e94d286..627e46bc9a 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -24,9 +24,10 @@ namespace serve { */ class BatchDecodeActionObj : public EngineActionObj { public: - explicit BatchDecodeActionObj(Array models, Sampler sampler, - Optional trace_recorder) + explicit BatchDecodeActionObj(Array models, LogitProcessor logit_processor, + Sampler sampler, Optional trace_recorder) : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), trace_recorder_(std::move(trace_recorder)) {} @@ -92,11 +93,17 @@ class BatchDecodeActionObj : public EngineActionObj { ICHECK_EQ(logits->shape[0], embeddings->shape[0]); ICHECK_EQ(logits->shape[1], 1); + // - Update logits. + logits = logits.CreateView({num_requests, logits->shape[2]}, logits->dtype); + logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); + + // - Compute probability distributions. + NDArray probs_device = + logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); + // - Sample tokens. - RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); std::vector next_tokens = - sampler_->BatchSampleTokens(logits, models_[0], mstates, generation_cfg, rngs); - RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + sampler_->BatchSampleTokens(probs_device, request_ids, generation_cfg, rngs); ICHECK_EQ(next_tokens.size(), num_requests); // - Update the committed tokens of states. @@ -122,16 +129,20 @@ class BatchDecodeActionObj : public EngineActionObj { * models, the `Step` function of the created action will not take effect. */ Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; /*! \brief Event trace recorder. */ Optional trace_recorder_; }; -EngineAction EngineAction::BatchDecode(Array models, Sampler sampler, +EngineAction EngineAction::BatchDecode(Array models, LogitProcessor logit_processor, + Sampler sampler, Optional trace_recorder) { - return EngineAction(make_object(std::move(models), std::move(sampler), - std::move(trace_recorder))); + return EngineAction( + make_object(std::move(models), std::move(logit_processor), + std::move(sampler), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index 3f5622cc6d..403350c4af 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -20,9 +20,10 @@ namespace serve { */ class BatchDraftActionObj : public EngineActionObj { public: - explicit BatchDraftActionObj(Array models, Sampler sampler, + explicit BatchDraftActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, Optional trace_recorder, int draft_length) : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), trace_recorder_(std::move(trace_recorder)), draft_length_(draft_length) { @@ -102,13 +103,19 @@ class BatchDraftActionObj : public EngineActionObj { ICHECK_EQ(logits->shape[0], embeddings->shape[0]); ICHECK_EQ(logits->shape[1], 1); + // - Update logits. + logits = logits.CreateView({num_requests, logits->shape[2]}, logits->dtype); + logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); + + // - Compute probability distributions. + NDArray probs_device = + logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); + // - Sample tokens. - RECORD_EVENT(trace_recorder_, request_ids, "start proposal sampling"); std::vector prob_dist; std::vector token_probs; std::vector next_tokens = sampler_->BatchSampleTokens( - logits, models_[model_id], mstates, generation_cfg, rngs, &prob_dist, &token_probs); - RECORD_EVENT(trace_recorder_, request_ids, "finish proposal sampling"); + probs_device, request_ids, generation_cfg, rngs, &prob_dist, &token_probs); ICHECK_EQ(next_tokens.size(), num_requests); // - Update the draft tokens, prob dist, token probs of states. @@ -143,6 +150,8 @@ class BatchDraftActionObj : public EngineActionObj { /*! \brief The model to run draft generation in speculative decoding. */ Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; /*! \brief Event trace recorder. */ @@ -151,11 +160,12 @@ class BatchDraftActionObj : public EngineActionObj { int draft_length_; }; -EngineAction EngineAction::BatchDraft(Array models, Sampler sampler, - Optional trace_recorder, +EngineAction EngineAction::BatchDraft(Array models, LogitProcessor logit_processor, + Sampler sampler, Optional trace_recorder, int draft_length) { - return EngineAction(make_object(std::move(models), std::move(sampler), - std::move(trace_recorder), draft_length)); + return EngineAction(make_object( + std::move(models), std::move(logit_processor), std::move(sampler), std::move(trace_recorder), + draft_length)); } } // namespace serve diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index ef33449fd7..e4aa836127 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -26,9 +26,11 @@ namespace serve { */ class BatchVerifyActionObj : public EngineActionObj { public: - explicit BatchVerifyActionObj(Array models, Sampler sampler, KVCacheConfig kv_cache_config, + explicit BatchVerifyActionObj(Array models, LogitProcessor logit_processor, + Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder) : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), kv_cache_config_(std::move(kv_cache_config)), trace_recorder_(std::move(trace_recorder)), @@ -103,13 +105,22 @@ class BatchVerifyActionObj : public EngineActionObj { ICHECK_EQ(logits->shape[0], 1); ICHECK_EQ(logits->shape[1], total_draft_length); + // - Update logits. std::vector cum_verify_lengths = {0}; for (int i = 0; i < num_requests; ++i) { cum_verify_lengths.push_back(cum_verify_lengths.back() + draft_lengths[i]); } + logits = logits.CreateView({total_draft_length, logits->shape[2]}, logits->dtype); + logit_processor_->InplaceUpdateLogits(logits, generation_cfg, verify_request_mstates, + request_ids, &cum_verify_lengths, &draft_output_tokens); + + // - Compute probability distributions. + NDArray probs_device = logit_processor_->ComputeProbsFromLogits( + logits, generation_cfg, request_ids, &cum_verify_lengths); + std::vector> accepted_tokens_arr = sampler_->BatchVerifyDraftTokens( - logits, cum_verify_lengths, models_[verify_model_id_], verify_request_mstates, - generation_cfg, rngs, draft_output_tokens, draft_output_token_prob, draft_output_prob_dist); + probs_device, request_ids, cum_verify_lengths, verify_request_mstates, generation_cfg, rngs, + draft_output_tokens, draft_output_token_prob, draft_output_prob_dist); ICHECK_EQ(accepted_tokens_arr.size(), num_requests); for (int i = 0; i < num_requests; ++i) { @@ -222,6 +233,8 @@ class BatchVerifyActionObj : public EngineActionObj { * models, the `Step` function of the created action will not take effect. */ Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; /*! \brief The kv cache config. */ @@ -233,15 +246,15 @@ class BatchVerifyActionObj : public EngineActionObj { /*! \brief The ids of verify/draft models. */ const int verify_model_id_ = 0; const int draft_model_id_ = 1; - const float eps_ = 1e-9; + const float eps_ = 1e-5; }; -EngineAction EngineAction::BatchVerify(Array models, Sampler sampler, - KVCacheConfig kv_cache_config, +EngineAction EngineAction::BatchVerify(Array models, LogitProcessor logit_processor, + Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder) { - return EngineAction(make_object(std::move(models), std::move(sampler), - std::move(kv_cache_config), - std::move(trace_recorder))); + return EngineAction(make_object( + std::move(models), std::move(logit_processor), std::move(sampler), std::move(kv_cache_config), + std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index bf0d607c92..a3f1b2d17c 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -18,10 +18,11 @@ namespace serve { */ class NewRequestPrefillActionObj : public EngineActionObj { public: - explicit NewRequestPrefillActionObj(Array models, Sampler sampler, - KVCacheConfig kv_cache_config, + explicit NewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, + Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder) : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), kv_cache_config_(std::move(kv_cache_config)), trace_recorder_(std::move(trace_recorder)) {} @@ -87,23 +88,31 @@ class NewRequestPrefillActionObj : public EngineActionObj { } } - // - Sample tokens. + // - Update logits. ICHECK(logits_for_sample.defined()); - logits_for_sample = logits_for_sample.CreateView({num_requests, 1, logits_for_sample->shape[2]}, - logits_for_sample->dtype); + Array generation_cfg; Array mstates_for_sample; std::vector rngs; + generation_cfg.reserve(num_requests); mstates_for_sample.reserve(num_requests); rngs.reserve(num_requests); for (int i = 0; i < num_requests; ++i) { + generation_cfg.push_back(requests[i]->generation_cfg); mstates_for_sample.push_back(rstates[i]->mstates[0]); rngs.push_back(&rstates[i]->rng); } - RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); - std::vector next_tokens = sampler_->BatchSampleTokens( - logits_for_sample, models_[0], mstates_for_sample, - requests.Map([](Request request) { return request->generation_cfg; }), rngs); - RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + logits_for_sample = logits_for_sample.CreateView({num_requests, logits_for_sample->shape[2]}, + logits_for_sample->dtype); + logit_processor_->InplaceUpdateLogits(logits_for_sample, generation_cfg, mstates_for_sample, + request_ids); + + // - Compute probability distributions. + NDArray probs_device = + logit_processor_->ComputeProbsFromLogits(logits_for_sample, generation_cfg, request_ids); + + // - Sample tokens. + std::vector next_tokens = + sampler_->BatchSampleTokens(probs_device, request_ids, generation_cfg, rngs); ICHECK_EQ(next_tokens.size(), num_requests); // - Update the committed tokens of states. @@ -199,6 +208,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { /*! \brief The models to run prefill in. */ Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; /*! \brief The KV cache config to help decide prefill is doable. */ @@ -207,12 +218,12 @@ class NewRequestPrefillActionObj : public EngineActionObj { Optional trace_recorder_; }; -EngineAction EngineAction::NewRequestPrefill(Array models, Sampler sampler, - KVCacheConfig kv_cache_config, +EngineAction EngineAction::NewRequestPrefill(Array models, LogitProcessor logit_processor, + Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder) { - return EngineAction(make_object(std::move(models), std::move(sampler), - std::move(kv_cache_config), - std::move(trace_recorder))); + return EngineAction(make_object( + std::move(models), std::move(logit_processor), std::move(sampler), std::move(kv_cache_config), + std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 6dce770dc6..c4ebbe4be3 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -100,12 +100,9 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object this->get_global_func = [this](const std::string& name) -> PackedFunc { return SessionFuncAsPackedFunc(sess, sess->GetGlobalFunc(name), name); }; + this->model_metadata_ = + ModelMetadata::FromModule(this->disco_mod->DebugGetFromRemote(0), std::move(model_config)); this->_InitFunctions(); - { - Module mod = this->disco_mod->DebugGetFromRemote(0); - this->softmax_func_ = mod->GetFunction("softmax_with_temperature"); - this->model_metadata_ = ModelMetadata::FromModule(mod, std::move(model_config)); - } } else { Module executable{nullptr}; if (reload_lib.type_code() == kTVMModuleHandle) { @@ -193,7 +190,11 @@ void FunctionTable::_InitFunctions() { this->prefill_func_ = mod_get_func("batch_prefill"); this->decode_func_ = mod_get_func("batch_decode"); this->verify_func_ = mod_get_func("batch_verify"); - this->softmax_func_ = mod_get_func("softmax_with_temperature"); + Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm; + this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true); + this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true); + this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true); + this->apply_bitmask_func_ = mod->GetFunction("apply_bitmask_inplace", true); this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache"); if (!this->create_kv_cache_func_.defined()) { this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index 24c6180707..e37b0e6f89 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -71,6 +71,9 @@ struct FunctionTable { PackedFunc decode_func_; PackedFunc verify_func_; PackedFunc softmax_func_; + PackedFunc apply_logit_bias_func_; + PackedFunc apply_penalty_func_; + PackedFunc apply_bitmask_func_; PackedFunc create_kv_cache_func_; PackedFunc reset_kv_cache_func_; bool support_backtracking_kv_; diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc new file mode 100644 index 0000000000..a45c1f9f13 --- /dev/null +++ b/cpp/serve/logit_processor.cc @@ -0,0 +1,404 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/logit_processor.cc + * \brief The implementation of logit processor. + */ +#include "logit_processor.h" + +#include +#include +#include + +namespace mlc { +namespace llm { +namespace serve { + +inline void CopyArray(NDArray src, NDArray dst) { + DLTensor dl_dst = *(dst.operator->()); + NDArray::CopyFromTo(src.operator->(), &dl_dst); +} + +/***************** LogitProcessor Implementation *****************/ + +TVM_REGISTER_OBJECT_TYPE(LogitProcessorObj); + +class LogitProcessorImpl : public LogitProcessorObj { + public: + /*! * \brief Constructor of LogitProcessorImpl. */ + explicit LogitProcessorImpl(int max_num_token, int vocab_size, FunctionTable* ft, DLDevice device, + Optional trace_recorder) + : max_num_token_(max_num_token), + vocab_size_(vocab_size), + bitmask_size_((vocab_size + 31) / 32), + softmax_func_(ft->softmax_func_), + device_(device), + apply_logit_bias_func_(ft->apply_logit_bias_func_), + apply_penalty_func_(ft->apply_penalty_func_), + apply_bitmask_func_(ft->apply_bitmask_func_), + trace_recorder_(std::move(trace_recorder)) { + DLDevice device_cpu{DLDeviceType::kDLCPU, /*device_id=*/0}; + // Initialize auxiliary arrays on CPU. + seq_ids_host_ = NDArray::Empty({max_num_token}, dtype_i32_, device_cpu); + pos2seq_id_host_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device_cpu); + token_ids_host_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device_cpu); + token_cnt_host_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device_cpu); + token_logit_bias_host_ = NDArray::Empty({max_num_token * vocab_size}, dtype_f32_, device_cpu); + penalties_host_ = NDArray::Empty({max_num_token, 3}, dtype_f32_, device_cpu); + bitmask_host_ = NDArray::Empty({max_num_token, bitmask_size_}, dtype_i32_, device_cpu); + temperature_host_ = NDArray::Empty({max_num_token}, dtype_f32_, device_cpu); + // Initialize auxiliary arrays on GPU. + seq_ids_device_ = NDArray::Empty({max_num_token}, dtype_i32_, device); + pos2seq_id_device_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device); + token_ids_device_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device); + token_cnt_device_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device); + token_logit_bias_device_ = NDArray::Empty({max_num_token * vocab_size}, dtype_f32_, device); + penalties_device_ = NDArray::Empty({max_num_token, 3}, dtype_f32_, device); + bitmask_device_ = NDArray::Empty({max_num_token, bitmask_size_}, dtype_i32_, device); + temperature_device_ = NDArray::Empty({max_num_token}, dtype_f32_, device); + + CHECK(apply_logit_bias_func_.defined()) + << "Function \"apply_logit_bias_inplace\" not found in model"; + CHECK(apply_penalty_func_.defined()) << "Function \"apply_penalty_inplace\" not found in model"; + CHECK(apply_bitmask_func_.defined()) << "Function \"apply_bitmask_inplace\" not found in model"; + } + + void InplaceUpdateLogits(NDArray logits, // + const Array& generation_cfg, // + const Array& mstates, // + const Array& request_ids, // + const std::vector* cum_num_token, // + const std::vector>* draft_tokens) final { + CHECK_EQ(logits->ndim, 2); + CHECK_EQ(logits->shape[1], vocab_size_); + CHECK(logits.DataType() == DataType::Float(32)); + CHECK_EQ(generation_cfg.size(), mstates.size()); + CHECK_LE(logits->shape[0], max_num_token_); + int num_total_token = logits->shape[0]; + int num_sequence = generation_cfg.size(); + + CHECK((cum_num_token == nullptr) == (draft_tokens == nullptr)); + if (cum_num_token != nullptr) { + CHECK_EQ(draft_tokens->size(), num_sequence); + CHECK_EQ(cum_num_token->size(), num_sequence + 1); + CHECK_EQ(cum_num_token->back(), num_total_token); + } else { + CHECK_EQ(num_sequence, num_total_token); + } + + RECORD_EVENT(trace_recorder_, request_ids, "start update logits"); + + // Update 1. logit bias + RECORD_EVENT(trace_recorder_, request_ids, "start apply logit bias"); + UpdateWithLogitBias(logits, generation_cfg, cum_num_token); + RECORD_EVENT(trace_recorder_, request_ids, "finish apply logit bias"); + + // Update 2. penalties + RECORD_EVENT(trace_recorder_, request_ids, "start apply penalty"); + UpdateWithPenalty(logits, generation_cfg, mstates, cum_num_token, draft_tokens); + RECORD_EVENT(trace_recorder_, request_ids, "finish apply penalty"); + + // Update 3. Vocabulary mask. + RECORD_EVENT(trace_recorder_, request_ids, "start apply logit mask"); + UpdateWithMask(logits, mstates, cum_num_token, draft_tokens); + RECORD_EVENT(trace_recorder_, request_ids, "finish apply logit mask"); + + RECORD_EVENT(trace_recorder_, request_ids, "finish update logits"); + } + + NDArray ComputeProbsFromLogits(NDArray logits, const Array& generation_cfg, + const Array& request_ids, + const std::vector* cum_num_token) final { + // logits: (n, v) + CHECK_EQ(logits->ndim, 2); + CHECK_LE(logits->shape[0], max_num_token_); + CHECK_EQ(logits->shape[1], vocab_size_); + CHECK(logits.DataType() == DataType::Float(32)); + int num_total_token = logits->shape[0]; + int num_sequence = generation_cfg.size(); + + if (cum_num_token != nullptr) { + CHECK_EQ(cum_num_token->size(), num_sequence + 1); + CHECK_EQ(cum_num_token->back(), num_total_token); + } else { + CHECK_EQ(num_sequence, num_total_token); + } + + RECORD_EVENT(trace_recorder_, request_ids, "start softmax"); + + // Construct: + // - temperature (max_num_token,) float32 + float* p_temperature = static_cast(temperature_host_->data); + + // - Set arrays. + for (int i = 0; i < num_sequence; ++i) { + int num_token_to_process = + cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i)); + int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i); + for (int j = 0; j < num_token_to_process; ++j) { + p_temperature[token_offset + j] = std::max(generation_cfg[i]->temperature, eps_); + } + } + + // - View arrays. + NDArray temperature_host = temperature_host_.CreateView({num_total_token}, dtype_f32_); + NDArray temperature_device = temperature_device_.CreateView({num_total_token}, dtype_f32_); + + // - Copy arrays to GPU. + CopyArray(/*src=*/temperature_host, /*dst=*/temperature_device); + + // - Call kernel. + NDArray probs = softmax_func_(logits.CreateView({num_total_token, 1, vocab_size_}, dtype_f32_), + temperature_device); + ICHECK_EQ(probs->ndim, 3); + ICHECK_EQ(probs->shape[0], num_total_token); + ICHECK_EQ(probs->shape[1], 1); + ICHECK_EQ(probs->shape[2], vocab_size_); + if (trace_recorder_.defined()) { + TVMSynchronize(device_.device_type, device_.device_id, /*stream=*/nullptr); + } + RECORD_EVENT(trace_recorder_, request_ids, "finish softmax"); + return probs.CreateView({num_total_token, vocab_size_}, probs->dtype); + } + + private: + void UpdateWithLogitBias(NDArray logits, const Array& generation_cfg, + const std::vector* cum_num_token) { + // Construct: + // - pos2seq_id (max_num_token * vocab_size,) int32 + // - token_ids (max_num_token * vocab_size,) int32 + // - token_logit_bias (max_num_token * vocab_size,) float32 + int* p_pos2seq_id = static_cast(pos2seq_id_host_->data); + int* p_token_ids = static_cast(token_ids_host_->data); + float* p_token_logit_bias = static_cast(token_logit_bias_host_->data); + + // - Set arrays. + int num_token_for_bias = 0; + int num_bias_token = 0; + for (int i = 0; i < static_cast(generation_cfg.size()); ++i) { + int num_token_to_process = + cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i)); + int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i); + for (int j = 0; j < num_token_to_process; ++j) { + if (!generation_cfg[i]->logit_bias.empty()) { + for (auto [token_id, bias] : generation_cfg[i]->logit_bias) { + p_pos2seq_id[num_bias_token] = token_offset + j; + p_token_ids[num_bias_token] = token_id; + p_token_logit_bias[num_bias_token] = bias; + ++num_bias_token; + } + ++num_token_for_bias; + } + } + } + + if (num_token_for_bias == 0) { + return; + } + + // - View arrays. + int num_token = num_bias_token; + NDArray pos2seq_id_host = pos2seq_id_host_.CreateView({num_token}, dtype_i32_); + NDArray pos2seq_id_device = pos2seq_id_device_.CreateView({num_token}, dtype_i32_); + NDArray token_ids_host = token_ids_host_.CreateView({num_token}, dtype_i32_); + NDArray token_ids_device = token_ids_device_.CreateView({num_token}, dtype_i32_); + NDArray token_logit_bias_host = token_logit_bias_host_.CreateView({num_token}, dtype_f32_); + NDArray token_logit_bias_device = token_logit_bias_device_.CreateView({num_token}, dtype_f32_); + + // - Copy arrays to GPU. + CopyArray(/*src=*/pos2seq_id_host, /*dst=*/pos2seq_id_device); + CopyArray(/*src=*/token_ids_host, /*dst=*/token_ids_device); + CopyArray(/*src=*/token_logit_bias_host, /*dst=*/token_logit_bias_device); + + // - Call kernel. + apply_logit_bias_func_(logits, pos2seq_id_device, token_ids_device, token_logit_bias_device); + if (trace_recorder_.defined()) { + TVMSynchronize(device_.device_type, device_.device_id, /*stream=*/nullptr); + } + } + + void UpdateWithPenalty(NDArray logits, const Array& generation_cfg, + const Array& mstates, + const std::vector* cum_num_token, + const std::vector>* draft_tokens) { + // Construct: + // - seq_ids (max_num_token,) int32 + // - pos2seq_id (max_num_token * vocab_size,) int32 + // - token_ids (max_num_token * vocab_size,) int32 + // - token_cnt (max_num_token * vocab_size,) int32 + // - penalties (max_num_token, 3) float32 + int* p_seq_ids = static_cast(seq_ids_host_->data); + int* p_pos2seq_id = static_cast(pos2seq_id_host_->data); + int* p_token_ids = static_cast(token_ids_host_->data); + int* p_token_cnt = static_cast(token_cnt_host_->data); + float* p_penalties = static_cast(penalties_host_->data); + + // - Set arrays. + int num_token_for_penalty = 0; + int num_penalty_appeared_token = 0; + for (int i = 0; i < static_cast(generation_cfg.size()); ++i) { + if (generation_cfg[i]->frequency_penalty != 0.0 || + generation_cfg[i]->presence_penalty != 0.0 || + generation_cfg[i]->repetition_penalty != 1.0) { + int num_token_to_process = + cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i)); + int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i); + CHECK(num_token_to_process == 1 || mstates[i]->draft_output_tokens.empty()); + for (int j = 0; j < num_token_to_process; ++j) { + p_seq_ids[num_token_for_penalty] = token_offset + j; + for (auto [token_id, cnt] : mstates[i]->appeared_token_ids) { + p_pos2seq_id[num_penalty_appeared_token] = num_token_for_penalty; + p_token_ids[num_penalty_appeared_token] = token_id; + p_token_cnt[num_penalty_appeared_token] = cnt; + ++num_penalty_appeared_token; + } + p_penalties[num_token_for_penalty * 3] = generation_cfg[i]->presence_penalty; + p_penalties[num_token_for_penalty * 3 + 1] = generation_cfg[i]->frequency_penalty; + p_penalties[num_token_for_penalty * 3 + 2] = generation_cfg[i]->repetition_penalty; + ++num_token_for_penalty; + if (j > 0) { + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1]); + } + } + if (num_token_to_process != 1) { + // Roll back. + mstates[i]->RemoveAllDraftTokens(); + } + } + } + + if (num_token_for_penalty == 0) { + return; + } + + // - View arrays. + int num_seq = num_token_for_penalty; + int num_token = num_penalty_appeared_token; + NDArray seq_ids_host = seq_ids_host_.CreateView({num_seq}, dtype_i32_); + NDArray seq_ids_device = seq_ids_device_.CreateView({num_seq}, dtype_i32_); + NDArray pos2seq_id_host = pos2seq_id_host_.CreateView({num_token}, dtype_i32_); + NDArray pos2seq_id_device = pos2seq_id_device_.CreateView({num_token}, dtype_i32_); + NDArray token_ids_host = token_ids_host_.CreateView({num_token}, dtype_i32_); + NDArray token_ids_device = token_ids_device_.CreateView({num_token}, dtype_i32_); + NDArray token_cnt_host = token_cnt_host_.CreateView({num_token}, dtype_i32_); + NDArray token_cnt_device = token_cnt_device_.CreateView({num_token}, dtype_i32_); + NDArray penalties_host = penalties_host_.CreateView({num_seq, 3}, dtype_f32_); + NDArray penalties_device = penalties_device_.CreateView({num_seq, 3}, dtype_f32_); + + // - Copy arrays to GPU. + CopyArray(/*src=*/seq_ids_host, /*dst=*/seq_ids_device); + CopyArray(/*src=*/pos2seq_id_host, /*dst=*/pos2seq_id_device); + CopyArray(/*src=*/token_ids_host, /*dst=*/token_ids_device); + CopyArray(/*src=*/token_cnt_host, /*dst=*/token_cnt_device); + CopyArray(/*src=*/penalties_host, /*dst=*/penalties_device); + + // - Call kernel. + apply_penalty_func_(logits, seq_ids_device, pos2seq_id_device, token_ids_device, + token_cnt_device, penalties_device); + if (trace_recorder_.defined()) { + TVMSynchronize(device_.device_type, device_.device_id, /*stream=*/nullptr); + } + } + + void UpdateWithMask(NDArray logits, const Array& mstates, + const std::vector* cum_num_token, + const std::vector>* draft_tokens) { + // Construct: + // - seq_ids (max_num_token,) int32 + // - bitmask (max_num_token, ceildiv(vocab_size, 32)), int32 + int* p_seq_ids = static_cast(seq_ids_host_->data); + int* p_bitmask = static_cast(bitmask_host_->data); + + // - Set arrays. + int num_token_for_mask = 0; + for (int i = 0; i < static_cast(mstates.size()); ++i) { + int num_token_to_process = + cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i)); + int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i); + CHECK(num_token_to_process == 1 || mstates[i]->draft_output_tokens.empty()); + for (int j = 0; j < num_token_to_process; ++j) { + std::vector bitmask = mstates[i]->GetTokenBitmask(vocab_size_); + if (!bitmask.empty()) { + p_seq_ids[num_token_for_mask] = token_offset + j; + ICHECK_EQ(bitmask.size(), bitmask_size_); + for (int p = 0; p < bitmask_size_; ++p) { + p_bitmask[num_token_for_mask * bitmask_size_ + p] = bitmask[p]; + } + ++num_token_for_mask; + } + if (j > 0) { + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1]); + } + } + if (num_token_to_process != 1) { + // Roll back. + mstates[i]->RemoveAllDraftTokens(); + } + } + + if (num_token_for_mask == 0) { + return; + } + + // - View arrays. + int num_seq = num_token_for_mask; + NDArray seq_ids_host = seq_ids_host_.CreateView({num_seq}, dtype_i32_); + NDArray seq_ids_device = seq_ids_device_.CreateView({num_seq}, dtype_i32_); + NDArray bitmask_host = bitmask_host_.CreateView({num_seq, bitmask_size_}, dtype_i32_); + NDArray bitmask_device = bitmask_device_.CreateView({num_seq, bitmask_size_}, dtype_i32_); + + // - Copy arrays to GPU. + CopyArray(/*src=*/seq_ids_host, /*dst=*/seq_ids_device); + CopyArray(/*src=*/bitmask_host, /*dst=*/bitmask_device); + + // - Call kernel. + apply_bitmask_func_(logits, seq_ids_device, bitmask_device); + if (trace_recorder_.defined()) { + TVMSynchronize(device_.device_type, device_.device_id, /*stream=*/nullptr); + } + } + + // Model configurations + const int max_num_token_; + const int vocab_size_; + const int bitmask_size_; + const DLDataType dtype_i32_ = DataType::Int(32); + const DLDataType dtype_f32_ = DataType::Float(32); + // Packed functions. + Device device_; + PackedFunc softmax_func_; + PackedFunc apply_logit_bias_func_; + PackedFunc apply_penalty_func_; + PackedFunc apply_bitmask_func_; + // Auxiliary NDArrays on CPU + NDArray seq_ids_host_; + NDArray pos2seq_id_host_; + NDArray token_ids_host_; + NDArray token_cnt_host_; + NDArray token_logit_bias_host_; + NDArray penalties_host_; + NDArray bitmask_host_; + NDArray temperature_host_; + // Auxiliary NDArrays on GPU + NDArray seq_ids_device_; + NDArray pos2seq_id_device_; + NDArray token_ids_device_; + NDArray token_cnt_device_; + NDArray token_logit_bias_device_; + NDArray penalties_device_; + NDArray bitmask_device_; + NDArray temperature_device_; + // Event trace recorder. + Optional trace_recorder_; + // A small epsilon. + const double eps_ = 1e-5; +}; + +LogitProcessor::LogitProcessor(int max_num_token, int vocab_size, FunctionTable* ft, + DLDevice device, Optional trace_recorder) { + data_ = make_object(max_num_token, vocab_size, ft, device, + std::move(trace_recorder)); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/logit_processor.h b/cpp/serve/logit_processor.h new file mode 100644 index 0000000000..2425542731 --- /dev/null +++ b/cpp/serve/logit_processor.h @@ -0,0 +1,94 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/logit_processor.h + * \brief The header for logit processor. + */ + +#ifndef MLC_LLM_SERVE_LOGIT_PROCESSOR_H_ +#define MLC_LLM_SERVE_LOGIT_PROCESSOR_H_ + +#include +#include + +#include "../base.h" +#include "config.h" +#include "event_trace_recorder.h" +#include "function_table.h" +#include "request_state.h" + +namespace mlc { +namespace llm { +namespace serve { + +using tvm::Device; +using namespace tvm::runtime; + +/*! + * \brief The logit processor class that updates logits with regard + * presence/frequency penalties, logit bias, etc.. + */ +class LogitProcessorObj : public Object { + public: + /*! + * \brief In-place update a batch of logits with regard to the given + * generation config and request states. + * \param logits The batch of raw logits, in shape (num_total_token, vocab_size), + * where `num_total_token` may be larger than the number of sequences + * indicated by `generation_cfg`, in which case some sequences may have + * more than one token. + * \param generation_cfg The generation config of each sequence in the batch. + * \param mstates The request states of each sequence in the batch. + * \param request_ids The ids of each request. + * \param cum_num_token The pointer to the cumulative token length of the sequences. + * If the pointer is nullptr, it means each sequence has only one token. + * \param draft_tokens The pointer to the draft tokens of each sequence + * when speculation is enabled, in which case some sequences may have + * more than one token. + */ + virtual void InplaceUpdateLogits(NDArray logits, const Array& generation_cfg, + const Array& mstates, + const Array& request_ids, + const std::vector* cum_num_token = nullptr, + const std::vector>* draft_tokens = nullptr) = 0; + + /*! + * \brief Compute probability distributions for the input batch of logits. + * \param logits The batch of updated logits. + * \param generation_cfg The generation config of each sequence in the batch. + * \param request_ids The ids of each request. + * \param cum_num_token The pointer to the cumulative token length of the sequences. + * If the pointer is nullptr, it means each sequence has only one token. + * \return The batch of computed probability distributions on GPU. + */ + virtual NDArray ComputeProbsFromLogits(NDArray logits, + const Array& generation_cfg, + const Array& request_ids, + const std::vector* cum_num_token = nullptr) = 0; + + static constexpr const char* _type_key = "mlc.serve.LogitProcessor"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(LogitProcessorObj, Object); +}; + +class LogitProcessor : public ObjectRef { + public: + /*! + * \brief Constructor. + * \param max_num_token The max number of tokens in the token processor. + * \param vocab_size The model's vocabulary size. + * \param ft The packed function table. + * \param device The device that the model runs on. + * \param trace_recorder The event trace recorder. + */ + explicit LogitProcessor(int max_num_token, int vocab_size, FunctionTable* ft, DLDevice device, + Optional trace_recorder); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogitProcessor, ObjectRef, LogitProcessorObj); +}; + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_LOGIT_PROCESSOR_H_ diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 48ff463667..ecaa5276d8 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -11,6 +11,8 @@ #include +#include "logit_processor.h" + namespace mlc { namespace llm { namespace serve { @@ -350,34 +352,14 @@ class ModelImpl : public ModelObj { return logits; } - NDArray SoftmaxWithTemperature(NDArray logits, Array generation_cfg) final { - // logits: (b, n, v) - CHECK_EQ(logits->ndim, 3); - CHECK_EQ(logits->shape[0], generation_cfg.size()); - CHECK_EQ(logits->device.device_type, device_.device_type); - CHECK_EQ(logits->device.device_id, device_.device_id); - - int batch_size = logits->shape[0]; - std::vector temperatures; - temperatures.reserve(batch_size); - for (GenerationConfig cfg : generation_cfg) { - temperatures.push_back(cfg->temperature); - } - NDArray temperatures_nd = - CopyArrayToDevice(temperatures, &temperature_arr_, logits->dtype, 32, device_); - ICHECK_EQ(temperatures_nd->ndim, 1); - ICHECK_EQ(temperatures_nd->shape[0], batch_size); - - NDArray probs = ft_.softmax_func_(logits, temperatures_nd); - ICHECK_EQ(probs->ndim, 3); - ICHECK_EQ(probs->shape[0], logits->shape[0]); - ICHECK_EQ(probs->shape[1], logits->shape[1]); - ICHECK_EQ(probs->shape[2], logits->shape[2]); - return probs; - } - /*********************** KV Cache Management ***********************/ + LogitProcessor CreateLogitProcessor(int max_num_token, + Optional trace_recorder) { + return LogitProcessor(max_num_token, vocab_size_, &this->ft_, device_, + std::move(trace_recorder)); + } + void CreateKVCache(KVCacheConfig kv_cache_config) final { IntTuple max_num_sequence{kv_cache_config->max_num_sequence}; IntTuple max_total_sequence_length{kv_cache_config->max_total_sequence_length}; @@ -451,6 +433,12 @@ class ModelImpl : public ModelObj { } else { LOG(FATAL) << "Key \"tensor_parallel_shards\" not found."; } + if (config.count("vocab_size")) { + CHECK(config["vocab_size"].is()); + this->vocab_size_ = config["vocab_size"].get(); + } else { + LOG(FATAL) << "Key \"vocab_size\" not found."; + } return config; } @@ -460,6 +448,7 @@ class ModelImpl : public ModelObj { int max_window_size_ = -1; int num_shards_ = -1; int max_num_sequence_ = -1; + int vocab_size_ = -1; //---------------------------- // TVM related states //---------------------------- diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 72a869198e..b561b7895e 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -12,7 +12,9 @@ #include "../base.h" #include "config.h" +#include "event_trace_recorder.h" #include "function_table.h" +#include "logit_processor.h" namespace mlc { namespace llm { @@ -92,15 +94,6 @@ class ModelObj : public Object { virtual NDArray BatchVerify(const NDArray& embeddings, const std::vector& seq_ids, const std::vector& lengths) = 0; - /*! - * \brief Computing probabilities from logits with softmax and temperatures. - * \param logits The logits to compute from. - * \param generation_cfg The generation config which contains the temperatures. - * \return The computed probabilities distribution. - */ - virtual NDArray SoftmaxWithTemperature(NDArray logits, - Array generation_cfg) = 0; - /*********************** KV Cache Management ***********************/ /*! @@ -123,6 +116,10 @@ class ModelObj : public Object { /*********************** Utilities ***********************/ + /*! \brief Create a logit processor from this model. */ + virtual LogitProcessor CreateLogitProcessor(int max_num_token, + Optional trace_recorder) = 0; + /*! * \brief Estimate number of CPU units required to drive the model * executing during TP. diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index a4b5297337..b721d32ac6 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -31,6 +31,11 @@ int RequestModelStateNode::GetInputLength() const { return total_length; } +std::vector RequestModelStateNode::GetTokenBitmask(int vocab_size) const { + // TODO(mlc-team): implement this function. + return std::vector(); +} + void RequestModelStateNode::CommitToken(int32_t token_id) { committed_tokens.push_back(token_id); appeared_token_ids[token_id] += 1; @@ -43,14 +48,17 @@ void RequestModelStateNode::AddDraftToken(int32_t token_id) { void RequestModelStateNode::RemoveLastDraftToken() { ICHECK(!draft_output_tokens.empty()); - appeared_token_ids[draft_output_tokens.back()] -= 1; + auto it = appeared_token_ids.find(draft_output_tokens.back()); draft_output_tokens.pop_back(); + CHECK(it != appeared_token_ids.end()); + if (--it->second == 0) { + appeared_token_ids.erase(it); + } } void RequestModelStateNode::RemoveAllDraftTokens() { while (!draft_output_tokens.empty()) { - appeared_token_ids[draft_output_tokens.back()] -= 1; - draft_output_tokens.pop_back(); + RemoveLastDraftToken(); } } diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 82835d01df..ea0b688810 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -81,6 +81,11 @@ class RequestModelStateNode : public Object { /*! \brief Return the total length of the input data. */ int GetInputLength() const; + /*! + * \brief Return the token bitmask induced by the current state. + * The returned vector should have size "ceildiv(vocab_size, 32)". + */ + std::vector GetTokenBitmask(int vocab_size) const; /*! \brief Commit a new token into committed_tokens. Update appeared_token_ids. */ void CommitToken(int32_t token_id); /*! \brief Add a draft token into draft_output_tokens. Update appeared_token_ids. */ diff --git a/cpp/serve/sampler.cc b/cpp/serve/sampler.cc index 8ddfca527a..502bde72e6 100644 --- a/cpp/serve/sampler.cc +++ b/cpp/serve/sampler.cc @@ -18,128 +18,6 @@ namespace mlc { namespace llm { namespace serve { -/***** Utility function for in-place logits/prob update on CPU *****/ - -/*! - * \brief In-place apply repetition penalty to logits based on history tokens. - * \param logits The logits (a batch) to be in-place mutated. - * \param token_offset The offset of the token in the batch - * whose logits will be updated. - * \param state The request state that contains history tokens. - * \param repetition_penalty The value of repetition penalty. - */ -void ApplyRepetitionPenaltyOnCPU(NDArray logits, int token_offset, RequestModelState state, - double repetition_penalty) { - // logits: (n, v) - CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, DLDeviceType::kDLCPU); - int vocab_size = logits->shape[1]; - - float* logits_raw_data = static_cast(logits->data) + (token_offset * vocab_size); - for (const auto& it : state->appeared_token_ids) { - int token_id = it.first; - ICHECK_GE(token_id, 0); - ICHECK_LT(token_id, vocab_size); - if (logits_raw_data[token_id] <= 0) { - logits_raw_data[token_id] *= repetition_penalty; - } else { - logits_raw_data[token_id] /= repetition_penalty; - } - } -} - -/*! - * \brief In-place apply frequency and presence penalty to logits based on history tokens. - * \param logits The logits (a batch) to be in-place mutated. - * \param token_offset The offset of the token in the batch - * whose logits will be updated. - * \param state The request state that contains history tokens. - * \param frequency_penalty The value of frequency penalty. - * \param presence_penalty The value of presence penalty. - */ -void ApplyFrequencyAndPresencePenaltyOnCPU(NDArray logits, int token_offset, - RequestModelState state, double frequency_penalty, - double presence_penalty) { - // logits: (n, v) - CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, DLDeviceType::kDLCPU); - int vocab_size = logits->shape[1]; - - float* logits_raw_data = static_cast(logits->data) + (token_offset * vocab_size); - for (const auto& it : state->appeared_token_ids) { - int token_id = it.first; - int occurrences = it.second; - ICHECK_GE(token_id, 0); - ICHECK_LT(token_id, vocab_size); - logits_raw_data[token_id] -= occurrences * frequency_penalty + presence_penalty; - } -} - -/*! - * \brief In-place compute softmax with temperature on CPU. - * \param logits The logits (a batch) to compute softmax from. - * \param token_offset The offset of the token in the batch - * to compute softmax for. Only the logits of the specified - * token will be updated to probability after softmax. - * \param temperature The temperature to apply before softmax. - */ -void ApplySoftmaxWithTemperatureOnCPU(NDArray logits, int token_offset, double temperature) { - // logits: (n, v) - CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, DLDeviceType::kDLCPU); - int vocab_size = logits->shape[1]; - - float* __restrict logits_raw_data = - static_cast(__builtin_assume_aligned(logits->data, 4)) + (token_offset * vocab_size); - float m = std::numeric_limits::min(); - float inv_temp = 1.0f / temperature; - double d = 0.0f; - for (int i = 0; i < vocab_size; ++i) { - float x = logits_raw_data[i] * inv_temp; - float m_prev = m; - m = std::max(m, x); - d = d * std::exp(m_prev - m) + std::exp(x - m); - } - for (int i = 0; i < vocab_size; ++i) { - float x = logits_raw_data[i] * inv_temp; - logits_raw_data[i] = std::exp(x - m) / d; - } -} - -/*! - * \brief In-place set probability via argmax. - * This is used for zero-temperature sampling cases. - * \param logits The logits (a batch) to set probability. - * \param token_offset The offset of the token in the batch - * to set probability for. Only the logits of the specified - * token will be updated to probability. - */ -void SetProbWithArgmaxOnCPU(NDArray logits, int token_offset) { - // logits: (n, v) - CHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, kDLCPU); - int vocab_size = logits->shape[1]; - - float* logits_raw_data = static_cast(logits->data) + (token_offset * vocab_size); - int argmax_pos = -1; - float max_logits = std::numeric_limits::lowest(); - for (int i = 0; i < vocab_size; ++i) { - if (logits_raw_data[i] > max_logits) { - max_logits = logits_raw_data[i]; - argmax_pos = i; - } - } - - ICHECK_NE(argmax_pos, -1); - for (int i = 0; i < vocab_size; ++i) { - logits_raw_data[i] = i == argmax_pos ? 1.0f : 0.0f; - } -} - /*! * \brief Sample a value from the input probability distribution with top-p. * The input is a batch of distributions, and we use `unit_offset` to specify @@ -181,6 +59,30 @@ std::pair SampleTopPFromProb(NDArray prob, int unit_offset, doub if (!(*output_prob_dist)[unit_offset].defined()) { (*output_prob_dist)[unit_offset] = NDArray::Empty({ndata}, prob->dtype, DLDevice{kDLCPU, 0}); } + } + + if (top_p == 0) { + // Specially handle case where top_p == 0. + // This case is equivalent to doing argmax. + int argmax_pos = -1; + float max_prob = 0.0; + for (int i = 0; i < ndata; ++i) { + if (p_prob[i] > max_prob) { + max_prob = p_prob[i]; + argmax_pos = i; + } + } + if (output_prob_dist) { + float* __restrict p_output_prob = + static_cast(__builtin_assume_aligned((*output_prob_dist)[unit_offset]->data, 4)); + for (int i = 0; i < ndata; ++i) { + p_output_prob[i] = i == argmax_pos ? 1.0 : 0.0; + } + } + return std::make_pair(1.0, argmax_pos); + } + + if (output_prob_dist) { (*output_prob_dist)[unit_offset].CopyFromBytes(p_prob, ndata * sizeof(float)); } @@ -193,7 +95,6 @@ std::pair SampleTopPFromProb(NDArray prob, int unit_offset, doub return std::make_pair(p_prob[i], i); } } - LOG(INFO) << "prob sum = " << prob_sum << ", sample = " << uniform_sample; ICHECK(false) << "Possibly prob distribution contains NAN."; } @@ -278,37 +179,6 @@ std::pair SampleTopPFromProb(NDArray prob, int unit_offset, doub return sampled_index; } -/*! - * \brief Copy logits or prob distributions from device to CPU. - * The input array is in layout (b, n, v). - * This function flattens the first dimension, returns an NDArray - * in shape (b * n, v). - */ -NDArray CopyLogitsOrProbsToCPU(NDArray arr_on_device, NDArray* arr_on_cpu) { - // arr_on_device: (b, n, v) - ICHECK_EQ(arr_on_device->ndim, 3); - ICHECK(!arr_on_cpu->defined() || (*arr_on_cpu)->ndim == 2); - ICHECK(arr_on_device->device.device_type != kDLCPU); - if (arr_on_cpu->defined()) { - ICHECK_EQ((*arr_on_cpu)->shape[1], arr_on_device->shape[2]); - } - - int64_t init_size = arr_on_cpu->defined() ? (*arr_on_cpu)->shape[0] : 32; - int64_t num_tokens = arr_on_device->shape[0] * arr_on_device->shape[1]; - int64_t vocab_size = arr_on_device->shape[2]; - while (init_size < num_tokens) { - init_size *= 2; - } - if (!arr_on_cpu->defined() || init_size != (*arr_on_cpu)->shape[0]) { - (*arr_on_cpu) = - NDArray::Empty({init_size, vocab_size}, arr_on_device->dtype, DLDevice{kDLCPU, 0}); - } - ICHECK_LE(num_tokens, (*arr_on_cpu)->shape[0]); - NDArray view = arr_on_cpu->CreateView({num_tokens, vocab_size}, arr_on_device->dtype); - view.CopyFrom(arr_on_device); - return view; -} - /********************* CPU Sampler *********************/ class CPUSampler : public SamplerObj { @@ -323,44 +193,68 @@ class CPUSampler : public SamplerObj { } } - std::vector BatchSampleTokens(NDArray logits_on_device, Model model, - Array request_mstates, - Array generation_cfg, + std::vector BatchSampleTokens(NDArray probs_device, // + const Array& request_ids, + const Array& generation_cfg, const std::vector& rngs, std::vector* output_prob_dist, std::vector* output_token_probs) final { - NDArray probs_on_cpu = BatchComputeProb(logits_on_device, /*cum_sequence_length=*/nullptr, - model, request_mstates, generation_cfg); + // probs_device: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); + CHECK_EQ(probs_device->ndim, 2); + // - Copy probs to CPU + RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); + NDArray probs_host = CopyProbsToCPU(probs_device); + RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); + // - Sample tokens from probabilities. - // NOTE: Though we have the probability field in RequestModelState, - // we do not save the probabilities right now. - // We will handle this in the future when we work on speculation. - std::vector output_tokens = SampleTokensFromProbs( - probs_on_cpu, request_mstates, generation_cfg, rngs, output_prob_dist, output_token_probs); - return output_tokens; + ICHECK_EQ(probs_host->shape[0], request_ids.size()); + ICHECK_EQ(probs_host->shape[0], generation_cfg.size()); + ICHECK_EQ(probs_host->shape[0], rngs.size()); + int n = probs_host->shape[0]; + + std::vector sampled_tokens; + sampled_tokens.resize(n); + if (output_prob_dist) { + output_prob_dist->resize(n); + } + if (output_token_probs) { + output_token_probs->resize(n); + } + + tvm::runtime::parallel_for_with_threading_backend( + [this, &sampled_tokens, &probs_host, &generation_cfg, &rngs, &request_ids, output_prob_dist, + output_token_probs](int i) { + RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); + // Sample top p from probability. + std::pair sample_result = SampleTopPFromProb( + probs_host, i, generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, + rngs[i]->GetRandomNumber(), output_prob_dist); + sampled_tokens[i] = sample_result.second; + if (output_token_probs) { + (*output_token_probs)[i] = sample_result.first; + } + RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token"); + }, + 0, n); + RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + return sampled_tokens; } std::vector> BatchVerifyDraftTokens( - NDArray logits_on_device, const std::vector& cum_verify_lengths, Model model, - const Array& request_mstates, + NDArray probs_device, const Array& request_ids, + const std::vector& cum_verify_lengths, const Array& request_mstates, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, const std::vector>& draft_output_token_prob, const std::vector>& draft_output_prob_dist) final { - bool can_compute_prob_in_parallel = CanComputeProbInParallel(generation_cfg); - NDArray logits_or_probs_on_cpu{nullptr}; - Array request_ids = - request_mstates.Map([](const RequestModelState& mstate) { return mstate->request->id; }); - if (can_compute_prob_in_parallel) { - logits_or_probs_on_cpu = BatchComputeProb(logits_on_device, &cum_verify_lengths, model, - request_mstates, generation_cfg); - } else { - RECORD_EVENT(trace_recorder_, request_ids, "start copy logits to CPU"); - logits_or_probs_on_cpu = CopyLogitsOrProbsToCPU(logits_on_device, &logits_or_probs_on_cpu_); - RECORD_EVENT(trace_recorder_, request_ids, "finish copy logits to CPU"); - } - ICHECK(logits_or_probs_on_cpu->device.device_type == kDLCPU); - ICHECK_EQ(logits_or_probs_on_cpu->ndim, 2); + // probs_device: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); + CHECK_EQ(probs_device->ndim, 2); + // - Copy probs to CPU + RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); + NDArray probs_host = CopyProbsToCPU(probs_device); + RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); int num_sequence = static_cast(cum_verify_lengths.size()) - 1; CHECK_EQ(rngs.size(), num_sequence); @@ -372,20 +266,14 @@ class CPUSampler : public SamplerObj { accepted_tokens.resize(num_sequence); float* __restrict global_p_probs = - static_cast(__builtin_assume_aligned(logits_or_probs_on_cpu->data, 4)); - int vocab_size = logits_or_probs_on_cpu->shape[1]; + static_cast(__builtin_assume_aligned(probs_host->data, 4)); + int vocab_size = probs_host->shape[1]; tvm::runtime::parallel_for_with_threading_backend( [&](int i) { int verify_start = cum_verify_lengths[i]; int verify_end = cum_verify_lengths[i + 1]; for (int cur_token_idx = 0; cur_token_idx < verify_end - verify_start; ++cur_token_idx) { - if (!can_compute_prob_in_parallel) { - SinglePosComputeProbsFromLogitsInplace(logits_or_probs_on_cpu, - verify_start + cur_token_idx, - request_mstates[i], generation_cfg[i]); - } - float* p_probs = global_p_probs + (verify_start + cur_token_idx) * vocab_size; int cur_token = draft_output_tokens[i][cur_token_idx]; float q_value = draft_output_token_prob[i][cur_token_idx]; @@ -422,8 +310,10 @@ class CPUSampler : public SamplerObj { // sample a new token from the new distribution int32_t new_token = - SampleTopPFromProb(logits_or_probs_on_cpu, verify_start + cur_token_idx, - generation_cfg[i]->top_p, rngs[i]->GetRandomNumber()) + SampleTopPFromProb( + probs_host, verify_start + cur_token_idx, + generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, + rngs[i]->GetRandomNumber()) .second; request_mstates[i]->CommitToken(new_token); accepted_tokens[i].push_back(cur_token); @@ -431,238 +321,42 @@ class CPUSampler : public SamplerObj { } }, 0, num_sequence); + RECORD_EVENT(trace_recorder_, request_ids, "finish draft verification"); return accepted_tokens; } private: - /*! - * \brief Given the generation config of a batch, check if the - * probability distributions needs to be computed on device via softmax. - * \param generation_cfg The input generation config. - * \return A boolean flag indicating if the check result. - */ - bool RequireGPUSoftmax(Array generation_cfg) { - // - Return false if there is customized probability compute function. - if (flogits_to_probs_inplace_.defined()) { - return false; - } - // - Return false if any sampling param has frequency/presence penalty other than 0.0. - // - Return false if any sampling param has repetition penalty other than 1.0. - // - Return false if any sampling param has zero temperature. - for (GenerationConfig cfg : generation_cfg) { - if (cfg->frequency_penalty != 0.0 || cfg->presence_penalty != 0.0 || - cfg->repetition_penalty != 1.0 || cfg->temperature < 1e-6) { - return false; - } - } - return true; - } - - /*! - * \brief Given the generation config of a batch, check if the - * probability distributions need to be computed serially. - */ - bool CanComputeProbInParallel(const Array& generation_cfg) { - for (const GenerationConfig& cfg : generation_cfg) { - if (cfg->frequency_penalty != 0.0 || cfg->presence_penalty != 0.0 || - cfg->repetition_penalty != 1.0) { - return false; - } + /*! \brief Copy prob distributions from device to CPU. */ + NDArray CopyProbsToCPU(NDArray probs_device) { + // probs_device: (n, v) + ICHECK(probs_device->device.device_type != kDLCPU); + if (probs_host_.defined()) { + ICHECK_EQ(probs_host_->shape[1], probs_device->shape[1]); } - return true; - } - /*! - * \brief Compute the probability distribution of the input logits. - * \param logits_on_device The logits to compute probability distribution for. - * \param model The LLM model which contains the softmax - * function on device that might be used to compute probability distribution. - * \param request_mstates The request states of each sequence in - * the batch with regard to the given model. - * \param generation_cfg The generation config of each request - * in the input batch. - * \return The probability distribution of the input logits. - */ - NDArray BatchComputeProb(NDArray logits_on_device, const std::vector* cum_sequence_length, - Model model, const Array& request_mstates, - const Array& generation_cfg) { - ICHECK(logits_on_device.defined()); - ICHECK_EQ(logits_on_device->ndim, 3); - int num_sequence; - if (cum_sequence_length == nullptr) { - ICHECK_EQ(logits_on_device->shape[1], 1) - << "Multi-token sampling for one sequence requiring `cum_sequence_length`."; - num_sequence = logits_on_device->shape[0]; - } else { - ICHECK(!cum_sequence_length->empty()); - num_sequence = static_cast(cum_sequence_length->size()) - 1; - ICHECK_EQ(logits_on_device->shape[0], 1); - ICHECK_EQ(logits_on_device->shape[1], cum_sequence_length->back()); + int64_t init_size = probs_host_.defined() ? probs_host_->shape[0] : 32; + int64_t num_tokens = probs_device->shape[0]; + int64_t vocab_size = probs_device->shape[1]; + while (init_size < num_tokens) { + init_size *= 2; } - ICHECK_EQ(generation_cfg.size(), num_sequence); - ICHECK_EQ(request_mstates.size(), num_sequence); - - Array request_ids = - request_mstates.Map([](const RequestModelState& mstate) { return mstate->request->id; }); - - RECORD_EVENT(trace_recorder_, request_ids, "start query need GPU softmax"); - bool require_gpu_softmax = RequireGPUSoftmax(generation_cfg); - RECORD_EVENT(trace_recorder_, request_ids, "finish query need GPU softmax"); - - // - Compute probabilities from logits. - NDArray logits_or_probs_on_cpu{nullptr}; - if (require_gpu_softmax) { - RECORD_EVENT(trace_recorder_, request_ids, "start GPU softmax"); - Array generation_cfg_for_softmax; - if (cum_sequence_length == nullptr) { - generation_cfg_for_softmax = generation_cfg; - } else { - logits_on_device = logits_on_device.CreateView( - {logits_on_device->shape[1], 1, logits_on_device->shape[2]}, logits_on_device->dtype); - generation_cfg_for_softmax.reserve(logits_on_device->shape[1]); - for (int i = 0; i < num_sequence; ++i) { - for (int pos = cum_sequence_length->at(i); pos < cum_sequence_length->at(i + 1); ++pos) { - generation_cfg_for_softmax.push_back(generation_cfg[i]); - } - } - } - NDArray probs_on_device = - model->SoftmaxWithTemperature(logits_on_device, generation_cfg_for_softmax); - RECORD_EVENT(trace_recorder_, request_ids, "finish GPU softmax"); - RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); - logits_or_probs_on_cpu = CopyLogitsOrProbsToCPU(probs_on_device, &logits_or_probs_on_cpu_); - RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); - } else { - RECORD_EVENT(trace_recorder_, request_ids, "start copy logits to CPU"); - logits_or_probs_on_cpu = CopyLogitsOrProbsToCPU(logits_on_device, &logits_or_probs_on_cpu_); - RECORD_EVENT(trace_recorder_, request_ids, "finish copy logits to CPU"); - // The "BatchComputeProbsFromLogitsInplace" function updates - // `logits_or_probs_on_cpu` in place. - BatchComputeProbsFromLogitsInplace(logits_or_probs_on_cpu, cum_sequence_length, - std::move(request_mstates), generation_cfg); + if (!probs_host_.defined() || init_size != probs_host_->shape[0]) { + probs_host_ = + NDArray::Empty({init_size, vocab_size}, probs_device->dtype, DLDevice{kDLCPU, 0}); } - // `CopyLogitsOrProbsToCPU` flattens the first two dimensions. - ICHECK_EQ(logits_or_probs_on_cpu->ndim, 2); - return logits_or_probs_on_cpu; - } - - /*! - * \brief Compute the probability distribution from on-cpu logits for - * a batch of tokens **in place**. - * \param logits The input logits on CPU. - * \param states The request states, which contains the history generated tokens. - * \param generation_cfg The generation config. - * \note The function returns nothing. It in-place updates the input logits array. - */ - void BatchComputeProbsFromLogitsInplace(NDArray logits, - const std::vector* cum_sequence_length, - Array states, - Array generation_cfg) { - // logits: (n, v) - CHECK_EQ(logits->ndim, 2); - CHECK_EQ(logits->device.device_type, kDLCPU); - - // - Invoke environment compute function if exists. - if (flogits_to_probs_inplace_.defined()) { - IntTuple cum_sequence_length_obj; - if (cum_sequence_length != nullptr) { - cum_sequence_length_obj = - IntTuple{cum_sequence_length->begin(), cum_sequence_length->end()}; - } - flogits_to_probs_inplace_(logits, cum_sequence_length_obj, states, generation_cfg); - return; - } - - tvm::runtime::parallel_for_with_threading_backend( - [this, &logits, cum_sequence_length, &states, &generation_cfg](int i) { - int offset_start = cum_sequence_length == nullptr ? i : cum_sequence_length->at(i); - int offset_end = cum_sequence_length == nullptr ? i + 1 : cum_sequence_length->at(i + 1); - for (int offset = offset_start; offset < offset_end; ++offset) { - SinglePosComputeProbsFromLogitsInplace(logits, offset, states[i], generation_cfg[i]); - } - }, - 0, logits->shape[0]); - } - - void SinglePosComputeProbsFromLogitsInplace(NDArray logits, int offset, - const RequestModelState& state, - const GenerationConfig& generation_cfg) { - // - Apply frequency/presence penalty or repetition penalty (inplace). - if (generation_cfg->frequency_penalty != 0.0 || generation_cfg->presence_penalty != 0.0) { - RECORD_EVENT(trace_recorder_, state->request->id, "start frequency/presence penalty"); - ApplyFrequencyAndPresencePenaltyOnCPU(logits, offset, state, - generation_cfg->frequency_penalty, - generation_cfg->presence_penalty); - RECORD_EVENT(trace_recorder_, state->request->id, "finish frequency/presence penalty"); - } else if (generation_cfg->repetition_penalty != 1.0) { - RECORD_EVENT(trace_recorder_, state->request->id, "start repetition penalty"); - ApplyRepetitionPenaltyOnCPU(logits, offset, state, generation_cfg->repetition_penalty); - RECORD_EVENT(trace_recorder_, state->request->id, "finish repetition penalty"); - } - // - Compute probability (inplace) from logits. - // Using softmax if temperature is non-zero. - // Or set probability of the max-logit position to 1. - if (generation_cfg->temperature >= 1e-6) { - RECORD_EVENT(trace_recorder_, state->request->id, "start CPU softmax"); - ApplySoftmaxWithTemperatureOnCPU(logits, offset, generation_cfg->temperature); - RECORD_EVENT(trace_recorder_, state->request->id, "finish CPU softmax"); - } else { - RECORD_EVENT(trace_recorder_, state->request->id, "start argmax"); - SetProbWithArgmaxOnCPU(logits, offset); - RECORD_EVENT(trace_recorder_, state->request->id, "finish argmax"); - } - } - - std::vector SampleTokensFromProbs(NDArray probs, - Array request_mstates, - Array generation_cfg, - const std::vector& rngs, - std::vector* output_prob_dist, - std::vector* output_token_probs) { - // probs: (n, v) - CHECK_EQ(probs->ndim, 2); - CHECK_EQ(probs->device.device_type, kDLCPU); - ICHECK_EQ(probs->shape[0], request_mstates.size()); - ICHECK_EQ(probs->shape[0], generation_cfg.size()); - ICHECK_EQ(probs->shape[0], rngs.size()); - - Array request_ids = - request_mstates.Map([](const RequestModelState& mstate) { return mstate->request->id; }); - - int n = probs->shape[0]; - std::vector sampled_tokens; - sampled_tokens.resize(n); - if (output_prob_dist) { - output_prob_dist->resize(n); - } - if (output_token_probs) { - output_token_probs->resize(n); - } - - tvm::runtime::parallel_for_with_threading_backend( - [this, &sampled_tokens, &probs, &generation_cfg, &rngs, &request_ids, output_prob_dist, - output_token_probs](int i) { - RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); - // Sample top p from probability. - std::pair sample_result = SampleTopPFromProb( - probs, i, generation_cfg[i]->top_p, rngs[i]->GetRandomNumber(), output_prob_dist); - sampled_tokens[i] = sample_result.second; - if (output_token_probs) { - (*output_token_probs)[i] = sample_result.first; - } - RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token"); - }, - 0, n); - return sampled_tokens; + ICHECK_LE(num_tokens, probs_host_->shape[0]); + NDArray view = probs_host_.CreateView({num_tokens, vocab_size}, probs_device->dtype); + view.CopyFrom(probs_device); + return view; } /*! \brief The event trace recorder for requests. */ Optional trace_recorder_; /*! \brief Customized function which computes prob distribution from logits */ PackedFunc flogits_to_probs_inplace_; - /*! \brief Shared array for logits and probability distributions on cpu. */ - NDArray logits_or_probs_on_cpu_{nullptr}; - const float eps_ = 1e-9; + /*! \brief Probability distribution array on CPU. */ + NDArray probs_host_{nullptr}; + const float eps_ = 1e-5; }; /*********************** Sampler ***********************/ diff --git a/cpp/serve/sampler.h b/cpp/serve/sampler.h index d74a7ef400..ac4820db64 100644 --- a/cpp/serve/sampler.h +++ b/cpp/serve/sampler.h @@ -32,12 +32,9 @@ using namespace tvm::runtime; class SamplerObj : public Object { public: /*! - * \brief Sample tokens from the input batch of logits. - * \param logits_on_device The logits to sample tokens from. - * \param model The LLM model which contains the softmax - * function on device that might be used to compute probability distribution. - * \param request_mstates The request states of each sequence in - * the batch with regard to the given model. + * \brief Sample tokens from the input batch of prob distribution on device. + * \param probs_device The prob distributions on GPU to sample tokens from. + * \param request_ids The id of each request. * \param generation_cfg The generation config of each request * in the input batch. * \param rngs The random number generator of each sequence. @@ -46,18 +43,17 @@ class SamplerObj : public Object { * \return The sampled tokens, one for each request in the batch. */ virtual std::vector BatchSampleTokens( - NDArray logits_on_device, Model model, Array request_mstates, - Array generation_cfg, const std::vector& rngs, + NDArray probs_device, const Array& request_ids, + const Array& generation_cfg, const std::vector& rngs, std::vector* output_prob_dist = nullptr, std::vector* output_token_probs = nullptr) = 0; /*! * \brief Verify draft tokens generated by small models in the large model * in speculative decoding. The input corresponds to a batch of sequences. - * \param logits_on_device The logits of the large model. + * \param probs_device The prob distributions on GPU to sample tokens from. + * \param request_ids The id of each request. * \param cum_verify_lengths The cumulative draft lengths to verify of all sequences. - * \param model The LLM model which contains the softmax - * function on device that might be used to compute probability distribution. * \param request_mstates The request states of each sequence in * the batch with regard to the large model. * \param generation_cfg The generation config of each request @@ -72,8 +68,8 @@ class SamplerObj : public Object { * \return The list of accepted tokens for each request. */ virtual std::vector> BatchVerifyDraftTokens( - NDArray logits_on_device, const std::vector& cum_verify_lengths, Model model, - const Array& request_mstates, + NDArray probs_device, const Array& request_ids, + const std::vector& cum_verify_lengths, const Array& request_mstates, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, const std::vector>& draft_output_token_prob, diff --git a/python/mlc_chat/compiler_pass/attach_to_ir_module.py b/python/mlc_chat/compiler_pass/attach_to_ir_module.py index 84a6c76243..58507299ac 100644 --- a/python/mlc_chat/compiler_pass/attach_to_ir_module.py +++ b/python/mlc_chat/compiler_pass/attach_to_ir_module.py @@ -1,8 +1,10 @@ """A couple of passes that simply attach additional information onto the IRModule.""" + from typing import Dict import tvm from tvm import IRModule, relax, tir +from tvm.script import tir as T @tvm.transform.module_pass(opt_level=0, name="AttachVariableBounds") @@ -44,3 +46,114 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR if isinstance(func, relax.Function): mod[g_var] = func.with_attr("relax.memory_plan_dynamic_func_output", True) return mod + + +@tvm.transform.module_pass(opt_level=0, name="AttachLogitProcessFunc") +class AttachLogitProcessFunc: # pylint: disable=too-few-public-methods + """Attach logit processing TIR functions to IRModule.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + mod = mod.clone() + mod["apply_logit_bias_inplace"] = _apply_logit_bias_inplace + mod["apply_penalty_inplace"] = _apply_penalty_inplace + mod["apply_bitmask_inplace"] = _apply_bitmask_inplace + return mod + + +@T.prim_func +def _apply_logit_bias_inplace( + var_logits: T.handle, + var_pos2seq_id: T.handle, + var_token_ids: T.handle, + var_logit_bias: T.handle, +) -> None: + """Function that applies logit bias in place.""" + T.func_attr( + {"global_symbol": "apply_logit_bias_inplace", "tir.noalias": True, "tir.is_scheduled": True} + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_token = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + # seq_ids + pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") + token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") + logit_bias = T.match_buffer(var_logit_bias, (num_token,), "float32") + + for p0 in T.thread_binding(0, (num_token + 1023) // 1024, "blockIdx.x"): + for p1 in T.thread_binding(0, 1024, "threadIdx.x"): + with T.block("block"): + vp = T.axis.spatial(num_token, p0 * 1024 + p1) + T.where(p0 * 1024 + p1 < num_token) + logits[pos2seq_id[vp], token_ids[vp]] += logit_bias[vp] + + +@T.prim_func +def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-locals + var_logits: T.handle, + var_seq_ids: T.handle, + var_pos2seq_id: T.handle, + var_token_ids: T.handle, + var_token_cnt: T.handle, + var_penalties: T.handle, +) -> None: + """Function that applies penalties in place.""" + T.func_attr( + {"global_symbol": "apply_penalty_inplace", "tir.noalias": True, "tir.is_scheduled": True} + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_token = T.int32(is_size_var=True) + num_seq = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") + pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") + token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") + token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32") + penalties = T.match_buffer(var_penalties, (num_seq, 3), "float32") + + for p0 in T.thread_binding(0, (num_token + 1023) // 1024, "blockIdx.x"): + for p1 in T.thread_binding(0, 1024, "threadIdx.x"): + with T.block("block"): + vp = T.axis.spatial(num_token, p0 * 1024 + p1) + T.where(p0 * 1024 + p1 < num_token) + # Penalties: (presence_penalty, frequency_penalty, repetition_penalty) + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= ( + penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1] + ) + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else( + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] > 0, + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2], + ) + + +@T.prim_func +def _apply_bitmask_inplace( + var_logits: T.handle, + var_seq_ids: T.handle, + var_bitmask: T.handle, +) -> None: + """Function that applies vocabulary masking in place.""" + T.func_attr( + {"global_symbol": "apply_bitmask_inplace", "tir.noalias": True, "tir.is_scheduled": True} + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_seq = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") + bitmask = T.match_buffer(var_bitmask, (num_seq, (vocab_size + 31 // 32)), "int32") + + for fused_s_v_0 in T.thread_binding(0, (num_seq * vocab_size + 1023) // 1024, "blockIdx.x"): + for fused_s_v_1 in T.thread_binding(0, 1024, "threadIdx.x"): + with T.block("block"): + vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size) + vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size) + T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size) + logits[seq_ids[vs], vv] = T.if_then_else( + (bitmask[vs, vv // 32] >> (vv % 32)) & 1 == 1, + logits[seq_ids[vs], vv], + T.float32(-1e10), + ) diff --git a/python/mlc_chat/compiler_pass/pipeline.py b/python/mlc_chat/compiler_pass/pipeline.py index 20676187bd..98922c6139 100644 --- a/python/mlc_chat/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler_pass/pipeline.py @@ -13,6 +13,7 @@ from .attach_to_ir_module import ( AttachAdditionalPrimFuncs, + AttachLogitProcessFunc, AttachMemoryPlanAttr, AttachVariableBounds, ) @@ -89,6 +90,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Phase 0. Add additional information for compilation and remove unused Relax func RewriteKVCacheCreation(target, flashinfer, metadata), AttachVariableBounds(variable_bounds), + AttachLogitProcessFunc(), AttachAdditionalPrimFuncs(additional_tirs), AttachMemoryPlanAttr(), tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)), diff --git a/python/mlc_chat/protocol/openai_api_protocol.py b/python/mlc_chat/protocol/openai_api_protocol.py index 128d7e99d7..36b75f81a5 100644 --- a/python/mlc_chat/protocol/openai_api_protocol.py +++ b/python/mlc_chat/protocol/openai_api_protocol.py @@ -63,7 +63,7 @@ class CompletionRequest(BaseModel): echo: bool = False frequency_penalty: float = 0.0 presence_penalty: float = 0.0 - logit_bias: Optional[Dict[str, float]] = None + logit_bias: Optional[Dict[int, float]] = None logprobs: Optional[int] = None max_tokens: int = 16 n: int = 1 @@ -84,6 +84,22 @@ def check_penalty_range(cls, penalty_value: float) -> float: raise ValueError("Penalty value should be in range [-2, 2].") return penalty_value + @field_validator("logit_bias") + @classmethod + def check_logit_bias( + cls, logit_bias_value: Optional[Dict[int, float]] + ) -> Optional[Dict[int, float]]: + """Check if the logit bias key is given as an integer.""" + if logit_bias_value is None: + return None + for token_id, bias in logit_bias_value.items(): + if abs(bias) > 100: + raise ValueError( + "Logit bias value should be in range [-100, 100], while value " + f"{bias} is given for token id {token_id}" + ) + return logit_bias_value + class CompletionResponseChoice(BaseModel): finish_reason: Optional[Literal["stop", "length"]] = None @@ -149,7 +165,7 @@ class ChatCompletionRequest(BaseModel): model: str frequency_penalty: float = 0.0 presence_penalty: float = 0.0 - logit_bias: Optional[Dict[str, float]] = None + logit_bias: Optional[Dict[int, float]] = None max_tokens: Optional[int] = None n: int = 1 response_format: Literal["text", "json_object"] = "text" @@ -163,6 +179,30 @@ class ChatCompletionRequest(BaseModel): user: Optional[str] = None ignore_eos: bool = False + @field_validator("frequency_penalty", "presence_penalty") + @classmethod + def check_penalty_range(cls, penalty_value: float) -> float: + """Check if the penalty value is in range [-2, 2].""" + if penalty_value < -2 or penalty_value > 2: + raise ValueError("Penalty value should be in range [-2, 2].") + return penalty_value + + @field_validator("logit_bias") + @classmethod + def check_logit_bias( + cls, logit_bias_value: Optional[Dict[int, float]] + ) -> Optional[Dict[int, float]]: + """Check if the logit bias key is given as an integer.""" + if logit_bias_value is None: + return None + for token_id, bias in logit_bias_value.items(): + if abs(bias) > 100: + raise ValueError( + "Logit bias value should be in range [-100, 100], while value " + f"{bias} is given for token id {token_id}" + ) + return logit_bias_value + class ChatCompletionResponseChoice(BaseModel): finish_reason: Optional[Literal["stop", "length", "tool_calls", "error"]] = None @@ -214,7 +254,6 @@ def openai_api_get_unsupported_fields( """Get the unsupported fields in the request.""" unsupported_field_default_values: List[Tuple[str, Any]] = [ ("best_of", 1), - ("logit_bias", None), ("logprobs", None), ("n", 1), ("response_format", "text"), @@ -238,6 +277,7 @@ def openai_api_get_generation_config( "max_tokens", "frequency_penalty", "presence_penalty", + "logit_bias", "seed", "ignore_eos", ] diff --git a/python/mlc_chat/serve/config.py b/python/mlc_chat/serve/config.py index 4223148e8e..1962b61215 100644 --- a/python/mlc_chat/serve/config.py +++ b/python/mlc_chat/serve/config.py @@ -1,7 +1,7 @@ """Configuration dataclasses used in MLC LLM serving""" import json from dataclasses import asdict, dataclass, field -from typing import List, Optional +from typing import Dict, List, Optional @dataclass @@ -31,6 +31,9 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes It will be suppressed when any of frequency_penalty and presence_penalty is non-zero. + logit_bias : Optional[Dict[int, float]] + The bias logit value added to selected tokens prior to sampling. + max_tokens : Optional[int] The maximum number of generated tokens, or None, in which case the generation will not stop @@ -56,6 +59,7 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes frequency_penalty: float = 0.0 presence_penalty: float = 0.0 repetition_penalty: float = 1.0 + logit_bias: Optional[Dict[int, float]] = field(default_factory=dict) max_tokens: Optional[int] = 128 seed: Optional[int] = None diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index 65c63c2166..0721e97190 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -484,6 +484,54 @@ def test_openai_v1_completions_temperature( ) +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_logit_bias( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + # NOTE: This test only tests that the system does not break on logit bias. + # The test does not promise the correctness of logit bias handling. + + prompt = "What's the meaning of life?" + max_tokens = 128 + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "stream": stream, + "logit_bias": {338: -100}, # 338 is " is" in Llama tokenizer. + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reason="length", + ) + + @pytest.mark.parametrize("stream", [False, True]) def test_openai_v1_completions_presence_frequency_penalty( served_model: Tuple[str, str], @@ -889,26 +937,6 @@ def test_openai_v1_chat_completions_system_prompt_wrong_pos( assert num_chunks == 1 -def test_openai_v1_chat_completions_unsupported_args( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - # Right now "tool_choice" is unsupported. - tool_choice = "auto" - payload = { - "model": served_model[0], - "messages": CHAT_COMPLETION_MESSAGES[0], - "tool_choice": tool_choice, - } - - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) - error_msg_prefix = 'Request fields "tool_choice" are not supported right now.' - expect_error(response.json(), msg_prefix=error_msg_prefix) - - def test_debug_dump_event_trace( served_model: Tuple[str, str], launch_server, # pylint: disable=unused-argument @@ -946,6 +974,8 @@ def test_debug_dump_event_trace( test_openai_v1_completions_stop_str(MODEL, None, stream=True) test_openai_v1_completions_temperature(MODEL, None, stream=False) test_openai_v1_completions_temperature(MODEL, None, stream=True) + test_openai_v1_completions_logit_bias(MODEL, None, stream=False) + test_openai_v1_completions_logit_bias(MODEL, None, stream=True) test_openai_v1_completions_presence_frequency_penalty(MODEL, None, stream=False) test_openai_v1_completions_presence_frequency_penalty(MODEL, None, stream=True) test_openai_v1_completions_seed(MODEL, None) @@ -965,6 +995,5 @@ def test_debug_dump_event_trace( test_openai_v1_chat_completions_ignore_eos(MODEL, None, stream=True) test_openai_v1_chat_completions_system_prompt_wrong_pos(MODEL, None, stream=False) test_openai_v1_chat_completions_system_prompt_wrong_pos(MODEL, None, stream=True) - test_openai_v1_chat_completions_unsupported_args(MODEL, None) test_debug_dump_event_trace(MODEL, None) From bcb9b6a33a672a70d760c9a8b03234124aab50c4 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Sat, 24 Feb 2024 21:35:18 +0800 Subject: [PATCH 003/531] [Serving][Grammar] BNF grammar simplifier and matcher (#1801) --- CMakeLists.txt | 5 + cpp/serve/grammar/grammar.cc | 109 ++++ cpp/serve/grammar/grammar.h | 127 +++-- cpp/serve/grammar/grammar_builder.h | 107 ++-- cpp/serve/grammar/grammar_parser.cc | 113 ++-- cpp/serve/grammar/grammar_parser.h | 2 +- cpp/serve/grammar/grammar_serializer.cc | 93 ++-- cpp/serve/grammar/grammar_serializer.h | 31 +- cpp/serve/grammar/grammar_simplifier.cc | 219 ++++++++ cpp/serve/grammar/grammar_simplifier.h | 184 +++++++ cpp/serve/grammar/grammar_state_matcher.cc | 517 ++++++++++++++++++ cpp/serve/grammar/grammar_state_matcher.h | 125 +++++ .../grammar/grammar_state_matcher_base.h | 236 ++++++++ .../grammar/grammar_state_matcher_preproc.h | 315 +++++++++++ .../grammar/grammar_state_matcher_state.h | 442 +++++++++++++++ cpp/serve/grammar/support.h | 123 +++++ cpp/{serve => support}/encoding.cc | 63 ++- cpp/{serve => support}/encoding.h | 9 +- cpp/tokenizers.cc | 10 + cpp/tokenizers.h | 16 + python/mlc_chat/serve/__init__.py | 2 +- python/mlc_chat/serve/grammar.py | 162 +++++- tests/python/__init__.py | 0 tests/python/conftest.py | 21 + tests/python/serve/test_grammar_parser.py | 217 +++++--- .../serve/test_grammar_state_matcher.py | 387 +++++++++++++ 26 files changed, 3312 insertions(+), 323 deletions(-) create mode 100644 cpp/serve/grammar/grammar_simplifier.cc create mode 100644 cpp/serve/grammar/grammar_simplifier.h create mode 100644 cpp/serve/grammar/grammar_state_matcher.cc create mode 100644 cpp/serve/grammar/grammar_state_matcher.h create mode 100644 cpp/serve/grammar/grammar_state_matcher_base.h create mode 100644 cpp/serve/grammar/grammar_state_matcher_preproc.h create mode 100644 cpp/serve/grammar/grammar_state_matcher_state.h create mode 100644 cpp/serve/grammar/support.h rename cpp/{serve => support}/encoding.cc (94%) rename cpp/{serve => support}/encoding.h (95%) create mode 100644 tests/python/__init__.py create mode 100644 tests/python/conftest.py create mode 100644 tests/python/serve/test_grammar_state_matcher.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 15b7c9ab2a..a1644f0894 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -101,6 +101,11 @@ else () target_link_libraries(mlc_llm PUBLIC -Wl,--no-as-needed ${FLASH_ATTN_LIBRARY}) endif() +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + target_compile_definitions(mlc_llm PRIVATE "TVM_LOG_DEBUG") + target_compile_definitions(mlc_llm_objs PRIVATE "TVM_LOG_DEBUG") + target_compile_definitions(mlc_llm_static PRIVATE "TVM_LOG_DEBUG") +endif() if (BUILD_CPP_TEST) message(STATUS "Building cpp unittests") diff --git a/cpp/serve/grammar/grammar.cc b/cpp/serve/grammar/grammar.cc index 110838f5dc..89d3956501 100644 --- a/cpp/serve/grammar/grammar.cc +++ b/cpp/serve/grammar/grammar.cc @@ -5,12 +5,121 @@ #include "grammar.h" +#include "grammar_parser.h" +#include "grammar_serializer.h" +#include "grammar_simplifier.h" + namespace mlc { namespace llm { namespace serve { TVM_REGISTER_OBJECT_TYPE(BNFGrammarNode); +std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar) { + os << BNFGrammarPrinter(grammar).ToString(); + return os; +} + +BNFGrammar BNFGrammar::FromEBNFString(const String& ebnf_string, bool normalize, bool simplify) { + auto grammar = EBNFParser::Parse(ebnf_string); + if (normalize) { + grammar = NestedRuleUnwrapper(grammar).Apply(); + } + return grammar; +} + +TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromEBNFString") + .set_body_typed([](String ebnf_string, bool normalize, bool simplify) { + return BNFGrammar::FromEBNFString(ebnf_string, normalize, simplify); + }); + +BNFGrammar BNFGrammar::FromJSON(const String& json_string) { + return BNFJSONParser::Parse(json_string); +} + +TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromJSON").set_body_typed([](String json_string) { + return BNFGrammar::FromJSON(json_string); +}); + +const std::string kJSONGrammarString = R"( +main ::= ( + "{" ws members_or_embrace ws | + "[" ws elements_or_embrace ws +) +value ::= ( + "{" ws members_or_embrace | + "[" ws elements_or_embrace | + "\"" characters "\"" | + [0-9] fraction exponent | + [1-9] digits fraction exponent | + "-" [0-9] fraction exponent | + "-" [1-9] digits fraction exponent | + "true" | + "false" | + "null" +) +members_or_embrace ::= ( + "\"" characters "\"" ws ":" ws value members_rest ws "}" | + "}" +) +members ::= "\"" characters "\"" ws ":" ws value members_rest +members_rest ::= ( + "" | + "," ws "\"" characters "\"" ws ":" ws value members_rest | + " " ws "," ws "\"" characters "\"" ws ":" ws value members_rest | + "\n" ws "," ws "\"" characters "\"" ws ":" ws value members_rest | + "\t" ws "," ws "\"" characters "\"" ws ":" ws value members_rest +) +elements_or_embrace ::= ( + "{" ws members_or_embrace elements_rest ws "]" | + "[" ws elements_or_embrace elements_rest ws "]" | + "\"" characters "\"" elements_rest ws "]" | + [0-9] fraction exponent elements_rest ws "]" | + [1-9] digits fraction exponent elements_rest ws "]" | + "-" [0-9] fraction exponent elements_rest ws "]" | + "-" [1-9] digits fraction exponent elements_rest ws "]" | + "true" elements_rest ws "]" | + "false" elements_rest ws "]" | + "null" elements_rest ws "]" | + "]" +) +elements ::= ( + "{" ws members_or_embrace elements_rest | + "[" ws elements_or_embrace elements_rest | + "\"" characters "\"" elements_rest | + [0-9] fraction exponent elements_rest | + [1-9] digits fraction exponent elements_rest | + "-" [0-9] fraction exponent elements_rest | + "-" [1-9] digits fraction exponent elements_rest | + "true" elements_rest | + "false" elements_rest | + "null" elements_rest +) +elements_rest ::= ( + "" | + "," ws elements | + " " ws "," ws elements | + "\n" ws "," ws elements | + "\t" ws "," ws elements +) +characters ::= "" | [^"\\] characters | "\\" escape characters +escape ::= "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +digits ::= [0-9] | [0-9] digits +fraction ::= "" | "." digits +exponent ::= "" | "e" sign digits | "E" sign digits +sign ::= "" | "+" | "-" +ws ::= [ \n\t]* +)"; + +BNFGrammar BNFGrammar::GetGrammarOfJSON() { + static const BNFGrammar grammar = BNFGrammar::FromEBNFString(kJSONGrammarString, true, false); + return grammar; +} + +TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarGetGrammarOfJSON").set_body_typed([]() { + return BNFGrammar::GetGrammarOfJSON(); +}); + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/grammar/grammar.h b/cpp/serve/grammar/grammar.h index 9461c893f8..22e674527d 100644 --- a/cpp/serve/grammar/grammar.h +++ b/cpp/serve/grammar/grammar.h @@ -23,32 +23,48 @@ using namespace tvm::runtime; /*! * \brief This class stores the abstract syntax tree (AST) of the Backus-Naur Form (BNF) grammar. * The BNF definition here is standard BNF, and the characters are represented using regex-style - * character ranges (e.g. [a-z], [^a-z]). + * character classes (e.g. [a-z], [^a-z]). * - * \details The BNF grammar consists of a set of rules. Each rule has a name and a definition, and - * represents a production rule. Each rule has a rule_id for reference. + * \details + * ### Rules + * The BNF grammar AST consists of a set of rules. Each rule contains a name and a definition, and + * corresponds to a production in the grammar. The definition of a rule is a RuleExpr. Each rule + * has a rule_id for reference. * - * The definition of a rule is a RuleExpr. Ruleexpr can be the definition of a rule or part of the - * definition of a rule. + * ### RuleExprs + * RuleExpr is the definition of a rule or part of the definition of a rule. It can contain + * elements, empty string, reference to other RuleExprs, or reference to other rules. Each RuleExpr + * corresponds to an rule_expr_id for reference. * * For example, in the following rule: rule ::= ("a" "b") | "c" * ("a" "b"), "c", ("a" "b") | "c" are all RuleExprs. * + * #### Types of RuleExprs * Every RuleExpr is represented by a type as well as a variable-length array containing its data. - * There are several types for RuleExpr: - * - Character range: a range of characters (each character is a unicode codepoint), - * e.g. [a-z], [ac-z] - * - Negative character range: all characters that are not in the range, e.g. [^a-z], [^ac-z] + * RuleExpr has several types: + * - Character class: a range of characters (each character is a unicode codepoint), e.g. [a-z], + * [ac-z]. + * A single character is represented by a character class with the same lower and upper bound. + * A string is represented by a sequence of character classes. + * - Negated character class: all characters that are not in the range, e.g. [^a-z], [^ac-z] * - EmptyStr: an empty string, i.e. "" * - Rule reference: a reference to another rule * - Sequence: a sequence of rule_exprs, e.g. ("a" "b"). These rule_exprs are concatenated together. * - Choices: a choice of rule_exprs, e.g. ("a" "b") | "c". Each rule_expr can be matched. + * - Character class star: special support for a repetition of a character class. e.g. [a-z]* * - * For the format of the data, see BNFGrammarNode::DataKind. Each RuleExpr corresponds to an - * rule_expr_id for reference. + * #### Storage of RuleExprs + * Each type of RuleExpr has a different data format. For the format of each type of RuleExpr, see + * docs in BNFGrammarNode::RuleExprType. * * We store all RuleExprs in csr_matrix style. That is, they are stored consecutively in one vector * (data vector) and the starting position of each RuleExpr is recorded in the indptr vector. + * + * \remark The character class star RuleExpr is for the special support for elements like [a-z]* + * in the grammar. We add it to make the matching more efficient, as we can avoid recursion into + * rules when matching a sequence of characters. It should be used like: + * rule1 ::= ((element1 element2 rule2 ...) | ...) + * rule2 ::= character_class_star_rule_expr(id_of_a_character_class_rule_expr) */ class BNFGrammarNode : public Object { public: @@ -56,22 +72,25 @@ class BNFGrammarNode : public Object { struct Rule { /*! \brief The name of the rule. */ std::string name; - /*! \brief The RuleExpr id of the definition of the rule. */ - int32_t rule_expr_id; + /*! \brief The RuleExpr id of the body of the rule. */ + int32_t body_expr_id; }; /*! \brief Get the number of rules. */ size_t NumRules() const { return rules_.size(); } /*! \brief Get the rule with the given id. */ - const Rule& GetRule(int32_t rule_id) const { return rules_[rule_id]; } + const Rule& GetRule(int32_t rule_id) const { + DCHECK(rule_id >= 0 && rule_id < static_cast(rules_.size())) + << "rule_id " << rule_id << " is out of bound"; + return rules_[rule_id]; + } - /*! \brief The data kind of the content of rule_exprs. */ - enum class DataKind : int32_t { + /*! \brief The type of the rule expr. */ + enum class RuleExprType : int32_t { // data format: [lower0, upper0, lower1, upper1, ...] - // to represent a single character, just add the same lower and upper bound. - kCharacterRange, + kCharacterClass, // data format: [lower0, upper0, lower1, upper1, ...] - kNegCharacterRange, + kNegCharacterClass, // data format: [] kEmptyStr, // data format: [rule_id] @@ -80,37 +99,41 @@ class BNFGrammarNode : public Object { kSequence, // data format: [rule_expr_id0, rule_expr_id1, ...] kChoices, + // data format: [rule_expr_id] + kStarQuantifier, }; /*! \brief The object representing a rule expr. */ struct RuleExpr { - /*! \brief The data kind. */ - DataKind kind; + /*! \brief The type of the rule expr. */ + RuleExprType type; /*! \brief The data of the RuleExpr. A variable-length array. */ const int32_t* data; /*! \brief The length of the data array. */ - size_t data_len; + int32_t data_len; + const int32_t size() const { return data_len; } /*! \brief Get the i-th element of the data array. */ - const int32_t& operator[](int i) const { return data[i]; } + const int32_t& operator[](int i) const { + DCHECK(i >= 0 && i < static_cast(data_len)) << "Index " << i << " is out of bound"; + return data[i]; + } + const int32_t* begin() const { return data; } + const int32_t* end() const { return data + data_len; } }; /*! \brief Get the number of rule_exprs. */ size_t NumRuleExprs() const { return rule_expr_indptr_.size(); } /*! \brief Get the rule_expr with the given id. */ RuleExpr GetRuleExpr(int32_t rule_expr_id) const { + DCHECK(rule_expr_id >= 0 && rule_expr_id < static_cast(rule_expr_indptr_.size())) + << "rule_expr_id " << rule_expr_id << " is out of bound"; int start_index = rule_expr_indptr_[rule_expr_id]; - DataKind kind = static_cast(rule_expr_data_[start_index]); - ++start_index; - int end_index; - if (rule_expr_id == static_cast(rule_expr_indptr_.size()) - 1) { - end_index = rule_expr_data_.size(); - } else { - end_index = rule_expr_indptr_[rule_expr_id + 1]; - } - ICHECK_GE(end_index, start_index); - return {kind, rule_expr_data_.data() + start_index, - static_cast(end_index - start_index)}; + auto start_ptr = rule_expr_data_.data() + start_index; + auto type = static_cast(start_ptr[0]); + auto data_ptr = start_ptr + 2; + auto data_len = start_ptr[1]; + return {type, data_ptr, data_len}; } static constexpr const char* _type_key = "mlc.serve.BNFGrammar"; @@ -134,7 +157,41 @@ class BNFGrammarNode : public Object { class BNFGrammar : public ObjectRef { public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BNFGrammar, ObjectRef, BNFGrammarNode); + /*! + * \brief Construct a BNF grammar with a EBNF-formatted string. Will parse the string and + * transform it into BNF AST. + * \param ebnf_string The EBNF-formatted string. + * \param normalize Whether to normalize the grammar. Default: true. Only set to false for the + * purpose of testing. + * + * \note In The normalized form of a BNF grammar, every rule is in the form: + * `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. + * + * I.e. a list of choices, each choice is a sequence of elements. Elements can be a character + * class or a rule reference. And if the rule can be empty, the first choice will be an empty + * string. + * \param simplify Whether to simplify the grammar to make matching more efficient. Default: true. + * Not implemented yet. + */ + static BNFGrammar FromEBNFString(const String& ebnf_string, bool normalize = true, + bool simplify = true); + + /*! + * \brief Construct a BNF grammar from the dumped JSON string. + * \param json_string The JSON-formatted string. This string should have the same format as + * the result of BNFGrammarJSONSerializer::ToString. + */ + static BNFGrammar FromJSON(const String& json_string); + + /*! + * \brief Get the grammar of standard JSON format. We have built-in support for JSON. + */ + static BNFGrammar GetGrammarOfJSON(); + + /*! \brief Print a BNF grammar. */ + friend std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar); + + TVM_DEFINE_OBJECT_REF_METHODS(BNFGrammar, ObjectRef, BNFGrammarNode); }; } // namespace serve diff --git a/cpp/serve/grammar/grammar_builder.h b/cpp/serve/grammar/grammar_builder.h index 095d050c6d..eaa8af04f9 100644 --- a/cpp/serve/grammar/grammar_builder.h +++ b/cpp/serve/grammar/grammar_builder.h @@ -24,7 +24,7 @@ using namespace tvm::runtime; class BNFGrammarBuilder { public: using Rule = BNFGrammarNode::Rule; - using DataKind = BNFGrammarNode::DataKind; + using RuleExprType = BNFGrammarNode::RuleExprType; using RuleExpr = BNFGrammarNode::RuleExpr; /*! \brief Default constructor. Creates a new grammar object. */ @@ -36,82 +36,91 @@ class BNFGrammarBuilder { * \param grammar The existing grammar. */ explicit BNFGrammarBuilder(const BNFGrammar& grammar) - : grammar_(make_object(*grammar.get())) {} + : grammar_(make_object(*grammar.get())) { + // for (size_t i = 0; i < grammar_->rules_.size(); ++i) { + // rule_name_to_id_[grammar_->rules_[i].name] = i; + // } + } - /*! \brief Finalize the grammar building and return the built grammar. */ - BNFGrammar Finalize() { return BNFGrammar(grammar_); } + /*! \brief Get the result grammar. */ + BNFGrammar Get() { return BNFGrammar(grammar_); } /****************** RuleExpr handling ******************/ - /*! \brief Insert a rule_expr and return the rule_expr id. */ - int32_t InsertRuleExpr(const RuleExpr& rule_expr) { + /*! \brief Add a rule_expr and return the rule_expr id. */ + int32_t AddRuleExpr(const RuleExpr& rule_expr) { grammar_->rule_expr_indptr_.push_back(grammar_->rule_expr_data_.size()); - grammar_->rule_expr_data_.push_back(static_cast(rule_expr.kind)); + grammar_->rule_expr_data_.push_back(static_cast(rule_expr.type)); + grammar_->rule_expr_data_.push_back(rule_expr.data_len); grammar_->rule_expr_data_.insert(grammar_->rule_expr_data_.end(), rule_expr.data, rule_expr.data + rule_expr.data_len); return static_cast(grammar_->rule_expr_indptr_.size()) - 1; } /*! - * \brief One element of a character range, containing a lower and a upper bound. Both bounds are + * \brief One element of a character class, containing a lower and a upper bound. Both bounds are * inclusive. */ - struct CharacterRangeElement { + struct CharacterClassElement { int32_t lower; int32_t upper; }; - /*! \brief Insert a RuleExpr for character range.*/ - int32_t InsertCharacterRange(const std::vector& elements) { - std::vector data; - for (const auto& range : elements) { - data.push_back(range.lower); - data.push_back(range.upper); - } - return InsertRuleExpr({DataKind::kCharacterRange, data.data(), data.size()}); - } - - /*! \brief Insert a RuleExpr for character range negation.*/ - int32_t InsertNegCharacterRange(const std::vector& elements) { + /*! + * \brief Add a RuleExpr for character class. + * \param elements A vector of CharacterClassElement, each containing a lower and a upper bound. + * \param is_neg_range Whether the character class is negated. + */ + int32_t AddCharacterClass(const std::vector& elements, + bool is_neg_range = false) { std::vector data; for (const auto& range : elements) { data.push_back(range.lower); data.push_back(range.upper); } - return InsertRuleExpr({DataKind::kNegCharacterRange, data.data(), data.size()}); + auto type = is_neg_range ? RuleExprType::kNegCharacterClass : RuleExprType::kCharacterClass; + return AddRuleExpr({type, data.data(), static_cast(data.size())}); } - /*! \brief Insert a RuleExpr for empty string.*/ - int32_t InsertEmptyStr() { return InsertRuleExpr({DataKind::kEmptyStr, nullptr, 0}); } + /*! \brief Add a RuleExpr for empty string.*/ + int32_t AddEmptyStr() { return AddRuleExpr({RuleExprType::kEmptyStr, nullptr, 0}); } - /*! \brief Insert a RuleExpr for rule reference.*/ - int32_t InsertRuleRef(int32_t rule_id) { + /*! \brief Add a RuleExpr for rule reference.*/ + int32_t AddRuleRef(int32_t rule_id) { std::vector data; data.push_back(rule_id); - return InsertRuleExpr({DataKind::kRuleRef, data.data(), data.size()}); + return AddRuleExpr({RuleExprType::kRuleRef, data.data(), static_cast(data.size())}); } - /*! \brief Insert a RuleExpr for RuleExpr sequence.*/ - int32_t InsertSequence(const std::vector& elements) { + /*! \brief Add a RuleExpr for RuleExpr sequence.*/ + int32_t AddSequence(const std::vector& elements) { std::vector data; data.insert(data.end(), elements.begin(), elements.end()); - return InsertRuleExpr({DataKind::kSequence, data.data(), data.size()}); + return AddRuleExpr({RuleExprType::kSequence, data.data(), static_cast(data.size())}); } - /*! \brief Insert a RuleExpr for RuleExpr choices.*/ - int32_t InsertChoices(const std::vector& choices) { + /*! \brief Add a RuleExpr for RuleExpr choices.*/ + int32_t AddChoices(const std::vector& choices) { std::vector data; data.insert(data.end(), choices.begin(), choices.end()); - return InsertRuleExpr({DataKind::kChoices, data.data(), data.size()}); + return AddRuleExpr({RuleExprType::kChoices, data.data(), static_cast(data.size())}); } + int32_t AddStarQuantifier(int32_t element) { + std::vector data; + data.push_back(element); + return AddRuleExpr( + {RuleExprType::kStarQuantifier, data.data(), static_cast(data.size())}); + } + + size_t NumRuleExprs() const { return grammar_->NumRuleExprs(); } /*! \brief Get the rule_expr with the given id. */ RuleExpr GetRuleExpr(int32_t rule_expr_id) { return grammar_->GetRuleExpr(rule_expr_id); } /****************** Rule handling ******************/ - /*! \brief Insert a rule and return the rule id. */ - int32_t InsertRule(const Rule& rule) { + /*! \brief Add a rule and return the rule id. */ + int32_t AddRule(const Rule& rule) { int32_t id = grammar_->rules_.size(); auto rules = grammar_->rules_; grammar_->rules_.push_back(rule); @@ -120,33 +129,45 @@ class BNFGrammarBuilder { return id; } + int32_t AddRule(const std::string& name, int32_t body_expr_id) { + return AddRule({name, body_expr_id}); + } + + int32_t AddRuleWithHint(const std::string& name_hint, int32_t body_expr_id) { + return AddRule({GetNewRuleName(name_hint), body_expr_id}); + } + + size_t NumRules() const { return grammar_->NumRules(); } + /*! \brief Get the rule with the given id. */ const Rule& GetRule(int32_t rule_id) const { return grammar_->rules_[rule_id]; } /*! - * \brief Insert an rule without body, and return the rule id. The rule body should be set later + * \brief Add an rule without body, and return the rule id. The rule body should be set later * with BNFGrammarBuilder::UpdateRuleBody. This method is useful for cases where the rule id is * required to build the rule body. * \sa BNFGrammarBuilder::UpdateRuleBody */ - int32_t InsertEmptyRule(const std::string& name) { return InsertRule({name, -1}); } + int32_t AddEmptyRule(const std::string& name) { return AddRule({name, -1}); } /*! * \brief Update the rule body of the given rule, specified by rule id. Can be used to set the - * rule body of a rule inserted by BNFGrammarBuilder::InsertEmptyRule. + * rule body of a rule inserted by BNFGrammarBuilder::AddEmptyRule. */ - void UpdateRuleBody(int32_t rule_id, int32_t rule_expr_id) { - grammar_->rules_[rule_id].rule_expr_id = rule_expr_id; + void UpdateRuleBody(int32_t rule_id, int32_t body_expr_id) { + CHECK(rule_id < static_cast(grammar_->rules_.size())) + << "Rule id " << rule_id << " is out of range."; + grammar_->rules_[rule_id].body_expr_id = body_expr_id; } /*! * \brief Update the rule body of the given rule, specified by rule name. Can be used to set the - * rule body of a rule inserted by BNFGrammarBuilder::InsertEmptyRule. + * rule body of a rule inserted by BNFGrammarBuilder::AddEmptyRule. */ - void UpdateRuleBody(std::string rule_name, int32_t rule_expr_id) { + void UpdateRuleBody(std::string rule_name, int32_t body_expr_id) { int32_t rule_id = GetRuleId(rule_name); CHECK(rule_id != -1) << "Rule " << rule_name << " is not found."; - UpdateRuleBody(rule_id, rule_expr_id); + UpdateRuleBody(rule_id, body_expr_id); } /*! diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index 375a9a8be8..b5f6be1849 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -6,7 +6,7 @@ #include "grammar_parser.h" #include "../../metadata/json_parser.h" -#include "../encoding.h" +#include "../../support/encoding.h" #include "grammar_builder.h" namespace mlc { @@ -24,7 +24,7 @@ class EBNFParserImpl { // Parsing different parts of the grammar std::string ParseName(bool accept_empty = false); - int32_t ParseCharacterRange(); + int32_t ParseCharacterClass(); int32_t ParseString(); int32_t ParseRuleRef(); int32_t ParseElement(); @@ -67,8 +67,8 @@ class EBNFParserImpl { // Throw a ParseError with the given message and the line and column number. [[noreturn]] void ThrowParseError(const std::string& msg) { - throw ParseError(msg + " at line " + std::to_string(cur_line_) + ", column " + - std::to_string(cur_column_)); + throw ParseError("EBNF parse error at line " + std::to_string(cur_line_) + ", column " + + std::to_string(cur_column_) + ": " + msg); } // The grammar builder @@ -123,23 +123,24 @@ std::string EBNFParserImpl::ParseName(bool accept_empty) { return std::string(start, cur_); } -// Character range: +// Character class: // 1. Examples: [a-z] [ab] [a-zA-Z0-9] [^a-z] [测] [\u0123] -// 2. "-" appearing in the start or end of the character range means itself. Only if it appears -// between two characters, it means a range. E.g. [a-] and [-a] means "a" or "-"" [a--] means a to - -// 3. "-" and "]" can be escaped: +// 2. The "-" character is treated as a literal character if it is the last or the first (after +// the "^"", if present) character within the brackets. E.g. [a-] and [-a] means "a" or "-" +// 3. "-" and "]" should be escaped when used as a literal character: // [\-] means - // [\]] means ] -// Character range should not contain newlines. -int32_t EBNFParserImpl::ParseCharacterRange() { +// Character class should not contain newlines. +int32_t EBNFParserImpl::ParseCharacterClass() { + static constexpr TCodepoint kUnknownUpperBound = -4; static const std::unordered_map kCustomEscapeMap = {{"\\-", '-'}, {"\\]", ']'}}; - std::vector elements; + std::vector elements; - bool is_not_range = false; + bool is_negated = false; if (Peek() == '^') { - is_not_range = true; + is_negated = true; Consume(); } @@ -147,7 +148,7 @@ int32_t EBNFParserImpl::ParseCharacterRange() { bool past_is_single_char = false; while (Peek() && Peek() != ']') { if (Peek() == '\r' || Peek() == '\n') { - ThrowParseError("Character range should not contain newline"); + ThrowParseError("Character class should not contain newline"); } else if (Peek() == '-' && Peek(1) != ']' && !past_is_hyphen && past_is_single_char) { Consume(); past_is_hyphen = true; @@ -166,29 +167,29 @@ int32_t EBNFParserImpl::ParseCharacterRange() { if (past_is_hyphen) { ICHECK(!elements.empty()); if (elements.back().lower > codepoint) { - ThrowParseError("Invalid character range: lower bound is larger than upper bound"); + ThrowParseError("Invalid character class: lower bound is larger than upper bound"); } elements.back().upper = codepoint; past_is_hyphen = false; ICHECK(past_is_single_char == false); } else { - elements.push_back({codepoint, -1}); + elements.push_back({codepoint, kUnknownUpperBound}); past_is_single_char = true; } } for (auto& element : elements) { - if (element.upper == -1) { + if (element.upper == kUnknownUpperBound) { element.upper = element.lower; } } - return builder_.InsertCharacterRange(elements); + return builder_.AddCharacterClass(elements, is_negated); } // parse a c style string with utf8 support int32_t EBNFParserImpl::ParseString() { - std::vector character_ranges; + std::vector character_classes; while (Peek() && Peek() != '\"') { if (Peek() == '\r' || Peek() == '\n') { ThrowParseError("String should not contain newline"); @@ -201,21 +202,21 @@ int32_t EBNFParserImpl::ParseString() { ThrowParseError("Invalid escape sequence"); } Consume(len); - character_ranges.push_back(builder_.InsertCharacterRange({{codepoint, codepoint}})); + character_classes.push_back(builder_.AddCharacterClass({{codepoint, codepoint}})); } - if (character_ranges.empty()) { - return builder_.InsertEmptyStr(); + if (character_classes.empty()) { + return builder_.AddEmptyStr(); } - return builder_.InsertSequence(character_ranges); + return builder_.AddSequence(character_classes); } int32_t EBNFParserImpl::ParseRuleRef() { std::string name = ParseName(); auto rule_id = builder_.GetRuleId(name); if (rule_id == -1) { - ThrowParseError("Rule " + name + " is not defined"); + ThrowParseError("Rule \"" + name + "\" is not defined"); } - return builder_.InsertRuleRef(rule_id); + return builder_.AddRuleRef(rule_id); } int32_t EBNFParserImpl::ParseElement() { @@ -236,7 +237,7 @@ int32_t EBNFParserImpl::ParseElement() { } case '[': { Consume(); - auto rule_expr_id = ParseCharacterRange(); + auto rule_expr_id = ParseCharacterClass(); if (Peek() != ']') { ThrowParseError("Expect ]"); } @@ -259,18 +260,14 @@ int32_t EBNFParserImpl::ParseElement() { ThrowParseError("Expect element"); } } - return -1; } int32_t EBNFParserImpl::HandleStarQuantifier(int32_t rule_expr_id) { - // a* --> rule ::= a rule | empty + // rule ::= a* + // We have special support for star quantifier in BNFGrammar AST auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_); - auto new_rule_id = builder_.InsertEmptyRule(new_rule_name); - auto new_rule_ref = builder_.InsertRuleRef(new_rule_id); - auto new_rule_expr_id = builder_.InsertChoices( - {builder_.InsertSequence({rule_expr_id, new_rule_ref}), builder_.InsertEmptyStr()}); - builder_.UpdateRuleBody(new_rule_id, new_rule_expr_id); - return new_rule_id; + auto new_rule_expr_id = builder_.AddStarQuantifier(rule_expr_id); + return builder_.AddRule({new_rule_name, new_rule_expr_id}); } int32_t EBNFParserImpl::HandlePlusQuantifier(int32_t rule_expr_id) { @@ -278,16 +275,15 @@ int32_t EBNFParserImpl::HandlePlusQuantifier(int32_t rule_expr_id) { // We will use rule_expr a for two times in this case // So first we create a rule for rule_expr a auto a_rule_name = builder_.GetNewRuleName(cur_rule_name_); - auto a_rule_id = builder_.InsertRule({a_rule_name, rule_expr_id}); + auto a_rule_id = builder_.AddRule({a_rule_name, rule_expr_id}); // Then create the new rule_expr. auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_); - auto new_rule_id = builder_.InsertEmptyRule(new_rule_name); - auto a_plus_ref = builder_.InsertRuleRef(new_rule_id); - auto a_ref1 = builder_.InsertRuleRef(a_rule_id); - auto a_ref2 = builder_.InsertRuleRef(a_rule_id); - auto new_rule_expr_id = - builder_.InsertChoices({builder_.InsertSequence({a_ref1, a_plus_ref}), a_ref2}); + auto new_rule_id = builder_.AddEmptyRule(new_rule_name); + auto a_plus_ref = builder_.AddRuleRef(new_rule_id); + auto a_ref1 = builder_.AddRuleRef(a_rule_id); + auto a_ref2 = builder_.AddRuleRef(a_rule_id); + auto new_rule_expr_id = builder_.AddChoices({builder_.AddSequence({a_ref1, a_plus_ref}), a_ref2}); builder_.UpdateRuleBody(new_rule_id, new_rule_expr_id); return new_rule_id; } @@ -295,8 +291,8 @@ int32_t EBNFParserImpl::HandlePlusQuantifier(int32_t rule_expr_id) { int32_t EBNFParserImpl::HandleQuestionQuantifier(int32_t rule_expr_id) { // a? --> rule ::= a | empty auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_); - auto new_rule_expr_id = builder_.InsertChoices({rule_expr_id, builder_.InsertEmptyStr()}); - auto new_rule_id = builder_.InsertRule({new_rule_name, new_rule_expr_id}); + auto new_rule_expr_id = builder_.AddChoices({rule_expr_id, builder_.AddEmptyStr()}); + auto new_rule_id = builder_.AddRule({new_rule_name, new_rule_expr_id}); return new_rule_id; } @@ -311,11 +307,12 @@ int32_t EBNFParserImpl::ParseQuantifier() { // We will transform a*, a+, a? into a rule, and return the reference to this rule switch (Peek(-1)) { case '*': - return builder_.InsertRuleRef(HandleStarQuantifier(rule_expr_id)); + // We assume that the star quantifier should be the body of some rule now + return builder_.AddStarQuantifier(rule_expr_id); case '+': - return builder_.InsertRuleRef(HandlePlusQuantifier(rule_expr_id)); + return builder_.AddRuleRef(HandlePlusQuantifier(rule_expr_id)); case '?': - return builder_.InsertRuleRef(HandleQuestionQuantifier(rule_expr_id)); + return builder_.AddRuleRef(HandleQuestionQuantifier(rule_expr_id)); default: LOG(FATAL) << "Unreachable"; } @@ -329,7 +326,7 @@ int32_t EBNFParserImpl::ParseSequence() { elements.push_back(ParseQuantifier()); ConsumeSpace(in_parentheses_); } - return builder_.InsertSequence(elements); + return builder_.AddSequence(elements); } int32_t EBNFParserImpl::ParseChoices() { @@ -343,7 +340,7 @@ int32_t EBNFParserImpl::ParseChoices() { choices.push_back(ParseSequence()); ConsumeSpace(); } - return builder_.InsertChoices(choices); + return builder_.AddChoices(choices); } EBNFParserImpl::Rule EBNFParserImpl::ParseRule() { @@ -369,9 +366,9 @@ void EBNFParserImpl::BuildRuleNameToId() { } Consume(3); if (builder_.GetRuleId(name) != -1) { - ThrowParseError("Rule " + name + " is defined multiple times"); + ThrowParseError("Rule \"" + name + "\" is defined multiple times"); } - builder_.InsertEmptyRule(name); + builder_.AddEmptyRule(name); } while (Peek() && Peek() != '\n' && Peek() != '\r') { Consume(); @@ -396,16 +393,16 @@ BNFGrammar EBNFParserImpl::DoParse(String ebnf_string) { ConsumeSpace(); while (Peek()) { auto new_rule = ParseRule(); - builder_.UpdateRuleBody(new_rule.name, new_rule.rule_expr_id); + builder_.UpdateRuleBody(new_rule.name, new_rule.body_expr_id); ConsumeSpace(); } if (builder_.GetRuleId("main") == -1) { - ThrowParseError("There must be a rule named main"); + ThrowParseError("There must be a rule named \"main\""); } - return builder_.Finalize(); + return builder_.Get(); } BNFGrammar EBNFParser::Parse(String ebnf_string) { @@ -413,10 +410,6 @@ BNFGrammar EBNFParser::Parse(String ebnf_string) { return parser.DoParse(ebnf_string); } -TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromEBNFString").set_body_typed([](String ebnf_string) { - return EBNFParser::Parse(ebnf_string); -}); - BNFGrammar BNFJSONParser::Parse(String json_string) { auto node = make_object(); auto grammar_json = json::ParseToJsonObject(json_string); @@ -425,7 +418,7 @@ BNFGrammar BNFJSONParser::Parse(String json_string) { auto rule_json_obj = rule_json.get(); auto name = json::Lookup(rule_json.get(), "name"); auto rule_expr = static_cast( - json::Lookup(rule_json.get(), "rule_expr_id")); + json::Lookup(rule_json.get(), "body_expr_id")); node->rules_.push_back(BNFGrammarNode::Rule({name, rule_expr})); } auto rule_expr_data_json = json::Lookup(grammar_json, "rule_expr_data"); @@ -439,10 +432,6 @@ BNFGrammar BNFJSONParser::Parse(String json_string) { return BNFGrammar(std::move(node)); } -TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromJSON").set_body_typed([](String json_string) { - return BNFJSONParser::Parse(json_string); -}); - } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/grammar/grammar_parser.h b/cpp/serve/grammar/grammar_parser.h index b934b055b0..6c5b0c03fa 100644 --- a/cpp/serve/grammar/grammar_parser.h +++ b/cpp/serve/grammar/grammar_parser.h @@ -20,7 +20,7 @@ using namespace tvm::runtime; /*! * \brief This class parses a BNF/EBNF grammar string into an BNF abstract syntax tree (AST). - * \details This function accepts the EBNF notation from the W3C XML Specification + * \details This function accepts the EBNF notation defined in the W3C XML Specification * (https://www.w3.org/TR/xml/#sec-notation), which is a popular standard, with the following * changes: * - Using # as comment mark instead of /**\/ diff --git a/cpp/serve/grammar/grammar_serializer.cc b/cpp/serve/grammar/grammar_serializer.cc index 69641c4186..b77e194199 100644 --- a/cpp/serve/grammar/grammar_serializer.cc +++ b/cpp/serve/grammar/grammar_serializer.cc @@ -9,7 +9,7 @@ #include #include -#include "../encoding.h" +#include "../../support/encoding.h" namespace mlc { namespace llm { @@ -17,37 +17,45 @@ namespace serve { using namespace tvm::runtime; -std::string BNFGrammarPrinter::PrintRuleExpr(int32_t rule_expr_id) { +std::string BNFGrammarPrinter::PrintRule(const Rule& rule) { + return rule.name + " ::= " + PrintRuleExpr(rule.body_expr_id); +} + +std::string BNFGrammarPrinter::PrintRule(int32_t rule_id) { + return PrintRule(grammar_->GetRule(rule_id)); +} + +std::string BNFGrammarPrinter::PrintRuleExpr(const RuleExpr& rule_expr) { std::string result; - auto rule_expr = grammar_->GetRuleExpr(rule_expr_id); - switch (rule_expr.kind) { - case DataKind::kCharacterRange: - result += PrintCharacterRange(rule_expr); - break; - case DataKind::kNegCharacterRange: - result += PrintCharacterRange(rule_expr); - break; - case DataKind::kEmptyStr: - result += PrintEmptyStr(rule_expr); - break; - case DataKind::kRuleRef: - result += PrintRuleRef(rule_expr); - break; - case DataKind::kSequence: - result += PrintSequence(rule_expr); - break; - case DataKind::kChoices: - result += PrintChoices(rule_expr); - break; + switch (rule_expr.type) { + case RuleExprType::kCharacterClass: + return PrintCharacterClass(rule_expr); + case RuleExprType::kNegCharacterClass: + return PrintCharacterClass(rule_expr); + case RuleExprType::kEmptyStr: + return PrintEmptyStr(rule_expr); + case RuleExprType::kRuleRef: + return PrintRuleRef(rule_expr); + case RuleExprType::kSequence: + return PrintSequence(rule_expr); + case RuleExprType::kChoices: + return PrintChoices(rule_expr); + case RuleExprType::kStarQuantifier: + return PrintStarQuantifier(rule_expr); + default: + LOG(FATAL) << "Unexpected RuleExpr type: " << static_cast(rule_expr.type); } - return result; } -std::string BNFGrammarPrinter::PrintCharacterRange(const RuleExpr& rule_expr) { +std::string BNFGrammarPrinter::PrintRuleExpr(int32_t rule_expr_id) { + return PrintRuleExpr(grammar_->GetRuleExpr(rule_expr_id)); +} + +std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) { static const std::unordered_map kCustomEscapeMap = {{'-', "\\-"}, {']', "\\]"}}; std::string result = "["; - if (rule_expr.kind == DataKind::kNegCharacterRange) { + if (rule_expr.type == RuleExprType::kNegCharacterClass) { result += "^"; } for (auto i = 0; i < rule_expr.data_len; i += 2) { @@ -70,55 +78,40 @@ std::string BNFGrammarPrinter::PrintRuleRef(const RuleExpr& rule_expr) { std::string BNFGrammarPrinter::PrintSequence(const RuleExpr& rule_expr) { std::string result; - auto prev_require_parentheses = require_parentheses_; - // If the sequence contains > 1 elements, and is nested in another rule_expr with > 1 elements, - // we need to print parentheses. - auto now_require_parentheses = require_parentheses_ && rule_expr.data_len > 1; - require_parentheses_ = require_parentheses_ || rule_expr.data_len > 1; - if (now_require_parentheses) { - result += "("; - } + result += "("; for (int i = 0; i < rule_expr.data_len; ++i) { result += PrintRuleExpr(rule_expr[i]); if (i + 1 != rule_expr.data_len) { result += " "; } } - if (now_require_parentheses) { - result += ")"; - } - require_parentheses_ = prev_require_parentheses; + result += ")"; return result; } std::string BNFGrammarPrinter::PrintChoices(const RuleExpr& rule_expr) { std::string result; - auto prev_require_parentheses = require_parentheses_; - auto now_require_parentheses = require_parentheses_ && rule_expr.data_len > 1; - require_parentheses_ = require_parentheses_ || rule_expr.data_len > 1; - if (now_require_parentheses) { - result += "("; - } + result += "("; for (int i = 0; i < rule_expr.data_len; ++i) { result += PrintRuleExpr(rule_expr[i]); if (i + 1 != rule_expr.data_len) { result += " | "; } } - if (now_require_parentheses) { - result += ")"; - } - require_parentheses_ = prev_require_parentheses; + result += ")"; return result; } +std::string BNFGrammarPrinter::PrintStarQuantifier(const RuleExpr& rule_expr) { + return PrintRuleExpr(rule_expr[0]) + "*"; +} + String BNFGrammarPrinter::ToString() { std::string result; auto num_rules = grammar_->NumRules(); for (auto i = 0; i < num_rules; ++i) { - auto rule = grammar_->GetRule(i); - result += rule.name + " ::= " + PrintRuleExpr(rule.rule_expr_id) + "\n"; + result += PrintRule(grammar_->GetRule(i)) + "\n"; } return result; } @@ -134,7 +127,7 @@ String BNFGrammarJSONSerializer::ToString() { for (const auto& rule : grammar_->rules_) { picojson::object rule_json; rule_json["name"] = picojson::value(rule.name); - rule_json["rule_expr_id"] = picojson::value(static_cast(rule.rule_expr_id)); + rule_json["body_expr_id"] = picojson::value(static_cast(rule.body_expr_id)); rules_json.push_back(picojson::value(rule_json)); } grammar_json["rules"] = picojson::value(rules_json); diff --git a/cpp/serve/grammar/grammar_serializer.h b/cpp/serve/grammar/grammar_serializer.h index d183e62b75..2bf47392bc 100644 --- a/cpp/serve/grammar/grammar_serializer.h +++ b/cpp/serve/grammar/grammar_serializer.h @@ -38,7 +38,8 @@ class BNFGrammarSerializer { */ class BNFGrammarPrinter : public BNFGrammarSerializer { private: - using DataKind = BNFGrammarNode::DataKind; + using Rule = BNFGrammarNode::Rule; + using RuleExprType = BNFGrammarNode::RuleExprType; using RuleExpr = BNFGrammarNode::RuleExpr; public: @@ -51,24 +52,28 @@ class BNFGrammarPrinter : public BNFGrammarSerializer { /*! \brief Print the complete grammar. */ String ToString() final; - /*! \brief Print a rule_expr corresponding to the given id. */ + /*! \brief Print a rule. */ + std::string PrintRule(const Rule& rule); + /*! \brief Print a rule corresponding to the given id. */ + std::string PrintRule(int32_t rule_id); + /*! \brief Print a RuleExpr. */ + std::string PrintRuleExpr(const RuleExpr& rule_expr); + /*! \brief Print a RuleExpr corresponding to the given id. */ std::string PrintRuleExpr(int32_t rule_expr_id); - /*! \brief Print rule_exprs for character range. */ - std::string PrintCharacterRange(const RuleExpr& rule_expr); - /*! \brief Print rule_exprs for empty string. */ + private: + /*! \brief Print a RuleExpr for character class. */ + std::string PrintCharacterClass(const RuleExpr& rule_expr); + /*! \brief Print a RuleExpr for empty string. */ std::string PrintEmptyStr(const RuleExpr& rule_expr); - /*! \brief Print rule_exprs for rule reference. */ + /*! \brief Print a RuleExpr for rule reference. */ std::string PrintRuleRef(const RuleExpr& rule_expr); - /*! \brief Print rule_exprs for rule_expr sequence. */ + /*! \brief Print a RuleExpr for rule_expr sequence. */ std::string PrintSequence(const RuleExpr& rule_expr); - /*! \brief Print rule_exprs for rule_expr choices. */ + /*! \brief Print a RuleExpr for rule_expr choices. */ std::string PrintChoices(const RuleExpr& rule_expr); - - private: - // Only print parentheses when necessary (i.e. when this rule_expr contains multiple elements - // and is nested within another multi-element rule_expr) - bool require_parentheses_ = false; + /*! \brief Print a RuleExpr for star quantifier. */ + std::string PrintStarQuantifier(const RuleExpr& rule_expr); }; /*! diff --git a/cpp/serve/grammar/grammar_simplifier.cc b/cpp/serve/grammar/grammar_simplifier.cc new file mode 100644 index 0000000000..ccbfe971f2 --- /dev/null +++ b/cpp/serve/grammar/grammar_simplifier.cc @@ -0,0 +1,219 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/grammar/grammar_simplifier.cc + */ + +#include "grammar_simplifier.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief Eliminates single-element sequence or choice nodes in the grammar. + * \example The sequence `(a)` or the choice `(a)` will be replaced by `a` in a rule. + * \example The rule `A ::= ((b) (((d))))` will be replaced by `A ::= (b d)`. + */ +class SingleElementSequenceOrChoiceEliminator : public BNFGrammarMutator { + public: + using BNFGrammarMutator::Apply; + using BNFGrammarMutator::BNFGrammarMutator; + + private: + int32_t VisitSequence(const RuleExpr& rule_expr) { + std::vector sequence_ids; + for (int32_t i : rule_expr) { + sequence_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + } + if (sequence_ids.size() == 1) { + return sequence_ids[0]; + } else { + return builder_.AddSequence(sequence_ids); + } + } + + int32_t VisitChoices(const RuleExpr& rule_expr) { + std::vector choice_ids; + for (int32_t i : rule_expr) { + choice_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + } + if (choice_ids.size() == 1) { + return choice_ids[0]; + } else { + return builder_.AddChoices(choice_ids); + } + } +}; + +class NestedRuleUnwrapperImpl : public BNFGrammarMutator { + public: + using BNFGrammarMutator::BNFGrammarMutator; + + BNFGrammar Apply() final { + grammar_ = SingleElementSequenceOrChoiceEliminator(grammar_).Apply(); + for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { + builder_.AddEmptyRule(grammar_->GetRule(i).name); + } + for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { + auto rule = grammar_->GetRule(i); + auto rule_expr = grammar_->GetRuleExpr(rule.body_expr_id); + cur_rule_name_ = rule.name; + auto new_body_expr_id = VisitRuleBody(rule_expr); + builder_.UpdateRuleBody(i, new_body_expr_id); + } + return builder_.Get(); + } + + private: + /*! \brief Visit a RuleExpr as the rule body. */ + int32_t VisitRuleBody(const RuleExpr& rule_expr) { + switch (rule_expr.type) { + case RuleExprType::kSequence: + return builder_.AddChoices({builder_.AddSequence(VisitSequence_(rule_expr))}); + case RuleExprType::kChoices: + return builder_.AddChoices(VisitChoices_(rule_expr)); + case RuleExprType::kEmptyStr: + return builder_.AddChoices({builder_.AddEmptyStr()}); + case RuleExprType::kCharacterClass: + case RuleExprType::kNegCharacterClass: + case RuleExprType::kRuleRef: + return builder_.AddChoices({builder_.AddSequence({builder_.AddRuleExpr(rule_expr)})}); + case RuleExprType::kStarQuantifier: + return builder_.AddStarQuantifier(VisitExpr(grammar_->GetRuleExpr(rule_expr[0]))); + default: + LOG(FATAL) << "Unexpected sequence type: " << static_cast(rule_expr.type); + } + } + + /*! + * \brief Visit a RuleExpr containing choices. + * \returns A list of new choice RuleExpr ids. + */ + std::vector VisitChoices_(const RuleExpr& rule_expr) { + std::vector new_choice_ids; + bool found_empty = false; + for (auto i : rule_expr) { + auto choice_expr = grammar_->GetRuleExpr(i); + switch (choice_expr.type) { + case RuleExprType::kSequence: + VisitSequenceInChoices(choice_expr, &new_choice_ids, &found_empty); + break; + case RuleExprType::kChoices: + VisitChoicesInChoices(choice_expr, &new_choice_ids, &found_empty); + break; + case RuleExprType::kEmptyStr: + found_empty = true; + break; + case RuleExprType::kCharacterClass: + case RuleExprType::kNegCharacterClass: + case RuleExprType::kRuleRef: + VisitElementInChoices(choice_expr, &new_choice_ids); + break; + default: + LOG(FATAL) << "Unexpected choice type: " << static_cast(choice_expr.type); + } + } + if (found_empty) { + new_choice_ids.insert(new_choice_ids.begin(), builder_.AddEmptyStr()); + } + ICHECK_GE(new_choice_ids.size(), 1); + return new_choice_ids; + } + + /*! \brief Visit a sequence RuleExpr that is one of a list of choices. */ + void VisitSequenceInChoices(const RuleExpr& rule_expr, std::vector* new_choice_ids, + bool* found_empty) { + auto sub_sequence_ids = VisitSequence_(rule_expr); + if (sub_sequence_ids.size() == 0) { + *found_empty = true; + } else { + new_choice_ids->push_back(builder_.AddSequence(sub_sequence_ids)); + } + } + + /*! \brief Visit a choice RuleExpr that is one of a list of choices. */ + void VisitChoicesInChoices(const RuleExpr& rule_expr, std::vector* new_choice_ids, + bool* found_empty) { + auto sub_choice_ids = VisitChoices_(rule_expr); + bool contains_empty = builder_.GetRuleExpr(sub_choice_ids[0]).type == RuleExprType::kEmptyStr; + if (contains_empty) { + *found_empty = true; + new_choice_ids->insert(new_choice_ids->end(), sub_choice_ids.begin() + 1, + sub_choice_ids.end()); + } else { + new_choice_ids->insert(new_choice_ids->end(), sub_choice_ids.begin(), sub_choice_ids.end()); + } + } + + /*! \brief Visit an atom element RuleExpr that is one of a list of choices. */ + void VisitElementInChoices(const RuleExpr& rule_expr, std::vector* new_choice_ids) { + auto sub_expr_id = builder_.AddRuleExpr(rule_expr); + new_choice_ids->push_back(builder_.AddSequence({sub_expr_id})); + } + + /*! + * \brief Visit a RuleExpr containing a sequence. + * \returns A list of new sequence RuleExpr ids. + */ + std::vector VisitSequence_(const RuleExpr& rule_expr) { + std::vector new_sequence_ids; + for (auto i : rule_expr) { + auto seq_expr = grammar_->GetRuleExpr(i); + switch (seq_expr.type) { + case RuleExprType::kSequence: + VisitSequenceInSequence(seq_expr, &new_sequence_ids); + break; + case RuleExprType::kChoices: + VisitChoiceInSequence(seq_expr, &new_sequence_ids); + break; + case RuleExprType::kEmptyStr: + break; + case RuleExprType::kCharacterClass: + case RuleExprType::kNegCharacterClass: + case RuleExprType::kRuleRef: + VisitElementInSequence(seq_expr, &new_sequence_ids); + break; + default: + LOG(FATAL) << "Unexpected sequence type: " << static_cast(seq_expr.type); + } + } + return new_sequence_ids; + } + + /*! \brief Visit a sequence RuleExpr that is one element in another sequence. */ + void VisitSequenceInSequence(const RuleExpr& rule_expr, std::vector* new_sequence_ids) { + auto sub_sequence_ids = VisitSequence_(rule_expr); + new_sequence_ids->insert(new_sequence_ids->end(), sub_sequence_ids.begin(), + sub_sequence_ids.end()); + } + + /*! \brief Visit a choice RuleExpr that is one element in a sequence. */ + void VisitChoiceInSequence(const RuleExpr& rule_expr, std::vector* new_sequence_ids) { + auto sub_choice_ids = VisitChoices_(rule_expr); + if (sub_choice_ids.size() == 1) { + auto choice_element_expr = builder_.GetRuleExpr(sub_choice_ids[0]); + if (choice_element_expr.type != RuleExprType::kEmptyStr) { + new_sequence_ids->insert(new_sequence_ids->end(), choice_element_expr.begin(), + choice_element_expr.end()); + } + } else { + auto new_choice_id = builder_.AddChoices(sub_choice_ids); + auto new_choice_rule_id = builder_.AddRuleWithHint(cur_rule_name_ + "_choice", new_choice_id); + new_sequence_ids->push_back(builder_.AddRuleRef(new_choice_rule_id)); + } + } + + /*! \brief Visit an atom element RuleExpr that is in a sequence. */ + void VisitElementInSequence(const RuleExpr& rule_expr, std::vector* new_sequence_ids) { + new_sequence_ids->push_back(builder_.AddRuleExpr(rule_expr)); + } + + /*! \brief The name of the current rule being visited. */ + std::string cur_rule_name_; +}; + +BNFGrammar NestedRuleUnwrapper::Apply() { return NestedRuleUnwrapperImpl(grammar_).Apply(); } + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/grammar/grammar_simplifier.h b/cpp/serve/grammar/grammar_simplifier.h new file mode 100644 index 0000000000..4ccc0b55e7 --- /dev/null +++ b/cpp/serve/grammar/grammar_simplifier.h @@ -0,0 +1,184 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/grammar/grammar_simplifier.h + * \brief The header for the simplification of the BNF AST. + */ + +#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SIMPLIFIER_H_ +#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SIMPLIFIER_H_ + +#include +#include + +#include "grammar.h" +#include "grammar_builder.h" +#include "grammar_serializer.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief Base class for visitors and mutators of the BNF grammar. + * \tparam T The type of the return value of visitor functions. Typical values: + * - int32_t: the id of the new rule_expr + * - void: no return value + * \tparam ReturnType The type of the return value of the transform function Apply(). Typical values + * are void (for visitor) and BNFGrammar (for mutator). + */ +template +class BNFGrammarMutator { + public: + /*! + * \brief Constructor. + * \param grammar The grammar to visit or mutate. + */ + explicit BNFGrammarMutator(const BNFGrammar& grammar) : grammar_(grammar) {} + + /*! + * \brief Apply the transformation to the grammar, or visit the grammar. + * \return The transformed grammar, or the visiting result, or void. + * \note Should be called only once after the mutator is constructed. + */ + virtual ReturnType Apply() { + if constexpr (std::is_same::value && std::is_same::value) { + for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { + auto rule = grammar_->GetRule(i); + auto rule_expr = grammar_->GetRuleExpr(rule.body_expr_id); + auto new_body_expr_id = VisitExpr(rule_expr); + builder_.AddRule(rule.name, new_body_expr_id); + } + return builder_.Get(); + } else if constexpr (!std::is_same::value) { + return ReturnType(); + } + } + + protected: + using Rule = BNFGrammarNode::Rule; + using RuleExpr = BNFGrammarNode::RuleExpr; + using RuleExprType = BNFGrammarNode::RuleExprType; + + /*! \brief Visit a RuleExpr. Dispatch to the corresponding Visit function. */ + virtual T VisitExpr(const RuleExpr& rule_expr) { + switch (rule_expr.type) { + case RuleExprType::kSequence: + return VisitSequence(rule_expr); + case RuleExprType::kChoices: + return VisitChoices(rule_expr); + case RuleExprType::kEmptyStr: + return VisitEmptyStr(rule_expr); + case RuleExprType::kCharacterClass: + case RuleExprType::kNegCharacterClass: + return VisitCharacterClass(rule_expr); + case RuleExprType::kRuleRef: + return VisitRuleRef(rule_expr); + case RuleExprType::kStarQuantifier: + return VisitStarQuantifier(rule_expr); + default: + LOG(FATAL) << "Unexpected sequence type: " << static_cast(rule_expr.type); + } + } + + /*! \brief Visit a sequence RuleExpr. */ + virtual T VisitSequence(const RuleExpr& rule_expr) { + if constexpr (std::is_same::value) { + for (auto i : rule_expr) { + VisitExpr(grammar_->GetRuleExpr(i)); + } + } else if constexpr (std::is_same::value) { + std::vector sequence_ids; + for (int32_t i : rule_expr) { + sequence_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + } + return builder_.AddSequence(sequence_ids); + } else { + return T(); + } + } + + /*! \brief Visit a choices RuleExpr. */ + virtual T VisitChoices(const RuleExpr& rule_expr) { + if constexpr (std::is_same::value) { + for (auto i : rule_expr) { + VisitExpr(grammar_->GetRuleExpr(i)); + } + } else if constexpr (std::is_same::value) { + std::vector choice_ids; + for (int32_t i : rule_expr) { + choice_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + } + return builder_.AddChoices(choice_ids); + } else { + return T(); + } + } + + /*! \brief Visit an element RuleExpr, including empty string, character class, and rule ref. */ + virtual T VisitElement(const RuleExpr& rule_expr) { + if constexpr (std::is_same::value) { + return; + } else if constexpr (std::is_same::value) { + return builder_.AddRuleExpr(rule_expr); + } else { + return T(); + } + } + + /*! \brief Visit an empty string RuleExpr. */ + virtual T VisitEmptyStr(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + + /*! \brief Visit a character class RuleExpr. */ + virtual T VisitCharacterClass(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + + /*! \brief Visit a rule reference RuleExpr. */ + virtual T VisitRuleRef(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + + /*! \brief Visit a star quantifier RuleExpr. */ + virtual T VisitStarQuantifier(const RuleExpr& rule_expr) { + if constexpr (std::is_same::value) { + VisitExpr(grammar_->GetRuleExpr(rule_expr[0])); + } else if constexpr (std::is_same::value) { + return builder_.AddStarQuantifier(VisitExpr(grammar_->GetRuleExpr(rule_expr[0]))); + } else { + return T(); + } + } + + /*! \brief The grammar to visit or mutate. */ + BNFGrammar grammar_; + /*! + * \brief The builder to build the new grammar. It is empty when the mutator is constructed, and + * can be used to build a new grammar in subclasses. + */ + BNFGrammarBuilder builder_; +}; + +/*! + * \brief Unwrap the rules containing nested expressions. After unwrapping, each rule will be in + * the form: `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. + * + * I.e. a list of choices, each choice is a sequence of elements. Elements can be a character class + * or a rule reference. And if the rule can be empty, the first choice will be an empty string. + * + * \example The rule `A ::= ((a) (((b)) (c)) "")` will be replaced by `A ::= ((a b c))`. One choice + * containing a sequence of three elements. The empty string is removed. + * \example The rule `A ::= (a | (b | (c | "")))` will be replaced by + * `A ::= ("" | (a) | (b) | (c))`. The first choice is an empty string, and each of the other three + * choices is a sequence containing a single element. + * \example The rule `A ::= (a | (b (c | d)))` will be replaced by + * `A ::= ((a) | (b B)), B ::= ((c) | (d))`. A new rule B is created to represent the nested + * choices. + */ +class NestedRuleUnwrapper : public BNFGrammarMutator { + public: + using BNFGrammarMutator::BNFGrammarMutator; + + BNFGrammar Apply() final; +}; + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SIMPLIFIER_H_ diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc new file mode 100644 index 0000000000..79cc8a351a --- /dev/null +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -0,0 +1,517 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/grammar/grammar_state_matcher.cc + */ +#include "grammar_state_matcher.h" + +#include +#include + +#include "../../tokenizers.h" +#include "grammar.h" +#include "grammar_serializer.h" +#include "grammar_state_matcher_base.h" +#include "grammar_state_matcher_preproc.h" +#include "grammar_state_matcher_state.h" +#include "support.h" + +namespace mlc { +namespace llm { +namespace serve { + +/* + * Note on the matching algorithm + * + * Given a context-free grammar, we match the characters in a string one by one. + * + * We adopt a non-deterministic pushdown automata (NPDA) in matching. To be specific, we maintain + * several stacks, each of which represents a possible path in the NPDA, and update the stacks + * during matching. + * + * ## Stack Structure (see grammar_state_matcher_state.h) + * The element of every stack is a RulePosition object, referring a position in the grammar. If a + * RulePosition is a RuleRef element (referring to another rule), the next element of the stack will + * be a position in this rule. If a RulePosition is a CharacterClass element, it will be the last + * in the stack, meaning *the next* character to match. + * + * ## Matching Process (see grammar_state_matcher_base.h) + * When accepting a new character and it is accepted by a stack, the last element of the stack will + * be advanced to the next position in the grammar. If it gets to the end of the rule, several + * elements at the end may be popped out, and the last element of the stack will be advanced. + * + * One stack may split since there may be multiple possible next positions. In this case, similar + * stacks with different top elements will be added. When ome stack cannot accept the new character, + * it will be removed from the stacks. + * + * ## Storage of Stacks (see grammar_state_matcher_state.h) + * Note these stacks form a tree structure as when splitting, the new stacks share the same prefix. + * We store all RulePositions as a tree, where every path from tree root to a node represents a + * stack. To represent stack tops, we attach additional pointers pointing the stack top nodes. + * Also, We maintain a history of the stack top pointers, so we can rollback to the previous state. + * + * All tree nodes are maintained by a buffer, and utilize reference counting to recycle. If a node + * is neither pointed by a stack top pointer, not pointed by some child nodes, it will be freed. + * + * ## Example + * ### Grammar + * main ::= [a] R + * R ::= [b] S [c] | [b] [c] T + * S ::= "" | [c] [d] + * T ::= [e] + * + * ### Previous step + * Previous accepted string: ab + * Previous stack tree: + * A------ + * | \ \ + * B D< E< + * | + * C< + * + * A: (rule main, choice 0, element 1) + * B: (rule R, choice 0, element 1) + * C: (rule S, choice 1, element 0) + * D: (rule R, choice 0, element 2) + * E: (rule R, choice 1, element 1) + * < means the stack top pointers in the previous step. + * The stacks in the previous step is: (A, B, C), (A, D), (A, E) + * + * ### Current step + * Current accepted string: abc + * Current stack tree: + * A----------------- G<< + * | \ \ \ + * B--- D< E< H + * | \ | + * C< F<< I<< + * + * F: (rule S, choice 1, element 1) + * G: (rule main, choice 0, element 2) (means the matching process has finished, and will be deleted + * when next char comes) + * H: (rule R, choice 1, element 2) + * I: (rule T, choice 0, element 0) + * << means the stack top pointers in the current step. + * The stacks in the current step is: (A, B, F), (A, H, I), (G,) + * + * ## Preprocess (see grammar_state_matcher_preproc.h) + * We will store all information about tokens that needed in matching in a GrammarStateInitContext + * object. Tokens are sorted by codepoint, allowing us to reuse the repeated prefixes between + * different tokens. + * + * For a given position in a rule, if we only consider this rule and its sub-rules during matching, + * without considering its parent rules (in actual matching, we also need to consider its parent + * rules), we can already determine that some tokens are acceptable while others are definitely + * rejected. Therefore, for a position in a rule, we can divide the token set into three categories: + * - accepted_indices: If a token is accepted by this rule + * - rejected_indices: If a token is rejected by this rule + * - uncertain_indices: Whether it can be accepted depends on the information from the parent + * level during actual matching. To be specific, If this token has a prefix that has not been + * rejected and has reached the end of this rule, then it is possible for it to be further accepted + * by the parent rule. + * + * During actual matching, we will directly accept or reject the tokens in accepted_indices and + * rejected_indices, and only consider the tokens in uncertain_indices. That speeds up the matching + * process. + */ + +using namespace tvm::runtime; + +TVM_REGISTER_OBJECT_TYPE(GrammarStateMatcherNode); + +/* \brief The concrete implementation of GrammarStateMatcherNode. */ +class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public GrammarStateMatcherBase { + private: + using RuleExpr = BNFGrammarNode::RuleExpr; + using RuleExprType = BNFGrammarNode::RuleExprType; + + public: + GrammarStateMatcherNodeImpl(std::shared_ptr init_ctx, + int max_rollback_steps = 0) + : GrammarStateMatcherBase(init_ctx->grammar), + init_ctx_(init_ctx), + max_rollback_steps_(max_rollback_steps) {} + + bool AcceptToken(int32_t token_id) final; + + void FindNextTokenBitmask(DLTensor* next_token_bitmask) final; + + void Rollback(int num_tokens) final; + + int MaxRollbackSteps() final { return max_rollback_steps_; } + + void ResetState() final { + stack_tops_history_.Reset(); + token_size_history_.clear(); + InitStackState(); + } + + private: + /*! + * \brief If is_uncertain_saved is true, find the next token in uncertain_indices. Otherwise, + * find the next token that is set to true in uncertain_tokens_bitset. + * \param iterator_uncertain The helper iterator to iterate over uncertain_indices or + * uncertain_tokens_bitset. + * \returns The index of the next token, or -1 if no more token. + */ + int GetNextUncertainToken(bool is_uncertain_saved, int* iterator_uncertain, + const std::vector& uncertain_indices, + const std::vector& uncertain_tokens_bitset); + + /*! \brief Set the acceptable next token in next_token_bitmask. */ + void SetTokenBitmask(DLTensor* next_token_bitmask, std::vector& accepted_indices, + std::vector& rejected_indices, bool can_reach_end); + + friend IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher); + + std::shared_ptr init_ctx_; + int max_rollback_steps_; + std::deque token_size_history_; + + // Temporary data for FindNextTokenBitmask. They are stored here to avoid repeated allocation. + std::vector tmp_accepted_indices_; + std::vector tmp_rejected_indices_; + std::vector tmp_accepted_indices_delta_; + std::vector tmp_rejected_indices_delta_; + std::vector tmp_uncertain_tokens_bitset_; +}; + +bool GrammarStateMatcherNodeImpl::AcceptToken(int32_t token_id) { + CHECK(init_ctx_->codepoint_tokens_lookup.count(token_id) > 0); + const auto& token = init_ctx_->codepoint_tokens_lookup[token_id].token; + for (auto codepoint : token) { + if (!AcceptCodepoint(codepoint, false)) { + return false; + } + } + token_size_history_.push_back(token.size()); + if (token_size_history_.size() > max_rollback_steps_) { + DiscardEarliestCodepoints(token_size_history_.front()); + token_size_history_.pop_front(); + } + return true; +} + +void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitmask) { + const auto& tokens_sorted_by_codepoint = init_ctx_->tokens_sorted_by_codepoint; + const auto& catagorized_tokens_for_grammar = init_ctx_->catagorized_tokens_for_grammar; + const auto& latest_stack_tops = stack_tops_history_.GetLatest(); + + // We check all the stacks one by one, and find the accepted token set or the rejected token set + // for each stack. We will try to find the small one of the two sets. + // The final accepted token set is the union of the accepted token sets of all stacks. + // The final rejected token set is the intersection of the rejected token sets of all stacks. + + // Note these indices store the indices in tokens_sorted_by_codepoint, instead of the token ids. + tmp_accepted_indices_.clear(); + // {-1} means the universal set, i.e. all tokens initially + tmp_rejected_indices_.assign({-1}); + + for (auto top : latest_stack_tops) { + // Step 1. Find the current catagorized_tokens + auto cur_rule_position = tree_[top]; + auto current_sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); + if (cur_rule_position.parent_id == RulePosition::kNoParent && + cur_rule_position.element_id == current_sequence.size()) { + continue; + } + + const auto& catagorized_tokens = catagorized_tokens_for_grammar.at( + {cur_rule_position.sequence_id, cur_rule_position.element_id}); + + // For each stack, we will check every uncertain token and put them into the accepted or + // rejected list. + // If the accepted tokens are saved, it means it is likely to be smaller than the rejected + // tokens, so we will just find the accepted tokens, and vice versa. + bool is_find_accept_mode = + catagorized_tokens.not_saved_index != CatagorizedTokens::NotSavedIndex::kAccepted; + + // If uncertain tokens are saved, we will iterate over the uncertain tokens. + // Otherwise, we will iterate over all_tokens - accepted_tokens - rejected_tokens. + bool is_uncertain_saved = + catagorized_tokens.not_saved_index != CatagorizedTokens::NotSavedIndex::kUncertain; + + // Step 2. Update the accepted tokens in accepted_indices_delta, or the rejected tokens in + // rejected_indices_delta. + + // Examine only the current one stack + stack_tops_history_.PushHistory({tree_.NewNode(cur_rule_position)}); + + const std::vector* prev_token = nullptr; + int prev_matched_size = 0; + + tmp_accepted_indices_delta_.clear(); + tmp_rejected_indices_delta_.clear(); + + if (!is_uncertain_saved) { + // unc_tokens = all_tokens - accepted_tokens - rejected_tokens + tmp_uncertain_tokens_bitset_.assign(tokens_sorted_by_codepoint.size(), true); + for (auto idx : catagorized_tokens.accepted_indices) { + tmp_uncertain_tokens_bitset_[idx] = false; + } + for (auto idx : catagorized_tokens.rejected_indices) { + tmp_uncertain_tokens_bitset_[idx] = false; + } + } + + int iterator_uncertain = -1; + + while (true) { + // Step 2.1. Find the current token. + auto idx = + GetNextUncertainToken(is_uncertain_saved, &iterator_uncertain, + catagorized_tokens.uncertain_indices, tmp_uncertain_tokens_bitset_); + if (idx == -1) { + break; + } + const auto& cur_token = tokens_sorted_by_codepoint[idx].token; + + // Step 2.2. Find the longest common prefix with the accepted part of the previous token. + // We can reuse the previous matched size to avoid unnecessary matching. + int prev_useful_size = 0; + if (prev_token) { + prev_useful_size = std::min(prev_matched_size, static_cast(cur_token.size())); + for (int j = 0; j < prev_useful_size; ++j) { + if (cur_token[j] != (*prev_token)[j]) { + prev_useful_size = j; + break; + } + } + RollbackCodepoints(prev_matched_size - prev_useful_size); + } + + // Step 2.3. Find if the current token is accepted or rejected. + bool accepted = true; + prev_matched_size = prev_useful_size; + + for (int j = prev_useful_size; j < cur_token.size(); ++j) { + if (!AcceptCodepoint(cur_token[j], false)) { + accepted = false; + break; + } + prev_matched_size = j + 1; + } + + // Step 2.4. Push the result to the delta list. + if (accepted && is_find_accept_mode) { + tmp_accepted_indices_delta_.push_back(idx); + } else if (!accepted && !is_find_accept_mode) { + tmp_rejected_indices_delta_.push_back(idx); + } + + prev_token = &cur_token; + } + + RollbackCodepoints(prev_matched_size + 1); + + // Step 3. Update the accepted_indices and rejected_indices + if (is_find_accept_mode) { + // accepted_indices += catagorized_tokens.accepted_indices + accepted_indices_delta + IntsetUnion(&tmp_accepted_indices_delta_, catagorized_tokens.accepted_indices); + IntsetUnion(&tmp_accepted_indices_, tmp_accepted_indices_delta_); + } else { + // rejected_indices = Intersect( + // rejected_indices, + // catagorized_tokens.rejected_indices + rejected_indices_delta) + IntsetUnion(&tmp_rejected_indices_delta_, catagorized_tokens.rejected_indices); + IntsetIntersection(&tmp_rejected_indices_, tmp_rejected_indices_delta_); + } + } + + // Finally update the rejected_ids bitset + bool can_reach_end = CanReachEnd(); + SetTokenBitmask(next_token_bitmask, tmp_accepted_indices_, tmp_rejected_indices_, can_reach_end); +} + +void GrammarStateMatcherNodeImpl::Rollback(int num_tokens) { + CHECK(num_tokens <= token_size_history_.size()); + while (num_tokens > 0) { + int steps = token_size_history_.back(); + RollbackCodepoints(steps); + token_size_history_.pop_back(); + --num_tokens; + } +} + +void GrammarStateMatcherNodeImpl::SetTokenBitmask(DLTensor* next_token_bitmask, + std::vector& accepted_indices, + std::vector& rejected_indices, + bool can_reach_end) { + // accepted_ids = Union(accepted_indices, all_tokens - rejected_indices) + // rejected_ids = Intersect(all_tokens - accepted_indices, rejected_indices) + DCHECK(next_token_bitmask->dtype.code == kDLUInt && next_token_bitmask->dtype.bits == 32 && + next_token_bitmask->data && next_token_bitmask->ndim == 1 && next_token_bitmask->shape); + + BitsetManager next_token_bitset(reinterpret_cast(next_token_bitmask->data), + next_token_bitmask->shape[0]); + + if (rejected_indices.size() == 1 && rejected_indices[0] == -1) { + // If rejected_indices is the universal set, the final accepted token set is just + // accepted_indices + next_token_bitset.Reset(init_ctx_->vocab_size, false); + for (int idx : accepted_indices) { + next_token_bitset.Set(init_ctx_->tokens_sorted_by_codepoint[idx].id, true); + } + + if (can_reach_end) { + // add end tokens + for (int idx : init_ctx_->stop_token_ids) { + next_token_bitset.Set(idx, true); + } + } + } else { + // Otherwise, the final rejected token set is (rejected_indices \ accepted_indices) + next_token_bitset.Reset(init_ctx_->vocab_size, true); + + auto it_acc = accepted_indices.begin(); + for (auto i : rejected_indices) { + while (it_acc != accepted_indices.end() && *it_acc < i) { + ++it_acc; + } + if (it_acc == accepted_indices.end() || *it_acc != i) { + next_token_bitset.Set(init_ctx_->tokens_sorted_by_codepoint[i].id, false); + } + } + + for (int idx : init_ctx_->special_token_ids) { + next_token_bitset.Set(idx, false); + } + if (!can_reach_end) { + for (int idx : init_ctx_->stop_token_ids) { + next_token_bitset.Set(idx, false); + } + } + } +} + +int GrammarStateMatcherNodeImpl::GetNextUncertainToken( + bool is_uncertain_saved, int* iterator_uncertain, const std::vector& uncertain_indices, + const std::vector& uncertain_tokens_bitset) { + if (is_uncertain_saved) { + ++*iterator_uncertain; + if (*iterator_uncertain == uncertain_indices.size()) { + return -1; + } + return uncertain_indices[*iterator_uncertain]; + } else { + ++*iterator_uncertain; + while (*iterator_uncertain < uncertain_tokens_bitset.size() && + !uncertain_tokens_bitset[*iterator_uncertain]) { + ++*iterator_uncertain; + } + if (*iterator_uncertain == uncertain_tokens_bitset.size()) { + return -1; + } + return *iterator_uncertain; + } +} + +GrammarStateMatcher::GrammarStateMatcher(std::shared_ptr init_ctx, + int max_rollback_steps) + : ObjectRef(make_object(init_ctx, max_rollback_steps)) {} + +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenizer") + .set_body_typed([](BNFGrammar grammar, Optional tokenizer, int max_rollback_steps) { + auto init_ctx = CreateInitContext( + grammar, tokenizer ? tokenizer.value()->TokenTable() : std::vector()); + return GrammarStateMatcher(init_ctx, max_rollback_steps); + }); + +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenTable") + .set_body([](TVMArgs args, TVMRetValue* rv) { + BNFGrammar grammar = args[0]; + std::vector token_table; + for (int i = 1; i < args.size() - 1; ++i) { + token_table.push_back(args[i]); + } + int max_rollback_steps = args[args.size() - 1]; + auto init_ctx = CreateInitContext(grammar, token_table); + *rv = GrammarStateMatcher(init_ctx, max_rollback_steps); + }); + +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugAcceptCodepoint") + .set_body_typed([](GrammarStateMatcher matcher, int32_t codepoint) { + auto mutable_node = + const_cast(matcher.as()); + return mutable_node->AcceptCodepoint(codepoint); + }); + +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherAcceptToken") + .set_body_typed([](GrammarStateMatcher matcher, int32_t token_id) { + return matcher->AcceptToken(token_id); + }); + +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherRollback") + .set_body_typed([](GrammarStateMatcher matcher, int num_tokens) { + matcher->Rollback(num_tokens); + }); + +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherMaxRollbackSteps") + .set_body_typed([](GrammarStateMatcher matcher) { return matcher->MaxRollbackSteps(); }); + +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherResetState") + .set_body_typed([](GrammarStateMatcher matcher) { matcher->ResetState(); }); + +/*! \brief Check if a matcher can accept the complete string, and then reach the end of the + * grammar. For test purpose. */ +bool MatchCompleteString(GrammarStateMatcher matcher, String str) { + auto mutable_node = + const_cast(matcher.as()); + auto codepoints = Utf8StringToCodepoints(str.c_str()); + int accepted_cnt = 0; + for (auto codepoint : codepoints) { + if (!mutable_node->AcceptCodepoint(codepoint, false)) { + mutable_node->RollbackCodepoints(accepted_cnt); + return false; + } + ++accepted_cnt; + } + return mutable_node->CanReachEnd(); +} + +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugMatchCompleteString") + .set_body_typed([](GrammarStateMatcher matcher, String str) { + return MatchCompleteString(matcher, str); + }); + +/*! + * \brief Find the ids of the rejected tokens for the next step. + * \returns A tuple of rejected token ids. + */ +IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher) { + auto init_ctx = matcher.as()->init_ctx_; + auto vocab_size = init_ctx->vocab_size; + auto bitset_size = BitsetManager::GetBitsetSize(vocab_size); + auto ndarray = NDArray::Empty(ShapeTuple{static_cast(bitset_size)}, + DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0}); + auto dltensor_manager = ndarray.ToDLPack(); + auto dltensor = ndarray.ToDLPack()->dl_tensor; + + auto start = std::chrono::high_resolution_clock::now(); + matcher->FindNextTokenBitmask(&dltensor); + auto end = std::chrono::high_resolution_clock::now(); + std::cout << "FindNextTokenBitmask takes " + << std::chrono::duration_cast(end - start).count() << "us"; + + auto bitset = BitsetManager(reinterpret_cast(dltensor.data), bitset_size); + std::vector rejected_ids; + for (int i = 0; i < vocab_size; i++) { + if (bitset[i] == 0) { + rejected_ids.push_back(i); + } + } + + std::cout << ", found accepted: " << vocab_size - rejected_ids.size() + << ", rejected: " << rejected_ids.size() << std::endl; + + dltensor_manager->deleter(dltensor_manager); + + auto ret = IntTuple(rejected_ids); + return ret; +} + +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFindNextRejectedTokens") + .set_body_typed(FindNextRejectedTokens); + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/grammar/grammar_state_matcher.h b/cpp/serve/grammar/grammar_state_matcher.h new file mode 100644 index 0000000000..0ea4b12b95 --- /dev/null +++ b/cpp/serve/grammar/grammar_state_matcher.h @@ -0,0 +1,125 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/grammar/grammar_state_matcher.h + * \brief The header for the support of matching tokens to BNF grammar. This is the core + * logic of the grammar-guided generation. + */ + +#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_H_ +#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_H_ + +#include +#include + +#include +#include +#include + +#include "../../support/encoding.h" +#include "grammar.h" +#include "support.h" + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! + * \brief A stateful matcher to match tokens to the specified BNF grammar. This class is the core + * logic of the grammar-guided generation. + * + * \details This class implements the non-deterministic pushdown automaton (NPDA) matching algorithm + * to match characters to a BNF grammar. It keep track of the current state of the matching process + * by maintaining several stacks internally as possible paths in the NPDA. It also supports + * backtracking. + * + * It is particularly capable of finding the set of tokens that are acceptable for the next step + * and storing them in a bitmask. This aids in grammar-guided generation. + * + * \example + * \code + * Tokenizer tokenizer = ...; + * auto init_ctx = GrammarStateMatcher::CreateInitContext(grammar, tokenizer->TokenTable()); + * GrammarStateMatcher matcher(init_ctx, 10); + * matcher->AcceptToken(67); + * + * // Construct a DLTensor with shape (tokenizer.GetVocabSize() + 31) / 32, and dtype uint32. + * DLTensor next_token_bitmask = ...; + * matcher->FindNextTokenBitmask(&next_token_bitmask); + * + * // Rollback is supported + * matcher->Rollback(1); + * \endcode + */ +class GrammarStateMatcherNode : public Object { + public: + /*! + * \brief Accept one token and update the state of the matcher. + * \param token_id The id of the token to accept. + * \return Whether the token is accepted. + */ + virtual bool AcceptToken(int32_t token_id) = 0; + + /*! + * \brief Find the set of tokens that are acceptable for the next step and store them in a + * bitmask. + * \param next_token_bitmask The bitmask to store the result. The bitmask must be pre-allocated, + * and its shape needs to be (ceil(vocab_size, 32),), with a dtype of uint32. + */ + virtual void FindNextTokenBitmask(DLTensor* next_token_bitmask) = 0; + + /*! + * \brief Rollback the matcher to a previous state. + * \param num_tokens The number of tokens to rollback. It cannot exceed the current number of + * steps, nor can it exceed the specified maximum number of rollback steps. + */ + virtual void Rollback(int num_tokens) = 0; + + /*! \brief Get the maximum number of rollback steps allowed. */ + virtual int MaxRollbackSteps() = 0; + + /*! \brief Reset the matcher to the initial state. */ + virtual void ResetState() = 0; + + static constexpr const char* _type_key = "mlc.serve.GrammarStateMatcher"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(GrammarStateMatcherNode, Object); +}; + +/*! + * \brief The init context of a GrammarStateMatcher. It contains the preprocessing results of the + * grammar and tokenizer. + */ +class GrammarStateInitContext; + +class GrammarStateMatcher : public ObjectRef { + public: + /*! + * \brief Construct a GrammarStateMatcher from the preprocessing result of type + * GrammarStateInitContext. + * \param init_ctx The init context. It is obtained through + * CreateInitContext as a result of preprocessing the grammar and tokenizer. + */ + GrammarStateMatcher(std::shared_ptr init_ctx, + int max_rollback_steps = 0); + + /*! + * \brief Specify a grammar and token_table to return their preprocessing results. These results + * are used to construct a GrammarStateMatcher. They can be stored elsewhere for quick + * construction of GrammarStateMatcher. + * \param grammar The grammar that the matcher follows. + * \param token_table The tokens that the matcher requires for matching. + */ + static std::shared_ptr CreateInitContext( + const BNFGrammar& grammar, const std::vector& token_table); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GrammarStateMatcher, ObjectRef, GrammarStateMatcherNode); +}; + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_H_ diff --git a/cpp/serve/grammar/grammar_state_matcher_base.h b/cpp/serve/grammar/grammar_state_matcher_base.h new file mode 100644 index 0000000000..11623661e7 --- /dev/null +++ b/cpp/serve/grammar/grammar_state_matcher_base.h @@ -0,0 +1,236 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/grammar/grammar_state_matcher_base.h + * \brief The base class of GrammarStateMatcher. It implements a character-based matching automata. + */ +#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_ +#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_ + +#include + +#include "../../tokenizers.h" +#include "grammar.h" +#include "grammar_state_matcher_state.h" + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! \brief The base class of GrammarStateMatcher. It implements a character-based matching + * automata, and supports accepting a character, rolling back by character, etc. + */ +class GrammarStateMatcherBase { + protected: + using RuleExpr = BNFGrammarNode::RuleExpr; + using RuleExprType = BNFGrammarNode::RuleExprType; + + public: + /*! + * \brief Construct a GrammarStateMatcherBase with the given grammar and initial rule position. + * \param grammar The grammar to match. + * \param init_rule_position The initial rule position. If not specified, the main rule will be + * used. + */ + GrammarStateMatcherBase(const BNFGrammar& grammar, RulePosition init_rule_position = {}) + : grammar_(grammar), tree_(grammar), stack_tops_history_(&tree_) { + InitStackState(init_rule_position); + } + + /*! \brief Accept one codepoint. */ + bool AcceptCodepoint(TCodepoint codepoint, bool verbose = false); + + /*! \brief Check if the end of the main rule is reached. If so, the stop token can be accepted. */ + bool CanReachEnd() const; + + /*! \brief Rollback the matcher to a previous state. */ + void RollbackCodepoints(int rollback_codepoint_cnt); + + /*! \brief Discard the earliest history. */ + void DiscardEarliestCodepoints(int discard_codepoint_cnt); + + /*! \brief Print the stack state. */ + std::string PrintStackState(int steps_behind_latest = 0) const; + + protected: + // Init the stack state according to the given rule position. + // If init_rule_position is {}, init the stack with the main rule. + void InitStackState(RulePosition init_rule_position = {}); + + // Update the old stack top to the next position, and push the new stack tops to new_stack_tops. + void UpdateNewStackTops(int32_t old_node_id, std::vector* new_stack_tops); + + BNFGrammar grammar_; + RulePositionTree tree_; + StackTopsHistory stack_tops_history_; + + // Temporary data for AcceptCodepoint. + std::vector tmp_new_stack_tops_; +}; + +/*! \brief Check the codepoint is contained in the character class. */ +inline bool CharacterClassContains(const BNFGrammarNode::RuleExpr& rule_expr, + TCodepoint codepoint) { + DCHECK(rule_expr.type == BNFGrammarNode::RuleExprType::kCharacterClass || + rule_expr.type == BNFGrammarNode::RuleExprType::kNegCharacterClass); + for (int i = 0; i < rule_expr.size(); i += 2) { + if (rule_expr.data[i] <= codepoint && codepoint <= rule_expr.data[i + 1]) { + return rule_expr.type == BNFGrammarNode::RuleExprType::kCharacterClass; + } + } + return rule_expr.type == BNFGrammarNode::RuleExprType::kNegCharacterClass; +} + +inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool verbose) { + if (verbose) { + std::cout << "Stack before accepting: " << PrintStackState() << std::endl; + } + tmp_new_stack_tops_.clear(); + + const auto& prev_stack_tops = stack_tops_history_.GetLatest(); + for (auto old_top : prev_stack_tops) { + const auto& rule_position = tree_[old_top]; + auto current_sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + if (rule_position.parent_id == RulePosition::kNoParent && + rule_position.element_id == current_sequence.size()) { + // This RulePosition means previous elements has matched the complete rule. + // But we are still need to accept a new character, so this stack will become invalid. + continue; + } + auto current_char_class = grammar_->GetRuleExpr(current_sequence[rule_position.element_id]); + // Special support for star quantifiers of character classes. + if (current_char_class.type == RuleExprType::kRuleRef) { + DCHECK(rule_position.char_class_id != -1); + current_char_class = grammar_->GetRuleExpr(rule_position.char_class_id); + } + DCHECK(current_char_class.type == RuleExprType::kCharacterClass || + current_char_class.type == RuleExprType::kNegCharacterClass); + auto ok = CharacterClassContains(current_char_class, codepoint); + if (!ok) { + continue; + } + UpdateNewStackTops(old_top, &tmp_new_stack_tops_); + } + if (tmp_new_stack_tops_.empty()) { + if (verbose) { + std::cout << "Codepoint: " << codepoint << " \"" << CodepointToPrintable(codepoint) + << "\" Rejected" << std::endl; + } + return false; + } + stack_tops_history_.PushHistory(tmp_new_stack_tops_); + if (verbose) { + std::cout << "Codepoint: " << codepoint << " \"" << CodepointToPrintable(codepoint) + << "\" Accepted" << std::endl; + std::cout << "Stack after accepting: " << PrintStackState() << std::endl; + } + return true; +} + +inline bool GrammarStateMatcherBase::CanReachEnd() const { + const auto& last_stack_tops = stack_tops_history_.GetLatest(); + return std::any_of(last_stack_tops.begin(), last_stack_tops.end(), + [&](int32_t id) { return tree_.IsEndPosition(tree_[id]); }); +} + +inline void GrammarStateMatcherBase::RollbackCodepoints(int rollback_codepoint_cnt) { + stack_tops_history_.Rollback(rollback_codepoint_cnt); +} + +inline void GrammarStateMatcherBase::DiscardEarliestCodepoints(int discard_codepoint_cnt) { + stack_tops_history_.DiscardEarliest(discard_codepoint_cnt); +} + +inline std::string GrammarStateMatcherBase::PrintStackState(int steps_behind_latest) const { + return stack_tops_history_.PrintHistory(steps_behind_latest); +} + +inline void GrammarStateMatcherBase::InitStackState(RulePosition init_rule_position) { + if (init_rule_position == kInvalidRulePosition) { + // Initialize the stack with the main rule. + auto main_rule = grammar_->GetRule(0); + auto main_rule_expr = grammar_->GetRuleExpr(main_rule.body_expr_id); + std::vector new_stack_tops; + for (auto i : main_rule_expr) { + DCHECK(grammar_->GetRuleExpr(i).type == RuleExprType::kSequence || + grammar_->GetRuleExpr(i).type == RuleExprType::kEmptyStr); + new_stack_tops.push_back(tree_.NewNode(RulePosition(0, i, 0, RulePosition::kNoParent))); + } + stack_tops_history_.PushHistory(new_stack_tops); + } else { + stack_tops_history_.PushHistory({tree_.NewNode(init_rule_position)}); + } +} + +inline void GrammarStateMatcherBase::UpdateNewStackTops(int32_t old_node_id, + std::vector* new_stack_tops) { + const auto& old_rule_position = tree_[old_node_id]; + // For char_class*, the old rule position itself is also the next position + if (old_rule_position.char_class_id != -1) { + new_stack_tops->push_back(tree_.NewNode(old_rule_position)); + } + + auto cur_rule_position = tree_.GetNextPosition(tree_[old_node_id]); + + // Continuously iterate to the next position (if reachs the end of the current rule, go to the + // next position of the parent rule). Push it into new_stack_tops. If this position can not + // be empty, exit the loop. + // Positions that can be empty: reference to a rule that can be empty, or a star quantifier + // rule. + for (; !tree_.IsEndPosition(cur_rule_position); + cur_rule_position = tree_.GetNextPosition(cur_rule_position)) { + auto sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); + auto element = grammar_->GetRuleExpr(sequence[cur_rule_position.element_id]); + if (element.type == RuleExprType::kCharacterClass || + element.type == RuleExprType::kNegCharacterClass) { + // Character class: cannot be empty. Break the loop. + new_stack_tops->push_back(tree_.NewNode(cur_rule_position)); + break; + } else { + // RuleRef + DCHECK(element.type == RuleExprType::kRuleRef); + auto new_rule_id = element[0]; + auto new_rule = grammar_->GetRule(new_rule_id); + auto new_rule_expr = grammar_->GetRuleExpr(new_rule.body_expr_id); + if (new_rule_expr.type == RuleExprType::kStarQuantifier) { + cur_rule_position.char_class_id = new_rule_expr[0]; + new_stack_tops->push_back(tree_.NewNode(cur_rule_position)); + } else { + DCHECK(new_rule_expr.type == RuleExprType::kChoices); + + bool contain_empty = false; + + // For rule containing choices, expand the rule and push all positions into new_stack_tops + for (auto j : new_rule_expr) { + auto sequence = grammar_->GetRuleExpr(j); + if (sequence.type == RuleExprType::kEmptyStr) { + contain_empty = true; + continue; + } + DCHECK(sequence.type == RuleExprType::kSequence); + DCHECK(grammar_->GetRuleExpr(sequence[0]).type == RuleExprType::kCharacterClass || + grammar_->GetRuleExpr(sequence[0]).type == RuleExprType::kNegCharacterClass); + // Note: rule_position is not inserted to the tree yet, so it need to be inserted first + auto parent_id = tree_.NewNode(cur_rule_position); + new_stack_tops->push_back(tree_.NewNode(RulePosition(new_rule_id, j, 0, parent_id))); + } + + if (!contain_empty) { + break; + } + } + } + } + + // Reaches the end of the main rule. Insert a special node to indicate the end. + if (tree_.IsEndPosition(cur_rule_position)) { + new_stack_tops->push_back(tree_.NewNode(cur_rule_position)); + } +} + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_ diff --git a/cpp/serve/grammar/grammar_state_matcher_preproc.h b/cpp/serve/grammar/grammar_state_matcher_preproc.h new file mode 100644 index 0000000000..62a1f2a6af --- /dev/null +++ b/cpp/serve/grammar/grammar_state_matcher_preproc.h @@ -0,0 +1,315 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/grammar/grammar_state_matcher_preproc.h + * \brief The header for the preprocessing of the grammar state matcher. + */ +#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_ +#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_ + +#include + +#include "../../support/encoding.h" +#include "grammar.h" +#include "grammar_state_matcher_base.h" + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! \brief A token and its id. */ +struct TokenAndId { + std::vector token; + int32_t id; + /*! \brief Compare tokens by their unicode codepoint sequence. */ + bool operator<(const TokenAndId& other) const; +}; + +/*! + * \brief Preprocessed information, for a given specific rule and position, divides the token set + * into three categories: accepted, rejected, and uncertain. + * \note Since the union of these three sets is the whole token set, we only need to store the + * smaller two sets. The unsaved set is specified by not_saved_index. + * \note These indices are the indices of tokens_sorted_by_codepoint in the GrammarStateInitContext + * object, instead of the token ids. That helps the matching process. + */ +struct CatagorizedTokens { + std::vector accepted_indices; + std::vector rejected_indices; + std::vector uncertain_indices; + enum class NotSavedIndex { kAccepted = 0, kRejected = 1, kUncertain = 2 }; + NotSavedIndex not_saved_index; + + CatagorizedTokens() = default; + + CatagorizedTokens(std::vector&& accepted_indices, + std::vector&& rejected_indices, + std::vector&& uncertain_indices); +}; + +/*! + * \brief All information that we need to match tokens in the tokenizer to the specified grammar. + * It is the result of preprocessing. + * \sa mlc::llm::serve::GrammarStateMatcher + */ +class GrammarStateInitContext { + public: + BNFGrammar grammar; + /*! \brief The vocabulary size of the tokenizer. */ + size_t vocab_size; + /*! \brief The sorted token and its id. Tokens are sorted to reuse the common prefix during + * matching. */ + std::vector tokens_sorted_by_codepoint; + /*! \brief The mapping from token id to token represented by codepoints. */ + std::unordered_map codepoint_tokens_lookup; + /*! \brief The stop tokens. They can be accepted iff GramamrMatcher can reach the end of the + * grammar. */ + std::vector stop_token_ids; + /*! \brief The special tokens. Currently we will ignore these tokens during grammar-guided + * matching. */ + std::vector special_token_ids; + + /*! \brief A sequence id and its position. */ + struct SequenceIdAndPosition { + int32_t sequence_id; + int32_t element_id; + bool operator==(const SequenceIdAndPosition& other) const { + return sequence_id == other.sequence_id && element_id == other.element_id; + } + }; + + /*! \brief Hash function for SequenceIdAndPosition. */ + struct SequenceIdAndPositionHash { + std::size_t operator()(const SequenceIdAndPosition& k) const { + return std::hash()(k.sequence_id) ^ (std::hash()(k.element_id) << 1); + } + }; + + /*! \brief Mapping from sequence id and its position to the catagorized tokens. */ + std::unordered_map + catagorized_tokens_for_grammar; +}; + +/* \brief The concrete implementation of GrammarStateMatcherNode. */ +class GrammarStateMatcherForInitContext : public GrammarStateMatcherBase { + public: + GrammarStateMatcherForInitContext(const BNFGrammar& grammar, RulePosition init_rule_position) + : GrammarStateMatcherBase(grammar, init_rule_position) {} + + CatagorizedTokens GetCatagorizedTokens(const std::vector& tokens_sorted_by_codepoint, + bool is_main_rule); + + private: + using RuleExpr = BNFGrammarNode::RuleExpr; + using RuleExprType = BNFGrammarNode::RuleExprType; + + // Temporary data for GetCatagorizedTokens. + std::vector tmp_accepted_indices_; + std::vector tmp_rejected_indices_; + std::vector tmp_uncertain_indices_; + std::vector tmp_can_see_end_stack_; +}; + +inline bool TokenAndId::operator<(const TokenAndId& other) const { + for (size_t i = 0; i < token.size(); ++i) { + if (i >= other.token.size()) { + return false; + } + if (token[i] < other.token[i]) { + return true; + } else if (token[i] > other.token[i]) { + return false; + } + } + return token.size() < other.token.size(); +} + +inline CatagorizedTokens::CatagorizedTokens(std::vector&& accepted_indices, + std::vector&& rejected_indices, + std::vector&& uncertain_indices) { + auto size_acc = accepted_indices.size(); + auto size_rej = rejected_indices.size(); + auto size_unc = uncertain_indices.size(); + not_saved_index = + (size_acc >= size_rej && size_acc >= size_unc) + ? NotSavedIndex::kAccepted + : (size_rej >= size_unc ? NotSavedIndex::kRejected : NotSavedIndex::kUncertain); + + if (not_saved_index != NotSavedIndex::kAccepted) { + this->accepted_indices = std::move(accepted_indices); + } + if (not_saved_index != NotSavedIndex::kRejected) { + this->rejected_indices = std::move(rejected_indices); + } + if (not_saved_index != NotSavedIndex::kUncertain) { + this->uncertain_indices = std::move(uncertain_indices); + } +} + +inline CatagorizedTokens GrammarStateMatcherForInitContext::GetCatagorizedTokens( + const std::vector& tokens_sorted_by_codepoint, bool is_main_rule) { + // Support the current stack contains only one stack with one RulePosition. + // Iterate over all tokens. Split them into three categories: + // - accepted_indices: If a token is accepted by current rule + // - rejected_indices: If a token is rejected by current rule + // - uncertain_indices: If a prefix of a token is accepted by current rule and comes to the end + // of the rule. + + // Note many tokens may contain the same prefix, so we will avoid unnecessary matching + + tmp_accepted_indices_.clear(); + tmp_rejected_indices_.clear(); + tmp_uncertain_indices_.clear(); + // For every character in the current token, stores whether it is possible to reach the end of + // the rule when matching until this character. Useful for rollback. + tmp_can_see_end_stack_.assign({CanReachEnd()}); + + int prev_matched_size = 0; + for (int i = 0; i < static_cast(tokens_sorted_by_codepoint.size()); ++i) { + const auto& token = tokens_sorted_by_codepoint[i].token; + const auto* prev_token = i > 0 ? &tokens_sorted_by_codepoint[i - 1].token : nullptr; + + // Find the longest common prefix with the accepted part of the previous token. + auto prev_useful_size = 0; + if (prev_token) { + prev_useful_size = std::min(prev_matched_size, static_cast(token.size())); + for (int j = 0; j < prev_useful_size; ++j) { + if (token[j] != (*prev_token)[j]) { + prev_useful_size = j; + break; + } + } + RollbackCodepoints(prev_matched_size - prev_useful_size); + tmp_can_see_end_stack_.erase( + tmp_can_see_end_stack_.end() - (prev_matched_size - prev_useful_size), + tmp_can_see_end_stack_.end()); + } + + // Find if the current token is accepted or rejected or uncertain. + bool accepted = true; + bool can_see_end = tmp_can_see_end_stack_.back(); + prev_matched_size = prev_useful_size; + for (int j = prev_useful_size; j < token.size(); ++j) { + if (!AcceptCodepoint(token[j], false)) { + accepted = false; + break; + } + if (CanReachEnd()) { + can_see_end = true; + } + tmp_can_see_end_stack_.push_back(can_see_end); + prev_matched_size = j + 1; + } + if (accepted) { + tmp_accepted_indices_.push_back(i); + } else if (can_see_end && !is_main_rule) { + // If the current rule is the main rule, there will be no uncertain indices since we will + // never consider its parent rule. Unaccepted tokens are just rejected. + tmp_uncertain_indices_.push_back(i); + } else { + tmp_rejected_indices_.push_back(i); + } + } + RollbackCodepoints(prev_matched_size); + return CatagorizedTokens(std::move(tmp_accepted_indices_), std::move(tmp_rejected_indices_), + std::move(tmp_uncertain_indices_)); +} + +inline std::string ReplaceUnderscoreWithSpace(const std::string& str, + const std::string& kSpecialUnderscore) { + std::string res; + size_t pos = 0; + while (pos < str.size()) { + size_t found = str.find(kSpecialUnderscore, pos); + if (found == std::string::npos) { + res += str.substr(pos); + break; + } + res += str.substr(pos, found - pos) + " "; + pos = found + kSpecialUnderscore.size(); + } + return res; +} + +inline std::shared_ptr CreateInitContext( + const BNFGrammar& grammar, const std::vector& token_table) { + using RuleExprType = BNFGrammarNode::RuleExprType; + auto ptr = std::make_shared(); + + ptr->grammar = grammar; + ptr->vocab_size = token_table.size(); + + if (ptr->vocab_size == 0) { + return ptr; + } + + for (int i = 0; i < token_table.size(); ++i) { + auto token = token_table[i]; + if (token == "" || token == "" || token == "") { + ptr->special_token_ids.push_back(i); + } else if (token == "") { + ptr->stop_token_ids.push_back(i); + } else if (token.size() == 1 && + (static_cast(token[0]) >= 128 || token[0] == 0)) { + // Currently we consider all tokens with one character that >= 128 as special tokens. + ptr->special_token_ids.push_back(i); + } else { + // First replace the special underscore with space. + auto codepoints = Utf8StringToCodepoints(token.c_str()); + DCHECK(!codepoints.empty() && + codepoints[0] != static_cast(CharHandlingError::kInvalidUtf8)) + << "Invalid token: " << token; + ptr->tokens_sorted_by_codepoint.push_back({codepoints, i}); + ptr->codepoint_tokens_lookup[i] = {codepoints, i}; + } + } + std::sort(ptr->tokens_sorted_by_codepoint.begin(), ptr->tokens_sorted_by_codepoint.end()); + + // Find the corresponding catagorized tokens for: + // 1. All character elements in the grammar + // 2. All RuleRef elements that refers to a rule of a StarQuantifier of a character class + for (int i = 0; i < static_cast(grammar->NumRules()); ++i) { + auto rule = grammar->GetRule(i); + auto rule_expr = grammar->GetRuleExpr(rule.body_expr_id); + // Skip StarQuantifier since we just handle it at the reference element during matching. + if (rule_expr.type == RuleExprType::kStarQuantifier) { + continue; + } + DCHECK(rule_expr.type == RuleExprType::kChoices); + for (auto sequence_id : rule_expr) { + auto sequence_expr = grammar->GetRuleExpr(sequence_id); + if (sequence_expr.type == RuleExprType::kEmptyStr) { + continue; + } + DCHECK(sequence_expr.type == RuleExprType::kSequence); + for (int element_id = 0; element_id < sequence_expr.size(); ++element_id) { + auto element_expr = grammar->GetRuleExpr(sequence_expr[element_id]); + auto cur_rule_position = RulePosition{i, sequence_id, element_id}; + if (element_expr.type == RuleExprType::kRuleRef) { + auto ref_rule = grammar->GetRule(element_expr[0]); + auto ref_rule_expr = grammar->GetRuleExpr(ref_rule.body_expr_id); + if (ref_rule_expr.type == RuleExprType::kChoices) { + continue; + } else { + // Reference to a StarQuantifier of a character class. + cur_rule_position.char_class_id = ref_rule_expr[0]; + } + } + + auto grammar_state_matcher = GrammarStateMatcherForInitContext(grammar, cur_rule_position); + auto cur_catagorized_tokens_for_grammar = + grammar_state_matcher.GetCatagorizedTokens(ptr->tokens_sorted_by_codepoint, i == 0); + ptr->catagorized_tokens_for_grammar[{sequence_id, element_id}] = + cur_catagorized_tokens_for_grammar; + } + } + } + return ptr; +} + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // TVM_LLVM_COMPILE_ENGINE_CPP_SERVE_GRAMMAR_STATE_MATCHER_PREPROC_H_ diff --git a/cpp/serve/grammar/grammar_state_matcher_state.h b/cpp/serve/grammar/grammar_state_matcher_state.h new file mode 100644 index 0000000000..d8f2185f98 --- /dev/null +++ b/cpp/serve/grammar/grammar_state_matcher_state.h @@ -0,0 +1,442 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/grammar/grammar_state_matcher_state.h + * \brief The header for the definition of the state used in the grammar state matcher. + */ +#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_STATE_H_ +#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_STATE_H_ + +#include +#include + +#include "grammar.h" +#include "grammar_serializer.h" + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! \brief Specifies a position in a rule. */ +struct RulePosition { + /*! \brief The rule's id. */ + int32_t rule_id = -1; + /*! \brief Which choice in this rule is selected. */ + int32_t sequence_id = -1; + /*! \brief Which element of the choice sequence is being visited. */ + int32_t element_id = -1; + /*! + * \brief If the element refers to another rule, and another rule is a star quantifier of + * a character class, this field will be set to the id of the character class. + * This is part of the special support of star quantifiers of character classes. + */ + int32_t char_class_id = -1; + /*! \brief The id of the parent node in the RulePositionTree. */ + int32_t parent_id = -1; + /*! \brief The reference count of this RulePosition. If reduces to zero, the node will be + * removed from the RulePositionBuffer. */ + int reference_count = 0; + + /*! \brief A parent_id value of kNoParent means this RulePosition is the root of the tree. */ + static constexpr int32_t kNoParent = -1; + + constexpr RulePosition() = default; + constexpr RulePosition(int32_t rule_id, int32_t sequence_id, int32_t element_id, + int32_t parent_id = kNoParent, int32_t char_class_id = -1) + : rule_id(rule_id), + sequence_id(sequence_id), + element_id(element_id), + char_class_id(char_class_id), + parent_id(parent_id) {} + + bool operator==(const RulePosition& other) const { + return rule_id == other.rule_id && sequence_id == other.sequence_id && + element_id == other.element_id && char_class_id == other.char_class_id && + parent_id == other.parent_id; + } + + bool operator!=(const RulePosition& other) const { return !(*this == other); } +}; + +/*! \brief A special value for invalid RulePosition. */ +inline constexpr RulePosition kInvalidRulePosition(-1, -1, -1, -1, -1); + +/*! \brief A buffer to manage all RulePositions. */ +class RulePositionBuffer { + public: + /*! + * \brief Allocate a new RulePosition. with given initial value. + * \returns The id of the allocated node. + */ + int32_t Allocate(RulePosition rule_position) { + int32_t id; + if (free_nodes_.empty()) { + buffer_.emplace_back(); + id = buffer_.size() - 1; + } else { + id = free_nodes_.back(); + DCHECK(buffer_[id] == kInvalidRulePosition); + free_nodes_.pop_back(); + } + rule_position.reference_count = 0; + buffer_[id] = rule_position; + return id; + } + + /*! \brief Free the RulePosition with the given id. */ + void Free(int32_t id) { + DCHECK(buffer_[id] != kInvalidRulePosition); + buffer_[id] = kInvalidRulePosition; + free_nodes_.push_back(id); + } + + /*! \brief Get the capacity of the buffer. */ + size_t Capacity() const { return buffer_.size(); } + + /*! \brief Get the number of allocated nodes. */ + size_t Size() const { + DCHECK(buffer_.size() >= free_nodes_.size()); + return buffer_.size() - free_nodes_.size(); + } + + /*! \brief Get the RulePosition with the given id. */ + RulePosition& operator[](int32_t id) { return buffer_[id]; } + const RulePosition& operator[](int32_t id) const { return buffer_[id]; } + + void Reset() { + buffer_.clear(); + free_nodes_.clear(); + } + + friend class RulePositionTree; + + private: + /*! \brief The buffer to store all RulePositions. */ + std::vector buffer_; + /*! \brief A stack to store all free node ids. */ + std::vector free_nodes_; +}; + +/*! + * \brief A tree structure to store all stacks. Every stack contains several RulePositions, and + * is represented as a path from the root to a leaf node. + */ +class RulePositionTree { + public: + /*! \brief Construct a RulePositionTree associated with the given grammar. */ + RulePositionTree(const BNFGrammar& grammar) : grammar_(grammar) {} + + /*! + * \brief Create a new node with the given RulePosition. The reference count of the new node + * is zero. + * + * \note Later, this node should either be pointed by some child rule, or become a stack top + * node (so it will be pointed to by an attached pointer) to be maintained in the + * reference-counting based memory management. + */ + int32_t NewNode(const RulePosition& rule_position) { + auto id = node_buffer_.Allocate(rule_position); + if (rule_position.parent_id != RulePosition::kNoParent) { + DCHECK(rule_position.parent_id < static_cast(node_buffer_.Capacity()) && + node_buffer_[rule_position.parent_id] != kInvalidRulePosition); + node_buffer_[rule_position.parent_id].reference_count++; + } + return id; + } + + /*! + * \brief Update a node in the stack to the next position. Next position means either the next + * element in the current rule, or if the current element is the last element in the rule, the + * next element in the parent rule. If the current node is the last element in the main rule, it + * is at the end position. + */ + RulePosition GetNextPosition(RulePosition rule_position) const; + + bool IsEndPosition(const RulePosition& rule_position) const; + + /*! \brief Attach an additional reference to the node with the given id. */ + void AttachRefTo(int32_t id) { + DCHECK(id != RulePosition::kNoParent); + node_buffer_[id].reference_count++; + } + + /*! \brief Remove a reference to the node with the given id. If the reference count becomes zero, + * free the node and recursively all its ancestors with zero reference count. */ + void RemoveRefTo(int32_t id) { + DCHECK(id != RulePosition::kNoParent); + auto cur_node = id; + while (cur_node != RulePosition::kNoParent) { + node_buffer_[cur_node].reference_count--; + if (node_buffer_[cur_node].reference_count != 0) { + break; + } + auto next_node = node_buffer_[cur_node].parent_id; + node_buffer_.Free(cur_node); + cur_node = next_node; + } + } + + /*! \brief Get the RulePosition with the given id. */ + const RulePosition& operator[](int32_t id) const { + DCHECK(id != RulePosition::kNoParent); + return node_buffer_[id]; + } + + /*! \brief Print the node with the given id to a string. */ + std::string PrintNode(int32_t id) const; + + /*! \brief Print the stack with the given top id to a string. */ + std::string PrintStackByTopId(int32_t top_id) const; + + /*! + * \brief Check the well-formedness of the tree and the associated buffer. For debug purpose. + * \details This function checks the following properties: + * 1. Every node is pointed directly or indirectly by a outside pointer. + * 2. Every node's reference count is consistent with the actual reference count. + * 3. All ids and positions are valid. + * 4. If a node in the buffer is free, it should be equal to kInvalidRulePosition. + */ + void CheckWellFormed(const std::vector& outside_pointers) const; + + /*! \brief Reset the tree and the associated buffer. */ + void Reset() { node_buffer_.Reset(); } + + private: + /*! \brief The grammar associated with this RulePositionTree. */ + BNFGrammar grammar_; + /*! \brief The buffer to store all RulePositions. */ + RulePositionBuffer node_buffer_; +}; + +/*! + * \brief A class to maintain the stack tops and its history to support rollback. + * \details This class helps to maintain nodes by automatically maintaining the attached references. + * If a node is not existing in any stack in the history record, it will be freed. + * + * It can store up to the previous max_rollback_steps + 1 steps of history, and thus supports + * rolling back up to max_rollback_steps steps. + */ +class StackTopsHistory { + public: + /*! + * \param tree The RulePositionTree to be associated with. Possibly modify the tree by attaching + * and removing references to the stack top nodes. + * \param max_rollback_steps The maximum number of rollback steps to be supported. + */ + StackTopsHistory(RulePositionTree* tree) : tree_(tree) {} + + /*! + * \brief Push a new history record consisting a list of stack tops. These nodes will be recorded + * as existing in a stack (by attaching a reference to them). + * \param stack_tops The stack tops to be pushed. + * \param drop_old Whether to drop the oldest history record if the history size exceeds the + * limit. If the history is dropped, node that do not exist in any stack any more will be freed. + */ + void PushHistory(const std::vector& stack_tops) { + stack_tops_history_.push_back(stack_tops); + for (auto id : stack_tops) { + tree_->AttachRefTo(id); + } + } + + /*! \brief Roll back to several previous steps. Possibly frees node that do not exist in any stack + * any more. */ + void Rollback(int rollback_steps) { + DCHECK(rollback_steps < stack_tops_history_.size()) + << "The number of requested rollback steps is greater than or equal to the current " + "history " + << "size: " << rollback_steps << " vs " << stack_tops_history_.size() << "."; + while (rollback_steps--) { + PopLatest(); + } + } + + /*! \brief Discard the earliest several steps. Possibly frees node that do not exist in any stack + * any more. */ + void DiscardEarliest(int discard_steps) { + DCHECK(discard_steps < stack_tops_history_.size()) + << "The number of requested discard steps is greater than or equal to the current " + "history " + << "size: " << discard_steps << " vs " << stack_tops_history_.size() << "."; + while (discard_steps--) { + PopEarliest(); + } + } + + /*! \brief Get the latest stack tops. */ + const std::vector& GetLatest() const { return stack_tops_history_.back(); } + + /*! + * \brief Print one history record. + * \param history_position_to_latest The number of steps behind the latest record. 0 means the + * latest record. + */ + std::string PrintHistory(int history_position_to_latest = 0) const; + + /*! \brief Get the number of history records. */ + int Size() const { return stack_tops_history_.size(); } + + /*! \brief Check the well-formedness of the tree and the associated buffer. */ + void CheckWellFormed() const; + + /*! \brief Reset the history and the associated node tree. */ + void Reset() { + stack_tops_history_.clear(); + tree_->Reset(); + } + + private: + /*! \brief Pop the oldest history record. Possibly frees node that do not exist in any stack any + * more. */ + void PopEarliest() { + const auto& old_stack_tops = stack_tops_history_.front(); + for (auto id : old_stack_tops) { + tree_->RemoveRefTo(id); + } + stack_tops_history_.pop_front(); + } + + /*! \brief Pop the latest history record. Possibly frees node that do not exist in any stack any + * more. */ + void PopLatest() { + const auto& new_stack_tops = stack_tops_history_.back(); + for (auto id : new_stack_tops) { + tree_->RemoveRefTo(id); + } + stack_tops_history_.pop_back(); + } + + /*! \brief Modifiable pointer to the RulePositionTree. */ + RulePositionTree* tree_; + /*! \brief The history of stack tops. */ + std::deque> stack_tops_history_; +}; + +/*! \brief See GetNextPosition. */ +inline bool RulePositionTree::IsEndPosition(const RulePosition& rule_position) const { + return rule_position.parent_id == RulePosition::kNoParent && + grammar_->GetRuleExpr(rule_position.sequence_id).size() == rule_position.element_id; +} + +/*! + * \brief Update a node in the stack to the next position. Next position means either the next + * element in the current rule, or if the current element is the last element in the rule, the + * next element in the parent rule. If the current node is the last element in the main rule, it + * is at the end position. + */ +inline RulePosition RulePositionTree::GetNextPosition(RulePosition rule_position) const { + if (IsEndPosition(rule_position)) { + return kInvalidRulePosition; + } + rule_position = RulePosition(rule_position.rule_id, rule_position.sequence_id, + rule_position.element_id + 1, rule_position.parent_id); + while (rule_position.parent_id != RulePosition::kNoParent && + grammar_->GetRuleExpr(rule_position.sequence_id).size() == rule_position.element_id) { + auto parent_rule_position = node_buffer_[rule_position.parent_id]; + rule_position = + RulePosition(parent_rule_position.rule_id, parent_rule_position.sequence_id, + parent_rule_position.element_id + 1, parent_rule_position.parent_id); + } + return rule_position; +} + +inline std::string RulePositionTree::PrintNode(int32_t id) const { + std::stringstream ss; + const auto& rule_position = node_buffer_[id]; + ss << "id: " << id; + ss << ", rule " << rule_position.rule_id << ": " << grammar_->GetRule(rule_position.rule_id).name; + ss << ", sequence " << rule_position.sequence_id << ": " + << BNFGrammarPrinter(grammar_).PrintRuleExpr(rule_position.sequence_id); + ss << ", element id: " << rule_position.element_id << ", parent id: " << rule_position.parent_id + << ", ref count: " << rule_position.reference_count; + return ss.str(); +} + +inline std::string RulePositionTree::PrintStackByTopId(int32_t top_id) const { + std::stringstream ss; + std::vector stack; + for (auto cur_id = top_id; cur_id != RulePosition::kNoParent; + cur_id = node_buffer_[cur_id].parent_id) { + stack.push_back(cur_id); + } + ss << "{\n"; + for (auto it = stack.rbegin(); it != stack.rend(); ++it) { + ss << PrintNode(*it) << "\n"; + } + ss << "}"; + return ss.str(); +} + +inline void RulePositionTree::CheckWellFormed(const std::vector& outside_pointers) const { + const auto& buffer = node_buffer_.buffer_; + std::unordered_set free_nodes_set(node_buffer_.free_nodes_.begin(), + node_buffer_.free_nodes_.end()); + int buffer_size = static_cast(buffer.size()); + std::vector new_reference_counter(buffer_size, 0); + std::vector visited(buffer_size, false); + std::queue visit_queue; + for (auto id : outside_pointers) { + CHECK(id >= 0 && id < buffer_size); + CHECK(buffer[id] != kInvalidRulePosition); + new_reference_counter[id]++; + if (visited[id] == false) { + visited[id] = true; + visit_queue.push(id); + } + } + while (!visit_queue.empty()) { + auto cur_id = visit_queue.front(); + visit_queue.pop(); + const auto& rule_position = buffer[cur_id]; + if (rule_position.parent_id != RulePosition::kNoParent) { + CHECK(rule_position.parent_id >= 0 && rule_position.parent_id < buffer_size); + CHECK(buffer[rule_position.parent_id] != kInvalidRulePosition); + new_reference_counter[rule_position.parent_id]++; + if (visited[rule_position.parent_id] == false) { + visited[rule_position.parent_id] = true; + visit_queue.push(rule_position.parent_id); + } + } + } + + for (int i = 0; i < static_cast(buffer.size()); ++i) { + if (free_nodes_set.count(i)) { + CHECK(buffer[i] == kInvalidRulePosition); + CHECK(visited[i] == false); + } else { + CHECK(visited[i] == true); + CHECK(buffer[i] != kInvalidRulePosition); + CHECK(new_reference_counter[i] == buffer[i].reference_count) + << "Reference counters unmatch for node #" << i << ": Updated " + << new_reference_counter[i] << ", Original " << buffer[i].reference_count; + } + } +} + +inline std::string StackTopsHistory::PrintHistory(int history_position_to_latest) const { + const auto& latest_tops = + stack_tops_history_[stack_tops_history_.size() - 1 - history_position_to_latest]; + std::stringstream ss; + ss << "Stacks tops size: " << latest_tops.size() << std::endl; + int cnt = 0; + for (auto id : latest_tops) { + ss << "Stack #" << cnt << ": " << tree_->PrintStackByTopId(id) << "\n"; + ++cnt; + } + return ss.str(); +} + +inline void StackTopsHistory::CheckWellFormed() const { + std::vector outside_pointers; + for (const auto& stack_tops : stack_tops_history_) { + outside_pointers.insert(outside_pointers.end(), stack_tops.begin(), stack_tops.end()); + } + tree_->CheckWellFormed(outside_pointers); +} + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_STATE_H_ diff --git a/cpp/serve/grammar/support.h b/cpp/serve/grammar/support.h new file mode 100644 index 0000000000..9ee6ffb3b3 --- /dev/null +++ b/cpp/serve/grammar/support.h @@ -0,0 +1,123 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/grammar/support.h + * \brief The header for utilities used in grammar-guided generation. + */ +#ifndef MLC_LLM_SERVE_GRAMMAR_SUPPORT_H_ +#define MLC_LLM_SERVE_GRAMMAR_SUPPORT_H_ + +#include + +#include +#include + +namespace mlc { +namespace llm { +namespace serve { + +/*! \brief Manages a segment of externally provided memory and use it as a bitset. */ +class BitsetManager { + public: + BitsetManager(uint32_t* data, int buffer_size) : data_(data), buffer_size_(buffer_size) {} + + static int GetBitsetSize(int size) { return (size + 31) / 32; } + + bool operator[](int index) const { + DCHECK(index >= 0 && index / 32 < buffer_size_); + return (data_[index / 32] >> (index % 32)) & 1; + } + + void Set(int index, bool value) { + DCHECK(index >= 0 && index / 32 < buffer_size_); + if (value) { + data_[index / 32] |= 1 << (index % 32); + } else { + data_[index / 32] &= ~(1 << (index % 32)); + } + } + + void Reset(int size, bool value) { + DCHECK(buffer_size_ >= GetBitsetSize(size)); + std::memset(data_, value ? 0xFF : 0, GetBitsetSize(size) * sizeof(uint32_t)); + } + + private: + uint32_t* const data_; + const int buffer_size_; +}; + +/*! + * \brief Let lhs be the union of lhs and rhs. Suppose that both sets are sorted. + * \note No additional vectors are allocated, and the time complexity is O(n) + */ +void IntsetUnion(std::vector* lhs, const std::vector& rhs) { + int original_lhs_size = lhs->size(); + int rhs_size = rhs.size(); + + lhs->resize(original_lhs_size + rhs_size); + + auto it_lhs = lhs->rbegin() + rhs_size; + auto it_rhs = rhs.rbegin(); + auto it_result = lhs->rbegin(); + + while (it_lhs != lhs->rend() && it_rhs != rhs.rend()) { + if (*it_lhs > *it_rhs) { + *it_result = *it_lhs; + ++it_lhs; + } else if (*it_lhs < *it_rhs) { + *it_result = *it_rhs; + ++it_rhs; + } else { + *it_result = *it_lhs; + ++it_lhs; + ++it_rhs; + } + ++it_result; + } + + while (it_rhs != rhs.rend()) { + *it_result = *it_rhs; + ++it_result; + ++it_rhs; + } + + auto last = std::unique(lhs->begin(), lhs->end()); + lhs->erase(last, lhs->end()); +} + +/*! + * \brief Let lhs be the intersection of lhs and rhs. Suppose that both sets are sorted. + * \note No additional vector is allocated, and the time complexity is O(n). + * \note Support the case where lhs is the universal set by setting lhs to {-1}. The result will be + * rhs then. + */ +void IntsetIntersection(std::vector* lhs, const std::vector& rhs) { + if (lhs->size() == 1 && (*lhs)[0] == -1) { + *lhs = rhs; + return; + } + + auto it_lhs = lhs->begin(); + auto it_rhs = rhs.begin(); + auto it_result = lhs->begin(); + + while (it_lhs != lhs->end() && it_rhs != rhs.end()) { + if (*it_lhs < *it_rhs) { + ++it_lhs; + } else if (*it_lhs > *it_rhs) { + ++it_rhs; + } else { + *it_result = *it_lhs; + ++it_lhs; + ++it_rhs; + ++it_result; + } + } + lhs->erase(it_result, lhs->end()); +} + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_GRAMMAR_SUPPORT_H_ diff --git a/cpp/serve/encoding.cc b/cpp/support/encoding.cc similarity index 94% rename from cpp/serve/encoding.cc rename to cpp/support/encoding.cc index a839584cf7..0509c1eb2a 100644 --- a/cpp/serve/encoding.cc +++ b/cpp/support/encoding.cc @@ -10,7 +10,6 @@ namespace mlc { namespace llm { -namespace serve { std::string CodepointToUtf8(TCodepoint codepoint) { ICHECK(codepoint <= 0x10FFFF) << "Invalid codepoint: " << codepoint; @@ -37,6 +36,33 @@ std::string CodepointToUtf8(TCodepoint codepoint) { return utf8; } +std::string CodepointToPrintable( + TCodepoint codepoint, const std::unordered_map& custom_escape_map) { + static const std::unordered_map kCodepointToEscape = { + {'\'', "\\\'"}, {'\"', "\\\""}, {'\?', "\\\?"}, {'\\', "\\\\"}, {'\a', "\\a"}, + {'\b', "\\b"}, {'\f', "\\f"}, {'\n', "\\n"}, {'\r', "\\r"}, {'\t', "\\t"}, + {'\v', "\\v"}, {'\0', "\\0"}, {'\x1B', "\\e"}}; + + if (auto it = custom_escape_map.find(codepoint); it != custom_escape_map.end()) { + return it->second; + } + + if (auto it = kCodepointToEscape.find(codepoint); it != kCodepointToEscape.end()) { + return it->second; + } + + if (codepoint >= 0x20 && codepoint <= 0x7E) { + return std::string({static_cast(codepoint)}); + } + + // convert codepoint to hex + int width = codepoint <= 0xFFFF ? 4 : 8; + std::stringstream ss; + ss << std::setfill('0') << std::setw(width) << std::hex << codepoint; + auto hex = ss.str(); + return codepoint <= 0xFFFF ? "\\u" + hex : "\\U" + hex; +} + std::pair Utf8ToCodepoint(const char* utf8) { const std::array kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07}; // clang-format off @@ -77,31 +103,17 @@ std::pair Utf8ToCodepoint(const char* utf8) { return {res, bytes}; } -std::string CodepointToPrintable( - TCodepoint codepoint, const std::unordered_map& custom_escape_map) { - static const std::unordered_map kCodepointToEscape = { - {'\'', "\\\'"}, {'\"', "\\\""}, {'\?', "\\\?"}, {'\\', "\\\\"}, {'\a', "\\a"}, - {'\b', "\\b"}, {'\f', "\\f"}, {'\n', "\\n"}, {'\r', "\\r"}, {'\t', "\\t"}, - {'\v', "\\v"}, {'\0', "\\0"}, {'\x1B', "\\e"}}; - - if (auto it = custom_escape_map.find(codepoint); it != custom_escape_map.end()) { - return it->second; - } - - if (auto it = kCodepointToEscape.find(codepoint); it != kCodepointToEscape.end()) { - return it->second; - } - - if (codepoint >= 0x20 && codepoint <= 0x7E) { - return std::string({static_cast(codepoint)}); +std::vector Utf8StringToCodepoints(const char* utf8) { + std::vector codepoints; + while (*utf8 != 0) { + auto [codepoint, bytes] = Utf8ToCodepoint(utf8); + if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { + return {codepoint}; + } + codepoints.push_back(codepoint); + utf8 += bytes; } - - // convert codepoint to hex - int width = codepoint <= 0xFFFF ? 4 : 8; - std::stringstream ss; - ss << std::setfill('0') << std::setw(width) << std::hex << codepoint; - auto hex = ss.str(); - return codepoint <= 0xFFFF ? "\\u" + hex : "\\U" + hex; + return codepoints; } int HexCharToInt(char c) { @@ -168,6 +180,5 @@ std::pair Utf8OrEscapeToCodepoint( } } -} // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/encoding.h b/cpp/support/encoding.h similarity index 95% rename from cpp/serve/encoding.h rename to cpp/support/encoding.h index 88fba475e9..f28aae6d74 100644 --- a/cpp/serve/encoding.h +++ b/cpp/support/encoding.h @@ -8,10 +8,10 @@ #include #include +#include namespace mlc { namespace llm { -namespace serve { /*! \brief Represents a unicode codepoint. */ using TCodepoint = int32_t; @@ -42,9 +42,9 @@ std::string CodepointToPrintable( */ enum class CharHandlingError : TCodepoint { /*! \brief The UTF-8 string is invalid. */ - kInvalidUtf8 = -1, + kInvalidUtf8 = -10, /*! \brief The escape sequence is invalid. */ - kInvalidEscape = -2, + kInvalidEscape = -11, }; /*! @@ -55,6 +55,8 @@ enum class CharHandlingError : TCodepoint { */ std::pair Utf8ToCodepoint(const char* utf8); +std::vector Utf8StringToCodepoints(const char* utf8); + /*! * \brief Convert a UTF-8 string or an escape sequence to a codepoint. By default the function * supports escape sequences in C ("\n", "\t", "\u0123"). User can specify more escape sequences @@ -69,7 +71,6 @@ std::pair Utf8ToCodepoint(const char* utf8); std::pair Utf8OrEscapeToCodepoint( const char* utf8, const std::unordered_map& custom_escape_map = {}); -} // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/tokenizers.cc b/cpp/tokenizers.cc index 2b4fef71cd..ef866f3bfc 100644 --- a/cpp/tokenizers.cc +++ b/cpp/tokenizers.cc @@ -34,6 +34,16 @@ std::string TokenizerObj::Decode(const std::vector& token_ids) const { return tokenizer->Decode(token_ids); } +size_t TokenizerObj::GetVocabSize() const { return tokenizer->GetVocabSize(); } + +std::string TokenizerObj::IdToToken(int32_t token_id) const { + return tokenizer->IdToToken(token_id); +} + +int32_t TokenizerObj::TokenToId(const std::string& token) const { + return tokenizer->TokenToId(token); +} + Tokenizer Tokenizer::FromPath(const String& _path) { std::filesystem::path path(_path.operator std::string()); std::filesystem::path sentencepiece; diff --git a/cpp/tokenizers.h b/cpp/tokenizers.h index a86c45ea53..16d9ba456b 100644 --- a/cpp/tokenizers.h +++ b/cpp/tokenizers.h @@ -33,6 +33,22 @@ class TokenizerObj : public Object { /*! \brief Return the token table of the tokenizer. */ const std::vector& TokenTable(); + /*! + * \brief Returns the vocabulary size. Special tokens are considered. + */ + size_t GetVocabSize() const; + + /*! + * \brief Convert the given id to its corresponding token if it exists. If not, return an + * empty string. + */ + std::string IdToToken(int32_t token_id) const; + + /*! + * \brief Convert the given token to its corresponding id if it exists. If not, return -1. + */ + int32_t TokenToId(const std::string& token) const; + static constexpr const char* _type_key = "mlc.Tokenizer"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; diff --git a/python/mlc_chat/serve/__init__.py b/python/mlc_chat/serve/__init__.py index f4560cee8f..8e31ae5f7e 100644 --- a/python/mlc_chat/serve/__init__.py +++ b/python/mlc_chat/serve/__init__.py @@ -5,6 +5,6 @@ from .config import EngineMode, GenerationConfig, KVCacheConfig from .data import Data, TextData, TokenData from .engine import Engine -from .grammar import BNFGrammar +from .grammar import BNFGrammar, GrammarStateMatcher from .request import Request, RequestStreamOutput from .server import PopenServer diff --git a/python/mlc_chat/serve/grammar.py b/python/mlc_chat/serve/grammar.py index bf0eedbfa8..3df954cb22 100644 --- a/python/mlc_chat/serve/grammar.py +++ b/python/mlc_chat/serve/grammar.py @@ -1,7 +1,10 @@ """Classes handling the grammar guided generation of MLC LLM serving""" +from typing import List, Union + import tvm._ffi from tvm.runtime import Object +from ..tokenizer import Tokenizer from . import _ffi_api @@ -14,7 +17,9 @@ class BNFGrammar(Object): """ @staticmethod - def from_ebnf_string(ebnf_string: str) -> "BNFGrammar": + def from_ebnf_string( + ebnf_string: str, normalize: bool = True, simplify: bool = True + ) -> "BNFGrammar": r"""Parse a BNF grammar from a string in BNF/EBNF format. This method accepts the EBNF notation from the W3C XML Specification @@ -31,13 +36,28 @@ def from_ebnf_string(ebnf_string: str) -> "BNFGrammar": ebnf_string : str The grammar string. + normalize : bool + Whether to normalize the grammar. Default: true. Only set to false for the purpose of + testing. + + In The normalized form of a BNF grammar, every rule is in the form: + `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. + + I.e. a list of choices, each choice is a sequence of elements. Elements can be a + character class or a rule reference. And if the rule can be empty, the first choice + will be an empty string. + + simplify : bool + Whether to simplify the grammar to make matching more efficient. Default: true. Not + implemented yet. + Returns ------- grammar : BNFGrammar The parsed BNF grammar. """ return _ffi_api.BNFGrammarFromEBNFString( # type: ignore # pylint: disable=no-member - ebnf_string + ebnf_string, normalize, simplify ) def to_string(self) -> str: @@ -50,6 +70,9 @@ def to_string(self) -> str: """ return str(_ffi_api.BNFGrammarToString(self)) # type: ignore # pylint: disable=no-member + def __str__(self) -> str: + return self.to_string() + @staticmethod def from_json(json_string: str) -> "BNFGrammar": """Load a BNF grammar from the raw representation of the AST in JSON format. @@ -82,3 +105,138 @@ def to_json(self, prettify: bool = True) -> str: return str( _ffi_api.BNFGrammarToJSON(self, prettify) # type: ignore # pylint: disable=no-member ) + + @staticmethod + def get_grammar_of_json() -> "BNFGrammar": + """Get the grammar of standard JSON. + + Returns + ------- + grammar : BNFGrammar + The JSON grammar. + """ + return _ffi_api.BNFGrammarGetGrammarOfJSON() # type: ignore # pylint: disable=no-member + + +@tvm._ffi.register_object("mlc.serve.GrammarStateMatcher") # pylint: disable=protected-access +class GrammarStateMatcher(Object): + """A stateful matcher to match tokens to the specified BNF grammar. This class is the core logic + of the grammar-guided generation. + + This class implements the non-deterministic pushdown automaton (NPDA) matching algorithm to + match characters to a BNF grammar. It keep track of the current state of the matching process by + maintaining several stacks internally as possible paths in the NPDA. It also supports + backtracking. + + It is particularly capable of finding the set of tokens that are acceptable for the next step + and storing them in a bitmask. This aids in grammar-guided generation. + + Parameters + ---------- + grammar : BNFGrammar + The BNF grammar to match. + + tokenizer : Union[None, Tokenizer, List[str]] + The tokenizer to use, or the list of tokens. + + (For debug purpose) If None, the matcher will use an empty token set, and can only accept + and match characters. Default: None. + + max_rollback_steps : int + The maximum number of steps to rollback when backtracking. Default: 0. + """ + + def __init__( + self, + grammar: BNFGrammar, + tokenizer: Union[None, Tokenizer, List[str]] = None, + max_rollback_steps: int = 0, + ): + if isinstance(tokenizer, list): + self.__init_handle_by_constructor__( + _ffi_api.GrammarStateMatcherFromTokenTable, # type: ignore # pylint: disable=no-member + grammar, + *tokenizer, + max_rollback_steps, + ) + else: + self.__init_handle_by_constructor__( + _ffi_api.GrammarStateMatcherFromTokenizer, # type: ignore # pylint: disable=no-member + grammar, + tokenizer, + max_rollback_steps, + ) + + def accept_token(self, token_id: int) -> bool: + """Accept one token and update the state of the matcher. + + Parameters + ---------- + token_id : int + The id of the token to accept. + + Returns + ------- + accepted : bool + Whether the token is accepted. + """ + return _ffi_api.GrammarStateMatcherAcceptToken(self, token_id) # type: ignore # pylint: disable=no-member + + def find_next_rejected_tokens(self) -> List[int]: + """Find the ids of the rejected tokens for the next step. + + Returns + ------- + rejected_token_ids : List[int] + A list of rejected token ids. + """ + + return _ffi_api.GrammarStateMatcherFindNextRejectedTokens(self) # type: ignore # pylint: disable=no-member + + def rollback(self, num_tokens: int) -> None: + """Rollback the matcher to a previous state. + + Parameters + ---------- + num_tokens : int + The number of tokens to rollback. It cannot exceed the current number of steps, nor can + it exceed the specified maximum number of rollback steps. + """ + _ffi_api.GrammarStateMatcherRollback(self, num_tokens) # type: ignore # pylint: disable=no-member + + def max_rollback_steps(self) -> int: + """Get the maximum number of rollback steps allowed. + + Returns + ------- + max_rollback_steps : int + The maximum number of rollback steps. + """ + return _ffi_api.GrammarStateMatcherMaxRollbackSteps(self) # type: ignore # pylint: disable=no-member + + def reset_state(self) -> None: + """Reset the matcher to the initial state.""" + _ffi_api.GrammarStateMatcherResetState(self) # type: ignore # pylint: disable=no-member + + def debug_accept_char(self, codepoint: int) -> bool: + """Accept one unicode codepoint to the current state. + + Parameters + ---------- + codepoint : int + The unicode codepoint of the character to be accepted. + """ + return _ffi_api.GrammarStateMatcherDebugAcceptCodepoint( # type: ignore # pylint: disable=no-member + self, codepoint + ) + + def debug_match_complete_string(self, string: str) -> bool: + """Check if a matcher can accept the complete string, and then reach the end of the + grammar. + + Parameters + ---------- + string : str + The string to be matched. + """ + return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string) # type: ignore # pylint: disable=no-member diff --git a/tests/python/__init__.py b/tests/python/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/python/conftest.py b/tests/python/conftest.py new file mode 100644 index 0000000000..b19fce722c --- /dev/null +++ b/tests/python/conftest.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,unused-import +import pytest +import tvm.testing + +pytest_plugins = ["tvm.testing.plugin"] diff --git a/tests/python/serve/test_grammar_parser.py b/tests/python/serve/test_grammar_parser.py index d9eea18cda..dd6cc64b5d 100644 --- a/tests/python/serve/test_grammar_parser.py +++ b/tests/python/serve/test_grammar_parser.py @@ -13,46 +13,45 @@ def test_bnf_simple(): b ::= "b" c ::= "c" """ - expected = """main ::= b c -b ::= [b] -c ::= [c] + expected = """main ::= ((b c)) +b ::= (([b])) +c ::= (([c])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before) + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) after = bnf_grammar.to_string() assert after == expected def test_ebnf(): before = """main ::= b c | b main -b ::= "b"* d +b ::= "b"* c ::= [acep-z]+ d ::= "d"? """ - expected = """main ::= (b c) | (b main) -b ::= b_1 d -c ::= c_2 -d ::= d_1 -b_1 ::= ([b] b_1) | "" -c_1 ::= [acep-z] -c_2 ::= (c_1 c_2) | c_1 -d_1 ::= [d] | "" + expected = """main ::= ((b c) | (b main)) +b ::= [b]* +c ::= ((c_2)) +d ::= ((d_1)) +c_1 ::= (([acep-z])) +c_2 ::= ((c_1 c_2) | (c_1)) +d_1 ::= ("" | ([d])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before) + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) after = bnf_grammar.to_string() - print(after) assert after == expected def test_char(): - before = r"""main ::= [a-z] [A-z] "\u0234" "\U00000345\xff" [-A-Z] [--] rest + before = r"""main ::= [a-z] [A-z] "\u0234" "\U00000345\xff" [-A-Z] [--] [^a] rest rest ::= [a-zA-Z0-9-] [\u0234-\U00000345] [测-试] [\--\]] rest1 rest1 ::= "\?\"\'测试あc" "👀" "" """ - expected = r"""main ::= [a-z] [A-z] [\u0234] ([\u0345] [\u00ff]) [\-A-Z] [\-\-] rest -rest ::= [a-zA-Z0-9\-] [\u0234-\u0345] [\u6d4b-\u8bd5] [\--\]] rest1 -rest1 ::= ([\?] [\"] [\'] [\u6d4b] [\u8bd5] [\u3042] [c]) [\U0001f440] "" + expected = r"""main ::= (([a-z] [A-z] ([\u0234]) ([\u0345] [\u00ff]) [\-A-Z] [\-\-] [^a] rest)) +rest ::= (([a-zA-Z0-9\-] [\u0234-\u0345] [\u6d4b-\u8bd5] [\--\]] rest1)) +rest1 ::= ((([\?] [\"] [\'] [\u6d4b] [\u8bd5] [\u3042] [c]) ([\U0001f440]) "")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before) + # Disable unwrap_nesting_rules to expose the result before unwrapping. + bnf_grammar = BNFGrammar.from_ebnf_string(before, False, False) after = bnf_grammar.to_string() assert after == expected @@ -65,9 +64,9 @@ def test_space(): "f" | "g" """ - expected = """main ::= ([a] [b] ([c] [d] [e])) | [f] | [g] + expected = """main ::= (([a] [b] [c] [d] [e]) | ([f]) | ([g])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before) + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) after = bnf_grammar.to_string() assert after == expected @@ -75,9 +74,31 @@ def test_space(): def test_nest(): before = """main::= "a" ("b" | "c" "d") | (("e" "f")) """ - expected = """main ::= ([a] ([b] | ([c] [d]))) | ([e] [f]) + expected = """main ::= (([a] main_choice) | ([e] [f])) +main_choice ::= (([b]) | ([c] [d])) +""" + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + after = bnf_grammar.to_string() + assert after == expected + + +def test_flatten(): + before = """main ::= or_test sequence_test nested_test empty_test +or_test ::= ([a] | "b") | "de" | "" | or_test | [^a-z] +sequence_test ::= [a] "a" ("b" ("c" | "d")) ("d" "e") sequence_test "" +nested_test ::= ("a" ("b" ("c" "d"))) | ("a" | ("b" | "c")) | nested_rest +nested_rest ::= ("a" | ("b" "c" | ("d" | "e" "f"))) | ((("g"))) +empty_test ::= "d" | (("" | "" "") "" | "a" "") | ("" ("" | "")) "" "" +""" + expected = """main ::= ((or_test sequence_test nested_test empty_test)) +or_test ::= ("" | ([a]) | ([b]) | ([d] [e]) | (or_test) | ([^a-z])) +sequence_test ::= (([a] [a] [b] sequence_test_choice [d] [e] sequence_test)) +nested_test ::= (([a] [b] [c] [d]) | ([a]) | ([b]) | ([c]) | (nested_rest)) +nested_rest ::= (([a]) | ([b] [c]) | ([d]) | ([e] [f]) | ([g])) +empty_test ::= ("" | ([d]) | ([a])) +sequence_test_choice ::= (([c]) | ([d])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before) + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) after = bnf_grammar.to_string() assert after == expected @@ -89,35 +110,39 @@ def test_json(): with open(json_ebnf_path, "r", encoding="utf-8") as file: before = file.read() - expected = r"""main ::= element -value ::= object | array | string | number | ([t] [r] [u] [e]) | ([f] [a] [l] [s] [e]) | ([n] [u] [l] [l]) -object ::= ([{] ws [}]) | ([{] members [}]) -members ::= member | (member [,] members) -member ::= ws string ws [:] element -array ::= ([[] ws [\]]) | ([[] elements [\]]) -elements ::= element | (element [,] elements) -element ::= ws value ws -string ::= [\"] characters [\"] -characters ::= "" | (character characters) -character ::= [\"\\] | ([\\] escape) -escape ::= [\"] | [\\] | [/] | [b] | [f] | [n] | [r] | [t] | ([u] hex hex hex hex) -hex ::= [A-Fa-f0-9] -number ::= integer fraction exponent -integer ::= digit | (onenine digits) | ([\-] digit) | ([\-] onenine digits) -digits ::= digit | (digit digits) -digit ::= [0-9] -onenine ::= [1-9] -fraction ::= "" | ([.] digits) -exponent ::= "" | (([e] | [E]) ("" | [+] | [\-]) digits) -ws ::= "" | ([ ] ws) | ([\n] ws) | ([\r] ws) | ([\t] ws) -""" - - bnf_grammar = BNFGrammar.from_ebnf_string(before) + expected = r"""main ::= ((element)) +value ::= ((object) | (array) | (string) | (number) | ([t] [r] [u] [e]) | ([f] [a] [l] [s] [e]) | ([n] [u] [l] [l])) +object ::= (([{] ws [}]) | ([{] members [}])) +members ::= ((member) | (member [,] members)) +member ::= ((ws string ws [:] element)) +array ::= (([[] ws [\]]) | ([[] elements [\]])) +elements ::= ((element) | (element [,] elements)) +element ::= ((ws value ws)) +string ::= (([\"] characters [\"])) +characters ::= ("" | (character characters)) +character ::= (([^\"\\]) | ([\\] escape)) +escape ::= (([\"]) | ([\\]) | ([/]) | ([b]) | ([f]) | ([n]) | ([r]) | ([t]) | ([u] hex hex hex hex)) +hex ::= (([A-Fa-f0-9])) +number ::= ((integer fraction exponent)) +integer ::= ((digit) | (onenine digits) | ([\-] digit) | ([\-] onenine digits)) +digits ::= ((digit) | (digit digits)) +digit ::= (([0-9])) +onenine ::= (([1-9])) +fraction ::= ("" | ([.] digits)) +exponent ::= ("" | (exponent_choice exponent_choice_1 digits)) +ws ::= ("" | ([ ] ws) | ([\n] ws) | ([\r] ws) | ([\t] ws)) +exponent_choice ::= (([e]) | ([E])) +exponent_choice_1 ::= ("" | ([+]) | ([\-])) +""" + + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) after = bnf_grammar.to_string() assert after == expected def test_to_string_roundtrip(): + """Checks the printed result can be parsed, and the parsing-printing process is idempotent.""" + before = r"""main ::= (b c) | (b main) b ::= b_1 d c ::= c_1 @@ -127,51 +152,72 @@ def test_to_string_roundtrip(): c_2 ::= [acep-z] d_1 ::= [d] | "" """ - bnf_grammar = BNFGrammar.from_ebnf_string(before) - string = bnf_grammar.to_string() - new_grammar = BNFGrammar.from_ebnf_string(string) - new_string = new_grammar.to_string() - assert string == new_string + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, True, False) + output_string_1 = bnf_grammar_1.to_string() + bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, True, False) + output_string_2 = bnf_grammar_2.to_string() + assert output_string_1 == output_string_2 def test_error(): - with pytest.raises(TVMError, match="Rule a is not defined at line 1, column 11"): + with pytest.raises( + TVMError, match='TVMError: EBNF parse error at line 1, column 11: Rule "a" is not defined' + ): BNFGrammar.from_ebnf_string("main ::= a b") - with pytest.raises(TVMError, match="Expect element at line 1, column 15"): + with pytest.raises( + TVMError, match="TVMError: EBNF parse error at line 1, column 15: Expect element" + ): BNFGrammar.from_ebnf_string('main ::= "a" |') - with pytest.raises(TVMError, match='Expect " at line 1, column 15'): + with pytest.raises(TVMError, match='TVMError: EBNF parse error at line 1, column 15: Expect "'): BNFGrammar.from_ebnf_string('main ::= "a" "') - with pytest.raises(TVMError, match="Expect rule name at line 1, column 1"): + with pytest.raises( + TVMError, match="TVMError: EBNF parse error at line 1, column 1: Expect rule name" + ): BNFGrammar.from_ebnf_string('::= "a"') with pytest.raises( - TVMError, match="Character range should not contain newline at line 1, column 12" + TVMError, + match="TVMError: EBNF parse error at line 1, column 12: Character class should not contain " + "newline", ): BNFGrammar.from_ebnf_string("main ::= [a\n]") - with pytest.raises(TVMError, match="Invalid escape sequence at line 1, column 11"): + with pytest.raises( + TVMError, match="TVMError: EBNF parse error at line 1, column 11: Invalid escape sequence" + ): BNFGrammar.from_ebnf_string(r'main ::= "\@"') - with pytest.raises(TVMError, match="Invalid escape sequence at line 1, column 11"): + with pytest.raises( + TVMError, match="TVMError: EBNF parse error at line 1, column 11: Invalid escape sequence" + ): BNFGrammar.from_ebnf_string(r'main ::= "\uFF"') with pytest.raises( TVMError, - match="Invalid character range: lower bound is larger than upper bound at " - "line 1, column 14", + match="TVMError: EBNF parse error at line 1, column 14: Invalid character class: " + "lower bound is larger than upper bound", ): BNFGrammar.from_ebnf_string(r"main ::= [Z-A]") - with pytest.raises(TVMError, match="Expect ::= at line 1, column 6"): + with pytest.raises( + TVMError, match="TVMError: EBNF parse error at line 1, column 6: Expect ::=" + ): BNFGrammar.from_ebnf_string(r'main := "a"') - with pytest.raises(TVMError, match="Rule main is defined multiple times at line 2, column 9"): + with pytest.raises( + TVMError, + match='TVMError: EBNF parse error at line 2, column 9: Rule "main" is defined multiple ' + "times", + ): BNFGrammar.from_ebnf_string('main ::= "a"\nmain ::= "b"') - with pytest.raises(TVMError, match="There must be a rule named main at line 1, column 10"): + with pytest.raises( + TVMError, + match='TVMError: EBNF parse error at line 1, column 10: There must be a rule named "main"', + ): BNFGrammar.from_ebnf_string('a ::= "a"') @@ -181,34 +227,33 @@ def test_to_json(): c ::= [a-z] """ expected = ( - '{"rule_expr_indptr":[0,2,4,7,9,11,14,17,20,23,26,30,32,34,37,39],' - '"rule_expr_data":[3,1,3,2,4,0,1,3,1,3,0,4,3,4,5,2,5,0,98,98,0,99,99,0,100,' - "100,4,7,8,9,4,10,5,11,0,97,122,4,13,5,14]," - '"rules":[{"rule_expr_id":6,"name":"main"},{"rule_expr_id":12,"name":"b"},' - '{"rule_expr_id":15,"name":"c"}]}' + '{"rule_expr_indptr":[0,3,6,10,13,16,20,24,28,32,36,41,44,48,51],"rule_expr_data"' + ":[3,1,1,3,1,2,4,2,0,1,3,1,1,3,1,0,4,2,3,4,5,2,2,5,0,2,98,98,0,2,99,99,0,2,100,100," + '4,3,7,8,9,5,1,10,0,2,97,122,4,1,12,5,1,13],"rules":[{"body_expr_id":6,"name":"main"},' + '{"body_expr_id":11,"name":"b"},{"body_expr_id":14,"name":"c"}]}' ) - bnf_grammar = BNFGrammar.from_ebnf_string(before) + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) after = bnf_grammar.to_json(False) assert after == expected def test_to_json_roundtrip(): - before = r"""main ::= (b c) | (b main) -b ::= b_1 d -c ::= c_1 -d ::= d_1 -b_1 ::= ([b] b_1) | "" -c_1 ::= (c_2 c_1) | c_2 -c_2 ::= [acep-z] -d_1 ::= [d] | "" + before = r"""main ::= ((b c) | (b main)) +b ::= ((b_1 d)) +c ::= ((c_1)) +d ::= ((d_1)) +b_1 ::= ("" | ([b] b_1)) +c_1 ::= ((c_2 c_1) | (c_2)) +c_2 ::= (([acep-z])) +d_1 ::= ("" | ([d])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before) - json = bnf_grammar.to_json(False) - new_grammar = BNFGrammar.from_json(json) - new_json = new_grammar.to_json(False) - after = new_grammar.to_string() - assert json == new_json - assert after == before + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, True, False) + output_json_1 = bnf_grammar_1.to_json(False) + bnf_grammar_2 = BNFGrammar.from_json(output_json_1) + output_json_2 = bnf_grammar_2.to_json(False) + output_str = bnf_grammar_2.to_string() + assert output_json_1 == output_json_2 + assert output_str == before if __name__ == "__main__": diff --git a/tests/python/serve/test_grammar_state_matcher.py b/tests/python/serve/test_grammar_state_matcher.py new file mode 100644 index 0000000000..cf7229af21 --- /dev/null +++ b/tests/python/serve/test_grammar_state_matcher.py @@ -0,0 +1,387 @@ +# pylint: disable=missing-module-docstring,missing-function-docstring +# pylint: disable=redefined-outer-name,unbalanced-tuple-unpacking +from typing import List + +import pytest +import tvm +import tvm.testing + +from mlc_chat.serve import BNFGrammar, GrammarStateMatcher +from mlc_chat.tokenizer import Tokenizer + + +@pytest.fixture(scope="function") +def json_grammar(): + return BNFGrammar.get_grammar_of_json() + + +(json_input_accepted,) = tvm.testing.parameters( + ('{"name": "John"}',), + ('{ "name" : "John" } \n',), + ("{}",), + ("[]",), + ('{"name": "Alice", "age": 30, "city": "New York"}',), + ('{"name": "Mike", "hobbies": ["reading", "cycling", "hiking"]}',), + ('{"name": "Emma", "address": {"street": "Maple Street", "city": "Boston"}}',), + ('[{"name": "David"}, {"name": "Sophia"}]',), + ( + '{"name": "William", "age": null, "married": true, "children": ["Liam", "Olivia"],' + ' "hasPets": false}', + ), + ( + '{"name": "Olivia", "contact": {"email": "olivia@example.com", "address": ' + '{"city": "Chicago", "zipcode": "60601"}}}', + ), + ( + '{"name": "Liam", "skills": ["Java", "Python"], "experience": ' + '[{"company": "CompanyA", "years": 5}, {"company": "CompanyB", "years": 3}]}', + ), + ( + '{"person": {"name": "Ethan", "age": 40}, "education": {"degree": "Masters", ' + '"university": "XYZ University"}, "work": [{"company": "ABC Corp", "position": ' + '"Manager"}, {"company": "DEF Corp", "position": "Senior Manager"}]}', + ), + ( + '{"name": "Charlotte", "details": {"personal": {"age": 35, "hobbies": ["gardening", ' + '"painting"]}, "professional": {"occupation": "Engineer", "skills": ' + '["CAD", "Project Management"], "projects": [{"name": "Project A", ' + '"status": "Completed"}, {"name": "Project B", "status": "In Progress"}]}}}', + ), +) + + +def test_json_accept(json_grammar: BNFGrammar, json_input_accepted: str): + assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_accepted) + + +# test_json_accept(json_grammar(), '{"name": "John"}') +# exit() + +(json_input_refused,) = tvm.testing.parameters( + (r'{ name: "John" }',), + (r'{ "name": "John", "age": 30, }',), # x + (r'{ "name": "John", "address": { "street": "123 Main St", "city": "New York" }',), + (r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling",], }',), # x + (r'{ "name": "John", "age": 30.5.7 }',), + (r'{ "name": "John, "age": 30, "hobbies": ["reading", "traveling"] }',), + ( + r'{ "name": "John", "age": 30, "hobbies": ["reading", { "type": "outdoor", "list": ' + r'["hiking", "swimming",]}] }', # + ), + (r'{ "name": "John", "age": 30, "status": "\P\J" }',), + ( + r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling"], "address": ' + r'{ "street": "123 Main St", "city": "New York", "coordinates": { "latitude": 40.7128, ' + r'"longitude": -74.0060 }}}, "work": { "company": "Acme", "position": "developer" }}', + ), +) + + +def test_json_refuse(json_grammar: BNFGrammar, json_input_refused): + assert not GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_refused) + + +(json_input_pressure,) = tvm.testing.parameters( + # Extra long string: 1k chars + ( + '["Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer nec odio. Praesent ' + "libero. Sed cursus ante dapibus diam. Sed nisi. Nulla quis sem at nibh elementum " + "imperdiet. Duis sagittis ipsum. Praesent mauris. Fusce nec tellus sed augue semper " + "porta. Mauris massa. Vestibulum lacinia arcu eget nulla. Class aptent taciti sociosqu " + "ad litora torquent per conubia nostra, per inceptos himenaeos. Curabitur sodales ligula " + "in libero. Sed dignissim lacinia nunc. Curabitur tortor. Pellentesque nibh. Aenean quam. " + "In scelerisque sem at dolor. Maecenas mattis. Sed convallis tristique sem. Proin ut " + "ligula vel nunc egestas porttitor. Morbi lectus risus, iaculis vel, suscipit quis, " + "luctus non, massa. Fusce ac turpis quis ligula lacinia aliquet. Mauris ipsum. Nulla " + "metus metus, ullamcorper vel, tincidunt sed, euismod in, nibh. Quisque volutpat " + "condimentum velit. Class aptent taciti sociosqu ad litora torquent per conubia nostra, " + "per inceptos himenaeos. Nam nec ante. Sed lacinia, urna non tincidunt mattis, tortor " + "neque adipiscing diam, a cursus ipsum ante quis turpis. Nulla facilisi. Ut fringilla. " + "Suspendisse potenti. Nunc feugiat mi a tellus consequat imperdiet. Vestibulum sapien. " + "Proin quam. Etiam ultrices. Suspendisse in justo eu magna luctus suscipit. Sed lectus. " + "Integer euismod lacus luctus magna. Quisque cursus, metus vitae pharetra auctor, sem " + 'massa mattis sem, at interdum magna augue eget diam."]', + ), + # long and complex json: 3k chars + ( + r"""{ + "web-app": { + "servlet": [ + { + "servlet-name": "cofaxCDS", + "servlet-class": "org.cofax.cds.CDSServlet", + "init-param": { + "configGlossary:installationAt": "Philadelphia, PA", + "configGlossary:adminEmail": "ksm@pobox.com", + "configGlossary:poweredBy": "Cofax", + "configGlossary:poweredByIcon": "/images/cofax.gif", + "configGlossary:staticPath": "/content/static", + "templateProcessorClass": "org.cofax.WysiwygTemplate", + "templateLoaderClass": "org.cofax.FilesTemplateLoader", + "templatePath": "templates", + "templateOverridePath": "", + "defaultListTemplate": "listTemplate.htm", + "defaultFileTemplate": "articleTemplate.htm", + "useJSP": false, + "jspListTemplate": "listTemplate.jsp", + "jspFileTemplate": "articleTemplate.jsp", + "cachePackageTagsTrack": 200, + "cachePackageTagsStore": 200, + "cachePackageTagsRefresh": 60, + "cacheTemplatesTrack": 100, + "cacheTemplatesStore": 50, + "cacheTemplatesRefresh": 15, + "cachePagesTrack": 200, + "cachePagesStore": 100, + "cachePagesRefresh": 10, + "cachePagesDirtyRead": 10, + "searchEngineListTemplate": "forSearchEnginesList.htm", + "searchEngineFileTemplate": "forSearchEngines.htm", + "searchEngineRobotsDb": "WEB-INF/robots.db", + "useDataStore": true, + "dataStoreClass": "org.cofax.SqlDataStore", + "redirectionClass": "org.cofax.SqlRedirection", + "dataStoreName": "cofax", + "dataStoreDriver": "com.microsoft.jdbc.sqlserver.SQLServerDriver", + "dataStoreUrl": "jdbc:microsoft:sqlserver://LOCALHOST:1433;DatabaseName=goon", + "dataStoreUser": "sa", + "dataStorePassword": "dataStoreTestQuery", + "dataStoreTestQuery": "SET NOCOUNT ON;select test='test';", + "dataStoreLogFile": "/usr/local/tomcat/logs/datastore.log", + "dataStoreInitConns": 10, + "dataStoreMaxConns": 100, + "dataStoreConnUsageLimit": 100, + "dataStoreLogLevel": "debug", + "maxUrlLength": 500 + } + }, + { + "servlet-name": "cofaxEmail", + "servlet-class": "org.cofax.cds.EmailServlet", + "init-param": { + "mailHost": "mail1", + "mailHostOverride": "mail2" + } + }, + { + "servlet-name": "cofaxAdmin", + "servlet-class": "org.cofax.cds.AdminServlet" + }, + { + "servlet-name": "fileServlet", + "servlet-class": "org.cofax.cds.FileServlet" + }, + { + "servlet-name": "cofaxTools", + "servlet-class": "org.cofax.cms.CofaxToolsServlet", + "init-param": { + "templatePath": "toolstemplates/", + "log": 1, + "logLocation": "/usr/local/tomcat/logs/CofaxTools.log", + "logMaxSize": "", + "dataLog": 1, + "dataLogLocation": "/usr/local/tomcat/logs/dataLog.log", + "dataLogMaxSize": "", + "removePageCache": "/content/admin/remove?cache=pages&id=", + "removeTemplateCache": "/content/admin/remove?cache=templates&id=", + "fileTransferFolder": "/usr/local/tomcat/webapps/content/fileTransferFolder", + "lookInContext": 1, + "adminGroupID": 4, + "betaServer": true + } + } + ], + "servlet-mapping": { + "cofaxCDS": "/", + "cofaxEmail": "/cofaxutil/aemail/*", + "cofaxAdmin": "/admin/*", + "fileServlet": "/static/*", + "cofaxTools": "/tools/*" + }, + "taglib": { + "taglib-uri": "cofax.tld", + "taglib-location": "/WEB-INF/tlds/cofax.tld" + } + } +} """, + ), +) + + +def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): + assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_pressure) + + +(input_find_rejected_tokens, expected_rejected_sizes) = tvm.testing.parameters( + ( + # short test + '{"id": 1,"name": "Example"} ', + [ + # fmt: off + 31989, 31907, 278, 278, 278, 31973, 31841, 31841, 31948, 31910, 278, 278, 278, 278, + 278, 31973, 31841, 31841, 271, 271, 271, 271, 271, 271, 271, 271, 31974, 31980, 31980 + # fmt: on + ], + ), + ( + # long test + """{ +"id": 1, +"na": "ex", +"ac": True, +"t": ["t1", "t2"], +"ne": {"lv2": {"val": "dp"}, "arr": [1, 2, 3]}, +"res": "res" +} +""", + [ + # fmt: off + 31989, 31907, 31907, 278, 278, 278, 31973, 31841, 31841, 31948, 31910, 31910, 278, 278, + 278, 31973, 31841, 31841, 271, 271, 271, 31974, 31910, 31910, 278, 278, 278, 31973, + 31841, 31841, 31841, 31841, 31841, 31841, 31841, 31841, 271, 271, 31974, 31974, 31974, + 31974, 31974, 31974, 31974, 31974, 31910, 31910, 278, 278, 278, 31973, 31973, 31973, + 31973, 31973, 31973, 31973, 31973, 31841, 31841, 31903, 278, 278, 278, 278, 31973, + 31841, 31841, 31901, 278, 278, 278, 278, 31973, 31841, 31841, 270, 270, 270, 31968, + 31970, 31910, 31910, 278, 278, 278, 278, 31973, 31841, 31841, 31835, 31943, 31841, + 31841, 31943, 31841, 31841, 31943, 31970, 31974, 31910, 31910, 278, 278, 278, 278, + 31973, 31841, 31841, 271, 271, 271, 271, 31974, 31974, 31980, 31980 + # fmt: on + ], + ), +) + + +def test_find_rejected_tokens( + json_grammar: BNFGrammar, input_find_rejected_tokens: str, expected_rejected_sizes: List[int] +): + tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + tokenizer = Tokenizer(tokenizer_path) + grammar_state_matcher = GrammarStateMatcher(json_grammar, tokenizer) + + real_sizes = [] + for c in input_find_rejected_tokens: + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() + real_sizes.append(len(rejected_token_ids)) + print("Accepting char:", c) + grammar_state_matcher.debug_accept_char(ord(c)) + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() + real_sizes.append(len(rejected_token_ids)) + assert real_sizes == expected_rejected_sizes + + +def test_accept_token(json_grammar: BNFGrammar): + token_table = [ + # fmt: off + "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', + # fmt: on + ] + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}", "\n"] + input_ids = [token_table.index(t) for t in input_splitted] + + grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) + + result = [] + + expected = [ + ["{"], + ['"', "}", "\n", " ", '"a":true'], + ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " "], + ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " "], + [":", "\n", " ", ':"'], + ['"', "{", "6", "\n", " "], + ["}", ", ", "6", "\n", " "], + [" ", "\n", '"', '"a":true'], + [" ", "\n", '"', '"a":true'], + ["}", ", ", "\n", " "], + ["", "\n", " "], + ["", "\n", " "], + ] + + for id in input_ids: + rejected = grammar_state_matcher.find_next_rejected_tokens() + accepted = list(set(range(len(token_table))) - set(rejected)) + accepted_tokens = [token_table[i] for i in accepted] + result.append(accepted_tokens) + assert id in accepted + grammar_state_matcher.accept_token(id) + + rejected = grammar_state_matcher.find_next_rejected_tokens() + accepted = list(set(range(len(token_table))) - set(rejected)) + accepted_tokens = [token_table[i] for i in accepted] + result.append(accepted_tokens) + + assert result == expected + + +def test_rollback(json_grammar: BNFGrammar): + token_table = [ + # fmt: off + "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', + # fmt: on + ] + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', " ", "}", "\n"] + input_ids = [token_table.index(t) for t in input_splitted] + + grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table, 5) + + assert grammar_state_matcher.max_rollback_steps() == 5 + + input_ids_splitted = [input_ids[i : i + 2] for i in range(0, len(input_ids), 2)] + + for i_1, i_2 in input_ids_splitted: + orig_result = [] + orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) + grammar_state_matcher.accept_token(i_1) + orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) + grammar_state_matcher.accept_token(i_2) + grammar_state_matcher.rollback(2) + result_after_rollback = [] + result_after_rollback.append(grammar_state_matcher.find_next_rejected_tokens()) + grammar_state_matcher.accept_token(i_1) + result_after_rollback.append(grammar_state_matcher.find_next_rejected_tokens()) + grammar_state_matcher.accept_token(i_2) + assert orig_result == result_after_rollback + + +def test_reset(json_grammar: BNFGrammar): + token_table = [ + # fmt: off + "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', + # fmt: on + ] + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', " ", "}", "\n"] + input_ids = [token_table.index(t) for t in input_splitted] + + grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) + + orig_result = [] + + for i in input_ids: + orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) + grammar_state_matcher.accept_token(i) + + grammar_state_matcher.reset_state() + + result_after_reset = [] + + for i in input_ids: + result_after_reset.append(grammar_state_matcher.find_next_rejected_tokens()) + grammar_state_matcher.accept_token(i) + + assert orig_result == result_after_reset + + +if __name__ == "__main__": + # Run a benchmark to show the performance before running tests + test_find_rejected_tokens( + BNFGrammar.get_grammar_of_json(), + '{"id": 1,"name": "Example"} ', + [ + # fmt: off + 31989, 31907, 278, 278, 278, 31973, 31841, 31841, 31948, 31910, 278, 278, 278, 278, + 278, 31973, 31841, 31841, 271, 271, 271, 271, 271, 271, 271, 271, 31974, 31980, 31980 + # fmt: on + ], + ) + + tvm.testing.main() From ce42880209d570c977d88bb173cd145dc0c37048 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 24 Feb 2024 15:06:23 -0500 Subject: [PATCH 004/531] [Serving] LogProbs support (#1832) This PR introduces the logprobs support with OpenAI API compatibility. It enhances the sampler with a function to get the top-probability tokens (supporting 5 tokens at most as of now). To make it easy to pass logprob results back from serving engine to frontend, we choose to pass logprob results in JSON string with OpenAI API spec. Unit tests are added to ensure the correctness of logprobs. And the logprobs support also work with speculative decoding. --- cpp/serve/config.cc | 14 ++ cpp/serve/config.h | 2 + cpp/serve/data.cc | 82 +++++++++ cpp/serve/data.h | 70 ++++++++ cpp/serve/engine.cc | 2 +- cpp/serve/engine_actions/action_commons.cc | 22 ++- cpp/serve/engine_actions/action_commons.h | 2 + cpp/serve/engine_actions/batch_decode.cc | 8 +- cpp/serve/engine_actions/batch_draft.cc | 18 +- cpp/serve/engine_actions/batch_verify.cc | 33 ++-- .../engine_actions/new_request_prefill.cc | 7 +- cpp/serve/logit_processor.cc | 10 +- cpp/serve/logit_processor.h | 10 +- cpp/serve/request.cc | 18 -- cpp/serve/request.h | 39 ---- cpp/serve/request_state.cc | 43 +++-- cpp/serve/request_state.h | 48 +++-- cpp/serve/sampler.cc | 168 +++++++++++++----- cpp/serve/sampler.h | 29 ++- .../mlc_chat/protocol/openai_api_protocol.py | 45 ++++- python/mlc_chat/serve/__init__.py | 5 +- python/mlc_chat/serve/async_engine.py | 38 ++-- python/mlc_chat/serve/config.py | 13 ++ python/mlc_chat/serve/data.py | 63 ++++++- python/mlc_chat/serve/engine.py | 40 +++-- .../serve/entrypoints/openai_entrypoints.py | 86 +++++++-- python/mlc_chat/serve/request.py | 48 +---- tests/python/serve/server/test_server.py | 37 ++++ tests/python/serve/test_serve_async_engine.py | 2 +- .../serve/test_serve_async_engine_spec.py | 2 +- tests/python/serve/test_serve_engine.py | 22 +-- tests/python/serve/test_serve_engine_spec.py | 22 +-- 32 files changed, 712 insertions(+), 336 deletions(-) diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 804ff9fe93..fde09ac32c 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -52,6 +52,18 @@ GenerationConfig::GenerationConfig(String config_json_str) { n->repetition_penalty = config["repetition_penalty"].get(); CHECK(n->repetition_penalty > 0) << "Repetition penalty must be a positive number!"; } + if (config.count("logprobs")) { + CHECK(config["logprobs"].is()); + n->logprobs = config["logprobs"].get(); + } + if (config.count("top_logprobs")) { + CHECK(config["top_logprobs"].is()); + n->top_logprobs = config["top_logprobs"].get(); + CHECK(n->top_logprobs >= 0 && n->top_logprobs <= 5) + << "At most 5 top logprob tokens are supported"; + CHECK(n->top_logprobs == 0 || n->logprobs) + << "\"logprobs\" must be true to support \"top_logprobs\""; + } if (config.count("logit_bias")) { CHECK(config["logit_bias"].is() || config["logit_bias"].is()); if (config["logit_bias"].is()) { @@ -128,6 +140,8 @@ String GenerationConfigNode::AsJSONString() const { config["frequency_penalty"] = picojson::value(this->frequency_penalty); config["presence_penalty"] = picojson::value(this->presence_penalty); config["repetition_penalty"] = picojson::value(this->repetition_penalty); + config["logprobs"] = picojson::value(this->logprobs); + config["top_logprobs"] = picojson::value(static_cast(this->top_logprobs)); config["max_tokens"] = picojson::value(static_cast(this->max_tokens)); config["seed"] = picojson::value(static_cast(this->seed)); diff --git a/cpp/serve/config.h b/cpp/serve/config.h index c9ebf0c847..9e316bf370 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -25,6 +25,8 @@ class GenerationConfigNode : public Object { double frequency_penalty = 0.0; double presence_penalty = 0.0; double repetition_penalty = 1.0; + bool logprobs = false; + int top_logprobs = 0; std::vector> logit_bias; int seed; bool ignore_eos = false; diff --git a/cpp/serve/data.cc b/cpp/serve/data.cc index 08d8afda3c..54e404ae1f 100644 --- a/cpp/serve/data.cc +++ b/cpp/serve/data.cc @@ -77,6 +77,88 @@ TVM_REGISTER_GLOBAL("mlc.serve.TokenDataGetTokenIds").set_body_typed([](TokenDat return data->token_ids; }); +/****************** SampleResult ******************/ + +/*! \brief Convert a single token with probability to JSON string. */ +inline void TokenToLogProbJSON(const Tokenizer& tokenizer, const TokenProbPair& token_prob, + std::ostringstream* os) { + const std::string& token = tokenizer->TokenTable()[token_prob.first]; + + (*os) << "\"token\": \""; + for (char ch : token) { + if (ch >= 33 && ch <= 126) { + // The character is in ASCII visible range. + // Handle escape characters in JSON. + if (ch == '"') { + (*os) << "\\\""; + } else if (ch == '\\') { + (*os) << "\\\\"; + } else { + (*os) << ch; + } + } + } + (*os) << "\", "; + (*os) << "\"logprob\": " << std::log(std::max(token_prob.second, 1e-10f)) << ", "; + (*os) << "\"bytes\": ["; + int token_len = token.size(); + for (int pos = 0; pos < token_len; ++pos) { + (*os) << static_cast(static_cast(token[pos])); + if (pos != token_len - 1) { + (*os) << ", "; + } + } + (*os) << "]"; +} + +std::string SampleResult::GetLogProbJSON(const Tokenizer& tokenizer, bool logprob) const { + ICHECK(top_prob_tokens.empty() || logprob); + if (!logprob) { + // Logprob is not needed. + return ""; + } + + std::ostringstream os; + os << "{"; + // - Convert the sampled token to JSON. + TokenToLogProbJSON(tokenizer, sampled_token_id, &os); + // - Convert the tokens with top probabilities. + os << ", \"top_logprobs\": ["; + int num_top = top_prob_tokens.size(); + for (int i = 0; i < num_top; ++i) { + os << "{"; + TokenToLogProbJSON(tokenizer, top_prob_tokens[i], &os); + os << "}"; + if (i != num_top - 1) { + os << ", "; + } + } + os << "]}"; + return os.str(); +} + +/****************** RequestStreamOutput ******************/ + +TVM_REGISTER_OBJECT_TYPE(RequestStreamOutputObj); + +RequestStreamOutput::RequestStreamOutput(String request_id, + const std::vector& delta_token_ids, + Optional> delta_logprob_json_strs, + Optional finish_reason) { + ObjectPtr n = make_object(); + n->request_id = std::move(request_id); + n->delta_token_ids = IntTuple{delta_token_ids.begin(), delta_token_ids.end()}; + n->delta_logprob_json_strs = std::move(delta_logprob_json_strs); + n->finish_reason = std::move(finish_reason); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("mlc.serve.RequestStreamOutputUnpack") + .set_body_typed([](RequestStreamOutput output) { + return Array{output->request_id, output->delta_token_ids, + output->delta_logprob_json_strs, output->finish_reason}; + }); + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/data.h b/cpp/serve/data.h index e097529df2..a63bdf81c4 100644 --- a/cpp/serve/data.h +++ b/cpp/serve/data.h @@ -5,11 +5,14 @@ #ifndef MLC_LLM_SERVE_DATA_H_ #define MLC_LLM_SERVE_DATA_H_ +#include #include #include #include #include +#include "../tokenizers.h" + namespace mlc { namespace llm { namespace serve { @@ -86,6 +89,73 @@ class TokenData : public Data { TVM_DEFINE_OBJECT_REF_METHODS(TokenData, Data, TokenDataNode); }; +/****************** SampleResult ******************/ + +// The pair of a token id and its probability in sampling. +using TokenProbPair = std::pair; + +/*! + * \brief The class of sampler's sampling result. + * It's not a TVM object since it will not be used directly on Python side. + */ +struct SampleResult { + /*! \brief The token id and probability of the sampled token. */ + TokenProbPair sampled_token_id; + /*! \brief The token id and probability of the tokens with top probabilities. */ + std::vector top_prob_tokens; + + /*! + * \brief Get the logprob JSON string of this token with regard + * to OpenAI API at https://platform.openai.com/docs/api-reference/chat/object. + * \param tokenizer The tokenizer for token table lookup. + * \param logprob A boolean indicating if need to return log probability. + * \return A JSON string that conforms to the logprob spec in OpenAI API. + */ + std::string GetLogProbJSON(const Tokenizer& tokenizer, bool logprob) const; +}; + +/****************** RequestStreamOutput ******************/ + +/*! + * \brief The generated delta request output that is streamed back + * through callback stream function. + */ +class RequestStreamOutputObj : public Object { + public: + /*! \brief The id of the request that the function is invoked for. */ + String request_id; + /*! + * \brief The new generated token ids since the last callback invocation + * for the input request. + */ + IntTuple delta_token_ids; + /*! \brief The logprobs JSON strings of the new generated tokens since last invocation. */ + Optional> delta_logprob_json_strs; + /*! + * \brief The finish reason of the request when it is finished, + * of None if the request has not finished yet. + */ + Optional finish_reason; + + static constexpr const char* _type_key = "mlc.serve.RequestStreamOutput"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_FINAL_OBJECT_INFO(RequestStreamOutputObj, Object); +}; + +/*! + * \brief Managed reference to RequestStreamOutputObj. + * \sa RequestStreamOutputObj + */ +class RequestStreamOutput : public ObjectRef { + public: + explicit RequestStreamOutput(String request_id, const std::vector& delta_token_ids, + Optional> delta_logprob_json_strs, + Optional finish_reason); + + TVM_DEFINE_OBJECT_REF_METHODS(RequestStreamOutput, ObjectRef, RequestStreamOutputObj); +}; + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 28b1e70006..5c2e2f0be9 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -178,7 +178,7 @@ class EngineImpl : public Engine { for (EngineAction action : actions_) { Array processed_requests = action->Step(estate_); if (!processed_requests.empty()) { - ActionStepPostProcess(processed_requests, estate_, models_, + ActionStepPostProcess(processed_requests, estate_, models_, tokenizer_, request_stream_callback_.value(), max_single_sequence_length_); return; } diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index f5344e9a0e..5526bed2d1 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -46,6 +46,7 @@ void ProcessFinishedRequest(Array finished_requests, EngineState estate } void ActionStepPostProcess(Array requests, EngineState estate, Array models, + const Tokenizer& tokenizer, FRequestStreamCallback request_stream_callback, int max_single_sequence_length) { Array finished_requests; @@ -57,15 +58,18 @@ void ActionStepPostProcess(Array requests, EngineState estate, ArrayGetRequestState(request); - auto [delta_token_ids, finish_reason] = rstate->GetReturnTokenIds(max_single_sequence_length); + auto [delta_token_ids, delta_logprob_json_strs, finish_reason] = + rstate->GetReturnTokenIds(tokenizer, max_single_sequence_length); // When there is no new delta tokens nor a finish reason, no need to invoke callback. if (delta_token_ids.empty() && !finish_reason.defined()) { continue; } - callback_delta_outputs.push_back( - RequestStreamOutput(request->id, TokenData(delta_token_ids), finish_reason)); + callback_delta_outputs.push_back(RequestStreamOutput( + request->id, delta_token_ids, + request->generation_cfg->logprobs > 0 ? delta_logprob_json_strs : Optional>(), + finish_reason)); if (finish_reason.defined()) { finished_requests.push_back(request); } @@ -91,21 +95,23 @@ void PreemptLastRunningRequest(EngineState estate, const Array& models, request->input_total_length + rstate->mstates[0]->committed_tokens.size() - 1; for (RequestModelState mstate : rstate->mstates) { mstate->RemoveAllDraftTokens(); - mstate->draft_output_token_prob.clear(); - mstate->draft_output_prob_dist.clear(); ICHECK(mstate->inputs.empty()); ICHECK(!mstate->committed_tokens.empty()); + std::vector committed_token_ids; + committed_token_ids.reserve(mstate->committed_tokens.size()); + for (const SampleResult& committed_token : mstate->committed_tokens) { + committed_token_ids.push_back(committed_token.sampled_token_id.first); + } Array inputs = request->inputs; if (const auto* token_input = inputs.back().as()) { // Merge the TokenData so that a single time TokenEmbed is needed. std::vector token_ids{token_input->token_ids->data, token_input->token_ids->data + token_input->token_ids.size()}; - token_ids.insert(token_ids.end(), mstate->committed_tokens.begin(), - mstate->committed_tokens.end()); + token_ids.insert(token_ids.end(), committed_token_ids.begin(), committed_token_ids.end()); inputs.Set(inputs.size() - 1, TokenData(token_ids)); } else { - inputs.push_back(TokenData(mstate->committed_tokens)); + inputs.push_back(TokenData(committed_token_ids)); } mstate->inputs = std::move(inputs); } diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index c629a15296..520180beff 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -35,11 +35,13 @@ void RemoveRequestFromModel(EngineState estate, int64_t req_internal_id, Array requests, EngineState estate, Array models, + const Tokenizer& tokenizer, FRequestStreamCallback request_stream_callback, int max_single_sequence_length); diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index 627e46bc9a..d7821020a1 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -67,7 +67,7 @@ class BatchDecodeActionObj : public EngineActionObj { rngs.reserve(num_requests); for (Request request : estate->running_queue) { RequestState rstate = estate->GetRequestState(request); - input_tokens.push_back(rstate->mstates[0]->committed_tokens.back()); + input_tokens.push_back(rstate->mstates[0]->committed_tokens.back().sampled_token_id.first); request_ids.push_back(request->id); request_internal_ids.push_back(rstate->mstates[0]->internal_id); mstates.push_back(rstate->mstates[0]); @@ -102,13 +102,13 @@ class BatchDecodeActionObj : public EngineActionObj { logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); // - Sample tokens. - std::vector next_tokens = + std::vector sample_results = sampler_->BatchSampleTokens(probs_device, request_ids, generation_cfg, rngs); - ICHECK_EQ(next_tokens.size(), num_requests); + ICHECK_EQ(sample_results.size(), num_requests); // - Update the committed tokens of states. for (int i = 0; i < num_requests; ++i) { - mstates[i]->CommitToken(next_tokens[i]); + mstates[i]->CommitToken(sample_results[i]); } auto tend = std::chrono::high_resolution_clock::now(); diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index 403350c4af..d9eba8e037 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -80,8 +80,9 @@ class BatchDraftActionObj : public EngineActionObj { input_tokens.clear(); for (int i = 0; i < num_requests; ++i) { // The first draft proposal uses the last committed token. - input_tokens.push_back(draft_id == 0 ? mstates[i]->committed_tokens.back() - : mstates[i]->draft_output_tokens.back()); + input_tokens.push_back( + draft_id == 0 ? mstates[i]->committed_tokens.back().sampled_token_id.first + : mstates[i]->draft_output_tokens.back().sampled_token_id.first); } // - Compute embeddings. @@ -113,16 +114,13 @@ class BatchDraftActionObj : public EngineActionObj { // - Sample tokens. std::vector prob_dist; - std::vector token_probs; - std::vector next_tokens = sampler_->BatchSampleTokens( - probs_device, request_ids, generation_cfg, rngs, &prob_dist, &token_probs); - ICHECK_EQ(next_tokens.size(), num_requests); + std::vector sample_results = sampler_->BatchSampleTokens( + probs_device, request_ids, generation_cfg, rngs, &prob_dist); + ICHECK_EQ(sample_results.size(), num_requests); - // - Update the draft tokens, prob dist, token probs of states. + // - Add draft token to the state. for (int i = 0; i < num_requests; ++i) { - mstates[i]->AddDraftToken(next_tokens[i]); - mstates[i]->draft_output_prob_dist.push_back(prob_dist[i]); - mstates[i]->draft_output_token_prob.push_back(token_probs[i]); + mstates[i]->AddDraftToken(sample_results[i], prob_dist[i]); estate->stats.total_draft_length += 1; } } diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index e4aa836127..b608c5b3b3 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -59,8 +59,7 @@ class BatchVerifyActionObj : public EngineActionObj { Array verify_request_mstates; Array generation_cfg; std::vector rngs; - std::vector> draft_output_tokens; - std::vector> draft_output_token_prob; + std::vector> draft_output_tokens; std::vector> draft_output_prob_dist; request_internal_ids.reserve(num_requests); all_tokens_to_verify.reserve(total_draft_length); @@ -68,7 +67,6 @@ class BatchVerifyActionObj : public EngineActionObj { rngs.reserve(num_requests); generation_cfg.reserve(num_requests); draft_output_tokens.reserve(num_requests); - draft_output_token_prob.reserve(num_requests); draft_output_prob_dist.reserve(num_requests); for (int i = 0; i < num_requests; ++i) { @@ -77,18 +75,16 @@ class BatchVerifyActionObj : public EngineActionObj { request_internal_ids.push_back(verify_mstate->internal_id); ICHECK(!draft_lengths.empty()); ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_tokens.size()); - ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_token_prob.size()); ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_prob_dist.size()); // the last committed token + all the draft tokens but the last one. - all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back()); - all_tokens_to_verify.insert(all_tokens_to_verify.end(), - draft_mstate->draft_output_tokens.begin(), - draft_mstate->draft_output_tokens.end() - 1); + all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().sampled_token_id.first); + for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()) - 1; ++j) { + all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); + } verify_request_mstates.push_back(verify_mstate); generation_cfg.push_back(requests[i]->generation_cfg); rngs.push_back(&rstates[i]->rng); draft_output_tokens.push_back(draft_mstate->draft_output_tokens); - draft_output_token_prob.push_back(draft_mstate->draft_output_token_prob); draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); } @@ -118,16 +114,17 @@ class BatchVerifyActionObj : public EngineActionObj { NDArray probs_device = logit_processor_->ComputeProbsFromLogits( logits, generation_cfg, request_ids, &cum_verify_lengths); - std::vector> accepted_tokens_arr = sampler_->BatchVerifyDraftTokens( - probs_device, request_ids, cum_verify_lengths, verify_request_mstates, generation_cfg, rngs, - draft_output_tokens, draft_output_token_prob, draft_output_prob_dist); - ICHECK_EQ(accepted_tokens_arr.size(), num_requests); + std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokens( + probs_device, request_ids, cum_verify_lengths, generation_cfg, rngs, draft_output_tokens, + draft_output_prob_dist); + ICHECK_EQ(sample_results_arr.size(), num_requests); for (int i = 0; i < num_requests; ++i) { - const std::vector& accepted_tokens = accepted_tokens_arr[i]; - int accept_length = accepted_tokens.size(); - for (int32_t token_id : accepted_tokens) { - rstates[i]->mstates[draft_model_id_]->CommitToken(token_id); + const std::vector& sample_results = sample_results_arr[i]; + int accept_length = sample_results.size(); + for (SampleResult sample_result : sample_results) { + rstates[i]->mstates[verify_model_id_]->CommitToken(sample_result); + rstates[i]->mstates[draft_model_id_]->CommitToken(sample_result); } estate->stats.current_total_seq_len += accept_length; estate->stats.total_accepted_length += accept_length; @@ -149,8 +146,6 @@ class BatchVerifyActionObj : public EngineActionObj { // clear the draft model states for (int i = 0; i < num_requests; ++i) { rstates[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(); - rstates[i]->mstates[draft_model_id_]->draft_output_token_prob.clear(); - rstates[i]->mstates[draft_model_id_]->draft_output_prob_dist.clear(); } auto tend = std::chrono::high_resolution_clock::now(); diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index a3f1b2d17c..72f54388e7 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -59,7 +59,6 @@ class NewRequestPrefillActionObj : public EngineActionObj { RequestModelState mstate = rstates[i]->mstates[model_id]; ICHECK_EQ(mstate->GetInputLength(), prefill_lengths[i]); ICHECK(mstate->draft_output_tokens.empty()); - ICHECK(mstate->draft_output_token_prob.empty()); ICHECK(mstate->draft_output_prob_dist.empty()); ICHECK(!mstate->inputs.empty()); // Add the sequence to the model. @@ -111,9 +110,9 @@ class NewRequestPrefillActionObj : public EngineActionObj { logit_processor_->ComputeProbsFromLogits(logits_for_sample, generation_cfg, request_ids); // - Sample tokens. - std::vector next_tokens = + std::vector sample_results = sampler_->BatchSampleTokens(probs_device, request_ids, generation_cfg, rngs); - ICHECK_EQ(next_tokens.size(), num_requests); + ICHECK_EQ(sample_results.size(), num_requests); // - Update the committed tokens of states. // - If a request is first-time prefilled, set the prefill finish time. @@ -122,7 +121,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { auto tnow = std::chrono::high_resolution_clock::now(); for (int i = 0; i < num_requests; ++i) { for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { - rstates[i]->mstates[model_id]->CommitToken(next_tokens[i]); + rstates[i]->mstates[model_id]->CommitToken(sample_results[i]); } if (mstates_for_sample[i]->committed_tokens.size() == 1) { rstates[i]->tprefill_finish = tnow; diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index a45c1f9f13..24ce003fe3 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -67,7 +67,7 @@ class LogitProcessorImpl : public LogitProcessorObj { const Array& mstates, // const Array& request_ids, // const std::vector* cum_num_token, // - const std::vector>* draft_tokens) final { + const std::vector>* draft_tokens) final { CHECK_EQ(logits->ndim, 2); CHECK_EQ(logits->shape[1], vocab_size_); CHECK(logits.DataType() == DataType::Float(32)); @@ -219,7 +219,7 @@ class LogitProcessorImpl : public LogitProcessorObj { void UpdateWithPenalty(NDArray logits, const Array& generation_cfg, const Array& mstates, const std::vector* cum_num_token, - const std::vector>* draft_tokens) { + const std::vector>* draft_tokens) { // Construct: // - seq_ids (max_num_token,) int32 // - pos2seq_id (max_num_token * vocab_size,) int32 @@ -256,7 +256,7 @@ class LogitProcessorImpl : public LogitProcessorObj { p_penalties[num_token_for_penalty * 3 + 2] = generation_cfg[i]->repetition_penalty; ++num_token_for_penalty; if (j > 0) { - mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1]); + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray()); } } if (num_token_to_process != 1) { @@ -301,7 +301,7 @@ class LogitProcessorImpl : public LogitProcessorObj { void UpdateWithMask(NDArray logits, const Array& mstates, const std::vector* cum_num_token, - const std::vector>* draft_tokens) { + const std::vector>* draft_tokens) { // Construct: // - seq_ids (max_num_token,) int32 // - bitmask (max_num_token, ceildiv(vocab_size, 32)), int32 @@ -326,7 +326,7 @@ class LogitProcessorImpl : public LogitProcessorObj { ++num_token_for_mask; } if (j > 0) { - mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1]); + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray()); } } if (num_token_to_process != 1) { diff --git a/cpp/serve/logit_processor.h b/cpp/serve/logit_processor.h index 2425542731..915f101218 100644 --- a/cpp/serve/logit_processor.h +++ b/cpp/serve/logit_processor.h @@ -45,11 +45,11 @@ class LogitProcessorObj : public Object { * when speculation is enabled, in which case some sequences may have * more than one token. */ - virtual void InplaceUpdateLogits(NDArray logits, const Array& generation_cfg, - const Array& mstates, - const Array& request_ids, - const std::vector* cum_num_token = nullptr, - const std::vector>* draft_tokens = nullptr) = 0; + virtual void InplaceUpdateLogits( + NDArray logits, const Array& generation_cfg, + const Array& mstates, const Array& request_ids, + const std::vector* cum_num_token = nullptr, + const std::vector>* draft_tokens = nullptr) = 0; /*! * \brief Compute probability distributions for the input batch of logits. diff --git a/cpp/serve/request.cc b/cpp/serve/request.cc index e727d8ebf7..25162d79fb 100644 --- a/cpp/serve/request.cc +++ b/cpp/serve/request.cc @@ -78,24 +78,6 @@ TVM_REGISTER_GLOBAL("mlc.serve.RequestGetGenerationConfigJSON").set_body_typed([ return request->generation_cfg->AsJSONString(); }); -/****************** RequestStreamOutput ******************/ - -TVM_REGISTER_OBJECT_TYPE(RequestStreamOutputObj); - -RequestStreamOutput::RequestStreamOutput(String request_id, TokenData delta_tokens, - Optional finish_reason) { - ObjectPtr n = make_object(); - n->request_id = std::move(request_id); - n->delta_tokens = std::move(delta_tokens); - n->finish_reason = std::move(finish_reason); - data_ = std::move(n); -} - -TVM_REGISTER_GLOBAL("mlc.serve.RequestStreamOutputUnpack") - .set_body_typed([](RequestStreamOutput output) { - return Array{output->request_id, output->delta_tokens, output->finish_reason}; - }); - } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/request.h b/cpp/serve/request.h index bdc3224f91..fb1eda7fd9 100644 --- a/cpp/serve/request.h +++ b/cpp/serve/request.h @@ -76,45 +76,6 @@ class Request : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Request, ObjectRef, RequestNode); }; -/****************** RequestStreamOutput ******************/ - -/*! - * \brief The generated delta request output that is streamed back - * through callback stream function. - */ -class RequestStreamOutputObj : public Object { - public: - /*! \brief The id of the request that the function is invoked for. */ - String request_id; - /*! - * \brief The new generated tokens since the last callback invocation - * for the input request. - */ - TokenData delta_tokens; - /*! - * \brief The finish reason of the request when it is finished, - * of None if the request has not finished yet. - */ - Optional finish_reason; - - static constexpr const char* _type_key = "mlc.serve.RequestStreamOutput"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_FINAL_OBJECT_INFO(RequestStreamOutputObj, Object); -}; - -/*! - * \brief Managed reference to RequestStreamOutputObj. - * \sa RequestStreamOutputObj - */ -class RequestStreamOutput : public ObjectRef { - public: - explicit RequestStreamOutput(String request_id, TokenData delta_tokens, - Optional finish_reason); - - TVM_DEFINE_OBJECT_REF_METHODS(RequestStreamOutput, ObjectRef, RequestStreamOutputObj); -}; - } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index b721d32ac6..cea6af7bff 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -36,20 +36,22 @@ std::vector RequestModelStateNode::GetTokenBitmask(int vocab_size) const { return std::vector(); } -void RequestModelStateNode::CommitToken(int32_t token_id) { - committed_tokens.push_back(token_id); - appeared_token_ids[token_id] += 1; +void RequestModelStateNode::CommitToken(SampleResult sampled_token) { + committed_tokens.push_back(std::move(sampled_token)); + appeared_token_ids[sampled_token.sampled_token_id.first] += 1; } -void RequestModelStateNode::AddDraftToken(int32_t token_id) { - draft_output_tokens.push_back(token_id); - appeared_token_ids[token_id] += 1; +void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, NDArray prob_dist) { + draft_output_tokens.push_back(std::move(sampled_token)); + draft_output_prob_dist.push_back(std::move(prob_dist)); + appeared_token_ids[sampled_token.sampled_token_id.first] += 1; } void RequestModelStateNode::RemoveLastDraftToken() { ICHECK(!draft_output_tokens.empty()); - auto it = appeared_token_ids.find(draft_output_tokens.back()); + auto it = appeared_token_ids.find(draft_output_tokens.back().sampled_token_id.first); draft_output_tokens.pop_back(); + draft_output_prob_dist.pop_back(); CHECK(it != appeared_token_ids.end()); if (--it->second == 0) { appeared_token_ids.erase(it); @@ -83,19 +85,20 @@ RequestState::RequestState(Request request, int num_models, int64_t internal_id, data_ = std::move(n); } -std::pair, Optional> RequestStateNode::GetReturnTokenIds( - int max_single_sequence_length) { +DeltaRequestReturn RequestStateNode::GetReturnTokenIds(const Tokenizer& tokenizer, + int max_single_sequence_length) { // - Case 0. There is remaining draft output ==> Unfinished // All draft outputs are supposed to be processed before finish. for (RequestModelState mstate : mstates) { if (!mstate->draft_output_tokens.empty()) { - return {{}, Optional()}; + return {{}, {}, Optional()}; } } std::vector return_token_ids; + std::vector logprob_json_strs; Optional finish_reason; - const std::vector& committed_tokens = mstates[0]->committed_tokens; + const std::vector& committed_tokens = mstates[0]->committed_tokens; int num_committed_tokens = committed_tokens.size(); ICHECK_LE(this->next_callback_token_pos, num_committed_tokens); @@ -103,7 +106,10 @@ std::pair, Optional> RequestStateNode::GetReturnTok ICHECK(!stop_str_handler->StopTriggered()); while (next_callback_token_pos < num_committed_tokens) { std::vector delta_token_ids = - stop_str_handler->Put(committed_tokens[next_callback_token_pos++]); + stop_str_handler->Put(committed_tokens[next_callback_token_pos].sampled_token_id.first); + logprob_json_strs.push_back(committed_tokens[next_callback_token_pos].GetLogProbJSON( + tokenizer, request->generation_cfg->logprobs)); + ++next_callback_token_pos; return_token_ids.insert(return_token_ids.end(), delta_token_ids.begin(), delta_token_ids.end()); if (stop_str_handler->StopTriggered()) { finish_reason = "stop"; @@ -131,25 +137,24 @@ std::pair, Optional> RequestStateNode::GetReturnTok } if (finish_reason.defined()) { - return {return_token_ids, finish_reason}; + return {return_token_ids, logprob_json_strs, finish_reason}; } // Case 3. Generation reaches the specified max generation length ==> Finished // `max_tokens` means the generation length is limited by model capacity. if (request->generation_cfg->max_tokens >= 0 && - static_cast(committed_tokens.size()) >= request->generation_cfg->max_tokens) { + num_committed_tokens >= request->generation_cfg->max_tokens) { std::vector remaining = stop_str_handler->Finish(); return_token_ids.insert(return_token_ids.end(), remaining.begin(), remaining.end()); - return {return_token_ids, String("length")}; + return {return_token_ids, logprob_json_strs, String("length")}; } // Case 4. Total length of the request reaches the maximum single sequence length ==> Finished - if (request->input_total_length + static_cast(committed_tokens.size()) >= - max_single_sequence_length) { + if (request->input_total_length + num_committed_tokens >= max_single_sequence_length) { std::vector remaining = stop_str_handler->Finish(); return_token_ids.insert(return_token_ids.end(), remaining.begin(), remaining.end()); - return {return_token_ids, String("length")}; + return {return_token_ids, logprob_json_strs, String("length")}; } - return {return_token_ids, Optional()}; + return {return_token_ids, logprob_json_strs, Optional()}; } } // namespace serve diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index ea0b688810..134d1df4bd 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -43,21 +43,22 @@ class RequestModelStateNode : public Object { /*! \brief The corresponding model id of this state. */ int model_id = -1; /*! - * \brief The committed generated token ids. A token is "committed" - * means it will no longer be updated (or changed). + * \brief The committed generated token ids and related probability info. + * A token is "committed" means it will no longer be updated (or changed). */ - std::vector committed_tokens; + std::vector committed_tokens; /*! \brief The list of input data yet for the model to prefill. */ Array inputs; // NOTE: The following fields are reserved for future speculative inference // settings, and are produced by the speculative small models. /*! - * \brief The draft generated token ids, which are usually generated - * by "small" speculative models. These tokens will be fed to a "large" - * model to determine the final result of speculation. + * \brief The draft generated token ids and related probability info, + * which are usually generated by "small" speculative models. + * These tokens will be fed to a "large" model to determine the final + * result of speculation. */ - std::vector draft_output_tokens; + std::vector draft_output_tokens; /*! * \brief The probability distribution on each position in the * draft. We keep the distributions for stochastic sampling when merging @@ -66,16 +67,6 @@ class RequestModelStateNode : public Object { * and draft outputs in speculative inference settings. */ std::vector draft_output_prob_dist; - /*! - * \brief The probability of the sampled token on each position in the - * draft. We keep the probabilities for stochastic sampling when merging - * speculations from multiple models. - * - * \note `draft_token_prob` can be inferred from `draft_tokens` and - * `draft_prob_dist`, but we still keep it so that we can have option - * choosing only to use one between them. - */ - std::vector draft_output_token_prob; /*! \brief The appeared committed and draft tokens and their occurrence times. */ std::unordered_map appeared_token_ids; @@ -87,9 +78,9 @@ class RequestModelStateNode : public Object { */ std::vector GetTokenBitmask(int vocab_size) const; /*! \brief Commit a new token into committed_tokens. Update appeared_token_ids. */ - void CommitToken(int32_t token_id); + void CommitToken(SampleResult sampled_token); /*! \brief Add a draft token into draft_output_tokens. Update appeared_token_ids. */ - void AddDraftToken(int32_t token_id); + void AddDraftToken(SampleResult sampled_token, NDArray prob_dist); /*! \brief Remove the last token from draft_output_tokens. Update appeared_token_ids. */ void RemoveLastDraftToken(); /*! \brief Remove all draft tokens from draft_output_tokens. Update appeared_token_ids. */ @@ -109,6 +100,12 @@ class RequestModelState : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestModelState, ObjectRef, RequestModelStateNode); }; +struct DeltaRequestReturn { + std::vector delta_token_ids; + Array delta_logprob_json_strs; + Optional finish_reason; +}; + class RequestStateNode : public Object { public: /*! \brief The request that this state corresponds to. */ @@ -134,14 +131,15 @@ class RequestStateNode : public Object { std::chrono::high_resolution_clock::time_point tprefill_finish; /*! - * \brief Get the delta token ids for this request to return since - * the last time calling into this function, and return the finish - * reason if the request generation has finished. + * \brief Get the delta token ids and the logprob JSON strings for this + * request to return since the last time calling into this function, + * and return the finish reason if the request generation has finished. + * \param tokenizer The tokenizer for logprob process. * \param max_single_sequence_length The maximum allowed single sequence length. - * \return The delta token ids to return, and the optional finish reason. + * \return The delta token ids to return, the logprob JSON strings of each + * delta token id, and the optional finish reason. */ - std::pair, Optional> GetReturnTokenIds( - int max_single_sequence_length); + DeltaRequestReturn GetReturnTokenIds(const Tokenizer& tokenizer, int max_single_sequence_length); static constexpr const char* _type_key = "mlc.serve.RequestState"; static constexpr const bool _type_has_method_sequal_reduce = false; diff --git a/cpp/serve/sampler.cc b/cpp/serve/sampler.cc index 502bde72e6..6a6bb65de9 100644 --- a/cpp/serve/sampler.cc +++ b/cpp/serve/sampler.cc @@ -28,13 +28,12 @@ namespace serve { * \param uniform_sample The random number in [0, 1] for sampling. * \param output_prob_dist Optional pointer to store the corresponding probability distribution of * each token, offset by unit_offset. If nullptr provided, nothing will be stored out. - * \return The sampled prob and value. + * \return The sampled value and probability. * \note This function is an enhancement of SampleTopPFromProb in TVM Unity. * We will upstream the enhancement after it gets stable. */ -std::pair SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, - double uniform_sample, - std::vector* output_prob_dist = nullptr) { +TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, double uniform_sample, + std::vector* output_prob_dist = nullptr) { // prob: (*, v) // The prob array may have arbitrary ndim and shape. // The last dimension corresponds to the prob distribution size. @@ -66,11 +65,17 @@ std::pair SampleTopPFromProb(NDArray prob, int unit_offset, doub // This case is equivalent to doing argmax. int argmax_pos = -1; float max_prob = 0.0; + float sum_prob = 0.0; for (int i = 0; i < ndata; ++i) { if (p_prob[i] > max_prob) { max_prob = p_prob[i]; argmax_pos = i; } + // Early exit. + sum_prob += p_prob[i]; + if (1 - sum_prob <= max_prob) { + break; + } } if (output_prob_dist) { float* __restrict p_output_prob = @@ -79,7 +84,7 @@ std::pair SampleTopPFromProb(NDArray prob, int unit_offset, doub p_output_prob[i] = i == argmax_pos ? 1.0 : 0.0; } } - return std::make_pair(1.0, argmax_pos); + return {argmax_pos, 1.0}; } if (output_prob_dist) { @@ -92,7 +97,7 @@ std::pair SampleTopPFromProb(NDArray prob, int unit_offset, doub for (int64_t i = 0; i < ndata; ++i) { prob_sum += p_prob[i]; if (prob_sum >= uniform_sample) { - return std::make_pair(p_prob[i], i); + return {i, p_prob[i]}; } } ICHECK(false) << "Possibly prob distribution contains NAN."; @@ -170,13 +175,77 @@ std::pair SampleTopPFromProb(NDArray prob, int unit_offset, doub // usually it is much less by applying this filtering(order of 10 - 20) data.reserve(256); std::pair sampled_index = sample_top_p_with_filter(top_p / 1024); - if (sampled_index.second >= 0) return sampled_index; + if (sampled_index.second >= 0) return {sampled_index.second, sampled_index.first}; } // fallback via full prob, rare case data.reserve(ndata); std::pair sampled_index = sample_top_p_with_filter(0.0f); ICHECK_GE(sampled_index.second, 0); - return sampled_index; + return {sampled_index.second, sampled_index.first}; +} + +namespace detail { + +/*! \brief Implementation of getting top probs on CPU. */ +template +std::vector ComputeTopProbsImpl(const float* p_prob, int ndata) { + std::vector top_probs; + top_probs.reserve(num_top_probs); + for (int i = 0; i < num_top_probs; ++i) { + top_probs.emplace_back(-1, -1.0f); + } + + float sum_prob = 0.0; + // Selection argsort. + for (int p = 0; p < ndata; ++p) { + int i = num_top_probs - 1; + for (; i >= 0; --i) { + if (p_prob[p] > top_probs[i].second) { + if (i != num_top_probs - 1) { + top_probs[i + 1] = top_probs[i]; + } + } else { + break; + } + } + if (i != num_top_probs - 1) { + top_probs[i + 1] = {p, p_prob[p]}; + } + + // Early exit. + sum_prob += p_prob[p]; + if (1 - sum_prob <= top_probs[num_top_probs - 1].second) { + break; + } + } + return top_probs; +} + +} // namespace detail + +/*! \brief Get the probs of a few number of tokens with top probabilities. */ +inline std::vector ComputeTopProbs(NDArray prob, int unit_offset, + int num_top_probs) { + ICHECK_LE(num_top_probs, 5); + ICHECK_EQ(prob->ndim, 2); + int ndata = prob->shape[1]; + const float* __restrict p_prob = + static_cast(__builtin_assume_aligned(prob->data, 4)) + (unit_offset * ndata); + switch (num_top_probs) { + case 0: + return {}; + case 1: + return detail::ComputeTopProbsImpl<1>(p_prob, ndata); + case 2: + return detail::ComputeTopProbsImpl<2>(p_prob, ndata); + case 3: + return detail::ComputeTopProbsImpl<3>(p_prob, ndata); + case 4: + return detail::ComputeTopProbsImpl<4>(p_prob, ndata); + case 5: + return detail::ComputeTopProbsImpl<5>(p_prob, ndata); + } + throw; } /********************* CPU Sampler *********************/ @@ -193,12 +262,11 @@ class CPUSampler : public SamplerObj { } } - std::vector BatchSampleTokens(NDArray probs_device, // - const Array& request_ids, - const Array& generation_cfg, - const std::vector& rngs, - std::vector* output_prob_dist, - std::vector* output_token_probs) final { + std::vector BatchSampleTokens(NDArray probs_device, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + std::vector* output_prob_dist) final { // probs_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); CHECK_EQ(probs_device->ndim, 2); @@ -213,40 +281,39 @@ class CPUSampler : public SamplerObj { ICHECK_EQ(probs_host->shape[0], rngs.size()); int n = probs_host->shape[0]; - std::vector sampled_tokens; - sampled_tokens.resize(n); + std::vector sample_results; + sample_results.resize(n); if (output_prob_dist) { output_prob_dist->resize(n); } - if (output_token_probs) { - output_token_probs->resize(n); - } tvm::runtime::parallel_for_with_threading_backend( - [this, &sampled_tokens, &probs_host, &generation_cfg, &rngs, &request_ids, output_prob_dist, - output_token_probs](int i) { + [this, &sample_results, &probs_host, &generation_cfg, &rngs, &request_ids, + output_prob_dist](int i) { RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); // Sample top p from probability. - std::pair sample_result = SampleTopPFromProb( + sample_results[i].sampled_token_id = SampleTopPFromProb( probs_host, i, generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, rngs[i]->GetRandomNumber(), output_prob_dist); - sampled_tokens[i] = sample_result.second; - if (output_token_probs) { - (*output_token_probs)[i] = sample_result.first; + if (output_prob_dist == nullptr) { + // When `output_prob_dist` is not nullptr, it means right now + // we are sampling for a small model in speculation, in which + // case we do not need to get the top probs. + sample_results[i].top_prob_tokens = + ComputeTopProbs(probs_host, i, generation_cfg[i]->top_logprobs); } RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token"); }, 0, n); RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); - return sampled_tokens; + return sample_results; } - std::vector> BatchVerifyDraftTokens( + std::vector> BatchVerifyDraftTokens( NDArray probs_device, const Array& request_ids, - const std::vector& cum_verify_lengths, const Array& request_mstates, - const Array& generation_cfg, const std::vector& rngs, - const std::vector>& draft_output_tokens, - const std::vector>& draft_output_token_prob, + const std::vector& cum_verify_lengths, const Array& generation_cfg, + const std::vector& rngs, + const std::vector>& draft_output_tokens, const std::vector>& draft_output_prob_dist) final { // probs_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); @@ -259,11 +326,10 @@ class CPUSampler : public SamplerObj { int num_sequence = static_cast(cum_verify_lengths.size()) - 1; CHECK_EQ(rngs.size(), num_sequence); CHECK_EQ(draft_output_tokens.size(), num_sequence); - CHECK_EQ(draft_output_token_prob.size(), num_sequence); CHECK_EQ(draft_output_prob_dist.size(), num_sequence); - std::vector> accepted_tokens; - accepted_tokens.resize(num_sequence); + std::vector> sample_results; + sample_results.resize(num_sequence); float* __restrict global_p_probs = static_cast(__builtin_assume_aligned(probs_host->data, 4)); @@ -275,19 +341,23 @@ class CPUSampler : public SamplerObj { int verify_end = cum_verify_lengths[i + 1]; for (int cur_token_idx = 0; cur_token_idx < verify_end - verify_start; ++cur_token_idx) { float* p_probs = global_p_probs + (verify_start + cur_token_idx) * vocab_size; - int cur_token = draft_output_tokens[i][cur_token_idx]; - float q_value = draft_output_token_prob[i][cur_token_idx]; + int cur_token = draft_output_tokens[i][cur_token_idx].sampled_token_id.first; + float q_value = draft_output_tokens[i][cur_token_idx].sampled_token_id.second; float p_value = p_probs[cur_token]; if (p_value >= q_value) { - request_mstates[i]->CommitToken(cur_token); - accepted_tokens[i].push_back(cur_token); + sample_results[i].push_back( + SampleResult{{cur_token, p_value}, + ComputeTopProbs(probs_host, verify_start + cur_token_idx, + generation_cfg[i]->top_logprobs)}); continue; } float r = rngs[i]->GetRandomNumber(); if (r < p_value / (q_value + eps_)) { - request_mstates[i]->CommitToken(cur_token); - accepted_tokens[i].push_back(cur_token); + sample_results[i].push_back( + SampleResult{{cur_token, p_value}, + ComputeTopProbs(probs_host, verify_start + cur_token_idx, + generation_cfg[i]->top_logprobs)}); continue; } @@ -309,20 +379,20 @@ class CPUSampler : public SamplerObj { } // sample a new token from the new distribution - int32_t new_token = - SampleTopPFromProb( - probs_host, verify_start + cur_token_idx, - generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, - rngs[i]->GetRandomNumber()) - .second; - request_mstates[i]->CommitToken(new_token); - accepted_tokens[i].push_back(cur_token); + SampleResult sample_result; + sample_result.sampled_token_id = SampleTopPFromProb( + probs_host, verify_start + cur_token_idx, + generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, + rngs[i]->GetRandomNumber()); + sample_result.top_prob_tokens = ComputeTopProbs( + probs_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); + sample_results[i].push_back(sample_result); break; } }, 0, num_sequence); RECORD_EVENT(trace_recorder_, request_ids, "finish draft verification"); - return accepted_tokens; + return sample_results; } private: diff --git a/cpp/serve/sampler.h b/cpp/serve/sampler.h index ac4820db64..6f9c6acf47 100644 --- a/cpp/serve/sampler.h +++ b/cpp/serve/sampler.h @@ -12,6 +12,7 @@ #include "../base.h" #include "../random.h" +#include "data.h" #include "event_trace_recorder.h" #include "model.h" #include "request_state.h" @@ -39,14 +40,15 @@ class SamplerObj : public Object { * in the input batch. * \param rngs The random number generator of each sequence. * \param output_prob_dist The output probability distribution - * \param output_token_probs The output token probabilities - * \return The sampled tokens, one for each request in the batch. + * \return The batch of sampling results, which contain the sampled token id + * and other probability info. */ - virtual std::vector BatchSampleTokens( - NDArray probs_device, const Array& request_ids, - const Array& generation_cfg, const std::vector& rngs, - std::vector* output_prob_dist = nullptr, - std::vector* output_token_probs = nullptr) = 0; + virtual std::vector BatchSampleTokens( + NDArray probs_device, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + std::vector* output_prob_dist = nullptr) = 0; /*! * \brief Verify draft tokens generated by small models in the large model @@ -54,25 +56,20 @@ class SamplerObj : public Object { * \param probs_device The prob distributions on GPU to sample tokens from. * \param request_ids The id of each request. * \param cum_verify_lengths The cumulative draft lengths to verify of all sequences. - * \param request_mstates The request states of each sequence in - * the batch with regard to the large model. * \param generation_cfg The generation config of each request * in the input batch. * \param rngs The random number generator of each sequence. * \param draft_output_tokens The draft tokens generated by the small model for * each sequence. - * \param draft_output_token_prob The draft tokens' probabilities computed from - * the small model for each sequence. * \param draft_output_prob_dist The probability distribution computed from the * small model for each sequence. * \return The list of accepted tokens for each request. */ - virtual std::vector> BatchVerifyDraftTokens( + virtual std::vector> BatchVerifyDraftTokens( NDArray probs_device, const Array& request_ids, - const std::vector& cum_verify_lengths, const Array& request_mstates, - const Array& generation_cfg, const std::vector& rngs, - const std::vector>& draft_output_tokens, - const std::vector>& draft_output_token_prob, + const std::vector& cum_verify_lengths, const Array& generation_cfg, + const std::vector& rngs, + const std::vector>& draft_output_tokens, const std::vector>& draft_output_prob_dist) = 0; static constexpr const char* _type_key = "mlc.serve.Sampler"; diff --git a/python/mlc_chat/protocol/openai_api_protocol.py b/python/mlc_chat/protocol/openai_api_protocol.py index 36b75f81a5..2ae26bf752 100644 --- a/python/mlc_chat/protocol/openai_api_protocol.py +++ b/python/mlc_chat/protocol/openai_api_protocol.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union import shortuuid -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator ################ Commons ################ @@ -18,8 +18,21 @@ class ListResponse(BaseModel): data: List[Any] +class TopLogProbs(BaseModel): + token: str + logprob: float + bytes: Optional[List[int]] + + +class LogProbsContent(BaseModel): + token: str + logprob: float + bytes: Optional[List[int]] + top_logprobs: List[TopLogProbs] = [] + + class LogProbs(BaseModel): - pass + content: List[LogProbsContent] class UsageInfo(BaseModel): @@ -63,8 +76,9 @@ class CompletionRequest(BaseModel): echo: bool = False frequency_penalty: float = 0.0 presence_penalty: float = 0.0 + logprobs: bool = False + top_logprobs: int = 0 logit_bias: Optional[Dict[int, float]] = None - logprobs: Optional[int] = None max_tokens: int = 16 n: int = 1 seed: Optional[int] = None @@ -100,6 +114,15 @@ def check_logit_bias( ) return logit_bias_value + @model_validator(mode="after") + def check_logprobs(self) -> "CompletionRequest": + """Check if the logprobs requirements are valid.""" + if self.top_logprobs < 0 or self.top_logprobs > 5: + raise ValueError('"top_logprobs" must be in range [0, 5]') + if not self.logprobs and self.top_logprobs > 0: + raise ValueError('"logprobs" must be True to support "top_logprobs"') + return self + class CompletionResponseChoice(BaseModel): finish_reason: Optional[Literal["stop", "length"]] = None @@ -165,6 +188,8 @@ class ChatCompletionRequest(BaseModel): model: str frequency_penalty: float = 0.0 presence_penalty: float = 0.0 + logprobs: bool = False + top_logprobs: int = 0 logit_bias: Optional[Dict[int, float]] = None max_tokens: Optional[int] = None n: int = 1 @@ -203,17 +228,28 @@ def check_logit_bias( ) return logit_bias_value + @model_validator(mode="after") + def check_logprobs(self) -> "ChatCompletionRequest": + """Check if the logprobs requirements are valid.""" + if self.top_logprobs < 0 or self.top_logprobs > 5: + raise ValueError('"top_logprobs" must be in range [0, 5]') + if not self.logprobs and self.top_logprobs > 0: + raise ValueError('"logprobs" must be True to support "top_logprobs"') + return self + class ChatCompletionResponseChoice(BaseModel): finish_reason: Optional[Literal["stop", "length", "tool_calls", "error"]] = None index: int = 0 message: ChatCompletionMessage + logprobs: Optional[LogProbs] = None class ChatCompletionStreamResponseChoice(BaseModel): finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None index: int = 0 delta: ChatCompletionMessage + logprobs: Optional[LogProbs] = None class ChatCompletionResponse(BaseModel): @@ -254,7 +290,6 @@ def openai_api_get_unsupported_fields( """Get the unsupported fields in the request.""" unsupported_field_default_values: List[Tuple[str, Any]] = [ ("best_of", 1), - ("logprobs", None), ("n", 1), ("response_format", "text"), ] @@ -277,6 +312,8 @@ def openai_api_get_generation_config( "max_tokens", "frequency_penalty", "presence_penalty", + "logprobs", + "top_logprobs", "logit_bias", "seed", "ignore_eos", diff --git a/python/mlc_chat/serve/__init__.py b/python/mlc_chat/serve/__init__.py index 8e31ae5f7e..59185ec520 100644 --- a/python/mlc_chat/serve/__init__.py +++ b/python/mlc_chat/serve/__init__.py @@ -1,10 +1,11 @@ """Subdirectory of serving.""" + # Load MLC LLM library by importing base from .. import base from .async_engine import AsyncThreadedEngine from .config import EngineMode, GenerationConfig, KVCacheConfig -from .data import Data, TextData, TokenData +from .data import Data, RequestStreamOutput, TextData, TokenData from .engine import Engine from .grammar import BNFGrammar, GrammarStateMatcher -from .request import Request, RequestStreamOutput +from .request import Request from .server import PopenServer diff --git a/python/mlc_chat/serve/async_engine.py b/python/mlc_chat/serve/async_engine.py index d478add478..74058ea314 100644 --- a/python/mlc_chat/serve/async_engine.py +++ b/python/mlc_chat/serve/async_engine.py @@ -15,7 +15,7 @@ from .config import EngineMode, GenerationConfig, KVCacheConfig from .engine import ModelInfo, _estimate_max_total_sequence_length, _process_model_args from .event_trace_recorder import EventTraceRecorder -from .request import Request, RequestStreamOutput +from .request import Request class AsyncRequestStream: @@ -31,13 +31,13 @@ class AsyncRequestStream: """ # The asynchronous queue to hold elements of - # - either a tuple of (str, int, Optional[str]), denoting the - # delta output text, the number of delta tokens, the optional - # finish reason respectively, + # - either a tuple of (str, int, List[str], Optional[str]), denoting the + # delta output text, the number of delta tokens, the logprob JSON strings + # of delta tokens, and the optional finish reason respectively, # - or an exception. if sys.version_info >= (3, 9): _queue: asyncio.Queue[ # pylint: disable=unsubscriptable-object - Union[Tuple[str, int, Optional[str]], Exception] + Union[Tuple[str, int, Optional[List[str]], Optional[str]], Exception] ] else: _queue: asyncio.Queue @@ -48,7 +48,10 @@ def __init__(self) -> None: self._queue = asyncio.Queue() self._finished = False - def push(self, item_or_exception: Union[Tuple[str, int, Optional[str]], Exception]) -> None: + def push( + self, + item_or_exception: Union[Tuple[str, int, Optional[List[str]], Optional[str]], Exception], + ) -> None: """Push a new token to the stream.""" if self._finished: # No new item is expected after finish. @@ -69,7 +72,7 @@ def finish(self) -> None: def __aiter__(self): return self - async def __anext__(self) -> Tuple[str, int, Optional[str]]: + async def __anext__(self) -> Tuple[str, int, Optional[List[str]], Optional[str]]: result = await self._queue.get() if isinstance(result, StopIteration): raise StopAsyncIteration @@ -183,12 +186,13 @@ def terminate(self): async def generate( self, prompt: Union[str, List[int]], generation_config: GenerationConfig, request_id: str - ) -> AsyncGenerator[Tuple[str, int, str], Any]: + ) -> AsyncGenerator[Tuple[str, int, Optional[List[str]], Optional[str]], Any]: """Asynchronous text generation interface. The method is a coroutine that streams a tuple at a time via yield. Each tuple is contained of - the delta text in type str, - the number of delta tokens in type int, + - the logprob JSON strings of delta tokens, - the optional finish reason in type Optional[str]. Parameters @@ -252,15 +256,15 @@ def _abort(self, request_id: str): self._request_tools.pop(request_id, None) self._ffi["abort_request"](request_id) - def _request_stream_callback(self, delta_outputs: List[RequestStreamOutput]) -> None: + def _request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: """The request stream callback function for engine to stream back the request generation results. Parameters ---------- - delta_outputs : List[RequestStreamOutput] + delta_outputs : List[data.RequestStreamOutput] The delta output of each requests. - Check out RequestStreamOutput for the fields of the outputs. + Check out data.RequestStreamOutput for the fields of the outputs. Note ---- @@ -275,10 +279,15 @@ def _request_stream_callback(self, delta_outputs: List[RequestStreamOutput]) -> self._request_stream_callback_impl, delta_outputs ) - def _request_stream_callback_impl(self, delta_outputs: List[RequestStreamOutput]) -> None: + def _request_stream_callback_impl(self, delta_outputs: List[data.RequestStreamOutput]) -> None: """The underlying implementation of request stream callback.""" for delta_output in delta_outputs: - request_id, delta_tokens, finish_reason = delta_output.unpack() + ( + request_id, + delta_token_ids, + delta_logprob_json_strs, + finish_reason, + ) = delta_output.unpack() tools = self._request_tools.get(request_id, None) if tools is None: continue @@ -287,14 +296,13 @@ def _request_stream_callback_impl(self, delta_outputs: List[RequestStreamOutput] stream, text_streamer = tools self.record_event(request_id, event="start detokenization") - delta_token_ids = delta_tokens.token_ids delta_text = text_streamer.put(delta_token_ids) if finish_reason is not None: delta_text += text_streamer.finish() self.record_event(request_id, event="finish detokenization") # Push new delta text to the stream. - stream.push((delta_text, len(delta_token_ids), finish_reason)) + stream.push((delta_text, len(delta_token_ids), delta_logprob_json_strs, finish_reason)) if finish_reason is not None: stream.finish() self._request_tools.pop(request_id, None) diff --git a/python/mlc_chat/serve/config.py b/python/mlc_chat/serve/config.py index 1962b61215..ccc152ab36 100644 --- a/python/mlc_chat/serve/config.py +++ b/python/mlc_chat/serve/config.py @@ -1,4 +1,5 @@ """Configuration dataclasses used in MLC LLM serving""" + import json from dataclasses import asdict, dataclass, field from typing import Dict, List, Optional @@ -31,6 +32,16 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes It will be suppressed when any of frequency_penalty and presence_penalty is non-zero. + logprobs : bool + Whether to return log probabilities of the output tokens or not. + If true, the log probabilities of each output token will be returned. + + top_logprobs : int + An integer between 0 and 5 specifying the number of most likely + tokens to return at each token position, each with an associated + log probability. + `logprobs` must be set to True if this parameter is used. + logit_bias : Optional[Dict[int, float]] The bias logit value added to selected tokens prior to sampling. @@ -59,6 +70,8 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes frequency_penalty: float = 0.0 presence_penalty: float = 0.0 repetition_penalty: float = 1.0 + logprobs: bool = False + top_logprobs: int = 0 logit_bias: Optional[Dict[int, float]] = field(default_factory=dict) max_tokens: Optional[int] = 128 diff --git a/python/mlc_chat/serve/data.py b/python/mlc_chat/serve/data.py index 75a18f4097..15c0a4f205 100644 --- a/python/mlc_chat/serve/data.py +++ b/python/mlc_chat/serve/data.py @@ -1,5 +1,6 @@ """Classes denoting multi-modality data used in MLC LLM serving""" -from typing import List + +from typing import List, Optional, Tuple import tvm._ffi from tvm.runtime import Object @@ -54,3 +55,63 @@ def __init__(self, token_ids: List[int]): def token_ids(self) -> List[int]: """Return the token ids of the TokenData.""" return list(_ffi_api.TokenDataGetTokenIds(self)) # type: ignore # pylint: disable=no-member + + +@tvm._ffi.register_object("mlc.serve.RequestStreamOutput") # pylint: disable=protected-access +class RequestStreamOutput(Object): + """The generated delta request output that is streamed back + through callback stream function. + It contains four fields (in order): + + request_id : str + The id of the request that the function is invoked for. + + delta_tokens : List[int] + The new generated tokens since the last callback invocation + for the input request. + + delta_logprob_json_strs : Optional[List[str]] + The logprobs JSON strings of the new generated tokens + since last invocation. + + finish_reason : Optional[str] + The finish reason of the request when it is finished, + of None if the request has not finished yet. + + Note + ---- + We do not provide constructor, since in practice only C++ side + instantiates this class. + """ + + def unpack(self) -> Tuple[str, List[int], Optional[List[str]], Optional[str]]: + """Return the fields of the delta output in a tuple. + + Returns + ------- + request_id : str + The id of the request that the function is invoked for. + + delta_tokens : List[int] + The new generated tokens since the last callback invocation + for the input request. + + delta_logprob_json_strs : Optional[List[str]] + The logprobs JSON strings of the new generated tokens + since last invocation. + + finish_reason : Optional[str] + The finish reason of the request when it is finished, + of None if the request has not finished yet. + """ + fields = _ffi_api.RequestStreamOutputUnpack(self) # type: ignore # pylint: disable=no-member + return ( + str(fields[0]), + list(fields[1]), + ( + [str(logprob_json_str) for logprob_json_str in fields[2]] + if fields[2] is not None + else None + ), + str(fields[3]) if fields[3] is not None else None, + ) diff --git a/python/mlc_chat/serve/engine.py b/python/mlc_chat/serve/engine.py index 5d34afa5dc..407fb72f17 100644 --- a/python/mlc_chat/serve/engine.py +++ b/python/mlc_chat/serve/engine.py @@ -22,7 +22,7 @@ from . import data from .config import EngineMode, GenerationConfig, KVCacheConfig from .event_trace_recorder import EventTraceRecorder -from .request import Request, RequestStreamOutput +from .request import Request logging.enable_logging() logger = logging.getLogger(__name__) @@ -269,7 +269,7 @@ def __init__( # pylint: disable=too-many-arguments models: Union[ModelInfo, List[ModelInfo]], kv_cache_config: KVCacheConfig, engine_mode: Optional[EngineMode] = None, - request_stream_callback: Optional[Callable[[List[RequestStreamOutput]], None]] = None, + request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, enable_tracing: bool = False, ): if isinstance(models, ModelInfo): @@ -329,7 +329,7 @@ def generate( self, prompts: Union[str, List[str], List[int], List[List[int]]], generation_config: Union[GenerationConfig, List[GenerationConfig]], - ) -> List[str]: + ) -> Tuple[List[str], List[Optional[List[str]]]]: """Generate texts for a list of input prompts. Each prompt can be a string or a list of token ids. The generation for each prompt is independent. @@ -350,8 +350,12 @@ def generate( Returns ------- - results : List[str] + output_text : List[str] The text generation results, one string for each input prompt. + + output_logprobs_str : List[Optional[List[str]]] + The logprob strings of each token for each input prompt, or None + if an input prompt does not require logprobs. """ if isinstance(prompts, str): # `prompts` is a single string. @@ -362,7 +366,7 @@ def generate( "str, a list of token ids or multiple lists of token ids." ) if len(prompts) == 0: - return [] + return [], [] if isinstance(prompts[0], int): # `prompts` is a list of token ids prompts = [prompts] # type: ignore @@ -376,10 +380,12 @@ def generate( ), "Number of generation config and number of prompts mismatch" num_finished_requests = 0 - outputs: List[str] = [] + output_texts: List[str] = [] + output_logprobs_str: List[Optional[List[str]]] = [] text_streamers: List[TextStreamer] = [] - for _ in range(num_requests): - outputs.append("") + for i in range(num_requests): + output_texts.append("") + output_logprobs_str.append([] if generation_config[i].logprobs else None) text_streamers.append(TextStreamer(self.tokenizer)) # Save a copy of the original function callback since `generate` @@ -388,18 +394,26 @@ def generate( original_callback = self._ffi["get_request_stream_callback"]() # Define the callback function for request generation results - def request_stream_callback(delta_outputs: List[RequestStreamOutput]): + def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]): nonlocal num_finished_requests for delta_output in delta_outputs: - request_id, delta_tokens, finish_reason = delta_output.unpack() + ( + request_id, + delta_token_ids, + delta_logprob_json_strs, + finish_reason, + ) = delta_output.unpack() rid = int(request_id) text_streamer = text_streamers[rid] + if output_logprobs_str[rid] is not None: + assert delta_logprob_json_strs is not None + output_logprobs_str[rid] += delta_logprob_json_strs - delta_text = text_streamer.put(delta_tokens.token_ids) + delta_text = text_streamer.put(delta_token_ids) if finish_reason is not None: delta_text += text_streamer.finish() - outputs[rid] += delta_text + output_texts[rid] += delta_text if finish_reason is not None: num_finished_requests += 1 @@ -426,7 +440,7 @@ def request_stream_callback(delta_outputs: List[RequestStreamOutput]): # Restore the callback function in engine. self._ffi["set_request_stream_callback"](original_callback) - return outputs + return output_texts, output_logprobs_str def add_request(self, request: Request) -> None: """Add a new request to the engine. diff --git a/python/mlc_chat/serve/entrypoints/openai_entrypoints.py b/python/mlc_chat/serve/entrypoints/openai_entrypoints.py index 20027deed4..de85ab83f3 100644 --- a/python/mlc_chat/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_chat/serve/entrypoints/openai_entrypoints.py @@ -23,6 +23,8 @@ CompletionResponse, CompletionResponseChoice, ListResponse, + LogProbs, + LogProbsContent, ModelResponse, UsageInfo, ) @@ -109,9 +111,12 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: num_completion_tokens = 0 finish_reason = None async_engine.record_event(request_id, event="invoke generate") - async for delta_text, num_delta_tokens, finish_reason in async_engine.generate( - prompt, generation_cfg, request_id - ): + async for ( + delta_text, + num_delta_tokens, + delta_logprob_json_strs, + finish_reason, + ) in async_engine.generate(prompt, generation_cfg, request_id): num_completion_tokens += num_delta_tokens if delta_text == "": # Ignore empty delta text -- do not yield. @@ -123,6 +128,16 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: CompletionResponseChoice( finish_reason=finish_reason, text=delta_text, + logprobs=( + LogProbs( + content=[ + LogProbsContent.model_validate_json(logprob_json_str) + for logprob_json_str in delta_logprob_json_strs + ] + ) + if delta_logprob_json_strs is not None + else None + ), ) ], model=request.model, @@ -163,10 +178,14 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: output_text = "" if not request.echo else async_engine.tokenizer.decode(prompt) num_completion_tokens = 0 finish_reason: Optional[str] = None + logprob_json_strs: Optional[List[str]] = [] if generation_cfg.logprobs else None async_engine.record_event(request_id, event="invoke generate") - async for delta_text, num_delta_tokens, finish_reason in async_engine.generate( - prompt, generation_cfg, request_id - ): + async for ( + delta_text, + num_delta_tokens, + delta_logprob_json_strs, + finish_reason, + ) in async_engine.generate(prompt, generation_cfg, request_id): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified # when the request is disconnected. @@ -178,6 +197,9 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) output_text += delta_text num_completion_tokens += num_delta_tokens + if logprob_json_strs is not None: + assert delta_logprob_json_strs is not None + logprob_json_strs += delta_logprob_json_strs assert finish_reason is not None suffix = request.suffix if request.suffix is not None else "" async_engine.record_event(request_id, event="finish") @@ -187,6 +209,16 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: CompletionResponseChoice( finish_reason=finish_reason, text=output_text + suffix, + logprobs=( + LogProbs( + content=[ + LogProbsContent.model_validate_json(logprob_json_str) + for logprob_json_str in logprob_json_strs + ] + ) + if logprob_json_strs is not None + else None + ), ) ], model=request.model, @@ -378,9 +410,12 @@ async def request_chat_completion( async def completion_stream_generator() -> AsyncGenerator[str, None]: assert request.n == 1 async_engine.record_event(request_id, event="invoke generate") - async for delta_text, _, finish_reason in async_engine.generate( - prompt, generation_cfg, request_id - ): + async for ( + delta_text, + _, + delta_logprob_json_strs, + finish_reason, + ) in async_engine.generate(prompt, generation_cfg, request_id): if delta_text == "": async_engine.record_event(request_id, event="skip empty delta text") # Ignore empty delta text -- do not yield. @@ -395,6 +430,16 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ChatCompletionStreamResponseChoice( finish_reason=finish_reason, delta=ChatCompletionMessage(content=delta_text, role="assistant"), + logprobs=( + LogProbs( + content=[ + LogProbsContent.model_validate_json(logprob_json_str) + for logprob_json_str in delta_logprob_json_strs + ] + ) + if delta_logprob_json_strs is not None + else None + ), ) ], model=request.model, @@ -413,10 +458,14 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: output_text = "" num_completion_tokens = 0 finish_reason: Optional[str] = None + logprob_json_strs: Optional[List[str]] = [] if generation_cfg.logprobs else None async_engine.record_event(request_id, event="invoke generate") - async for delta_text, num_delta_tokens, finish_reason in async_engine.generate( - prompt, generation_cfg, request_id - ): + async for ( + delta_text, + num_delta_tokens, + delta_logprob_json_strs, + finish_reason, + ) in async_engine.generate(prompt, generation_cfg, request_id): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified # when the request is disconnected. @@ -428,6 +477,9 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) output_text += delta_text num_completion_tokens += num_delta_tokens + if logprob_json_strs is not None: + assert delta_logprob_json_strs is not None + logprob_json_strs += delta_logprob_json_strs assert finish_reason is not None async_engine.record_event(request_id, event="finish") @@ -467,6 +519,16 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ChatCompletionResponseChoice( finish_reason=finish_reason, message=message, + logprobs=( + LogProbs( + content=[ + LogProbsContent.model_validate_json(logprob_json_str) + for logprob_json_str in logprob_json_strs + ] + ) + if logprob_json_strs is not None + else None + ), ) ], model=request.model, diff --git a/python/mlc_chat/serve/request.py b/python/mlc_chat/serve/request.py index f725a1c6d1..5c2d8ad196 100644 --- a/python/mlc_chat/serve/request.py +++ b/python/mlc_chat/serve/request.py @@ -1,12 +1,13 @@ """The request class in MLC LLM serving""" -from typing import List, Optional, Tuple, Union + +from typing import List, Union import tvm._ffi from tvm.runtime import Object from . import _ffi_api from .config import GenerationConfig -from .data import Data, TokenData +from .data import Data @tvm._ffi.register_object("mlc.serve.Request") # pylint: disable=protected-access @@ -55,46 +56,3 @@ def generation_config(self) -> GenerationConfig: return GenerationConfig.from_json( _ffi_api.RequestGetGenerationConfigJSON(self) # type: ignore # pylint: disable=no-member ) - - -@tvm._ffi.register_object("mlc.serve.RequestStreamOutput") # pylint: disable=protected-access -class RequestStreamOutput(Object): - """The generated delta request output that is streamed back - through callback stream function. - It contains three fields (in order): - - request_id : str - The id of the request that the function is invoked for. - - delta_tokens : data.TokenData - The new generated tokens since the last callback invocation - for the input request. - - finish_reason : Optional[str] - The finish reason of the request when it is finished, - of None if the request has not finished yet. - - Note - ---- - We do not provide constructor, since in practice only C++ side - instantiates this class. - """ - - def unpack(self) -> Tuple[str, TokenData, Optional[str]]: - """Return the fields of the delta output in a tuple. - - Returns - ------- - request_id : str - The id of the request that the function is invoked for. - - delta_tokens : data.TokenData - The new generated tokens since the last callback invocation - for the input request. - - finish_reason : Optional[str] - The finish reason of the request when it is finished, - of None if the request has not finished yet. - """ - fields = _ffi_api.RequestStreamOutputUnpack(self) # type: ignore # pylint: disable=no-member - return str(fields[0]), fields[1], str(fields[2]) if fields[2] is not None else None diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index 0721e97190..a30b744018 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -646,6 +646,39 @@ def test_openai_v1_completions_prompt_overlong( assert num_chunks == 1 +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_invalid_logprobs( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + payload = { + "model": served_model[0], + "prompt": "What is the meaning of life?", + "max_tokens": 256, + "stream": stream, + "logprobs": False, + "top_logprobs": 4, + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + assert response.json()["detail"][0]["msg"].endswith( + '"logprobs" must be True to support "top_logprobs"' + ) + + payload["logprobs"] = True + payload["top_logprobs"] = 6 + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + assert response.json()["detail"][0]["msg"].endswith('"top_logprobs" must be in range [0, 5]') + + def test_openai_v1_completions_unsupported_args( served_model: Tuple[str, str], launch_server, # pylint: disable=unused-argument @@ -783,6 +816,8 @@ def test_openai_v1_chat_completions_openai_package( model=served_model[0], messages=messages, stream=stream, + logprobs=True, + top_logprobs=2, ) if not stream: check_openai_nonstream_response( @@ -981,6 +1016,8 @@ def test_debug_dump_event_trace( test_openai_v1_completions_seed(MODEL, None) test_openai_v1_completions_prompt_overlong(MODEL, None, stream=False) test_openai_v1_completions_prompt_overlong(MODEL, None, stream=True) + test_openai_v1_completions_invalid_logprobs(MODEL, None, stream=False) + test_openai_v1_completions_invalid_logprobs(MODEL, None, stream=True) test_openai_v1_completions_unsupported_args(MODEL, None) test_openai_v1_completions_request_cancellation(MODEL, None) diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index 70ddc00f72..df8e64bec0 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -44,7 +44,7 @@ async def generate_task( ): print(f"generate task for request {request_id}") rid = int(request_id) - async for delta_text, num_delta_tokens, finish_reason in async_engine.generate( + async for delta_text, _, _, _ in async_engine.generate( prompt, generation_cfg, request_id=request_id ): outputs[rid] += delta_text diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index e314e5fc46..89a113d1bb 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -54,7 +54,7 @@ async def generate_task( ): print(f"generate task for request {request_id}") rid = int(request_id) - async for delta_text, num_delta_tokens, finish_reason in async_engine.generate( + async for delta_text, _, _, _ in async_engine.generate( prompt, generation_cfg, request_id=request_id ): outputs[rid] += delta_text diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 5d65978f10..373a97a743 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -87,8 +87,8 @@ def test_engine_basic(): # Define the callback function for request generation results def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_tokens, _ = delta_output.unpack() - outputs[int(request_id)] += delta_tokens.token_ids + request_id, delta_token_ids, _, _ = delta_output.unpack() + outputs[int(request_id)] += delta_token_ids # Create engine engine = Engine(model, kv_cache_config, request_stream_callback=fcallback) @@ -153,10 +153,10 @@ class CallbackTimer: def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_tokens, finish_reason = delta_output.unpack() + request_id, delta_token_ids, _, finish_reason = delta_output.unpack() if finish_reason is not None: print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += delta_tokens.token_ids + outputs[int(request_id)] += delta_token_ids finish_time[int(request_id)] = self.timer return fcallback @@ -231,10 +231,10 @@ class CallbackTimer: def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_tokens, finish_reason = delta_output.unpack() + request_id, delta_token_ids, _, finish_reason = delta_output.unpack() if finish_reason is not None: print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += delta_tokens.token_ids + outputs[int(request_id)] += delta_token_ids finish_time[int(request_id)] = self.timer return fcallback @@ -312,11 +312,11 @@ class CallbackTimer: def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_tokens, finish_reason = delta_output.unpack() + request_id, delta_token_ids, _, finish_reason = delta_output.unpack() if finish_reason is not None: print(f"Request {request_id} finished at step {self.timer}.") self.finished_requests += 1 - outputs[int(request_id)] += delta_tokens.token_ids + outputs[int(request_id)] += delta_token_ids finish_time[int(request_id)] = self.timer return fcallback @@ -376,8 +376,10 @@ def test_engine_generate(): max_tokens = 256 # Generate output. - outputs = engine.generate(prompts[:num_requests], GenerationConfig(max_tokens=max_tokens)) - for req_id, output in enumerate(outputs): + output_texts, _ = engine.generate( + prompts[:num_requests], GenerationConfig(max_tokens=max_tokens) + ) + for req_id, output in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") print(f"Output {req_id}:{output}\n") diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 6bb8c9e08d..1eee361fd8 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -93,8 +93,8 @@ def test_engine_basic(): # Define the callback function for request generation results def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_tokens, _ = delta_output.unpack() - outputs[int(request_id)] += delta_tokens.token_ids + request_id, delta_token_ids, _, _ = delta_output.unpack() + outputs[int(request_id)] += delta_token_ids # Create engine engine = Engine([model, ssm], kv_cache_config, engine_mode, fcallback) @@ -164,10 +164,10 @@ class CallbackTimer: def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_tokens, finish_reason = delta_output.unpack() + request_id, delta_token_ids, _, finish_reason = delta_output.unpack() if finish_reason is not None: print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += delta_tokens.token_ids + outputs[int(request_id)] += delta_token_ids finish_time[int(request_id)] = self.timer return fcallback @@ -224,8 +224,10 @@ def test_engine_generate(): max_tokens = 256 # Generate output. - outputs = engine.generate(prompts[:num_requests], GenerationConfig(max_tokens=max_tokens)) - for req_id, output in enumerate(outputs): + output_texts, _ = engine.generate( + prompts[:num_requests], GenerationConfig(max_tokens=max_tokens) + ) + for req_id, output in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") print(f"Output {req_id}:{output}\n") @@ -253,8 +255,8 @@ def test_engine_efficiency(): # Define the callback function for request generation results def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_tokens, _ = delta_output.unpack() - outputs[int(request_id)] += delta_tokens.token_ids + request_id, delta_token_ids, _, _ = delta_output.unpack() + outputs[int(request_id)] += delta_token_ids # Create engine engine = Engine(model, kv_cache_config, request_stream_callback=fcallback) @@ -324,8 +326,8 @@ def test_engine_spec_efficiency(): # Define the callback function for request generation results def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_tokens, _ = delta_output.unpack() - outputs[int(request_id)] += delta_tokens.token_ids + request_id, delta_token_ids, _, _ = delta_output.unpack() + outputs[int(request_id)] += delta_token_ids # Create engine spec_engine = Engine([model, ssm], kv_cache_config, engine_mode, fcallback) From 1cbd67b4eefe8f9cd8f4a1d798221483ec120fe9 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 26 Feb 2024 15:38:42 -0500 Subject: [PATCH 005/531] [Serving] Support Mixtral in MLC Serve (#1840) This PR supports Mixtral in MLC serve. The main thing is only introducing the Mistral conversation template to Python registry so that MLC Serve can use. Besides that, this PR updates the KV cache capacity analysis to make it more accurate in terms of usage calculation, while being conservative since there is a known issue regarding batch-prefill embedding taking which may lead to OOM. We will reset the follow up on the issue with a fix in the future and then enable the estimation to use more GPU vRAM. --- python/mlc_chat/conversation_template.py | 17 ++++ python/mlc_chat/serve/async_engine.py | 2 +- python/mlc_chat/serve/engine.py | 55 +++++++++--- tests/python/serve/server/test_server.py | 106 +++++++++++------------ 4 files changed, 112 insertions(+), 68 deletions(-) diff --git a/python/mlc_chat/conversation_template.py b/python/mlc_chat/conversation_template.py index 6ca148f021..9ec0a6bfee 100644 --- a/python/mlc_chat/conversation_template.py +++ b/python/mlc_chat/conversation_template.py @@ -51,6 +51,23 @@ def get_conv_template(name: str) -> Optional[Conversation]: ) ) +# Mistral default +ConvTemplateRegistry.register_conv_template( + Conversation( + name="mistral_default", + system_template=f"[INST] {MessagePlaceholders.SYSTEM.value}\n\n ", + system_message="Always assist with care, respect, and truth. Respond with utmost " + "utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. " + "Ensure replies promote fairness and positivity.", + roles={"user": "[INST]", "assistant": "[/INST]", "tool": "[INST]"}, + seps=[" "], + role_content_sep=" ", + role_empty_sep="", + stop_str=[""], + stop_token_ids=[2], + ) +) + # Gorilla ConvTemplateRegistry.register_conv_template( Conversation( diff --git a/python/mlc_chat/serve/async_engine.py b/python/mlc_chat/serve/async_engine.py index 74058ea314..97330fea0d 100644 --- a/python/mlc_chat/serve/async_engine.py +++ b/python/mlc_chat/serve/async_engine.py @@ -128,7 +128,7 @@ def __init__( if kv_cache_config.max_total_sequence_length is None: kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( - models, config_file_paths + models, config_file_paths, kv_cache_config.max_num_sequence ) if kv_cache_config.prefill_chunk_size is None: kv_cache_config.prefill_chunk_size = prefill_chunk_size diff --git a/python/mlc_chat/serve/engine.py b/python/mlc_chat/serve/engine.py index 407fb72f17..f5e69e6d54 100644 --- a/python/mlc_chat/serve/engine.py +++ b/python/mlc_chat/serve/engine.py @@ -138,12 +138,15 @@ def _convert_model_info(model: ModelInfo) -> List[Any]: def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals - models: List[ModelInfo], config_file_paths: List[str] + models: List[ModelInfo], config_file_paths: List[str], max_num_sequence: int ) -> int: """Estimate the max total sequence length (capacity) of the KV cache.""" assert len(models) != 0 kv_bytes_per_token = 0 + kv_aux_workspace_bytes = 0 + model_workspace_bytes = 0 + logit_processor_workspace_bytes = 0 params_bytes = 0 temp_func_bytes = 0 @@ -169,15 +172,26 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals model_config = json_object["model_config"] num_layers = model_config["num_hidden_layers"] hidden_size = model_config["hidden_size"] - num_qo_heads = model_config["num_attention_heads"] - num_kv_heads = model_config["num_key_value_heads"] + head_dim = model_config["head_dim"] + vocab_size = model_config["vocab_size"] tensor_parallel_shards = model_config["tensor_parallel_shards"] - kv_bytes_per_token += ( - (hidden_size / num_qo_heads) - * (num_kv_heads / tensor_parallel_shards) # on single GPU - * num_layers - * 4 # key, value, fp16 - * 1.10 # over estimation to guarantee safety + num_qo_heads = model_config["num_attention_heads"] / tensor_parallel_shards + num_kv_heads = model_config["num_key_value_heads"] / tensor_parallel_shards + prefill_chunk_size = model_config["prefill_chunk_size"] + kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25 + kv_aux_workspace_bytes += ( + (max_num_sequence + 1) * 88 + + prefill_chunk_size * (num_qo_heads + 1) * 8 + + prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 + + 48 * 1024 * 1024 + ) + model_workspace_bytes += ( + prefill_chunk_size * 4 + + max_num_sequence * 4 + + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 + ) + logit_processor_workspace_bytes += ( + max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 ) # Get single-card GPU size. @@ -191,7 +205,15 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals ) max_total_sequence_length = int( - (int(gpu_size_bytes) * 0.97 - params_bytes * 1.04 - temp_func_bytes) / kv_bytes_per_token + ( + int(gpu_size_bytes) * 0.85 + - params_bytes + - temp_func_bytes + - kv_aux_workspace_bytes + - model_workspace_bytes + - logit_processor_workspace_bytes + ) + / kv_bytes_per_token ) assert max_total_sequence_length > 0, ( "Cannot estimate KV cache capacity. " @@ -199,7 +221,12 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals ) total_size = ( - params_bytes * 1.05 + temp_func_bytes + kv_bytes_per_token * max_total_sequence_length + params_bytes + + temp_func_bytes + + kv_aux_workspace_bytes + + model_workspace_bytes + + logit_processor_workspace_bytes + + kv_bytes_per_token * max_total_sequence_length ) logger.info( "%s: %d.", @@ -211,8 +238,8 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals green("Estimated total single GPU memory usage"), total_size / 1024 / 1024, params_bytes / 1024 / 1024, - kv_bytes_per_token * max_total_sequence_length / 1024 / 1024, - temp_func_bytes / 1024 / 1024, + (kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes) / 1024 / 1024, + (model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes) / 1024 / 1024, ) return int(max_total_sequence_length) @@ -299,7 +326,7 @@ def __init__( # pylint: disable=too-many-arguments if kv_cache_config.max_total_sequence_length is None: kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( - models, config_file_paths + models, config_file_paths, kv_cache_config.max_num_sequence ) if kv_cache_config.prefill_chunk_size is None: kv_cache_config.prefill_chunk_size = prefill_chunk_size diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index a30b744018..324c4b377c 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -42,7 +42,7 @@ def check_openai_nonstream_response( model: str, object_str: str, num_choices: int, - finish_reason: str, + finish_reasons: List[str], completion_tokens: Optional[int] = None, echo_prompt: Optional[str] = None, suffix: Optional[str] = None, @@ -57,7 +57,7 @@ def check_openai_nonstream_response( assert len(choices) == num_choices for idx, choice in enumerate(choices): assert choice["index"] == idx - assert choice["finish_reason"] == finish_reason + assert choice["finish_reason"] in finish_reasons text: str if not is_chat_completion: @@ -95,7 +95,7 @@ def check_openai_stream_response( model: str, object_str: str, num_choices: int, - finish_reason: str, + finish_reasons: List[str], completion_tokens: Optional[int] = None, echo_prompt: Optional[str] = None, suffix: Optional[str] = None, @@ -126,9 +126,9 @@ def check_openai_stream_response( outputs[idx] += delta["content"] if finished[idx]: - assert choice["finish_reason"] == finish_reason + assert choice["finish_reason"] in finish_reasons elif choice["finish_reason"] is not None: - assert choice["finish_reason"] == finish_reason + assert choice["finish_reason"] in finish_reasons finished[idx] = True if not is_chat_completion: @@ -171,7 +171,7 @@ def test_openai_v1_models( # `served_model` and `launch_server` are pytest fixtures # defined in conftest.py. - response = requests.get(OPENAI_V1_MODELS_URL, timeout=60).json() + response = requests.get(OPENAI_V1_MODELS_URL, timeout=180).json() assert response["object"] == "list" models = response["data"] assert isinstance(models, list) @@ -202,7 +202,7 @@ def test_openai_v1_completions( "stream": stream, } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) if not stream: check_openai_nonstream_response( response.json(), @@ -210,7 +210,7 @@ def test_openai_v1_completions( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], completion_tokens=max_tokens, ) else: @@ -225,7 +225,7 @@ def test_openai_v1_completions( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], completion_tokens=max_tokens, ) @@ -255,7 +255,7 @@ def test_openai_v1_completions_openai_package( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], completion_tokens=max_tokens, ) else: @@ -268,7 +268,7 @@ def test_openai_v1_completions_openai_package( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], completion_tokens=max_tokens, ) @@ -284,7 +284,7 @@ def test_openai_v1_completions_invalid_requested_model( "prompt": "What is the meaning of life?", "max_tokens": 10, } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) expect_error( response_str=response.json(), msg_prefix=f'The requested model "{model}" is not served.' ) @@ -309,7 +309,7 @@ def test_openai_v1_completions_echo( "stream": stream, } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) if not stream: check_openai_nonstream_response( response.json(), @@ -317,7 +317,7 @@ def test_openai_v1_completions_echo( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], completion_tokens=max_tokens, echo_prompt=prompt, ) @@ -333,7 +333,7 @@ def test_openai_v1_completions_echo( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], completion_tokens=max_tokens, echo_prompt=prompt, ) @@ -359,7 +359,7 @@ def test_openai_v1_completions_suffix( "stream": stream, } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) if not stream: check_openai_nonstream_response( response.json(), @@ -367,7 +367,7 @@ def test_openai_v1_completions_suffix( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], completion_tokens=max_tokens, suffix=suffix, ) @@ -383,7 +383,7 @@ def test_openai_v1_completions_suffix( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], completion_tokens=max_tokens, suffix=suffix, ) @@ -411,7 +411,7 @@ def test_openai_v1_completions_stop_str( "stream": stream, } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) if not stream: check_openai_nonstream_response( response.json(), @@ -419,7 +419,7 @@ def test_openai_v1_completions_stop_str( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="stop", + finish_reasons=["stop", "length"], stop=stop, ) else: @@ -434,7 +434,7 @@ def test_openai_v1_completions_stop_str( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="stop", + finish_reasons=["stop", "length"], stop=stop, ) @@ -458,7 +458,7 @@ def test_openai_v1_completions_temperature( "temperature": 0.0, } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) if not stream: check_openai_nonstream_response( response.json(), @@ -466,7 +466,7 @@ def test_openai_v1_completions_temperature( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], ) else: responses = [] @@ -480,7 +480,7 @@ def test_openai_v1_completions_temperature( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], ) @@ -506,7 +506,7 @@ def test_openai_v1_completions_logit_bias( "logit_bias": {338: -100}, # 338 is " is" in Llama tokenizer. } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) if not stream: check_openai_nonstream_response( response.json(), @@ -514,7 +514,7 @@ def test_openai_v1_completions_logit_bias( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], ) else: responses = [] @@ -528,7 +528,7 @@ def test_openai_v1_completions_logit_bias( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], ) @@ -552,7 +552,7 @@ def test_openai_v1_completions_presence_frequency_penalty( "presence_penalty": 2.0, } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) if not stream: check_openai_nonstream_response( response.json(), @@ -560,7 +560,7 @@ def test_openai_v1_completions_presence_frequency_penalty( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], ) else: responses = [] @@ -574,7 +574,7 @@ def test_openai_v1_completions_presence_frequency_penalty( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], ) @@ -595,8 +595,8 @@ def test_openai_v1_completions_seed( "seed": 233, } - response1 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) - response2 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response1 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) + response2 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) for response in [response1, response2]: check_openai_nonstream_response( response.json(), @@ -604,7 +604,7 @@ def test_openai_v1_completions_seed( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], ) text1 = response1.json()["choices"][0]["text"] @@ -621,7 +621,7 @@ def test_openai_v1_completions_prompt_overlong( # `served_model` and `launch_server` are pytest fixtures # defined in conftest.py. - num_tokens = 17000 + num_tokens = 1000000 prompt = [128] * num_tokens payload = { "model": served_model[0], @@ -630,7 +630,7 @@ def test_openai_v1_completions_prompt_overlong( "stream": stream, } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) error_msg_prefix = ( f"Request prompt has {num_tokens} tokens in total, larger than the model capacity" ) @@ -664,7 +664,7 @@ def test_openai_v1_completions_invalid_logprobs( "top_logprobs": 4, } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY assert response.json()["detail"][0]["msg"].endswith( '"logprobs" must be True to support "top_logprobs"' @@ -673,8 +673,8 @@ def test_openai_v1_completions_invalid_logprobs( payload["logprobs"] = True payload["top_logprobs"] = 6 - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY assert response.json()["detail"][0]["msg"].endswith('"top_logprobs" must be in range [0, 5]') @@ -695,7 +695,7 @@ def test_openai_v1_completions_unsupported_args( "best_of": best_of, } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) error_msg_prefix = 'Request fields "best_of" are not supported right now.' expect_error(response.json(), msg_prefix=error_msg_prefix) @@ -719,7 +719,7 @@ def test_openai_v1_completions_request_cancellation( # The server should still be alive after a request cancelled. # We query `v1/models` to validate the server liveness. - response = requests.get(OPENAI_V1_MODELS_URL, timeout=60).json() + response = requests.get(OPENAI_V1_MODELS_URL, timeout=180).json() assert response["object"] == "list" models = response["data"] @@ -774,7 +774,7 @@ def test_openai_v1_chat_completions( "stream": stream, } - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) if not stream: check_openai_nonstream_response( response.json(), @@ -782,7 +782,7 @@ def test_openai_v1_chat_completions( model=served_model[0], object_str="chat.completion", num_choices=1, - finish_reason="stop", + finish_reasons=["stop"], ) else: responses = [] @@ -796,7 +796,7 @@ def test_openai_v1_chat_completions( model=served_model[0], object_str="chat.completion.chunk", num_choices=1, - finish_reason="stop", + finish_reasons=["stop"], ) @@ -826,7 +826,7 @@ def test_openai_v1_chat_completions_openai_package( model=served_model[0], object_str="chat.completion", num_choices=1, - finish_reason="stop", + finish_reasons=["stop"], ) else: responses = [] @@ -838,7 +838,7 @@ def test_openai_v1_chat_completions_openai_package( model=served_model[0], object_str="chat.completion.chunk", num_choices=1, - finish_reason="stop", + finish_reasons=["stop"], ) @@ -860,7 +860,7 @@ def test_openai_v1_chat_completions_max_tokens( "max_tokens": max_tokens, } - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) if not stream: check_openai_nonstream_response( response.json(), @@ -868,7 +868,7 @@ def test_openai_v1_chat_completions_max_tokens( model=served_model[0], object_str="chat.completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], completion_tokens=max_tokens, ) else: @@ -883,7 +883,7 @@ def test_openai_v1_chat_completions_max_tokens( model=served_model[0], object_str="chat.completion.chunk", num_choices=1, - finish_reason="length", + finish_reasons=["length"], completion_tokens=max_tokens, ) @@ -907,7 +907,7 @@ def test_openai_v1_chat_completions_ignore_eos( "ignore_eos": True, } - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) if not stream: check_openai_nonstream_response( response.json(), @@ -915,7 +915,7 @@ def test_openai_v1_chat_completions_ignore_eos( model=served_model[0], object_str="chat.completion", num_choices=1, - finish_reason="length", + finish_reasons=["length"], completion_tokens=max_tokens, ) else: @@ -930,7 +930,7 @@ def test_openai_v1_chat_completions_ignore_eos( model=served_model[0], object_str="chat.completion.chunk", num_choices=1, - finish_reason="length", + finish_reasons=["length"], completion_tokens=max_tokens, ) @@ -958,7 +958,7 @@ def test_openai_v1_chat_completions_system_prompt_wrong_pos( "stream": stream, } - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) error_msg = "System prompt at position 1 in the message list is invalid." if not stream: expect_error(response.json(), msg_prefix=error_msg) @@ -980,7 +980,7 @@ def test_debug_dump_event_trace( # defined in conftest.py. # We only check that the request does not fail. payload = {"model": served_model[0]} - response = requests.post(DEBUG_DUMP_EVENT_TRACE_URL, json=payload, timeout=60) + response = requests.post(DEBUG_DUMP_EVENT_TRACE_URL, json=payload, timeout=180) assert response.status_code == HTTPStatus.OK From 607dc5a7486e0ca87cd7f8fa9e2e8223e1eec490 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 27 Feb 2024 09:28:38 -0500 Subject: [PATCH 006/531] [Fix] Fix `u_char` for Windows build (#1848) Prior to this PR, `u_char` was used while it is not a standard type in C++, which causes Windows build failure. This PR fixes it by using `unsigned char`. --- cpp/serve/data.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/serve/data.cc b/cpp/serve/data.cc index 54e404ae1f..770619f7c3 100644 --- a/cpp/serve/data.cc +++ b/cpp/serve/data.cc @@ -103,7 +103,7 @@ inline void TokenToLogProbJSON(const Tokenizer& tokenizer, const TokenProbPair& (*os) << "\"bytes\": ["; int token_len = token.size(); for (int pos = 0; pos < token_len; ++pos) { - (*os) << static_cast(static_cast(token[pos])); + (*os) << static_cast(static_cast(token[pos])); if (pos != token_len - 1) { (*os) << ", "; } From c4d1b69cf0613f581b4bdfdb17415d8e30ce4a04 Mon Sep 17 00:00:00 2001 From: Git bot Date: Tue, 27 Feb 2024 16:13:57 +0000 Subject: [PATCH 007/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 59c3556043..2c1ce3ab46 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 59c3556043abdc88f3ed98e07aa6176ac9a3f0cd +Subproject commit 2c1ce3ab467f9367c14afd9579ed1388aaae0b90 From 31e05717ca61af268335a5958699a47931866e43 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Tue, 27 Feb 2024 20:05:13 -0500 Subject: [PATCH 008/531] [Fix] Add phi lm head name to is_final_fc, add q4f16_ft to CI (#1849) [Fix] Add phi lm head name to is_final_fc --- python/mlc_chat/quantization/utils.py | 2 +- tests/python/integration/test_model_compile.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/mlc_chat/quantization/utils.py b/python/mlc_chat/quantization/utils.py index 4159da8f04..05a9b9e233 100644 --- a/python/mlc_chat/quantization/utils.py +++ b/python/mlc_chat/quantization/utils.py @@ -44,4 +44,4 @@ def convert_uint_to_float( # pylint: disable=too-many-arguments def is_final_fc(name: str) -> bool: """Determines whether the parameter is the last layer based on its name.""" # TODO: use more specious condition to determine final fc # pylint: disable=fixme - return name in ["head", "lm_head"] + return name in ["head", "lm_head", "lm_head.linear", "embed_out"] diff --git a/tests/python/integration/test_model_compile.py b/tests/python/integration/test_model_compile.py index 92c8894ed9..7dbdbf8109 100644 --- a/tests/python/integration/test_model_compile.py +++ b/tests/python/integration/test_model_compile.py @@ -65,6 +65,7 @@ "q3f16_1", "q4f16_1", "q4f32_1", + "q4f16_ft", ] TENSOR_PARALLEL_SHARDS = [ 1, @@ -102,6 +103,9 @@ def test_model_compile(): # pylint: disable=too-many-locals TENSOR_PARALLEL_SHARDS, ) ): + if not target.startswith("cuda") and quant == "q4f16_ft": + # FasterTransformer only works with cuda + continue log_file = os.path.join(tmp_dir, f"lib{idx}.log") cmd = [ sys.executable, From 89f3e41447f132780412f5c0e9f4d6592242f983 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 27 Feb 2024 21:07:39 -0600 Subject: [PATCH 009/531] [Build] Replace mod_transform_before_build with IRModule pass (#1852) Instead of a python function that returns an updated `IRModule`, the new `optimize_mod_pipeline` function returns a `tvm.ir.transform.Pass` which can be applied to an `IRModule`. --- mlc_llm/core.py | 82 ++++++++++++++++++------------------------------- 1 file changed, 30 insertions(+), 52 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index bd86e0a4c9..fa415cc36e 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -550,41 +550,12 @@ def get_cuda_sm_version(): return sm -def mod_transform_before_build( - mod: tvm.IRModule, - param_manager: param_manager.ParamManager, +def optimize_mod_pipeline( args: argparse.Namespace, config: Dict, -) -> tvm.IRModule: +) -> tvm.ir.transform.Pass: """First-stage: Legalize ops and trace""" - if args.model.startswith("minigpt"): - model_names = ["embed"] - else: - model_names = [ - "prefill", - "decode", - ] - - if not args.use_vllm_attention: - model_names += [ - "create_kv_cache", - "softmax_with_temperature", - "get_metadata", - ] - else: - # This is equivalent to prefill but without KV cache. It is used for - # determining the number of paged cache blocks that can be allocated. - model_names.append("evaluate") - - if args.sep_embed: - model_names = ["embed", "prefill_with_embed"] + model_names[1:] - if args.enable_batching: - model_names[2] = "decode_with_embed" - if args.model.lower().startswith("rwkv-"): - model_names += ["reset_kv_cache"] - - mod = param_manager.transform_dequantize()(mod) - mod = relax.transform.BundleModelParams()(mod) + seq = [] use_ft_quant = args.quantization.name in [ "q4f16_ft", @@ -592,7 +563,7 @@ def mod_transform_before_build( "q4f16_ft_group", "q8f16_ft_group", ] - mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod) + seq.append(mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)) if ( not args.enable_batching @@ -610,12 +581,12 @@ def mod_transform_before_build( if max_seq_len: num_key_value_heads = config.get_num_key_value_heads() # pylint: disable=no-value-for-parameter - mod = fuse_split_rotary_embedding( + seq.append(fuse_split_rotary_embedding( config.num_attention_heads // args.num_shards, num_key_value_heads // args.num_shards, config.hidden_size // args.num_shards, config.position_embedding_base, - )(mod) + ) if args.target_kind == "cuda": patterns = [] @@ -625,8 +596,8 @@ def mod_transform_before_build( if has_cutlass and not args.no_cutlass_attn: # pylint: disable=no-value-for-parameter if args.use_flash_attn_mqa: - mod = rewrite_attention(use_flash_mqa=True)(mod) - mod = rewrite_attention(use_flash_mqa=False)(mod) + seq.append(rewrite_attention(use_flash_mqa=True)) + seq.append(rewrite_attention(use_flash_mqa=False)) patterns += get_patterns_with_prefix("cutlass.attention") if has_cutlass and not args.no_cutlass_norm: @@ -650,31 +621,37 @@ def mod_transform_before_build( if hasattr(config, "rms_norm_eps"): options["cutlass"]["rms_eps"] = config.rms_norm_eps - mod = tvm.transform.Sequential( + seq.extend( [ relax.transform.FuseOpsByPattern( patterns, bind_constants=False, annotate_codegen=True ), annotate_workspace, relax.transform.AllocateWorkspace(), - relax.transform.RunCodegen(options, entry_functions=model_names), + relax.transform.RunCodegen(options), ] - )(mod) + ) if args.target_kind == "android": - mod = mlc_llm.transform.FuseTranspose1Matmul()(mod) - mod = mlc_llm.transform.FuseTranspose2Matmul()(mod) - mod = mlc_llm.transform.FuseTransposeMatmul()(mod) - mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter - mod = mlc_llm.transform.FuseDecodeMatmulEwise()(mod) - mod = mlc_llm.transform.FuseDecodeTake()(mod) - mod = relax.transform.DeadCodeElimination(model_names)(mod) - mod = mlc_llm.transform.CleanUpTIRAttrs()(mod) - mod_deploy = mod + seq.extend( + [ + mlc_llm.transform.FuseTranspose1Matmul(), + mlc_llm.transform.FuseTranspose2Matmul(), + ] + ) + seq.extend( + [ + mlc_llm.transform.FuseTransposeMatmul(), + relax.pipeline.get_pipeline(), + mlc_llm.transform.FuseDecodeMatmulEwise(), + mlc_llm.transform.FuseDecodeTake(), + relax.transform.DeadCodeElimination(), + mlc_llm.transform.CleanUpTIRAttrs(), + ] + ) - utils.debug_dump_script(mod_deploy, "mod_deploy.py", args) + return tvm.ir.transform.Sequential(seq, name="mlc_llm.core.optimize_mod_pipeline") - return mod_deploy def dump_mlc_chat_config( @@ -867,6 +844,7 @@ def build_model_from_args(args: argparse.Namespace): for qspec_updater_class in param_manager.qspec_updater_classes: qspec_updater = qspec_updater_class(param_manager) qspec_updater.visit_module(mod) + mod = param_manager.transform_dequantize()(mod) if not args.build_model_only: parameter_transforms = [] @@ -958,7 +936,7 @@ def build_model_from_args(args: argparse.Namespace): if args.convert_weights_only: exit(0) - mod = mod_transform_before_build(mod, param_manager, args, model_config) + mod = optimize_mod_pipeline(args, model_config)(mod) if args.num_shards > 1: # We require a "create_sharding_info" function for all # multi-GPU models, even if they are using pre-sharded From 6ce17595e0d944b1203b5aee513a38e6abf31695 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Wed, 28 Feb 2024 14:44:49 +0800 Subject: [PATCH 010/531] [SLM] Add support for InternLM architecture (#1835) * Create __init__.py * Add files via upload * Update model.py * Update model_preset.py * Update conv_templates.cc * Update internlm_loader.py * Update internlm_quantization.py * fix name of notes * Update model.py * Migration * fix pylint issue * fix pylint issue * fix pylint error * Update internlm_loader.py * Update __init__.py * Update __init__.py * Delete python/mlc_chat/model/internlm/__init__.py * Add files via upload --- cpp/conv_templates.cc | 1 + .../model/baichuan/baichuan_loader.py | 6 +- python/mlc_chat/model/internlm/__init__.py | 0 .../model/internlm/internlm_loader.py | 102 +++++ .../mlc_chat/model/internlm/internlm_model.py | 350 ++++++++++++++++++ .../model/internlm/internlm_quantization.py | 53 +++ python/mlc_chat/model/model.py | 15 + python/mlc_chat/model/model_preset.py | 26 ++ 8 files changed, 550 insertions(+), 3 deletions(-) create mode 100644 python/mlc_chat/model/internlm/__init__.py create mode 100644 python/mlc_chat/model/internlm/internlm_loader.py create mode 100644 python/mlc_chat/model/internlm/internlm_model.py create mode 100644 python/mlc_chat/model/internlm/internlm_quantization.py diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index c25c75e129..b0928b7457 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -759,6 +759,7 @@ Conversation Conversation::FromTemplate(const std::string& name) { {"stablelm-2", StableLM2}, {"baichuan", ChatML}, {"gemma_instruction", GemmaInstruction}, + {"internlm", ChatML}, }; auto it = factory.find(name); if (it == factory.end()) { diff --git a/python/mlc_chat/model/baichuan/baichuan_loader.py b/python/mlc_chat/model/baichuan/baichuan_loader.py index 01b85281ff..2807060438 100644 --- a/python/mlc_chat/model/baichuan/baichuan_loader.py +++ b/python/mlc_chat/model/baichuan/baichuan_loader.py @@ -1,5 +1,5 @@ """ -This file specifies how MLC's StableLM parameter maps from other formats, for example HuggingFace +This file specifies how MLC's BaichuanLM parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ @@ -19,8 +19,8 @@ def huggingface(model_config: BaichuanConfig, quantization: Quantization) -> Ext Parameters ---------- - model_config : GPT2Config - The configuration of the GPT-2 model. + model_config : BaichuanConfig + The configuration of the Baichuan model. quantization : Quantization The quantization configuration. diff --git a/python/mlc_chat/model/internlm/__init__.py b/python/mlc_chat/model/internlm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_chat/model/internlm/internlm_loader.py b/python/mlc_chat/model/internlm/internlm_loader.py new file mode 100644 index 0000000000..7e80aeeb64 --- /dev/null +++ b/python/mlc_chat/model/internlm/internlm_loader.py @@ -0,0 +1,102 @@ +""" +This file specifies how MLC's InternLM parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .internlm_model import InternLMConfig, InternLMForCausalLM + + +def huggingface(model_config: InternLMConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : InternLMConfig + The configuration of the InternLM model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = InternLMForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.wqkv_pack.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + mlc_name = f"{attn}.wqkv_pack.bias" + if mlc_name in named_parameters: + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.bias", + f"{attn}.k_proj.bias", + f"{attn}.v_proj.bias", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_chat/model/internlm/internlm_model.py b/python/mlc_chat/model/internlm/internlm_model.py new file mode 100644 index 0000000000..0f6b92a76f --- /dev/null +++ b/python/mlc_chat/model/internlm/internlm_model.py @@ -0,0 +1,350 @@ +""" +Implementation for InternLM architecture. +TODO: add docstring +""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.nn import PagedKVCache, RopeMode +from mlc_chat.support import logging +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class InternLMConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the InternLM model.""" + + vocab_size: int + hidden_size: int + num_hidden_layers: int + num_attention_heads: int + rms_norm_eps: float + intermediate_size: int + bias: bool + use_cache: bool + pad_token_id: int + bos_token_id: int + eos_token_id: int + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + + +# pylint: disable=invalid-name,missing-docstring + + +class InternLMAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: InternLMConfig): + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.context_window_size + + self.wqkv_pack = nn.Linear( + self.hidden_size, 3 * self.num_heads * self.head_dim, bias=config.bias + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h = self.head_dim, self.num_heads + b, s, _ = hidden_states.shape + qkv = self.wqkv_pack(hidden_states) + qkv = op.reshape(qkv, (b, s, 3 * h, d)) + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), (b, s, h * d) + ) + attn_output = self.o_proj(output) + return attn_output + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h = self.head_dim, self.num_heads + b, s, _ = hidden_states.shape + qkv = self.wqkv_pack(hidden_states) + qkv = op.reshape(qkv, (b, s, 3 * h, d)) + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), (b, s, h * d) + ) + attn_output = self.o_proj(output) + return attn_output + + +class InternLMMLP(nn.Module): + def __init__(self, config: InternLMConfig): + self.gate_up_proj = nn.Linear( + in_features=config.hidden_size, + out_features=2 * config.intermediate_size, + bias=False, + ) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.down_proj(op.silu(x1) * x2) + + +class InternLMDecoderLayer(nn.Module): + def __init__(self, config: InternLMConfig): + self.self_attn = InternLMAttention(config) + self.mlp = InternLMMLP(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, -1, config.rms_norm_eps, bias=False + ) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) + hidden_states = out + hidden_states + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = out + hidden_states + return hidden_states + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn.batch_forward( + self.input_layernorm(hidden_states), paged_kv_cache, layer_id + ) + hidden_states = out + hidden_states + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = out + hidden_states + return hidden_states + + +class InternLMModel(nn.Module): + def __init__(self, config: InternLMConfig): + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + def batch_forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + for layer_id, layer in enumerate(self.layers): + hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class InternLMForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: InternLMConfig): + self.model = InternLMModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_attention_heads + self.vocab_size = config.vocab_size + self.rope_theta = 10000 + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.model.batch_forward(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + return self.model.embed_tokens(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/internlm/internlm_quantization.py b/python/mlc_chat/model/internlm/internlm_quantization.py new file mode 100644 index 0000000000..22f2eae2f5 --- /dev/null +++ b/python/mlc_chat/model/internlm/internlm_quantization.py @@ -0,0 +1,53 @@ +"""This file specifies how MLC's InternLM parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .internlm_model import InternLMConfig, InternLMForCausalLM + + +def group_quant( + model_config: InternLMConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a InternLM-architecture model using group quantization.""" + model: nn.Module = InternLMForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: InternLMConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a InternLM model using FasterTransformer quantization.""" + model: nn.Module = InternLMForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: InternLMConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a InternLM model without quantization.""" + model: nn.Module = InternLMForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/model/model.py b/python/mlc_chat/model/model.py index 68d052c173..730f5eff6b 100644 --- a/python/mlc_chat/model/model.py +++ b/python/mlc_chat/model/model.py @@ -13,6 +13,7 @@ from .gpt2 import gpt2_loader, gpt2_model, gpt2_quantization from .gpt_bigcode import gpt_bigcode_loader, gpt_bigcode_model, gpt_bigcode_quantization from .gpt_neox import gpt_neox_loader, gpt_neox_model, gpt_neox_quantization +from .internlm import internlm_loader, internlm_model, internlm_quantization from .llama import llama_loader, llama_model, llama_quantization from .mistral import mistral_loader, mistral_model, mistral_quantization from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization @@ -248,4 +249,18 @@ class Model: "ft-quant": baichuan_quantization.ft_quant, }, ), + "internlm": Model( + name="internlm", + model=internlm_model.InternLMForCausalLM, + config=internlm_model.InternLMConfig, + source={ + "huggingface-torch": internlm_loader.huggingface, + "huggingface-safetensor": internlm_loader.huggingface, + }, + quantize={ + "no-quant": internlm_quantization.no_quant, + "group-quant": internlm_quantization.group_quant, + "ft-quant": internlm_quantization.ft_quant, + }, + ), } diff --git a/python/mlc_chat/model/model_preset.py b/python/mlc_chat/model/model_preset.py index bacfd43ffd..0ec2f633c2 100644 --- a/python/mlc_chat/model/model_preset.py +++ b/python/mlc_chat/model/model_preset.py @@ -447,6 +447,32 @@ "use_cache": True, "vocab_size": 125696, }, + "internlm": { + "architectures": ["InternLMForCausalLM"], + "auto_map": { + "AutoConfig": "configuration_internlm.InternLMConfig", + "AutoModel": "modeling_internlm.InternLMForCausalLM", + "AutoModelForCausalLM": "modeling_internlm.InternLMForCausalLM", + }, + "bias": True, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "internlm", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "pad_token_id": 2, + "rms_norm_eps": 1e-06, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.33.2", + "use_cache": True, + "vocab_size": 103168, + }, # TODO(mlc-team): enable the model presets when stablized. # "gemma_2b": { # "architectures": ["GemmaForCausalLM"], From 1497744277fc8634f41d3ea40fafc0454f492bbc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 28 Feb 2024 08:01:31 -0600 Subject: [PATCH 011/531] [Bugfix] Handle model names with multiple path components (#1851) Prior to this commit, a model name with multiple path components (e.g. `dist/models/group_name/model_name`) would have duplicated path components (e.g. `dist/group_name/artifact_path/group_name/libname.so`). This commit resolves the duplication. --- mlc_llm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index fa415cc36e..614baf74a1 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -761,7 +761,7 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None: mod_deploy["decode"] = mod_deploy["decode"].with_attr({"num_input": 3}) ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib) - output_filename = f"{args.model}-{args.quantization.name}-{target_kind}.{args.lib_format}" + output_filename = f"{os.path.split(args.model)[1]}-{args.quantization.name}-{target_kind}.{args.lib_format}" utils.debug_dump_shader(ex, f"{args.model}_{args.quantization.name}_{target_kind}", args) args.lib_path = os.path.join(args.artifact_path, output_filename) From 74563147759144dd6885f6f9e7d22e018a9a7a80 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Wed, 28 Feb 2024 09:02:38 -0500 Subject: [PATCH 012/531] [KVCache] Add max num threads awareness to KVCache kernels (#1822) * [KVCache] Add max num threads to KVCache kernels, fix WebGPU * Read max_num_threads_per_block when available * Change merge state in place kernel * Make attention decode aware of max num threads, not just webgpu Co-authored-by: Egor Churaev * Change util function name --------- Co-authored-by: Egor Churaev --- .../mlc_chat/compiler_pass/fuse_add_norm.py | 7 +- python/mlc_chat/model/model_preset.py | 24 +++ python/mlc_chat/nn/kv_cache.py | 182 +++++++++++------- python/mlc_chat/op/position_embedding.py | 26 ++- python/mlc_chat/support/max_thread_check.py | 38 ++++ tests/python/model/test_llama.py | 5 +- 6 files changed, 199 insertions(+), 83 deletions(-) create mode 100644 python/mlc_chat/support/max_thread_check.py diff --git a/python/mlc_chat/compiler_pass/fuse_add_norm.py b/python/mlc_chat/compiler_pass/fuse_add_norm.py index 88ed1dc73c..04adefc90d 100644 --- a/python/mlc_chat/compiler_pass/fuse_add_norm.py +++ b/python/mlc_chat/compiler_pass/fuse_add_norm.py @@ -6,6 +6,8 @@ from tvm.relax.dpl.pattern import is_op, wildcard from tvm.script import tir as T +from ..support.max_thread_check import get_max_num_threads_per_block + # mypy: disable-error-code="attr-defined,valid-type" # pylint: disable=too-many-locals,invalid-name @@ -147,8 +149,9 @@ def __init__(self, target: tvm.target.Target) -> None: """ self.TX = 1024 # default - if target.max_num_threads < self.TX: - self.TX = target.max_num_threads + max_num_threads_per_block = get_max_num_threads_per_block(target) + if max_num_threads_per_block < self.TX: + self.TX = max_num_threads_per_block def transform_module(self, mod: tvm.IRModule, _ctx: tvm.transform.PassContext) -> tvm.IRModule: """IRModule-level transformation.""" diff --git a/python/mlc_chat/model/model_preset.py b/python/mlc_chat/model/model_preset.py index 0ec2f633c2..04a20dc210 100644 --- a/python/mlc_chat/model/model_preset.py +++ b/python/mlc_chat/model/model_preset.py @@ -153,6 +153,30 @@ "context_window_size": 2048, "prefill_chunk_size": 2048, }, + "tinyllama_1b_chat_v1.0": { + "architectures": ["LlamaForCausalLM"], + "attention_bias": False, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 5632, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 22, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 10000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.35.0", + "use_cache": True, + "vocab_size": 32000, + }, "mistral_7b": { "architectures": ["MistralForCausalLM"], "bos_token_id": 1, diff --git a/python/mlc_chat/nn/kv_cache.py b/python/mlc_chat/nn/kv_cache.py index e956037411..5e39a614e6 100644 --- a/python/mlc_chat/nn/kv_cache.py +++ b/python/mlc_chat/nn/kv_cache.py @@ -18,6 +18,11 @@ rope_freq, ) +from ..support.max_thread_check import ( + check_thread_limits, + get_max_num_threads_per_block, +) + class RopeMode(enum.IntEnum): """The RoPE mode of the Paged KV cache. @@ -477,10 +482,20 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: Target): # pylint: disable= group_size = h_q // h_kv sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + bdx = 32 num_warps = 4 tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 L_per_cta = tile_x // group_size + # Otherwise we would exceed maxComputeWorkgroupStorageSize + if ( + str(target.kind) == "webgpu" + and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 + ): + tile_z = 8 + num_warps = 2 + check_thread_limits(target, bdx=bdx, bdy=num_warps, bdz=1, gdz=1) + def mask(causal, row, col, kv_len, qo_len): return T.if_then_else( causal > 0, @@ -529,7 +544,7 @@ def batch_prefill_paged_kv( for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): for lby in T.thread_binding(h_kv, thread="blockIdx.y"): for lty in T.thread_binding(num_warps, thread="threadIdx.y"): - for ltx in T.thread_binding(32, thread="threadIdx.x"): + for ltx in T.thread_binding(bdx, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() @@ -553,9 +568,9 @@ def batch_prefill_paged_kv( m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - m_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") - m_prev = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") - d_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") ## get tile_no, batch_idx, batch_tiles, batch_rows tile_id[0] = bx @@ -588,8 +603,8 @@ def batch_prefill_paged_kv( T.tvm_storage_sync("shared") # init states - for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): - row: T.int32 = i * 32 * num_warps + ty * 32 + tx + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: m_smem[row] = -5e4 d_smem[row] = 1.0 @@ -667,8 +682,8 @@ def batch_prefill_paged_kv( T.tvm_storage_sync("shared") # Update S, m, d - for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): - row: T.int32 = i * 32 * num_warps + ty * 32 + tx + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: with T.block("update1"): m_prev[i] = m_smem[row] @@ -683,8 +698,8 @@ def batch_prefill_paged_kv( m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) - for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): - row: T.int32 = i * 32 * num_warps + ty * 32 + tx + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx with T.block("update"): for j in T.serial(tile_z): # this is to avoid sync inside condition branch @@ -698,8 +713,8 @@ def batch_prefill_paged_kv( else: S_smem[row, j] = T.exp2(-5e4 - m_new[i]) - for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): - row: T.int32 = i * 32 * num_warps + ty * 32 + tx + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: with T.block("update"): for j in T.serial(tile_z): @@ -752,7 +767,7 @@ def apply_to_qkv_load(sch: tir.Schedule, block): loop_x, loop_y = sch.get_loops(block)[-2:] loop = sch.fuse(loop_x, loop_y) _, ty, tx, vec = sch.split( - loop, factors=[None, num_warps, 32, LOAD_VEC], preserve_unit_iters=True + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True ) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") @@ -764,7 +779,7 @@ def apply_to_so_ewise(sch: tir.Schedule, block, tile): yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[num_warps, 32]) + ty, tx = sch.split(t, factors=[num_warps, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") @@ -776,7 +791,7 @@ def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[num_warps, 32]) + ty, tx = sch.split(t, factors=[num_warps, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") @@ -789,12 +804,12 @@ def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument def apply_to_md(sch, block): loop = sch.get_loops(block)[-1] - _, ty, tx = sch.split(loop, factors=[None, num_warps, 32]) + _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - tile_s = get_tile_size(tile_x, tile_z, 32 * num_warps) - tile_o = get_tile_size(tile_x, tile_y, 32 * num_warps) + tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) + tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) @@ -820,7 +835,8 @@ def _attention_decode( H_kv = num_kv_heads D = head_dim - thread_limit = 512 if str(target.kind) != "webgpu" else 256 + max_num_threads_per_block = get_max_num_threads_per_block(target) + thread_limit = min(max_num_threads_per_block, 512) GROUP_SIZE = H_qo // H_kv VEC_SIZE = min(max(8 // qkv_dtype_bytes, D // 32), 4) @@ -833,6 +849,7 @@ def _attention_decode( bdz = threads_per_CTA // (bdx * bdy) tile_size_per_bdx = 2 if GROUP_SIZE == 1 else 1 log2e = math.log2(math.exp(1)) + check_thread_limits(target, bdx=bdx, bdy=bdy, bdz=bdz, gdz=1) # pylint: disable=line-too-long,too-many-arguments,too-many-branches # fmt: off @@ -1049,6 +1066,11 @@ def _merge_state_inplace( VEC_SIZE = min(max(8 // v_dtype_bytes, head_dim // 32), 4) bdx = head_dim // VEC_SIZE bdy = num_heads + max_num_threads_per_block = get_max_num_threads_per_block(target) + while bdx * bdy > max_num_threads_per_block and bdy > 1: + bdy //= 2 + gdy = num_heads // bdy + check_thread_limits(target, bdx=bdx, bdy=bdy, bdz=1, gdz=1) @T.prim_func def merge_state_inplace( @@ -1068,43 +1090,46 @@ def merge_state_inplace( S_other = T.match_buffer(s_other, (N, H), "float32") for bx in T.thread_binding(N, thread="blockIdx.x"): - for ty in T.thread_binding(bdy, thread="threadIdx.y"): - for tx in T.thread_binding(bdx, thread="threadIdx.x"): - with T.block("merge"): - s_val = _var("float32") - s_other_val = _var("float32") - s_max = _var("float32") - scale = _var("float32") - other_scale = _var("float32") - - v_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, scope="local") - v_other_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, scope="local") - - s_val[0] = S[bx, ty] - s_other_val[0] = S_other[bx, ty] - s_max[0] = T.max(s_val[0], s_other_val[0]) - s_val[0] = T.exp2(s_val[0] - s_max[0]) - s_other_val[0] = T.exp2(s_other_val[0] - s_max[0]) - scale[0] = s_val[0] / (s_val[0] + s_other_val[0]) - other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0]) - - # load v - for vec in T.vectorized(VEC_SIZE): - v_vec[vec] = V[bx, ty, tx * VEC_SIZE + vec] - # load v_other - for vec in T.vectorized(VEC_SIZE): - v_other_vec[vec] = V_other[bx, ty, tx * VEC_SIZE + vec] - - # merge - for vec in T.serial(VEC_SIZE): - v_vec[vec] = v_vec[vec] * scale[0] + v_other_vec[vec] * other_scale[0] - - # store v - for vec in T.vectorized(VEC_SIZE): - V[bx, ty, tx * VEC_SIZE + vec] = v_vec[vec] - - # store s - S[bx, ty] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] + for by in T.thread_binding(gdy, thread="blockIdx.y"): + for ty in T.thread_binding(bdy, thread="threadIdx.y"): + for tx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("merge"): + s_val = _var("float32") + s_other_val = _var("float32") + s_max = _var("float32") + scale = _var("float32") + other_scale = _var("float32") + + v_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, scope="local") + v_other_vec = T.alloc_buffer((VEC_SIZE,), v_dtype, scope="local") + + s_val[0] = S[bx, ty + by * bdy] + s_other_val[0] = S_other[bx, ty + by * bdy] + s_max[0] = T.max(s_val[0], s_other_val[0]) + s_val[0] = T.exp2(s_val[0] - s_max[0]) + s_other_val[0] = T.exp2(s_other_val[0] - s_max[0]) + scale[0] = s_val[0] / (s_val[0] + s_other_val[0]) + other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0]) + + # load v + for vec in T.vectorized(VEC_SIZE): + v_vec[vec] = V[bx, ty + by * bdy, tx * VEC_SIZE + vec] + # load v_other + for vec in T.vectorized(VEC_SIZE): + v_other_vec[vec] = V_other[bx, ty + by * bdy, tx * VEC_SIZE + vec] + + # merge + for vec in T.serial(VEC_SIZE): + v_vec[vec] = ( + v_vec[vec] * scale[0] + v_other_vec[vec] * other_scale[0] + ) + + # store v + for vec in T.vectorized(VEC_SIZE): + V[bx, ty + by * bdy, tx * VEC_SIZE + vec] = v_vec[vec] + + # store s + S[bx, ty + by * bdy] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] # pylint: enable=invalid-name return merge_state_inplace @@ -1119,10 +1144,19 @@ def _attention_prefill_ragged( group_size = h_q // h_kv sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + bdx = 32 num_warps = 4 tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 L_per_cta = tile_x // group_size + # Otherwise we would exceed maxComputeWorkgroupStorageSize + if ( + str(target.kind) == "webgpu" + and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 + ): + tile_z = 8 + num_warps = 2 + def mask(causal, row, col, kv_len, qo_len): return T.if_then_else( causal > 0, @@ -1166,7 +1200,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): for lby in T.thread_binding(h_kv, thread="blockIdx.y"): for lty in T.thread_binding(num_warps, thread="threadIdx.y"): - for ltx in T.thread_binding(32, thread="threadIdx.x"): + for ltx in T.thread_binding(bdx, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() @@ -1190,9 +1224,9 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") - m_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") - m_prev = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") - d_new = T.alloc_buffer((math.ceil(tile_x / (32 * num_warps)),), "float32", scope="local") + m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") ## get tile_no, batch_idx, batch_tiles, batch_rows tile_id[0] = bx @@ -1218,8 +1252,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran T.tvm_storage_sync("shared") # init states - for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): - row: T.int32 = i * 32 * num_warps + ty * 32 + tx + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: m_smem[row] = -5e4 d_smem[row] = 1.0 @@ -1294,8 +1328,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran T.tvm_storage_sync("shared") # Update S, m, d - for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): - row: T.int32 = i * 32 * num_warps + ty * 32 + tx + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: with T.block("update1"): m_prev[i] = m_smem[row] @@ -1310,8 +1344,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) - for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): - row: T.int32 = i * 32 * num_warps + ty * 32 + tx + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx with T.block("update"): for j in T.serial(tile_z): # this is to avoid sync inside condition branch @@ -1325,8 +1359,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran else: S_smem[row, j] = T.exp2(-5e4 - m_new[i]) - for i in T.serial(T.ceildiv(tile_x, 32 * num_warps)): - row: T.int32 = i * 32 * num_warps + ty * 32 + tx + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx if row < tile_x: with T.block("update"): for j in T.serial(tile_z): @@ -1379,7 +1413,7 @@ def apply_to_qkv_load(sch: tir.Schedule, block): loop_x, loop_y = sch.get_loops(block)[-2:] loop = sch.fuse(loop_x, loop_y) _, ty, tx, vec = sch.split( - loop, factors=[None, num_warps, 32, LOAD_VEC], preserve_unit_iters=True + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True ) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") @@ -1391,7 +1425,7 @@ def apply_to_so_ewise(sch: tir.Schedule, block, tile): yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[num_warps, 32]) + ty, tx = sch.split(t, factors=[num_warps, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") @@ -1403,7 +1437,7 @@ def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[num_warps, 32]) + ty, tx = sch.split(t, factors=[num_warps, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") @@ -1416,12 +1450,12 @@ def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument def apply_to_md(sch, block): loop = sch.get_loops(block)[-1] - _, ty, tx = sch.split(loop, factors=[None, num_warps, 32]) + _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - tile_s = get_tile_size(tile_x, tile_z, 32 * num_warps) - tile_o = get_tile_size(tile_x, tile_y, 32 * num_warps) + tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) + tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) diff --git a/python/mlc_chat/op/position_embedding.py b/python/mlc_chat/op/position_embedding.py index 12bdaaae45..323afc02da 100644 --- a/python/mlc_chat/op/position_embedding.py +++ b/python/mlc_chat/op/position_embedding.py @@ -7,6 +7,11 @@ from tvm.script import tir as T from tvm.target import Target +from ..support.max_thread_check import ( + check_thread_limits, + get_max_num_threads_per_block, +) + # pylint: disable=invalid-name @@ -313,6 +318,15 @@ def llama_inplace_rope( if rotary_dim is None: rotary_dim = head_dim + VEC_SIZE = 4 + bdx = (head_dim + VEC_SIZE - 1) // VEC_SIZE # T.ceildiv(head_dim, VEC_SIZE) + bdy = 32 + max_num_threads_per_block = get_max_num_threads_per_block(target) + # TODO(mlc-team): Check correctness after `bdy` backoff + while bdx * bdy > max_num_threads_per_block and bdy > 1: + bdy //= 2 + check_thread_limits(target, bdx=bdx, bdy=bdy, bdz=1, gdz=1) + def _rope( x: T.Buffer, s: tir.Var, @@ -359,12 +373,12 @@ def tir_rotary( # pylint: disable=too-many-locals instance_offset: T.int32 = append_len_indptr[b] rope_offset: T.int32 = rope_offsets[b] append_len: T.int32 = append_len_indptr[b + 1] - append_len_indptr[b] - for s0 in range(T.ceildiv(append_len, 32)): - for s1 in T.thread_binding(32, thread="threadIdx.y"): - for d0 in T.thread_binding(T.ceildiv(head_dim, 4), thread="threadIdx.x"): - for d1 in T.vectorized(4): - s: T.int32 = s0 * 32 + s1 - d: T.int32 = d0 * 4 + d1 + for s0 in range(T.ceildiv(append_len, bdy)): + for s1 in T.thread_binding(bdy, thread="threadIdx.y"): + for d0 in T.thread_binding(bdx, thread="threadIdx.x"): + for d1 in T.vectorized(VEC_SIZE): + s: T.int32 = s0 * bdy + s1 + d: T.int32 = d0 * VEC_SIZE + d1 if s < append_len and d < rotary_dim: if h < num_q_heads: q[s + instance_offset, h, d] = _rope(q, s, h, d, rope_offset, instance_offset) diff --git a/python/mlc_chat/support/max_thread_check.py b/python/mlc_chat/support/max_thread_check.py new file mode 100644 index 0000000000..6c078c3bbf --- /dev/null +++ b/python/mlc_chat/support/max_thread_check.py @@ -0,0 +1,38 @@ +"""Helper functions for checking max num thread.""" + +from tvm.target import Target + + +def get_max_num_threads_per_block(target: Target): + """ + max(max_num_threads, max_threads_per_block); if latter does not exist, return max_num_threads. + We add this method since some targets have both fields and `max_threads_per_block` is larger. + """ + max_num_threads = target.max_num_threads + max_threads_per_block = target.attrs.get("max_threads_per_block", None) + if max_threads_per_block is None: + return max_num_threads + return max(max_num_threads, max_threads_per_block) + + +def check_thread_limits(target: Target, bdx: int, bdy: int, bdz: int, gdz: int): + """ + Check whether max num threads exceeded given a target. + + Parameters + ---------- + bdx: threadIdx.x + bdy: threadIdx.y + bdz: threadIdx.z + gdz: blockIdx.z + """ + max_num_threads_per_block = get_max_num_threads_per_block(target) + + assert ( + bdx * bdy * bdz <= max_num_threads_per_block + ), f"{target.kind} max num threads exceeded: {bdx}*{bdy}*{bdz}>{max_num_threads_per_block}" + + if str(target.kind) == "webgpu": + # https://gpuweb.github.io/gpuweb/#dom-supported-limits-maxcomputeworkgroupsizez + assert bdz <= 64, f"webgpu's threadIdx.z cannot exceed 64, but got bdz={bdz}" + assert gdz == 1, f"webgpu's blockIdx.z should be 1, but got gdz={gdz}" diff --git a/tests/python/model/test_llama.py b/tests/python/model/test_llama.py index 8ea682f7f0..6e1b38dbca 100644 --- a/tests/python/model/test_llama.py +++ b/tests/python/model/test_llama.py @@ -4,7 +4,9 @@ from mlc_chat.model import MODEL_PRESETS, MODELS -@pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"]) +@pytest.mark.parametrize( + "model_name", ["llama2_7b", "llama2_13b", "llama2_70b", "tinyllama_1b_chat_v1.0"] +) def test_llama2_creation(model_name: str): model_info = MODELS["llama"] config = model_info.config.from_dict(MODEL_PRESETS[model_name]) @@ -21,3 +23,4 @@ def test_llama2_creation(model_name: str): test_llama2_creation("llama2_7b") test_llama2_creation("llama2_13b") test_llama2_creation("llama2_70b") + test_llama2_creation("tinyllama_1b_chat_v1") From 52d002fd71eff2789f6335452556524806cb0638 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Wed, 28 Feb 2024 22:11:54 +0800 Subject: [PATCH 013/531] [KVCache] Migrate Baichuan model to PagedKVCache (#1854) --- .../mlc_chat/model/baichuan/baichuan_model.py | 234 ++++++++++++------ 1 file changed, 165 insertions(+), 69 deletions(-) diff --git a/python/mlc_chat/model/baichuan/baichuan_model.py b/python/mlc_chat/model/baichuan/baichuan_model.py index 5bcedd4837..8e8944783e 100644 --- a/python/mlc_chat/model/baichuan/baichuan_model.py +++ b/python/mlc_chat/model/baichuan/baichuan_model.py @@ -10,6 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_chat import op as op_ext +from mlc_chat.nn import PagedKVCache, RopeMode from mlc_chat.support import logging from mlc_chat.support.config import ConfigBase from mlc_chat.support.style import bold @@ -73,7 +74,6 @@ def __post_init__(self): bold("context_window_size"), ) self.prefill_chunk_size = self.context_window_size - assert self.tensor_parallel_shards == 1, "Baichuan currently does not support sharding." # pylint: disable=invalid-name,missing-docstring @@ -89,32 +89,27 @@ def __init__(self, config: BaichuanConfig): self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.k_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim]) - self.v_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim]) + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h = self.head_dim, self.num_heads + b, s, _ = hidden_states.shape + qkv = self.W_pack(hidden_states) + qkv = op.reshape(qkv, (b, s, 3 * h, d)) + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), (b, s, h * d) + ) + attn_output = self.o_proj(output) + return attn_output - def forward( # pylint: disable=too-many-locals - self, - hidden_states: Tensor, - attention_mask: Tensor, - total_seq_len: tir.Var, - ): - d, h, t = self.head_dim, self.num_heads, total_seq_len + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h = self.head_dim, self.num_heads b, s, _ = hidden_states.shape - assert b == 1, "Only support batch size 1 at this moment." - # Step 1. QKV Projection qkv = self.W_pack(hidden_states) qkv = op.reshape(qkv, (b, s, 3 * h, d)) - # Step 2. Apply QK rotary embedding - q, k, v = op_ext.llama_rope(qkv, t, 10000, h, h) - # Step 3. Query and update KVCache - self.k_cache.append(op.squeeze(k, axis=0)) - self.v_cache.append(op.squeeze(v, axis=0)) - k = self.k_cache.view(t) - v = self.v_cache.view(t) - # Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V - output = op_ext.attention(q, k, v, casual_mask=attention_mask) - # Step 5. Apply output projection - return self.o_proj(output) + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), (b, s, h * d) + ) + attn_output = self.o_proj(output) + return attn_output class BaichuanMLP(nn.Module): @@ -140,8 +135,17 @@ def __init__(self, config: BaichuanConfig): self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, norm_eps, bias=False) self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, norm_eps, bias=False) - def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): - out = self.self_attn(self.input_layernorm(hidden_states), attention_mask, total_seq_len) + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) + hidden_states = out + hidden_states + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = out + hidden_states + return hidden_states + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn.batch_forward( + self.input_layernorm(hidden_states), paged_kv_cache, layer_id + ) hidden_states = out + hidden_states out = self.mlp(self.post_attention_layernorm(hidden_states)) hidden_states = out + hidden_states @@ -157,19 +161,33 @@ def __init__(self, config: BaichuanConfig): ) self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) - def forward(self, input_ids: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): - hidden_states = self.embed_tokens(input_ids) - for layer in self.layers: - hidden_states = layer(hidden_states, attention_mask, total_seq_len) + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + def batch_forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + for layer_id, layer in enumerate(self.layers): + hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) hidden_states = self.norm(hidden_states) return hidden_states -class BaichuanForCausalLM(nn.Module): +class BaichuanForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: BaichuanConfig): self.model = BaichuanModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_attention_heads + self.vocab_size = config.vocab_size + self.rope_theta = 10000 + self.tensor_parallel_shards = config.tensor_parallel_shards self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -177,72 +195,150 @@ def to(self, dtype: Optional[str] = None): if dtype is not None: self.dtype = dtype - def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.model.batch_forward(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + return self.model.embed_tokens(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + def _index(x: te.Tensor): # x[:-1,:] b, s, d = x.shape return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") - hidden_states = self.model(inputs, total_seq_len, attention_mask) + hidden_states = self.model(input_embed, paged_kv_cache) hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") - return logits + return logits, paged_kv_cache - def prefill(self, inputs: Tensor, total_seq_len: tir.Var): - def _attention_mask(batch_size, seq_len, total_seq_len): - return te.compute( - (batch_size, 1, seq_len, total_seq_len), - lambda b, _, i, j: tir.if_then_else( - i < j - (total_seq_len - seq_len), - tir.min_value(self.dtype), - tir.max_value(self.dtype), - ), - name="attention_mask_prefill", - ) + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() - batch_size, seq_len = inputs.shape - attention_mask = op.tensor_expr_op( - _attention_mask, - name_hint="attention_mask_prefill", - args=[batch_size, seq_len, total_seq_len], - ) - return self.forward(inputs, total_seq_len, attention_mask) + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache - def decode(self, inputs: Tensor, total_seq_len: tir.Var): - batch_size, seq_len = inputs.shape - attention_mask = op.full( - shape=[batch_size, 1, seq_len, total_seq_len], - fill_value=tir.max_value(self.dtype), - dtype=self.dtype, - ) - return self.forward(inputs, total_seq_len, attention_mask) + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) def get_default_spec(self): - batch_size = 1 mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "prefill": { - "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), - "total_seq_len": int, + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", }, }, "decode": { - "inputs": nn.spec.Tensor([batch_size, 1], "int32"), - "total_seq_len": int, + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, "$": { "param_mode": "none", "effect_mode": "none", From ac57c03ccc1ec8e9d8079d6577c5c135dd80bec0 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 28 Feb 2024 20:59:23 -0500 Subject: [PATCH 014/531] [Python] Lazy import of transformers for tiktoken conversion (#1860) This PR moves the import of transformers into the function body of tiktoken tokenizer conversion, so we do not have a force dependency on transformers. --- python/mlc_chat/support/convert_tiktoken.py | 31 +++++++++++++-------- python/setup.py | 1 + 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/python/mlc_chat/support/convert_tiktoken.py b/python/mlc_chat/support/convert_tiktoken.py index 9bf0504565..f022a072c6 100644 --- a/python/mlc_chat/support/convert_tiktoken.py +++ b/python/mlc_chat/support/convert_tiktoken.py @@ -9,18 +9,6 @@ import os from typing import Dict, List, Optional -from transformers import AutoTokenizer -from transformers.models.gpt2.tokenization_gpt2 import ( - bytes_to_unicode, -) - -byte_encoder = bytes_to_unicode() - - -def token_bytes_to_string(b): - """Convert a token from bytes to a string""" - return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) - def bpe( mergeable_ranks: Dict[bytes, int], token: bytes, max_rank: Optional[int] = None @@ -44,6 +32,17 @@ def bpe( def generate_vocab_and_merges(encoder, mergeable_ranks): """Generate vocab and merges in huggingface tokenizers format""" + + from transformers.models.gpt2.tokenization_gpt2 import ( # pylint: disable=import-outside-toplevel + bytes_to_unicode, + ) + + byte_encoder = bytes_to_unicode() + + def token_bytes_to_string(b): + """Convert a token from bytes to a string""" + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + merges = [] vocab = {} for token, rank in mergeable_ranks.items(): @@ -64,6 +63,14 @@ def generate_vocab_and_merges(encoder, mergeable_ranks): def convert_tiktoken(model_path, output_dir, context_window_size=None): """Convert tiktoken tokenizers to huggingface tokenizers style""" + try: + from transformers import AutoTokenizer # pylint: disable=import-outside-toplevel + except ImportError: + raise ImportError( # pylint: disable=raise-missing-from + 'Converting tiktoken tokenizer requires the "transformers" package.' + 'Please install the "transformers" package to convert toktoken tokenizer' + ) + tiktoken_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) encoder = tiktoken_tokenizer.tokenizer diff --git a/python/setup.py b/python/setup.py index f866e9a72a..4602f55cb8 100644 --- a/python/setup.py +++ b/python/setup.py @@ -108,6 +108,7 @@ def main(): "tqdm", "tiktoken", "prompt_toolkit", + "openai", ], distclass=BinaryDistribution, **setup_kwargs, From 1f70d7177c25162d159ad3d526bfb2c8061c5638 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 29 Feb 2024 22:42:46 +0800 Subject: [PATCH 015/531] [SLM] RWKV5 World Support (#1787) This PR adds RWKV5 support with RNNState, a similar interface as PagedAttention. Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> --- cpp/llm_chat.cc | 80 ++-- mlc_llm/core.py | 12 +- .../compiler_pass/attach_to_ir_module.py | 3 +- python/mlc_chat/interface/gen_config.py | 1 + python/mlc_chat/model/model.py | 15 + python/mlc_chat/model/model_preset.py | 25 +- python/mlc_chat/model/rwkv5/__init__.py | 0 python/mlc_chat/model/rwkv5/rwkv5_loader.py | 87 ++++ python/mlc_chat/model/rwkv5/rwkv5_model.py | 433 ++++++++++++++++++ .../model/rwkv5/rwkv5_quantization.py | 52 +++ python/mlc_chat/nn/rnn_state.py | 329 +++++++++++++ tests/legacy-python/dump_intermediate.py | 68 ++- 12 files changed, 1046 insertions(+), 59 deletions(-) create mode 100644 python/mlc_chat/model/rwkv5/__init__.py create mode 100644 python/mlc_chat/model/rwkv5/rwkv5_loader.py create mode 100644 python/mlc_chat/model/rwkv5/rwkv5_model.py create mode 100644 python/mlc_chat/model/rwkv5/rwkv5_quantization.py create mode 100644 python/mlc_chat/nn/rnn_state.py diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index d3e0b6d63c..b7a426a17f 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -15,17 +15,12 @@ #include #include -#include #include #include #include #include -#include #include -#include -#include #include -#include #include #include "./metadata/model.h" @@ -244,6 +239,35 @@ struct FunctionTable { } } + void _TryInitKVState() { + PackedFunc f_flashinfer_paged_kv_cache = mod_get_func("create_flashinfer_paged_kv_cache"); + PackedFunc f_tir_paged_kv_cache = mod_get_func("create_tir_paged_kv_cache"); + PackedFunc f_create_rnn_state = mod_get_func("create_rnn_state"); + + if (f_flashinfer_paged_kv_cache.defined() || f_tir_paged_kv_cache.defined() || + f_create_rnn_state.defined()) { + // Prefer to use flashinfer paged kv cache, but fall back to tir paged kv cache + if (f_flashinfer_paged_kv_cache.defined()) { + this->use_kv_state = KVStateKind::kAttention; + this->create_kv_cache_func_ = f_flashinfer_paged_kv_cache; + } else if (f_tir_paged_kv_cache.defined()) { + this->use_kv_state = KVStateKind::kAttention; + this->create_kv_cache_func_ = f_tir_paged_kv_cache; + } else if (f_create_rnn_state.defined()) { + this->use_kv_state = KVStateKind::kRNNState; + this->create_kv_cache_func_ = f_create_rnn_state; + } + this->reset_kv_cache_func_ = get_global_func("vm.builtin.kv_state_clear"); + this->kv_cache_add_sequence_func_ = get_global_func("vm.builtin.kv_state_add_sequence"); + this->kv_cache_remove_sequence_func_ = get_global_func("vm.builtin.kv_state_remove_sequence"); + this->kv_cache_begin_forward_func_ = get_global_func("vm.builtin.kv_state_begin_forward"); + this->kv_cache_end_forward_func_ = get_global_func("vm.builtin.kv_state_end_forward"); + this->fkvcache_array_popn_ = get_global_func("vm.builtin.kv_state_popn"); + // TODO(mlc-team): enable backtracing when using paged kvcache + this->support_backtracking_kv_ = true; + } + } + void _InitFunctions() { this->prefill_func_ = mod_get_func("prefill"); this->embed_func_ = mod_get_func("embed"); @@ -251,25 +275,10 @@ struct FunctionTable { this->decode_func_ = mod_get_func("decode"); this->softmax_func_ = mod_get_func("softmax_with_temperature"); this->encoding_without_cache_func_ = mod_get_func("encoding_without_cache"); - PackedFunc f_flashinfer_paged_kv_cache = mod_get_func("create_flashinfer_paged_kv_cache"); - PackedFunc f_tir_paged_kv_cache = mod_get_func("create_tir_paged_kv_cache"); - if (f_flashinfer_paged_kv_cache != nullptr || f_tir_paged_kv_cache != nullptr) { - this->use_paged_kv_cache = true; - this->create_kv_cache_func_ = f_flashinfer_paged_kv_cache == nullptr - ? f_tir_paged_kv_cache - : f_flashinfer_paged_kv_cache; - this->reset_kv_cache_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_clear"); - this->kv_cache_add_sequence_func_ = - get_global_func("vm.builtin.paged_attention_kv_cache_add_sequence"); - this->kv_cache_remove_sequence_func_ = - get_global_func("vm.builtin.paged_attention_kv_cache_remove_sequence"); - this->kv_cache_begin_forward_func_ = - get_global_func("vm.builtin.paged_attention_kv_cache_begin_forward"); - this->kv_cache_end_forward_func_ = - get_global_func("vm.builtin.paged_attention_kv_cache_end_forward"); - this->fkvcache_array_popn_ = get_global_func("vm.builtin.paged_attention_kv_cache_popn"); - support_backtracking_kv_ = true; - } else { + _TryInitKVState(); + + // Fall back to the old way of creating kv cache if neither paged kv cache nor rnn state is used + if (!this->use_kv_state) { this->create_kv_cache_func_ = mod_get_func("create_kv_cache"); if (this->create_kv_cache_func_ == nullptr) { this->create_kv_cache_func_ = mod_get_func("_initialize_effect"); @@ -308,7 +317,14 @@ struct FunctionTable { } bool use_disco = false; - bool use_paged_kv_cache = false; + + enum KVStateKind { + kNone = 0, + kAttention = 1, + kRNNState = 2, + }; + + KVStateKind use_kv_state = kNone; Session sess{nullptr}; DRef disco_mod{nullptr}; tvm::runtime::Module local_vm{nullptr}; @@ -630,13 +646,17 @@ class LLMChat { // Step 5. Load params in nd-array cache. this->params_ = ft_.LoadParams(model_path, device_, use_presharded_weights_); // Step 6. KV cache creation. - if (ft_.use_paged_kv_cache) { + if (ft_.use_kv_state == FunctionTable::KVStateKind::kAttention) { IntTuple max_num_sequence{1}; IntTuple max_total_sequence_length{this->max_window_size_}; IntTuple prefill_chunk_size{this->prefill_chunk_size_}; IntTuple page_size{16}; this->kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence, max_total_sequence_length, prefill_chunk_size, page_size); + } else if (ft_.use_kv_state == FunctionTable::KVStateKind::kRNNState) { + IntTuple max_num_sequence{1}; + IntTuple max_history_length{1}; + this->kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence, max_history_length); } else { this->kv_cache_ = ft_.create_kv_cache_func_(); } @@ -1307,7 +1327,7 @@ class LLMChat { output_message_ = tokenizer_->Decode(output_ids_); } // resize kv to remove the context - if (ft_.use_paged_kv_cache) { + if (ft_.use_kv_state) { ft_.fkvcache_array_popn_(kv_cache_, /*seq_id=*/0, backoff); } else { ft_.fkvcache_array_popn_(kv_cache_, backoff); @@ -1337,7 +1357,7 @@ class LLMChat { if (input_tokens.size() > 1 && ft_.prefill_func_.defined()) { ObjectRef input_data = ft_.CopyToWorker0(this->GetInputTokenNDArray(input_tokens)); if (sliding_window_size_ == -1) { - if (ft_.use_paged_kv_cache) { + if (ft_.use_kv_state) { IntTuple seq_ids_tuple({0}); ShapeTuple input_len_shape = ShapeTuple({static_cast(input_tokens.size())}); ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, input_len_shape); @@ -1373,7 +1393,7 @@ class LLMChat { int64_t pos = cur_pos + i + 1 - input_tokens.size(); ShapeTuple pos_shape = ShapeTuple({pos}); if (sliding_window_size_ == -1) { - if (ft_.use_paged_kv_cache) { + if (ft_.use_kv_state) { IntTuple seq_ids_tuple({0}); IntTuple append_length({1}); ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, append_length); @@ -1488,7 +1508,7 @@ class LLMChat { // Clear kv cache void ResetKVCache() { ft_.reset_kv_cache_func_(kv_cache_); - if (ft_.use_paged_kv_cache) { + if (ft_.use_kv_state) { ft_.kv_cache_add_sequence_func_(kv_cache_, 0); } } diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 614baf74a1..35464c8669 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -581,11 +581,13 @@ def optimize_mod_pipeline( if max_seq_len: num_key_value_heads = config.get_num_key_value_heads() # pylint: disable=no-value-for-parameter - seq.append(fuse_split_rotary_embedding( - config.num_attention_heads // args.num_shards, - num_key_value_heads // args.num_shards, - config.hidden_size // args.num_shards, - config.position_embedding_base, + seq.append( + fuse_split_rotary_embedding( + config.num_attention_heads // args.num_shards, + num_key_value_heads // args.num_shards, + config.hidden_size // args.num_shards, + config.position_embedding_base, + ) ) if args.target_kind == "cuda": diff --git a/python/mlc_chat/compiler_pass/attach_to_ir_module.py b/python/mlc_chat/compiler_pass/attach_to_ir_module.py index 58507299ac..0b33647509 100644 --- a/python/mlc_chat/compiler_pass/attach_to_ir_module.py +++ b/python/mlc_chat/compiler_pass/attach_to_ir_module.py @@ -12,7 +12,8 @@ class AttachVariableBounds: # pylint: disable=too-few-public-methods """Attach variable bounds to each Relax function, which primarily helps with memory planning.""" def __init__(self, variable_bounds: Dict[str, int]): - self.variable_bounds = variable_bounds + # Specifically for RWKV workloads, which contains -1 max_seq_len + self.variable_bounds = {k: v for k, v in variable_bounds.items() if v > 0} def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: """Entrypoint""" diff --git a/python/mlc_chat/interface/gen_config.py b/python/mlc_chat/interface/gen_config.py index 35592dbf29..444c200915 100644 --- a/python/mlc_chat/interface/gen_config.py +++ b/python/mlc_chat/interface/gen_config.py @@ -194,6 +194,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b "added_tokens.json", "tokenizer_config.json", ] +# FIXME: Copy RWKV tokenizer file # pylint: disable=fixme CONV_TEMPLATES = { "chatml", diff --git a/python/mlc_chat/model/model.py b/python/mlc_chat/model/model.py index 730f5eff6b..9c82cfe9cb 100644 --- a/python/mlc_chat/model/model.py +++ b/python/mlc_chat/model/model.py @@ -20,6 +20,7 @@ from .phi import phi_loader, phi_model, phi_quantization from .qwen import qwen_loader, qwen_model, qwen_quantization from .qwen2 import qwen2_loader, qwen2_model, qwen2_quantization +from .rwkv5 import rwkv5_loader, rwkv5_model, rwkv5_quantization from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization ModelConfig = Any @@ -263,4 +264,18 @@ class Model: "ft-quant": internlm_quantization.ft_quant, }, ), + "rwkv5": Model( + name="rwkv5", + model=rwkv5_model.RWKV5_ForCasualLM, + config=rwkv5_model.RWKV5Config, + source={ + "huggingface-torch": rwkv5_loader.huggingface, + "huggingface-safetensor": rwkv5_loader.huggingface, + }, + quantize={ + "no-quant": rwkv5_quantization.no_quant, + "group-quant": rwkv5_quantization.group_quant, + "ft-quant": rwkv5_quantization.ft_quant, + }, + ), } diff --git a/python/mlc_chat/model/model_preset.py b/python/mlc_chat/model/model_preset.py index 04a20dc210..409112b6b5 100644 --- a/python/mlc_chat/model/model_preset.py +++ b/python/mlc_chat/model/model_preset.py @@ -497,7 +497,7 @@ "use_cache": True, "vocab_size": 103168, }, - # TODO(mlc-team): enable the model presets when stablized. + # TODO(mlc-team): enable the model presets when stabilized. # "gemma_2b": { # "architectures": ["GemmaForCausalLM"], # "attention_bias": False, @@ -542,4 +542,27 @@ # "transformers_version": "4.38.0.dev0", # "vocab_size": 256000, # }, + "rwkv5_3b": { + "architectures": ["RwkvForCausalLM"], + "auto_map": { + "AutoConfig": "configuration_rwkv5.Rwkv5Config", + "AutoModelForCausalLM": "modeling_rwkv5.RwkvForCausalLM", + }, + "attention_hidden_size": 2560, + "bos_token_id": 0, + "context_length": 4096, + "eos_token_id": 0, + "head_size": 64, + "hidden_size": 2560, + "intermediate_size": None, + "layer_norm_epsilon": 1e-05, + "model_type": "rwkv5", + "model_version": "5_2", + "num_hidden_layers": 32, + "rescale_every": 6, + "tie_word_embeddings": True, + "transformers_version": "4.34.0", + "use_cache": True, + "vocab_size": 65536, + }, } diff --git a/python/mlc_chat/model/rwkv5/__init__.py b/python/mlc_chat/model/rwkv5/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_chat/model/rwkv5/rwkv5_loader.py b/python/mlc_chat/model/rwkv5/rwkv5_loader.py new file mode 100644 index 0000000000..72454f4a6e --- /dev/null +++ b/python/mlc_chat/model/rwkv5/rwkv5_loader.py @@ -0,0 +1,87 @@ +""" +This file specifies how MLC's RWKV5 parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from ...loader import ExternMapping +from ...quantization import Quantization +from .rwkv5_model import RWKV5_ForCasualLM, RWKV5Config + + +def huggingface(model_config: RWKV5Config, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : RWKVConfig + The configuration of the Mistral model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = RWKV5_ForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params = model.export_tvm( # pylint: disable=unbalanced-tuple-unpacking + spec=model.get_default_spec() + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # convert time_decay + mlc_name = f"model.blocks.{i}.attention.time_decay" + hf_name = f"rwkv.blocks.{i}.attention.time_decay" + mlc_param = named_parameters[mlc_name] + if mlc_param.dtype != "float32": + raise ValueError(f"RWKV5 time_decay should be float32, got {mlc_param.dtype}") + mapping.add_mapping( + mlc_name, + [hf_name], + functools.partial( + lambda x, dtype: np.exp(-np.exp(x.astype(dtype))), + dtype=mlc_param.dtype, + ), + ) + + # rescale + if model_config.rescale_every > 0: + for name in ["feed_forward.value.weight", "attention.output.weight"]: + mlc_name = f"model.blocks.{i}.{name}" + hf_name = f"rwkv.blocks.{i}.{name}" + mlc_param = named_parameters[mlc_name] + + mapping.add_mapping( + mlc_name, + [hf_name], + functools.partial( + lambda x, dtype, t: x.astype(dtype) / (2**t), + dtype=mlc_param.dtype, + t=i // model_config.rescale_every, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + hf_name = mlc_name.replace("model", "rwkv") + mapping.add_mapping( + mlc_name, + [hf_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + return mapping diff --git a/python/mlc_chat/model/rwkv5/rwkv5_model.py b/python/mlc_chat/model/rwkv5/rwkv5_model.py new file mode 100644 index 0000000000..066ff7d9f4 --- /dev/null +++ b/python/mlc_chat/model/rwkv5/rwkv5_model.py @@ -0,0 +1,433 @@ +"""Implementation for RWKV5 architecture.""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Object, Tensor, op +from tvm.script import tir as T + +from mlc_chat.nn.rnn_state import RNNState +from mlc_chat.support import logging +from mlc_chat.support.config import ConfigBase + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class StateID: + """State ID for RWKV5.""" + + ATT_X = 0 + ATT_KV = 1 + FFN_X = 2 + + +@dataclasses.dataclass +class RWKV5Config(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the RWKV5 model.""" + + hidden_size: int + intermediate_size: int + num_hidden_layers: int + vocab_size: int + model_version: str + tensor_parallel_shards: int = 1 + rescale_every: int = 0 + head_size: int = 64 + layer_norm_epsilon: float = 1e-5 + context_window_size: int = -1 # RWKV does not have context window limitation. + prefill_chunk_size: int = 4096 + num_heads: int = 0 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.model_version != "5_2": + raise ValueError(f"Only support RWKV v5_2, got {self.model_version}.") + self.intermediate_size = self.intermediate_size or int((self.hidden_size * 3.5)) // 32 * 32 + self.num_heads = ( + self.hidden_size // self.head_size if self.num_heads == 0 else self.num_heads + ) + if self.num_heads * self.head_size != self.hidden_size: + raise ValueError( + f"hidden_size ({self.hidden_size}) must be diisible " + f"by head_size ({self.head_size})" + ) + if self.tensor_parallel_shards != 1: + raise ValueError("Only support single deice at this moment.") + + +# pylint: disable=invalid-name,missing-docstring +# pylint: disable=too-many-arguments, too-many-locals, redefined-argument-from-local +def create_wkv5_func( + num_heads: int, + head_size: int, + dtype: str, + out_dtype: str, + state_dtype: str, +): + @T.prim_func + def wkv_func( + r: T.handle, + k: T.handle, + v: T.handle, + time_decay: T.handle, + time_faaaa: T.handle, + state: T.handle, + out: T.handle, + out_state: T.handle, + ): + T.func_attr({"op_pattern": 8, "tir.noalias": True, "tir.is_scheduled": 1}) + batch_size, seq_len = T.int64(), T.int64() + # Inputs + r_buf = T.match_buffer(r, (batch_size, seq_len, num_heads, head_size), dtype=dtype) + k_buf = T.match_buffer(k, (batch_size, seq_len, num_heads, head_size), dtype=dtype) + v_buf = T.match_buffer(v, (batch_size, seq_len, num_heads, head_size), dtype=dtype) + time_decay_buf = T.match_buffer(time_decay, (num_heads, head_size), dtype="float32") + time_faaaa_buf = T.match_buffer(time_faaaa, (num_heads, head_size), dtype="float32") + state_buf = T.match_buffer( + state, (batch_size, num_heads, head_size, head_size), dtype=state_dtype + ) + # Outputs + out_buf = T.match_buffer(out, (batch_size, seq_len, num_heads, head_size), dtype=out_dtype) + out_state_buf = T.match_buffer( + out_state, (batch_size, num_heads, head_size, head_size), dtype=state_dtype + ) + for b in T.thread_binding(batch_size, thread="blockIdx.y"): + for h in T.thread_binding(num_heads, thread="blockIdx.x"): + for i in T.thread_binding(head_size, thread="threadIdx.x"): + for j in range(head_size): + with T.block("init_state"): + vb, vh, vi, vj = T.axis.remap("SSSS", [b, h, i, j]) + out_state_buf[vb, vh, vi, vj] = state_buf[vb, vh, vi, vj] + + for t in range(seq_len): + with T.block("comput"): + vb = T.axis.spatial(batch_size, b) + vt = T.axis.opaque(seq_len, t) + vh = T.axis.spatial(num_heads, h) + vi = T.axis.spatial(head_size, i) + out_buf[vb, vt, vh, vi] = 0 + + for k in range(head_size): + x = k_buf[vb, vt, vh, k] * v_buf[vb, vt, vh, vi] + out_buf[vb, vt, vh, vi] += T.cast( + r_buf[vb, vt, vh, k], out_dtype + ) * T.cast( + time_faaaa_buf[vh, k] * x + out_state_buf[vb, vh, vi, k], + out_dtype, + ) + out_state_buf[vb, vh, vi, k] = ( + out_state_buf[vb, vh, vi, k] * time_decay_buf[vh, k] + x + ) + + return wkv_func + + +# pylint: enable=too-many-arguments, too-many-locals + + +def token_shift(state: Tensor, x: Tensor): + # x.shape = (batch, seq_len, hidden_size) + # state.shape = (batch, hidden_size) + seq_len = x.shape[1] + + def _te_token_shift(state: te.Tensor, x: te.Tensor): + return te.compute( + x.shape, + lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]), + ) + + return state if seq_len == 1 else op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) + + +def last_token(x: Tensor): + # x.shape = (batch, seq_len, hidden_size) + batch, seq_len, hidden_size = x.shape + assert batch == 1 + + def _te_last_token(x: te.Tensor): + return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j]) + + return x if seq_len == 1 else op.tensor_expr_op(_te_last_token, "last_token", [x]) + + +class RWKV5_FNN(nn.Module): + def __init__(self, config: RWKV5Config, layer_id: int): + super().__init__() + self.time_mix_key = nn.Parameter((1, 1, config.hidden_size)) + self.time_mix_receptance = nn.Parameter((1, 1, config.hidden_size)) + self.key = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.receptance = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.value = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.layer_id = layer_id + + def forward(self, x: Tensor, state: RNNState): + batch, _, hidden_size = x.shape + state_x = state.get(self.layer_id, StateID.FFN_X, (batch, hidden_size), x.dtype) + state_x = token_shift(state_x, x) + xk = x * self.time_mix_key + state_x * (1.0 - self.time_mix_key) + xr = x * self.time_mix_receptance + state_x * (1.0 - self.time_mix_receptance) + last_x = last_token(x).reshape(batch, hidden_size) + state = state.set(self.layer_id, StateID.FFN_X, last_x) + r = op.sigmoid(self.receptance(xr)) + xv = op.square(op.relu(self.key(xk))) + return r * self.value(xv), state + + +class RWKV5_Attention(nn.Module): # pylint: disable=too-many-instance-attributes + """Attention layer for RWKV.""" + + def __init__(self, config: RWKV5Config, layer_id: int): + super().__init__() + self.time_decay = nn.Parameter((config.num_heads, config.head_size)) + self.time_faaaa = nn.Parameter((config.num_heads, config.head_size)) + + self.time_mix_gate = nn.Parameter((1, 1, config.hidden_size)) + self.time_mix_key = nn.Parameter((1, 1, config.hidden_size)) + self.time_mix_value = nn.Parameter((1, 1, config.hidden_size)) + self.time_mix_receptance = nn.Parameter((1, 1, config.hidden_size)) + + self.key = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.value = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.receptance = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.gate = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.output = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.ln_x = nn.GroupNorm( + config.num_heads, + config.hidden_size, + ) + self.hidden_size = config.hidden_size + self.head_size = config.head_size + self.num_heads = config.num_heads + self.layer_id = layer_id + self.dtype = "float32" + + def forward(self, x: Tensor, state: RNNState): # pylint: disable=too-many-locals + batch, seq_len, hidden_size = x.shape + assert hidden_size == self.hidden_size + B, T, H, N = ( # pylint: disable=redefined-outer-name + batch, + seq_len, + self.head_size, + self.num_heads, + ) + x_state = state.get(self.layer_id, StateID.ATT_X, (batch, self.hidden_size), x.dtype) + x_state = token_shift(x_state, x) + kv_state = state.get( + self.layer_id, + StateID.ATT_KV, + (batch, self.num_heads, self.head_size, self.head_size), + "float32", # Always use float32 for state KV. + ) + + xk = x * self.time_mix_key + x_state * (1.0 - self.time_mix_key) + xv = x * self.time_mix_value + x_state * (1.0 - self.time_mix_value) + xr = x * self.time_mix_receptance + x_state * (1.0 - self.time_mix_receptance) + xg = x * self.time_mix_gate + x_state * (1.0 - self.time_mix_gate) + + r = op.reshape(self.receptance(xr), (B, T, N, H)) + k = op.reshape(self.key(xk), (B, T, N, H)) + v = op.reshape(self.value(xv), (B, T, N, H)) + g = op.silu(self.gate(xg)) + + out, kv_state = op.tensor_ir_op( + create_wkv5_func( + self.num_heads, + self.head_size, + dtype=self.dtype, + out_dtype="float32", + state_dtype="float32", + ), + "wkv5", + [r, k, v, self.time_decay, self.time_faaaa, kv_state], + [ + Tensor.placeholder([B, T, N, H], "float32"), + Tensor.placeholder([B, N, H, H], "float32"), + ], + ) + + last_x = last_token(x).reshape(batch, hidden_size) + state = state.set(self.layer_id, StateID.ATT_X, last_x) + state = state.set(self.layer_id, StateID.ATT_KV, kv_state) + out = op.astype(self.ln_x(op.reshape(out, x.shape), channel_axis=-1, axes=[]), self.dtype) + return self.output(out * g), state + + def to(self, dtype: Optional[str] = None): + # RWKV uses special dtype, so we need to convert it. + if dtype is not None: + self.dtype = dtype + + self.time_mix_gate.to(dtype) + self.time_mix_key.to(dtype) + self.time_mix_value.to(dtype) + self.time_mix_receptance.to(dtype) + self.key.to(dtype) + self.value.to(dtype) + self.receptance.to(dtype) + self.gate.to(dtype) + self.output.to(dtype) + + # These parameters are necessary to be converted to float32. + self.time_decay.to("float32") + self.time_faaaa.to("float32") + self.ln_x.to("float32") + + +class RWKV5_Layer(nn.Module): + def __init__(self, config: RWKV5Config, layer_id: int): + super().__init__() + if layer_id == 0: + self.pre_ln = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + ) + self.ln1 = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + ) + self.ln2 = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + ) + self.attention = RWKV5_Attention(config, layer_id) + self.feed_forward = RWKV5_FNN(config, layer_id) + self.layer_id = layer_id + self.rescale_every = config.rescale_every + + def forward(self, x: Tensor, state: RNNState) -> Tensor: + if self.layer_id == 0: + x = self.pre_ln(x) + att_x, state = self.attention(self.ln1(x), state) + x += att_x + ffn_x, state = self.feed_forward(self.ln2(x), state) + x += ffn_x + if self.rescale_every > 0 and (self.layer_id + 1) % self.rescale_every == 0: + x = x / 2.0 + return x, state + + +class RWKV5_Model(nn.Module): + """Exact same as LlamaModel.""" + + def __init__(self, config: RWKV5Config): + super().__init__() + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.blocks = nn.ModuleList( + [RWKV5_Layer(config, i) for i in range(config.num_hidden_layers)] + ) + self.ln_out = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + ) + + def forward(self, input_embed: Tensor, state: RNNState): + """Forward pass of the model, passing through all decoder layers.""" + hidden_states = input_embed + for block in self.blocks: + hidden_states, state = block(hidden_states, state) + return self.ln_out(hidden_states), state + + +class RWKV5_ForCasualLM(nn.Module): # pylint: disable=too-many-instance-attributes + """Same as LlamaForCausalLM, except for the use of sliding window attention.""" + + def __init__(self, config: RWKV5Config): + self.model = RWKV5_Model(config) + self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.num_hidden_layers = config.num_hidden_layers + self.hidden_size = config.hidden_size + self.num_heads = config.num_heads + self.head_size = config.head_size + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def embed(self, input_ids: Tensor): + return self.model.embeddings(input_ids) + + def forward(self, input_embed: Tensor, state: RNNState): + """Forward pass.""" + hidden_states, state = self.model(input_embed, state) + hidden_states = last_token(hidden_states) + logits = self.head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, state + + def prefill(self, input_embed: Tensor, state: RNNState): + """Prefilling the prompt.""" + return self.forward(input_embed, state) + + def decode(self, input_embed: Tensor, state: RNNState): + """Decoding step.""" + return self.forward(input_embed, state) + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + """Softmax.""" + return op.softmax(logits / temperature, axis=-1) + + def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Object: + """Create RNN state.""" + init_values = [ + op.zeros((self.hidden_size,), dtype=self.dtype), # ATT_X + op.zeros((self.num_heads, self.head_size, self.head_size), dtype="float32"), # ATT_KV + op.zeros((self.hidden_size,), dtype=self.dtype), # FFN_X + ] + return RNNState.create( + max_batch_size=max_batch_size, + num_hidden_layers=self.num_hidden_layers, + max_history=max_history, + init_values=init_values, + ) + + def get_default_spec(self): + batch_size = 1 + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor( + [batch_size, "seq_len", self.hidden_size], self.dtype + ), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([batch_size, 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor([batch_size, 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_rnn_state": { + "max_batch_size": int, + "max_history": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/rwkv5/rwkv5_quantization.py b/python/mlc_chat/model/rwkv5/rwkv5_quantization.py new file mode 100644 index 0000000000..235519774c --- /dev/null +++ b/python/mlc_chat/model/rwkv5/rwkv5_quantization.py @@ -0,0 +1,52 @@ +"""This file specifies how MLC's RWKV5 parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from ...loader import QuantizeMapping +from ...quantization import FTQuantize, GroupQuantize, NoQuantize +from .rwkv5_model import RWKV5_ForCasualLM, RWKV5Config + + +def group_quant( + model_config: RWKV5Config, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a RWKV4-architecture model using group quantization.""" + model: nn.Module = RWKV5_ForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: RWKV5Config, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a InternLM model using FasterTransformer quantization.""" + model: nn.Module = RWKV5_ForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: RWKV5Config, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPTBigCode model without quantization.""" + model: nn.Module = RWKV5_ForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_chat/nn/rnn_state.py b/python/mlc_chat/nn/rnn_state.py new file mode 100644 index 0000000000..13dc731a47 --- /dev/null +++ b/python/mlc_chat/nn/rnn_state.py @@ -0,0 +1,329 @@ +"""RNN State modeling.""" + +from typing import Sequence, Union + +from tvm import relax as rx +from tvm import tir +from tvm.relax.frontend.nn import Object, Tensor +from tvm.script import tir as T + + +class RNNState(Object): + """The RNN State used in Space State Models""" + + @staticmethod + def create( + max_batch_size: tir.Var, + num_hidden_layers: int, + max_history: int, + init_values: Sequence[Tensor], + name: str = "rnn_state", + ) -> "RNNState": + """Create a RNN state object. + + Parameters + ---------- + max_batch_size : tir.Var + The maximum batch size. + num_hidden_layers : int + The number of hidden layers. + max_history : int + The maximum history length. + init_values : Sequence[Tensor] + The initial values of the RNN state. + """ + + bb = rx.BlockBuilder.current() + state_infos = [(v.shape, v.dtype) for v in init_values] + + f_gets = [ + bb.add_func( + RNNState.create_get_func(shape, dtype, max_batch_size, max_history, id), + f"rnn_state_get_{id}", + ) + for id, (shape, dtype) in enumerate(state_infos) + ] + f_sets = [ + bb.add_func( + RNNState.create_set_func(shape, dtype, max_batch_size, max_history, id), + f"rnn_state_set_{id}", + ) + for id, (shape, dtype) in enumerate(state_infos) + ] + + ret = RNNState( + _expr=rx.call_pure_packed( + "vm.builtin.rnn_state_create", + rx.PrimValue(num_hidden_layers), + max_batch_size, + max_history, + f_gets, + f_sets, + [v._expr for v in init_values], # pylint: disable=protected-access + sinfo_args=[rx.ObjectStructInfo()], + ), + _name=name, + ) + return ret + + def get( + self, + layer_id: int, + state_id: int, + shape: Sequence[tir.PrimExpr], + dtype: str, + ) -> Tensor: + """Get the state of the RNN layer. + + - If there is only one sequence, we can directly use the storage memory, + without copying the data. + - If there are multiple sequences, we need to copy the data to get a contiguous + memory. + + Parameters + ---------- + layer_id : int + The layer id. + state_id : int + The state id. + shape : Sequence[tir.PrimExpr] + The shape of the state tensor. + dtype: str + The data type of the state tensor. + + Returns + ------- + Tensor + The state tensor, with shape `(batch_size, *state_size)`. + """ + bb = rx.BlockBuilder.current() + + return Tensor( + _expr=bb.emit( + rx.call_dps_packed( + "vm.builtin.rnn_state_get", + [self._expr, layer_id, state_id], + out_sinfo=rx.TensorStructInfo(shape, dtype), + ) + ) + ) + + def set(self, layer_id: int, state_id: int, value: Tensor) -> "RNNState": + """Set the state of the RNN layer. + + Parameters + ---------- + layer_id : int + The layer id. + state_id : int + The state id. + value : Tensor + The state tensor, with shape `(batch_size, *state_size)`. + """ + bb = rx.BlockBuilder.current() + return RNNState( + _expr=bb.emit( + rx.call_pure_packed( + "vm.builtin.rnn_state_set", + self._expr, + rx.PrimValue(layer_id), + rx.PrimValue(state_id), + value._expr, # pylint: disable=protected-access + sinfo_args=[rx.ObjectStructInfo()], + ) + ), + _name="rnn_state_set", + ) + + @staticmethod + def create_get_func( + shape: Sequence[Union[int, tir.Var]], + dtype: str, + max_batch_size: Union[int, tir.Var], + max_history: Union[int, tir.Var], + state_id: int, + ) -> tir.PrimFunc: + """Create the get function with given state shape. + + Parameters + ---------- + shape : Sequence[Union[int, tir.Var]] + The shape of the state tensor. + + dtype: str + The data type of the state tensor. + + max_batch_size : Union[int, tir.Var] + The maximum batch size. + + max_history : Union[int, tir.Var] + The maximum history length. + + state_id : int + The id of the state, used for naming the function. + + Returns + ------- + tir.PrimFunc + The get function. + """ + + def _func_one_dim(): + @T.prim_func + def f( + var_storage: T.handle, + var_seq_slot_ids: T.handle, + var_history_slot_ids: T.handle, + var_output: T.handle, + ): + batch_size = T.int32(is_size_var=True) + T.func_attr({"global_symbol": f"rnn_state_get_{state_id}"}) + + storage = T.match_buffer( + var_storage, (max_batch_size, max_history, shape[0]), dtype + ) + seq_slot_ids = T.match_buffer(var_seq_slot_ids, (batch_size,), "int32") + history_slot_ids = T.match_buffer(var_history_slot_ids, (batch_size,), "int32") + output = T.match_buffer(var_output, (batch_size, shape[0]), dtype) + + for i in range(batch_size): + for s in range(shape[0]): + with T.block("copy"): + vi, vs = T.axis.remap("SS", [i, s]) + seq_id: T.int32 = seq_slot_ids[vi] + history_id: T.int32 = history_slot_ids[vi] + output[vi, vs] = storage[seq_id, history_id, vs] + + return f + + def _func_high_dim(): + # Add a wrapper function to avoid parse the following code when len(shape) = 1 + @T.prim_func + def f( + var_storage: T.handle, + var_seq_slot_ids: T.handle, + var_history_slot_ids: T.handle, + var_output: T.handle, + ): + batch_size = T.int32(is_size_var=True) + T.func_attr({"global_symbol": f"rnn_state_get_{state_id}"}) + + storage = T.match_buffer(var_storage, (max_batch_size, max_history, *shape), dtype) + seq_slot_ids = T.match_buffer(var_seq_slot_ids, (batch_size,), "int32") + history_slot_ids = T.match_buffer(var_history_slot_ids, (batch_size,), "int32") + output = T.match_buffer(var_output, (batch_size, *shape), dtype) + + for i in range(batch_size): + for s in T.grid(*shape): + with T.block("copy"): + vi, *vs = T.axis.remap("S" * (len(shape) + 1), [i, *s]) + seq_id: T.int32 = seq_slot_ids[vi] + history_id: T.int32 = history_slot_ids[vi] + # The following line is equivalent to: + # `output[vi, *vs] = storage[seq_id, history_id, *vs]` + # However, unpacking operator in subscript requires Python 3.11 or newer + T.buffer_store( + output, T.BufferLoad(storage, [seq_id, history_id, *vs]), [vi, *vs] + ) + + return f + + return _func_one_dim() if len(shape) == 1 else _func_high_dim() + + @staticmethod + def create_set_func( + shape: Sequence[Union[int, tir.Var]], + dtype: str, + max_batch_size: Union[int, tir.Var], + max_history: Union[int, tir.Var], + state_id: int, + ) -> tir.PrimFunc: + """Create the set function with given state shape. + + Parameters + ---------- + shape : Sequence[Union[int, tir.Var]] + The shape of the state tensor. + + dtype: str + The data type of the state tensor. + + max_batch_size : Union[int, tir.Var] + The maximum batch size. + + max_history : Union[int, tir.Var] + The maximum history length. + + state_id : int + The id of the state, used for naming the function. + + Returns + ------- + tir.PrimFunc + The set function. + """ + + def _func_one_dim(): + @T.prim_func + def f( + var_storage: T.handle, + var_seq_slot_ids: T.handle, + var_history_slot_ids: T.handle, + var_data: T.handle, + ): + batch_size = T.int32(is_size_var=True) + T.func_attr({"global_symbol": f"rnn_state_set_{state_id}"}) + + storage = T.match_buffer( + var_storage, (max_batch_size, max_history, shape[0]), dtype + ) + seq_slot_ids = T.match_buffer(var_seq_slot_ids, (batch_size,), "int32") + history_slot_ids = T.match_buffer(var_history_slot_ids, (batch_size,), "int32") + data = T.match_buffer(var_data, (batch_size, shape[0]), dtype) + + for i in range(batch_size): + for s in range(shape[0]): + with T.block("copy"): + vi, vs = T.axis.remap("SS", [i, s]) + seq_id: T.int32 = seq_slot_ids[vi] + history_id: T.int32 = (history_slot_ids[vi] + 1) % T.cast( + max_history, "int32" + ) + storage[seq_id, history_id, vs] = data[vi, vs] + + return f + + def _func_high_dim(): + @T.prim_func + def f( + var_storage: T.handle, + var_seq_slot_ids: T.handle, + var_history_slot_ids: T.handle, + var_data: T.handle, + ): + batch_size = T.int32(is_size_var=True) + T.func_attr({"global_symbol": f"rnn_state_set_{state_id}"}) + + storage = T.match_buffer(var_storage, (max_batch_size, max_history, *shape), dtype) + seq_slot_ids = T.match_buffer(var_seq_slot_ids, (batch_size,), "int32") + history_slot_ids = T.match_buffer(var_history_slot_ids, (batch_size,), "int32") + data = T.match_buffer(var_data, (batch_size, *shape), dtype) + + for i in range(batch_size): + for s in T.grid(*shape): + with T.block("copy"): + vi, *vs = T.axis.remap("S" * (len(shape) + 1), [i, *s]) + seq_id: T.int32 = seq_slot_ids[vi] + history_id: T.int32 = (history_slot_ids[vi] + 1) % T.cast( + max_history, "int32" + ) + # The following line is equivalent to: + # `storage[seq_id, history_id, *vs] = data[vi, *vs]` + # However, unpacking operator in subscript requires Python 3.11 or newer + T.buffer_store( + storage, T.BufferLoad(data, [vi, *vs]), [seq_id, history_id, *vs] + ) + + return f + + return _func_one_dim() if len(shape) == 1 else _func_high_dim() diff --git a/tests/legacy-python/dump_intermediate.py b/tests/legacy-python/dump_intermediate.py index 59bcd85eca..e1da427c00 100644 --- a/tests/legacy-python/dump_intermediate.py +++ b/tests/legacy-python/dump_intermediate.py @@ -7,10 +7,10 @@ import numpy as np import torch import tvm +from mlc_llm import utils from transformers import AutoTokenizer from tvm import relax - -from mlc_llm import utils +from tvm.runtime import ShapeTuple # pylint: disable=redefined-outer-name @@ -120,33 +120,57 @@ def deploy_to_pipeline(args) -> None: print("Tokenizing...") inputs = tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy() + inputs = tvm.nd.array(inputs, device=primary_device) first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device) - seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) - second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) - kv_caches = state.vm["_initialize_effect"]() + + kv_cache_method: str + if state.vm.module.implements_function( + "create_tir_paged_kv_cache" + ) or state.vm.module.implements_function("create_flashinfer_paged_kv_cache"): + kv_cache_method = "paged_kv_cache" + raise NotImplementedError() + elif state.vm.module.implements_function("create_rnn_state"): + kv_cache_method = "rnn_state" + max_num_seq, history = ShapeTuple([1]), ShapeTuple([1]) + kv_caches = state.vm.module["create_rnn_state"](max_num_seq, history) + f_add_seq = tvm.get_global_func("vm.builtin.kv_state_add_sequence") + f_begin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward") + f_end_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward") + elif state.vm.module.implements_function("_initialize_effect"): + kv_cache_method = "effect" + kv_caches = state.vm.module["_initialize_effect"]() + else: + raise ValueError("Unknown how to create KVCache") + + def forward(inputs, kv_caches, total_seq_len): + hidden = state.vm["embed"](inputs, const_params) + if inputs.shape[1] > 1: + f_forward = state.vm["prefill"] + else: + f_forward = state.vm["decode"] + if kv_cache_method == "effect": + logits, kv_caches = f_forward( + hidden, ShapeTuple([total_seq_len]), kv_caches, const_params + ) + else: + seq_ids, input_shape = ShapeTuple([0]), ShapeTuple([inputs.shape[1]]) + f_begin_forward(kv_caches, seq_ids, input_shape) + logits, kv_caches = f_forward(hidden, kv_caches, const_params) + f_end_forward(kv_caches) + + return logits, kv_caches print("Running inference...") - print("======================= Starts Encoding =======================") - try: - prefill_func = state.vm["prefill"] - except AttributeError: - prefill_func = None + print("======================= Starts Prefilling ======================") - if inputs.shape[1] > 1 and prefill_func: - inputs = tvm.nd.array(inputs, device=primary_device) - logits, kv_caches = prefill_func(inputs, seq_len_shape, kv_caches, const_params) - else: - for i in range(inputs.shape[1]): - input_slice = tvm.nd.array(inputs[:, i : i + 1], device=primary_device) - logits, kv_caches = state.vm["decode"]( - input_slice, seq_len_shape, kv_caches, const_params - ) + if kv_cache_method != "effect": + f_add_seq(kv_caches, 0) + logits, kv_caches = forward(inputs, kv_caches, inputs.shape[1]) print("======================= Starts Decoding =======================") - logits, kv_caches = state.vm["decode"]( - first_sampled_token, second_seq_len_shape, kv_caches, const_params - ) + + logits, kv_caches = forward(first_sampled_token, kv_caches, inputs.shape[1] + 1) def _parse_args(): From eb465ec8fdba280dc0a4ebbc287b7bc5ea2a6473 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 1 Mar 2024 05:39:43 +0800 Subject: [PATCH 016/531] [Serving] Register the ChatML conversation template (#1862) Following #1854 , this pr registers the ChatML conversation template. --- python/mlc_chat/conversation_template.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/mlc_chat/conversation_template.py b/python/mlc_chat/conversation_template.py index 9ec0a6bfee..a5dd9dfe6a 100644 --- a/python/mlc_chat/conversation_template.py +++ b/python/mlc_chat/conversation_template.py @@ -92,3 +92,25 @@ def get_conv_template(name: str) -> Optional[Conversation]: stop_token_ids=[2], ) ) + +# ChatML +ConvTemplateRegistry.register_conv_template( + Conversation( + name="chatml", + system_template=f"<|im_start|>{MessagePlaceholders.SYSTEM.value}<|im_end|> ", + system_message=( + "system A conversation between a user and an LLM-based AI assistant. The " + "assistant gives helpful and honest answers." + ), + roles={ + "user": "<|im_start|>user", + "assistant": "<|im_start|>assistant", + "tool": "<|im_start|>user", + }, + seps=["<|im_end|>\n"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|im_end|>"], + stop_token_ids=[2], + ) +) From 5bbe2049c9c9074afde75fc79f4835aac1597c3a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 1 Mar 2024 10:15:15 -0600 Subject: [PATCH 017/531] [Utils][Transform] Added SetEntryFuncs transform (#1855) Sets the entry functions for a module. This utility is intended for cases where only module contains several externally-exposed functions, and only one is desired for use. (e.g. Separating out a `transform_params` function from an `IRModule` that also contains inference functions.) This commit only updates the external visibility, after which `relax.transform.DeadCodeElimination()` can be applied. --- mlc_llm/transform/__init__.py | 1 + mlc_llm/transform/set_entry_funcs.py | 70 ++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 mlc_llm/transform/set_entry_funcs.py diff --git a/mlc_llm/transform/__init__.py b/mlc_llm/transform/__init__.py index 2c67369a8e..758d8a1081 100644 --- a/mlc_llm/transform/__init__.py +++ b/mlc_llm/transform/__init__.py @@ -7,3 +7,4 @@ from .reorder_transform_func import ReorderTransformFunc from .rewrite_attention import rewrite_attention from .transpose_matmul import FuseTransposeMatmul, FuseTranspose1Matmul, FuseTranspose2Matmul +from .set_entry_funcs import SetEntryFuncs diff --git a/mlc_llm/transform/set_entry_funcs.py b/mlc_llm/transform/set_entry_funcs.py new file mode 100644 index 0000000000..714da06dd7 --- /dev/null +++ b/mlc_llm/transform/set_entry_funcs.py @@ -0,0 +1,70 @@ +import re + +from typing import List, Union + +import tvm +from tvm.ir import GlobalVar + + +def SetEntryFuncs(*entry_funcs: List[Union[GlobalVar, str]]) -> tvm.ir.transform.Pass: + """Update which functions are externally-exposed + + All functions whose GlobalVar is contained `entry_funcs` list, or + whose name matches a regular expression in `entry_funcs`, are set + as externally exposed. All other functions are set as internal. + + This pass does not add or remove any functions from the + `IRModule`. This pass may result in functions no longer being + used by any externally-exposed function. In these cases, users + may use the `relax.transform.DeadCodeElimination` pass to remove + any unnecessary functions. + + Parameters + ---------- + entry_funcs: List[Union[GlobalVar, str]] + + Specifies which functions that should be externally exposed, + either by GlobalVar or by regular expression. + + Returns + ------- + transform: tvm.ir.transform.Pass + + The IRModule-to-IRModule transformation + """ + + def is_entry_func(gvar: GlobalVar) -> bool: + for entry_func in entry_funcs: + if isinstance(entry_func, GlobalVar): + if entry_func.same_as(gvar): + return True + elif isinstance(entry_func, str): + if re.fullmatch(entry_func, gvar.name_hint): + return True + else: + raise TypeError( + f"SetEntryFuncs requires all arguments to be a GlobalVar or a str. " + f"However, argument {entry_func} has type {type(entry_func)}." + ) + + def is_exposed(func: tvm.ir.BaseFunc) -> bool: + return func.attrs is not None and "global_symbol" in func.attrs + + @tvm.ir.transform.module_pass(opt_level=0, name="SetEntryFuncs") + def transform(mod: tvm.IRModule, _pass_context) -> tvm.IRModule: + updates = {} + for gvar, func in mod.functions.items(): + if is_entry_func(gvar): + if not is_exposed(func): + updates[gvar] = func.with_attr("global_symbol", gvar.name_hint) + else: + if is_exposed(func): + updates[gvar] = func.without_attr("global_symbol") + + if updates: + mod = mod.clone() + mod.update(updates) + + return mod + + return transform From eb6645232ba71b27d9b91eb8fd62dc42c8db5e54 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 1 Mar 2024 10:15:23 -0600 Subject: [PATCH 018/531] [Build] Update transform_params_for_each_rank to IRModule pass (#1856) This allows it to be used as part of a optimization pipeline specified as a `tvm.ir.transform.Sequential`. --- mlc_llm/core.py | 2 +- mlc_llm/relax_model/commons.py | 2 +- mlc_llm/relax_model/param_manager.py | 98 +++++++++++++++------ mlc_llm/transform/reorder_transform_func.py | 17 ++-- 4 files changed, 84 insertions(+), 35 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 35464c8669..065b3a29ac 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -859,7 +859,7 @@ def build_model_from_args(args: argparse.Namespace): # Run pre-sharding if required if args.num_shards > 1 and args.use_presharded_weights: mod_shard = create_shard_transformation_func(param_manager, args, model_config) - mod_shard = transform_params_for_each_rank(mod_shard, num_shards=args.num_shards) + mod_shard = transform_params_for_each_rank(num_shards=args.num_shards)(mod_shard) parameter_transforms.append(mod_shard) # Chain all parameter transforms together. This allows diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py index be0c477ebc..d55c2ca5e6 100644 --- a/mlc_llm/relax_model/commons.py +++ b/mlc_llm/relax_model/commons.py @@ -286,7 +286,7 @@ def create_shard_transformation_func(param_manager, args, model_config) -> tvm.I ) bb = relax.BlockBuilder() # pylint: disable=invalid-name - with bb.function("transform_params"): + with bb.function("transform_params", attrs={"num_input": 1}): rank = tir.SizeVar("rank", "int64") # TODO(Lunderberg): Support primitive inputs to relax # functions. Currently, using a PrimStructInfo as the diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index f776db3f1e..9a59b933b8 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -1081,8 +1081,8 @@ def _create_quantize_func(param_manager: ParamManager) -> tvm.IRModule: def transform_params_for_each_rank( - mod: tvm.IRModule, num_shards: int, rank_argument_name: str = "rank_arg" -) -> tvm.IRModule: + num_shards: int, rank_argument_name: str = "rank_arg" +) -> tvm.ir.transform.Pass: """Update a parameter transform to apply across all ranks For use in generating a pre-sharded set of weights. Given a @@ -1113,31 +1113,47 @@ def transform_params_for_each_rank( The modified parameter transformation """ - generic_transform = mod["transform_params"] - tensor_params = generic_transform.params[1:] - bb = relax.BlockBuilder() + @tvm.ir.transform.module_pass(opt_level=0, name="ParamManager.transform_params_for_each_rank") + def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule: + generic_transform = mod["transform_params"] - with bb.function("transform_params", params=tensor_params): - output = [] - for rank in range(num_shards): - # TODO(Lunderberg): Implement this in terms of a - # generic utility that inlines local functions. - func = generic_transform - func = func.bind_params({rank_argument_name: relax.ShapeExpr([rank])}) - func = relax.utils.copy_with_new_vars(func) - func = func.bind_params( - {var: tensor_param for (var, tensor_param) in zip(func.params, tensor_params)} - ) - shard_tuple = func.body - output.extend([shard_tuple[i] for i in range(len(tensor_params))]) + if generic_transform.attrs is not None and "num_input" in generic_transform.attrs: + num_input = generic_transform.attrs["num_input"].value + else: + num_input = 0 - with bb.dataflow(): - gv = bb.emit_output(relax.Tuple(output)) - bb.emit_func_output(gv) + if num_input == 0: + return mod + + tensor_params = generic_transform.params[num_input:] + attrs = {"num_input": num_input - 1} - mod["transform_params"] = bb.get()["transform_params"] - return mod + bb = relax.BlockBuilder() + + with bb.function("transform_params", params=tensor_params, attrs=attrs): + output = [] + for rank in range(num_shards): + # TODO(Lunderberg): Implement this in terms of a + # generic utility that inlines local functions. + func = generic_transform + func = func.bind_params({rank_argument_name: relax.ShapeExpr([rank])}) + func = relax.utils.copy_with_new_vars(func) + func = func.bind_params( + {var: tensor_param for (var, tensor_param) in zip(func.params, tensor_params)} + ) + shard_tuple = func.body + output.extend([shard_tuple[i] for i in range(len(tensor_params))]) + + with bb.dataflow(): + gv = bb.emit_output(relax.Tuple(output)) + bb.emit_func_output(gv) + + mod = mod.clone() + mod["transform_params"] = bb.get()["transform_params"] + return mod + + return transform_func def chain_parameter_transforms(mod_a: tvm.IRModule, mod_b: tvm.IRModule) -> tvm.IRModule: @@ -1181,12 +1197,44 @@ def chain_parameter_transforms(mod_a: tvm.IRModule, mod_b: tvm.IRModule) -> tvm. bb = relax.BlockBuilder() - with bb.function("transform_params", params=func_a.params): + def get_num_input_attr(func): + if func.attrs is None: + return 0 + + attrs = func.attrs + if "num_input" not in attrs: + return 0 + num_input = attrs["num_input"] + + assert isinstance(num_input, tvm.tir.IntImm) + return num_input.value + + # Either func_a or func_b may have parameters that are provided at + # a later point. The chaining of parameter transforms assumes + # that all model weights accepted by func_b are produced by + # func_a. If func_b accepts non-weight parameters (e.g. the GPU + # rank), these must still be provided. + func_a_num_input = get_num_input_attr(func_a) + func_b_num_input = get_num_input_attr(func_b) + + output_num_input = func_a_num_input + func_b_num_input + output_params = [ + *func_a.params[:func_a_num_input], + *func_b.params[:func_b_num_input], + *func_a.params[func_a_num_input:], + ] + + with bb.function( + "transform_params", params=output_params, attrs={"num_input": output_num_input} + ): with bb.dataflow(): # TODO(Lunderberg): Implement this in terms of a # generic utility that inlines local functions. func_a_output = bb.emit(func_a.body) - func_b_param_map = {param: expr for (param, expr) in zip(func_b.params, func_a_output)} + func_b_param_map = { + param: expr + for (param, expr) in zip(func_b.params[func_b_num_input:], func_a_output) + } func_b_output = func_b.bind_params(func_b_param_map).body gv = bb.emit_output(func_b_output) bb.emit_func_output(gv) diff --git a/mlc_llm/transform/reorder_transform_func.py b/mlc_llm/transform/reorder_transform_func.py index 40403c822e..aa5ff9f81b 100644 --- a/mlc_llm/transform/reorder_transform_func.py +++ b/mlc_llm/transform/reorder_transform_func.py @@ -37,11 +37,7 @@ def analyze_func( func: relax.Function, pidx2binname: Dict[int, str], -) -> Tuple[ - List[relax.Binding], - Dict[relax.Var, List[relax.Binding]], - Dict[relax.Binding, int], -]: +) -> Tuple[List[relax.Binding], Dict[relax.Var, List[relax.Binding]], Dict[relax.Binding, int],]: """Binding grouping analysis function. It takes the function to be analyzed, and mapping from each raw tensor index to the name of the binary file where it resides. @@ -85,14 +81,19 @@ def analyze_func( var_users: Dict[relax.Var, List[relax.Binding]] = {} num_depending_vars: Dict[relax.Binding, int] = {} + if func.attrs is not None and "num_input" in func.attrs: + num_input = func.attrs["num_input"].value + else: + num_input = 0 + # Sanity check on the function pattern. - assert len(func.params) == 1 + assert len(func.params) == num_input + 1 assert isinstance(func.body, relax.SeqExpr) assert len(func.body.blocks) == 1 assert isinstance(func.body.blocks[0], relax.DataflowBlock) assert func.body.blocks[0].bindings[-1].var.same_as(func.body.body) - params = func.params[0] + model_param_tuple = func.params[num_input] bindings = func.body.blocks[0].bindings # Go through each binding except the last one. (The last one is the output @@ -102,7 +103,7 @@ def analyze_func( binding_var_set.add(binding.var) var_users[binding.var] = [] - if isinstance(value, relax.TupleGetItem) and value.tuple_value.same_as(params): + if isinstance(value, relax.TupleGetItem) and value.tuple_value.same_as(model_param_tuple): # For weight fetching bindings (`lv = params[idx]`), we group them # according to the binary file name. pidx = value.index From 5f2a06e5508eba19bfd5e9156ddad8b88329f7e6 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Sat, 2 Mar 2024 12:11:48 +0800 Subject: [PATCH 019/531] [Serving][Grammar] Integrate JSON grammar into the generation pipeline (#1867) This PR is the 3rd part of the grammar-guided generation. This intregrates the grammar framework into the generation process, and supports JSON output for now. The API this PR provides is compatible with the OpenAI api. ### APIs #### Python API ``` @dataclass class ResponseFormat: type: Literal["text", "json_object"] = "text" json_schema: Optional[str] = None @dataclass class GenerationConfig: response_format: ResponseFormat = ResponseFormat(type="text") ``` #### Rest API ``` response_format: { "type": "text" } # text generation, by default response_format: { "type": "json_object" } # json generation response_format: { "type": "json_object", json_schema="..."} # json generation with schema ``` JSON generation with schema is not supported yet, but has been planned to be realized in the future. ### Performance #### Without JSON ``` Single token prefill latency: 891.2234 ms/tok Single token decode latency: 31.3399 ms/tok Prefill token throughput: 4693.3077 tok/s Decode token throughput: 226.4406 tok/s Overall token throughput: 470.3180 tok/s ``` #### With JSON ``` Single token prefill latency: 219.2287 ms/tok Single token decode latency: 29.1399 ms/tok Prefill token throughput: 7392.1555 tok/s Decode token throughput: 179.2296 tok/s Overall token throughput: 1052.1996 tok/s ``` We observed a slight decrease in performance under JSON mode. This will be further optimized in the future. --- cpp/serve/config.cc | 27 ++++ cpp/serve/config.h | 9 ++ cpp/serve/engine.cc | 9 +- cpp/serve/engine_actions/action_commons.cc | 8 ++ cpp/serve/function_table.cc | 1 + cpp/serve/function_table.h | 2 +- cpp/serve/grammar/grammar.cc | 6 +- cpp/serve/grammar/grammar_state_matcher.cc | 33 ++--- cpp/serve/grammar/grammar_state_matcher.h | 2 +- .../grammar/grammar_state_matcher_preproc.h | 14 +- cpp/serve/grammar/support.h | 4 +- cpp/serve/logit_processor.cc | 63 ++++++--- cpp/serve/request_state.cc | 28 ++-- cpp/serve/request_state.h | 36 +++-- .../compiler_pass/attach_to_ir_module.py | 4 +- .../mlc_chat/protocol/openai_api_protocol.py | 10 +- python/mlc_chat/protocol/protocol_utils.py | 5 +- python/mlc_chat/serve/config.py | 33 ++++- tests/python/serve/server/test_server.py | 113 +++++++++++++++ .../serve/test_grammar_state_matcher.py | 86 ++++++------ .../python/serve/test_serve_engine_grammar.py | 131 ++++++++++++++++++ 21 files changed, 505 insertions(+), 119 deletions(-) create mode 100644 tests/python/serve/test_serve_engine_grammar.py diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index fde09ac32c..341c52b498 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -130,6 +130,26 @@ GenerationConfig::GenerationConfig(String config_json_str) { CHECK(config["ignore_eos"].is()); n->ignore_eos = config["ignore_eos"].get(); } + + if (config.count("response_format")) { + CHECK(config["response_format"].is()); + picojson::object response_format_json = config["response_format"].get(); + ResponseFormat response_format; + if (response_format_json.count("type")) { + CHECK(response_format_json["type"].is()); + response_format.type = response_format_json["type"].get(); + } + if (response_format_json.count("json_schema")) { + if (response_format_json["json_schema"].is()) { + response_format.json_schema = NullOpt; + } else { + CHECK(response_format_json["json_schema"].is()); + response_format.json_schema = response_format_json["json_schema"].get(); + } + } + n->response_format = response_format; + } + data_ = std::move(n); } @@ -166,6 +186,13 @@ String GenerationConfigNode::AsJSONString() const { // Params for benchmarking. Not the part of openai spec. config["ignore_eos"] = picojson::value(this->ignore_eos); + picojson::object response_format; + response_format["type"] = picojson::value(this->response_format.type); + response_format["json_schema"] = this->response_format.json_schema + ? picojson::value(this->response_format.json_schema.value()) + : picojson::value(); + config["response_format"] = picojson::value(response_format); + return picojson::value(config).serialize(true); } diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 9e316bf370..bd6d0ba0c9 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -13,10 +13,17 @@ namespace mlc { namespace llm { namespace serve { +using namespace tvm; using namespace tvm::runtime; /****************** GenerationConfig ******************/ +/*! \brief The response format of a request. */ +struct ResponseFormat { + String type = "text"; + Optional json_schema = NullOpt; +}; + /*! \brief The generation configuration of a request. */ class GenerationConfigNode : public Object { public: @@ -35,6 +42,8 @@ class GenerationConfigNode : public Object { Array stop_strs; std::vector stop_token_ids; + ResponseFormat response_format; + String AsJSONString() const; static constexpr const char* _type_key = "mlc.serve.GenerationConfig"; diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 5c2e2f0be9..1fce1d8ca6 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -19,6 +19,7 @@ #include "engine_actions/action_commons.h" #include "engine_state.h" #include "event_trace_recorder.h" +#include "grammar/grammar_state_matcher.h" #include "logit_processor.h" #include "model.h" #include "request.h" @@ -56,6 +57,8 @@ class EngineImpl : public Engine { this->trace_recorder_ = trace_recorder; this->tokenizer_ = Tokenizer::FromPath(tokenizer_path); this->token_table_ = tokenizer_->TokenTable(); + this->json_grammar_state_init_ctx_ = + GrammarStateMatcher::CreateInitContext(BNFGrammar::GetGrammarOfJSON(), this->token_table_); // Step 2. Initialize each model independently. // Create the logit processor and sampler. this->models_.clear(); @@ -133,8 +136,8 @@ class EngineImpl : public Engine { // Append to the waiting queue and create the request state. estate_->waiting_queue.push_back(request); estate_->request_states.emplace( - request->id, - RequestState(request, models_.size(), estate_->id_manager.GetNewId(), token_table_)); + request->id, RequestState(request, models_.size(), estate_->id_manager.GetNewId(), + token_table_, json_grammar_state_init_ctx_)); } void AbortRequest(const String& request_id) final { @@ -208,6 +211,8 @@ class EngineImpl : public Engine { int max_single_sequence_length_; Tokenizer tokenizer_; std::vector token_table_; + // The initial context for the grammar state matching of JSON. + std::shared_ptr json_grammar_state_init_ctx_; // Models Array models_; // Request stream callback function diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index 5526bed2d1..e737a048ef 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -66,6 +66,14 @@ void ActionStepPostProcess(Array requests, EngineState estate, Arraymstates[0]->grammar_state_matcher) { + const auto& grammar_state_matcher = rstate->mstates[0]->grammar_state_matcher.value(); + for (auto token_id : delta_token_ids) { + grammar_state_matcher->AcceptToken(token_id); + } + } + callback_delta_outputs.push_back(RequestStreamOutput( request->id, delta_token_ids, request->generation_cfg->logprobs > 0 ? delta_logprob_json_strs : Optional>(), diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index c4ebbe4be3..5f5dc59816 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -87,6 +87,7 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), lib_path, null_device); + this->disco_buffers = Map(); this->mod_get_func = [this, fmodule_get_function = sess->GetGlobalFunc("runtime.ModuleGetFunction")]( const std::string& name) -> PackedFunc { diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index e37b0e6f89..956f19e02e 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -55,7 +55,7 @@ struct FunctionTable { bool use_disco = false; Session sess{nullptr}; DRef disco_mod{nullptr}; - Map disco_buffers; + Map disco_buffers{nullptr}; tvm::runtime::Module local_vm{nullptr}; picojson::object model_config; diff --git a/cpp/serve/grammar/grammar.cc b/cpp/serve/grammar/grammar.cc index 89d3956501..697fb29d60 100644 --- a/cpp/serve/grammar/grammar.cc +++ b/cpp/serve/grammar/grammar.cc @@ -43,8 +43,8 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromJSON").set_body_typed([](String jso const std::string kJSONGrammarString = R"( main ::= ( - "{" ws members_or_embrace ws | - "[" ws elements_or_embrace ws + "{" ws members_or_embrace | + "[" ws elements_or_embrace ) value ::= ( "{" ws members_or_embrace | @@ -102,7 +102,7 @@ elements_rest ::= ( "\n" ws "," ws elements | "\t" ws "," ws elements ) -characters ::= "" | [^"\\] characters | "\\" escape characters +characters ::= "" | [^"\\\r\n] characters | "\\" escape characters escape ::= "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] digits ::= [0-9] | [0-9] digits fraction ::= "" | "." digits diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index 79cc8a351a..a0b2350a2e 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -137,7 +137,7 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm void Rollback(int num_tokens) final; - int MaxRollbackSteps() final { return max_rollback_steps_; } + int MaxRollbackSteps() const final { return max_rollback_steps_; } void ResetState() final { stack_tops_history_.Reset(); @@ -176,7 +176,8 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm }; bool GrammarStateMatcherNodeImpl::AcceptToken(int32_t token_id) { - CHECK(init_ctx_->codepoint_tokens_lookup.count(token_id) > 0); + CHECK(init_ctx_->codepoint_tokens_lookup.count(token_id) > 0) + << "Token id " << token_id << " is not supported in generation"; const auto& token = init_ctx_->codepoint_tokens_lookup[token_id].token; for (auto codepoint : token) { if (!AcceptCodepoint(codepoint, false)) { @@ -323,7 +324,9 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm } void GrammarStateMatcherNodeImpl::Rollback(int num_tokens) { - CHECK(num_tokens <= token_size_history_.size()); + CHECK(num_tokens <= token_size_history_.size()) + << "Intended to rollback " << num_tokens << " tokens, but only the last " + << token_size_history_.size() << " steps of history are saved"; while (num_tokens > 0) { int steps = token_size_history_.back(); RollbackCodepoints(steps); @@ -338,8 +341,9 @@ void GrammarStateMatcherNodeImpl::SetTokenBitmask(DLTensor* next_token_bitmask, bool can_reach_end) { // accepted_ids = Union(accepted_indices, all_tokens - rejected_indices) // rejected_ids = Intersect(all_tokens - accepted_indices, rejected_indices) - DCHECK(next_token_bitmask->dtype.code == kDLUInt && next_token_bitmask->dtype.bits == 32 && - next_token_bitmask->data && next_token_bitmask->ndim == 1 && next_token_bitmask->shape); + CHECK(next_token_bitmask->dtype.code == kDLUInt && next_token_bitmask->dtype.bits == 32 && + next_token_bitmask->data && next_token_bitmask->ndim == 1 && next_token_bitmask->shape) + << "The provied bitmask's shape or dtype is not valid."; BitsetManager next_token_bitset(reinterpret_cast(next_token_bitmask->data), next_token_bitmask->shape[0]); @@ -411,7 +415,7 @@ GrammarStateMatcher::GrammarStateMatcher(std::shared_ptr tokenizer, int max_rollback_steps) { - auto init_ctx = CreateInitContext( + auto init_ctx = GrammarStateMatcher::CreateInitContext( grammar, tokenizer ? tokenizer.value()->TokenTable() : std::vector()); return GrammarStateMatcher(init_ctx, max_rollback_steps); }); @@ -424,7 +428,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenTable") token_table.push_back(args[i]); } int max_rollback_steps = args[args.size() - 1]; - auto init_ctx = CreateInitContext(grammar, token_table); + auto init_ctx = GrammarStateMatcher::CreateInitContext(grammar, token_table); *rv = GrammarStateMatcher(init_ctx, max_rollback_steps); }); @@ -474,7 +478,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugMatchCompleteString") }); /*! - * \brief Find the ids of the rejected tokens for the next step. + * \brief Find the ids of the rejected tokens for the next step. For test purposes. * \returns A tuple of rejected token ids. */ IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher) { @@ -483,16 +487,15 @@ IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher) { auto bitset_size = BitsetManager::GetBitsetSize(vocab_size); auto ndarray = NDArray::Empty(ShapeTuple{static_cast(bitset_size)}, DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0}); - auto dltensor_manager = ndarray.ToDLPack(); - auto dltensor = ndarray.ToDLPack()->dl_tensor; + auto dltensor = const_cast(ndarray.operator->()); auto start = std::chrono::high_resolution_clock::now(); - matcher->FindNextTokenBitmask(&dltensor); + matcher->FindNextTokenBitmask(dltensor); auto end = std::chrono::high_resolution_clock::now(); - std::cout << "FindNextTokenBitmask takes " + std::cerr << "FindNextTokenBitmask takes " << std::chrono::duration_cast(end - start).count() << "us"; - auto bitset = BitsetManager(reinterpret_cast(dltensor.data), bitset_size); + auto bitset = BitsetManager(reinterpret_cast(dltensor->data), bitset_size); std::vector rejected_ids; for (int i = 0; i < vocab_size; i++) { if (bitset[i] == 0) { @@ -500,11 +503,9 @@ IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher) { } } - std::cout << ", found accepted: " << vocab_size - rejected_ids.size() + std::cerr << ", found accepted: " << vocab_size - rejected_ids.size() << ", rejected: " << rejected_ids.size() << std::endl; - dltensor_manager->deleter(dltensor_manager); - auto ret = IntTuple(rejected_ids); return ret; } diff --git a/cpp/serve/grammar/grammar_state_matcher.h b/cpp/serve/grammar/grammar_state_matcher.h index 0ea4b12b95..ec6e8f19b1 100644 --- a/cpp/serve/grammar/grammar_state_matcher.h +++ b/cpp/serve/grammar/grammar_state_matcher.h @@ -77,7 +77,7 @@ class GrammarStateMatcherNode : public Object { virtual void Rollback(int num_tokens) = 0; /*! \brief Get the maximum number of rollback steps allowed. */ - virtual int MaxRollbackSteps() = 0; + virtual int MaxRollbackSteps() const = 0; /*! \brief Reset the matcher to the initial state. */ virtual void ResetState() = 0; diff --git a/cpp/serve/grammar/grammar_state_matcher_preproc.h b/cpp/serve/grammar/grammar_state_matcher_preproc.h index 62a1f2a6af..194d5b2935 100644 --- a/cpp/serve/grammar/grammar_state_matcher_preproc.h +++ b/cpp/serve/grammar/grammar_state_matcher_preproc.h @@ -55,7 +55,8 @@ struct CatagorizedTokens { */ class GrammarStateInitContext { public: - BNFGrammar grammar; + /******************* Information about the tokenizer *******************/ + /*! \brief The vocabulary size of the tokenizer. */ size_t vocab_size; /*! \brief The sorted token and its id. Tokens are sorted to reuse the common prefix during @@ -70,6 +71,12 @@ class GrammarStateInitContext { * matching. */ std::vector special_token_ids; + /******************* Information about the grammar *******************/ + + BNFGrammar grammar; + + /******************* Grammar-specific tokenizer information *******************/ + /*! \brief A sequence id and its position. */ struct SequenceIdAndPosition { int32_t sequence_id; @@ -232,7 +239,7 @@ inline std::string ReplaceUnderscoreWithSpace(const std::string& str, return res; } -inline std::shared_ptr CreateInitContext( +inline std::shared_ptr GrammarStateMatcher::CreateInitContext( const BNFGrammar& grammar, const std::vector& token_table) { using RuleExprType = BNFGrammarNode::RuleExprType; auto ptr = std::make_shared(); @@ -252,7 +259,8 @@ inline std::shared_ptr CreateInitContext( ptr->stop_token_ids.push_back(i); } else if (token.size() == 1 && (static_cast(token[0]) >= 128 || token[0] == 0)) { - // Currently we consider all tokens with one character that >= 128 as special tokens. + // Currently we consider all tokens with one character that >= 128 as special tokens, + // and will ignore generating them during grammar-guided generation. ptr->special_token_ids.push_back(i); } else { // First replace the special underscore with space. diff --git a/cpp/serve/grammar/support.h b/cpp/serve/grammar/support.h index 9ee6ffb3b3..9df1083335 100644 --- a/cpp/serve/grammar/support.h +++ b/cpp/serve/grammar/support.h @@ -50,7 +50,7 @@ class BitsetManager { * \brief Let lhs be the union of lhs and rhs. Suppose that both sets are sorted. * \note No additional vectors are allocated, and the time complexity is O(n) */ -void IntsetUnion(std::vector* lhs, const std::vector& rhs) { +inline void IntsetUnion(std::vector* lhs, const std::vector& rhs) { int original_lhs_size = lhs->size(); int rhs_size = rhs.size(); @@ -91,7 +91,7 @@ void IntsetUnion(std::vector* lhs, const std::vector& rhs) { * \note Support the case where lhs is the universal set by setting lhs to {-1}. The result will be * rhs then. */ -void IntsetIntersection(std::vector* lhs, const std::vector& rhs) { +inline void IntsetIntersection(std::vector* lhs, const std::vector& rhs) { if (lhs->size() == 1 && (*lhs)[0] == -1) { *lhs = rhs; return; diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index 24ce003fe3..5af7a39d29 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -8,6 +8,7 @@ #include #include #include +#include namespace mlc { namespace llm { @@ -44,7 +45,7 @@ class LogitProcessorImpl : public LogitProcessorObj { token_cnt_host_ = NDArray::Empty({max_num_token * vocab_size}, dtype_i32_, device_cpu); token_logit_bias_host_ = NDArray::Empty({max_num_token * vocab_size}, dtype_f32_, device_cpu); penalties_host_ = NDArray::Empty({max_num_token, 3}, dtype_f32_, device_cpu); - bitmask_host_ = NDArray::Empty({max_num_token, bitmask_size_}, dtype_i32_, device_cpu); + bitmask_host_ = NDArray::Empty({max_num_token, bitmask_size_}, dtype_u32_, device_cpu); temperature_host_ = NDArray::Empty({max_num_token}, dtype_f32_, device_cpu); // Initialize auxiliary arrays on GPU. seq_ids_device_ = NDArray::Empty({max_num_token}, dtype_i32_, device); @@ -99,7 +100,7 @@ class LogitProcessorImpl : public LogitProcessorObj { // Update 3. Vocabulary mask. RECORD_EVENT(trace_recorder_, request_ids, "start apply logit mask"); - UpdateWithMask(logits, mstates, cum_num_token, draft_tokens); + UpdateWithMask(logits, mstates, cum_num_token, draft_tokens, request_ids); RECORD_EVENT(trace_recorder_, request_ids, "finish apply logit mask"); RECORD_EVENT(trace_recorder_, request_ids, "finish update logits"); @@ -301,40 +302,59 @@ class LogitProcessorImpl : public LogitProcessorObj { void UpdateWithMask(NDArray logits, const Array& mstates, const std::vector* cum_num_token, - const std::vector>* draft_tokens) { + const std::vector>* draft_tokens, + const Array& request_ids) { // Construct: // - seq_ids (max_num_token,) int32 // - bitmask (max_num_token, ceildiv(vocab_size, 32)), int32 - int* p_seq_ids = static_cast(seq_ids_host_->data); - int* p_bitmask = static_cast(bitmask_host_->data); + int32_t* p_seq_ids = static_cast(seq_ids_host_->data); + uint32_t* p_bitmask = static_cast(bitmask_host_->data); // - Set arrays. - int num_token_for_mask = 0; + ICHECK(mstates.size() == request_ids.size()); + + int batch_size = logits->shape[0]; + ICHECK((cum_num_token == nullptr && batch_size == mstates.size()) || + (cum_num_token != nullptr && batch_size == cum_num_token->size())); + + std::memset(p_seq_ids, 0, batch_size * sizeof(int32_t)); + for (int i = 0; i < static_cast(mstates.size()); ++i) { - int num_token_to_process = + int token_start_offset = cum_num_token == nullptr ? i : cum_num_token->at(i); + int token_number = cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i)); - int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i); - CHECK(num_token_to_process == 1 || mstates[i]->draft_output_tokens.empty()); - for (int j = 0; j < num_token_to_process; ++j) { - std::vector bitmask = mstates[i]->GetTokenBitmask(vocab_size_); - if (!bitmask.empty()) { - p_seq_ids[num_token_for_mask] = token_offset + j; - ICHECK_EQ(bitmask.size(), bitmask_size_); - for (int p = 0; p < bitmask_size_; ++p) { - p_bitmask[num_token_for_mask * bitmask_size_ + p] = bitmask[p]; - } - ++num_token_for_mask; + CHECK(token_number == 1 || mstates[i]->draft_output_tokens.empty()); + bool require_mask = mstates[i]->RequireNextTokenBitmask(); + for (int j = 0; j < token_number; ++j) { + if (require_mask) { + // Find a slice of bitmask_host_: bitmask_host_[num_token_for_mask, :] + auto bitmask_dltensor = *bitmask_host_.operator->(); + int64_t bitmask_shape[] = {bitmask_size_}; + bitmask_dltensor.data = p_bitmask + (token_start_offset + j) * bitmask_size_; + bitmask_dltensor.shape = bitmask_shape; + bitmask_dltensor.ndim = 1; + + mstates[i]->FindNextTokenBitmask(&bitmask_dltensor); + p_seq_ids[token_start_offset + j] = 1; } if (j > 0) { mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray()); } } - if (num_token_to_process != 1) { + if (token_number != 1) { // Roll back. mstates[i]->RemoveAllDraftTokens(); } } + int num_token_for_mask = 0; + for (int i = 0; i < batch_size; ++i) { + if (p_seq_ids[i] == 1) { + p_seq_ids[num_token_for_mask] = i; + ++num_token_for_mask; + } + } + if (num_token_for_mask == 0) { return; } @@ -343,8 +363,8 @@ class LogitProcessorImpl : public LogitProcessorObj { int num_seq = num_token_for_mask; NDArray seq_ids_host = seq_ids_host_.CreateView({num_seq}, dtype_i32_); NDArray seq_ids_device = seq_ids_device_.CreateView({num_seq}, dtype_i32_); - NDArray bitmask_host = bitmask_host_.CreateView({num_seq, bitmask_size_}, dtype_i32_); - NDArray bitmask_device = bitmask_device_.CreateView({num_seq, bitmask_size_}, dtype_i32_); + NDArray bitmask_host = bitmask_host_.CreateView({batch_size, bitmask_size_}, dtype_i32_); + NDArray bitmask_device = bitmask_device_.CreateView({batch_size, bitmask_size_}, dtype_i32_); // - Copy arrays to GPU. CopyArray(/*src=*/seq_ids_host, /*dst=*/seq_ids_device); @@ -362,6 +382,7 @@ class LogitProcessorImpl : public LogitProcessorObj { const int vocab_size_; const int bitmask_size_; const DLDataType dtype_i32_ = DataType::Int(32); + const DLDataType dtype_u32_ = DataType::UInt(32); const DLDataType dtype_f32_ = DataType::Float(32); // Packed functions. Device device_; diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index cea6af7bff..7519a56adb 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -13,13 +13,20 @@ namespace serve { TVM_REGISTER_OBJECT_TYPE(RequestModelStateNode); -RequestModelState::RequestModelState(Request request, int model_id, int64_t internal_id, - Array inputs) { +RequestModelState::RequestModelState( + Request request, int model_id, int64_t internal_id, Array inputs, + std::shared_ptr json_grammar_state_init_ctx) { ObjectPtr n = make_object(); - n->request = std::move(request); n->model_id = model_id; n->internal_id = internal_id; n->inputs = std::move(inputs); + + if (request->generation_cfg->response_format.type == "json_object") { + // TODO(yixin): add support for stop_token_ids + n->grammar_state_matcher = GrammarStateMatcher(json_grammar_state_init_ctx); + } + + n->request = std::move(request); data_ = std::move(n); } @@ -31,9 +38,12 @@ int RequestModelStateNode::GetInputLength() const { return total_length; } -std::vector RequestModelStateNode::GetTokenBitmask(int vocab_size) const { - // TODO(mlc-team): implement this function. - return std::vector(); +bool RequestModelStateNode::RequireNextTokenBitmask() { return grammar_state_matcher.defined(); } + +void RequestModelStateNode::FindNextTokenBitmask(DLTensor* bitmask) { + ICHECK(grammar_state_matcher.defined()); + + grammar_state_matcher.value()->FindNextTokenBitmask(bitmask); } void RequestModelStateNode::CommitToken(SampleResult sampled_token) { @@ -67,12 +77,14 @@ void RequestModelStateNode::RemoveAllDraftTokens() { TVM_REGISTER_OBJECT_TYPE(RequestStateNode); RequestState::RequestState(Request request, int num_models, int64_t internal_id, - const std::vector& token_table) { + const std::vector& token_table, + std::shared_ptr json_grammar_state_init_ctx) { ObjectPtr n = make_object(); Array mstates; mstates.reserve(num_models); for (int i = 0; i < num_models; ++i) { - mstates.push_back(RequestModelState(request, i, internal_id, request->inputs)); + mstates.push_back( + RequestModelState(request, i, internal_id, request->inputs, json_grammar_state_init_ctx)); } n->rng = RandomGenerator(request->generation_cfg->seed); n->stop_str_handler = StopStrHandler( diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 134d1df4bd..6cf5928a13 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -13,6 +13,7 @@ #include "../random.h" #include "../streamer.h" #include "config.h" +#include "grammar/grammar_state_matcher.h" #include "request.h" namespace mlc { @@ -70,13 +71,25 @@ class RequestModelStateNode : public Object { /*! \brief The appeared committed and draft tokens and their occurrence times. */ std::unordered_map appeared_token_ids; + /*! + * \brief The current state of the generated token matching the grammar. Used in grammar-guided + * generation, otherwise it's NullOpt. + */ + Optional grammar_state_matcher; + /*! \brief Return the total length of the input data. */ int GetInputLength() const; /*! - * \brief Return the token bitmask induced by the current state. - * The returned vector should have size "ceildiv(vocab_size, 32)". + * \brief Return whether the next token bitmask is required, i.e. the grammar-guided generation is + * enabled. + */ + bool RequireNextTokenBitmask(); + /*! + * \brief Find the next token bitmask and store it in the given DLTensor. + * \param bitmask The DLTensor to store the next token bitmask. The bitmask should be a tensor + * with dtype uint32_t and shape (ceildiv(vocab_size, 32),). */ - std::vector GetTokenBitmask(int vocab_size) const; + void FindNextTokenBitmask(DLTensor* bitmask); /*! \brief Commit a new token into committed_tokens. Update appeared_token_ids. */ void CommitToken(SampleResult sampled_token); /*! \brief Add a draft token into draft_output_tokens. Update appeared_token_ids. */ @@ -94,8 +107,8 @@ class RequestModelStateNode : public Object { class RequestModelState : public ObjectRef { public: - explicit RequestModelState(Request request, int model_id, int64_t internal_id, - Array inputs); + explicit RequestModelState(Request request, int model_id, int64_t internal_id, Array inputs, + std::shared_ptr json_grammar_state_init_ctx); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestModelState, ObjectRef, RequestModelStateNode); }; @@ -131,13 +144,13 @@ class RequestStateNode : public Object { std::chrono::high_resolution_clock::time_point tprefill_finish; /*! - * \brief Get the delta token ids and the logprob JSON strings for this - * request to return since the last time calling into this function, - * and return the finish reason if the request generation has finished. + * \brief Get the delta token ids and the logprob JSON strings for this request to return since + * the last time calling into this function, and return the finish reason if the request + * generation has finished. * \param tokenizer The tokenizer for logprob process. * \param max_single_sequence_length The maximum allowed single sequence length. - * \return The delta token ids to return, the logprob JSON strings of each - * delta token id, and the optional finish reason. + * \return The delta token ids to return, the logprob JSON strings of each delta token id, and + * the optional finish reason. */ DeltaRequestReturn GetReturnTokenIds(const Tokenizer& tokenizer, int max_single_sequence_length); @@ -150,7 +163,8 @@ class RequestStateNode : public Object { class RequestState : public ObjectRef { public: explicit RequestState(Request request, int num_models, int64_t internal_id, - const std::vector& token_table); + const std::vector& token_table, + std::shared_ptr json_grammar_state_init_ctx); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestState, ObjectRef, RequestStateNode); }; diff --git a/python/mlc_chat/compiler_pass/attach_to_ir_module.py b/python/mlc_chat/compiler_pass/attach_to_ir_module.py index 0b33647509..06026397a4 100644 --- a/python/mlc_chat/compiler_pass/attach_to_ir_module.py +++ b/python/mlc_chat/compiler_pass/attach_to_ir_module.py @@ -145,7 +145,7 @@ def _apply_bitmask_inplace( num_seq = T.int32(is_size_var=True) logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") - bitmask = T.match_buffer(var_bitmask, (num_seq, (vocab_size + 31 // 32)), "int32") + bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32") for fused_s_v_0 in T.thread_binding(0, (num_seq * vocab_size + 1023) // 1024, "blockIdx.x"): for fused_s_v_1 in T.thread_binding(0, 1024, "threadIdx.x"): @@ -154,7 +154,7 @@ def _apply_bitmask_inplace( vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size) T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size) logits[seq_ids[vs], vv] = T.if_then_else( - (bitmask[vs, vv // 32] >> (vv % 32)) & 1 == 1, + (bitmask[seq_ids[vs], vv // 32] >> (vv % 32)) & 1 == 1, logits[seq_ids[vs], vv], T.float32(-1e10), ) diff --git a/python/mlc_chat/protocol/openai_api_protocol.py b/python/mlc_chat/protocol/openai_api_protocol.py index 2ae26bf752..e45711d516 100644 --- a/python/mlc_chat/protocol/openai_api_protocol.py +++ b/python/mlc_chat/protocol/openai_api_protocol.py @@ -65,6 +65,11 @@ class ModelResponse(BaseModel): ################ v1/completions ################ +class ResponseFormat(BaseModel): + type: Literal["text", "json_object"] = "text" + json_schema: Optional[str] = None + + class CompletionRequest(BaseModel): """OpenAI completion request protocol. API reference: https://platform.openai.com/docs/api-reference/completions/create @@ -89,6 +94,7 @@ class CompletionRequest(BaseModel): top_p: float = 1.0 user: Optional[str] = None ignore_eos: bool = False + response_format: ResponseFormat = ResponseFormat() @field_validator("frequency_penalty", "presence_penalty") @classmethod @@ -193,7 +199,6 @@ class ChatCompletionRequest(BaseModel): logit_bias: Optional[Dict[int, float]] = None max_tokens: Optional[int] = None n: int = 1 - response_format: Literal["text", "json_object"] = "text" seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stream: bool = False @@ -203,6 +208,7 @@ class ChatCompletionRequest(BaseModel): tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None user: Optional[str] = None ignore_eos: bool = False + response_format: ResponseFormat = ResponseFormat() @field_validator("frequency_penalty", "presence_penalty") @classmethod @@ -291,7 +297,6 @@ def openai_api_get_unsupported_fields( unsupported_field_default_values: List[Tuple[str, Any]] = [ ("best_of", 1), ("n", 1), - ("response_format", "text"), ] unsupported_fields: List[str] = [] @@ -326,4 +331,5 @@ def openai_api_get_generation_config( kwargs["max_tokens"] = -1 if request.stop is not None: kwargs["stop_strs"] = [request.stop] if isinstance(request.stop, str) else request.stop + kwargs["response_format"] = request.response_format.model_dump() return kwargs diff --git a/python/mlc_chat/protocol/protocol_utils.py b/python/mlc_chat/protocol/protocol_utils.py index a9a68a1f82..b515ffc47c 100644 --- a/python/mlc_chat/protocol/protocol_utils.py +++ b/python/mlc_chat/protocol/protocol_utils.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from ..serve.config import GenerationConfig +from ..serve.config import GenerationConfig, ResponseFormat from . import RequestProtocol from .openai_api_protocol import ChatCompletionRequest as OpenAIChatCompletionRequest from .openai_api_protocol import CompletionRequest as OpenAICompletionRequest @@ -43,6 +43,9 @@ def get_generation_config( else: raise RuntimeError("Cannot reach here") + response_format_dict = kwargs.get("response_format", {}) + kwargs["response_format"] = ResponseFormat(**response_format_dict) + if extra_stop_token_ids is not None: stop_token_ids = kwargs.get("stop_token_ids", []) assert isinstance(stop_token_ids, list) diff --git a/python/mlc_chat/serve/config.py b/python/mlc_chat/serve/config.py index ccc152ab36..00cd53f66f 100644 --- a/python/mlc_chat/serve/config.py +++ b/python/mlc_chat/serve/config.py @@ -2,7 +2,31 @@ import json from dataclasses import asdict, dataclass, field -from typing import Dict, List, Optional +from typing import Dict, List, Literal, Optional + + +@dataclass +class ResponseFormat: + """The response format dataclass. + + Parameters + ---------- + type : Literal["text", "json_object"] + The type of response format. Default: "text". + + json_schema : Optional[str] + The JSON schema string for the JSON response format. If None, a legal json string without + special restrictions will be generated. + + Could be specified when the response format is "json_object". Default: None. + """ + + type: Literal["text", "json_object"] = "text" + json_schema: Optional[str] = None + + def __post_init__(self): + if self.json_schema is not None and self.type != "json_object": + raise ValueError("JSON json_schema is only supported in JSON response format") @dataclass @@ -16,7 +40,7 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes top_p : float In sampling, only the most probable tokens with probabilities summed up to - `top_k` are kept for sampling. + `top_p` are kept for sampling. frequency_penalty : float Positive values penalize new tokens based on their existing frequency @@ -63,6 +87,9 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes ignore_eos: bool When it is true, ignore the eos token and generate tokens until `max_tokens`. Default is set to False. + + response_format : ResponseFormat + The response format of the generation output. """ temperature: float = 0.8 @@ -80,6 +107,8 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes stop_token_ids: List[int] = field(default_factory=list) ignore_eos: bool = False + response_format: ResponseFormat = field(default_factory=ResponseFormat) + def asjson(self) -> str: """Return the config in string of JSON format.""" return json.dumps(asdict(self)) diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index 324c4b377c..3cb015000f 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -35,6 +35,14 @@ DEBUG_DUMP_EVENT_TRACE_URL = "http://127.0.0.1:8000/debug/dump_event_trace" +def is_json_or_json_prefix(s: str) -> bool: + try: + json.loads(s) + return True + except json.JSONDecodeError as e: + return e.pos == len(s) + + def check_openai_nonstream_response( response: Dict, *, @@ -48,6 +56,7 @@ def check_openai_nonstream_response( suffix: Optional[str] = None, stop: Optional[List[str]] = None, require_substr: Optional[List[str]] = None, + json_mode: bool = False, ): assert response["model"] == model assert response["object"] == object_str @@ -55,6 +64,7 @@ def check_openai_nonstream_response( choices = response["choices"] assert isinstance(choices, list) assert len(choices) == num_choices + for idx, choice in enumerate(choices): assert choice["index"] == idx assert choice["finish_reason"] in finish_reasons @@ -79,6 +89,8 @@ def check_openai_nonstream_response( if require_substr is not None: for substr in require_substr: assert substr in text + if json_mode: + assert is_json_or_json_prefix(text) usage = response["usage"] assert isinstance(usage, dict) @@ -101,6 +113,7 @@ def check_openai_stream_response( suffix: Optional[str] = None, stop: Optional[List[str]] = None, require_substr: Optional[List[str]] = None, + json_mode: bool = False, ): assert len(responses) > 0 @@ -154,6 +167,8 @@ def check_openai_stream_response( if require_substr is not None: for substr in require_substr: assert substr in output + if json_mode: + assert is_json_or_json_prefix(output) def expect_error(response_str: str, msg_prefix: Optional[str] = None): @@ -484,6 +499,55 @@ def test_openai_v1_completions_temperature( ) +# TODO(yixin): support eos_token_id for tokenizer +@pytest.mark.skip("JSON test for completion api requires internal eos_token_id support") +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_json( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + prompt = "Response with a json object:" + max_tokens = 128 + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "stream": stream, + "response_format": {"type": "json_object"}, + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reasons=["length", "stop"], + json_mode=True, + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reasons=["length", "stop"], + json_mode=True, + ) + + @pytest.mark.parametrize("stream", [False, True]) def test_openai_v1_completions_logit_bias( served_model: Tuple[str, str], @@ -888,6 +952,53 @@ def test_openai_v1_chat_completions_max_tokens( ) +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_chat_completions_json( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + messages = [{"role": "user", "content": "Response with a json object:"}] + max_tokens = 128 + payload = { + "model": served_model[0], + "messages": messages, + "stream": stream, + "max_tokens": max_tokens, + "response_format": {"type": "json_object"}, + } + + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion", + num_choices=1, + finish_reasons=["length", "stop"], + json_mode=True, + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion.chunk", + num_choices=1, + finish_reasons=["length", "stop"], + json_mode=True, + ) + + @pytest.mark.parametrize("stream", [False, True]) def test_openai_v1_chat_completions_ignore_eos( served_model: Tuple[str, str], @@ -1028,6 +1139,8 @@ def test_debug_dump_event_trace( test_openai_v1_chat_completions_openai_package(MODEL, None, stream=True, messages=msg) test_openai_v1_chat_completions_max_tokens(MODEL, None, stream=False) test_openai_v1_chat_completions_max_tokens(MODEL, None, stream=True) + test_openai_v1_chat_completions_json(MODEL, None, stream=False) + test_openai_v1_chat_completions_json(MODEL, None, stream=True) test_openai_v1_chat_completions_ignore_eos(MODEL, None, stream=False) test_openai_v1_chat_completions_ignore_eos(MODEL, None, stream=True) test_openai_v1_chat_completions_system_prompt_wrong_pos(MODEL, None, stream=False) diff --git a/tests/python/serve/test_grammar_state_matcher.py b/tests/python/serve/test_grammar_state_matcher.py index cf7229af21..61d6341c48 100644 --- a/tests/python/serve/test_grammar_state_matcher.py +++ b/tests/python/serve/test_grammar_state_matcher.py @@ -1,5 +1,6 @@ # pylint: disable=missing-module-docstring,missing-function-docstring # pylint: disable=redefined-outer-name,unbalanced-tuple-unpacking +import sys from typing import List import pytest @@ -17,7 +18,7 @@ def json_grammar(): (json_input_accepted,) = tvm.testing.parameters( ('{"name": "John"}',), - ('{ "name" : "John" } \n',), + ('{ "name" : "John" }',), ("{}",), ("[]",), ('{"name": "Alice", "age": 30, "city": "New York"}',), @@ -54,19 +55,17 @@ def test_json_accept(json_grammar: BNFGrammar, json_input_accepted: str): assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_accepted) -# test_json_accept(json_grammar(), '{"name": "John"}') -# exit() - (json_input_refused,) = tvm.testing.parameters( (r'{ name: "John" }',), - (r'{ "name": "John", "age": 30, }',), # x + (r'{ "name": "John" } ',), # trailing space is not accepted + (r'{ "name": "John", "age": 30, }',), (r'{ "name": "John", "address": { "street": "123 Main St", "city": "New York" }',), - (r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling",], }',), # x + (r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling",], }',), (r'{ "name": "John", "age": 30.5.7 }',), (r'{ "name": "John, "age": 30, "hobbies": ["reading", "traveling"] }',), ( r'{ "name": "John", "age": 30, "hobbies": ["reading", { "type": "outdoor", "list": ' - r'["hiking", "swimming",]}] }', # + r'["hiking", "swimming",]}] }', ), (r'{ "name": "John", "age": 30, "status": "\P\J" }',), ( @@ -203,7 +202,7 @@ def test_json_refuse(json_grammar: BNFGrammar, json_input_refused): "taglib-location": "/WEB-INF/tlds/cofax.tld" } } -} """, +}""", ), ) @@ -215,11 +214,11 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): (input_find_rejected_tokens, expected_rejected_sizes) = tvm.testing.parameters( ( # short test - '{"id": 1,"name": "Example"} ', + '{"id": 1,"name": "Example"}', [ # fmt: off - 31989, 31907, 278, 278, 278, 31973, 31841, 31841, 31948, 31910, 278, 278, 278, 278, - 278, 31973, 31841, 31841, 271, 271, 271, 271, 271, 271, 271, 271, 31974, 31980, 31980 + 31989, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 299, 299, 299, 299, + 299, 31973, 31846, 31846, 292, 292, 292, 292, 292, 292, 292, 292, 31974, 31999 # fmt: on ], ), @@ -228,30 +227,29 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): """{ "id": 1, "na": "ex", -"ac": True, +"ac": true, "t": ["t1", "t2"], "ne": {"lv2": {"val": "dp"}, "arr": [1, 2, 3]}, "res": "res" -} -""", +}""", [ # fmt: off - 31989, 31907, 31907, 278, 278, 278, 31973, 31841, 31841, 31948, 31910, 31910, 278, 278, - 278, 31973, 31841, 31841, 271, 271, 271, 31974, 31910, 31910, 278, 278, 278, 31973, - 31841, 31841, 31841, 31841, 31841, 31841, 31841, 31841, 271, 271, 31974, 31974, 31974, - 31974, 31974, 31974, 31974, 31974, 31910, 31910, 278, 278, 278, 31973, 31973, 31973, - 31973, 31973, 31973, 31973, 31973, 31841, 31841, 31903, 278, 278, 278, 278, 31973, - 31841, 31841, 31901, 278, 278, 278, 278, 31973, 31841, 31841, 270, 270, 270, 31968, - 31970, 31910, 31910, 278, 278, 278, 278, 31973, 31841, 31841, 31835, 31943, 31841, - 31841, 31943, 31841, 31841, 31943, 31970, 31974, 31910, 31910, 278, 278, 278, 278, - 31973, 31841, 31841, 271, 271, 271, 271, 31974, 31974, 31980, 31980 + 31989, 31912, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 31915, 299, 299, + 299, 31973, 31846, 31846, 292, 292, 292, 31974, 31915, 31915, 299, 299, 299, 31973, + 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 299, 299, 31973, 31846, 31846, + 31840, 291, 291, 291, 31969, 31846, 31846, 291, 291, 291, 31969, 31974, 31915, 31915, + 299, 299, 299, 31973, 31846, 31846, 31908, 299, 299, 299, 299, 31973, 31846, 31846, + 31906, 299, 299, 299, 299, 31973, 31846, 31846, 291, 291, 291, 31968, 31970, 31915, + 31915, 299, 299, 299, 299, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, + 31846, 31846, 31943, 31970, 31974, 31915, 31915, 299, 299, 299, 299, 31973, 31846, + 31846, 292, 292, 292, 292, 31974, 31974, 31999 # fmt: on ], ), ) -def test_find_rejected_tokens( +def test_find_next_rejected_tokens( json_grammar: BNFGrammar, input_find_rejected_tokens: str, expected_rejected_sizes: List[int] ): tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" @@ -262,10 +260,11 @@ def test_find_rejected_tokens( for c in input_find_rejected_tokens: rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() real_sizes.append(len(rejected_token_ids)) - print("Accepting char:", c) - grammar_state_matcher.debug_accept_char(ord(c)) + print("Accepting char:", c, file=sys.stderr) + assert grammar_state_matcher.debug_accept_char(ord(c)) rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() real_sizes.append(len(rejected_token_ids)) + print(real_sizes) assert real_sizes == expected_rejected_sizes @@ -275,7 +274,7 @@ def test_accept_token(json_grammar: BNFGrammar): "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', # fmt: on ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}", "\n"] + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}"] input_ids = [token_table.index(t) for t in input_splitted] grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) @@ -285,16 +284,15 @@ def test_accept_token(json_grammar: BNFGrammar): expected = [ ["{"], ['"', "}", "\n", " ", '"a":true'], - ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " "], - ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " "], + ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", " "], + ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", " "], [":", "\n", " ", ':"'], ['"', "{", "6", "\n", " "], ["}", ", ", "6", "\n", " "], [" ", "\n", '"', '"a":true'], [" ", "\n", '"', '"a":true'], ["}", ", ", "\n", " "], - ["", "\n", " "], - ["", "\n", " "], + [""], ] for id in input_ids: @@ -303,7 +301,7 @@ def test_accept_token(json_grammar: BNFGrammar): accepted_tokens = [token_table[i] for i in accepted] result.append(accepted_tokens) assert id in accepted - grammar_state_matcher.accept_token(id) + assert grammar_state_matcher.accept_token(id) rejected = grammar_state_matcher.find_next_rejected_tokens() accepted = list(set(range(len(token_table))) - set(rejected)) @@ -319,7 +317,7 @@ def test_rollback(json_grammar: BNFGrammar): "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', # fmt: on ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', " ", "}", "\n"] + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}"] input_ids = [token_table.index(t) for t in input_splitted] grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table, 5) @@ -331,15 +329,15 @@ def test_rollback(json_grammar: BNFGrammar): for i_1, i_2 in input_ids_splitted: orig_result = [] orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) - grammar_state_matcher.accept_token(i_1) + assert grammar_state_matcher.accept_token(i_1) orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) - grammar_state_matcher.accept_token(i_2) + assert grammar_state_matcher.accept_token(i_2) grammar_state_matcher.rollback(2) result_after_rollback = [] result_after_rollback.append(grammar_state_matcher.find_next_rejected_tokens()) - grammar_state_matcher.accept_token(i_1) + assert grammar_state_matcher.accept_token(i_1) result_after_rollback.append(grammar_state_matcher.find_next_rejected_tokens()) - grammar_state_matcher.accept_token(i_2) + assert grammar_state_matcher.accept_token(i_2) assert orig_result == result_after_rollback @@ -349,7 +347,7 @@ def test_reset(json_grammar: BNFGrammar): "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', # fmt: on ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', " ", "}", "\n"] + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}"] input_ids = [token_table.index(t) for t in input_splitted] grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) @@ -358,7 +356,7 @@ def test_reset(json_grammar: BNFGrammar): for i in input_ids: orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) - grammar_state_matcher.accept_token(i) + assert grammar_state_matcher.accept_token(i) grammar_state_matcher.reset_state() @@ -366,20 +364,20 @@ def test_reset(json_grammar: BNFGrammar): for i in input_ids: result_after_reset.append(grammar_state_matcher.find_next_rejected_tokens()) - grammar_state_matcher.accept_token(i) + assert grammar_state_matcher.accept_token(i) assert orig_result == result_after_reset if __name__ == "__main__": # Run a benchmark to show the performance before running tests - test_find_rejected_tokens( + test_find_next_rejected_tokens( BNFGrammar.get_grammar_of_json(), - '{"id": 1,"name": "Example"} ', + '{"id": 1,"name": "Example"}', [ # fmt: off - 31989, 31907, 278, 278, 278, 31973, 31841, 31841, 31948, 31910, 278, 278, 278, 278, - 278, 31973, 31841, 31841, 271, 271, 271, 271, 271, 271, 271, 271, 31974, 31980, 31980 + 31989, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 299, 299, 299, 299, + 299, 31973, 31846, 31846, 292, 292, 292, 292, 292, 292, 292, 292, 31974, 31999 # fmt: on ], ) diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py new file mode 100644 index 0000000000..901e6c4d98 --- /dev/null +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -0,0 +1,131 @@ +# pylint: disable=chained-comparison,line-too-long,missing-docstring, +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +import asyncio +from typing import List + +import pytest + +from mlc_chat.serve import Engine, GenerationConfig, KVCacheConfig +from mlc_chat.serve.async_engine import AsyncThreadedEngine +from mlc_chat.serve.config import ResponseFormat +from mlc_chat.serve.engine import ModelInfo + +prompts_list = [ + "Generate a JSON string containing 20 objects:", + "Generate a JSON containing a list:", + "Generate a JSON with 5 elements:", +] +model_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" +model_lib_path = "dist/libs/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + + +def test_batch_generation_with_grammar(): + # Initialize model loading info and KV cache config + model = ModelInfo(model_path, model_lib_path=model_lib_path) + kv_cache_config = KVCacheConfig(page_size=16) + # Create engine + engine = Engine(model, kv_cache_config) + + prompts = prompts_list * 2 + + temperature = 1 + repetition_penalty = 1 + max_tokens = 512 + generation_config_no_json = GenerationConfig( + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + stop_token_ids=[2], + response_format=ResponseFormat(type="text"), + ) + generation_config_json = GenerationConfig( + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + stop_token_ids=[2], + response_format=ResponseFormat(type="json_object"), + ) + all_generation_configs = [generation_config_no_json] * 3 + [generation_config_json] * 3 + + # Generate output. + output_texts, _ = engine.generate(prompts, all_generation_configs) + for req_id, output in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + print(f"Output {req_id}: {output}\n") + + +async def run_async_engine(): + # Initialize model loading info and KV cache config + model = ModelInfo(model_path, model_lib_path=model_lib_path) + kv_cache_config = KVCacheConfig(page_size=16) + # Create engine + async_engine = AsyncThreadedEngine(model, kv_cache_config, enable_tracing=True) + + prompts = prompts_list * 20 + + max_tokens = 256 + temperature = 1 + repetition_penalty = 1 + max_tokens = 512 + generation_config = GenerationConfig( + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + stop_token_ids=[2], + response_format=ResponseFormat(type="json_object"), + ) + + outputs: List[str] = ["" for _ in range(len(prompts))] + + async def generate_task( + async_engine: AsyncThreadedEngine, + prompt: str, + generation_cfg: GenerationConfig, + request_id: str, + ): + print(f"Start generation task for request {request_id}") + rid = int(request_id) + async for delta_text, _, _, _ in async_engine.generate( + prompt, generation_cfg, request_id=request_id + ): + outputs[rid] += delta_text + + tasks = [ + asyncio.create_task( + generate_task(async_engine, prompts[i], generation_config, request_id=str(i)) + ) + for i in range(len(prompts)) + ] + + await asyncio.gather(*tasks) + + # Print output. + print("All finished") + for req_id, output in enumerate(outputs): + print(f"Prompt {req_id}: {prompts[req_id]}") + print(f"Output {req_id}: {output}\n") + + print(async_engine.trace_recorder.dump_json(), file=open("tmpfiles/tmp.json", "w")) + + async_engine.terminate() + + +def test_async_engine(): + asyncio.run(run_async_engine()) + + +def test_generation_config_error(): + with pytest.raises(ValueError): + GenerationConfig( + temperature=1.0, + repetition_penalty=1.0, + max_tokens=128, + stop_token_ids=[2], + response_format=ResponseFormat(type="text", json_schema="{}"), + ) + + +if __name__ == "__main__": + test_batch_generation_with_grammar() + test_async_engine() + test_generation_config_error() From 7806dee5c4554f02876ec03bf4e61ff0aaa49be3 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 2 Mar 2024 08:02:04 -0500 Subject: [PATCH 020/531] [Serving] Support "n" for parallel generation (#1868) This PR brings field `n` to generation config and thereby supports parallel generation. This parallel generation effectively leverages the "fork" functionality of paged KV cache. This PR supports specifying the number of parallel generation `n` in stardard OpenAI ChatCompletion API. This is the last feature towards the OpenAI API feature completeness. --- cpp/serve/config.cc | 6 + cpp/serve/config.h | 1 + cpp/serve/data.cc | 18 +- cpp/serve/data.h | 12 +- cpp/serve/engine.cc | 58 +++- cpp/serve/engine_actions/action.h | 2 + cpp/serve/engine_actions/action_commons.cc | 224 +++++++++++----- cpp/serve/engine_actions/action_commons.h | 26 +- cpp/serve/engine_actions/batch_decode.cc | 68 ++--- cpp/serve/engine_actions/batch_draft.cc | 64 ++--- cpp/serve/engine_actions/batch_verify.cc | 141 +++++----- .../engine_actions/new_request_prefill.cc | 251 ++++++++++++----- cpp/serve/engine_state.cc | 4 +- cpp/serve/function_table.cc | 2 + cpp/serve/function_table.h | 1 + cpp/serve/model.cc | 4 + cpp/serve/model.h | 3 + cpp/serve/request_state.cc | 38 ++- cpp/serve/request_state.h | 67 ++++- cpp/serve/sampler.cc | 16 +- cpp/serve/sampler.h | 18 +- .../mlc_chat/protocol/openai_api_protocol.py | 2 +- python/mlc_chat/serve/async_engine.py | 104 ++++--- python/mlc_chat/serve/config.py | 4 + python/mlc_chat/serve/data.py | 71 +++-- python/mlc_chat/serve/engine.py | 72 ++--- .../serve/entrypoints/openai_entrypoints.py | 253 ++++++++++-------- tests/python/serve/server/test_server.py | 78 ++++-- tests/python/serve/test_serve_async_engine.py | 22 +- .../serve/test_serve_async_engine_spec.py | 18 +- tests/python/serve/test_serve_engine.py | 36 ++- .../python/serve/test_serve_engine_grammar.py | 26 +- tests/python/serve/test_serve_engine_spec.py | 32 ++- 33 files changed, 1137 insertions(+), 605 deletions(-) diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 341c52b498..451b3a0279 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -29,6 +29,11 @@ GenerationConfig::GenerationConfig(String config_json_str) { ObjectPtr n = make_object(); picojson::object config = config_json.get(); + if (config.count("n")) { + CHECK(config["n"].is()); + n->n = config["n"].get(); + CHECK_GT(n->n, 0) << "\"n\" should be at least 1"; + } if (config.count("temperature")) { CHECK(config["temperature"].is()); n->temperature = config["temperature"].get(); @@ -155,6 +160,7 @@ GenerationConfig::GenerationConfig(String config_json_str) { String GenerationConfigNode::AsJSONString() const { picojson::object config; + config["n"] = picojson::value(static_cast(this->n)); config["temperature"] = picojson::value(this->temperature); config["top_p"] = picojson::value(this->top_p); config["frequency_penalty"] = picojson::value(this->frequency_penalty); diff --git a/cpp/serve/config.h b/cpp/serve/config.h index bd6d0ba0c9..e9e4d68970 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -27,6 +27,7 @@ struct ResponseFormat { /*! \brief The generation configuration of a request. */ class GenerationConfigNode : public Object { public: + int n = 1; double temperature = 0.8; double top_p = 0.95; double frequency_penalty = 0.0; diff --git a/cpp/serve/data.cc b/cpp/serve/data.cc index 770619f7c3..3e56ad6ec3 100644 --- a/cpp/serve/data.cc +++ b/cpp/serve/data.cc @@ -141,22 +141,22 @@ std::string SampleResult::GetLogProbJSON(const Tokenizer& tokenizer, bool logpro TVM_REGISTER_OBJECT_TYPE(RequestStreamOutputObj); -RequestStreamOutput::RequestStreamOutput(String request_id, - const std::vector& delta_token_ids, - Optional> delta_logprob_json_strs, - Optional finish_reason) { +RequestStreamOutput::RequestStreamOutput( + String request_id, Array group_delta_token_ids, + Optional>> group_delta_logprob_json_strs, + Array> group_finish_reason) { ObjectPtr n = make_object(); n->request_id = std::move(request_id); - n->delta_token_ids = IntTuple{delta_token_ids.begin(), delta_token_ids.end()}; - n->delta_logprob_json_strs = std::move(delta_logprob_json_strs); - n->finish_reason = std::move(finish_reason); + n->group_delta_token_ids = std::move(group_delta_token_ids); + n->group_delta_logprob_json_strs = std::move(group_delta_logprob_json_strs); + n->group_finish_reason = std::move(group_finish_reason); data_ = std::move(n); } TVM_REGISTER_GLOBAL("mlc.serve.RequestStreamOutputUnpack") .set_body_typed([](RequestStreamOutput output) { - return Array{output->request_id, output->delta_token_ids, - output->delta_logprob_json_strs, output->finish_reason}; + return Array{output->request_id, output->group_delta_token_ids, + output->group_delta_logprob_json_strs, output->group_finish_reason}; }); } // namespace serve diff --git a/cpp/serve/data.h b/cpp/serve/data.h index a63bdf81c4..ba92c662eb 100644 --- a/cpp/serve/data.h +++ b/cpp/serve/data.h @@ -128,14 +128,14 @@ class RequestStreamOutputObj : public Object { * \brief The new generated token ids since the last callback invocation * for the input request. */ - IntTuple delta_token_ids; + Array group_delta_token_ids; /*! \brief The logprobs JSON strings of the new generated tokens since last invocation. */ - Optional> delta_logprob_json_strs; + Optional>> group_delta_logprob_json_strs; /*! * \brief The finish reason of the request when it is finished, * of None if the request has not finished yet. */ - Optional finish_reason; + Array> group_finish_reason; static constexpr const char* _type_key = "mlc.serve.RequestStreamOutput"; static constexpr const bool _type_has_method_sequal_reduce = false; @@ -149,9 +149,9 @@ class RequestStreamOutputObj : public Object { */ class RequestStreamOutput : public ObjectRef { public: - explicit RequestStreamOutput(String request_id, const std::vector& delta_token_ids, - Optional> delta_logprob_json_strs, - Optional finish_reason); + explicit RequestStreamOutput(String request_id, Array group_delta_token_ids, + Optional>> group_delta_logprob_json_strs, + Array> finish_reason); TVM_DEFINE_OBJECT_REF_METHODS(RequestStreamOutput, ObjectRef, RequestStreamOutputObj); }; diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 1fce1d8ca6..411dbfc908 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -13,6 +13,7 @@ #include #include +#include #include "../tokenizers.h" #include "engine_actions/action.h" @@ -91,6 +92,7 @@ class EngineImpl : public Engine { logit_processor, // sampler, // this->kv_cache_config_, // + this->engine_mode_, // this->trace_recorder_), EngineAction::BatchDraft(this->models_, logit_processor, sampler, this->trace_recorder_, this->engine_mode_->spec_draft_length), @@ -101,6 +103,7 @@ class EngineImpl : public Engine { logit_processor, // sampler, // this->kv_cache_config_, // + this->engine_mode_, // this->trace_recorder_), EngineAction::BatchDecode(this->models_, logit_processor, sampler, this->trace_recorder_)}; @@ -135,9 +138,27 @@ class EngineImpl : public Engine { ICHECK_NE(request->input_total_length, -1); // Append to the waiting queue and create the request state. estate_->waiting_queue.push_back(request); - estate_->request_states.emplace( - request->id, RequestState(request, models_.size(), estate_->id_manager.GetNewId(), - token_table_, json_grammar_state_init_ctx_)); + + int n = request->generation_cfg->n; + int rng_seed = request->generation_cfg->seed; + + RequestState rstate; + // Create the request state entry for the input. + rstate.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(), rng_seed, + token_table_, json_grammar_state_init_ctx_); + if (n > 1) { + // Then create a request state entry for each parallel generation branch. + // We add a offset to the rng seed so that to make generations different. + rstate.reserve(n + 1); + rstate[0]->children_idx.reserve(n); + for (int i = 0; i < n; ++i) { + rstate[0]->children_idx.push_back(rstate.size()); + rstate.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(), + rng_seed + i + 1, token_table_, json_grammar_state_init_ctx_, + /*parent_idx=*/0); + } + } + estate_->request_states.emplace(request->id, rstate); } void AbortRequest(const String& request_id) final { @@ -148,26 +169,39 @@ class EngineImpl : public Engine { } RequestState rstate = it_rstate->second; - Request request = rstate->request; + Request request = rstate[0]->request; // - Check if the request is running or pending. auto it_running = std::find(estate_->running_queue.begin(), estate_->running_queue.end(), request); auto it_waiting = std::find(estate_->waiting_queue.begin(), estate_->waiting_queue.end(), request); - ICHECK(it_running != estate_->running_queue.end() || - it_waiting != estate_->waiting_queue.end()); - int64_t req_internal_id = rstate->mstates[0]->internal_id; - estate_->id_manager.RecycleId(req_internal_id); + for (const RequestStateEntry& rsentry : rstate) { + estate_->id_manager.RecycleId(rsentry->mstates[0]->internal_id); + } estate_->request_states.erase(request->id); if (it_running != estate_->running_queue.end()) { // The request to abort is in running queue estate_->running_queue.erase(it_running); - estate_->stats.current_total_seq_len -= - request->input_total_length + rstate->mstates[0]->committed_tokens.size() - 1; - RemoveRequestFromModel(estate_, req_internal_id, models_); - } else { + + // Reduce the input length. + estate_->stats.current_total_seq_len -= request->input_total_length; + // Reduce the generated length. + for (int i = 0; i < static_cast(rstate.size()); ++i) { + if (rstate[i]->status != RequestStateStatus::kAlive) { + continue; + } + estate_->stats.current_total_seq_len -= rstate[i]->mstates[0]->committed_tokens.size(); + RemoveRequestFromModel(estate_, rstate[i]->mstates[0]->internal_id, models_); + if (rstate[i]->children_idx.empty()) { + // For each running leaf state, length 1 is over reduced since the last + // token is not added into KV cache. So we add the length back. + ++estate_->stats.current_total_seq_len; + } + } + } + if (it_waiting != estate_->waiting_queue.end()) { // The request to abort is in waiting queue estate_->waiting_queue.erase(it_waiting); } diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index 8e305e26af..d6bd611802 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -56,11 +56,13 @@ class EngineAction : public ObjectRef { * \param logit_processor The logit processor. * \param sampler The sampler to sample new tokens. * \param kv_cache_config The KV cache config to help decide prefill is doable. + * \param engine_mode The engine operation mode. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction NewRequestPrefill(Array models, LogitProcessor logit_processor, Sampler sampler, KVCacheConfig kv_cache_config, + EngineMode engine_mode, Optional trace_recorder); /*! * \brief Create the action that runs one-step decode for requests in the diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index e737a048ef..d665dea778 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -16,32 +16,70 @@ void RemoveRequestFromModel(EngineState estate, int64_t req_internal_id, Array finished_requests, EngineState estate, - Array models, int max_single_sequence_length) { - // - Remove the finished request. - for (Request request : finished_requests) { - // Remove from running queue. - auto it = std::find(estate->running_queue.begin(), estate->running_queue.end(), request); - ICHECK(it != estate->running_queue.end()); - estate->running_queue.erase(it); - - // Update engine states. - RequestState state = estate->GetRequestState(request); - RemoveRequestFromModel(estate, state->mstates[0]->internal_id, models); - estate->id_manager.RecycleId(state->mstates[0]->internal_id); - estate->request_states.erase(request->id); - - // Update engine statistics. - int num_input_tokens = request->input_total_length; - int num_output_tokens = state->mstates[0]->committed_tokens.size() - 1; - estate->stats.current_total_seq_len -= num_input_tokens + num_output_tokens; - auto trequest_finish = std::chrono::high_resolution_clock::now(); - estate->stats.request_total_prefill_time += - static_cast((state->tprefill_finish - state->tadd).count()) / 1e9; - estate->stats.total_prefill_length += num_input_tokens; - estate->stats.request_total_decode_time += - static_cast((trequest_finish - state->tprefill_finish).count()) / 1e9; - estate->stats.total_decode_length += num_output_tokens; +void ProcessFinishedRequestStateEntries(RequestState finished_rsentries, EngineState estate, + Array models, int max_single_sequence_length) { + // - Remove the finished request state entries. + for (const RequestStateEntry& rsentry : finished_rsentries) { + // The finished entry must be a leaf. + ICHECK(rsentry->children_idx.empty()); + // Mark the status of this entry as finished. + rsentry->status = RequestStateStatus::kFinished; + // Remove the request state entry from all the models. + RemoveRequestFromModel(estate, rsentry->mstates[0]->internal_id, models); + estate->id_manager.RecycleId(rsentry->mstates[0]->internal_id); + estate->stats.current_total_seq_len -= + static_cast(rsentry->mstates[0]->committed_tokens.size()) - 1; + + RequestState rstate = estate->GetRequestState(rsentry->request); + int parent_idx = rsentry->parent_idx; + while (parent_idx != -1) { + bool all_children_finished = true; + for (int child_idx : rstate[parent_idx]->children_idx) { + if (rstate[child_idx]->status != RequestStateStatus::kFinished) { + all_children_finished = false; + break; + } + } + if (!all_children_finished) { + break; + } + + // All the children of the parent request state entry have finished. + // So we mark the parent entry as finished. + rstate[parent_idx]->status = RequestStateStatus::kFinished; + // Remove the request state entry from all the models. + RemoveRequestFromModel(estate, rstate[parent_idx]->mstates[0]->internal_id, models); + estate->id_manager.RecycleId(rstate[parent_idx]->mstates[0]->internal_id); + estate->stats.current_total_seq_len -= + static_cast(rstate[parent_idx]->mstates[0]->committed_tokens.size()); + // Climb up to the parent. + parent_idx = rstate[parent_idx]->parent_idx; + } + + if (parent_idx == -1) { + // All request state entries of the request have been removed. + // Reduce the total input length from the engine stats. + estate->stats.current_total_seq_len -= rsentry->request->input_total_length; + // Remove from running queue and engine state. + auto it = + std::find(estate->running_queue.begin(), estate->running_queue.end(), rsentry->request); + ICHECK(it != estate->running_queue.end()); + estate->running_queue.erase(it); + estate->request_states.erase(rsentry->request->id); + + // Update engine statistics. + const RequestStateEntry& root_rsentry = rstate[0]; + auto trequest_finish = std::chrono::high_resolution_clock::now(); + estate->stats.request_total_prefill_time += + static_cast((root_rsentry->tprefill_finish - root_rsentry->tadd).count()) / 1e9; + estate->stats.total_prefill_length += rsentry->request->input_total_length; + estate->stats.request_total_decode_time += + static_cast((trequest_finish - root_rsentry->tprefill_finish).count()) / 1e9; + for (const RequestStateEntry& entry : rstate) { + estate->stats.total_decode_length += entry->mstates[0]->committed_tokens.size(); + } + estate->stats.total_decode_length -= rsentry->request->generation_cfg->n; + } } } @@ -49,85 +87,137 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array finished_requests; - finished_requests.reserve(requests.size()); + std::vector finished_rsentries; + finished_rsentries.reserve(requests.size()); Array callback_delta_outputs; callback_delta_outputs.reserve(requests.size()); // - Collect new generated tokens and finish reasons for requests. for (Request request : requests) { + int n = request->generation_cfg->n; RequestState rstate = estate->GetRequestState(request); - auto [delta_token_ids, delta_logprob_json_strs, finish_reason] = - rstate->GetReturnTokenIds(tokenizer, max_single_sequence_length); - - // When there is no new delta tokens nor a finish reason, no need to invoke callback. - if (delta_token_ids.empty() && !finish_reason.defined()) { - continue; - } + Array group_delta_token_ids; + Array> group_delta_logprob_json_strs; + Array> group_finish_reason; + group_delta_token_ids.reserve(n); + group_delta_logprob_json_strs.reserve(n); + group_finish_reason.reserve(n); + + bool invoke_callback = false; + for (int i = 0; i < n; ++i) { + const RequestStateEntry& rsentry = n == 1 ? rstate[0] : rstate[i + 1]; + const DeltaRequestReturn& delta_request_ret = + rsentry->GetReturnTokenIds(tokenizer, max_single_sequence_length); + group_delta_token_ids.push_back(IntTuple{delta_request_ret.delta_token_ids.begin(), + delta_request_ret.delta_token_ids.end()}); + group_delta_logprob_json_strs.push_back(delta_request_ret.delta_logprob_json_strs); + group_finish_reason.push_back(delta_request_ret.finish_reason); + if (delta_request_ret.finish_reason.defined()) { + invoke_callback = true; + finished_rsentries.push_back(rsentry); + } - // Update the grammar matcher state if it exists. - if (rstate->mstates[0]->grammar_state_matcher) { - const auto& grammar_state_matcher = rstate->mstates[0]->grammar_state_matcher.value(); - for (auto token_id : delta_token_ids) { - grammar_state_matcher->AcceptToken(token_id); + if (!delta_request_ret.delta_token_ids.empty()) { + invoke_callback = true; + // Update the grammar matcher state if it exists. + if (rsentry->mstates[0]->grammar_state_matcher) { + const auto& grammar_state_matcher = rsentry->mstates[0]->grammar_state_matcher.value(); + for (int32_t token_id : delta_request_ret.delta_token_ids) { + grammar_state_matcher->AcceptToken(token_id); + } + } } } - callback_delta_outputs.push_back(RequestStreamOutput( - request->id, delta_token_ids, - request->generation_cfg->logprobs > 0 ? delta_logprob_json_strs : Optional>(), - finish_reason)); - if (finish_reason.defined()) { - finished_requests.push_back(request); + if (invoke_callback) { + callback_delta_outputs.push_back(RequestStreamOutput( + request->id, std::move(group_delta_token_ids), + request->generation_cfg->logprobs > 0 ? std::move(group_delta_logprob_json_strs) + : Optional>>(), + std::move(group_finish_reason))); } } // - Invoke the stream callback function once for all collected requests. request_stream_callback(callback_delta_outputs); - ProcessFinishedRequest(std::move(finished_requests), std::move(estate), std::move(models), - max_single_sequence_length); + ProcessFinishedRequestStateEntries(std::move(finished_rsentries), std::move(estate), + std::move(models), max_single_sequence_length); } -void PreemptLastRunningRequest(EngineState estate, const Array& models, - Optional trace_recorder) { +RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, + const Array& models, + Optional trace_recorder) { + ICHECK(!estate->running_queue.empty()); Request request = estate->running_queue.back(); + // Find the last alive request state entry, which is what we want to preempt. + RequestState rstate = estate->GetRequestState(request); + int preempt_rstate_idx = -1; + for (int i = static_cast(rstate.size()) - 1; i >= 0; --i) { + if (rstate[i]->status == RequestStateStatus::kAlive) { + preempt_rstate_idx = i; + break; + } + } + ICHECK_NE(preempt_rstate_idx, -1); + RequestStateEntry rsentry = rstate[preempt_rstate_idx]; + // Remove from models. // - Clear model speculation draft. // - Update `inputs` for future prefill. - RequestState rstate = estate->GetRequestState(request); - RECORD_EVENT(trace_recorder, rstate->request->id, "preempt"); + RECORD_EVENT(trace_recorder, rsentry->request->id, "preempt"); + rsentry->status = RequestStateStatus::kPending; + estate->stats.current_total_seq_len -= rsentry->mstates[0]->committed_tokens.size(); + if (rsentry->children_idx.empty()) { + // The length was overly decreased by 1 when the entry has no child. + ++estate->stats.current_total_seq_len; + } + if (rsentry->parent_idx == -1) { + // Subtract the input length from the total length when the + // current entry is the root entry of the request. + estate->stats.current_total_seq_len -= request->input_total_length; + } estate->stats.current_total_seq_len -= - request->input_total_length + rstate->mstates[0]->committed_tokens.size() - 1; - for (RequestModelState mstate : rstate->mstates) { + request->input_total_length + rsentry->mstates[0]->committed_tokens.size() - 1; + for (RequestModelState mstate : rsentry->mstates) { mstate->RemoveAllDraftTokens(); ICHECK(mstate->inputs.empty()); - ICHECK(!mstate->committed_tokens.empty()); std::vector committed_token_ids; committed_token_ids.reserve(mstate->committed_tokens.size()); for (const SampleResult& committed_token : mstate->committed_tokens) { committed_token_ids.push_back(committed_token.sampled_token_id.first); } - Array inputs = request->inputs; - if (const auto* token_input = inputs.back().as()) { - // Merge the TokenData so that a single time TokenEmbed is needed. - std::vector token_ids{token_input->token_ids->data, - token_input->token_ids->data + token_input->token_ids.size()}; - token_ids.insert(token_ids.end(), committed_token_ids.begin(), committed_token_ids.end()); - inputs.Set(inputs.size() - 1, TokenData(token_ids)); - } else { + Array inputs; + if (rsentry->parent_idx == -1) { + inputs = request->inputs; + if (const auto* token_input = inputs.back().as()) { + // Merge the TokenData so that a single time TokenEmbed is needed. + std::vector token_ids{token_input->token_ids->data, + token_input->token_ids->data + token_input->token_ids.size()}; + token_ids.insert(token_ids.end(), committed_token_ids.begin(), committed_token_ids.end()); + inputs.Set(inputs.size() - 1, TokenData(token_ids)); + } else if (!committed_token_ids.empty()) { + inputs.push_back(TokenData(committed_token_ids)); + } + } else if (!committed_token_ids.empty()) { inputs.push_back(TokenData(committed_token_ids)); } mstate->inputs = std::move(inputs); } - RemoveRequestFromModel(estate, rstate->mstates[0]->internal_id, models); + RemoveRequestFromModel(estate, rsentry->mstates[0]->internal_id, models); - // Move from running queue to the front of waiting queue. - estate->running_queue.erase(estate->running_queue.end() - 1); - estate->waiting_queue.insert(estate->waiting_queue.begin(), request); + if (preempt_rstate_idx == 0) { + // Remove from running queue. + estate->running_queue.erase(estate->running_queue.end() - 1); + } + if (preempt_rstate_idx == static_cast(rstate.size()) - 1) { + // Add to the front of waiting queue. + estate->waiting_queue.insert(estate->waiting_queue.begin(), request); + } + return rsentry; } } // namespace serve diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index 520180beff..bc3d10ee06 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -46,15 +46,31 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array& models, - Optional trace_recorder); +RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, + const Array& models, + Optional trace_recorder); + +/*! \brief Get the running request entries from the engine state. */ +inline std::vector GetRunningRequestStateEntries(const EngineState& estate) { + std::vector rsentries; + for (const Request& request : estate->running_queue) { + for (const RequestStateEntry& rsentry : estate->GetRequestState(request)) { + if (rsentry->status == RequestStateStatus::kAlive && rsentry->children_idx.empty()) { + rsentries.push_back(rsentry); + } + } + } + return rsentries; +} } // namespace serve } // namespace llm diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index d7821020a1..0b23541c22 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -37,42 +37,46 @@ class BatchDecodeActionObj : public EngineActionObj { return {}; } - // Preempt requests when decode cannot apply. - int num_available_pages = models_[0]->GetNumAvailablePages(); - while (!CanDecode(estate->running_queue.size())) { - PreemptLastRunningRequest(estate, models_, trace_recorder_); + // Preempt request state entries when decode cannot apply. + std::vector running_rsentries = GetRunningRequestStateEntries(estate); + while (!CanDecode(running_rsentries.size())) { + RequestStateEntry preempted = + PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + if (preempted.same_as(running_rsentries.back())) { + running_rsentries.pop_back(); + } } auto tstart = std::chrono::high_resolution_clock::now(); - // NOTE: Right now we only support decode all the running requests at a time. - int num_requests = estate->running_queue.size(); - estate->stats.current_total_seq_len += num_requests; + // NOTE: Right now we only support decode all the running request states at a time. + int num_rsentries = running_rsentries.size(); + estate->stats.current_total_seq_len += num_rsentries; // Collect // - the last committed token, - // - the request states, - // - the sampling parameters, - // of each request. + // - the request id, + // - the generation config, + // - the random number generator, + // of each request state entry. std::vector input_tokens; Array request_ids; std::vector request_internal_ids; Array mstates; Array generation_cfg; std::vector rngs; - input_tokens.reserve(num_requests); - request_ids.reserve(num_requests); - request_internal_ids.reserve(num_requests); - mstates.reserve(num_requests); - generation_cfg.reserve(num_requests); - rngs.reserve(num_requests); - for (Request request : estate->running_queue) { - RequestState rstate = estate->GetRequestState(request); - input_tokens.push_back(rstate->mstates[0]->committed_tokens.back().sampled_token_id.first); - request_ids.push_back(request->id); - request_internal_ids.push_back(rstate->mstates[0]->internal_id); - mstates.push_back(rstate->mstates[0]); - generation_cfg.push_back(request->generation_cfg); - rngs.push_back(&rstate->rng); + input_tokens.reserve(num_rsentries); + request_ids.reserve(num_rsentries); + request_internal_ids.reserve(num_rsentries); + mstates.reserve(num_rsentries); + generation_cfg.reserve(num_rsentries); + rngs.reserve(num_rsentries); + for (const RequestStateEntry& rsentry : running_rsentries) { + input_tokens.push_back(rsentry->mstates[0]->committed_tokens.back().sampled_token_id.first); + request_ids.push_back(rsentry->request->id); + request_internal_ids.push_back(rsentry->mstates[0]->internal_id); + mstates.push_back(rsentry->mstates[0]); + generation_cfg.push_back(rsentry->request->generation_cfg); + rngs.push_back(&rsentry->rng); } // - Compute embeddings. @@ -82,8 +86,8 @@ class BatchDecodeActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "finish embedding"); ICHECK_EQ(embeddings->ndim, 3); ICHECK_EQ(embeddings->shape[0], 1); - ICHECK_EQ(embeddings->shape[1], num_requests); - embeddings = embeddings.CreateView({num_requests, 1, embeddings->shape[2]}, embeddings->dtype); + ICHECK_EQ(embeddings->shape[1], num_rsentries); + embeddings = embeddings.CreateView({num_rsentries, 1, embeddings->shape[2]}, embeddings->dtype); // - Invoke model decode. RECORD_EVENT(trace_recorder_, request_ids, "start decode"); @@ -94,7 +98,7 @@ class BatchDecodeActionObj : public EngineActionObj { ICHECK_EQ(logits->shape[1], 1); // - Update logits. - logits = logits.CreateView({num_requests, logits->shape[2]}, logits->dtype); + logits = logits.CreateView({num_rsentries, logits->shape[2]}, logits->dtype); logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); // - Compute probability distributions. @@ -104,10 +108,10 @@ class BatchDecodeActionObj : public EngineActionObj { // - Sample tokens. std::vector sample_results = sampler_->BatchSampleTokens(probs_device, request_ids, generation_cfg, rngs); - ICHECK_EQ(sample_results.size(), num_requests); + ICHECK_EQ(sample_results.size(), num_rsentries); // - Update the committed tokens of states. - for (int i = 0; i < num_requests; ++i) { + for (int i = 0; i < num_rsentries; ++i) { mstates[i]->CommitToken(sample_results[i]); } @@ -118,10 +122,10 @@ class BatchDecodeActionObj : public EngineActionObj { } private: - /*! \brief Check if the input requests can be decoded under conditions. */ - bool CanDecode(int num_requests) { + /*! \brief Check if the input request state entries can be decoded under conditions. */ + bool CanDecode(int num_rsentries) { int num_available_pages = models_[0]->GetNumAvailablePages(); - return num_requests <= num_available_pages; + return num_rsentries <= num_available_pages; } /*! diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index d9eba8e037..da345b6c89 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -36,49 +36,51 @@ class BatchDraftActionObj : public EngineActionObj { return {}; } - // Preempt requests when decode cannot apply. - while (!CanDecode(estate->running_queue.size())) { - PreemptLastRunningRequest(estate, models_, trace_recorder_); + // Preempt request state entries when decode cannot apply. + std::vector running_rsentries = GetRunningRequestStateEntries(estate); + while (!CanDecode(running_rsentries.size())) { + RequestStateEntry preempted = + PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + if (preempted.same_as(running_rsentries.back())) { + running_rsentries.pop_back(); + } } auto tstart = std::chrono::high_resolution_clock::now(); - // NOTE: Right now we only support decode all the running requests at a time. - int num_requests = estate->running_queue.size(); + int num_rsentries = running_rsentries.size(); Array request_ids; std::vector request_internal_ids; Array generation_cfg; - Array rstates; std::vector rngs; - request_ids.reserve(num_requests); - request_internal_ids.reserve(num_requests); - generation_cfg.reserve(num_requests); - rstates.reserve(num_requests); - for (const Request& request : estate->running_queue) { - RequestState rstate = estate->GetRequestState(request); - request_ids.push_back(request->id); - rstates.push_back(rstate); - request_internal_ids.push_back(rstate->mstates[0]->internal_id); - generation_cfg.push_back(request->generation_cfg); - rngs.push_back(&rstate->rng); + request_ids.reserve(num_rsentries); + request_internal_ids.reserve(num_rsentries); + generation_cfg.reserve(num_rsentries); + for (const RequestStateEntry& rsentry : running_rsentries) { + request_ids.push_back(rsentry->request->id); + request_internal_ids.push_back(rsentry->mstates[0]->internal_id); + generation_cfg.push_back(rsentry->request->generation_cfg); + rngs.push_back(&rsentry->rng); } // The first model doesn't get involved in draft proposal. for (int model_id = 1; model_id < static_cast(models_.size()); ++model_id) { // Collect // - the last committed token, - // - the request states, - // - the sampling parameters, + // - the request model state // of each request. std::vector input_tokens; - Array mstates = - rstates.Map([model_id](const RequestState& rstate) { return rstate->mstates[model_id]; }); - input_tokens.reserve(num_requests); + Array mstates; + input_tokens.reserve(num_rsentries); + mstates.reserve(num_rsentries); + for (const RequestStateEntry& rsentry : running_rsentries) { + mstates.push_back(rsentry->mstates[model_id]); + } // draft_length_ rounds of draft proposal. for (int draft_id = 0; draft_id < draft_length_; ++draft_id) { // prepare new input tokens input_tokens.clear(); - for (int i = 0; i < num_requests; ++i) { + for (int i = 0; i < num_rsentries; ++i) { // The first draft proposal uses the last committed token. input_tokens.push_back( draft_id == 0 ? mstates[i]->committed_tokens.back().sampled_token_id.first @@ -92,9 +94,9 @@ class BatchDraftActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "finish proposal embedding"); ICHECK_EQ(embeddings->ndim, 3); ICHECK_EQ(embeddings->shape[0], 1); - ICHECK_EQ(embeddings->shape[1], num_requests); + ICHECK_EQ(embeddings->shape[1], num_rsentries); embeddings = - embeddings.CreateView({num_requests, 1, embeddings->shape[2]}, embeddings->dtype); + embeddings.CreateView({num_rsentries, 1, embeddings->shape[2]}, embeddings->dtype); // - Invoke model decode. RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); @@ -105,7 +107,7 @@ class BatchDraftActionObj : public EngineActionObj { ICHECK_EQ(logits->shape[1], 1); // - Update logits. - logits = logits.CreateView({num_requests, logits->shape[2]}, logits->dtype); + logits = logits.CreateView({num_rsentries, logits->shape[2]}, logits->dtype); logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); // - Compute probability distributions. @@ -115,11 +117,11 @@ class BatchDraftActionObj : public EngineActionObj { // - Sample tokens. std::vector prob_dist; std::vector sample_results = sampler_->BatchSampleTokens( - probs_device, request_ids, generation_cfg, rngs, &prob_dist); - ICHECK_EQ(sample_results.size(), num_requests); + probs_device, request_ids, generation_cfg, rngs, /*prob_indices=*/nullptr, &prob_dist); + ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. - for (int i = 0; i < num_requests; ++i) { + for (int i = 0; i < num_rsentries; ++i) { mstates[i]->AddDraftToken(sample_results[i], prob_dist[i]); estate->stats.total_draft_length += 1; } @@ -134,12 +136,12 @@ class BatchDraftActionObj : public EngineActionObj { private: /*! \brief Check if the input requests can be decoded under conditions. */ - bool CanDecode(int num_requests) { + bool CanDecode(int num_rsentries) { // The first model is not involved in draft proposal. for (int model_id = 1; model_id < static_cast(models_.size()); ++model_id) { // Check if the model has enough available pages. int num_available_pages = models_[model_id]->GetNumAvailablePages(); - if (num_requests > num_available_pages) { + if (num_rsentries > num_available_pages) { return false; } } diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index b608c5b3b3..3720340589 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -42,15 +42,15 @@ class BatchVerifyActionObj : public EngineActionObj { return {}; } - const auto& [requests, rstates, draft_lengths, total_draft_length] = GetDraftsToVerify(estate); - ICHECK_EQ(requests.size(), rstates.size()); - ICHECK_EQ(requests.size(), draft_lengths.size()); - if (requests.empty()) { + const auto& [rsentries, draft_lengths, total_draft_length] = GetDraftsToVerify(estate); + ICHECK_EQ(rsentries.size(), draft_lengths.size()); + if (rsentries.empty()) { return {}; } - int num_requests = requests.size(); - Array request_ids = requests.Map([](const Request& request) { return request->id; }); + int num_rsentries = rsentries.size(); + Array request_ids = + rsentries.Map([](const RequestStateEntry& rstate) { return rstate->request->id; }); auto tstart = std::chrono::high_resolution_clock::now(); // - Get embedding and run verify. @@ -61,17 +61,17 @@ class BatchVerifyActionObj : public EngineActionObj { std::vector rngs; std::vector> draft_output_tokens; std::vector> draft_output_prob_dist; - request_internal_ids.reserve(num_requests); + request_internal_ids.reserve(num_rsentries); all_tokens_to_verify.reserve(total_draft_length); - verify_request_mstates.reserve(num_requests); - rngs.reserve(num_requests); - generation_cfg.reserve(num_requests); - draft_output_tokens.reserve(num_requests); - draft_output_prob_dist.reserve(num_requests); - - for (int i = 0; i < num_requests; ++i) { - RequestModelState verify_mstate = rstates[i]->mstates[verify_model_id_]; - RequestModelState draft_mstate = rstates[i]->mstates[draft_model_id_]; + verify_request_mstates.reserve(num_rsentries); + rngs.reserve(num_rsentries); + generation_cfg.reserve(num_rsentries); + draft_output_tokens.reserve(num_rsentries); + draft_output_prob_dist.reserve(num_rsentries); + + for (int i = 0; i < num_rsentries; ++i) { + RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_]; + RequestModelState draft_mstate = rsentries[i]->mstates[draft_model_id_]; request_internal_ids.push_back(verify_mstate->internal_id); ICHECK(!draft_lengths.empty()); ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_tokens.size()); @@ -82,8 +82,8 @@ class BatchVerifyActionObj : public EngineActionObj { all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); } verify_request_mstates.push_back(verify_mstate); - generation_cfg.push_back(requests[i]->generation_cfg); - rngs.push_back(&rstates[i]->rng); + generation_cfg.push_back(rsentries[i]->request->generation_cfg); + rngs.push_back(&rsentries[i]->rng); draft_output_tokens.push_back(draft_mstate->draft_output_tokens); draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); } @@ -103,7 +103,8 @@ class BatchVerifyActionObj : public EngineActionObj { // - Update logits. std::vector cum_verify_lengths = {0}; - for (int i = 0; i < num_requests; ++i) { + cum_verify_lengths.reserve(num_rsentries + 1); + for (int i = 0; i < num_rsentries; ++i) { cum_verify_lengths.push_back(cum_verify_lengths.back() + draft_lengths[i]); } logits = logits.CreateView({total_draft_length, logits->shape[2]}, logits->dtype); @@ -117,14 +118,14 @@ class BatchVerifyActionObj : public EngineActionObj { std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokens( probs_device, request_ids, cum_verify_lengths, generation_cfg, rngs, draft_output_tokens, draft_output_prob_dist); - ICHECK_EQ(sample_results_arr.size(), num_requests); + ICHECK_EQ(sample_results_arr.size(), num_rsentries); - for (int i = 0; i < num_requests; ++i) { + for (int i = 0; i < num_rsentries; ++i) { const std::vector& sample_results = sample_results_arr[i]; int accept_length = sample_results.size(); for (SampleResult sample_result : sample_results) { - rstates[i]->mstates[verify_model_id_]->CommitToken(sample_result); - rstates[i]->mstates[draft_model_id_]->CommitToken(sample_result); + rsentries[i]->mstates[verify_model_id_]->CommitToken(sample_result); + rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result); } estate->stats.current_total_seq_len += accept_length; estate->stats.total_accepted_length += accept_length; @@ -137,46 +138,32 @@ class BatchVerifyActionObj : public EngineActionObj { // it is possible to re-compute prefill for the small models. if (rollback_length > 0) { models_[verify_model_id_]->PopNFromKVCache( - rstates[i]->mstates[verify_model_id_]->internal_id, rollback_length); - models_[draft_model_id_]->PopNFromKVCache(rstates[i]->mstates[draft_model_id_]->internal_id, - rollback_length); + rsentries[i]->mstates[verify_model_id_]->internal_id, rollback_length); + models_[draft_model_id_]->PopNFromKVCache( + rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length); } } - // clear the draft model states - for (int i = 0; i < num_requests; ++i) { - rstates[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(); + // clear the draft model state entries + for (int i = 0; i < num_rsentries; ++i) { + rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(); } auto tend = std::chrono::high_resolution_clock::now(); estate->stats.engine_total_decode_time += static_cast((tend - tstart).count()) / 1e9; - return requests; + return estate->running_queue; } private: - /*! \brief Check if the drafts can be verified under conditions. */ - bool CanVerify(EngineState estate, int num_verify_req, int total_draft_length, - int num_required_pages, int num_available_pages) { - int num_running_requests = estate->running_queue.size(); - ICHECK_LE(num_running_requests, kv_cache_config_->max_num_sequence); - - // No exceeding of the maximum allowed requests that can - // run simultaneously. - if (num_running_requests + num_verify_req > kv_cache_config_->max_num_sequence) { - return false; - } - - // NOTE: The conditions are heuristic and can be revised. - // Cond 1: total input length <= prefill chunk size. - // Cond 2: at least one verify can be performed. - // Cond 3: number of total tokens does not exceed the limit - int new_batch_size = num_running_requests + num_verify_req; - return total_draft_length <= kv_cache_config_->prefill_chunk_size && - num_required_pages <= num_available_pages && - estate->stats.current_total_seq_len + total_draft_length <= - kv_cache_config_->max_total_sequence_length; - } + struct DraftRequestStateEntries { + /*! \brief The request state entries to verify. */ + Array draft_rsentries; + /*! \brief The draft length of each request state. */ + std::vector draft_lengths; + /*! \brief The total draft length. */ + int total_draft_length; + }; /*! * \brief Decide whether to run verify for the draft of each request. @@ -184,43 +171,43 @@ class BatchVerifyActionObj : public EngineActionObj { * \return The drafts to verify, together with their respective * state and input length. */ - std::tuple, Array, std::vector, int> GetDraftsToVerify( - EngineState estate) { - // - Try to verify pending requests. - std::vector verify_requests; - std::vector rstates; + DraftRequestStateEntries GetDraftsToVerify(EngineState estate) { std::vector draft_lengths; int total_draft_length = 0; int total_required_pages = 0; int num_available_pages = models_[verify_model_id_]->GetNumAvailablePages(); - int req_id = 1; - for (; req_id <= static_cast(estate->running_queue.size()); ++req_id) { - Request request = estate->running_queue[req_id - 1]; - RequestState rstate = estate->GetRequestState(request); - int draft_length = rstate->mstates[draft_model_id_]->draft_output_tokens.size(); + // Preempt the request state entries that cannot fit the large model for verification. + std::vector running_rsentries = GetRunningRequestStateEntries(estate); + std::vector num_page_requirement; + num_page_requirement.reserve(running_rsentries.size()); + for (const RequestStateEntry& rsentry : running_rsentries) { + int draft_length = rsentry->mstates[draft_model_id_]->draft_output_tokens.size(); int num_require_pages = (draft_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + draft_lengths.push_back(draft_length); + num_page_requirement.push_back(num_require_pages); total_draft_length += draft_length; total_required_pages += num_require_pages; - if (CanVerify(estate, req_id, total_draft_length, total_required_pages, - num_available_pages)) { - verify_requests.push_back(request); - rstates.push_back(rstate); - draft_lengths.push_back(draft_length); - } else { - total_draft_length -= draft_length; - total_required_pages -= num_require_pages; - break; - } } - // preempt all the remaining requests - while (req_id <= static_cast(estate->running_queue.size())) { - PreemptLastRunningRequest(estate, models_, trace_recorder_); - req_id += 1; + while (!CanVerify(total_required_pages)) { + RequestStateEntry preempted = + PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + if (preempted.same_as(running_rsentries.back())) { + total_draft_length -= draft_lengths.back(); + total_required_pages -= num_page_requirement.back(); + draft_lengths.pop_back(); + num_page_requirement.pop_back(); + running_rsentries.pop_back(); + } } - return {verify_requests, rstates, draft_lengths, total_draft_length}; + return {running_rsentries, draft_lengths, total_draft_length}; + } + + bool CanVerify(int num_required_pages) { + int num_available_pages = models_[0]->GetNumAvailablePages(); + return num_required_pages <= num_available_pages; } /*! diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index 72f54388e7..24d431ae7e 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -7,6 +7,7 @@ #include "../model.h" #include "../sampler.h" #include "action.h" +#include "action_commons.h" namespace mlc { namespace llm { @@ -20,32 +21,49 @@ class NewRequestPrefillActionObj : public EngineActionObj { public: explicit NewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, KVCacheConfig kv_cache_config, + EngineMode engine_mode, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), kv_cache_config_(std::move(kv_cache_config)), + engine_mode_(std::move(engine_mode)), trace_recorder_(std::move(trace_recorder)) {} Array Step(EngineState estate) final { // - Find the requests in `waiting_queue` that can prefill in this step. - auto [requests, rstates, prefill_lengths] = GetRequestsToPrefill(estate); - ICHECK_EQ(requests.size(), rstates.size()); - ICHECK_EQ(requests.size(), prefill_lengths.size()); - if (requests.empty()) { + auto [rstates, prefill_lengths] = GetRequestStatesToPrefill(estate); + ICHECK_EQ(rstates.size(), prefill_lengths.size()); + if (rstates.empty()) { return {}; } - int num_requests = requests.size(); - Array request_ids = requests.Map([](const Request& request) { return request->id; }); + int num_rstates = rstates.size(); auto tstart = std::chrono::high_resolution_clock::now(); - // - Move requests from waiting queue to running queue. - for (int i = 0; i < num_requests; ++i) { - auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), requests[i]); - ICHECK(it != estate->waiting_queue.end()); - estate->waiting_queue.erase(it); - estate->running_queue.push_back(requests[i]); + // - Update status of request states from pending to alive. + Array request_ids; + std::vector rstates_of_requests; + request_ids.reserve(num_rstates); + rstates_of_requests.reserve(num_rstates); + for (RequestStateEntry rstate : rstates) { + const Request& request = rstate->request; + RequestState request_rstates = estate->GetRequestState(request); + request_ids.push_back(request->id); + rstate->status = RequestStateStatus::kAlive; + + // - Remove the request from waiting queue if all its request states are now alive. + // - Add the request to running queue if all its request states were pending. + bool alive_state_existed = false; + for (const RequestStateEntry& request_state : request_rstates) { + if (request_state->status == RequestStateStatus::kAlive && !request_state.same_as(rstate)) { + alive_state_existed = true; + } + } + if (!alive_state_existed) { + estate->running_queue.push_back(request); + } + rstates_of_requests.push_back(std::move(request_rstates)); } // - Get embedding and run prefill for each model. @@ -53,22 +71,28 @@ class NewRequestPrefillActionObj : public EngineActionObj { for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { Array embeddings; std::vector request_internal_ids; - embeddings.reserve(num_requests); - request_internal_ids.reserve(num_requests); - for (int i = 0; i < num_requests; ++i) { + embeddings.reserve(num_rstates); + request_internal_ids.reserve(num_rstates); + for (int i = 0; i < num_rstates; ++i) { RequestModelState mstate = rstates[i]->mstates[model_id]; ICHECK_EQ(mstate->GetInputLength(), prefill_lengths[i]); ICHECK(mstate->draft_output_tokens.empty()); ICHECK(mstate->draft_output_prob_dist.empty()); ICHECK(!mstate->inputs.empty()); - // Add the sequence to the model. - models_[model_id]->AddNewSequence(mstate->internal_id); + // Add the sequence to the model, or fork the sequence from its parent. + if (rstates[i]->parent_idx == -1) { + models_[model_id]->AddNewSequence(mstate->internal_id); + } else { + models_[model_id]->ForkSequence( + rstates_of_requests[i][rstates[i]->parent_idx]->mstates[model_id]->internal_id, + mstate->internal_id); + } request_internal_ids.push_back(mstate->internal_id); - RECORD_EVENT(trace_recorder_, requests[i]->id, "start embedding"); + RECORD_EVENT(trace_recorder_, rstates[i]->request->id, "start embedding"); for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { embeddings.push_back(mstate->inputs[i]->GetEmbedding(models_[model_id])); } - RECORD_EVENT(trace_recorder_, requests[i]->id, "finish embedding"); + RECORD_EVENT(trace_recorder_, rstates[i]->request->id, "finish embedding"); // Clean up `inputs` after prefill mstate->inputs.clear(); } @@ -79,7 +103,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "finish prefill"); ICHECK_EQ(logits->ndim, 3); ICHECK_EQ(logits->shape[0], 1); - ICHECK_EQ(logits->shape[1], num_requests); + ICHECK_EQ(logits->shape[1], num_rstates); if (model_id == 0) { // We only need to sample for model 0 in prefill. @@ -90,19 +114,16 @@ class NewRequestPrefillActionObj : public EngineActionObj { // - Update logits. ICHECK(logits_for_sample.defined()); Array generation_cfg; - Array mstates_for_sample; - std::vector rngs; - generation_cfg.reserve(num_requests); - mstates_for_sample.reserve(num_requests); - rngs.reserve(num_requests); - for (int i = 0; i < num_requests; ++i) { - generation_cfg.push_back(requests[i]->generation_cfg); - mstates_for_sample.push_back(rstates[i]->mstates[0]); - rngs.push_back(&rstates[i]->rng); + Array mstates_for_logitproc; + generation_cfg.reserve(num_rstates); + mstates_for_logitproc.reserve(num_rstates); + for (int i = 0; i < num_rstates; ++i) { + generation_cfg.push_back(rstates[i]->request->generation_cfg); + mstates_for_logitproc.push_back(rstates[i]->mstates[0]); } - logits_for_sample = logits_for_sample.CreateView({num_requests, logits_for_sample->shape[2]}, + logits_for_sample = logits_for_sample.CreateView({num_rstates, logits_for_sample->shape[2]}, logits_for_sample->dtype); - logit_processor_->InplaceUpdateLogits(logits_for_sample, generation_cfg, mstates_for_sample, + logit_processor_->InplaceUpdateLogits(logits_for_sample, generation_cfg, mstates_for_logitproc, request_ids); // - Compute probability distributions. @@ -110,85 +131,172 @@ class NewRequestPrefillActionObj : public EngineActionObj { logit_processor_->ComputeProbsFromLogits(logits_for_sample, generation_cfg, request_ids); // - Sample tokens. + // For rstates which are depended by other states, sample + // one token for each rstate that is depending. + // Otherwise, sample a token for the current rstate. + std::vector prob_indices; + RequestState rstates_for_sample; + std::vector rngs; + prob_indices.reserve(num_rstates); + rstates_for_sample.reserve(num_rstates); + rngs.reserve(num_rstates); + request_ids.clear(); + generation_cfg.clear(); + for (int i = 0; i < num_rstates; ++i) { + estate->stats.current_total_seq_len += prefill_lengths[i]; + const RequestStateEntry& rstate = rstates[i]; + for (int child_idx : rstate->children_idx) { + if (rstates_of_requests[i][child_idx]->mstates[0]->committed_tokens.empty()) { + // If rstates_of_requests[i][child_idx] has no committed token, + // the prefill of the current rstate will unblock rstates_of_requests[i][child_idx], + // and thus we want to sample a token for rstates_of_requests[i][child_idx]. + prob_indices.push_back(i); + rstates_for_sample.push_back(rstates_of_requests[i][child_idx]); + request_ids.push_back(rstate->request->id); + generation_cfg.push_back(rstate->request->generation_cfg); + rngs.push_back(&rstates_of_requests[i][child_idx]->rng); + + ICHECK(rstates_of_requests[i][child_idx]->status == RequestStateStatus::kPending); + rstates_of_requests[i][child_idx]->status = RequestStateStatus::kAlive; + for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { + models_[model_id]->ForkSequence( + rstate->mstates[model_id]->internal_id, + rstates_of_requests[i][child_idx]->mstates[model_id]->internal_id); + } + } + } + if (rstate->children_idx.empty()) { + // If rstate has no child, we sample a token for itself. + prob_indices.push_back(i); + rstates_for_sample.push_back(rstate); + request_ids.push_back(rstate->request->id); + generation_cfg.push_back(rstate->request->generation_cfg); + rngs.push_back(&rstate->rng); + } + } std::vector sample_results = - sampler_->BatchSampleTokens(probs_device, request_ids, generation_cfg, rngs); - ICHECK_EQ(sample_results.size(), num_requests); + sampler_->BatchSampleTokens(probs_device, request_ids, generation_cfg, rngs, &prob_indices); + ICHECK_EQ(sample_results.size(), rstates_for_sample.size()); // - Update the committed tokens of states. // - If a request is first-time prefilled, set the prefill finish time. - // - Accumulate the sequence length in engine statistics. - int sum_prefill_lengths = 0; auto tnow = std::chrono::high_resolution_clock::now(); - for (int i = 0; i < num_requests; ++i) { - for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { - rstates[i]->mstates[model_id]->CommitToken(sample_results[i]); + for (int i = 0; i < static_cast(rstates_for_sample.size()); ++i) { + for (const RequestModelState& mstate : rstates_for_sample[i]->mstates) { + mstate->CommitToken(sample_results[i]); } - if (mstates_for_sample[i]->committed_tokens.size() == 1) { - rstates[i]->tprefill_finish = tnow; + if (rstates_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { + rstates_for_sample[i]->tprefill_finish = tnow; } - sum_prefill_lengths += prefill_lengths[i]; } - estate->stats.current_total_seq_len += sum_prefill_lengths; auto tend = std::chrono::high_resolution_clock::now(); estate->stats.engine_total_prefill_time += static_cast((tend - tstart).count()) / 1e9; - return requests; + std::vector processed_requests; + { + processed_requests.reserve(num_rstates); + std::unordered_set dedup_map; + for (int i = 0; i < static_cast(rstates.size()); ++i) { + const RequestStateEntry& rstate = rstates[i]; + if (dedup_map.find(rstate->request.get()) != dedup_map.end()) { + continue; + } + dedup_map.insert(rstate->request.get()); + processed_requests.push_back(rstate->request); + + bool pending_state_exists = false; + for (const RequestStateEntry& request_state : rstates_of_requests[i]) { + if (request_state->status == RequestStateStatus::kPending) { + pending_state_exists = true; + break; + } + } + if (!pending_state_exists) { + auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), + rstate->request); + ICHECK(it != estate->waiting_queue.end()); + estate->waiting_queue.erase(it); + } + } + } + return processed_requests; } private: /*! - * \brief Find one or multiple requests to run prefill. + * \brief Find one or multiple request states to run prefill. * \param estate The engine state. * \return The requests to prefill, together with their respective * state and input length. */ - std::tuple, Array, std::vector> GetRequestsToPrefill( + std::tuple, std::vector> GetRequestStatesToPrefill( EngineState estate) { if (estate->waiting_queue.empty()) { // No request to prefill. - return {{}, {}, {}}; + return {{}, {}}; } // - Try to prefill pending requests. - std::vector prefill_requests; - std::vector rstates; + std::vector rsentries_to_prefill; std::vector prefill_lengths; int total_input_length = 0; int total_required_pages = 0; int num_available_pages = models_[0]->GetNumAvailablePages(); + int num_running_rsentries = GetRunningRequestStateEntries(estate).size(); - for (int i = 1; i <= static_cast(estate->waiting_queue.size()); ++i) { - Request request = estate->waiting_queue[i - 1]; + int num_prefill_rsentries = 0; + for (const Request& request : estate->waiting_queue) { RequestState rstate = estate->GetRequestState(request); - int input_length = rstate->mstates[0]->GetInputLength(); - int num_require_pages = - (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; - total_input_length += input_length; - total_required_pages += num_require_pages; - if (CanPrefill(estate, i, total_input_length, total_required_pages, num_available_pages)) { - prefill_requests.push_back(request); - rstates.push_back(rstate); - prefill_lengths.push_back(input_length); - } else { - total_input_length -= input_length; - total_required_pages -= num_require_pages; + bool prefill_stops = false; + for (const RequestStateEntry& rsentry : rstate) { + // A request state entry can be prefilled only when: + // - it has inputs, and + // - it is pending, and + // - it has no parent or its parent is alive. + if (rsentry->mstates[0]->inputs.empty() || + rsentry->status != RequestStateStatus::kPending || + (rsentry->parent_idx != -1 && + rstate[rsentry->parent_idx]->status == RequestStateStatus::kPending)) { + continue; + } + + int input_length = rsentry->mstates[0]->GetInputLength(); + int num_require_pages = + (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + if (CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->children_idx.size(), + total_input_length, total_required_pages, num_available_pages, + num_running_rsentries)) { + rsentries_to_prefill.push_back(rsentry); + prefill_lengths.push_back(input_length); + ++num_prefill_rsentries; + } else { + total_input_length -= input_length; + total_required_pages -= num_require_pages; + prefill_stops = true; + break; + } + } + if (prefill_stops) { break; } } - return {prefill_requests, rstates, prefill_lengths}; + return {rsentries_to_prefill, prefill_lengths}; } /*! \brief Check if the input requests can be prefilled under conditions. */ - bool CanPrefill(EngineState estate, int num_prefill_req, int total_input_length, - int num_required_pages, int num_available_pages) { - int num_running_requests = estate->running_queue.size(); - ICHECK_LE(num_running_requests, kv_cache_config_->max_num_sequence); + bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, + int num_required_pages, int num_available_pages, int num_running_rsentries) { + ICHECK_LE(num_running_rsentries, kv_cache_config_->max_num_sequence); // No exceeding of the maximum allowed requests that can // run simultaneously. - if (num_running_requests + num_prefill_req > kv_cache_config_->max_num_sequence) { + int spec_factor = engine_mode_->enable_speculative ? engine_mode_->spec_draft_length : 1; + if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > + kv_cache_config_->max_num_sequence) { return false; } @@ -198,7 +306,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { // Cond 3: number of total tokens after 8 times of decode does not // exceed the limit, where 8 is a watermark number can // be configured and adjusted in the future. - int new_batch_size = num_running_requests + num_prefill_req; + int new_batch_size = num_running_rsentries + num_prefill_rsentries; return total_input_length <= kv_cache_config_->prefill_chunk_size && num_required_pages + new_batch_size <= num_available_pages && estate->stats.current_total_seq_len + total_input_length + 8 * new_batch_size <= @@ -213,16 +321,19 @@ class NewRequestPrefillActionObj : public EngineActionObj { Sampler sampler_; /*! \brief The KV cache config to help decide prefill is doable. */ KVCacheConfig kv_cache_config_; + /*! \brief The engine operation mode. */ + EngineMode engine_mode_; /*! \brief Event trace recorder. */ Optional trace_recorder_; }; EngineAction EngineAction::NewRequestPrefill(Array models, LogitProcessor logit_processor, Sampler sampler, KVCacheConfig kv_cache_config, + EngineMode engine_mode, Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), std::move(kv_cache_config), - std::move(trace_recorder))); + std::move(engine_mode), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_state.cc b/cpp/serve/engine_state.cc index e63622550f..3aeac5ffaf 100644 --- a/cpp/serve/engine_state.cc +++ b/cpp/serve/engine_state.cc @@ -50,7 +50,9 @@ void EngineStateObj::Reset() { } RequestState EngineStateObj::GetRequestState(Request request) { - return request_states.at(request->id); + auto it = request_states.find(request->id); + ICHECK(it != request_states.end()); + return it->second; } } // namespace serve diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 5f5dc59816..512fc21333 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -204,6 +204,8 @@ void FunctionTable::_InitFunctions() { this->reset_kv_cache_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_clear"); this->kv_cache_add_sequence_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_add_sequence"); + this->kv_cache_fork_sequence_func_ = + get_global_func("vm.builtin.paged_attention_kv_cache_fork_sequence"); this->kv_cache_remove_sequence_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_remove_sequence"); this->kv_cache_begin_forward_func_ = diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index 956f19e02e..5475886d11 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -78,6 +78,7 @@ struct FunctionTable { PackedFunc reset_kv_cache_func_; bool support_backtracking_kv_; PackedFunc kv_cache_add_sequence_func_; + PackedFunc kv_cache_fork_sequence_func_; PackedFunc kv_cache_remove_sequence_func_; PackedFunc kv_cache_begin_forward_func_; PackedFunc kv_cache_end_forward_func_; diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index ecaa5276d8..c89eaaceae 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -371,6 +371,10 @@ class ModelImpl : public ModelObj { void AddNewSequence(int64_t seq_id) final { ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); } + void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final { + ft_.kv_cache_fork_sequence_func_(kv_cache_, parent_seq_id, child_seq_id); + } + /*! \brief Remove the given sequence from the KV cache in the model. */ void RemoveSequence(int64_t seq_id) final { ft_.kv_cache_remove_sequence_func_(kv_cache_, seq_id); diff --git a/cpp/serve/model.h b/cpp/serve/model.h index b561b7895e..fe396c4094 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -105,6 +105,9 @@ class ModelObj : public Object { /*! \brief Add a new sequence with the given sequence id to the KV cache. */ virtual void AddNewSequence(int64_t seq_id) = 0; + /*! \brief Fork a sequence from a given parent sequence. */ + virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) = 0; + /*! \brief Remove the given sequence from the KV cache in the model. */ virtual void RemoveSequence(int64_t seq_id) = 0; diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 7519a56adb..8b5543d4f1 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -74,31 +74,38 @@ void RequestModelStateNode::RemoveAllDraftTokens() { } } -TVM_REGISTER_OBJECT_TYPE(RequestStateNode); +TVM_REGISTER_OBJECT_TYPE(RequestStateEntryNode); -RequestState::RequestState(Request request, int num_models, int64_t internal_id, - const std::vector& token_table, - std::shared_ptr json_grammar_state_init_ctx) { - ObjectPtr n = make_object(); +RequestStateEntry::RequestStateEntry( + Request request, int num_models, int64_t internal_id, int rng_seed, + const std::vector& token_table, + std::shared_ptr json_grammar_state_init_ctx, int parent_idx) { + ObjectPtr n = make_object(); Array mstates; + Array inputs; + if (parent_idx == -1) { + inputs = request->inputs; + } mstates.reserve(num_models); for (int i = 0; i < num_models; ++i) { mstates.push_back( - RequestModelState(request, i, internal_id, request->inputs, json_grammar_state_init_ctx)); + RequestModelState(request, i, internal_id, inputs, json_grammar_state_init_ctx)); } - n->rng = RandomGenerator(request->generation_cfg->seed); + n->status = RequestStateStatus::kPending; + n->rng = RandomGenerator(rng_seed); n->stop_str_handler = StopStrHandler( !request->generation_cfg->ignore_eos ? request->generation_cfg->stop_strs : Array(), token_table); n->request = std::move(request); + n->parent_idx = parent_idx; n->mstates = std::move(mstates); n->next_callback_token_pos = 0; n->tadd = std::chrono::high_resolution_clock::now(); data_ = std::move(n); } -DeltaRequestReturn RequestStateNode::GetReturnTokenIds(const Tokenizer& tokenizer, - int max_single_sequence_length) { +DeltaRequestReturn RequestStateEntryNode::GetReturnTokenIds(const Tokenizer& tokenizer, + int max_single_sequence_length) { // - Case 0. There is remaining draft output ==> Unfinished // All draft outputs are supposed to be processed before finish. for (RequestModelState mstate : mstates) { @@ -114,7 +121,12 @@ DeltaRequestReturn RequestStateNode::GetReturnTokenIds(const Tokenizer& tokenize int num_committed_tokens = committed_tokens.size(); ICHECK_LE(this->next_callback_token_pos, num_committed_tokens); - // Case 1. Any of the stop strings is matched. + // Case 1. There is no new token ids. + if (this->next_callback_token_pos == num_committed_tokens) { + return {{}, {}, Optional()}; + } + + // Case 2. Any of the stop strings is matched. ICHECK(!stop_str_handler->StopTriggered()); while (next_callback_token_pos < num_committed_tokens) { std::vector delta_token_ids = @@ -129,7 +141,7 @@ DeltaRequestReturn RequestStateNode::GetReturnTokenIds(const Tokenizer& tokenize } } - // Case 2. Any of the stop tokens appears in the committed tokens ===> Finished + // Case 3. Any of the stop tokens appears in the committed tokens ===> Finished // `stop_token_ids` includes the stop tokens from conversation template and user-provided tokens. // This check will be ignored when `ignore_eos` is set for the benchmarking purpose. if (!request->generation_cfg->ignore_eos) { @@ -152,7 +164,7 @@ DeltaRequestReturn RequestStateNode::GetReturnTokenIds(const Tokenizer& tokenize return {return_token_ids, logprob_json_strs, finish_reason}; } - // Case 3. Generation reaches the specified max generation length ==> Finished + // Case 4. Generation reaches the specified max generation length ==> Finished // `max_tokens` means the generation length is limited by model capacity. if (request->generation_cfg->max_tokens >= 0 && num_committed_tokens >= request->generation_cfg->max_tokens) { @@ -160,7 +172,7 @@ DeltaRequestReturn RequestStateNode::GetReturnTokenIds(const Tokenizer& tokenize return_token_ids.insert(return_token_ids.end(), remaining.begin(), remaining.end()); return {return_token_ids, logprob_json_strs, String("length")}; } - // Case 4. Total length of the request reaches the maximum single sequence length ==> Finished + // Case 5. Total length of the request reaches the maximum single sequence length ==> Finished if (request->input_total_length + num_committed_tokens >= max_single_sequence_length) { std::vector remaining = stop_str_handler->Finish(); return_token_ids.insert(return_token_ids.end(), remaining.begin(), remaining.end()); diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 6cf5928a13..66e36d5b93 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -119,10 +119,57 @@ struct DeltaRequestReturn { Optional finish_reason; }; -class RequestStateNode : public Object { +/****************** Request States ******************/ + +/*! + * \brief For each request, we maintain its "request state" in the + * engine. Generally, the state of a request contains the information + * of the request's generation at the current moment, including + * the generated token ids, the grammar handler, etc. + * + * When a request has multiple parallel generations (e.g., the field + * `n` of its generation config is more than 1), each generation will + * have different states all the time. + * + * Therefore, to better support parallel generations, we denote the + * state of a single generation as a "RequestStateEntry" instance, + * and denote the state of a request's all generations using a vector, + * named as a "RequestState" instance. + * + * A request's all state entries are organized as a tree structure + * when there are parallel generations. + * - the request input has the root status entry, + * - each parallel generation is a child of the root. + * This tree structure may be further extended to more complicated + * cases in the future. As of now, for the case of `n > 1`, there + * will be (n + 1) entries in total. In a "RequestState", the root + * entry always has index 0. And we guarantee that the entry order + * from the vector begin to the end is always a topological order + * of the tree. + */ + +/*! \brief Request state status. */ +enum class RequestStateStatus : int { + kPending = 0, + kAlive = 1, + kFinished = 2, +}; + +class RequestStateEntryNode : public Object { public: + /*! \brief The status of the request state. */ + RequestStateStatus status; /*! \brief The request that this state corresponds to. */ Request request; + /*! + * \brief The idx of the parent request state of this state. + * Being -1 means the state has no parent and is the foremost + * "prefix" state or the only state. + */ + int parent_idx = -1; + /*! \brief The children indices of the request state. */ + std::vector children_idx; + /*! * \brief The state with regard to each model. * \sa RequestModelState @@ -154,21 +201,25 @@ class RequestStateNode : public Object { */ DeltaRequestReturn GetReturnTokenIds(const Tokenizer& tokenizer, int max_single_sequence_length); - static constexpr const char* _type_key = "mlc.serve.RequestState"; + static constexpr const char* _type_key = "mlc.serve.RequestStateEntry"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_FINAL_OBJECT_INFO(RequestStateNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(RequestStateEntryNode, Object); }; -class RequestState : public ObjectRef { +class RequestStateEntry : public ObjectRef { public: - explicit RequestState(Request request, int num_models, int64_t internal_id, - const std::vector& token_table, - std::shared_ptr json_grammar_state_init_ctx); + explicit RequestStateEntry(Request request, int num_models, int64_t internal_id, int rng_seed, + const std::vector& token_table, + std::shared_ptr json_grammar_state_init_ctx, + int parent_idx = -1); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestState, ObjectRef, RequestStateNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestStateEntry, ObjectRef, RequestStateEntryNode); }; +/*! \brief A request's state, which groups all the request state entries. */ +typedef std::vector RequestState; + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/sampler.cc b/cpp/serve/sampler.cc index 6a6bb65de9..d201158628 100644 --- a/cpp/serve/sampler.cc +++ b/cpp/serve/sampler.cc @@ -266,6 +266,7 @@ class CPUSampler : public SamplerObj { const Array& request_ids, // const Array& generation_cfg, // const std::vector& rngs, // + const std::vector* prob_indices, // std::vector* output_prob_dist) final { // probs_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); @@ -276,10 +277,12 @@ class CPUSampler : public SamplerObj { RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); // - Sample tokens from probabilities. - ICHECK_EQ(probs_host->shape[0], request_ids.size()); - ICHECK_EQ(probs_host->shape[0], generation_cfg.size()); - ICHECK_EQ(probs_host->shape[0], rngs.size()); - int n = probs_host->shape[0]; + int n = request_ids.size(); + ICHECK_EQ(generation_cfg.size(), n); + ICHECK_EQ(rngs.size(), n); + if (prob_indices == nullptr) { + ICHECK_EQ(probs_host->shape[0], n); + } std::vector sample_results; sample_results.resize(n); @@ -288,12 +291,13 @@ class CPUSampler : public SamplerObj { } tvm::runtime::parallel_for_with_threading_backend( - [this, &sample_results, &probs_host, &generation_cfg, &rngs, &request_ids, + [this, &sample_results, &probs_host, &generation_cfg, &rngs, &request_ids, prob_indices, output_prob_dist](int i) { RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); // Sample top p from probability. sample_results[i].sampled_token_id = SampleTopPFromProb( - probs_host, i, generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, + probs_host, prob_indices == nullptr ? i : prob_indices->at(i), + generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, rngs[i]->GetRandomNumber(), output_prob_dist); if (output_prob_dist == nullptr) { // When `output_prob_dist` is not nullptr, it means right now diff --git a/cpp/serve/sampler.h b/cpp/serve/sampler.h index 6f9c6acf47..faa2cffd57 100644 --- a/cpp/serve/sampler.h +++ b/cpp/serve/sampler.h @@ -39,15 +39,25 @@ class SamplerObj : public Object { * \param generation_cfg The generation config of each request * in the input batch. * \param rngs The random number generator of each sequence. + * \param prob_indices The indices of probability distribution in `probs_device` + * that each request in `request_ids` samples from. + * It defaults to nullptr, which means each request samples from the + * corresponding index in `prob_indices`. + * In usual cases, we only sample one token for each prob distribution + * in the batch, and `prob_indices` is nullptr in such cases. + * When we want to sample multiple tokens from a prob distribution (e.g., + * starting parallel generation after prefill the input), we use `prob_indices` + * to represent which distribution a token should be sampled from * \param output_prob_dist The output probability distribution * \return The batch of sampling results, which contain the sampled token id * and other probability info. */ virtual std::vector BatchSampleTokens( - NDArray probs_device, // - const Array& request_ids, // - const Array& generation_cfg, // - const std::vector& rngs, // + NDArray probs_device, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + const std::vector* prob_indices = nullptr, // std::vector* output_prob_dist = nullptr) = 0; /*! diff --git a/python/mlc_chat/protocol/openai_api_protocol.py b/python/mlc_chat/protocol/openai_api_protocol.py index e45711d516..b0d4d56192 100644 --- a/python/mlc_chat/protocol/openai_api_protocol.py +++ b/python/mlc_chat/protocol/openai_api_protocol.py @@ -296,7 +296,6 @@ def openai_api_get_unsupported_fields( """Get the unsupported fields in the request.""" unsupported_field_default_values: List[Tuple[str, Any]] = [ ("best_of", 1), - ("n", 1), ] unsupported_fields: List[str] = [] @@ -312,6 +311,7 @@ def openai_api_get_generation_config( """Create the generation config from the given request.""" kwargs: Dict[str, Any] = {} arg_names = [ + "n", "temperature", "top_p", "max_tokens", diff --git a/python/mlc_chat/serve/async_engine.py b/python/mlc_chat/serve/async_engine.py index 97330fea0d..84037b6fb1 100644 --- a/python/mlc_chat/serve/async_engine.py +++ b/python/mlc_chat/serve/async_engine.py @@ -5,6 +5,7 @@ import asyncio import sys import threading +from dataclasses import dataclass from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union import tvm @@ -18,6 +19,32 @@ from .request import Request +@dataclass +class AsyncStreamOutput: + """The output of AsyncThreadedEngine.generate + + Attributes + ---------- + delta_text : str + The delta text generated since the last output. + + num_delta_tokens : int + The number of delta tokens generated since the last output. + + delta_logprob_json_strs : Optional[List[str]] + The list of logprob JSON strings since the last output, + or None if the request does not require logprobs. + + finish_reason : Optional[str] + The finish reason of the request, or None if unfinished. + """ + + delta_text: str + num_delta_tokens: int + delta_logprob_json_strs: Optional[List[str]] + finish_reason: Optional[str] + + class AsyncRequestStream: """The asynchronous stream for requests. @@ -30,14 +57,11 @@ class AsyncRequestStream: can use to iterates all the generated tokens in order asynchronously. """ - # The asynchronous queue to hold elements of - # - either a tuple of (str, int, List[str], Optional[str]), denoting the - # delta output text, the number of delta tokens, the logprob JSON strings - # of delta tokens, and the optional finish reason respectively, - # - or an exception. + # The asynchronous queue to hold elements of either a list of + # AsyncStreamOutput or an exception. if sys.version_info >= (3, 9): _queue: asyncio.Queue[ # pylint: disable=unsubscriptable-object - Union[Tuple[str, int, Optional[List[str]], Optional[str]], Exception] + Union[List[AsyncStreamOutput], Exception] ] else: _queue: asyncio.Queue @@ -48,10 +72,7 @@ def __init__(self) -> None: self._queue = asyncio.Queue() self._finished = False - def push( - self, - item_or_exception: Union[Tuple[str, int, Optional[List[str]], Optional[str]], Exception], - ) -> None: + def push(self, item_or_exception: Union[List[AsyncStreamOutput], Exception]) -> None: """Push a new token to the stream.""" if self._finished: # No new item is expected after finish. @@ -72,7 +93,7 @@ def finish(self) -> None: def __aiter__(self): return self - async def __anext__(self) -> Tuple[str, int, Optional[List[str]], Optional[str]]: + async def __anext__(self) -> List[AsyncStreamOutput]: result = await self._queue.get() if isinstance(result, StopIteration): raise StopAsyncIteration @@ -156,7 +177,8 @@ def __init__( engine_mode = EngineMode() # The mapping from request ids to request asynchronous stream. - self._request_tools: Dict[str, Tuple[AsyncRequestStream, TextStreamer]] = {} + self._request_tools: Dict[str, Tuple[AsyncRequestStream, List[TextStreamer]]] = {} + self._num_unfinished_generations: Dict[str, int] = {} def _background_loop(): self._ffi["init_background_engine"]( @@ -186,14 +208,11 @@ def terminate(self): async def generate( self, prompt: Union[str, List[int]], generation_config: GenerationConfig, request_id: str - ) -> AsyncGenerator[Tuple[str, int, Optional[List[str]], Optional[str]], Any]: + ) -> AsyncGenerator[List[AsyncStreamOutput], Any]: """Asynchronous text generation interface. - The method is a coroutine that streams a tuple at a time via yield. - Each tuple is contained of - - the delta text in type str, - - the number of delta tokens in type int, - - the logprob JSON strings of delta tokens, - - the optional finish reason in type Optional[str]. + The method is a coroutine that streams a list of AsyncStreamOutput + at a time via yield. The returned list length is the number of + parallel generations specified by `generation_config.n`. Parameters ---------- @@ -230,7 +249,11 @@ async def generate( ) else: # Record the stream in the tracker - self._request_tools[request_id] = (stream, TextStreamer(self.tokenizer)) + self._request_tools[request_id] = ( + stream, + [TextStreamer(self.tokenizer) for _ in range(generation_config.n)], + ) + self._num_unfinished_generations[request_id] = generation_config.n self._ffi["add_request"](request) # Iterate the stream asynchronously and yield the token. @@ -282,28 +305,39 @@ def _request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput] def _request_stream_callback_impl(self, delta_outputs: List[data.RequestStreamOutput]) -> None: """The underlying implementation of request stream callback.""" for delta_output in delta_outputs: - ( - request_id, - delta_token_ids, - delta_logprob_json_strs, - finish_reason, - ) = delta_output.unpack() + request_id, stream_outputs = delta_output.unpack() tools = self._request_tools.get(request_id, None) if tools is None: continue self.record_event(request_id, event="start callback") - stream, text_streamer = tools - - self.record_event(request_id, event="start detokenization") - delta_text = text_streamer.put(delta_token_ids) - if finish_reason is not None: - delta_text += text_streamer.finish() - self.record_event(request_id, event="finish detokenization") + stream, text_streamers = tools + outputs = [] + for stream_output, text_streamer in zip(stream_outputs, text_streamers): + self.record_event(request_id, event="start detokenization") + delta_text = ( + text_streamer.put(stream_output.delta_token_ids) + if len(stream_output.delta_token_ids) > 0 + else "" + ) + if stream_output.finish_reason is not None: + delta_text += text_streamer.finish() + self.record_event(request_id, event="finish detokenization") + + outputs.append( + AsyncStreamOutput( + delta_text=delta_text, + num_delta_tokens=len(stream_output.delta_token_ids), + delta_logprob_json_strs=stream_output.delta_logprob_json_strs, + finish_reason=stream_output.finish_reason, + ) + ) + if stream_output.finish_reason is not None: + self._num_unfinished_generations[request_id] -= 1 # Push new delta text to the stream. - stream.push((delta_text, len(delta_token_ids), delta_logprob_json_strs, finish_reason)) - if finish_reason is not None: + stream.push(outputs) + if self._num_unfinished_generations[request_id] == 0: stream.finish() self._request_tools.pop(request_id, None) self.record_event(request_id, event="finish callback") diff --git a/python/mlc_chat/serve/config.py b/python/mlc_chat/serve/config.py index 00cd53f66f..1b90a4b24a 100644 --- a/python/mlc_chat/serve/config.py +++ b/python/mlc_chat/serve/config.py @@ -35,6 +35,9 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes Parameters ---------- + n : int + How many chat completion choices to generate for each input message. + temperature : float The value that applies to logits and modulates the next token probabilities. @@ -92,6 +95,7 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes The response format of the generation output. """ + n: int = 1 temperature: float = 0.8 top_p: float = 0.95 frequency_penalty: float = 0.0 diff --git a/python/mlc_chat/serve/data.py b/python/mlc_chat/serve/data.py index 15c0a4f205..57532827e9 100644 --- a/python/mlc_chat/serve/data.py +++ b/python/mlc_chat/serve/data.py @@ -1,5 +1,6 @@ """Classes denoting multi-modality data used in MLC LLM serving""" +from dataclasses import dataclass from typing import List, Optional, Tuple import tvm._ffi @@ -57,16 +58,13 @@ def token_ids(self) -> List[int]: return list(_ffi_api.TokenDataGetTokenIds(self)) # type: ignore # pylint: disable=no-member -@tvm._ffi.register_object("mlc.serve.RequestStreamOutput") # pylint: disable=protected-access -class RequestStreamOutput(Object): - """The generated delta request output that is streamed back - through callback stream function. - It contains four fields (in order): - - request_id : str - The id of the request that the function is invoked for. +@dataclass +class SingleRequestStreamOutput: + """The request stream output of a single request. - delta_tokens : List[int] + Attributes + ---------- + delta_token_ids : List[int] The new generated tokens since the last callback invocation for the input request. @@ -77,6 +75,24 @@ class RequestStreamOutput(Object): finish_reason : Optional[str] The finish reason of the request when it is finished, of None if the request has not finished yet. + """ + + delta_token_ids: List[int] + delta_logprob_json_strs: Optional[List[str]] + finish_reason: Optional[str] + + +@tvm._ffi.register_object("mlc.serve.RequestStreamOutput") # pylint: disable=protected-access +class RequestStreamOutput(Object): + """The generated delta request output that is streamed back + through callback stream function. + It contains four fields (in order): + + request_id : str + The id of the request that the function is invoked for. + + stream_outputs : List[SingleRequestStreamOutput] + The output instances, one for a request. Note ---- @@ -84,7 +100,7 @@ class RequestStreamOutput(Object): instantiates this class. """ - def unpack(self) -> Tuple[str, List[int], Optional[List[str]], Optional[str]]: + def unpack(self) -> Tuple[str, List[SingleRequestStreamOutput]]: """Return the fields of the delta output in a tuple. Returns @@ -92,26 +108,23 @@ def unpack(self) -> Tuple[str, List[int], Optional[List[str]], Optional[str]]: request_id : str The id of the request that the function is invoked for. - delta_tokens : List[int] - The new generated tokens since the last callback invocation - for the input request. - - delta_logprob_json_strs : Optional[List[str]] - The logprobs JSON strings of the new generated tokens - since last invocation. - - finish_reason : Optional[str] - The finish reason of the request when it is finished, - of None if the request has not finished yet. + stream_outputs : List[SingleRequestStreamOutput] + The output instances, one for a request. """ fields = _ffi_api.RequestStreamOutputUnpack(self) # type: ignore # pylint: disable=no-member - return ( - str(fields[0]), - list(fields[1]), - ( - [str(logprob_json_str) for logprob_json_str in fields[2]] + request_id = str(fields[0]) + stream_outputs = [] + for i, (delta_token_ids, finish_reason) in enumerate(zip(fields[1], fields[3])): + delta_logprob_json_strs = ( + [str(logprob_json_str) for logprob_json_str in fields[2][i]] if fields[2] is not None else None - ), - str(fields[3]) if fields[3] is not None else None, - ) + ) + stream_outputs.append( + SingleRequestStreamOutput( + delta_token_ids=list(delta_token_ids), + delta_logprob_json_strs=delta_logprob_json_strs, + finish_reason=str(finish_reason) if finish_reason is not None else None, + ) + ) + return request_id, stream_outputs diff --git a/python/mlc_chat/serve/engine.py b/python/mlc_chat/serve/engine.py index f5e69e6d54..a55ee09ddb 100644 --- a/python/mlc_chat/serve/engine.py +++ b/python/mlc_chat/serve/engine.py @@ -352,11 +352,11 @@ def __init__( # pylint: disable=too-many-arguments ) self.tokenizer = Tokenizer(tokenizer_path) - def generate( + def generate( # pylint: disable=too-many-locals self, prompts: Union[str, List[str], List[int], List[List[int]]], generation_config: Union[GenerationConfig, List[GenerationConfig]], - ) -> Tuple[List[str], List[Optional[List[str]]]]: + ) -> Tuple[List[List[str]], List[Optional[List[List[str]]]]]: """Generate texts for a list of input prompts. Each prompt can be a string or a list of token ids. The generation for each prompt is independent. @@ -377,10 +377,12 @@ def generate( Returns ------- - output_text : List[str] - The text generation results, one string for each input prompt. + output_text : List[List[str]] + The text generation results, one list of strings for each input prompt. + The length of each list is the parallel generation `n` in + generation config. - output_logprobs_str : List[Optional[List[str]]] + output_logprobs_str : List[Optional[List[List[str]]]] The logprob strings of each token for each input prompt, or None if an input prompt does not require logprobs. """ @@ -406,14 +408,21 @@ def generate( len(generation_config) == num_requests ), "Number of generation config and number of prompts mismatch" - num_finished_requests = 0 - output_texts: List[str] = [] - output_logprobs_str: List[Optional[List[str]]] = [] - text_streamers: List[TextStreamer] = [] + num_finished_generations = 0 + output_texts: List[List[str]] = [] + output_logprobs_str: List[Optional[List[List[str]]]] = [] + text_streamers: List[List[TextStreamer]] = [] for i in range(num_requests): - output_texts.append("") + output_texts.append([]) output_logprobs_str.append([] if generation_config[i].logprobs else None) - text_streamers.append(TextStreamer(self.tokenizer)) + text_streamers.append([]) + for _ in range(generation_config[i].n): + output_texts[i].append("") + text_streamers[i].append(TextStreamer(self.tokenizer)) + if output_logprobs_str[i] is not None: + output_logprobs_str[i].append([]) + + num_total_generations = sum(cfg.n for cfg in generation_config) # Save a copy of the original function callback since `generate` # overrides the callback function. @@ -422,27 +431,30 @@ def generate( # Define the callback function for request generation results def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]): - nonlocal num_finished_requests + nonlocal num_finished_generations for delta_output in delta_outputs: - ( - request_id, - delta_token_ids, - delta_logprob_json_strs, - finish_reason, - ) = delta_output.unpack() + request_id, stream_outputs = delta_output.unpack() rid = int(request_id) - text_streamer = text_streamers[rid] - if output_logprobs_str[rid] is not None: - assert delta_logprob_json_strs is not None - output_logprobs_str[rid] += delta_logprob_json_strs - - delta_text = text_streamer.put(delta_token_ids) - if finish_reason is not None: - delta_text += text_streamer.finish() - output_texts[rid] += delta_text - if finish_reason is not None: - num_finished_requests += 1 + assert len(stream_outputs) == generation_config[rid].n + for i, (stream_output, text_streamer) in enumerate( + zip(stream_outputs, text_streamers[rid]) + ): + if output_logprobs_str[rid] is not None: + assert stream_output.delta_logprob_json_strs is not None + output_logprobs_str[rid][i] += stream_output.delta_logprob_json_strs + + delta_text = ( + text_streamer.put(stream_output.delta_token_ids) + if len(stream_output.delta_token_ids) > 0 + else "" + ) + if stream_output.finish_reason is not None: + delta_text += text_streamer.finish() + + output_texts[rid][i] += delta_text + if stream_output.finish_reason is not None: + num_finished_generations += 1 # Override the callback function in engine. self._ffi["set_request_stream_callback"](request_stream_callback) @@ -462,7 +474,7 @@ def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]): ) ) - while num_finished_requests != num_requests: + while num_finished_generations != num_total_generations: self.step() # Restore the callback function in engine. diff --git a/python/mlc_chat/serve/entrypoints/openai_entrypoints.py b/python/mlc_chat/serve/entrypoints/openai_entrypoints.py index de85ab83f3..15e944e16a 100644 --- a/python/mlc_chat/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_chat/serve/entrypoints/openai_entrypoints.py @@ -91,14 +91,15 @@ async def request_completion(request: CompletionRequest, raw_request: fastapi.Re if request.stream: async def completion_stream_generator() -> AsyncGenerator[str, None]: - assert request.n == 1 - # - Echo back the prompt. if request.echo: text = async_engine.tokenizer.decode(prompt) response = CompletionResponse( id=request_id, - choices=[CompletionResponseChoice(text=text)], + choices=[ + CompletionResponseChoice(index=i, text=text) + for i in range(generation_cfg.n) + ], model=request.model, usage=UsageInfo( prompt_tokens=len(prompt), @@ -109,37 +110,45 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: # - Generate new tokens. num_completion_tokens = 0 - finish_reason = None + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] async_engine.record_event(request_id, event="invoke generate") - async for ( - delta_text, - num_delta_tokens, - delta_logprob_json_strs, - finish_reason, - ) in async_engine.generate(prompt, generation_cfg, request_id): - num_completion_tokens += num_delta_tokens - if delta_text == "": - # Ignore empty delta text -- do not yield. - continue - - response = CompletionResponse( - id=request_id, - choices=[ + async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): + assert len(delta_outputs) == generation_cfg.n + choices = [] + for i, delta_output in enumerate(delta_outputs): + finish_reason_updated = False + if delta_output.finish_reason is not None and finish_reasons[i] is None: + finish_reasons[i] = delta_output.finish_reason + finish_reason_updated = True + num_completion_tokens += delta_output.num_delta_tokens + if not finish_reason_updated and delta_output.delta_text == "": + # Ignore empty delta text when finish reason is not updated. + continue + + choices.append( CompletionResponseChoice( - finish_reason=finish_reason, - text=delta_text, + index=i, + finish_reason=finish_reasons[i], + text=delta_output.delta_text, logprobs=( LogProbs( content=[ LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in delta_logprob_json_strs + for logprob_json_str in delta_output.delta_logprob_json_strs ] ) - if delta_logprob_json_strs is not None + if delta_output.delta_logprob_json_strs is not None else None ), ) - ], + ) + + if len(choices) == 0: + # Skip yield when there is no delta output. + continue + response = CompletionResponse( + id=request_id, + choices=choices, model=request.model, usage=UsageInfo( prompt_tokens=len(prompt), @@ -151,14 +160,16 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: # - Echo the suffix. if request.suffix is not None: - assert finish_reason is not None + assert all(finish_reason is not None for finish_reason in finish_reasons) response = CompletionResponse( id=request_id, choices=[ CompletionResponseChoice( + index=i, finish_reason=finish_reason, text=request.suffix, ) + for i, finish_reason in enumerate(finish_reasons) ], model=request.model, usage=UsageInfo( @@ -175,17 +186,15 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) # Normal response. - output_text = "" if not request.echo else async_engine.tokenizer.decode(prompt) + init_output_text = "" if not request.echo else async_engine.tokenizer.decode(prompt) + output_texts = [init_output_text for _ in range(generation_cfg.n)] num_completion_tokens = 0 - finish_reason: Optional[str] = None - logprob_json_strs: Optional[List[str]] = [] if generation_cfg.logprobs else None + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + logprob_json_strs_list: Optional[List[List[str]]] = ( + [[] for _ in range(generation_cfg.n)] if generation_cfg.logprobs else None + ) async_engine.record_event(request_id, event="invoke generate") - async for ( - delta_text, - num_delta_tokens, - delta_logprob_json_strs, - finish_reason, - ) in async_engine.generate(prompt, generation_cfg, request_id): + async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified # when the request is disconnected. @@ -195,31 +204,40 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: return entrypoint_utils.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" ) - output_text += delta_text - num_completion_tokens += num_delta_tokens - if logprob_json_strs is not None: - assert delta_logprob_json_strs is not None - logprob_json_strs += delta_logprob_json_strs - assert finish_reason is not None + + assert len(delta_outputs) == generation_cfg.n + for i, delta_output in enumerate(delta_outputs): + if delta_output.finish_reason is not None and finish_reasons[i] is None: + finish_reasons[i] = delta_output.finish_reason + output_texts[i] += delta_output.delta_text + num_completion_tokens += delta_output.num_delta_tokens + if logprob_json_strs_list is not None: + assert delta_output.delta_logprob_json_strs is not None + logprob_json_strs_list[i] += delta_output.delta_logprob_json_strs + assert all(finish_reason is not None for finish_reason in finish_reasons) suffix = request.suffix if request.suffix is not None else "" async_engine.record_event(request_id, event="finish") response = CompletionResponse( id=request_id, choices=[ CompletionResponseChoice( + index=i, finish_reason=finish_reason, text=output_text + suffix, logprobs=( LogProbs( content=[ LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in logprob_json_strs + for logprob_json_str in logprob_json_strs_list[ # pylint: disable=unsubscriptable-object + i + ] ] ) - if logprob_json_strs is not None + if logprob_json_strs_list is not None else None ), ) + for i, (output_text, finish_reason) in enumerate(zip(output_texts, finish_reasons)) ], model=request.model, usage=UsageInfo( @@ -408,44 +426,55 @@ async def request_chat_completion( if request.stream: async def completion_stream_generator() -> AsyncGenerator[str, None]: - assert request.n == 1 async_engine.record_event(request_id, event="invoke generate") - async for ( - delta_text, - _, - delta_logprob_json_strs, - finish_reason, - ) in async_engine.generate(prompt, generation_cfg, request_id): - if delta_text == "": - async_engine.record_event(request_id, event="skip empty delta text") - # Ignore empty delta text -- do not yield. - continue - - if conv_template.use_function_calling: - finish_reason = "tool_calls" + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): + assert len(delta_outputs) == generation_cfg.n + choices = [] + for i, delta_output in enumerate(delta_outputs): + finish_reason_updated = False + if delta_output.finish_reason is not None and finish_reasons[i] is None: + finish_reasons[i] = ( + delta_output.finish_reason + if not conv_template.use_function_calling + else "tool_calls" + ) + finish_reason_updated = True + if not finish_reason_updated and delta_output.delta_text == "": + # Ignore empty delta text when finish reason is not updated. + async_engine.record_event(request_id, event="skip empty delta text") + continue - response = ChatCompletionStreamResponse( - id=request_id, - choices=[ + choices.append( ChatCompletionStreamResponseChoice( - finish_reason=finish_reason, - delta=ChatCompletionMessage(content=delta_text, role="assistant"), + index=i, + finish_reason=finish_reasons[i], + delta=ChatCompletionMessage( + content=delta_output.delta_text, role="assistant" + ), logprobs=( LogProbs( content=[ LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in delta_logprob_json_strs + for logprob_json_str in delta_output.delta_logprob_json_strs ] ) - if delta_logprob_json_strs is not None + if delta_output.delta_logprob_json_strs is not None else None ), ) - ], + ) + + if len(choices) == 0: + # Skip yield when there is no delta output. + continue + response = ChatCompletionStreamResponse( + id=request_id, + choices=choices, model=request.model, system_fingerprint="", ) - async_engine.record_event(request_id, event=f"yield delta text {delta_text}") + async_engine.record_event(request_id, event="yield delta output") yield f"data: {response.model_dump_json()}\n\n" async_engine.record_event(request_id, event="finish") yield "data: [DONE]\n\n" @@ -455,17 +484,14 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) # Normal response. - output_text = "" + output_texts = ["" for _ in range(generation_cfg.n)] num_completion_tokens = 0 - finish_reason: Optional[str] = None - logprob_json_strs: Optional[List[str]] = [] if generation_cfg.logprobs else None + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + logprob_json_strs_list: Optional[List[List[str]]] = ( + [[] for _ in range(generation_cfg.n)] if generation_cfg.logprobs else None + ) async_engine.record_event(request_id, event="invoke generate") - async for ( - delta_text, - num_delta_tokens, - delta_logprob_json_strs, - finish_reason, - ) in async_engine.generate(prompt, generation_cfg, request_id): + async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified # when the request is disconnected. @@ -475,61 +501,72 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: return entrypoint_utils.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" ) - output_text += delta_text - num_completion_tokens += num_delta_tokens - if logprob_json_strs is not None: - assert delta_logprob_json_strs is not None - logprob_json_strs += delta_logprob_json_strs - assert finish_reason is not None + + assert len(delta_outputs) == generation_cfg.n + for i, delta_output in enumerate(delta_outputs): + if delta_output.finish_reason is not None and finish_reasons[i] is None: + finish_reasons[i] = delta_output.finish_reason + output_texts[i] += delta_output.delta_text + num_completion_tokens += delta_output.num_delta_tokens + if logprob_json_strs_list is not None: + assert delta_output.delta_logprob_json_strs is not None + logprob_json_strs_list[i] += delta_output.delta_logprob_json_strs + assert all(finish_reason is not None for finish_reason in finish_reasons) async_engine.record_event(request_id, event="finish") + tool_calls_list: List[List[ChatToolCall]] = [[] for _ in range(generation_cfg.n)] if conv_template.use_function_calling: - try: - fn_json_list = convert_function_str_to_json(output_text) - except (SyntaxError, ValueError): - output_text = "Got an invalid function call output from model" - finish_reason = "error" - else: - tool_calls = [ - ChatToolCall( - type="function", - function=ChatFunctionCall( - name=fn_json_obj["name"], arguments=fn_json_obj["arguments"] - ), - ) - for fn_json_obj in fn_json_list - if fn_json_obj is not None - ] - if len(tool_calls) == 0: + for i, output_text in enumerate(output_texts): + try: + fn_json_list = convert_function_str_to_json(output_text) + except (SyntaxError, ValueError): output_text = "Got an invalid function call output from model" - finish_reason = "error" + finish_reasons[i] = "error" else: - finish_reason = "tool_calls" - - message = ( - ChatCompletionMessage(role="assistant", content=output_text) - if (not conv_template.use_function_calling or finish_reason == "error") - else ChatCompletionMessage(role="assistant", content=None, tool_calls=tool_calls) - ) + tool_calls_list[i] = [ + ChatToolCall( + type="function", + function=ChatFunctionCall( + name=fn_json_obj["name"], arguments=fn_json_obj["arguments"] + ), + ) + for fn_json_obj in fn_json_list + if fn_json_obj is not None + ] + if len(tool_calls_list[i]) == 0: + output_texts[i] = "Got an invalid function call output from model" + finish_reasons[i] = "error" + else: + finish_reasons[i] = "tool_calls" return ChatCompletionResponse( id=request_id, choices=[ ChatCompletionResponseChoice( - finish_reason=finish_reason, - message=message, + index=i, + finish_reason=finish_reasons[i], + message=( + ChatCompletionMessage(role="assistant", content=output_text) + if (not conv_template.use_function_calling or finish_reason == "error") + else ChatCompletionMessage(role="assistant", tool_calls=tool_calls) + ), logprobs=( LogProbs( content=[ LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in logprob_json_strs + for logprob_json_str in logprob_json_strs_list[ # pylint: disable=unsubscriptable-object + i + ] ] ) - if logprob_json_strs is not None + if logprob_json_strs_list is not None else None ), ) + for i, (output_text, finish_reason, tool_calls) in enumerate( + zip(output_texts, finish_reasons, tool_calls_list) + ) ], model=request.model, system_fingerprint="", diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index 3cb015000f..1436de34d7 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -63,34 +63,33 @@ def check_openai_nonstream_response( choices = response["choices"] assert isinstance(choices, list) - assert len(choices) == num_choices - - for idx, choice in enumerate(choices): - assert choice["index"] == idx + assert len(choices) <= num_choices + texts: List[str] = ["" for _ in range(num_choices)] + for choice in choices: + idx = choice["index"] assert choice["finish_reason"] in finish_reasons - text: str if not is_chat_completion: assert isinstance(choice["text"], str) - text = choice["text"] + texts[idx] = choice["text"] if echo_prompt is not None: - assert text + assert texts[idx] if suffix is not None: - assert text + assert texts[idx] else: message = choice["message"] assert message["role"] == "assistant" assert isinstance(message["content"], str) - text = message["content"] + texts[idx] = message["content"] if stop is not None: for stop_str in stop: - assert stop_str not in text + assert stop_str not in texts[idx] if require_substr is not None: for substr in require_substr: - assert substr in text + assert substr in texts[idx] if json_mode: - assert is_json_or_json_prefix(text) + assert is_json_or_json_prefix(texts[idx]) usage = response["usage"] assert isinstance(usage, dict) @@ -125,9 +124,9 @@ def check_openai_stream_response( choices = response["choices"] assert isinstance(choices, list) - assert len(choices) == num_choices - for idx, choice in enumerate(choices): - assert choice["index"] == idx + assert len(choices) <= num_choices + for choice in choices: + idx = choice["index"] if not is_chat_completion: assert isinstance(choice["text"], str) @@ -156,7 +155,7 @@ def check_openai_stream_response( if completion_tokens is not None: assert responses[-1]["usage"]["completion_tokens"] == completion_tokens - for output in outputs: + for i, output in enumerate(outputs): if echo_prompt is not None: assert output.startswith(echo_prompt) if suffix is not None: @@ -864,6 +863,51 @@ def test_openai_v1_chat_completions( ) +@pytest.mark.parametrize("stream", [False, True]) +@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) +def test_openai_v1_chat_completions_n( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, + messages: List[Dict[str, str]], +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + n = 3 + payload = { + "model": served_model[0], + "messages": messages, + "stream": stream, + "n": n, + } + + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion", + num_choices=n, + finish_reasons=["stop"], + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion.chunk", + num_choices=n, + finish_reasons=["stop"], + ) + + @pytest.mark.parametrize("stream", [False, True]) @pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) def test_openai_v1_chat_completions_openai_package( @@ -1135,6 +1179,8 @@ def test_debug_dump_event_trace( for msg in CHAT_COMPLETION_MESSAGES: test_openai_v1_chat_completions(MODEL, None, stream=False, messages=msg) test_openai_v1_chat_completions(MODEL, None, stream=True, messages=msg) + test_openai_v1_chat_completions_n(MODEL, None, stream=False, messages=msg) + test_openai_v1_chat_completions_n(MODEL, None, stream=True, messages=msg) test_openai_v1_chat_completions_openai_package(MODEL, None, stream=False, messages=msg) test_openai_v1_chat_completions_openai_package(MODEL, None, stream=True, messages=msg) test_openai_v1_chat_completions_max_tokens(MODEL, None, stream=False) diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index df8e64bec0..c7616df5f7 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -26,15 +26,17 @@ async def test_engine_generate(): "dist/Llama-2-7b-chat-hf-q0f16-MLC", model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", ) - kv_cache_config = KVCacheConfig(page_size=16) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine async_engine = AsyncThreadedEngine(model, kv_cache_config) num_requests = 10 max_tokens = 256 - generation_cfg = GenerationConfig(max_tokens=max_tokens) + generation_cfg = GenerationConfig(max_tokens=max_tokens, n=3) - outputs: List[str] = ["" for _ in range(num_requests)] + output_texts: List[List[str]] = [ + ["" for _ in range(generation_cfg.n)] for _ in range(num_requests) + ] async def generate_task( async_engine: AsyncThreadedEngine, @@ -44,10 +46,12 @@ async def generate_task( ): print(f"generate task for request {request_id}") rid = int(request_id) - async for delta_text, _, _, _ in async_engine.generate( + async for delta_outputs in async_engine.generate( prompt, generation_cfg, request_id=request_id ): - outputs[rid] += delta_text + assert len(delta_outputs) == generation_cfg.n + for i, delta_output in enumerate(delta_outputs): + output_texts[rid][i] += delta_output.delta_text tasks = [ asyncio.create_task( @@ -60,9 +64,13 @@ async def generate_task( # Print output. print("All finished") - for req_id, output in enumerate(outputs): + for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") - print(f"Output {req_id}:{output}\n") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") async_engine.terminate() del async_engine diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index 89a113d1bb..becc594622 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -44,7 +44,9 @@ async def test_engine_generate(): max_tokens = 256 generation_cfg = GenerationConfig(max_tokens=max_tokens) - outputs: List[str] = ["" for _ in range(num_requests)] + output_texts: List[List[str]] = [ + ["" for _ in range(generation_cfg.n)] for _ in range(num_requests) + ] async def generate_task( async_engine: AsyncThreadedEngine, @@ -54,10 +56,12 @@ async def generate_task( ): print(f"generate task for request {request_id}") rid = int(request_id) - async for delta_text, _, _, _ in async_engine.generate( + async for delta_outputs in async_engine.generate( prompt, generation_cfg, request_id=request_id ): - outputs[rid] += delta_text + assert len(delta_outputs) == generation_cfg.n + for i, delta_output in enumerate(delta_outputs): + output_texts[rid][i] += delta_output.delta_text tasks = [ asyncio.create_task( @@ -70,9 +74,13 @@ async def generate_task( # Print output. print("All finished") - for req_id, output in enumerate(outputs): + for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") - print(f"Output {req_id}:{output}\n") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") async_engine.terminate() del async_engine diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 373a97a743..5cd13be91e 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -87,8 +87,9 @@ def test_engine_basic(): # Define the callback function for request generation results def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_token_ids, _, _ = delta_output.unpack() - outputs[int(request_id)] += delta_token_ids + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine engine = Engine(model, kv_cache_config, request_stream_callback=fcallback) @@ -153,10 +154,11 @@ class CallbackTimer: def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_token_ids, _, finish_reason = delta_output.unpack() - if finish_reason is not None: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += delta_token_ids + outputs[int(request_id)] += stream_outputs[0].delta_token_ids finish_time[int(request_id)] = self.timer return fcallback @@ -231,10 +233,11 @@ class CallbackTimer: def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_token_ids, _, finish_reason = delta_output.unpack() - if finish_reason is not None: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += delta_token_ids + outputs[int(request_id)] += stream_outputs[0].delta_token_ids finish_time[int(request_id)] = self.timer return fcallback @@ -312,11 +315,12 @@ class CallbackTimer: def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_token_ids, _, finish_reason = delta_output.unpack() - if finish_reason is not None: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: print(f"Request {request_id} finished at step {self.timer}.") self.finished_requests += 1 - outputs[int(request_id)] += delta_token_ids + outputs[int(request_id)] += stream_outputs[0].delta_token_ids finish_time[int(request_id)] = self.timer return fcallback @@ -368,7 +372,7 @@ def test_engine_generate(): "dist/Llama-2-7b-chat-hf-q0f16-MLC", model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", ) - kv_cache_config = KVCacheConfig(page_size=16) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine engine = Engine(model, kv_cache_config) @@ -379,9 +383,13 @@ def test_engine_generate(): output_texts, _ = engine.generate( prompts[:num_requests], GenerationConfig(max_tokens=max_tokens) ) - for req_id, output in enumerate(output_texts): + for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") - print(f"Output {req_id}:{output}\n") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") if __name__ == "__main__": diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index 901e6c4d98..e96eac9dda 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -49,9 +49,13 @@ def test_batch_generation_with_grammar(): # Generate output. output_texts, _ = engine.generate(prompts, all_generation_configs) - for req_id, output in enumerate(output_texts): + for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") - print(f"Output {req_id}: {output}\n") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") async def run_async_engine(): @@ -75,7 +79,9 @@ async def run_async_engine(): response_format=ResponseFormat(type="json_object"), ) - outputs: List[str] = ["" for _ in range(len(prompts))] + output_texts: List[List[str]] = [ + ["" for _ in range(generation_config.n)] for _ in range(len(prompts)) + ] async def generate_task( async_engine: AsyncThreadedEngine, @@ -85,10 +91,12 @@ async def generate_task( ): print(f"Start generation task for request {request_id}") rid = int(request_id) - async for delta_text, _, _, _ in async_engine.generate( + async for delta_outputs in async_engine.generate( prompt, generation_cfg, request_id=request_id ): - outputs[rid] += delta_text + assert len(delta_outputs) == generation_cfg.n + for i, delta_output in enumerate(delta_outputs): + output_texts[rid][i] += delta_output.delta_text tasks = [ asyncio.create_task( @@ -101,9 +109,13 @@ async def generate_task( # Print output. print("All finished") - for req_id, output in enumerate(outputs): + for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") - print(f"Output {req_id}: {output}\n") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") print(async_engine.trace_recorder.dump_json(), file=open("tmpfiles/tmp.json", "w")) diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 1eee361fd8..663744305d 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -93,8 +93,9 @@ def test_engine_basic(): # Define the callback function for request generation results def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_token_ids, _, _ = delta_output.unpack() - outputs[int(request_id)] += delta_token_ids + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine engine = Engine([model, ssm], kv_cache_config, engine_mode, fcallback) @@ -164,10 +165,11 @@ class CallbackTimer: def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_token_ids, _, finish_reason = delta_output.unpack() - if finish_reason is not None: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += delta_token_ids + outputs[int(request_id)] += stream_outputs[0].delta_token_ids finish_time[int(request_id)] = self.timer return fcallback @@ -225,11 +227,15 @@ def test_engine_generate(): # Generate output. output_texts, _ = engine.generate( - prompts[:num_requests], GenerationConfig(max_tokens=max_tokens) + prompts[:num_requests], GenerationConfig(max_tokens=max_tokens, n=3) ) - for req_id, output in enumerate(output_texts): + for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") - print(f"Output {req_id}:{output}\n") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") def test_engine_efficiency(): @@ -255,8 +261,9 @@ def test_engine_efficiency(): # Define the callback function for request generation results def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_token_ids, _, _ = delta_output.unpack() - outputs[int(request_id)] += delta_token_ids + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine engine = Engine(model, kv_cache_config, request_stream_callback=fcallback) @@ -326,8 +333,9 @@ def test_engine_spec_efficiency(): # Define the callback function for request generation results def fcallback(delta_outputs: List[RequestStreamOutput]): for delta_output in delta_outputs: - request_id, delta_token_ids, _, _ = delta_output.unpack() - outputs[int(request_id)] += delta_token_ids + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine spec_engine = Engine([model, ssm], kv_cache_config, engine_mode, fcallback) From 63c338b79f7e72738eb33e414282538d4745791b Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 2 Mar 2024 10:04:14 -0500 Subject: [PATCH 021/531] [CI] Add retry to scm checkout (#1869) Sometimes scm checkout can timeout, this PR add retry to that --- ci/jenkinsfile.groovy | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ci/jenkinsfile.groovy b/ci/jenkinsfile.groovy index 351c8d4e38..ec8210c172 100644 --- a/ci/jenkinsfile.groovy +++ b/ci/jenkinsfile.groovy @@ -47,7 +47,10 @@ def unpack_lib(name, libs) { def init_git(submodule = false) { cleanWs() - checkout scm + // add retry in case checkout timeouts + retry(5) { + checkout scm + } if (submodule) { retry(5) { timeout(time: 10, unit: 'MINUTES') { From e8b5b0bd9eff8474beda7d20642594f0d65602aa Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 2 Mar 2024 17:48:06 -0500 Subject: [PATCH 022/531] [Attn] Use float32 accumulation in attention kernel (#1870) Prior to this PR, the TIR attention kernels does not cast matmul operands to fp32 before multiplying. For models like Phi-2 which may have large Q/K/V data (at the level of a few hundreds), the fp16 multiplication exceeds the range of fp16, and lead to attention result being NAN sometimes. This PR fixes this issue. --- python/mlc_chat/nn/kv_cache.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/mlc_chat/nn/kv_cache.py b/python/mlc_chat/nn/kv_cache.py index 5e39a614e6..cb0e000b87 100644 --- a/python/mlc_chat/nn/kv_cache.py +++ b/python/mlc_chat/nn/kv_cache.py @@ -673,7 +673,7 @@ def batch_prefill_paged_kv( i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 - S_local[i, j] += Q_smem[i, k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): with T.block("S_store"): @@ -731,7 +731,7 @@ def batch_prefill_paged_kv( i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) - O_local[i, j] += S_smem[i, k] * V_smem[k, j] + O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") # Store O from smem to gmem for li, lj in T.grid(tile_x, tile_y): @@ -982,7 +982,7 @@ def batch_decode_paged_kv( # compute S = Q * K * sm_scale S_reduce_local[0] = 0 for vec in T.serial(VEC_SIZE): - S_reduce_local[0] += Q_local[vec] * K_local[vec] * attn_score_scaling_factor * sm_scale + S_reduce_local[0] += T.cast(Q_local[vec], "float32") * T.cast(K_local[vec], "float32") * attn_score_scaling_factor * sm_scale with T.block("block_cross_thread"): T.reads(S_reduce_local[0]) @@ -1016,7 +1016,7 @@ def batch_decode_paged_kv( for vec in T.vectorized(VEC_SIZE): V_local[vec] = V_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec] for vec in T.vectorized(VEC_SIZE): - O_local[vec] += V_local[vec] * S_local[j] + O_local[vec] += T.cast(V_local[vec], "float32") * S_local[j] if bdz > 1: # allreduce over bdz @@ -1319,7 +1319,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 - S_local[i, j] += Q_smem[i, k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): with T.block("S_store"): @@ -1377,7 +1377,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) - O_local[i, j] += S_smem[i, k] * V_smem[k, j] + O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") # Store O from smem to gmem for li, lj in T.grid(tile_x, tile_y): From 91008ae99e6112d2e0cf5d9a692da4a8be37d8c8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 3 Mar 2024 08:16:30 -0600 Subject: [PATCH 023/531] [Utils] Allow ReorderTransformFunc to be used without param manager (#1857) Prior to this commit, the `ReorderTransformFunc` required several components of the `ParamManager` to use. The functionality it provides, reordering dataflow blocks to minimize the liveset, is useful outside of the context of the `ParamManager`. This commit makes the following changes, allowing it to be used independently of the `ParamManager`. - Generate the `pidx2binname` dictionary outside of `ReorderTransformFunc` - Allow parameters to be separate `func.params`, rather than a single bundled tuple parameter. --- mlc_llm/relax_model/param_manager.py | 12 +- mlc_llm/transform/reorder_transform_func.py | 157 +++++++++++++------- 2 files changed, 110 insertions(+), 59 deletions(-) diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index 9a59b933b8..1ad1ee6428 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -837,11 +837,13 @@ def optimize_transform_param_order(self) -> tvm.transform.Pass: tvm.transform.Pass The transformation """ - return ReorderTransformFunc( - self.pidx2pname, - self.torch_pname2binname, - self.f_convert_pname_fwd, - ) + + pidx2binname: Dict[int, str] = { + pidx: self.torch_pname2binname[self.f_convert_pname_fwd(pname)[0]] + for pidx, pname in self.pidx2pname.items() + if self.f_convert_pname_fwd(pname)[0] in self.torch_pname2binname + } + return ReorderTransformFunc(pidx2binname) @mutator diff --git a/mlc_llm/transform/reorder_transform_func.py b/mlc_llm/transform/reorder_transform_func.py index aa5ff9f81b..50b6337e3a 100644 --- a/mlc_llm/transform/reorder_transform_func.py +++ b/mlc_llm/transform/reorder_transform_func.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Set, Tuple +from typing import Callable, Dict, List, Set, Tuple, Optional import tvm from tvm import relax @@ -87,14 +87,22 @@ def analyze_func( num_input = 0 # Sanity check on the function pattern. - assert len(func.params) == num_input + 1 assert isinstance(func.body, relax.SeqExpr) assert len(func.body.blocks) == 1 assert isinstance(func.body.blocks[0], relax.DataflowBlock) assert func.body.blocks[0].bindings[-1].var.same_as(func.body.body) - model_param_tuple = func.params[num_input] - bindings = func.body.blocks[0].bindings + if isinstance(func.params[num_input].struct_info, relax.TupleStructInfo): + model_param_tuple = func.params[num_input] + else: + model_param_tuple = None + for i, var in enumerate(func.params[num_input:]): + binname = pidx2binname.get(i, var.name_hint) + if binname not in binname2get_param_bindings: + binname2get_param_bindings[binname] = [] + binname2get_param_bindings[binname].append(var) + + bindings = list(func.body.blocks[0].bindings) # Go through each binding except the last one. (The last one is the output # binding `gv = (lv, lv1, ...)`) which we ignore for analysis. @@ -103,7 +111,11 @@ def analyze_func( binding_var_set.add(binding.var) var_users[binding.var] = [] - if isinstance(value, relax.TupleGetItem) and value.tuple_value.same_as(model_param_tuple): + if ( + model_param_tuple is not None + and isinstance(value, relax.TupleGetItem) + and value.tuple_value.same_as(model_param_tuple) + ): # For weight fetching bindings (`lv = params[idx]`), we group them # according to the binary file name. pidx = value.index @@ -139,7 +151,7 @@ def fvisit(obj): def reorder_func( func: relax.Function, - pidx2binname: Dict[int, str], + pidx2binname: Optional[Dict[int, str]] = None, ) -> relax.Function: """Reorder the bindings of the input weight transform Relax function according the weight location in binary files. @@ -153,51 +165,95 @@ def reorder_func( func : relax.Function The weight transform function to be analyzed. - pidx2binname : Dict[int, str] - The mapping from each raw tensor index to the name of the binary - file where it resides. + pidx2binname : Optional[Dict[int, str]] + + The mapping from each raw tensor index to the name of the + binary file where it resides. If a relax dataflow graph has + multiple valid topological sorts, the order that minimizes the + number of simultaneously open files will be produced + + If `None` (default), the existing order of relax bindings is + preserved in these cases. Returns ------- func_updated : relax.Function The returned function where the bindings are updated with the new order. + """ - get_param_bindings, var_users, num_depending_vars = analyze_func(func, pidx2binname) - - # The bindings in the new order, output by the topological sort. - new_bindings: List[relax.Binding] = [] - # The queue used in the topological sort. - binding_queue: List[relax.Binding] = [] - - for binding, n_depending in list(num_depending_vars.items()): - if n_depending == 0: - binding_queue.append(binding) - del num_depending_vars[binding] - - # Start topological sort: - # each time we emit a weight fetching binding, and then adds all bindings - # that depend on it. - for get_param_binding in get_param_bindings: - binding_queue.append(get_param_binding) - - while len(binding_queue) > 0: - binding = binding_queue.pop(0) - new_bindings.append(binding) - for user_binding in var_users[binding.var]: - num_depending_vars[user_binding] -= 1 - if num_depending_vars[user_binding] == 0: - del num_depending_vars[user_binding] - binding_queue.append(user_binding) - - # Add the output binding. - new_bindings.append(func.body.blocks[0].bindings[-1]) - # Sanity check on the integrity. - assert len(new_bindings) == len(func.body.blocks[0].bindings) - assert len(num_depending_vars) == 0 + + if pidx2binname is None: + pidx2binname = {} + + bindings_to_visit = list(func.body.blocks[0].bindings) + param_lookup = {param: i for i, param in enumerate(func.params)} + binding_lookup = {} + previously_defined = set(func.params) + new_binding_order = [] + + param_tuple = None + if len(func.params) == 1 and isinstance(func.params[0].struct_info, relax.TupleStructInfo): + param_tuple = func.params[0] + + def sort_key(i): + binding = bindings_to_visit[i] + upstream_vars = relax.analysis.free_vars(binding.value) + + valid_ordering = all(var in previously_defined for var in upstream_vars) + last_param_used = max( + (param_lookup[var] for var in upstream_vars if var in param_lookup), default=-1 + ) + earliest_binding_used = min( + (binding_lookup[var] for var in upstream_vars if var in binding_lookup), default=-1 + ) + if ( + param_tuple + and isinstance(binding.value, relax.TupleGetItem) + and binding.value.tuple_value.same_as(param_tuple) + and binding.value.index in pidx2binname + ): + tuple_param_group = pidx2binname[binding.value.index] + else: + tuple_param_group = "" + + return [ + # First, sort by valid orderings, so the min element will + # always be a binding that would be legal to use. + -valid_ordering, + # Next, sort by the function parameter used by this + # binding, in increasing order. That way, we start by + # computing everything that required just the first + # parameter, then move on to variables that can be + # computed with the first two parameters, and so on. + last_param_used, + # Next, sort by the other bindings used. This way, for + # variables that are only used as input in a single + # downstream binding, the variable's required live range + # is minimized. + -earliest_binding_used, + # Finally, if this is a `TupleGetItem(param_tuple, i)`, + # select the option that uses an already-open file. This + # is mainly used relevant when loading from pytorch, which + # require loading the entire file at once. + tuple_param_group, + ] + + while bindings_to_visit: + i_binding = min(range(len(bindings_to_visit)), key=sort_key) + binding = bindings_to_visit.pop(i_binding) + + assert all(var in previously_defined for var in relax.analysis.free_vars(binding.value)) + new_binding_order.append(binding) + previously_defined.add(binding.var) + + assert len(new_binding_order) == len(func.body.blocks[0].bindings) return relax.Function( func.params, - relax.SeqExpr(blocks=[relax.DataflowBlock(new_bindings)], body=func.body.body), + relax.SeqExpr( + blocks=[relax.DataflowBlock(new_binding_order)], + body=func.body.body, + ), func.ret_struct_info, func.is_pure, func.attrs, @@ -206,17 +262,10 @@ def reorder_func( @tvm.transform.module_pass(opt_level=0, name="ReorderTransformFunc") class ReorderTransformFunc: - def __init__( - self, - pidx2pname: Dict[int, str], - pname2binname: Dict[str, str], - f_convert_pname_fwd: Callable[[str], List[str]], - ) -> None: - self.pidx2binname: Dict[int, str] = { - pidx: pname2binname[f_convert_pname_fwd(pname)[0]] - for pidx, pname in pidx2pname.items() - if f_convert_pname_fwd(pname)[0] in pname2binname - } + def __init__(self, pidx2binname: Optional[Dict[int, str]] = None): + if pidx2binname is None: + pidx2binname = {} + self.pidx2binname = pidx2binname def transform_module( self, @@ -225,7 +274,7 @@ def transform_module( ) -> IRModule: mod = mod.clone() for gv, func in list(mod.functions.items()): - if isinstance(func, relax.Function): + if isinstance(func, relax.Function) and func.attrs and "global_symbol" in func.attrs: assert gv.name_hint.endswith("transform_params") func_updated = reorder_func(func, self.pidx2binname) mod[gv] = func_updated From 731616e9ba4e521718114fce693e794a8e8ad90d Mon Sep 17 00:00:00 2001 From: Kartik Khandelwal Date: Sun, 3 Mar 2024 07:19:06 -0800 Subject: [PATCH 024/531] [SLM] Migrate Phi-2 to paged KV Cache #1871 (#1872) This PR migrates Phi-2 for Paged KV cache Attention as a part of Model definition migration according to #1749 . Co-authored-by: Shrey Gupta --- python/mlc_chat/model/phi/phi_model.py | 309 +++++++++++++++++-------- 1 file changed, 211 insertions(+), 98 deletions(-) diff --git a/python/mlc_chat/model/phi/phi_model.py b/python/mlc_chat/model/phi/phi_model.py index 421876d16f..04360efbcd 100644 --- a/python/mlc_chat/model/phi/phi_model.py +++ b/python/mlc_chat/model/phi/phi_model.py @@ -2,6 +2,7 @@ Implementation for Phi architecture. TODO: add docstring """ + import dataclasses from typing import Any, Dict, Optional, Union @@ -10,6 +11,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_chat import op as op_ext +from mlc_chat.nn import PagedKVCache, RopeMode from mlc_chat.support import logging from mlc_chat.support import tensor_parallel as tp from mlc_chat.support.config import ConfigBase @@ -174,20 +176,9 @@ def forward(self, hidden_states: Tensor): return hidden_states -class PhiCrossAttention(nn.Module): - def __init__(self, config: PhiConfig): # pylint: disable=unused-argument - super().__init__() - - def forward(self, q: Tensor, k: Tensor, v: Tensor, attention_mask: Tensor): - output = op_ext.attention(q, k, v, casual_mask=attention_mask, qk_dtype="float32") - return output - - class PhiMHA(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: PhiConfig): - self.rope_theta = config.position_embedding_base - self.rotary_dim = config.rotary_dim - self.n_head = config.n_head // config.tensor_parallel_shards + self.num_q_heads = config.n_head // config.tensor_parallel_shards assert ( config.n_head % config.tensor_parallel_shards == 0 ), f"n_head({config.n_head}) must be divisible by tensor_parallel_shards" @@ -196,32 +187,36 @@ def __init__(self, config: PhiConfig): config.n_head_kv % config.tensor_parallel_shards == 0 ), f"n_head({config.n_head_kv}) must be divisible by tensor_parallel_shards" self.head_dim = config.head_dim - op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv) + op_size = self.head_dim * (self.num_q_heads + 2 * self.n_head_kv) hidden_size = config.n_embd self.Wqkv = nn.Linear(hidden_size, op_size, bias=True) - self.out_proj = nn.Linear(self.n_head * self.head_dim, hidden_size, bias=True) - self.inner_cross_attn = PhiCrossAttention(config) - self.k_cache = nn.KVCache(config.context_window_size, [self.n_head_kv, self.head_dim]) - self.v_cache = nn.KVCache(config.context_window_size, [self.n_head_kv, self.head_dim]) - - def forward(self, x: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): - d, h_q, h_kv, t = self.head_dim, self.n_head, self.n_head_kv, total_seq_len - b, s, _ = x.shape - assert b == 1, "Only support batch size 1 at this moment." - # Step 1. QKV Projection - qkv = self.Wqkv(x) + self.out_proj = nn.Linear(self.num_q_heads * self.head_dim, hidden_size, bias=True) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.n_head_kv + b, s, _ = hidden_states.shape + # QKV Projection + qkv = self.Wqkv(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + (b, s, h_q * d), + ) + return self.out_proj(output) + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.n_head_kv + b, s, _ = hidden_states.shape + # QKV Projection + qkv = self.Wqkv(hidden_states) qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) - # Step 2. Apply QK rotary embedding - q, k, v = op_ext.llama_rope(qkv, t, self.rope_theta, h_q, h_kv, rotary_dim=self.rotary_dim) - # Step 3. Query and update KVCache - self.k_cache.append(op.squeeze(k, axis=0)) - self.v_cache.append(op.squeeze(v, axis=0)) - k = self.k_cache.view(t) - v = self.v_cache.view(t) - # Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V - output = self.inner_cross_attn(q, k, v, attention_mask) - # Step 5. Apply output projection + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + (b, s, h_q * d), + ) return self.out_proj(output) @@ -238,14 +233,17 @@ def _set(param, hint): param.attrs["shard_strategy"] = hint hd = config.head_dim - q = self.mixer.n_head * hd + q = self.mixer.num_q_heads * hd k = self.mixer.n_head_kv * hd v = self.mixer.n_head_kv * hd _set( self.mixer.Wqkv.weight, tp.ShardSingleDim("_shard_qkv_weight", segs=[q, k, v], dim=0), ) - _set(self.mixer.Wqkv.bias, tp.ShardSingleDim("_shard_qkv_bias", segs=[q, k, v], dim=0)) + _set( + self.mixer.Wqkv.bias, + tp.ShardSingleDim("_shard_qkv_bias", segs=[q, k, v], dim=0), + ) _set(self.mixer.out_proj.weight, tp.ShardSingleDim("_shard_o_weight", dim=1)) _set(self.mlp.fc1.weight, tp.ShardSingleDim("_shard_mlp_fc1_weight", dim=0)) _set(self.mlp.fc1.bias, tp.ShardSingleDim("_shard_mlp_fc1_bias", dim=0)) @@ -254,32 +252,45 @@ def _set(param, hint): self.tensor_parallel_shards = config.tensor_parallel_shards _set_tp() - def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): residual = hidden_states hidden_states = self.ln(hidden_states) with tp.shard_bias(self.mixer.out_proj, self.tensor_parallel_shards), tp.shard_bias( self.mlp.fc2, self.tensor_parallel_shards ): - attn_outputs = self.mixer( - hidden_states, - attention_mask, - total_seq_len, - ) - + attn_outputs = self.mixer(hidden_states, paged_kv_cache, layer_id) feed_forward_hidden_states = self.mlp(hidden_states) - def _apply_parallel_residual(attn_out, mlp_out, residual): - if self.tensor_parallel_shards > 1: - return op.ccl_allreduce( - attn_out + mlp_out + residual / self.tensor_parallel_shards, "sum" - ) - return attn_out + mlp_out + residual + hidden_states = self._apply_parallel_residual( + attn_outputs, feed_forward_hidden_states, residual + ) + + return hidden_states + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + residual = hidden_states + hidden_states = self.ln(hidden_states) - hidden_states = _apply_parallel_residual(attn_outputs, feed_forward_hidden_states, residual) + with tp.shard_bias(self.mixer.out_proj, self.tensor_parallel_shards), tp.shard_bias( + self.mlp.fc2, self.tensor_parallel_shards + ): + attn_outputs = self.mixer.batch_forward(hidden_states, paged_kv_cache, layer_id) + feed_forward_hidden_states = self.mlp(hidden_states) + + hidden_states = self._apply_parallel_residual( + attn_outputs, feed_forward_hidden_states, residual + ) return hidden_states + def _apply_parallel_residual(self, attn_out, mlp_out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce( + attn_out + mlp_out + residual / self.tensor_parallel_shards, "sum" + ) + return attn_out + mlp_out + residual + class PhiCausalLMHead(nn.Module): def __init__(self, config: PhiConfig) -> None: @@ -300,21 +311,31 @@ def forward(self, hidden_states: Tensor): class PhiModel(nn.Module): def __init__(self, config: PhiConfig) -> None: super().__init__() - self.embd = nn.Embedding("vocab_size", config.n_embd) - self.h = nn.ModuleList([PhiParallelBlock(config) for i in range(config.n_layer)]) + self.embd = nn.Embedding(config.vocab_size, config.n_embd) + self.h = nn.ModuleList([PhiParallelBlock(config) for _ in range(config.n_layer)]) self.tensor_parallel_shards = config.tensor_parallel_shards - def forward(self, input_ids: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): if self.tensor_parallel_shards > 1: - input_ids = op.ccl_broadcast_from_worker0(input_ids) - hidden_states = self.embd(input_ids) - for layer in self.h: - hidden_states = layer(hidden_states, attention_mask, total_seq_len) + input_embed = op.ccl_broadcast_from_worker0(input_embed) + hidden_states = input_embed + for layer_id, layer in enumerate(self.h): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + + return hidden_states + + def batch_forward(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + if self.tensor_parallel_shards > 1: + input_embeds = op.ccl_broadcast_from_worker0(input_embeds) + hidden_states = input_embeds + for layer_id, layer in enumerate(self.h): + hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) return hidden_states class PhiForCausalLM(nn.Module): + # pylint: disable=too-many-instance-attributes def __init__(self, config: Union[PhiConfig, Phi1Config]) -> None: super().__init__() @@ -323,6 +344,15 @@ def __init__(self, config: Union[PhiConfig, Phi1Config]) -> None: self.transformer = PhiModel(config) self.lm_head = PhiCausalLMHead(config) + self.num_hidden_layers = config.n_layer + self.num_attention_heads = config.n_head + self.num_key_value_heads = config.n_head_kv + self.head_dim = config.head_dim + self.hidden_size = config.n_embd + self.vocab_size = config.vocab_size + self.rope_theta = config.position_embedding_base + self.tensor_parallel_shards = config.tensor_parallel_shards + self.rotary_dim = config.rotary_dim self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -330,71 +360,154 @@ def to(self, dtype: Optional[str] = None): if dtype is not None: self.dtype = dtype - def forward(self, input_ids: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): - def _index(x: te.Tensor): # x[:-1,:] + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.transformer.batch_forward(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + lm_logits = self.lm_head(hidden_states) + if lm_logits.dtype != "float32": + lm_logits = lm_logits.astype("float32") + return lm_logits + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): b, s, d = x.shape return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") - hidden_states = self.transformer(input_ids, total_seq_len, attention_mask) + hidden_states = self.transformer(input_embed, paged_kv_cache) hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) - lm_logits = self.lm_head(hidden_states) + logits = self.lm_head(hidden_states) - return lm_logits + if logits.dtype != "float32": + logits = logits.astype("float32") - def prefill(self, inputs: Tensor, total_seq_len: tir.Var): - def _attention_mask(batch_size, seq_len, total_seq_len): - return te.compute( - (batch_size, 1, seq_len, total_seq_len), - lambda b, _, i, j: tir.if_then_else( - i < j - (total_seq_len - seq_len), - tir.min_value(self.dtype), - tir.max_value(self.dtype), - ), - name="attention_mask_prefill", - ) + return logits, paged_kv_cache - batch_size, seq_len = inputs.shape - attention_mask = op.tensor_expr_op( - _attention_mask, - name_hint="attention_mask_prefill", - args=[batch_size, seq_len, total_seq_len], - ) - return self.forward(inputs, total_seq_len, attention_mask) + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() - def decode(self, inputs: Tensor, total_seq_len: tir.Var): - batch_size, seq_len = inputs.shape - attention_mask = op.full( - shape=[batch_size, 1, seq_len, total_seq_len], - fill_value=tir.max_value(self.dtype), - dtype=self.dtype, - ) - return self.forward(inputs, total_seq_len, attention_mask) + hidden_states = self.transformer(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def embed(self, input_ids: Tensor): + embeds = self.transformer.embd(input_ids) + return embeds + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + rotary_dim=self.rotary_dim, + dtype=self.dtype, + ) def get_default_spec(self): - batch_size = 1 mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "prefill": { - "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), - "total_seq_len": int, + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", }, }, "decode": { - "inputs": nn.spec.Tensor([batch_size, 1], "int32"), - "total_seq_len": int, + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, "$": { "param_mode": "none", "effect_mode": "none", From e4341b3088307cc4e944bebce73e5daf671d435e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 3 Mar 2024 15:28:43 -0500 Subject: [PATCH 025/531] [Fix] Fix the use of "call_inplace_packed" and "call_pure_packed" (#1874) The use of `call_inplace_packed` and `call_pure_packed` in the old flow is outdated due to signature changes. This PR fixes the issue. --- mlc_llm/relax_model/chatglm.py | 22 +++++-- mlc_llm/relax_model/gpt_bigcode.py | 16 +++-- mlc_llm/relax_model/gpt_neox.py | 16 +++-- mlc_llm/relax_model/gptj.py | 12 ++-- mlc_llm/relax_model/llama.py | 44 +++++++------ mlc_llm/relax_model/llama_batched_vllm.py | 22 ++++--- mlc_llm/relax_model/mistral.py | 31 ++++----- mlc_llm/relax_model/rwkv.py | 76 +++++++---------------- mlc_llm/relax_model/stablelm_3b.py | 16 +++-- 9 files changed, 131 insertions(+), 124 deletions(-) diff --git a/mlc_llm/relax_model/chatglm.py b/mlc_llm/relax_model/chatglm.py index 9a2afdff8a..f1a5b574dc 100644 --- a/mlc_llm/relax_model/chatglm.py +++ b/mlc_llm/relax_model/chatglm.py @@ -286,7 +286,8 @@ def forward( k_cache = nn.emit( relax.op.call_inplace_packed( f_kv_cache_append, - args=[k_cache, squeezed_k], + k_cache, + squeezed_k, inplace_indices=[0], sinfo_args=[relax.ObjectStructInfo()], ) @@ -294,7 +295,8 @@ def forward( v_cache = nn.emit( relax.op.call_inplace_packed( f_kv_cache_append, - args=[v_cache, squeezed_v], + v_cache, + squeezed_v, inplace_indices=[0], sinfo_args=[relax.ObjectStructInfo()], ) @@ -308,14 +310,16 @@ def forward( k = nn.emit( relax.call_pure_packed( f_kv_cache_view, - args=[k_cache, kv_cache_shape], + k_cache, + kv_cache_shape, sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)], ) ) v = nn.emit( relax.call_pure_packed( f_kv_cache_view, - args=[v_cache, kv_cache_shape], + v_cache, + kv_cache_shape, sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)], ) ) @@ -707,7 +711,9 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: ChatGLMConfig) -> None: bb.emit( relax.call_pure_packed( f_kv_cache_create, - args=[zeros, init_shape, relax.PrimValue(0)], + zeros, + init_shape, + relax.PrimValue(0), sinfo_args=[relax.ObjectStructInfo()], ) ) @@ -731,7 +737,11 @@ def get_model(args: argparse.Namespace, hf_config): model = args.model dtype = args.quantization.model_dtype - if model.startswith("chatglm2") or model.startswith("codegeex2") or model.startswith("chatglm3"): + if ( + model.startswith("chatglm2") + or model.startswith("codegeex2") + or model.startswith("chatglm3") + ): config = ChatGLMConfig( **hf_config, dtype=dtype, diff --git a/mlc_llm/relax_model/gpt_bigcode.py b/mlc_llm/relax_model/gpt_bigcode.py index a089390853..4f72400e3c 100644 --- a/mlc_llm/relax_model/gpt_bigcode.py +++ b/mlc_llm/relax_model/gpt_bigcode.py @@ -223,7 +223,8 @@ def te_slice(x: te.Tensor, start: int, end: int): k_cache = nn.emit( relax.op.call_inplace_packed( f_kv_cache_append, - args=[k_cache, squeezed_k], + k_cache, + squeezed_k, inplace_indices=[0], sinfo_args=[relax.ObjectStructInfo()], ) @@ -231,7 +232,8 @@ def te_slice(x: te.Tensor, start: int, end: int): v_cache = nn.emit( relax.op.call_inplace_packed( f_kv_cache_append, - args=[v_cache, squeezed_v], + v_cache, + squeezed_v, inplace_indices=[0], sinfo_args=[relax.ObjectStructInfo()], ) @@ -245,14 +247,16 @@ def te_slice(x: te.Tensor, start: int, end: int): k = nn.emit( relax.call_pure_packed( f_kv_cache_view, - args=[k_cache, kv_cache_shape], + k_cache, + kv_cache_shape, sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)], ) ) v = nn.emit( relax.call_pure_packed( f_kv_cache_view, - args=[v_cache, kv_cache_shape], + v_cache, + kv_cache_shape, sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)], ) ) @@ -580,7 +584,9 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: GPTBigCodeConfig) -> No bb.emit( relax.call_pure_packed( f_kv_cache_create, - args=[zeros, init_shape, relax.PrimValue(0)], + zeros, + init_shape, + relax.PrimValue(0), sinfo_args=[relax.ObjectStructInfo()], ) ) diff --git a/mlc_llm/relax_model/gpt_neox.py b/mlc_llm/relax_model/gpt_neox.py index cdf80d1740..30f2d25ac5 100644 --- a/mlc_llm/relax_model/gpt_neox.py +++ b/mlc_llm/relax_model/gpt_neox.py @@ -116,7 +116,8 @@ def forward( k_cache = nn.emit( relax.op.call_inplace_packed( f_kv_cache_append, - args=[k_cache, squeeze(k, axis=0)], + k_cache, + squeeze(k, axis=0), inplace_indices=[0], sinfo_args=[relax.ObjectStructInfo()], ) @@ -124,7 +125,8 @@ def forward( v_cache = nn.emit( relax.op.call_inplace_packed( f_kv_cache_append, - args=[v_cache, squeeze(v, axis=0)], + v_cache, + squeeze(v, axis=0), inplace_indices=[0], sinfo_args=[relax.ObjectStructInfo()], ) @@ -135,14 +137,16 @@ def forward( k = nn.emit( relax.call_pure_packed( f_kv_cache_view, - args=[k_cache, kv_cache_shape], + k_cache, + kv_cache_shape, sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)], ) ) v = nn.emit( relax.call_pure_packed( f_kv_cache_view, - args=[v_cache, kv_cache_shape], + v_cache, + kv_cache_shape, sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)], ) ) @@ -635,7 +639,9 @@ def create_kv_cache_func( bb.emit( relax.call_pure_packed( f_kv_cache_create, - args=[zeros, init_shape, relax.PrimValue(0)], + zeros, + init_shape, + relax.PrimValue(0), sinfo_args=[relax.ObjectStructInfo()], ) ) diff --git a/mlc_llm/relax_model/gptj.py b/mlc_llm/relax_model/gptj.py index 90965835ad..ea755a447a 100644 --- a/mlc_llm/relax_model/gptj.py +++ b/mlc_llm/relax_model/gptj.py @@ -155,7 +155,8 @@ def _project(proj): k_cache = nn.emit( relax.op.call_inplace_packed( f_kv_cache_append, - args=[k_cache, squeeze(k, axis=0)], + k_cache, + squeeze(k, axis=0), inplace_indices=[0], sinfo_args=[relax.ObjectStructInfo()], ) @@ -163,7 +164,8 @@ def _project(proj): v_cache = nn.emit( relax.op.call_inplace_packed( f_kv_cache_append, - args=[v_cache, squeeze(v, axis=0)], + v_cache, + squeeze(v, axis=0), inplace_indices=[0], sinfo_args=[relax.ObjectStructInfo()], ) @@ -174,14 +176,16 @@ def _project(proj): k = nn.emit( relax.call_pure_packed( f_kv_cache_view, - args=[k_cache, kv_cache_shape], + k_cache, + kv_cache_shape, sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)], ) ) v = nn.emit( relax.call_pure_packed( f_kv_cache_view, - args=[v_cache, kv_cache_shape], + v_cache, + kv_cache_shape, sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)], ) ) diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 06272e3a7b..7cad3d6fc4 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -1152,29 +1152,27 @@ def create_paged_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> N cache = bb.emit_output( relax.call_pure_packed( f_kv_cache_create, - args=[ - cache_config, - relax.PrimValue(config.num_hidden_layers), - relax.PrimValue(num_qo_heads), - relax.PrimValue(num_kv_heads), - relax.PrimValue(head_dim), - relax.PrimValue(1), - relax.PrimValue(config.position_embedding_base), - zeros, - bb.get().get_global_var("kv_cache_transpose_append"), - bb.get().get_global_var("attention_prefill"), - bb.get().get_global_var("attention_decode"), - bb.get().get_global_var("attention_prefill_ragged"), - bb.get().get_global_var("attention_prefill_ragged_begin_forward"), - bb.get().get_global_var("attention_prefill_ragged_end_forward"), - bb.get().get_global_var("attention_prefill_begin_forward"), - bb.get().get_global_var("attention_prefill_end_forward"), - bb.get().get_global_var("attention_decode_begin_forward"), - bb.get().get_global_var("attention_decode_end_forward"), - bb.get().get_global_var("attention_rope_in_place"), - bb.get().get_global_var("attention_merge_state"), - bb.get().get_global_var("kv_cache_debug_get_kv"), - ], + cache_config, + relax.PrimValue(config.num_hidden_layers), + relax.PrimValue(num_qo_heads), + relax.PrimValue(num_kv_heads), + relax.PrimValue(head_dim), + relax.PrimValue(1), + relax.PrimValue(config.position_embedding_base), + zeros, + bb.get().get_global_var("kv_cache_transpose_append"), + bb.get().get_global_var("attention_prefill"), + bb.get().get_global_var("attention_decode"), + bb.get().get_global_var("attention_prefill_ragged"), + bb.get().get_global_var("attention_prefill_ragged_begin_forward"), + bb.get().get_global_var("attention_prefill_ragged_end_forward"), + bb.get().get_global_var("attention_prefill_begin_forward"), + bb.get().get_global_var("attention_prefill_end_forward"), + bb.get().get_global_var("attention_decode_begin_forward"), + bb.get().get_global_var("attention_decode_end_forward"), + bb.get().get_global_var("attention_rope_in_place"), + bb.get().get_global_var("attention_merge_state"), + bb.get().get_global_var("kv_cache_debug_get_kv"), sinfo_args=[relax.ObjectStructInfo()], ) ) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 365500be04..4ff6fb0621 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -3,27 +3,27 @@ import numpy as np import tvm from tvm import relax, te -from tvm.relax.op import ccl, reshape, expand_dims, concat, zeros, repeat, take +from tvm.ir import VDevice +from tvm.relax.op import ccl, concat, expand_dims, repeat, reshape, take, zeros from tvm.relax.op.nn import attention_var_len from tvm.relax.testing import nn -from tvm.ir import VDevice from tvm.script import relax as R from tvm.script.ir_builder import tir as T from ..quantization import QuantizationScheme -from .modules import ModuleList -from .param_manager import ParamManager from .llama import ( - LlamaConfig, - Linear, Embedding, - LlamaRMSNorm, + Linear, LlamaAttentionBase, + LlamaConfig, LlamaDecoderLayer, + LlamaRMSNorm, get_param_quant_kind, - setup_params, rotary_modulate_by_freq, + setup_params, ) +from .modules import ModuleList +from .param_manager import ParamManager def apply_rotary_pos_emb(q, k, positions, position_embedding_base): @@ -95,7 +95,11 @@ def forward( kv = nn.emit( relax.op.call_inplace_packed( "tvm.contrib.vllm.reshape_and_cache", - args=[keys_to_cache, values_to_cache, k_cache, v_cache, slot_mapping], + keys_to_cache, + values_to_cache, + k_cache, + v_cache, + slot_mapping, inplace_indices=[2, 3], sinfo_args=[k_cache.struct_info, v_cache.struct_info], ) diff --git a/mlc_llm/relax_model/mistral.py b/mlc_llm/relax_model/mistral.py index e08495f2d9..f9959fdb11 100644 --- a/mlc_llm/relax_model/mistral.py +++ b/mlc_llm/relax_model/mistral.py @@ -48,6 +48,7 @@ def __init__( num_shards=1, **kwargs, ): + sliding_window = 4096 if sliding_window is None else sliding_window self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id @@ -345,14 +346,16 @@ def te_squeeze(x): key_cached = nn.emit( relax.call_pure_packed( f_kv_cache_view, - args=[k_cache, kv_cache_shape], + k_cache, + kv_cache_shape, sinfo_args=[R.Tensor(kv_cache_shape, kv_cur_dtype)], ) ) value_cached = nn.emit( relax.call_pure_packed( f_kv_cache_view, - args=[v_cache, kv_cache_shape], + v_cache, + kv_cache_shape, sinfo_args=[R.Tensor(kv_cache_shape, kv_cur_dtype)], ) ) @@ -402,12 +405,10 @@ def te_squeeze(x): k_cache = nn.emit( relax.op.call_inplace_packed( f_kv_cache_override, - args=[ - k_cache, - squeezed_key, - relax.PrimValue(self.sliding_window), - relax.PrimValue(attention_sink_size), - ], + k_cache, + squeezed_key, + relax.PrimValue(self.sliding_window), + relax.PrimValue(attention_sink_size), inplace_indices=[0], sinfo_args=[relax.ObjectStructInfo()], ) @@ -415,12 +416,10 @@ def te_squeeze(x): v_cache = nn.emit( relax.op.call_inplace_packed( f_kv_cache_override, - args=[ - v_cache, - squeezed_value, - relax.PrimValue(self.sliding_window), - relax.PrimValue(attention_sink_size), - ], + v_cache, + squeezed_value, + relax.PrimValue(self.sliding_window), + relax.PrimValue(attention_sink_size), inplace_indices=[0], sinfo_args=[relax.ObjectStructInfo()], ) @@ -960,7 +959,9 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: MistralConfig) -> None: bb.emit( relax.call_pure_packed( f_kv_cache_create, - args=[zeros, init_shape, relax.PrimValue(0)], + zeros, + init_shape, + relax.PrimValue(0), sinfo_args=[relax.ObjectStructInfo()], ) ) diff --git a/mlc_llm/relax_model/rwkv.py b/mlc_llm/relax_model/rwkv.py index 5b47cc31f9..3c1a9ffa0d 100644 --- a/mlc_llm/relax_model/rwkv.py +++ b/mlc_llm/relax_model/rwkv.py @@ -10,7 +10,7 @@ from ..quantization import ParamQuantKind, QuantizationScheme from .commons import create_metadata_func -from .modules import ModuleList, Linear +from .modules import Linear, ModuleList from .param_manager import ParamManager # Reference: https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model_run.py @@ -66,7 +66,8 @@ def _load_state(state: Expr, hidden_size: int, dtype: str) -> Expr: cache = nn.emit( relax.call_pure_packed( f_load_cache, - args=[state, R.shape([1, hidden_size])], + state, + R.shape([1, hidden_size]), sinfo_args=[R.Tensor((1, hidden_size), dtype)], ) ) @@ -80,7 +81,8 @@ def _store_state(state: Expr, value: Expr): return nn.emit( relax.op.call_inplace_packed( f_store_cache, - args=[state, value], + state, + value, inplace_indices=[0], sinfo_args=[R.Object()], ) @@ -179,9 +181,7 @@ class RWKV_Embedding(nn.Module): def __init__(self, num_embeddings, embedding_dim, dtype): self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim - self.weight = nn.Parameter( - (num_embeddings, embedding_dim), dtype=dtype, name="weight" - ) + self.weight = nn.Parameter((num_embeddings, embedding_dim), dtype=dtype, name="weight") def forward(self, x: relax.Expr) -> relax.Var: x = nn.emit(op.reshape(x, shape=[-1])) @@ -195,9 +195,7 @@ def __init__(self, intermediate_size, dtype, eps=1e-5, name_prefix=""): self.weight = nn.Parameter( (intermediate_size,), dtype=dtype, name=f"{name_prefix}_ln_weight" ) - self.bias = nn.Parameter( - (intermediate_size,), dtype=dtype, name=f"{name_prefix}_ln_bias" - ) + self.bias = nn.Parameter((intermediate_size,), dtype=dtype, name=f"{name_prefix}_ln_bias") def forward(self, x: relax.Expr) -> relax.Var: x = nn.emit( @@ -227,9 +225,7 @@ def __init__(self, config: RWKVConfig, index: int) -> None: self.key = Linear( self.hidden_size, config.intermediate_size, dtype=config.dtype, bias=False ) - self.receptance = Linear( - self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False - ) + self.receptance = Linear(self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False) self.value = Linear( config.intermediate_size, self.hidden_size, dtype=config.dtype, bias=False ) @@ -244,9 +240,7 @@ def forward(self, x: Expr, state: Expr) -> Expr: saved_x = nn.emit_te(_te_concat_saved_x, saved_x, x) ones = nn.emit(relax.op.ones((hidden_size,), self.dtype)) xk = nn.emit(x * self.time_mix_key + saved_x * (ones - self.time_mix_key)) - xr = nn.emit( - x * self.time_mix_receptance + saved_x * (ones - self.time_mix_receptance) - ) + xr = nn.emit(x * self.time_mix_receptance + saved_x * (ones - self.time_mix_receptance)) if not is_one(context_length): x = nn.emit_te(_te_get_last_x, x) assert is_one(x.struct_info.shape[0]) @@ -279,18 +273,10 @@ def __init__(self, config: RWKVConfig, index: int) -> None: self.time_mix_receptance = nn.Parameter( (self.hidden_size,), dtype=config.dtype, name=f"att_{index}_time_mix_r" ) - self.key = Linear( - self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False - ) - self.value = Linear( - self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False - ) - self.receptance = Linear( - self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False - ) - self.output = Linear( - self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False - ) + self.key = Linear(self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False) + self.value = Linear(self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False) + self.receptance = Linear(self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False) + self.output = Linear(self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False) def forward(self, x: Expr, state: Expr) -> Expr: # Load current state @@ -309,9 +295,7 @@ def forward(self, x: Expr, state: Expr) -> Expr: xk = nn.emit(x * self.time_mix_key + saved_x * (ones - self.time_mix_key)) xv = nn.emit(x * self.time_mix_value + saved_x * (ones - self.time_mix_value)) - xr = nn.emit( - x * self.time_mix_receptance + saved_x * (ones - self.time_mix_receptance) - ) + xr = nn.emit(x * self.time_mix_receptance + saved_x * (ones - self.time_mix_receptance)) r = nn.emit(op.sigmoid(self.receptance(xr))) k = nn.emit(op.astype(self.key(xk), "float32")) @@ -395,9 +379,7 @@ def __init__(self, config: RWKVConfig) -> None: embedding_dim=config.hidden_size, dtype=config.dtype, ) - self.blocks = ModuleList( - [RWKVLayer(config, i) for i in range(config.num_hidden_layers)] - ) + self.blocks = ModuleList([RWKVLayer(config, i) for i in range(config.num_hidden_layers)]) self.ln_out = RWKV_LayerNorm( config.hidden_size, config.dtype, @@ -423,9 +405,7 @@ def forward(self, input_ids: Expr, state: Expr) -> Tuple[Expr, List[Expr]]: class RWKVForCausalLM(nn.Module): def __init__(self, config: RWKVConfig): self.rwkv = RWKVModel(config) - self.head = Linear( - config.hidden_size, config.vocab_size, dtype=config.dtype, bias=False - ) + self.head = Linear(config.hidden_size, config.vocab_size, dtype=config.dtype, bias=False) self.vocab_size = config.vocab_size ############ End ############ @@ -443,9 +423,7 @@ def forward( return logits, key_value_cache -def get_param_quant_kind( - name: str, param_info: relax.TensorStructInfo -) -> ParamQuantKind: +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: if name.endswith("embeddings.weight"): return ParamQuantKind.embedding_table elif name == "head.weight": @@ -469,9 +447,7 @@ def create_func( with bb.function(func_name): model = RWKVForCausalLM(config) - param_manager.register_params( - model, func_name, quant_scheme, get_param_quant_kind - ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids = nn.Placeholder((1, seq_len), dtype="int32", name="input_ids") # Placeholder for compatibility to LLAMA @@ -519,7 +495,9 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: RWKVConfig) -> None: bb.emit( relax.call_pure_packed( f_kv_cache_create, - args=[init_value, init_shape, relax.PrimValue(1)], + init_value, + init_shape, + relax.PrimValue(1), sinfo_args=[R.Object()], ), name_hint=f"{name}_state_{i}", @@ -539,24 +517,18 @@ def create_kv_cache_reset_func(bb: relax.BlockBuilder, config: RWKVConfig) -> No fp32_neg_inf = bb.emit(fp32_zeros - relax.const(1e30, "float32")) caches = [] for i in range(config.num_hidden_layers): - caches.append( - _store_state(state[i * 5 + State.ATT_X], input_dtype_zeros) - ) + caches.append(_store_state(state[i * 5 + State.ATT_X], input_dtype_zeros)) caches.append(_store_state(state[i * 5 + State.ATT_B], fp32_zeros)) caches.append(_store_state(state[i * 5 + State.ATT_A], fp32_zeros)) caches.append(_store_state(state[i * 5 + State.ATT_P], fp32_neg_inf)) - caches.append( - _store_state(state[i * 5 + State.FFN_X], input_dtype_zeros) - ) + caches.append(_store_state(state[i * 5 + State.FFN_X], input_dtype_zeros)) gv = bb.emit_output(caches) bb.emit_func_output(gv) def create_softmax_func(bb: relax.BlockBuilder, config: RWKVConfig) -> None: with bb.function("softmax_with_temperature"): - logits = nn.Placeholder( - (1, 1, config.vocab_size), dtype="float32", name="logits" - ) + logits = nn.Placeholder((1, 1, config.vocab_size), dtype="float32", name="logits") temperature = nn.Placeholder((), dtype="float32", name="temperature") with bb.dataflow(): div = bb.emit(relax.op.divide(logits, temperature)) diff --git a/mlc_llm/relax_model/stablelm_3b.py b/mlc_llm/relax_model/stablelm_3b.py index ac1c9a71ad..c39b8018ce 100644 --- a/mlc_llm/relax_model/stablelm_3b.py +++ b/mlc_llm/relax_model/stablelm_3b.py @@ -269,7 +269,8 @@ def forward( k_cache = nn.emit( relax.op.call_inplace_packed( f_kv_cache_append, - args=[k_cache, squeezed_key], + k_cache, + squeezed_key, inplace_indices=[0], sinfo_args=[relax.ObjectStructInfo()], ) @@ -277,7 +278,8 @@ def forward( v_cache = nn.emit( relax.op.call_inplace_packed( f_kv_cache_append, - args=[v_cache, squeezed_value], + v_cache, + squeezed_value, inplace_indices=[0], sinfo_args=[relax.ObjectStructInfo()], ) @@ -287,14 +289,16 @@ def forward( k_cache = nn.emit( relax.call_pure_packed( f_kv_cache_view, - args=[k_cache, kv_cache_shape], + k_cache, + kv_cache_shape, sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], ) ) v_cache = nn.emit( relax.call_pure_packed( f_kv_cache_view, - args=[v_cache, kv_cache_shape], + v_cache, + kv_cache_shape, sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], ) ) @@ -721,7 +725,9 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> No bb.emit( relax.call_pure_packed( f_kv_cache_create, - args=[zeros, init_shape, relax.PrimValue(0)], + zeros, + init_shape, + relax.PrimValue(0), sinfo_args=[relax.ObjectStructInfo()], ) ) From c0606ecc1789935ba8bf6d11cf48e48d2e58097d Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 3 Mar 2024 15:44:33 -0500 Subject: [PATCH 026/531] [Fix] Add the missing BundleModelParams pass (#1875) PR #1852 missed to apply the BundleModelParams pass and thus made the compiled models not runnable through ChatModule (#1864). This PR fixes the issue. --- mlc_llm/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 065b3a29ac..d4855582e6 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -847,6 +847,7 @@ def build_model_from_args(args: argparse.Namespace): qspec_updater = qspec_updater_class(param_manager) qspec_updater.visit_module(mod) mod = param_manager.transform_dequantize()(mod) + mod = relax.transform.BundleModelParams()(mod) if not args.build_model_only: parameter_transforms = [] From 07af0f98b49490c9e4f394ea90e55afde7a95e7c Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 3 Mar 2024 16:11:57 -0500 Subject: [PATCH 027/531] [Docs] Update Android APK download link (#1876) As pointed out by #1830, this PR fixes the Android app download link in docs. --- docs/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 15ad6ca536..596e5d3877 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -146,7 +146,7 @@ It is recommended to have at least 6GB free VRAM to run it. .. image:: https://seeklogo.com/images/D/download-android-apk-badge-logo-D074C6882B-seeklogo.com.png :width: 135 - :target: https://github.com/mlc-ai/binary-mlc-llm-libs/raw/main/mlc-chat.apk + :target: https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android/mlc-chat.apk | From 837869ae758330b64ff33cb6d8d2de2f14da5260 Mon Sep 17 00:00:00 2001 From: Diego Cao <50705298+DiegoCao@users.noreply.github.com> Date: Sun, 3 Mar 2024 17:57:10 -0500 Subject: [PATCH 028/531] Fix MLC-LLM website link weight convert not accessible (#1877) Fix website link not accessible --- docs/compilation/compile_models.rst | 4 ++-- docs/compilation/convert_weights.rst | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index e9a3d631c2..24ebbed730 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -5,8 +5,8 @@ Compile Model Libraries To run a model with MLC LLM in any platform, you need: -1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-MLC - `_.) +1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + `_.) 2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). If you are simply adding a model variant, follow :ref:`convert-weights-via-MLC` suffices. diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst index 6b39cf8b68..ef39cd9efb 100644 --- a/docs/compilation/convert_weights.rst +++ b/docs/compilation/convert_weights.rst @@ -5,8 +5,8 @@ Convert Weights via MLC To run a model with MLC LLM in any platform, you need: -1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-MLC - `_.) +1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC + `__). In many cases, we only need to convert weights and reuse existing model library. From d2cfb1edd7a84f9bdb10010e1ed36b9b6e14a520 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Mon, 4 Mar 2024 21:43:04 +0800 Subject: [PATCH 029/531] [Serving][Grammar] Support termination state in GrammarStateMatcher (#1884) --- cpp/serve/engine_actions/action_commons.cc | 7 --- cpp/serve/grammar/grammar_state_matcher.cc | 54 ++++++++++++++++--- cpp/serve/grammar/grammar_state_matcher.h | 11 ++++ .../grammar/grammar_state_matcher_base.h | 4 +- .../grammar/grammar_state_matcher_preproc.h | 31 +++++------ cpp/serve/logit_processor.cc | 7 +-- cpp/serve/request_state.cc | 8 +++ .../mlc_chat/protocol/openai_api_protocol.py | 10 ++-- python/mlc_chat/protocol/protocol_utils.py | 5 +- python/mlc_chat/serve/grammar.py | 20 +++++++ tests/python/serve/test_grammar_parser.py | 2 +- .../serve/test_grammar_state_matcher.py | 39 ++++++++++++-- 12 files changed, 149 insertions(+), 49 deletions(-) diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index d665dea778..85248062a4 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -120,13 +120,6 @@ void ActionStepPostProcess(Array requests, EngineState estate, Arraymstates[0]->grammar_state_matcher) { - const auto& grammar_state_matcher = rsentry->mstates[0]->grammar_state_matcher.value(); - for (int32_t token_id : delta_request_ret.delta_token_ids) { - grammar_state_matcher->AcceptToken(token_id); - } - } } } diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index a0b2350a2e..3087a3d665 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -139,6 +139,8 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm int MaxRollbackSteps() const final { return max_rollback_steps_; } + bool IsTerminated() const { return stack_tops_history_.GetLatest().empty(); } + void ResetState() final { stack_tops_history_.Reset(); token_size_history_.clear(); @@ -161,6 +163,18 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm void SetTokenBitmask(DLTensor* next_token_bitmask, std::vector& accepted_indices, std::vector& rejected_indices, bool can_reach_end); + /*! \brief Check if a token is a stop token. */ + bool IsStopToken(int32_t token_id) const { + return std::find(init_ctx_->stop_token_ids.begin(), init_ctx_->stop_token_ids.end(), + token_id) != init_ctx_->stop_token_ids.end(); + } + + /*! + * \brief Accept the stop token and terminates the matcher. + * \returns Whether the stop token can be accepted. + */ + bool AcceptStopToken(); + friend IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher); std::shared_ptr init_ctx_; @@ -175,10 +189,28 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm std::vector tmp_uncertain_tokens_bitset_; }; +bool GrammarStateMatcherNodeImpl::AcceptStopToken() { + if (!CanReachEnd()) { + return false; + } + stack_tops_history_.PushHistory({}); // Terminate the matcher by setting the stack to empty + return true; +} + bool GrammarStateMatcherNodeImpl::AcceptToken(int32_t token_id) { - CHECK(init_ctx_->codepoint_tokens_lookup.count(token_id) > 0) + CHECK(!IsTerminated()) + << "GrammarStateMatcher has terminated after accepting the stop token, but is trying to " + "accept another token id " + << token_id; + + // Handle the stop token + if (IsStopToken(token_id)) { + return AcceptStopToken(); + } + + CHECK(init_ctx_->id_to_token_codepoints.count(token_id) > 0) << "Token id " << token_id << " is not supported in generation"; - const auto& token = init_ctx_->codepoint_tokens_lookup[token_id].token; + const auto& token = init_ctx_->id_to_token_codepoints[token_id].token; for (auto codepoint : token) { if (!AcceptCodepoint(codepoint, false)) { return false; @@ -193,7 +225,10 @@ bool GrammarStateMatcherNodeImpl::AcceptToken(int32_t token_id) { } void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitmask) { - const auto& tokens_sorted_by_codepoint = init_ctx_->tokens_sorted_by_codepoint; + CHECK(!IsTerminated()) + << "GrammarStateMatcher has terminated after accepting the stop token, but is trying to " + "find the next token mask"; + const auto& sorted_token_codepoints = init_ctx_->sorted_token_codepoints; const auto& catagorized_tokens_for_grammar = init_ctx_->catagorized_tokens_for_grammar; const auto& latest_stack_tops = stack_tops_history_.GetLatest(); @@ -202,7 +237,7 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm // The final accepted token set is the union of the accepted token sets of all stacks. // The final rejected token set is the intersection of the rejected token sets of all stacks. - // Note these indices store the indices in tokens_sorted_by_codepoint, instead of the token ids. + // Note these indices store the indices in sorted_token_codepoints, instead of the token ids. tmp_accepted_indices_.clear(); // {-1} means the universal set, i.e. all tokens initially tmp_rejected_indices_.assign({-1}); @@ -245,7 +280,7 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm if (!is_uncertain_saved) { // unc_tokens = all_tokens - accepted_tokens - rejected_tokens - tmp_uncertain_tokens_bitset_.assign(tokens_sorted_by_codepoint.size(), true); + tmp_uncertain_tokens_bitset_.assign(sorted_token_codepoints.size(), true); for (auto idx : catagorized_tokens.accepted_indices) { tmp_uncertain_tokens_bitset_[idx] = false; } @@ -264,7 +299,7 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm if (idx == -1) { break; } - const auto& cur_token = tokens_sorted_by_codepoint[idx].token; + const auto& cur_token = sorted_token_codepoints[idx].token; // Step 2.2. Find the longest common prefix with the accepted part of the previous token. // We can reuse the previous matched size to avoid unnecessary matching. @@ -353,7 +388,7 @@ void GrammarStateMatcherNodeImpl::SetTokenBitmask(DLTensor* next_token_bitmask, // accepted_indices next_token_bitset.Reset(init_ctx_->vocab_size, false); for (int idx : accepted_indices) { - next_token_bitset.Set(init_ctx_->tokens_sorted_by_codepoint[idx].id, true); + next_token_bitset.Set(init_ctx_->sorted_token_codepoints[idx].id, true); } if (can_reach_end) { @@ -372,7 +407,7 @@ void GrammarStateMatcherNodeImpl::SetTokenBitmask(DLTensor* next_token_bitmask, ++it_acc; } if (it_acc == accepted_indices.end() || *it_acc != i) { - next_token_bitset.Set(init_ctx_->tokens_sorted_by_codepoint[i].id, false); + next_token_bitset.Set(init_ctx_->sorted_token_codepoints[i].id, false); } } @@ -452,6 +487,9 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherRollback") TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherMaxRollbackSteps") .set_body_typed([](GrammarStateMatcher matcher) { return matcher->MaxRollbackSteps(); }); +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherIsTerminated") + .set_body_typed([](GrammarStateMatcher matcher) { return matcher->IsTerminated(); }); + TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherResetState") .set_body_typed([](GrammarStateMatcher matcher) { matcher->ResetState(); }); diff --git a/cpp/serve/grammar/grammar_state_matcher.h b/cpp/serve/grammar/grammar_state_matcher.h index ec6e8f19b1..443a791edc 100644 --- a/cpp/serve/grammar/grammar_state_matcher.h +++ b/cpp/serve/grammar/grammar_state_matcher.h @@ -58,6 +58,11 @@ class GrammarStateMatcherNode : public Object { * \brief Accept one token and update the state of the matcher. * \param token_id The id of the token to accept. * \return Whether the token is accepted. + * \note Termination state. + * When the end of the main rule is reached, the matcher can only accept the stop token. + * The matcher is terminated after accepting the stop token, i.e. no AcceptToken or + * FindNextTokenMask operations can be performed. The termination state can be canceled + * using Rollback(). */ virtual bool AcceptToken(int32_t token_id) = 0; @@ -79,6 +84,12 @@ class GrammarStateMatcherNode : public Object { /*! \brief Get the maximum number of rollback steps allowed. */ virtual int MaxRollbackSteps() const = 0; + /*! + * \brief Check if the matcher has accepted the stop token and terminated. + * \sa AcceptToken + */ + virtual bool IsTerminated() const = 0; + /*! \brief Reset the matcher to the initial state. */ virtual void ResetState() = 0; diff --git a/cpp/serve/grammar/grammar_state_matcher_base.h b/cpp/serve/grammar/grammar_state_matcher_base.h index 11623661e7..0028994b3c 100644 --- a/cpp/serve/grammar/grammar_state_matcher_base.h +++ b/cpp/serve/grammar/grammar_state_matcher_base.h @@ -86,9 +86,9 @@ inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool if (verbose) { std::cout << "Stack before accepting: " << PrintStackState() << std::endl; } - tmp_new_stack_tops_.clear(); - const auto& prev_stack_tops = stack_tops_history_.GetLatest(); + + tmp_new_stack_tops_.clear(); for (auto old_top : prev_stack_tops) { const auto& rule_position = tree_[old_top]; auto current_sequence = grammar_->GetRuleExpr(rule_position.sequence_id); diff --git a/cpp/serve/grammar/grammar_state_matcher_preproc.h b/cpp/serve/grammar/grammar_state_matcher_preproc.h index 194d5b2935..3d1ffeb754 100644 --- a/cpp/serve/grammar/grammar_state_matcher_preproc.h +++ b/cpp/serve/grammar/grammar_state_matcher_preproc.h @@ -31,7 +31,7 @@ struct TokenAndId { * into three categories: accepted, rejected, and uncertain. * \note Since the union of these three sets is the whole token set, we only need to store the * smaller two sets. The unsaved set is specified by not_saved_index. - * \note These indices are the indices of tokens_sorted_by_codepoint in the GrammarStateInitContext + * \note These indices are the indices of sorted_token_codepoints in the GrammarStateInitContext * object, instead of the token ids. That helps the matching process. */ struct CatagorizedTokens { @@ -59,11 +59,12 @@ class GrammarStateInitContext { /*! \brief The vocabulary size of the tokenizer. */ size_t vocab_size; - /*! \brief The sorted token and its id. Tokens are sorted to reuse the common prefix during - * matching. */ - std::vector tokens_sorted_by_codepoint; - /*! \brief The mapping from token id to token represented by codepoints. */ - std::unordered_map codepoint_tokens_lookup; + /*! \brief All tokens represented by the id and codepoints of each. The tokens are sorted by + * codepoint values to reuse the common prefix during matching. */ + std::vector sorted_token_codepoints; + /*! \brief The mapping from token id to token represented by codepoints. Only contains + * non-special and non-stop tokens. */ + std::unordered_map id_to_token_codepoints; /*! \brief The stop tokens. They can be accepted iff GramamrMatcher can reach the end of the * grammar. */ std::vector stop_token_ids; @@ -104,7 +105,7 @@ class GrammarStateMatcherForInitContext : public GrammarStateMatcherBase { GrammarStateMatcherForInitContext(const BNFGrammar& grammar, RulePosition init_rule_position) : GrammarStateMatcherBase(grammar, init_rule_position) {} - CatagorizedTokens GetCatagorizedTokens(const std::vector& tokens_sorted_by_codepoint, + CatagorizedTokens GetCatagorizedTokens(const std::vector& sorted_token_codepoints, bool is_main_rule); private: @@ -155,7 +156,7 @@ inline CatagorizedTokens::CatagorizedTokens(std::vector&& accepted_indi } inline CatagorizedTokens GrammarStateMatcherForInitContext::GetCatagorizedTokens( - const std::vector& tokens_sorted_by_codepoint, bool is_main_rule) { + const std::vector& sorted_token_codepoints, bool is_main_rule) { // Support the current stack contains only one stack with one RulePosition. // Iterate over all tokens. Split them into three categories: // - accepted_indices: If a token is accepted by current rule @@ -173,9 +174,9 @@ inline CatagorizedTokens GrammarStateMatcherForInitContext::GetCatagorizedTokens tmp_can_see_end_stack_.assign({CanReachEnd()}); int prev_matched_size = 0; - for (int i = 0; i < static_cast(tokens_sorted_by_codepoint.size()); ++i) { - const auto& token = tokens_sorted_by_codepoint[i].token; - const auto* prev_token = i > 0 ? &tokens_sorted_by_codepoint[i - 1].token : nullptr; + for (int i = 0; i < static_cast(sorted_token_codepoints.size()); ++i) { + const auto& token = sorted_token_codepoints[i].token; + const auto* prev_token = i > 0 ? &sorted_token_codepoints[i - 1].token : nullptr; // Find the longest common prefix with the accepted part of the previous token. auto prev_useful_size = 0; @@ -268,11 +269,11 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC DCHECK(!codepoints.empty() && codepoints[0] != static_cast(CharHandlingError::kInvalidUtf8)) << "Invalid token: " << token; - ptr->tokens_sorted_by_codepoint.push_back({codepoints, i}); - ptr->codepoint_tokens_lookup[i] = {codepoints, i}; + ptr->sorted_token_codepoints.push_back({codepoints, i}); + ptr->id_to_token_codepoints[i] = {codepoints, i}; } } - std::sort(ptr->tokens_sorted_by_codepoint.begin(), ptr->tokens_sorted_by_codepoint.end()); + std::sort(ptr->sorted_token_codepoints.begin(), ptr->sorted_token_codepoints.end()); // Find the corresponding catagorized tokens for: // 1. All character elements in the grammar @@ -307,7 +308,7 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC auto grammar_state_matcher = GrammarStateMatcherForInitContext(grammar, cur_rule_position); auto cur_catagorized_tokens_for_grammar = - grammar_state_matcher.GetCatagorizedTokens(ptr->tokens_sorted_by_codepoint, i == 0); + grammar_state_matcher.GetCatagorizedTokens(ptr->sorted_token_codepoints, i == 0); ptr->catagorized_tokens_for_grammar[{sequence_id, element_id}] = cur_catagorized_tokens_for_grammar; } diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index 5af7a39d29..1afcf10c60 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -100,7 +100,7 @@ class LogitProcessorImpl : public LogitProcessorObj { // Update 3. Vocabulary mask. RECORD_EVENT(trace_recorder_, request_ids, "start apply logit mask"); - UpdateWithMask(logits, mstates, cum_num_token, draft_tokens, request_ids); + UpdateWithMask(logits, mstates, cum_num_token, draft_tokens); RECORD_EVENT(trace_recorder_, request_ids, "finish apply logit mask"); RECORD_EVENT(trace_recorder_, request_ids, "finish update logits"); @@ -302,8 +302,7 @@ class LogitProcessorImpl : public LogitProcessorObj { void UpdateWithMask(NDArray logits, const Array& mstates, const std::vector* cum_num_token, - const std::vector>* draft_tokens, - const Array& request_ids) { + const std::vector>* draft_tokens) { // Construct: // - seq_ids (max_num_token,) int32 // - bitmask (max_num_token, ceildiv(vocab_size, 32)), int32 @@ -311,8 +310,6 @@ class LogitProcessorImpl : public LogitProcessorObj { uint32_t* p_bitmask = static_cast(bitmask_host_->data); // - Set arrays. - ICHECK(mstates.size() == request_ids.size()); - int batch_size = logits->shape[0]; ICHECK((cum_num_token == nullptr && batch_size == mstates.size()) || (cum_num_token != nullptr && batch_size == cum_num_token->size())); diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 8b5543d4f1..7dc9d0b627 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -49,6 +49,14 @@ void RequestModelStateNode::FindNextTokenBitmask(DLTensor* bitmask) { void RequestModelStateNode::CommitToken(SampleResult sampled_token) { committed_tokens.push_back(std::move(sampled_token)); appeared_token_ids[sampled_token.sampled_token_id.first] += 1; + + // Update the grammar matcher state if it exists. + if (grammar_state_matcher) { + bool accepted = + grammar_state_matcher.value()->AcceptToken(sampled_token.sampled_token_id.first); + ICHECK(accepted) << "Token id " << sampled_token.sampled_token_id.first + << " is not accepted by the grammar state matcher."; + } } void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, NDArray prob_dist) { diff --git a/python/mlc_chat/protocol/openai_api_protocol.py b/python/mlc_chat/protocol/openai_api_protocol.py index b0d4d56192..8e56d3855f 100644 --- a/python/mlc_chat/protocol/openai_api_protocol.py +++ b/python/mlc_chat/protocol/openai_api_protocol.py @@ -10,6 +10,8 @@ import shortuuid from pydantic import BaseModel, Field, field_validator, model_validator +from mlc_chat.serve.config import ResponseFormat + ################ Commons ################ @@ -65,7 +67,7 @@ class ModelResponse(BaseModel): ################ v1/completions ################ -class ResponseFormat(BaseModel): +class RequestResponseFormat(BaseModel): type: Literal["text", "json_object"] = "text" json_schema: Optional[str] = None @@ -94,7 +96,7 @@ class CompletionRequest(BaseModel): top_p: float = 1.0 user: Optional[str] = None ignore_eos: bool = False - response_format: ResponseFormat = ResponseFormat() + response_format: RequestResponseFormat = Field(default_factory=RequestResponseFormat) @field_validator("frequency_penalty", "presence_penalty") @classmethod @@ -208,7 +210,7 @@ class ChatCompletionRequest(BaseModel): tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None user: Optional[str] = None ignore_eos: bool = False - response_format: ResponseFormat = ResponseFormat() + response_format: RequestResponseFormat = Field(default_factory=RequestResponseFormat) @field_validator("frequency_penalty", "presence_penalty") @classmethod @@ -331,5 +333,5 @@ def openai_api_get_generation_config( kwargs["max_tokens"] = -1 if request.stop is not None: kwargs["stop_strs"] = [request.stop] if isinstance(request.stop, str) else request.stop - kwargs["response_format"] = request.response_format.model_dump() + kwargs["response_format"] = ResponseFormat(**request.response_format.model_dump()) return kwargs diff --git a/python/mlc_chat/protocol/protocol_utils.py b/python/mlc_chat/protocol/protocol_utils.py index b515ffc47c..a9a68a1f82 100644 --- a/python/mlc_chat/protocol/protocol_utils.py +++ b/python/mlc_chat/protocol/protocol_utils.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from ..serve.config import GenerationConfig, ResponseFormat +from ..serve.config import GenerationConfig from . import RequestProtocol from .openai_api_protocol import ChatCompletionRequest as OpenAIChatCompletionRequest from .openai_api_protocol import CompletionRequest as OpenAICompletionRequest @@ -43,9 +43,6 @@ def get_generation_config( else: raise RuntimeError("Cannot reach here") - response_format_dict = kwargs.get("response_format", {}) - kwargs["response_format"] = ResponseFormat(**response_format_dict) - if extra_stop_token_ids is not None: stop_token_ids = kwargs.get("stop_token_ids", []) assert isinstance(stop_token_ids, list) diff --git a/python/mlc_chat/serve/grammar.py b/python/mlc_chat/serve/grammar.py index 3df954cb22..f6122c5e8a 100644 --- a/python/mlc_chat/serve/grammar.py +++ b/python/mlc_chat/serve/grammar.py @@ -179,6 +179,15 @@ def accept_token(self, token_id: int) -> bool: ------- accepted : bool Whether the token is accepted. + + Note + ---- + Termination state. + + When the end of the main rule is reached, the matcher can only accept the stop token. + The matcher is terminated after accepting the stop token, i.e. no accept_token or + find_next_rejected_tokens operations can be performed. The termination state can be canceled + using Rollback(). """ return _ffi_api.GrammarStateMatcherAcceptToken(self, token_id) # type: ignore # pylint: disable=no-member @@ -218,6 +227,17 @@ def reset_state(self) -> None: """Reset the matcher to the initial state.""" _ffi_api.GrammarStateMatcherResetState(self) # type: ignore # pylint: disable=no-member + def is_terminated(self) -> bool: + """Check if the matcher has accepted the stop token and terminated. See also + GrammarStateMatcher.accept_token. + + Returns + ------- + terminated : bool + Whether the matcher has terminated. + """ + return _ffi_api.GrammarStateMatcherIsTerminated(self) # type: ignore # pylint: disable=no-member + def debug_accept_char(self, codepoint: int) -> bool: """Accept one unicode codepoint to the current state. diff --git a/tests/python/serve/test_grammar_parser.py b/tests/python/serve/test_grammar_parser.py index dd6cc64b5d..ceffd5805d 100644 --- a/tests/python/serve/test_grammar_parser.py +++ b/tests/python/serve/test_grammar_parser.py @@ -3,7 +3,7 @@ import pytest import tvm.testing -from tvm._ffi.base import TVMError +from tvm import TVMError from mlc_chat.serve import BNFGrammar diff --git a/tests/python/serve/test_grammar_state_matcher.py b/tests/python/serve/test_grammar_state_matcher.py index 61d6341c48..c03a414931 100644 --- a/tests/python/serve/test_grammar_state_matcher.py +++ b/tests/python/serve/test_grammar_state_matcher.py @@ -6,6 +6,7 @@ import pytest import tvm import tvm.testing +from tvm import TVMError from mlc_chat.serve import BNFGrammar, GrammarStateMatcher from mlc_chat.tokenizer import Tokenizer @@ -268,7 +269,8 @@ def test_find_next_rejected_tokens( assert real_sizes == expected_rejected_sizes -def test_accept_token(json_grammar: BNFGrammar): +def test_token_based_operations(json_grammar: BNFGrammar): + """Test accepting token and finding the next token mask.""" token_table = [ # fmt: off "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', @@ -279,8 +281,6 @@ def test_accept_token(json_grammar: BNFGrammar): grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) - result = [] - expected = [ ["{"], ['"', "}", "\n", " ", '"a":true'], @@ -295,6 +295,8 @@ def test_accept_token(json_grammar: BNFGrammar): [""], ] + result = [] + for id in input_ids: rejected = grammar_state_matcher.find_next_rejected_tokens() accepted = list(set(range(len(token_table))) - set(rejected)) @@ -369,6 +371,37 @@ def test_reset(json_grammar: BNFGrammar): assert orig_result == result_after_reset +def test_termination(json_grammar: BNFGrammar): + token_table = [ + # fmt: off + "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', + # fmt: on + ] + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}", ""] + input_ids = [token_table.index(t) for t in input_splitted] + + grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table, 5) + + orig_result = [] + + for i in input_ids: + orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) + assert grammar_state_matcher.accept_token(i) + + assert grammar_state_matcher.is_terminated() + + with pytest.raises(TVMError): + grammar_state_matcher.accept_token(0) + + with pytest.raises(TVMError): + grammar_state_matcher.find_next_rejected_tokens() + + grammar_state_matcher.rollback(2) + + assert not grammar_state_matcher.is_terminated() + assert grammar_state_matcher.accept_token(input_ids[-2]) + + if __name__ == "__main__": # Run a benchmark to show the performance before running tests test_find_next_rejected_tokens( From 65ec85d7f8f24c39b631dcc361dbf8e0e8f3ad8d Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 4 Mar 2024 08:45:58 -0500 Subject: [PATCH 030/531] [Serving] Make RequestState as a standalone object class (#1878) This PR adopts suggestions from the support of OpenAI API parallel generation `n` in #1868. The main update in this PR is to make the RequestState as a standalone object class, which was a typedef from `std::vector` before. This PR also fixes a bug in prefill that will cause engine failure when `n` is large. --- cpp/serve/engine.cc | 35 ++++----- cpp/serve/engine_actions/action_commons.cc | 37 +++++----- cpp/serve/engine_actions/action_commons.h | 4 +- cpp/serve/engine_actions/batch_decode.cc | 9 ++- cpp/serve/engine_actions/batch_draft.cc | 7 +- cpp/serve/engine_actions/batch_verify.cc | 4 +- .../engine_actions/new_request_prefill.cc | 71 ++++++++++--------- cpp/serve/request_state.cc | 12 ++++ cpp/serve/request_state.h | 35 ++++++--- cpp/serve/sampler.cc | 43 ++++++----- cpp/serve/sampler.h | 28 +++----- 11 files changed, 159 insertions(+), 126 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 411dbfc908..56cab63927 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -142,23 +142,23 @@ class EngineImpl : public Engine { int n = request->generation_cfg->n; int rng_seed = request->generation_cfg->seed; - RequestState rstate; + std::vector rsentries; // Create the request state entry for the input. - rstate.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(), rng_seed, - token_table_, json_grammar_state_init_ctx_); + rsentries.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(), rng_seed, + token_table_, json_grammar_state_init_ctx_); if (n > 1) { // Then create a request state entry for each parallel generation branch. // We add a offset to the rng seed so that to make generations different. - rstate.reserve(n + 1); - rstate[0]->children_idx.reserve(n); + rsentries.reserve(n + 1); + rsentries[0]->child_indices.reserve(n); for (int i = 0; i < n; ++i) { - rstate[0]->children_idx.push_back(rstate.size()); - rstate.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(), - rng_seed + i + 1, token_table_, json_grammar_state_init_ctx_, - /*parent_idx=*/0); + rsentries[0]->child_indices.push_back(rsentries.size()); + rsentries.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(), + rng_seed + i + 1, token_table_, json_grammar_state_init_ctx_, + /*parent_idx=*/0); } } - estate_->request_states.emplace(request->id, rstate); + estate_->request_states.emplace(request->id, RequestState(std::move(rsentries))); } void AbortRequest(const String& request_id) final { @@ -169,7 +169,7 @@ class EngineImpl : public Engine { } RequestState rstate = it_rstate->second; - Request request = rstate[0]->request; + Request request = rstate->entries[0]->request; // - Check if the request is running or pending. auto it_running = @@ -177,7 +177,7 @@ class EngineImpl : public Engine { auto it_waiting = std::find(estate_->waiting_queue.begin(), estate_->waiting_queue.end(), request); - for (const RequestStateEntry& rsentry : rstate) { + for (const RequestStateEntry& rsentry : rstate->entries) { estate_->id_manager.RecycleId(rsentry->mstates[0]->internal_id); } estate_->request_states.erase(request->id); @@ -188,13 +188,14 @@ class EngineImpl : public Engine { // Reduce the input length. estate_->stats.current_total_seq_len -= request->input_total_length; // Reduce the generated length. - for (int i = 0; i < static_cast(rstate.size()); ++i) { - if (rstate[i]->status != RequestStateStatus::kAlive) { + for (int i = 0; i < static_cast(rstate->entries.size()); ++i) { + if (rstate->entries[i]->status != RequestStateStatus::kAlive) { continue; } - estate_->stats.current_total_seq_len -= rstate[i]->mstates[0]->committed_tokens.size(); - RemoveRequestFromModel(estate_, rstate[i]->mstates[0]->internal_id, models_); - if (rstate[i]->children_idx.empty()) { + estate_->stats.current_total_seq_len -= + rstate->entries[i]->mstates[0]->committed_tokens.size(); + RemoveRequestFromModel(estate_, rstate->entries[i]->mstates[0]->internal_id, models_); + if (rstate->entries[i]->child_indices.empty()) { // For each running leaf state, length 1 is over reduced since the last // token is not added into KV cache. So we add the length back. ++estate_->stats.current_total_seq_len; diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index 85248062a4..133bc4e6e5 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -16,12 +16,13 @@ void RemoveRequestFromModel(EngineState estate, int64_t req_internal_id, Array models, int max_single_sequence_length) { +void ProcessFinishedRequestStateEntries(std::vector finished_rsentries, + EngineState estate, Array models, + int max_single_sequence_length) { // - Remove the finished request state entries. for (const RequestStateEntry& rsentry : finished_rsentries) { // The finished entry must be a leaf. - ICHECK(rsentry->children_idx.empty()); + ICHECK(rsentry->child_indices.empty()); // Mark the status of this entry as finished. rsentry->status = RequestStateStatus::kFinished; // Remove the request state entry from all the models. @@ -34,8 +35,8 @@ void ProcessFinishedRequestStateEntries(RequestState finished_rsentries, EngineS int parent_idx = rsentry->parent_idx; while (parent_idx != -1) { bool all_children_finished = true; - for (int child_idx : rstate[parent_idx]->children_idx) { - if (rstate[child_idx]->status != RequestStateStatus::kFinished) { + for (int child_idx : rstate->entries[parent_idx]->child_indices) { + if (rstate->entries[child_idx]->status != RequestStateStatus::kFinished) { all_children_finished = false; break; } @@ -46,14 +47,14 @@ void ProcessFinishedRequestStateEntries(RequestState finished_rsentries, EngineS // All the children of the parent request state entry have finished. // So we mark the parent entry as finished. - rstate[parent_idx]->status = RequestStateStatus::kFinished; + rstate->entries[parent_idx]->status = RequestStateStatus::kFinished; // Remove the request state entry from all the models. - RemoveRequestFromModel(estate, rstate[parent_idx]->mstates[0]->internal_id, models); - estate->id_manager.RecycleId(rstate[parent_idx]->mstates[0]->internal_id); + RemoveRequestFromModel(estate, rstate->entries[parent_idx]->mstates[0]->internal_id, models); + estate->id_manager.RecycleId(rstate->entries[parent_idx]->mstates[0]->internal_id); estate->stats.current_total_seq_len -= - static_cast(rstate[parent_idx]->mstates[0]->committed_tokens.size()); + static_cast(rstate->entries[parent_idx]->mstates[0]->committed_tokens.size()); // Climb up to the parent. - parent_idx = rstate[parent_idx]->parent_idx; + parent_idx = rstate->entries[parent_idx]->parent_idx; } if (parent_idx == -1) { @@ -68,14 +69,14 @@ void ProcessFinishedRequestStateEntries(RequestState finished_rsentries, EngineS estate->request_states.erase(rsentry->request->id); // Update engine statistics. - const RequestStateEntry& root_rsentry = rstate[0]; + const RequestStateEntry& root_rsentry = rstate->entries[0]; auto trequest_finish = std::chrono::high_resolution_clock::now(); estate->stats.request_total_prefill_time += static_cast((root_rsentry->tprefill_finish - root_rsentry->tadd).count()) / 1e9; estate->stats.total_prefill_length += rsentry->request->input_total_length; estate->stats.request_total_decode_time += static_cast((trequest_finish - root_rsentry->tprefill_finish).count()) / 1e9; - for (const RequestStateEntry& entry : rstate) { + for (const RequestStateEntry& entry : rstate->entries) { estate->stats.total_decode_length += entry->mstates[0]->committed_tokens.size(); } estate->stats.total_decode_length -= rsentry->request->generation_cfg->n; @@ -106,7 +107,7 @@ void ActionStepPostProcess(Array requests, EngineState estate, Arrayentries[0] : rstate->entries[i + 1]; const DeltaRequestReturn& delta_request_ret = rsentry->GetReturnTokenIds(tokenizer, max_single_sequence_length); group_delta_token_ids.push_back(IntTuple{delta_request_ret.delta_token_ids.begin(), @@ -148,14 +149,14 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, // Find the last alive request state entry, which is what we want to preempt. RequestState rstate = estate->GetRequestState(request); int preempt_rstate_idx = -1; - for (int i = static_cast(rstate.size()) - 1; i >= 0; --i) { - if (rstate[i]->status == RequestStateStatus::kAlive) { + for (int i = static_cast(rstate->entries.size()) - 1; i >= 0; --i) { + if (rstate->entries[i]->status == RequestStateStatus::kAlive) { preempt_rstate_idx = i; break; } } ICHECK_NE(preempt_rstate_idx, -1); - RequestStateEntry rsentry = rstate[preempt_rstate_idx]; + RequestStateEntry rsentry = rstate->entries[preempt_rstate_idx]; // Remove from models. // - Clear model speculation draft. @@ -163,7 +164,7 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, RECORD_EVENT(trace_recorder, rsentry->request->id, "preempt"); rsentry->status = RequestStateStatus::kPending; estate->stats.current_total_seq_len -= rsentry->mstates[0]->committed_tokens.size(); - if (rsentry->children_idx.empty()) { + if (rsentry->child_indices.empty()) { // The length was overly decreased by 1 when the entry has no child. ++estate->stats.current_total_seq_len; } @@ -206,7 +207,7 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, // Remove from running queue. estate->running_queue.erase(estate->running_queue.end() - 1); } - if (preempt_rstate_idx == static_cast(rstate.size()) - 1) { + if (preempt_rstate_idx == static_cast(rstate->entries.size()) - 1) { // Add to the front of waiting queue. estate->waiting_queue.insert(estate->waiting_queue.begin(), request); } diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index bc3d10ee06..aea455a1be 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -63,8 +63,8 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, inline std::vector GetRunningRequestStateEntries(const EngineState& estate) { std::vector rsentries; for (const Request& request : estate->running_queue) { - for (const RequestStateEntry& rsentry : estate->GetRequestState(request)) { - if (rsentry->status == RequestStateStatus::kAlive && rsentry->children_idx.empty()) { + for (const RequestStateEntry& rsentry : estate->GetRequestState(request)->entries) { + if (rsentry->status == RequestStateStatus::kAlive && rsentry->child_indices.empty()) { rsentries.push_back(rsentry); } } diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index 0b23541c22..00bf503969 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -102,12 +102,15 @@ class BatchDecodeActionObj : public EngineActionObj { logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); // - Compute probability distributions. - NDArray probs_device = + NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); // - Sample tokens. - std::vector sample_results = - sampler_->BatchSampleTokens(probs_device, request_ids, generation_cfg, rngs); + // Fill range [0, num_rsentries) into `sample_indices`. + std::vector sample_indices(num_rsentries); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + std::vector sample_results = sampler_->BatchSampleTokens( + probs_on_device, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); // - Update the committed tokens of states. diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index da345b6c89..626e863566 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -111,13 +111,16 @@ class BatchDraftActionObj : public EngineActionObj { logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); // - Compute probability distributions. - NDArray probs_device = + NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); // - Sample tokens. + // Fill range [0, num_rsentries) into `sample_indices`. + std::vector sample_indices(num_rsentries); + std::iota(sample_indices.begin(), sample_indices.end(), 0); std::vector prob_dist; std::vector sample_results = sampler_->BatchSampleTokens( - probs_device, request_ids, generation_cfg, rngs, /*prob_indices=*/nullptr, &prob_dist); + probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 3720340589..fc0d857c00 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -112,11 +112,11 @@ class BatchVerifyActionObj : public EngineActionObj { request_ids, &cum_verify_lengths, &draft_output_tokens); // - Compute probability distributions. - NDArray probs_device = logit_processor_->ComputeProbsFromLogits( + NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits( logits, generation_cfg, request_ids, &cum_verify_lengths); std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokens( - probs_device, request_ids, cum_verify_lengths, generation_cfg, rngs, draft_output_tokens, + probs_on_device, request_ids, cum_verify_lengths, generation_cfg, rngs, draft_output_tokens, draft_output_prob_dist); ICHECK_EQ(sample_results_arr.size(), num_rsentries); diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index 24d431ae7e..b60a125c3f 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -48,14 +48,14 @@ class NewRequestPrefillActionObj : public EngineActionObj { rstates_of_requests.reserve(num_rstates); for (RequestStateEntry rstate : rstates) { const Request& request = rstate->request; - RequestState request_rstates = estate->GetRequestState(request); + RequestState request_rstate = estate->GetRequestState(request); request_ids.push_back(request->id); rstate->status = RequestStateStatus::kAlive; // - Remove the request from waiting queue if all its request states are now alive. // - Add the request to running queue if all its request states were pending. bool alive_state_existed = false; - for (const RequestStateEntry& request_state : request_rstates) { + for (const RequestStateEntry& request_state : request_rstate->entries) { if (request_state->status == RequestStateStatus::kAlive && !request_state.same_as(rstate)) { alive_state_existed = true; } @@ -63,7 +63,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { if (!alive_state_existed) { estate->running_queue.push_back(request); } - rstates_of_requests.push_back(std::move(request_rstates)); + rstates_of_requests.push_back(std::move(request_rstate)); } // - Get embedding and run prefill for each model. @@ -83,9 +83,11 @@ class NewRequestPrefillActionObj : public EngineActionObj { if (rstates[i]->parent_idx == -1) { models_[model_id]->AddNewSequence(mstate->internal_id); } else { - models_[model_id]->ForkSequence( - rstates_of_requests[i][rstates[i]->parent_idx]->mstates[model_id]->internal_id, - mstate->internal_id); + models_[model_id]->ForkSequence(rstates_of_requests[i] + ->entries[rstates[i]->parent_idx] + ->mstates[model_id] + ->internal_id, + mstate->internal_id); } request_internal_ids.push_back(mstate->internal_id); RECORD_EVENT(trace_recorder_, rstates[i]->request->id, "start embedding"); @@ -127,66 +129,67 @@ class NewRequestPrefillActionObj : public EngineActionObj { request_ids); // - Compute probability distributions. - NDArray probs_device = + NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits(logits_for_sample, generation_cfg, request_ids); // - Sample tokens. // For rstates which are depended by other states, sample // one token for each rstate that is depending. // Otherwise, sample a token for the current rstate. - std::vector prob_indices; - RequestState rstates_for_sample; + std::vector sample_indices; + std::vector rsentries_for_sample; std::vector rngs; - prob_indices.reserve(num_rstates); - rstates_for_sample.reserve(num_rstates); + sample_indices.reserve(num_rstates); + rsentries_for_sample.reserve(num_rstates); rngs.reserve(num_rstates); request_ids.clear(); generation_cfg.clear(); for (int i = 0; i < num_rstates; ++i) { estate->stats.current_total_seq_len += prefill_lengths[i]; const RequestStateEntry& rstate = rstates[i]; - for (int child_idx : rstate->children_idx) { - if (rstates_of_requests[i][child_idx]->mstates[0]->committed_tokens.empty()) { + for (int child_idx : rstate->child_indices) { + if (rstates_of_requests[i]->entries[child_idx]->mstates[0]->committed_tokens.empty()) { // If rstates_of_requests[i][child_idx] has no committed token, // the prefill of the current rstate will unblock rstates_of_requests[i][child_idx], // and thus we want to sample a token for rstates_of_requests[i][child_idx]. - prob_indices.push_back(i); - rstates_for_sample.push_back(rstates_of_requests[i][child_idx]); + sample_indices.push_back(i); + rsentries_for_sample.push_back(rstates_of_requests[i]->entries[child_idx]); request_ids.push_back(rstate->request->id); generation_cfg.push_back(rstate->request->generation_cfg); - rngs.push_back(&rstates_of_requests[i][child_idx]->rng); + rngs.push_back(&rstates_of_requests[i]->entries[child_idx]->rng); - ICHECK(rstates_of_requests[i][child_idx]->status == RequestStateStatus::kPending); - rstates_of_requests[i][child_idx]->status = RequestStateStatus::kAlive; + ICHECK(rstates_of_requests[i]->entries[child_idx]->status == + RequestStateStatus::kPending); + rstates_of_requests[i]->entries[child_idx]->status = RequestStateStatus::kAlive; for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { models_[model_id]->ForkSequence( rstate->mstates[model_id]->internal_id, - rstates_of_requests[i][child_idx]->mstates[model_id]->internal_id); + rstates_of_requests[i]->entries[child_idx]->mstates[model_id]->internal_id); } } } - if (rstate->children_idx.empty()) { + if (rstate->child_indices.empty()) { // If rstate has no child, we sample a token for itself. - prob_indices.push_back(i); - rstates_for_sample.push_back(rstate); + sample_indices.push_back(i); + rsentries_for_sample.push_back(rstate); request_ids.push_back(rstate->request->id); generation_cfg.push_back(rstate->request->generation_cfg); rngs.push_back(&rstate->rng); } } - std::vector sample_results = - sampler_->BatchSampleTokens(probs_device, request_ids, generation_cfg, rngs, &prob_indices); - ICHECK_EQ(sample_results.size(), rstates_for_sample.size()); + std::vector sample_results = sampler_->BatchSampleTokens( + probs_on_device, sample_indices, request_ids, generation_cfg, rngs); + ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); // - Update the committed tokens of states. // - If a request is first-time prefilled, set the prefill finish time. auto tnow = std::chrono::high_resolution_clock::now(); - for (int i = 0; i < static_cast(rstates_for_sample.size()); ++i) { - for (const RequestModelState& mstate : rstates_for_sample[i]->mstates) { + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) { mstate->CommitToken(sample_results[i]); } - if (rstates_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { - rstates_for_sample[i]->tprefill_finish = tnow; + if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { + rsentries_for_sample[i]->tprefill_finish = tnow; } } @@ -206,7 +209,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { processed_requests.push_back(rstate->request); bool pending_state_exists = false; - for (const RequestStateEntry& request_state : rstates_of_requests[i]) { + for (const RequestStateEntry& request_state : rstates_of_requests[i]->entries) { if (request_state->status == RequestStateStatus::kPending) { pending_state_exists = true; break; @@ -249,7 +252,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { for (const Request& request : estate->waiting_queue) { RequestState rstate = estate->GetRequestState(request); bool prefill_stops = false; - for (const RequestStateEntry& rsentry : rstate) { + for (const RequestStateEntry& rsentry : rstate->entries) { // A request state entry can be prefilled only when: // - it has inputs, and // - it is pending, and @@ -257,7 +260,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { if (rsentry->mstates[0]->inputs.empty() || rsentry->status != RequestStateStatus::kPending || (rsentry->parent_idx != -1 && - rstate[rsentry->parent_idx]->status == RequestStateStatus::kPending)) { + rstate->entries[rsentry->parent_idx]->status == RequestStateStatus::kPending)) { continue; } @@ -266,12 +269,12 @@ class NewRequestPrefillActionObj : public EngineActionObj { (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; total_input_length += input_length; total_required_pages += num_require_pages; - if (CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->children_idx.size(), + if (CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), total_input_length, total_required_pages, num_available_pages, num_running_rsentries)) { rsentries_to_prefill.push_back(rsentry); prefill_lengths.push_back(input_length); - ++num_prefill_rsentries; + num_prefill_rsentries += 1 + rsentry->child_indices.size(); } else { total_input_length -= input_length; total_required_pages -= num_require_pages; diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 7dc9d0b627..6eca65f05f 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -82,6 +82,8 @@ void RequestModelStateNode::RemoveAllDraftTokens() { } } +/****************** RequestStateEntry ******************/ + TVM_REGISTER_OBJECT_TYPE(RequestStateEntryNode); RequestStateEntry::RequestStateEntry( @@ -189,6 +191,16 @@ DeltaRequestReturn RequestStateEntryNode::GetReturnTokenIds(const Tokenizer& tok return {return_token_ids, logprob_json_strs, Optional()}; } +/****************** RequestState ******************/ + +TVM_REGISTER_OBJECT_TYPE(RequestStateNode); + +RequestState::RequestState(std::vector entries) { + ObjectPtr n = make_object(); + n->entries = std::move(entries); + data_ = std::move(n); +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 66e36d5b93..83a12fade4 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -155,29 +155,33 @@ enum class RequestStateStatus : int { kFinished = 2, }; +/*! + * \brief A request's state entry. It contains the state of a single + * generation of a request, or the state of a prompt prefix of a request. + */ class RequestStateEntryNode : public Object { public: - /*! \brief The status of the request state. */ + /*! \brief The status of the request state entry. */ RequestStateStatus status; /*! \brief The request that this state corresponds to. */ Request request; /*! - * \brief The idx of the parent request state of this state. + * \brief The idx of the parent request state entry of this state. * Being -1 means the state has no parent and is the foremost - * "prefix" state or the only state. + * "prefix" entry or the only entry. */ int parent_idx = -1; - /*! \brief The children indices of the request state. */ - std::vector children_idx; + /*! \brief The children indices of the request state entry. */ + std::vector child_indices; /*! * \brief The state with regard to each model. * \sa RequestModelState */ Array mstates; - /*! \brief The random number generator of this request. */ + /*! \brief The random number generator of this request state entry. */ RandomGenerator rng; - /*! \brief The stop string handler of this request. */ + /*! \brief The stop string handler of this request state entry. */ StopStrHandler stop_str_handler; /*! * \brief The start position of the committed tokens in the @@ -218,7 +222,22 @@ class RequestStateEntry : public ObjectRef { }; /*! \brief A request's state, which groups all the request state entries. */ -typedef std::vector RequestState; +class RequestStateNode : public Object { + public: + std::vector entries; + + static constexpr const char* _type_key = "mlc.serve.RequestState"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_FINAL_OBJECT_INFO(RequestStateNode, Object); +}; + +class RequestState : public ObjectRef { + public: + explicit RequestState(std::vector entries); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestState, ObjectRef, RequestStateNode); +}; } // namespace serve } // namespace llm diff --git a/cpp/serve/sampler.cc b/cpp/serve/sampler.cc index d201158628..4a59cefaff 100644 --- a/cpp/serve/sampler.cc +++ b/cpp/serve/sampler.cc @@ -262,27 +262,24 @@ class CPUSampler : public SamplerObj { } } - std::vector BatchSampleTokens(NDArray probs_device, // + std::vector BatchSampleTokens(NDArray probs_on_device, // + const std::vector& sample_indices, // const Array& request_ids, // const Array& generation_cfg, // const std::vector& rngs, // - const std::vector* prob_indices, // std::vector* output_prob_dist) final { - // probs_device: (n, v) + // probs_on_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); - CHECK_EQ(probs_device->ndim, 2); + CHECK_EQ(probs_on_device->ndim, 2); // - Copy probs to CPU RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); - NDArray probs_host = CopyProbsToCPU(probs_device); + NDArray probs_host = CopyProbsToCPU(probs_on_device); RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); // - Sample tokens from probabilities. int n = request_ids.size(); ICHECK_EQ(generation_cfg.size(), n); ICHECK_EQ(rngs.size(), n); - if (prob_indices == nullptr) { - ICHECK_EQ(probs_host->shape[0], n); - } std::vector sample_results; sample_results.resize(n); @@ -291,12 +288,12 @@ class CPUSampler : public SamplerObj { } tvm::runtime::parallel_for_with_threading_backend( - [this, &sample_results, &probs_host, &generation_cfg, &rngs, &request_ids, prob_indices, + [this, &sample_results, &probs_host, &generation_cfg, &rngs, &request_ids, sample_indices, output_prob_dist](int i) { RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); // Sample top p from probability. sample_results[i].sampled_token_id = SampleTopPFromProb( - probs_host, prob_indices == nullptr ? i : prob_indices->at(i), + probs_host, sample_indices[i], generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, rngs[i]->GetRandomNumber(), output_prob_dist); if (output_prob_dist == nullptr) { @@ -314,17 +311,17 @@ class CPUSampler : public SamplerObj { } std::vector> BatchVerifyDraftTokens( - NDArray probs_device, const Array& request_ids, + NDArray probs_on_device, const Array& request_ids, const std::vector& cum_verify_lengths, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, const std::vector>& draft_output_prob_dist) final { - // probs_device: (n, v) + // probs_on_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); - CHECK_EQ(probs_device->ndim, 2); + CHECK_EQ(probs_on_device->ndim, 2); // - Copy probs to CPU RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); - NDArray probs_host = CopyProbsToCPU(probs_device); + NDArray probs_host = CopyProbsToCPU(probs_on_device); RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); int num_sequence = static_cast(cum_verify_lengths.size()) - 1; @@ -401,26 +398,26 @@ class CPUSampler : public SamplerObj { private: /*! \brief Copy prob distributions from device to CPU. */ - NDArray CopyProbsToCPU(NDArray probs_device) { - // probs_device: (n, v) - ICHECK(probs_device->device.device_type != kDLCPU); + NDArray CopyProbsToCPU(NDArray probs_on_device) { + // probs_on_device: (n, v) + ICHECK(probs_on_device->device.device_type != kDLCPU); if (probs_host_.defined()) { - ICHECK_EQ(probs_host_->shape[1], probs_device->shape[1]); + ICHECK_EQ(probs_host_->shape[1], probs_on_device->shape[1]); } int64_t init_size = probs_host_.defined() ? probs_host_->shape[0] : 32; - int64_t num_tokens = probs_device->shape[0]; - int64_t vocab_size = probs_device->shape[1]; + int64_t num_tokens = probs_on_device->shape[0]; + int64_t vocab_size = probs_on_device->shape[1]; while (init_size < num_tokens) { init_size *= 2; } if (!probs_host_.defined() || init_size != probs_host_->shape[0]) { probs_host_ = - NDArray::Empty({init_size, vocab_size}, probs_device->dtype, DLDevice{kDLCPU, 0}); + NDArray::Empty({init_size, vocab_size}, probs_on_device->dtype, DLDevice{kDLCPU, 0}); } ICHECK_LE(num_tokens, probs_host_->shape[0]); - NDArray view = probs_host_.CreateView({num_tokens, vocab_size}, probs_device->dtype); - view.CopyFrom(probs_device); + NDArray view = probs_host_.CreateView({num_tokens, vocab_size}, probs_on_device->dtype); + view.CopyFrom(probs_on_device); return view; } diff --git a/cpp/serve/sampler.h b/cpp/serve/sampler.h index faa2cffd57..c48702c0c7 100644 --- a/cpp/serve/sampler.h +++ b/cpp/serve/sampler.h @@ -34,36 +34,30 @@ class SamplerObj : public Object { public: /*! * \brief Sample tokens from the input batch of prob distribution on device. - * \param probs_device The prob distributions on GPU to sample tokens from. + * \param probs_on_device The prob distributions on GPU to sample tokens from. + * \param sample_indices Specifying which request we should sample for + * in i-th output. The output result is sample as follow: + * result[i] = sample_from(prob_on_device[sample_indices[i],:], generation_config[i])); * \param request_ids The id of each request. * \param generation_cfg The generation config of each request * in the input batch. * \param rngs The random number generator of each sequence. - * \param prob_indices The indices of probability distribution in `probs_device` - * that each request in `request_ids` samples from. - * It defaults to nullptr, which means each request samples from the - * corresponding index in `prob_indices`. - * In usual cases, we only sample one token for each prob distribution - * in the batch, and `prob_indices` is nullptr in such cases. - * When we want to sample multiple tokens from a prob distribution (e.g., - * starting parallel generation after prefill the input), we use `prob_indices` - * to represent which distribution a token should be sampled from * \param output_prob_dist The output probability distribution * \return The batch of sampling results, which contain the sampled token id * and other probability info. */ virtual std::vector BatchSampleTokens( - NDArray probs_device, // - const Array& request_ids, // - const Array& generation_cfg, // - const std::vector& rngs, // - const std::vector* prob_indices = nullptr, // + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // std::vector* output_prob_dist = nullptr) = 0; /*! * \brief Verify draft tokens generated by small models in the large model * in speculative decoding. The input corresponds to a batch of sequences. - * \param probs_device The prob distributions on GPU to sample tokens from. + * \param probs_on_device The prob distributions on GPU to sample tokens from. * \param request_ids The id of each request. * \param cum_verify_lengths The cumulative draft lengths to verify of all sequences. * \param generation_cfg The generation config of each request @@ -76,7 +70,7 @@ class SamplerObj : public Object { * \return The list of accepted tokens for each request. */ virtual std::vector> BatchVerifyDraftTokens( - NDArray probs_device, const Array& request_ids, + NDArray probs_on_device, const Array& request_ids, const std::vector& cum_verify_lengths, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, From ffef890c0650b5bb521447f0275ea9791092b492 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 5 Mar 2024 03:05:22 +0800 Subject: [PATCH 031/531] [SLM] Update StableLM model and migrate it to paged KV Cache (#1882) --- python/mlc_chat/model/model.py | 8 +- python/mlc_chat/model/model_preset.py | 34 +-- .../model/stable_lm/stablelm_loader.py | 10 +- .../model/stable_lm/stablelm_model.py | 280 ++++++++++++------ .../model/stable_lm/stablelm_quantization.py | 14 +- 5 files changed, 221 insertions(+), 125 deletions(-) diff --git a/python/mlc_chat/model/model.py b/python/mlc_chat/model/model.py index 9c82cfe9cb..e03d89762a 100644 --- a/python/mlc_chat/model/model.py +++ b/python/mlc_chat/model/model.py @@ -222,10 +222,10 @@ class Model: "ft-quant": qwen2_quantization.ft_quant, }, ), - "stablelm_epoch": Model( - name="stablelm_epoch", - model=stablelm_model.StableLMEpochForCausalLM, - config=stablelm_model.StableLMEpochConfig, + "stablelm": Model( + name="stablelm", + model=stablelm_model.StableLmForCausalLM, + config=stablelm_model.StableLmConfig, source={ "huggingface-torch": stablelm_loader.huggingface, "huggingface-safetensor": stablelm_loader.huggingface, diff --git a/python/mlc_chat/model/model_preset.py b/python/mlc_chat/model/model_preset.py index 409112b6b5..9314b1143b 100644 --- a/python/mlc_chat/model/model_preset.py +++ b/python/mlc_chat/model/model_preset.py @@ -416,34 +416,28 @@ "use_sliding_window": False, "vocab_size": 151936, }, - "stablelm_epoch": { - "architectures": ["StableLMEpochForCausalLM"], - "auto_map": { - "AutoConfig": "configuration_stablelm_epoch.StableLMEpochConfig", - "AutoModelForCausalLM": "modeling_stablelm_epoch.StableLMEpochForCausalLM", - }, - "bos_token_id": 100257, - "eos_token_id": 100257, + "stablelm": { + "architectures": ["StableLmForCausalLM"], + "bos_token_id": 0, + "eos_token_id": 0, "hidden_act": "silu", - "hidden_size": 2048, + "hidden_size": 2560, "initializer_range": 0.02, - "intermediate_size": 5632, + "intermediate_size": 6912, "max_position_embeddings": 4096, - "model_type": "stablelm_epoch", - "norm_eps": 1e-05, + "model_type": "stablelm", + "layer_norm_eps": 1e-05, "num_attention_heads": 32, - "num_heads": 32, - "num_hidden_layers": 24, + "num_hidden_layers": 32, "num_key_value_heads": 32, - "rope_pct": 0.25, + "partial_rotary_factor": 0.25, "rope_theta": 10000, - "rotary_scaling_factor": 1.0, - "tie_word_embeddings": True, + "tie_word_embeddings": False, "torch_dtype": "bfloat16", - "transformers_version": "4.36.2", + "transformers_version": "4.38.0", "use_cache": True, - "use_qkv_bias": True, - "vocab_size": 100352, + "use_qkv_bias": False, + "vocab_size": 50304, }, "baichuan": { "architectures": ["BaichuanForCausalLM"], diff --git a/python/mlc_chat/model/stable_lm/stablelm_loader.py b/python/mlc_chat/model/stable_lm/stablelm_loader.py index f635c0ed47..d2cc4d93c8 100644 --- a/python/mlc_chat/model/stable_lm/stablelm_loader.py +++ b/python/mlc_chat/model/stable_lm/stablelm_loader.py @@ -10,17 +10,17 @@ from mlc_chat.loader import ExternMapping from mlc_chat.quantization import Quantization -from .stablelm_model import StableLMEpochConfig, StableLMEpochForCausalLM +from .stablelm_model import StableLmConfig, StableLmForCausalLM -def huggingface(model_config: StableLMEpochConfig, quantization: Quantization) -> ExternMapping: +def huggingface(model_config: StableLmConfig, quantization: Quantization) -> ExternMapping: """Returns a parameter mapping that maps from the names of MLC LLM parameters to the names of HuggingFace PyTorch parameters. Parameters ---------- - model_config : GPT2Config - The configuration of the GPT-2 model. + model_config : StableLmConfig + The configuration of the StableLm model. quantization : Quantization The quantization configuration. @@ -30,7 +30,7 @@ def huggingface(model_config: StableLMEpochConfig, quantization: Quantization) - param_map : ExternMapping The parameter mapping from MLC to HuggingFace PyTorch. """ - model = StableLMEpochForCausalLM(model_config) + model = StableLmForCausalLM(model_config) if quantization is not None: model.to(quantization.model_dtype) _, _named_params, _ = model.export_tvm( # type: ignore[misc] diff --git a/python/mlc_chat/model/stable_lm/stablelm_model.py b/python/mlc_chat/model/stable_lm/stablelm_model.py index 3a5ce65879..7f5e56e819 100644 --- a/python/mlc_chat/model/stable_lm/stablelm_model.py +++ b/python/mlc_chat/model/stable_lm/stablelm_model.py @@ -11,6 +11,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_chat import op as op_ext +from mlc_chat.nn import PagedKVCache, RopeMode from mlc_chat.support import logging from mlc_chat.support.config import ConfigBase from mlc_chat.support.style import bold @@ -19,7 +20,7 @@ @dataclasses.dataclass -class StableLMEpochConfig(ConfigBase): # pylint: disable=too-many-instance-attributes +class StableLmConfig(ConfigBase): # pylint: disable=too-many-instance-attributes """Configuration of the StableLM model.""" vocab_size: int @@ -27,8 +28,8 @@ class StableLMEpochConfig(ConfigBase): # pylint: disable=too-many-instance-attr num_hidden_layers: int num_attention_heads: int num_key_value_heads: int - norm_eps: float - rope_pct: float + layer_norm_eps: float + partial_rotary_factor: float rope_theta: int intermediate_size: int use_qkv_bias: bool = False # Default to False for Stable-LM 3B model @@ -78,16 +79,15 @@ def __post_init__(self): # pylint: disable=invalid-name,missing-docstring -class StableLMAttention(nn.Module): # pylint: disable=too-many-instance-attributes - def __init__(self, config: StableLMEpochConfig): +class StableLmAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: StableLmConfig): self.hidden_size = config.hidden_size self.rope_theta = config.rope_theta - self.rope_pct = config.rope_pct self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.rotary_ndims = int(self.head_dim * config.rope_pct) + self.rotary_ndims = int(config.partial_rotary_factor * self.head_dim) self.qkv_proj = nn.Linear( in_features=config.hidden_size, @@ -103,35 +103,33 @@ def __init__(self, config: StableLMEpochConfig): config.context_window_size, [self.num_key_value_heads, self.head_dim] ) - def forward( # pylint: disable=too-many-locals - self, - hidden_states: Tensor, - attention_mask: Tensor, - total_seq_len: tir.Var, - ): - d, h_q, h_kv, t = self.head_dim, self.num_heads, self.num_key_value_heads, total_seq_len + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads + b, s, _ = hidden_states.shape + qkv = self.qkv_proj(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), + (b, s, h_q * d), + ) + attn_output = self.o_proj(output) + return attn_output + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads b, s, _ = hidden_states.shape - assert b == 1, "Only support batch size 1 at this moment." - # Step 1. QKV Projection qkv = self.qkv_proj(hidden_states) qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) - # Step 2. Apply QK rotary embedding - q, k, v = op_ext.llama_rope( - qkv, t, self.rope_theta, h_q, h_kv, rotary_dim=self.rotary_ndims + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), + (b, s, h_q * d), ) - # Step 3. Query and update KVCache - self.k_cache.append(op.squeeze(k, axis=0)) - self.v_cache.append(op.squeeze(v, axis=0)) - k = self.k_cache.view(t) - v = self.v_cache.view(t) - # Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V - output = op_ext.attention(q, k, v, casual_mask=attention_mask) - # Step 5. Apply output projection - return self.o_proj(output) - - -class StalbeLMMLP(nn.Module): - def __init__(self, config: StableLMEpochConfig): + attn_output = self.o_proj(output) + return attn_output + + +class StableLmMLP(nn.Module): + def __init__(self, config: StableLmConfig): self.intermediate_size = config.intermediate_size self.gate_up_proj = nn.Linear( in_features=config.hidden_size, @@ -146,117 +144,221 @@ def forward(self, x: Tensor): return self.down_proj(op.silu(x1) * x2) -class StableLMDecoderLayer(nn.Module): - def __init__(self, config: StableLMEpochConfig): - norm_eps = config.norm_eps - self.self_attn = StableLMAttention(config) - self.mlp = StalbeLMMLP(config) +class StableLmDecoderLayer(nn.Module): + def __init__(self, config: StableLmConfig): + norm_eps = config.layer_norm_eps + self.self_attn = StableLmAttention(config) + self.mlp = StableLmMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) - def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): - out = self.self_attn(self.input_layernorm(hidden_states), attention_mask, total_seq_len) + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) + hidden_states = out + hidden_states + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = out + hidden_states + return hidden_states + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn.batch_forward( + self.input_layernorm(hidden_states), paged_kv_cache, layer_id + ) hidden_states = out + hidden_states out = self.mlp(self.post_attention_layernorm(hidden_states)) hidden_states = out + hidden_states return hidden_states -class StableLMEpochModel(nn.Module): - def __init__(self, config: StableLMEpochConfig): +class StableLmModel(nn.Module): + def __init__(self, config: StableLmConfig): assert config.hidden_size % config.num_attention_heads == 0 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( - [StableLMDecoderLayer(config) for _ in range(config.num_hidden_layers)] + [StableLmDecoderLayer(config) for _ in range(config.num_hidden_layers)] ) - self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - def forward(self, input_ids: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): - hidden_states = self.embed_tokens(input_ids) - for layer in self.layers: - hidden_states = layer(hidden_states, attention_mask, total_seq_len) + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) hidden_states = self.norm(hidden_states) return hidden_states + def batch_forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + for layer_id, layer in enumerate(self.layers): + hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states -class StableLMEpochForCausalLM(nn.Module): - def __init__(self, config: StableLMEpochConfig): - self.model = StableLMEpochModel(config) + +class StableLmForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: StableLmConfig): + self.model = StableLmModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.vocab_size = config.vocab_size self.dtype = "float32" + self.num_hidden_layers = config.num_hidden_layers + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_attention_heads + self.vocab_size = config.vocab_size + self.rope_theta = config.rope_theta + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + self.partial_rotary_factor = config.partial_rotary_factor def to(self, dtype: Optional[str] = None): super().to(dtype=dtype) if dtype is not None: self.dtype = dtype - def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.model.batch_forward(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + return self.model.embed_tokens(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + def _index(x: te.Tensor): # x[:-1,:] b, s, d = x.shape return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") - hidden_states = self.model(inputs, total_seq_len, attention_mask) + hidden_states = self.model(input_embed, paged_kv_cache) hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") - return logits + return logits, paged_kv_cache - def prefill(self, inputs: Tensor, total_seq_len: tir.Var): - def _attention_mask(batch_size, seq_len, total_seq_len): - return te.compute( - (batch_size, 1, seq_len, total_seq_len), - lambda b, _, i, j: tir.if_then_else( - i < j - (total_seq_len - seq_len), - tir.min_value(self.dtype), - tir.max_value(self.dtype), - ), - name="attention_mask_prefill", - ) + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() - batch_size, seq_len = inputs.shape - attention_mask = op.tensor_expr_op( - _attention_mask, - name_hint="attention_mask_prefill", - args=[batch_size, seq_len, total_seq_len], - ) - return self.forward(inputs, total_seq_len, attention_mask) + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache - def decode(self, inputs: Tensor, total_seq_len: tir.Var): - batch_size, seq_len = inputs.shape - attention_mask = op.full( - shape=[batch_size, 1, seq_len, total_seq_len], - fill_value=tir.max_value(self.dtype), - dtype=self.dtype, - ) - return self.forward(inputs, total_seq_len, attention_mask) + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + rotary_dim=int(self.head_dim * self.partial_rotary_factor), + ) def get_default_spec(self): - batch_size = 1 mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "prefill": { - "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), - "total_seq_len": int, + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", }, }, "decode": { - "inputs": nn.spec.Tensor([batch_size, 1], "int32"), - "total_seq_len": int, + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_chat/model/stable_lm/stablelm_quantization.py b/python/mlc_chat/model/stable_lm/stablelm_quantization.py index 0bb6047d2f..327082aeaa 100644 --- a/python/mlc_chat/model/stable_lm/stablelm_quantization.py +++ b/python/mlc_chat/model/stable_lm/stablelm_quantization.py @@ -7,15 +7,15 @@ from mlc_chat.loader import QuantizeMapping from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize -from .stablelm_model import StableLMEpochConfig, StableLMEpochForCausalLM +from .stablelm_model import StableLmConfig, StableLmForCausalLM def group_quant( - model_config: StableLMEpochConfig, + model_config: StableLmConfig, quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a StableLM-architecture model using group quantization.""" - model: nn.Module = StableLMEpochForCausalLM(model_config) + model: nn.Module = StableLmForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) model = quantization.quantize_model( @@ -27,11 +27,11 @@ def group_quant( def ft_quant( - model_config: StableLMEpochConfig, + model_config: StableLmConfig, quantization: FTQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a StableLM model using FasterTransformer quantization.""" - model: nn.Module = StableLMEpochForCausalLM(model_config) + model: nn.Module = StableLmForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) model = quantization.quantize_model( @@ -43,11 +43,11 @@ def ft_quant( def no_quant( - model_config: StableLMEpochConfig, + model_config: StableLmConfig, quantization: NoQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a StableLM model without quantization.""" - model: nn.Module = StableLMEpochForCausalLM(model_config) + model: nn.Module = StableLmForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) return model, quant_map From ef2db85873d6e9d5c619d9b15cd839564ef8fdc2 Mon Sep 17 00:00:00 2001 From: Diego Cao <50705298+DiegoCao@users.noreply.github.com> Date: Mon, 4 Mar 2024 17:34:59 -0500 Subject: [PATCH 032/531] [KVCache] Qwen 1.0 Model PagedKV Support (#1887) Support Qwen1.0 Paged KV Cache --- docs/compilation/compile_models.rst | 3 +- docs/compilation/convert_weights.rst | 3 +- python/mlc_chat/model/qwen/qwen_model.py | 236 ++++++++++++++++------- 3 files changed, 172 insertions(+), 70 deletions(-) diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index 24ebbed730..855c805094 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -5,8 +5,7 @@ Compile Model Libraries To run a model with MLC LLM in any platform, you need: -1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - `_.) +1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC `__.) 2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). If you are simply adding a model variant, follow :ref:`convert-weights-via-MLC` suffices. diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst index ef39cd9efb..7657bca7d8 100644 --- a/docs/compilation/convert_weights.rst +++ b/docs/compilation/convert_weights.rst @@ -5,8 +5,7 @@ Convert Weights via MLC To run a model with MLC LLM in any platform, you need: -1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - `_.) 2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). In many cases, we only need to convert weights and reuse existing model library. diff --git a/python/mlc_chat/model/qwen/qwen_model.py b/python/mlc_chat/model/qwen/qwen_model.py index ef4caca009..48c66525fb 100644 --- a/python/mlc_chat/model/qwen/qwen_model.py +++ b/python/mlc_chat/model/qwen/qwen_model.py @@ -10,6 +10,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_chat import op as op_ext +from mlc_chat.nn import PagedKVCache, RopeMode from mlc_chat.support import logging from mlc_chat.support.config import ConfigBase from mlc_chat.support.style import bold @@ -80,10 +81,9 @@ class QWenAttention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: QWenConfig): self.hidden_size = config.hidden_size self.rope_theta = config.rotary_emb_base - self.num_heads = config.num_attention_heads + self.num_heads = config.num_attention_heads // config.tensor_parallel_shards self.head_dim = self.hidden_size // self.num_heads self.projection_size = config.kv_channels * config.num_attention_heads - self.c_attn = nn.Linear( in_features=config.hidden_size, out_features=3 * self.projection_size, @@ -98,31 +98,34 @@ def __init__(self, config: QWenConfig): def forward( # pylint: disable=too-many-locals self, hidden_states: Tensor, - attention_mask: Tensor, - total_seq_len: tir.Var, + paged_kv_cache: PagedKVCache, + layer_id: int, ): - d, h, t = self.head_dim, self.num_heads, total_seq_len + d, h = self.head_dim, self.num_heads b, s, _ = hidden_states.shape - assert b == 1, "Only support batch size 1 at this moment." - # Step 1. QKV Projection + qkv = self.c_attn(hidden_states) qkv = op.reshape(qkv, (b, s, 3 * h, d)) - # Step 2. Apply QK rotary embedding - q, k, v = op_ext.llama_rope(qkv, t, self.rope_theta, h, h) - # Step 3. Query and update KVCache - self.k_cache.append(op.squeeze(k, axis=0)) - self.v_cache.append(op.squeeze(v, axis=0)) - k = self.k_cache.view(t) - v = self.v_cache.view(t) - # Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V - output = op_ext.attention(q, k, v, casual_mask=attention_mask) - # Step 5. Apply output projection + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), (b, s, h * d) + ) + return self.c_proj(output) + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + b, s, _ = hidden_states.shape + qkv = self.c_attn(hidden_states) + qkv = op.reshape(qkv, (b, s, 3 * self.head_dim, self.num_heads)) + # try batch forward + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), + (b, s, self.head_dim * self.num_heads), + ) return self.c_proj(output) class QWenMLP(nn.Module): def __init__(self, config: QWenConfig): - self.intermediate_size = config.intermediate_size + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear( in_features=config.hidden_size, out_features=self.intermediate_size, @@ -144,8 +147,15 @@ def __init__(self, config: QWenConfig): self.ln_1 = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) self.ln_2 = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) - def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): - out = self.attn(self.ln_1(hidden_states), attention_mask, total_seq_len) + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.attn(self.ln_1(hidden_states), paged_kv_cache, layer_id) + hidden_states = out + hidden_states + out = self.mlp(self.ln_2(hidden_states)) + hidden_states = out + hidden_states + return hidden_states + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.attn.batch_forward(self.ln_1(hidden_states), paged_kv_cache, layer_id) hidden_states = out + hidden_states out = self.mlp(self.ln_2(hidden_states)) hidden_states = out + hidden_states @@ -159,19 +169,34 @@ def __init__(self, config: QWenConfig): self.h = nn.ModuleList([QWenBlock(config) for _ in range(config.num_hidden_layers)]) self.ln_f = nn.RMSNorm(config.hidden_size, -1, config.layer_norm_epsilon, bias=False) - def forward(self, input_ids: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): - hidden_states = self.wte(input_ids) - for layer in self.h: - hidden_states = layer(hidden_states, attention_mask, total_seq_len) + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + # hidden_states = self.wte(input_ids) + hidden_states = inputs + for layer_id, layer in enumerate(self.h): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) hidden_states = self.ln_f(hidden_states) return hidden_states + def batch_forward(self, inputs, paged_kv_cache: PagedKVCache): + # hidden_states = self.wte(input_ids) + hidden_states = inputs + for layer_id, layer in enumerate(self.h): + hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.ln_f(hidden_states) + return hidden_states -class QWenLMHeadModel(nn.Module): + +class QWenLMHeadModel(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: QWenConfig): self.transformer = QWenModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, dtype="float32") + self.hidden_size = config.hidden_size self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_attention_heads + self.tensor_parallel_shards = config.tensor_parallel_shards + self.rotary_emb_base = config.rotary_emb_base self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -179,72 +204,151 @@ def to(self, dtype: Optional[str] = None): if dtype is not None: self.dtype = dtype - def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + def batch_forward( + self, + inputs: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + hidden_states = self.transformer.batch_forward(inputs, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + return self.transformer.wte(input_ids) + + def prefill(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + def _index(x: te.Tensor): # x[:-1,:] b, s, d = x.shape return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") - hidden_states = self.transformer(inputs, total_seq_len, attention_mask) - hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + hidden_states = self.transformer(inputs, paged_kv_cache) + hidden_states = op.tensor_expr_op( + _index, + name_hint="index", + args=[hidden_states], + ) logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") - return logits + return logits, paged_kv_cache - def prefill(self, inputs: Tensor, total_seq_len: tir.Var): - def _attention_mask(batch_size, seq_len, total_seq_len): - return te.compute( - (batch_size, 1, seq_len, total_seq_len), - lambda b, _, i, j: tir.if_then_else( - i < j - (total_seq_len - seq_len), - tir.min_value(self.dtype), - tir.max_value(self.dtype), - ), - name="attention_mask_prefill", - ) + def decode(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() - batch_size, seq_len = inputs.shape - attention_mask = op.tensor_expr_op( - _attention_mask, - name_hint="attention_mask_prefill", - args=[batch_size, seq_len, total_seq_len], - ) - return self.forward(inputs, total_seq_len, attention_mask) + hidden_states = self.transformer(inputs, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache - def decode(self, inputs: Tensor, total_seq_len: tir.Var): - batch_size, seq_len = inputs.shape - attention_mask = op.full( - shape=[batch_size, 1, seq_len, total_seq_len], - fill_value=tir.max_value(self.dtype), - dtype=self.dtype, - ) - return self.forward(inputs, total_seq_len, attention_mask) + def batch_prefill(self, inputs: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(inputs, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(inputs, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(inputs, paged_kv_cache) + return logits, paged_kv_cache def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rotary_emb_base, + dtype=self.dtype, + ) def get_default_spec(self): - batch_size = 1 mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "prefill": { - "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), - "total_seq_len": int, + "inputs": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", }, }, "decode": { - "inputs": nn.spec.Tensor([batch_size, 1], "int32"), - "total_seq_len": int, + "inputs": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "inputs": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "inputs": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "inputs": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, "$": { "param_mode": "none", "effect_mode": "none", From 25877f9ff909bbe8c6af7301b5334e832e3af373 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 4 Mar 2024 18:17:47 -0500 Subject: [PATCH 033/531] [Serving] Estimate KV cache memory usage with metadata (#1888) Prior to this PR, the serving engine memory usage estimation reads model config for fields such as `num_key_value_heads`, `num_hidden_layers`, etc.. However, since not every model share the same set of config names (#1854), the estimation fails for models that do not have this set of config field names. This PR makes the following changes. First, it attaches these field values into the model's metadata, in which way we unify the field names for different models effectively. Then, when estimating the memory usage, we read these fields from the metadata, rather than model config, so we are safe for the name inconsistency. --- python/mlc_chat/cli/model_metadata.py | 12 +++++++++++ .../rewrite_kv_cache_creation.py | 15 +++++++++++++ python/mlc_chat/serve/engine.py | 21 +++++++++++++------ 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/python/mlc_chat/cli/model_metadata.py b/python/mlc_chat/cli/model_metadata.py index 9939476d98..2ba9e2aa88 100644 --- a/python/mlc_chat/cli/model_metadata.py +++ b/python/mlc_chat/cli/model_metadata.py @@ -1,4 +1,5 @@ """A tool that inspects the metadata of a model lib.""" + import json import math from dataclasses import asdict @@ -120,6 +121,10 @@ def _print_memory_usage_in_json(metadata: Dict[str, Any], config: Dict) -> None: ) +def _print_kv_cache_metadata_in_json(metadata: Dict[str, Any]) -> None: + print(json.dumps(metadata["kv_cache"])) + + def main(): """Entry point for the model metadata tool.""" parser = ArgumentParser(description="A tool that inspects the metadata of a model lib.") @@ -154,6 +159,11 @@ def main(): action="store_true", help="""If set, only inspect the metadata in memory usage and print usage in raw JSON.""", ) + parser.add_argument( + "--print-kv-cache-metadata-in-json-only", + action="store_true", + help="""If set, only inspect the metadata in KV cache and print usage in raw JSON.""", + ) parsed = parser.parse_args() # Load metadata from model lib try: @@ -174,6 +184,8 @@ def main(): _print_memory_usage_in_json(metadata, cfg) elif parsed.memory_only: _report_memory_usage(metadata, cfg) + elif parsed.print_kv_cache_metadata_in_json_only: + _print_kv_cache_metadata_in_json(metadata) else: _report_all(metadata) diff --git a/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py b/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py index 808969ea64..89c2710e32 100644 --- a/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py +++ b/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py @@ -64,6 +64,11 @@ def __init__( flashinfer : bool A boolean indicating if flashinfer is enabled. + + metadata : Dict[str, Any] + The model's metadata for KV cache creation. + Note that the metadata will be updated in this pass -- the + KV cache metadata will be attached. """ self.target = target self.flashinfer = flashinfer @@ -88,12 +93,22 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR new_mod = new_mod.with_attrs(mod.attrs) kwargs = extract_creation_args(creation_func) + self.attach_kv_cache_metadata(kwargs) bb = relax.BlockBuilder(new_mod) self.create_tir_paged_kv_cache(bb, kwargs) self.create_flashinfer_paged_kv_cache(bb, kwargs) return bb.finalize() + def attach_kv_cache_metadata(self, kwargs: Dict[str, Any]): + """Attach the KV cache metadata to model metadata.""" + self.metadata["kv_cache"] = { + "num_hidden_layers": kwargs["num_hidden_layers"], + "num_attention_heads": kwargs["num_attention_heads"], + "num_key_value_heads": kwargs["num_key_value_heads"], + "head_dim": kwargs["head_dim"], + } + def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]) -> None: """Create the TIR-based PagedKVCache""" max_batch_size = relax.Var( diff --git a/python/mlc_chat/serve/engine.py b/python/mlc_chat/serve/engine.py index a55ee09ddb..6343658f51 100644 --- a/python/mlc_chat/serve/engine.py +++ b/python/mlc_chat/serve/engine.py @@ -166,18 +166,27 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals params_bytes += usage_json["params_bytes"] temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) + cmd = [ + sys.executable, + "-m", + "mlc_chat.cli.model_metadata", + model.model_lib_path, + "--print-kv-cache-metadata-in-json", + ] + kv_cache_metadata_str = subprocess.check_output(cmd, universal_newlines=True) + kv_cache_metadata = json.loads(kv_cache_metadata_str) + # Read model config and compute the kv size per token. with open(config_file_path, mode="rt", encoding="utf-8") as file: json_object = json.load(file) model_config = json_object["model_config"] - num_layers = model_config["num_hidden_layers"] - hidden_size = model_config["hidden_size"] - head_dim = model_config["head_dim"] vocab_size = model_config["vocab_size"] - tensor_parallel_shards = model_config["tensor_parallel_shards"] - num_qo_heads = model_config["num_attention_heads"] / tensor_parallel_shards - num_kv_heads = model_config["num_key_value_heads"] / tensor_parallel_shards prefill_chunk_size = model_config["prefill_chunk_size"] + num_layers = kv_cache_metadata["num_hidden_layers"] + head_dim = kv_cache_metadata["head_dim"] + num_qo_heads = kv_cache_metadata["num_attention_heads"] + num_kv_heads = kv_cache_metadata["num_key_value_heads"] + hidden_size = head_dim * num_qo_heads kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25 kv_aux_workspace_bytes += ( (max_num_sequence + 1) * 88 From aeb55f1e721c4b77a3b3bbfc47a193679d2bda08 Mon Sep 17 00:00:00 2001 From: David Pissarra <61968959+davidpissarra@users.noreply.github.com> Date: Tue, 5 Mar 2024 04:40:57 +0000 Subject: [PATCH 034/531] [KVCache] Migrate bigcode arch to PagedKVCache (#1891) Compilation and runtime smooth. I will open follow-up PRs to enable starcoder2 support in the same model definition file --- .../rewrite_kv_cache_creation.py | 2 + .../model/gpt_bigcode/gpt_bigcode_loader.py | 1 + .../model/gpt_bigcode/gpt_bigcode_model.py | 265 ++++++++++++------ 3 files changed, 188 insertions(+), 80 deletions(-) diff --git a/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py b/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py index 89c2710e32..d167a8bf6d 100644 --- a/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py +++ b/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py @@ -147,6 +147,8 @@ def create_flashinfer_paged_kv_cache( "gpt2" in self.metadata["model_type"] ) + # filter by attention group size + or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 8] ): return diff --git a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py index 8d479d3ad8..1504719045 100644 --- a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py +++ b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py @@ -2,6 +2,7 @@ This file specifies how MLC's GPTBigCode parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ + import functools from mlc_chat.loader import ExternMapping diff --git a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py index 10a0291d11..babe901b55 100644 --- a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py @@ -2,6 +2,7 @@ Implementation for GPTBigCode architecture. TODO: add docstring """ + import dataclasses from typing import Any, Dict, Optional @@ -10,6 +11,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_chat import op as op_ext +from mlc_chat.nn import PagedKVCache, RopeMode from mlc_chat.support import logging from mlc_chat.support import tensor_parallel as tp from mlc_chat.support.config import ConfigBase @@ -109,34 +111,44 @@ def __init__(self, config: GPTBigCodeConfig): self.k_cache = nn.KVCache(config.context_window_size, [self.num_kv_heads, self.head_dim]) self.v_cache = nn.KVCache(config.context_window_size, [self.num_kv_heads, self.head_dim]) - def forward( # pylint: disable=too-many-locals + def forward( self, hidden_states: Tensor, - attention_mask: Tensor, - total_seq_len: tir.Var, + paged_kv_cache: PagedKVCache, + layer_id: int, ): - d, h_q, h_kv, t = self.head_dim, self.num_q_heads, self.num_kv_heads, total_seq_len + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads b, s, _ = hidden_states.shape - assert b == 1, "Only support batch size 1 at this moment." + # QKV Projection qkv = self.c_attn(hidden_states) - qkv = op.reshape(qkv, (b, s, h_q + 2 * h_kv, d)) - q, k, v = op.split(qkv, indices_or_sections=[h_q, h_q + h_kv], axis=2) - - self.k_cache.append(op.squeeze(k, axis=0)) - self.v_cache.append(op.squeeze(v, axis=0)) - k = self.k_cache.view(t) - v = self.v_cache.view(t) - output = op_ext.attention(q, k, v, casual_mask=attention_mask) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, h_q), (b, s, h_q * d) + ) + return self.c_proj(output) + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads + b, s, _ = hidden_states.shape + + # QKV Projection + qkv = self.c_attn(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, h_q), (b, s, h_q * d) + ) return self.c_proj(output) class GPTBigCodeBlock(nn.Module): def __init__(self, config: GPTBigCodeConfig): - self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.attn = GPTBigCodeAttention(config) - self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.mlp = GPTBigCodeMLP(config) + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) def _set_tp(): def _set(layer, hint): @@ -154,11 +166,18 @@ def _set(layer, hint): self.tensor_parallel_shards = config.tensor_parallel_shards _set_tp() - def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): - hidden_states = ( - self.attn(self.ln_1(hidden_states), attention_mask, total_seq_len) + hidden_states - ) - hidden_states = self.mlp(self.ln_2(hidden_states)) + hidden_states + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.attn(self.ln_1(hidden_states), paged_kv_cache, layer_id) + hidden_states = out + hidden_states + out = self.mlp(self.ln_2(hidden_states)) + hidden_states = out + hidden_states + return hidden_states + + def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.attn.batch_forward(self.ln_1(hidden_states), paged_kv_cache, layer_id) + hidden_states = out + hidden_states + out = self.mlp(self.ln_2(hidden_states)) + hidden_states = out + hidden_states return hidden_states @@ -171,42 +190,50 @@ def __init__(self, config: GPTBigCodeConfig): self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.tensor_parallel_shards = config.tensor_parallel_shards - def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): if self.tensor_parallel_shards > 1: - inputs = op.ccl_broadcast_from_worker0(inputs) - - # Token Embeddings - t_embd = self.wte(inputs) + input_embed = op.ccl_broadcast_from_worker0(input_embed) # Position Embeddings - # Generate np.arange(offset, offset+seq_len) - def _input_positions(inputs: te.Tensor, total_seq_len: tir.Var): - b, s = inputs.shape - offset = total_seq_len - s - return te.compute( - (b, s), lambda _, j: (offset + j).astype("int32"), name="input_positions" - ) + # shape[1] indicates the total query length in the batch + input_positions = paged_kv_cache.get_query_positions(input_embed.shape[1]) + pos_embd = self.wpe(input_positions) - input_positions = op.tensor_expr_op( - _input_positions, - name_hint="input_positions", - args=[inputs, total_seq_len], - ) + # apply position embeddings + hidden_states = input_embed + pos_embd + for layer_id, layer in enumerate(self.h): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.ln_f(hidden_states) + + return hidden_states + + def batch_forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + if self.tensor_parallel_shards > 1: + input_embed = op.ccl_broadcast_from_worker0(input_embed) + + # Position Embeddings + # shape[1] indicates the total query length in the batch + input_positions = paged_kv_cache.get_query_positions(input_embed.shape[1]) pos_embd = self.wpe(input_positions) # apply position embeddings - hidden_states = t_embd + pos_embd - for layer in self.h: - hidden_states = layer(hidden_states, attention_mask, total_seq_len) + hidden_states = input_embed + pos_embd + for layer_id, layer in enumerate(self.h): + hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) hidden_states = self.ln_f(hidden_states) return hidden_states -class GPTBigCodeForCausalLM(nn.Module): +class GPTBigCodeForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: GPTBigCodeConfig): self.transformer = GPTBigCodeModel(config) self.lm_head = nn.Linear(config.n_embd, "vocab_size", bias=False) + self.n_layer = config.n_layer + self.n_embd = config.n_embd + self.num_q_heads = config.n_head // config.tensor_parallel_shards + self.num_kv_heads = 1 + self.head_dim = config.n_embd // config.n_head self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -214,72 +241,150 @@ def to(self, dtype: Optional[str] = None): if dtype is not None: self.dtype = dtype - def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + def batch_forward( + self, + input_embed: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.transformer.batch_forward(input_embed, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + return self.transformer.wte(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + def _index(x: te.Tensor): # x[:-1,:] b, s, d = x.shape return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") - hidden_states = self.transformer(inputs, total_seq_len, attention_mask) + hidden_states = self.transformer(input_embed, paged_kv_cache) hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") - return logits + return logits, paged_kv_cache - def prefill(self, inputs: Tensor, total_seq_len: tir.Var): - def _attention_mask(batch_size, seq_len, total_seq_len): - return te.compute( - (batch_size, 1, seq_len, total_seq_len), - lambda b, _, i, j: tir.if_then_else( - i < j - (total_seq_len - seq_len), - tir.min_value(self.dtype), - tir.max_value(self.dtype), - ), - name="attention_mask_prefill", - ) + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() - batch_size, seq_len = inputs.shape - attention_mask = op.tensor_expr_op( - _attention_mask, - name_hint="attention_mask_prefill", - args=[batch_size, seq_len, total_seq_len], - ) - return self.forward(inputs, total_seq_len, attention_mask) + hidden_states = self.transformer(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache - def decode(self, inputs: Tensor, total_seq_len: tir.Var): - batch_size, seq_len = inputs.shape - attention_mask = op.full( - shape=[batch_size, 1, seq_len, total_seq_len], - fill_value=tir.max_value(self.dtype), - dtype=self.dtype, - ) - return self.forward(inputs, total_seq_len, attention_mask) + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=self.n_layer, + num_attention_heads=self.num_q_heads, + num_key_value_heads=self.num_kv_heads, + head_dim=self.head_dim, + rope_mode=RopeMode.NONE, + rope_scale=-1, + rope_theta=-1, + dtype=self.dtype, + ) def get_default_spec(self): - batch_size = 1 mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "prefill": { - "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), - "total_seq_len": int, + "input_embed": nn.spec.Tensor([1, "seq_len", self.n_embd], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", }, }, "decode": { - "inputs": nn.spec.Tensor([batch_size, 1], "int32"), - "total_seq_len": int, + "input_embed": nn.spec.Tensor([1, 1, self.n_embd], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.n_embd], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.n_embd], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.n_embd], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, "$": { "param_mode": "none", "effect_mode": "none", From e7b6cbc9f22eba224914272585688e133e0fc1ea Mon Sep 17 00:00:00 2001 From: Kartik Khandelwal Date: Mon, 4 Mar 2024 21:08:05 -0800 Subject: [PATCH 035/531] [Serving] Add Phi-2 conv template to mlc serve (#1890) This PR adds the phi-2 model template to MLC serve. For testing 1. Start server ```python -m mlc_chat.serve.server --model ./dist/phi-2-q4f16_1-MLC/ --model-lib-path ./dist/phi-2-q4f16_1-MLC/phi-2-q4f16_1-cuda.so --device auto --max-batch-size 2 --enable-tracing --host 127.0.0.1 --port 8000 --max-total-seq-length 8000``` 2. Send request ```python test_server_rest_api.py``` ```python # test_server_rest_api.py import requests import json model = "./dist/phi-2-q4f16_1-MLC/" port = 8000 payload = { "model": f"{model}", "messages": [{"role": "user", "content": "Tell me about Machine Learning in 200 words."}], "stream": False, } r = requests.post(f"http://127.0.0.1:{port}/v1/chat/completions", json=payload) if r.status_code != 200: print(r.json()) else: print(r.json()["choices"][0]["message"]["content"]) ``` --- python/mlc_chat/conversation_template.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/mlc_chat/conversation_template.py b/python/mlc_chat/conversation_template.py index a5dd9dfe6a..7192cc818b 100644 --- a/python/mlc_chat/conversation_template.py +++ b/python/mlc_chat/conversation_template.py @@ -114,3 +114,22 @@ def get_conv_template(name: str) -> Optional[Conversation]: stop_token_ids=[2], ) ) + +# Phi-2 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="phi-2", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={ + "user": "Instruct", + "assistant": "Output", + "tool": "Instruct", + }, + seps=["\n"], + role_content_sep=": ", + role_empty_sep=":", + stop_str=["<|endoftext|>"], + stop_token_ids=[50256], + ) +) From 8a8c529711b9c74ed2c50c0a622d39de2ea30733 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 5 Mar 2024 07:09:38 -0500 Subject: [PATCH 036/531] [Attn] Fix attention kernel for head dim not divisble by 32 (#1889) Prior to this PR, our TIR prefill attention kernel assumes the head dim to be a multiple of 32. As reported by #1826, this assumption does not always hold. This PR fixes this issue so that models with different head dim can also compile. --- python/mlc_chat/nn/kv_cache.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mlc_chat/nn/kv_cache.py b/python/mlc_chat/nn/kv_cache.py index cb0e000b87..4f14774338 100644 --- a/python/mlc_chat/nn/kv_cache.py +++ b/python/mlc_chat/nn/kv_cache.py @@ -779,7 +779,7 @@ def apply_to_so_ewise(sch: tir.Schedule, block, tile): yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[num_warps, bdx]) + ty, tx = sch.split(t, factors=[None, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") @@ -791,7 +791,7 @@ def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[num_warps, bdx]) + ty, tx = sch.split(t, factors=[None, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") @@ -1425,7 +1425,7 @@ def apply_to_so_ewise(sch: tir.Schedule, block, tile): yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[num_warps, bdx]) + ty, tx = sch.split(t, factors=[None, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") @@ -1437,7 +1437,7 @@ def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) t = sch.fuse(xo, yo) - ty, tx = sch.split(t, factors=[num_warps, bdx]) + ty, tx = sch.split(t, factors=[None, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") From b345a9e10881deeed6e5297c076328c3ad27c074 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 5 Mar 2024 12:49:26 -0500 Subject: [PATCH 037/531] [Python] Enable "thrust" for CUDA by default (#1866) This PR enables thrust for CUDA targets so that we can dispatch some operators (e.g., cumsum) to thrust. --- cmake/gen_cmake_config.py | 6 ++++++ python/mlc_chat/support/auto_target.py | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/cmake/gen_cmake_config.py b/cmake/gen_cmake_config.py index c2d9263dc3..f12983c441 100644 --- a/cmake/gen_cmake_config.py +++ b/cmake/gen_cmake_config.py @@ -31,11 +31,14 @@ ), ] + enabled_backends = set() + for backend in backends: while True: use_backend = input(backend.prompt_str) if use_backend in ["yes", "Y", "y"]: cmake_config_str += f"set({backend.cmake_config_name} ON)\n" + enabled_backends.add(backend.name) break elif use_backend in ["no", "N", "n"]: cmake_config_str += f"set({backend.cmake_config_name} OFF)\n" @@ -43,6 +46,9 @@ else: print(f"Invalid input: {use_backend}. Please input again.") + if "CUDA" in enabled_backends: + cmake_config_str += f"set(USE_THRUST ON)\n" + # FlashInfer related use_flashInfer = False # pylint: disable=invalid-name while True: diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_chat/support/auto_target.py index 80041db7f7..a4bb853bc7 100644 --- a/python/mlc_chat/support/auto_target.py +++ b/python/mlc_chat/support/auto_target.py @@ -1,4 +1,5 @@ """Helper functions for target auto-detection.""" + import os from typing import TYPE_CHECKING, Callable, List, Optional, Tuple @@ -42,6 +43,12 @@ def detect_target_and_host(target_hint: str, host_hint: str = "auto") -> Tuple[T if target.host is None: target = Target(target, host=_detect_target_host(host_hint)) if target.kind.name == "cuda": + # Enable thrust for CUDA + target_dict = dict(target.export()) + target_dict["libs"] = ( + (target_dict["libs"] + ["thrust"]) if "libs" in target_dict else ["thrust"] + ) + target = Target(target_dict) _register_cuda_hook(target) return target, build_func From 2f26e05d4ca1beb006099d0c1f1370a73e0f13ba Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 5 Mar 2024 16:46:14 -0800 Subject: [PATCH 038/531] [Serving] Fix loading presharded weights (#1894) --- cpp/serve/function_table.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 512fc21333..39214d6e8a 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -149,7 +149,10 @@ ObjectRef FunctionTable::LoadParams(const std::string& model_path, Device device DRef loader = loader_create(metadata_path, ndarray_cache_metadata, "", this->disco_mod); params = loader_load_all(loader); } else { - PackedFunc loader = this->get_global_func("mlc.loader.LoadMultiGPU"); + auto load_func_name = getenv("MLC_INTERNAL_PRESHARD_NUM") == nullptr + ? "mlc.loader.LoadMultiGPU" + : "mlc.loader.LoadMultiGPUPresharded"; + PackedFunc loader = this->get_global_func(load_func_name); params = loader(model_path, this->disco_mod, picojson::value(this->model_config).serialize()); } return params; From a41f9037c4a7d971f6ba70be935d7bf2b453f635 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 7 Mar 2024 09:07:16 -0500 Subject: [PATCH 039/531] [Serving] Address embedding lookup OOM issue (#1899) This PR addresses the OOM issue that may be caused by embedding lookup when the batch size of a prefill action is large. Prior to this PR, a large embedding tensor will be created for each sequence in the prefilled batch, thus may take unexpectedly large memory when the batch size is large. --- cpp/serve/data.cc | 6 +- cpp/serve/data.h | 19 +- cpp/serve/engine.cc | 26 +- cpp/serve/engine_actions/action.h | 6 +- cpp/serve/engine_actions/batch_decode.cc | 8 +- cpp/serve/engine_actions/batch_draft.cc | 9 +- cpp/serve/engine_actions/batch_verify.cc | 2 +- .../engine_actions/new_request_prefill.cc | 161 ++++++------ cpp/serve/function_table.cc | 7 +- cpp/serve/function_table.h | 5 +- cpp/serve/logit_processor.cc | 2 +- cpp/serve/model.cc | 230 +++++++++--------- cpp/serve/model.h | 36 ++- .../compiler_pass/attach_to_ir_module.py | 35 ++- ...ation.py => dispatch_kv_cache_creation.py} | 4 +- python/mlc_chat/compiler_pass/pipeline.py | 6 +- python/mlc_chat/serve/engine.py | 2 +- 17 files changed, 323 insertions(+), 241 deletions(-) rename python/mlc_chat/compiler_pass/{rewrite_kv_cache_creation.py => dispatch_kv_cache_creation.py} (97%) diff --git a/cpp/serve/data.cc b/cpp/serve/data.cc index 3e56ad6ec3..e6155061db 100644 --- a/cpp/serve/data.cc +++ b/cpp/serve/data.cc @@ -31,7 +31,7 @@ int TextDataNode::GetLength() const { "Please tokenize the text and construct a TokenData object."; } -NDArray TextDataNode::GetEmbedding(Model model) const { +ObjectRef TextDataNode::GetEmbedding(Model model, ObjectRef* dst, int offset) const { LOG(FATAL) << "\"GetEmbedding\" for TextData is not supported. " "Please tokenize the text and construct a TokenData object."; } @@ -62,7 +62,9 @@ TokenData::TokenData(std::vector token_ids) { int TokenDataNode::GetLength() const { return token_ids.size(); } -NDArray TokenDataNode::GetEmbedding(Model model) const { return model->TokenEmbed(token_ids); } +ObjectRef TokenDataNode::GetEmbedding(Model model, ObjectRef* dst, int offset) const { + return model->TokenEmbed(token_ids, dst, offset); +} TVM_REGISTER_GLOBAL("mlc.serve.TokenData").set_body([](TVMArgs args, TVMRetValue* rv) { std::vector token_ids; diff --git a/cpp/serve/data.h b/cpp/serve/data.h index ba92c662eb..b9558b8fad 100644 --- a/cpp/serve/data.h +++ b/cpp/serve/data.h @@ -29,8 +29,19 @@ class DataNode : public Object { /*! \brief Get the length (equivalent number of tokens) of the data. */ virtual int GetLength() const = 0; - /*! \brief Compute the embedding of this data with regard to the input model. */ - virtual NDArray GetEmbedding(Model model) const = 0; + /*! + * \brief Compute the embedding of this data with regard to the input model. + * When the input destination pointer is not nullptr, it in-place writes the + * embedding into the input destination array at the given offset. + * Otherwise, the embeddings will be directly returned back. + * \param model The model to take embeddings from. + * \param dst The destination array of the embedding lookup. + * \param offset The token offset where the computed embeddings will be written + * into the destination array. + * \return The updated destination embedding array or the computed embeddings. + * \note When `dst` is nullptr, we require `offset` to be 0. + */ + virtual ObjectRef GetEmbedding(Model model, ObjectRef* dst = nullptr, int offset = 0) const = 0; static constexpr const char* _type_key = "mlc.serve.Data"; static constexpr const bool _type_has_method_sequal_reduce = false; @@ -52,7 +63,7 @@ class TextDataNode : public DataNode { String text; int GetLength() const final; - NDArray GetEmbedding(Model model) const final; + ObjectRef GetEmbedding(Model model, ObjectRef* dst = nullptr, int offset = 0) const final; static constexpr const char* _type_key = "mlc.serve.TextData"; TVM_DECLARE_BASE_OBJECT_INFO(TextDataNode, DataNode); @@ -74,7 +85,7 @@ class TokenDataNode : public DataNode { IntTuple token_ids; int GetLength() const final; - NDArray GetEmbedding(Model model) const final; + ObjectRef GetEmbedding(Model model, ObjectRef* dst = nullptr, int offset = 0) const final; static constexpr const char* _type_key = "mlc.serve.TokenData"; TVM_DECLARE_BASE_OBJECT_INFO(TokenDataNode, DataNode); diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 56cab63927..f043b4bcac 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -63,6 +63,7 @@ class EngineImpl : public Engine { // Step 2. Initialize each model independently. // Create the logit processor and sampler. this->models_.clear(); + this->model_workspaces_.clear(); for (const auto& model_info : model_infos) { TVMArgValue model_lib = std::get<0>(model_info); String model_path = std::get<1>(model_info); @@ -75,6 +76,7 @@ class EngineImpl : public Engine { << ", is smaller than the pre-defined max single sequence length, " << this->max_single_sequence_length_; this->models_.push_back(model); + this->model_workspaces_.push_back(ModelWorkspace{model->AllocEmbeddingTensor()}); } int max_logit_processor_num_token = kv_cache_config_->max_num_sequence; if (engine_mode_->enable_speculative) { @@ -88,22 +90,24 @@ class EngineImpl : public Engine { // Speculative decoding is only possible for more than one model. ICHECK_GT(this->models_.size(), 1U); this->actions_ = { - EngineAction::NewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->kv_cache_config_, // - this->engine_mode_, // + EngineAction::NewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + this->kv_cache_config_, // + this->engine_mode_, // this->trace_recorder_), EngineAction::BatchDraft(this->models_, logit_processor, sampler, this->trace_recorder_, this->engine_mode_->spec_draft_length), EngineAction::BatchVerify(this->models_, logit_processor, sampler, this->kv_cache_config_, this->trace_recorder_)}; } else { - this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->kv_cache_config_, // - this->engine_mode_, // + this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + this->kv_cache_config_, // + this->engine_mode_, // this->trace_recorder_), EngineAction::BatchDecode(this->models_, logit_processor, sampler, this->trace_recorder_)}; @@ -250,6 +254,8 @@ class EngineImpl : public Engine { std::shared_ptr json_grammar_state_init_ctx_; // Models Array models_; + // Workspace of each model. + std::vector model_workspaces_; // Request stream callback function Optional request_stream_callback_; // Engine actions. diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index d6bd611802..7a5e217569 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -55,14 +55,16 @@ class EngineAction : public ObjectRef { * \param models The models to run prefill in. * \param logit_processor The logit processor. * \param sampler The sampler to sample new tokens. + * \param model_workspaces The workspace of each model. * \param kv_cache_config The KV cache config to help decide prefill is doable. * \param engine_mode The engine operation mode. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction NewRequestPrefill(Array models, LogitProcessor logit_processor, - Sampler sampler, KVCacheConfig kv_cache_config, - EngineMode engine_mode, + Sampler sampler, + std::vector model_workspaces, + KVCacheConfig kv_cache_config, EngineMode engine_mode, Optional trace_recorder); /*! * \brief Create the action that runs one-step decode for requests in the diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index 00bf503969..23b2e6bca4 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -81,20 +81,16 @@ class BatchDecodeActionObj : public EngineActionObj { // - Compute embeddings. RECORD_EVENT(trace_recorder_, request_ids, "start embedding"); - NDArray embeddings = + ObjectRef embeddings = models_[0]->TokenEmbed({IntTuple{input_tokens.begin(), input_tokens.end()}}); RECORD_EVENT(trace_recorder_, request_ids, "finish embedding"); - ICHECK_EQ(embeddings->ndim, 3); - ICHECK_EQ(embeddings->shape[0], 1); - ICHECK_EQ(embeddings->shape[1], num_rsentries); - embeddings = embeddings.CreateView({num_rsentries, 1, embeddings->shape[2]}, embeddings->dtype); // - Invoke model decode. RECORD_EVENT(trace_recorder_, request_ids, "start decode"); NDArray logits = models_[0]->BatchDecode(embeddings, request_internal_ids); RECORD_EVENT(trace_recorder_, request_ids, "finish decode"); ICHECK_EQ(logits->ndim, 3); - ICHECK_EQ(logits->shape[0], embeddings->shape[0]); + ICHECK_EQ(logits->shape[0], num_rsentries); ICHECK_EQ(logits->shape[1], 1); // - Update logits. diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index 626e863566..617d826296 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -89,21 +89,16 @@ class BatchDraftActionObj : public EngineActionObj { // - Compute embeddings. RECORD_EVENT(trace_recorder_, request_ids, "start proposal embedding"); - NDArray embeddings = + ObjectRef embeddings = models_[model_id]->TokenEmbed({IntTuple{input_tokens.begin(), input_tokens.end()}}); RECORD_EVENT(trace_recorder_, request_ids, "finish proposal embedding"); - ICHECK_EQ(embeddings->ndim, 3); - ICHECK_EQ(embeddings->shape[0], 1); - ICHECK_EQ(embeddings->shape[1], num_rsentries); - embeddings = - embeddings.CreateView({num_rsentries, 1, embeddings->shape[2]}, embeddings->dtype); // - Invoke model decode. RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); NDArray logits = models_[model_id]->BatchDecode(embeddings, request_internal_ids); RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); ICHECK_EQ(logits->ndim, 3); - ICHECK_EQ(logits->shape[0], embeddings->shape[0]); + ICHECK_EQ(logits->shape[0], num_rsentries); ICHECK_EQ(logits->shape[1], 1); // - Update logits. diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index fc0d857c00..79c2a17b95 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -89,7 +89,7 @@ class BatchVerifyActionObj : public EngineActionObj { } RECORD_EVENT(trace_recorder_, request_ids, "start verify embedding"); - NDArray embeddings = models_[verify_model_id_]->TokenEmbed( + ObjectRef embeddings = models_[verify_model_id_]->TokenEmbed( {IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}}); RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding"); diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index b60a125c3f..9a2722ff1c 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -20,81 +20,87 @@ namespace serve { class NewRequestPrefillActionObj : public EngineActionObj { public: explicit NewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, - Sampler sampler, KVCacheConfig kv_cache_config, - EngineMode engine_mode, + Sampler sampler, std::vector model_workspaces, + KVCacheConfig kv_cache_config, EngineMode engine_mode, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), kv_cache_config_(std::move(kv_cache_config)), engine_mode_(std::move(engine_mode)), trace_recorder_(std::move(trace_recorder)) {} Array Step(EngineState estate) final { // - Find the requests in `waiting_queue` that can prefill in this step. - auto [rstates, prefill_lengths] = GetRequestStatesToPrefill(estate); - ICHECK_EQ(rstates.size(), prefill_lengths.size()); - if (rstates.empty()) { + auto [rsentries, prefill_lengths] = GetRequestStateEntriesToPrefill(estate); + ICHECK_EQ(rsentries.size(), prefill_lengths.size()); + if (rsentries.empty()) { return {}; } - int num_rstates = rstates.size(); + int num_rsentries = rsentries.size(); auto tstart = std::chrono::high_resolution_clock::now(); // - Update status of request states from pending to alive. Array request_ids; - std::vector rstates_of_requests; - request_ids.reserve(num_rstates); - rstates_of_requests.reserve(num_rstates); - for (RequestStateEntry rstate : rstates) { - const Request& request = rstate->request; + std::vector rstates_of_entries; + request_ids.reserve(num_rsentries); + rstates_of_entries.reserve(num_rsentries); + for (RequestStateEntry rsentry : rsentries) { + const Request& request = rsentry->request; RequestState request_rstate = estate->GetRequestState(request); request_ids.push_back(request->id); - rstate->status = RequestStateStatus::kAlive; + rsentry->status = RequestStateStatus::kAlive; // - Remove the request from waiting queue if all its request states are now alive. // - Add the request to running queue if all its request states were pending. bool alive_state_existed = false; - for (const RequestStateEntry& request_state : request_rstate->entries) { - if (request_state->status == RequestStateStatus::kAlive && !request_state.same_as(rstate)) { + for (const RequestStateEntry& rsentry_ : request_rstate->entries) { + if (rsentry_->status == RequestStateStatus::kAlive && !rsentry_.same_as(rsentry)) { alive_state_existed = true; } } if (!alive_state_existed) { estate->running_queue.push_back(request); } - rstates_of_requests.push_back(std::move(request_rstate)); + rstates_of_entries.push_back(std::move(request_rstate)); } // - Get embedding and run prefill for each model. NDArray logits_for_sample{nullptr}; for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { - Array embeddings; std::vector request_internal_ids; - embeddings.reserve(num_rstates); - request_internal_ids.reserve(num_rstates); - for (int i = 0; i < num_rstates; ++i) { - RequestModelState mstate = rstates[i]->mstates[model_id]; + request_internal_ids.reserve(num_rsentries); + ObjectRef embeddings = model_workspaces_[model_id].embeddings; + int cum_prefill_length = 0; + bool single_input = num_rsentries == 1 && rsentries[0]->mstates[model_id]->inputs.size() == 1; + for (int i = 0; i < num_rsentries; ++i) { + RequestModelState mstate = rsentries[i]->mstates[model_id]; ICHECK_EQ(mstate->GetInputLength(), prefill_lengths[i]); ICHECK(mstate->draft_output_tokens.empty()); ICHECK(mstate->draft_output_prob_dist.empty()); ICHECK(!mstate->inputs.empty()); // Add the sequence to the model, or fork the sequence from its parent. - if (rstates[i]->parent_idx == -1) { + if (rsentries[i]->parent_idx == -1) { models_[model_id]->AddNewSequence(mstate->internal_id); } else { - models_[model_id]->ForkSequence(rstates_of_requests[i] - ->entries[rstates[i]->parent_idx] + models_[model_id]->ForkSequence(rstates_of_entries[i] + ->entries[rsentries[i]->parent_idx] ->mstates[model_id] ->internal_id, mstate->internal_id); } request_internal_ids.push_back(mstate->internal_id); - RECORD_EVENT(trace_recorder_, rstates[i]->request->id, "start embedding"); + RECORD_EVENT(trace_recorder_, rsentries[i]->request->id, "start embedding"); for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { - embeddings.push_back(mstate->inputs[i]->GetEmbedding(models_[model_id])); + embeddings = + mstate->inputs[i]->GetEmbedding(models_[model_id], + /*dst=*/!single_input ? &embeddings : nullptr, + /*offset=*/cum_prefill_length); + cum_prefill_length += mstate->inputs[i]->GetLength(); } - RECORD_EVENT(trace_recorder_, rstates[i]->request->id, "finish embedding"); + RECORD_EVENT(trace_recorder_, rsentries[i]->request->id, "finish embedding"); // Clean up `inputs` after prefill mstate->inputs.clear(); } @@ -105,7 +111,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "finish prefill"); ICHECK_EQ(logits->ndim, 3); ICHECK_EQ(logits->shape[0], 1); - ICHECK_EQ(logits->shape[1], num_rstates); + ICHECK_EQ(logits->shape[1], num_rsentries); if (model_id == 0) { // We only need to sample for model 0 in prefill. @@ -117,13 +123,13 @@ class NewRequestPrefillActionObj : public EngineActionObj { ICHECK(logits_for_sample.defined()); Array generation_cfg; Array mstates_for_logitproc; - generation_cfg.reserve(num_rstates); - mstates_for_logitproc.reserve(num_rstates); - for (int i = 0; i < num_rstates; ++i) { - generation_cfg.push_back(rstates[i]->request->generation_cfg); - mstates_for_logitproc.push_back(rstates[i]->mstates[0]); + generation_cfg.reserve(num_rsentries); + mstates_for_logitproc.reserve(num_rsentries); + for (int i = 0; i < num_rsentries; ++i) { + generation_cfg.push_back(rsentries[i]->request->generation_cfg); + mstates_for_logitproc.push_back(rsentries[i]->mstates[0]); } - logits_for_sample = logits_for_sample.CreateView({num_rstates, logits_for_sample->shape[2]}, + logits_for_sample = logits_for_sample.CreateView({num_rsentries, logits_for_sample->shape[2]}, logits_for_sample->dtype); logit_processor_->InplaceUpdateLogits(logits_for_sample, generation_cfg, mstates_for_logitproc, request_ids); @@ -133,48 +139,48 @@ class NewRequestPrefillActionObj : public EngineActionObj { logit_processor_->ComputeProbsFromLogits(logits_for_sample, generation_cfg, request_ids); // - Sample tokens. - // For rstates which are depended by other states, sample + // For rsentries which have children, sample // one token for each rstate that is depending. // Otherwise, sample a token for the current rstate. std::vector sample_indices; std::vector rsentries_for_sample; std::vector rngs; - sample_indices.reserve(num_rstates); - rsentries_for_sample.reserve(num_rstates); - rngs.reserve(num_rstates); + sample_indices.reserve(num_rsentries); + rsentries_for_sample.reserve(num_rsentries); + rngs.reserve(num_rsentries); request_ids.clear(); generation_cfg.clear(); - for (int i = 0; i < num_rstates; ++i) { + for (int i = 0; i < num_rsentries; ++i) { estate->stats.current_total_seq_len += prefill_lengths[i]; - const RequestStateEntry& rstate = rstates[i]; - for (int child_idx : rstate->child_indices) { - if (rstates_of_requests[i]->entries[child_idx]->mstates[0]->committed_tokens.empty()) { - // If rstates_of_requests[i][child_idx] has no committed token, - // the prefill of the current rstate will unblock rstates_of_requests[i][child_idx], - // and thus we want to sample a token for rstates_of_requests[i][child_idx]. + const RequestStateEntry& rsentry = rsentries[i]; + for (int child_idx : rsentry->child_indices) { + if (rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty()) { + // If rstates_of_entries[i]->entries[child_idx] has no committed token, + // the prefill of the current rsentry will unblock + // rstates_of_entries[i]->entries[child_idx], + // and thus we want to sample a token for rstates_of_entries[i]->entries[child_idx]. sample_indices.push_back(i); - rsentries_for_sample.push_back(rstates_of_requests[i]->entries[child_idx]); - request_ids.push_back(rstate->request->id); - generation_cfg.push_back(rstate->request->generation_cfg); - rngs.push_back(&rstates_of_requests[i]->entries[child_idx]->rng); + rsentries_for_sample.push_back(rstates_of_entries[i]->entries[child_idx]); + request_ids.push_back(rsentry->request->id); + generation_cfg.push_back(rsentry->request->generation_cfg); + rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng); - ICHECK(rstates_of_requests[i]->entries[child_idx]->status == - RequestStateStatus::kPending); - rstates_of_requests[i]->entries[child_idx]->status = RequestStateStatus::kAlive; + ICHECK(rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending); + rstates_of_entries[i]->entries[child_idx]->status = RequestStateStatus::kAlive; for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { models_[model_id]->ForkSequence( - rstate->mstates[model_id]->internal_id, - rstates_of_requests[i]->entries[child_idx]->mstates[model_id]->internal_id); + rsentry->mstates[model_id]->internal_id, + rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id); } } } - if (rstate->child_indices.empty()) { - // If rstate has no child, we sample a token for itself. + if (rsentry->child_indices.empty()) { + // If rsentry has no child, we sample a token for itself. sample_indices.push_back(i); - rsentries_for_sample.push_back(rstate); - request_ids.push_back(rstate->request->id); - generation_cfg.push_back(rstate->request->generation_cfg); - rngs.push_back(&rstate->rng); + rsentries_for_sample.push_back(rsentry); + request_ids.push_back(rsentry->request->id); + generation_cfg.push_back(rsentry->request->generation_cfg); + rngs.push_back(&rsentry->rng); } } std::vector sample_results = sampler_->BatchSampleTokens( @@ -198,26 +204,26 @@ class NewRequestPrefillActionObj : public EngineActionObj { std::vector processed_requests; { - processed_requests.reserve(num_rstates); + processed_requests.reserve(num_rsentries); std::unordered_set dedup_map; - for (int i = 0; i < static_cast(rstates.size()); ++i) { - const RequestStateEntry& rstate = rstates[i]; - if (dedup_map.find(rstate->request.get()) != dedup_map.end()) { + for (int i = 0; i < static_cast(rsentries.size()); ++i) { + const RequestStateEntry& rsentry = rsentries[i]; + if (dedup_map.find(rsentry->request.get()) != dedup_map.end()) { continue; } - dedup_map.insert(rstate->request.get()); - processed_requests.push_back(rstate->request); + dedup_map.insert(rsentry->request.get()); + processed_requests.push_back(rsentry->request); bool pending_state_exists = false; - for (const RequestStateEntry& request_state : rstates_of_requests[i]->entries) { - if (request_state->status == RequestStateStatus::kPending) { + for (const RequestStateEntry& rsentry_ : rstates_of_entries[i]->entries) { + if (rsentry_->status == RequestStateStatus::kPending) { pending_state_exists = true; break; } } if (!pending_state_exists) { auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), - rstate->request); + rsentry->request); ICHECK(it != estate->waiting_queue.end()); estate->waiting_queue.erase(it); } @@ -228,12 +234,11 @@ class NewRequestPrefillActionObj : public EngineActionObj { private: /*! - * \brief Find one or multiple request states to run prefill. + * \brief Find one or multiple request state entries to run prefill. * \param estate The engine state. - * \return The requests to prefill, together with their respective - * state and input length. + * \return The request entries to prefill, together with their input lengths. */ - std::tuple, std::vector> GetRequestStatesToPrefill( + std::tuple, std::vector> GetRequestStateEntriesToPrefill( EngineState estate) { if (estate->waiting_queue.empty()) { // No request to prefill. @@ -322,6 +327,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; + /*! \brief Workspace of each model. */ + std::vector model_workspaces_; /*! \brief The KV cache config to help decide prefill is doable. */ KVCacheConfig kv_cache_config_; /*! \brief The engine operation mode. */ @@ -331,12 +338,14 @@ class NewRequestPrefillActionObj : public EngineActionObj { }; EngineAction EngineAction::NewRequestPrefill(Array models, LogitProcessor logit_processor, - Sampler sampler, KVCacheConfig kv_cache_config, - EngineMode engine_mode, + Sampler sampler, + std::vector model_workspaces, + KVCacheConfig kv_cache_config, EngineMode engine_mode, Optional trace_recorder) { return EngineAction(make_object( - std::move(models), std::move(logit_processor), std::move(sampler), std::move(kv_cache_config), - std::move(engine_mode), std::move(trace_recorder))); + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(kv_cache_config), std::move(engine_mode), + std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 39214d6e8a..46855221d1 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -199,6 +199,7 @@ void FunctionTable::_InitFunctions() { this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true); this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true); this->apply_bitmask_func_ = mod->GetFunction("apply_bitmask_inplace", true); + this->alloc_embedding_tensor_func_ = mod_get_func("alloc_embedding_tensor"); this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache"); if (!this->create_kv_cache_func_.defined()) { this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); @@ -219,7 +220,9 @@ void FunctionTable::_InitFunctions() { this->kv_cache_popn_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_popn"); this->kv_cache_get_num_available_pages_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_get_num_available_pages"); - this->view_func_ = get_global_func("vm.builtin.reshape"); + this->nd_view_func_ = get_global_func("vm.builtin.reshape"); + this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of"); + this->nd_copy_embedding_to_offset_func_ = get_global_func("mlc.copy_embedding_to_offset"); support_backtracking_kv_ = true; } @@ -245,7 +248,7 @@ ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String tensor_ this->disco_buffers.Set(tensor_name, buffer); } ShapeTuple real_shape = host_array.Shape(); - DRef buffer_view = view_func_(buffer, real_shape); + DRef buffer_view = nd_view_func_(buffer, real_shape); sess->CopyToWorker0(host_array, buffer_view); return buffer_view; } else { diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index 5475886d11..9f8d8daed6 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -74,6 +74,7 @@ struct FunctionTable { PackedFunc apply_logit_bias_func_; PackedFunc apply_penalty_func_; PackedFunc apply_bitmask_func_; + PackedFunc alloc_embedding_tensor_func_; PackedFunc create_kv_cache_func_; PackedFunc reset_kv_cache_func_; bool support_backtracking_kv_; @@ -85,7 +86,9 @@ struct FunctionTable { PackedFunc kv_cache_attention_func_; PackedFunc kv_cache_popn_func_; PackedFunc kv_cache_get_num_available_pages_func_; - PackedFunc view_func_; + PackedFunc nd_view_func_; + PackedFunc nd_get_shape_func_; + PackedFunc nd_copy_embedding_to_offset_func_; }; } // namespace serve diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index 1afcf10c60..f5fe8b661a 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -312,7 +312,7 @@ class LogitProcessorImpl : public LogitProcessorObj { // - Set arrays. int batch_size = logits->shape[0]; ICHECK((cum_num_token == nullptr && batch_size == mstates.size()) || - (cum_num_token != nullptr && batch_size == cum_num_token->size())); + (cum_num_token != nullptr && batch_size == cum_num_token->back())); std::memset(p_seq_ids, 0, batch_size * sizeof(int32_t)); diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index c89eaaceae..113648b3a9 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -19,72 +19,6 @@ namespace serve { /*********************** Utils ***********************/ -/*! - * \brief Concatenate the input embeddings along the sequence dimension. - * Store the concatenation result into the input destination NDarray. - * Return concatenation result as an NDArray view of the destination array. - * \param embedding_arr The array of embeddings to concatenate. - * \param total_length The total length of the input embeddings along the sequence dim. - * \param device The device where the embeddings locate. - * \param initial_seq_len The initial sequence length to allocate for embeddings. - * \param dst The destination of the concatenation - * \return The concatenated embeddings. - */ -NDArray ConcatEmbeddings(const Array& embedding_arr, int64_t total_length, DLDevice device, - int initial_seq_len, NDArray* dst) { - ICHECK(!embedding_arr.empty()); - if (embedding_arr.size() == 1) { - return embedding_arr[0]; - } - ICHECK_NOTNULL(dst); - int hidden_size = -1; - DataType dtype; - for (NDArray inp_embeddings : embedding_arr) { - // inp_embedding: (1, n, h) - CHECK_EQ(inp_embeddings->ndim, 3); - CHECK_EQ(inp_embeddings->shape[0], 1); - CHECK_EQ(inp_embeddings->device.device_type, device.device_type); - CHECK_EQ(inp_embeddings->device.device_id, device.device_id); - if (hidden_size == -1) { - hidden_size = inp_embeddings->shape[2]; - dtype = inp_embeddings.DataType(); - } else { - CHECK_EQ(inp_embeddings->shape[2], hidden_size); - CHECK_EQ(inp_embeddings.DataType(), dtype); - } - } - - // - Resize the shared embedding array. - if (dst->defined()) { - ICHECK_EQ((*dst)->ndim, 3); - ICHECK_EQ((*dst)->shape[0], 1); - ICHECK_EQ((*dst)->shape[2], hidden_size); - } - int64_t init_size = dst->defined() ? (*dst)->shape[1] : initial_seq_len; - while (init_size < total_length) { - init_size *= 2; - } - if (!dst->defined() || init_size != (*dst)->shape[1]) { - *dst = NDArray::Empty({1, init_size, hidden_size}, dtype, device); - } - - // - Copy input embeddings. - int64_t start_pos = 0; - for (NDArray inp_embeddings : embedding_arr) { - int64_t length = inp_embeddings->shape[1]; - CHECK_LE(start_pos + length, total_length); - - DLTensor copy_dst = *(dst->operator->()); - copy_dst.byte_offset = start_pos * hidden_size * dtype.bytes(); - copy_dst.shape = inp_embeddings->shape; - NDArray::CopyFromTo(inp_embeddings.operator->(), ©_dst); - - start_pos += length; - } - CHECK_EQ(start_pos, total_length); - return dst->CreateView({1, total_length, hidden_size}, dtype); -} - /*! \brief Utility function that copies input array to the device. */ template NDArray CopyArrayToDevice(const std::vector& array, NDArray* dst, DLDataType dtype, @@ -159,37 +93,30 @@ class ModelImpl : public ModelObj { /*********************** Model Computation ***********************/ - NDArray TokenEmbed(IntTuple token_ids) final { + ObjectRef TokenEmbed(IntTuple token_ids, ObjectRef* dst, int offset) final { int num_tokens = token_ids.size(); std::vector vec_token_ids(token_ids->data, token_ids->data + num_tokens); // Copy input token ids to device. DLDataType dtype(DataType::Int(32)); NDArray token_ids_nd = - CopyArrayToDevice(vec_token_ids, &input_token_ids_, dtype, max_window_size_, device_); + CopyArrayToDevice(vec_token_ids, &input_token_ids_, dtype, prefill_chunk_size_, device_); ICHECK_EQ(token_ids_nd->ndim, 1); ICHECK_EQ(token_ids_nd->shape[0], num_tokens); token_ids_nd = token_ids_nd.CreateView({1, num_tokens}, dtype); - - CHECK(ft_.embed_func_.defined()) - << "`embed` function is not found in the model. Please make sure the model is compiled " - "with flag `--sep-embed` and `--enable-batching`"; - auto token_ids_dref_or_nd = ft_.CopyToWorker0(token_ids_nd, "token_ids", {max_window_size_}); + auto token_ids_dref_or_nd = ft_.CopyToWorker0(token_ids_nd, "token_ids", {prefill_chunk_size_}); ObjectRef embeddings = ft_.embed_func_(token_ids_dref_or_nd, params_); - NDArray embeddings_ndarray; - if (ft_.use_disco) { - embeddings_ndarray = Downcast(embeddings)->DebugGetFromRemote(0); + if (dst != nullptr) { + CHECK(dst->defined()); + ft_.nd_copy_embedding_to_offset_func_(embeddings, *dst, offset); + return *dst; } else { - embeddings_ndarray = Downcast(embeddings); + CHECK_EQ(offset, 0); + return embeddings; } - // embeddings: (1, total_length, hidden_size) - ICHECK_EQ(embeddings_ndarray->ndim, 3); - ICHECK_EQ(embeddings_ndarray->shape[0], 1); - ICHECK_EQ(embeddings_ndarray->shape[1], num_tokens); - return embeddings_ndarray; } - NDArray BatchPrefill(const Array& embedding_arr, const std::vector& seq_ids, + NDArray BatchPrefill(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) final { CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); @@ -202,15 +129,6 @@ class ModelImpl : public ModelObj { logit_pos.push_back(total_length - 1); } - // embeddings: (1, n, h) - NDArray embeddings = - ConcatEmbeddings(embedding_arr, total_length, device_, max_window_size_, &embeddings_); - ICHECK_EQ(embeddings->ndim, 3); - ICHECK_EQ(embeddings->shape[0], 1); - ICHECK_EQ(embeddings->shape[1], total_length); - ICHECK_EQ(embeddings->device.device_type, device_.device_type); - ICHECK_EQ(embeddings->device.device_id, device_.device_id); - NDArray logit_pos_nd = CopyArrayToDevice(logit_pos, &logit_pos_arr_, DataType::Int(32), 32, device_); @@ -226,8 +144,23 @@ class ModelImpl : public ModelObj { IntTuple lengths_tuple(lengths.begin(), lengths.end()); ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); - ObjectRef embeddings_dref_or_nd = ft_.CopyToWorker0( - embeddings, "embedding_prefill", {1, max_window_size_, embeddings.Shape()[2]}); + ObjectRef embeddings_dref_or_nd; + if (!embeddings->IsInstance()) { + // embeddings: (1, n, h) + NDArray embeddings_nd = Downcast(embeddings); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(embeddings_nd->ndim, 3); + ICHECK_EQ(embeddings_nd->shape[0], 1); + ICHECK_GE(embeddings_nd->shape[1], total_length); + ICHECK_EQ(embeddings_nd->shape[2], hidden_size_); + ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); + ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); + embeddings_dref_or_nd = + embeddings_nd.CreateView({1, total_length, hidden_size_}, embeddings_nd->dtype); + } else { + ShapeTuple embedding_shape{1, total_length, hidden_size_}; + embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); + } ObjectRef logit_pos_dref_or_nd = ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); // args: embeddings, logit_pos, kv_cache, params @@ -254,13 +187,8 @@ class ModelImpl : public ModelObj { return logits; } - NDArray BatchDecode(const NDArray& embeddings, const std::vector& seq_ids) final { - // embeddings: (b, 1, h) - CHECK_EQ(embeddings->ndim, 3); - CHECK_EQ(embeddings->shape[0], seq_ids.size()); - CHECK_EQ(embeddings->shape[1], 1); - CHECK_EQ(embeddings->device.device_type, device_.device_type); - CHECK_EQ(embeddings->device.device_id, device_.device_id); + NDArray BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) final { + int num_sequence = seq_ids.size(); CHECK(ft_.decode_func_.defined()) << "`decode_with_embed` function is not found in the model. Please make sure the model is " @@ -272,11 +200,26 @@ class ModelImpl : public ModelObj { // Reserve in KV cache for the lengths of the input. // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); - IntTuple lengths_tuple(std::vector(/*n=*/embeddings->shape[0], /*v=*/1)); + IntTuple lengths_tuple(std::vector(/*n=*/seq_ids.size(), /*v=*/1)); ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); - ObjectRef embeddings_dref_or_nd = ft_.CopyToWorker0( - embeddings, "embedding_decode", {max_num_sequence_, 1, embeddings.Shape()[2]}); + ObjectRef embeddings_dref_or_nd; + if (!embeddings->IsInstance()) { + // embeddings: (1, b, h) + NDArray embeddings_nd = Downcast(embeddings); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(embeddings_nd->ndim, 3); + ICHECK_EQ(embeddings_nd->shape[0], 1); + ICHECK_GE(embeddings_nd->shape[1], num_sequence); + ICHECK_EQ(embeddings_nd->shape[2], hidden_size_); + ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); + ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); + embeddings_dref_or_nd = + embeddings_nd.CreateView({num_sequence, 1, hidden_size_}, embeddings_nd->dtype); + } else { + ShapeTuple embedding_shape{num_sequence, 1, hidden_size_}; + embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); + } // args: embeddings, kv_cache, params ObjectRef ret; @@ -297,12 +240,12 @@ class ModelImpl : public ModelObj { // logits: (b, 1, v) ICHECK_EQ(logits->ndim, 3); - ICHECK_EQ(logits->shape[0], embeddings->shape[0]); + ICHECK_EQ(logits->shape[0], num_sequence); ICHECK_EQ(logits->shape[1], 1); return logits; } - NDArray BatchVerify(const NDArray& embeddings, const std::vector& seq_ids, + NDArray BatchVerify(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) final { CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); @@ -312,13 +255,6 @@ class ModelImpl : public ModelObj { total_length += lengths[i]; } - // embeddings: (1, n, h) - ICHECK_EQ(embeddings->ndim, 3); - ICHECK_EQ(embeddings->shape[0], 1); - ICHECK_EQ(embeddings->shape[1], total_length); - ICHECK_EQ(embeddings->device.device_type, device_.device_type); - ICHECK_EQ(embeddings->device.device_id, device_.device_id); - CHECK(ft_.verify_func_.defined()) << "`verify_with_embed` function is not found in the model. Please make sure the model is " "compiled with flag `--sep-embed` and `--enable-batching`"; @@ -331,8 +267,23 @@ class ModelImpl : public ModelObj { IntTuple lengths_tuple(lengths.begin(), lengths.end()); ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); - ObjectRef embeddings_dref_or_nd = ft_.CopyToWorker0( - embeddings, "embedding_verify", {1, max_window_size_, embeddings.Shape()[2]}); + ObjectRef embeddings_dref_or_nd; + if (!embeddings->IsInstance()) { + // embeddings: (1, n, h) + NDArray embeddings_nd = Downcast(embeddings); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(embeddings_nd->ndim, 3); + ICHECK_EQ(embeddings_nd->shape[0], 1); + ICHECK_GE(embeddings_nd->shape[1], total_length); + ICHECK_EQ(embeddings_nd->shape[2], hidden_size_); + ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); + ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); + embeddings_dref_or_nd = + embeddings_nd.CreateView({1, total_length, hidden_size_}, embeddings_nd->dtype); + } else { + ShapeTuple embedding_shape{1, total_length, hidden_size_}; + embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); + } // args: embeddings, logit_pos, kv_cache, params ObjectRef ret = ft_.verify_func_(embeddings_dref_or_nd, kv_cache_, params_); NDArray logits; @@ -407,6 +358,26 @@ class ModelImpl : public ModelObj { return max_window_size_; } + ObjectRef AllocEmbeddingTensor() final { + // Allocate the embedding tensor. + ObjectRef embedding = ft_.alloc_embedding_tensor_func_(); + // Get the shape of the embedding tensor for hidden size. + ShapeTuple embedding_shape; + if (ft_.use_disco) { + ICHECK(embedding->IsInstance()); + ObjectRef shape_ref = ft_.nd_get_shape_func_(embedding); + embedding_shape = Downcast(shape_ref)->DebugGetFromRemote(0); + } else { + NDArray embedding_nd = Downcast(embedding); + embedding_shape = embedding_nd.Shape(); + } + ICHECK_EQ(embedding_shape.size(), 3); + ICHECK_EQ(embedding_shape[0], 1); + ICHECK_EQ(embedding_shape[1], prefill_chunk_size_); + this->hidden_size_ = embedding_shape[2]; + return embedding; + } + void Reset() final { // Reset the KV cache. if (kv_cache_.defined()) { @@ -437,6 +408,12 @@ class ModelImpl : public ModelObj { } else { LOG(FATAL) << "Key \"tensor_parallel_shards\" not found."; } + if (config.count("prefill_chunk_size")) { + CHECK(config["prefill_chunk_size"].is()); + this->prefill_chunk_size_ = config["prefill_chunk_size"].get(); + } else { + LOG(FATAL) << "Key \"prefill_chunk_size\" not found."; + } if (config.count("vocab_size")) { CHECK(config["vocab_size"].is()); this->vocab_size_ = config["vocab_size"].get(); @@ -452,6 +429,8 @@ class ModelImpl : public ModelObj { int max_window_size_ = -1; int num_shards_ = -1; int max_num_sequence_ = -1; + int prefill_chunk_size_ = -1; + int hidden_size_ = -1; int vocab_size_ = -1; //---------------------------- // TVM related states @@ -466,11 +445,28 @@ class ModelImpl : public ModelObj { ObjectRef params_; // Shared NDArray NDArray input_token_ids_{nullptr}; - NDArray embeddings_{nullptr}; NDArray logit_pos_arr_{nullptr}; - NDArray temperature_arr_{nullptr}; }; +TVM_REGISTER_GLOBAL("mlc.copy_embedding_to_offset") + .set_body_typed([](NDArray embedding, NDArray dst, int offset) { + // embedding: (1, m, hidden_size) + // dst: (1, prefill_chunk_size, hidden_size) + ICHECK_EQ(embedding->ndim, 3); + ICHECK_EQ(embedding->shape[0], 1); + ICHECK_EQ(dst->ndim, 3); + ICHECK_EQ(dst->shape[0], 1); + ICHECK_LE(embedding->shape[1] + offset, dst->shape[1]); + ICHECK_EQ(embedding->shape[2], dst->shape[2]); + const DLTensor& copy_src = *(embedding.operator->()); + const DLTensor* p_copy_dst = dst.operator->(); + DLTensor copy_dst = *p_copy_dst; + copy_dst.shape = embedding->shape; + copy_dst.byte_offset = + offset * embedding->shape[2] * ((embedding->dtype.bits * embedding->dtype.lanes + 7) / 8); + NDArray::CopyFromTo(©_src, ©_dst); + }); + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/model.h b/cpp/serve/model.h index fe396c4094..acc50187d2 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -23,6 +23,20 @@ namespace serve { using tvm::Device; using namespace tvm::runtime; +/*! + * \brief The workspace tensors that may be shared across different + * calls to Model. For example, the prefill action use the `embeddings` + * workspace for the concatenated embeddings of different sequences. + * The workspace tensor is created by Model but owned by engine. + */ +struct ModelWorkspace { + /*! + * \brief The embedding tensor. It can be either an NDArray when tensor + * model parallelism is not enabled, or a DRef when using tensor model parallelism. + */ + ObjectRef embeddings{nullptr}; +}; + /*! * \brief The model module for LLM functions. * It runs an LLM, and has an internal KV cache that maintains @@ -53,10 +67,18 @@ class ModelObj : public Object { /*! * \brief Compute embeddings for the input token ids. + * When the input destination pointer is defined, it in-place writes the + * embedding into the input destination array at the given offset. + * Otherwise, the embeddings will be directly returned back. * \param token_ids The token ids to compute embedding for. - * \return The computed embeddings. + * \param dst The destination array of the embedding lookup. + * \param offset The token offset where the computed embeddings will be written + * into the destination array. + * \return The updated destination embedding array or the computed embeddings. + * \note When `dst` is undefined, we require `offset` to be 0. */ - virtual NDArray TokenEmbed(IntTuple batch_token_ids) = 0; + virtual ObjectRef TokenEmbed(IntTuple batch_token_ids, ObjectRef* dst = nullptr, + int offset = 0) = 0; /*! * \brief Batch prefill function. Embedding in, logits out. @@ -67,8 +89,7 @@ class ModelObj : public Object { * \param lengths The length of each sequence to prefill. * \return The logits for the next token. */ - virtual NDArray BatchPrefill(const Array& embedding_arr, - const std::vector& seq_ids, + virtual NDArray BatchPrefill(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) = 0; /*! @@ -79,7 +100,7 @@ class ModelObj : public Object { * \param seq_id The id of the sequence in the KV cache. * \return The logits for the next token for each sequence in the batch. */ - virtual NDArray BatchDecode(const NDArray& embeddings, const std::vector& seq_ids) = 0; + virtual NDArray BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) = 0; /*! * \brief Batch verify function. Embedding in, logits out. @@ -91,7 +112,7 @@ class ModelObj : public Object { * That is to say, it does not accept "running a verify step for a subset * of the full batch". */ - virtual NDArray BatchVerify(const NDArray& embeddings, const std::vector& seq_ids, + virtual NDArray BatchVerify(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) = 0; /*********************** KV Cache Management ***********************/ @@ -135,6 +156,9 @@ class ModelObj : public Object { /*! \brief Get the max window size of the model. */ virtual int GetMaxWindowSize() const = 0; + /*! \brief Allocate an embedding tensor with the prefill chunk size. */ + virtual ObjectRef AllocEmbeddingTensor() = 0; + /*! \brief Reset the model KV cache and other statistics. */ virtual void Reset() = 0; diff --git a/python/mlc_chat/compiler_pass/attach_to_ir_module.py b/python/mlc_chat/compiler_pass/attach_to_ir_module.py index 06026397a4..47baacd755 100644 --- a/python/mlc_chat/compiler_pass/attach_to_ir_module.py +++ b/python/mlc_chat/compiler_pass/attach_to_ir_module.py @@ -1,6 +1,6 @@ """A couple of passes that simply attach additional information onto the IRModule.""" -from typing import Dict +from typing import Any, Dict import tvm from tvm import IRModule, relax, tir @@ -62,6 +62,39 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR return mod +@tvm.transform.module_pass(opt_level=0, name="AttachAllocEmbeddingTensorFunc") +class AttachAllocEmbeddingTensorFunc: # pylint: disable=too-few-public-methods + """Attach embedding tensor allocation Relax function to IRModule.""" + + def __init__(self, metadata: Dict[str, Any]): + self.metadata = metadata + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + embed_func = None + for gv, func in mod.functions_items(): + if gv.name_hint == "embed": + embed_func = func + + if embed_func is None: + return mod + + hidden_size = embed_func.ret_struct_info.shape[-1] + dtype = embed_func.ret_struct_info.dtype + bb = relax.BlockBuilder(mod) + with bb.function("alloc_embedding_tensor", []): + bb.emit_func_output( + bb.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr([1, self.metadata["prefill_chunk_size"], hidden_size]), + dtype, + runtime_device_index=0, + ) + ) + ) + return bb.finalize() + + @T.prim_func def _apply_logit_bias_inplace( var_logits: T.handle, diff --git a/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py b/python/mlc_chat/compiler_pass/dispatch_kv_cache_creation.py similarity index 97% rename from python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py rename to python/mlc_chat/compiler_pass/dispatch_kv_cache_creation.py index d167a8bf6d..08cf730f5f 100644 --- a/python/mlc_chat/compiler_pass/rewrite_kv_cache_creation.py +++ b/python/mlc_chat/compiler_pass/dispatch_kv_cache_creation.py @@ -48,8 +48,8 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]: } -@tvm.transform.module_pass(opt_level=0, name="RewriteKVCacheCreation") -class RewriteKVCacheCreation: # pylint: disable=too-many-instance-attributes +@tvm.transform.module_pass(opt_level=0, name="DispatchKVCacheCreation") +class DispatchKVCacheCreation: # pylint: disable=too-many-instance-attributes """Rewrite KV cache creation functions to IRModule.""" def __init__( diff --git a/python/mlc_chat/compiler_pass/pipeline.py b/python/mlc_chat/compiler_pass/pipeline.py index 98922c6139..00d0d3c4f8 100644 --- a/python/mlc_chat/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler_pass/pipeline.py @@ -13,12 +13,14 @@ from .attach_to_ir_module import ( AttachAdditionalPrimFuncs, + AttachAllocEmbeddingTensorFunc, AttachLogitProcessFunc, AttachMemoryPlanAttr, AttachVariableBounds, ) from .clean_up_tir_attrs import CleanUpTIRAttrs from .cublas_dispatch import CublasDispatch +from .dispatch_kv_cache_creation import DispatchKVCacheCreation from .estimate_memory_usage import AttachMetadataWithMemoryUsage from .fuse_add_norm import FuseAddRMSNorm from .fuse_dequantize_matmul_ewise import FuseDequantizeMatmulEwise @@ -27,7 +29,6 @@ from .fuse_ft_dequantize_matmul_epilogue import FuseFTDequantizeEpilogue from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc -from .rewrite_kv_cache_creation import RewriteKVCacheCreation from .scatter_tuple_get_item import ScatterTupleGetItem logger = logging.getLogger(__name__) @@ -88,10 +89,11 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I seq = tvm.transform.Sequential( [ # Phase 0. Add additional information for compilation and remove unused Relax func - RewriteKVCacheCreation(target, flashinfer, metadata), + DispatchKVCacheCreation(target, flashinfer, metadata), AttachVariableBounds(variable_bounds), AttachLogitProcessFunc(), AttachAdditionalPrimFuncs(additional_tirs), + AttachAllocEmbeddingTensorFunc(metadata), AttachMemoryPlanAttr(), tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)), _DebugDump("debug-phase0.py", debug_dump, show_meta=False), diff --git a/python/mlc_chat/serve/engine.py b/python/mlc_chat/serve/engine.py index 6343658f51..c4b3e5d9b4 100644 --- a/python/mlc_chat/serve/engine.py +++ b/python/mlc_chat/serve/engine.py @@ -215,7 +215,7 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals max_total_sequence_length = int( ( - int(gpu_size_bytes) * 0.85 + int(gpu_size_bytes) * 0.90 - params_bytes - temp_func_bytes - kv_aux_workspace_bytes From 88ac813cf33922fee5924cb2c9fa191c0def3a92 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 7 Mar 2024 09:07:35 -0500 Subject: [PATCH 040/531] [Model] Remove redundant `batch_forward` and move broadcast (#1900) This PR contains four changes: 1. It removes the duplicate `batch_forward` defined in model definitions. This function was widely used prior to our migration to PagedKVCache, since before migration the attention codepath of single sequence forward and batch forward differ. But since our migration, the codepaths are unified into one, and therefore we can safely remove most `batch_forward` functions. 2. It moves `op.ccl_broadcast_from_worker0` from model main forward (which will be called at the beginning of prefill/decode) to embedding. This change has two benefits. Firstly, the token ids taken by `embed` was not broadcasted across workers, and it is possible for workers other than 0 to have illegal token ids which is not in the range of vocab size, and moving the broadcasting to `embed` perfectly address this issue. Secondly, broadcasting token ids in `embed` is more lightweight than broadcasting embeddings in `prefill`/`decode`, since the tensor size of token ids is much smaller. 3. It adds `max_batch_size` to the config class of models, so that they are potentially compatible with batching and MLC serve. 4. It removes the `k_cache` and `v_cache` effects from the models that have switched to PagedKVCache support. Randomly picked a few models (as below) to run the engine test, and all of them are passed: * phi-2 with tp=2, * RedPajama with tp=2, * stablelm with tp=2 (since stablelm does not support TP right now). --- .../mlc_chat/model/baichuan/baichuan_model.py | 31 +-------- python/mlc_chat/model/gemma/gemma_model.py | 5 +- python/mlc_chat/model/gpt2/gpt2_model.py | 64 ++----------------- .../model/gpt_bigcode/gpt_bigcode_model.py | 54 ++-------------- .../mlc_chat/model/gpt_neox/gpt_neox_model.py | 55 ++-------------- .../mlc_chat/model/internlm/internlm_model.py | 30 +-------- python/mlc_chat/model/llama/llama_model.py | 38 +---------- .../mlc_chat/model/mixtral/mixtral_model.py | 5 +- python/mlc_chat/model/phi/phi_model.py | 46 ++----------- python/mlc_chat/model/qwen/qwen_model.py | 35 +--------- .../model/stable_lm/stablelm_model.py | 38 +---------- 11 files changed, 36 insertions(+), 365 deletions(-) diff --git a/python/mlc_chat/model/baichuan/baichuan_model.py b/python/mlc_chat/model/baichuan/baichuan_model.py index 8e8944783e..6119afc10f 100644 --- a/python/mlc_chat/model/baichuan/baichuan_model.py +++ b/python/mlc_chat/model/baichuan/baichuan_model.py @@ -2,6 +2,7 @@ Implementation for BAICHUAN architecture. TODO: add docstring """ + import dataclasses from typing import Any, Dict, Optional @@ -37,6 +38,7 @@ class BaichuanConfig(ConfigBase): # pylint: disable=too-many-instance-attribute context_window_size: int = 0 prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -100,17 +102,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: attn_output = self.o_proj(output) return attn_output - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - d, h = self.head_dim, self.num_heads - b, s, _ = hidden_states.shape - qkv = self.W_pack(hidden_states) - qkv = op.reshape(qkv, (b, s, 3 * h, d)) - output = op.reshape( - paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), (b, s, h * d) - ) - attn_output = self.o_proj(output) - return attn_output - class BaichuanMLP(nn.Module): def __init__(self, config: BaichuanConfig): @@ -142,15 +133,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: hidden_states = out + hidden_states return hidden_states - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - out = self.self_attn.batch_forward( - self.input_layernorm(hidden_states), paged_kv_cache, layer_id - ) - hidden_states = out + hidden_states - out = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = out + hidden_states - return hidden_states - class BaichuanModel(nn.Module): def __init__(self, config: BaichuanConfig): @@ -168,13 +150,6 @@ def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): hidden_states = self.norm(hidden_states) return hidden_states - def batch_forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): - hidden_states = inputs - for layer_id, layer in enumerate(self.layers): - hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) - hidden_states = self.norm(hidden_states) - return hidden_states - class BaichuanForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: BaichuanConfig): @@ -203,7 +178,7 @@ def batch_forward( ): op_ext.configure() - hidden_states = self.model.batch_forward(input_embeds, paged_kv_cache) + hidden_states = self.model(input_embeds, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.lm_head(hidden_states) diff --git a/python/mlc_chat/model/gemma/gemma_model.py b/python/mlc_chat/model/gemma/gemma_model.py index 01455896a4..080147d393 100644 --- a/python/mlc_chat/model/gemma/gemma_model.py +++ b/python/mlc_chat/model/gemma/gemma_model.py @@ -202,11 +202,8 @@ def __init__(self, config: GemmaConfig): [GemmaDecoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) - self.tensor_parallel_shards = config.tensor_parallel_shards def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): - if self.tensor_parallel_shards > 1: - input_embed = op.ccl_broadcast_from_worker0(input_embed) hidden_states = input_embed hidden_states = hidden_states * (self.hidden_size**0.5) for layer_id, layer in enumerate(self.layers): @@ -250,6 +247,8 @@ def batch_forward( return logits def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.model.embed_tokens(input_ids) def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): diff --git a/python/mlc_chat/model/gpt2/gpt2_model.py b/python/mlc_chat/model/gpt2/gpt2_model.py index 911f0ddaab..1d930ba43d 100644 --- a/python/mlc_chat/model/gpt2/gpt2_model.py +++ b/python/mlc_chat/model/gpt2/gpt2_model.py @@ -35,6 +35,7 @@ class GPT2Config(ConfigBase): # pylint: disable=too-many-instance-attributes scale_attn_by_inverse_layer_idx: bool = False tensor_parallel_shards: int = 1 head_dim: int = 0 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -96,9 +97,6 @@ def __init__(self, config: GPT2Config): ) self.c_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=True) - self.k_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim]) - self.v_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim]) - def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): d, h = self.head_dim, self.num_heads b, s, _ = hidden_states.shape @@ -120,27 +118,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: ) return self.c_proj(output) - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - d, h = self.head_dim, self.num_heads - b, s, _ = hidden_states.shape - - qkv = self.c_attn(hidden_states) - qkv = op.reshape(qkv, (b, s, 3 * h, d)) - - if self.scale_attn_by_inverse_layer_idx: - attn_score_scaling_factor = 1.0 / float(layer_id + 1) - else: - attn_score_scaling_factor = 1.0 - - # Attention - output = op.reshape( - paged_kv_cache.attention_with_fused_qkv( - layer_id, qkv, self.num_heads, attn_score_scaling_factor - ), - (b, s, h * d), - ) - return self.c_proj(output) - class GPT2MLP(nn.Module): def __init__(self, config: GPT2Config): @@ -200,18 +177,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: return hidden_states - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - with tp.shard_bias(self.attn.c_proj, self.tensor_parallel_shards), tp.shard_bias( - self.mlp.c_proj, self.tensor_parallel_shards - ): - hidden_states = self._apply_residual( - self.attn.batch_forward(self.ln_1(hidden_states), paged_kv_cache, layer_id), - hidden_states, - ) - hidden_states = self._apply_residual(self.mlp(self.ln_2(hidden_states)), hidden_states) - - return hidden_states - def _apply_residual(self, out, residual): if self.tensor_parallel_shards > 1: return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum") @@ -225,13 +190,8 @@ def __init__(self, config: GPT2Config): self.wpe = nn.Embedding(config.context_window_size, config.n_embd) self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.tensor_parallel_shards = config.tensor_parallel_shards def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): - if self.tensor_parallel_shards > 1: - inputs = op.ccl_broadcast_from_worker0(inputs) - hidden_states = inputs - # Position Embeddings # Generate np.arange(offset, offset+seq_len) # shape[1] indicates the total query length in the batch @@ -245,24 +205,6 @@ def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): hidden_states = self.ln_f(hidden_states) return hidden_states - def batch_forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): - if self.tensor_parallel_shards > 1: - inputs = op.ccl_broadcast_from_worker0(inputs) - hidden_states = inputs - - # Position Embeddings - # Generate np.arange(offset, offset+seq_len) - # shape[1] indicates the total query length in the batch - input_positions = paged_kv_cache.get_query_positions(inputs.shape[1]) - pos_embd = self.wpe(input_positions) - - # Pass through GPT2Block - hidden_states = hidden_states + pos_embd - for layer_id, layer in enumerate(self.h): - hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) - hidden_states = self.ln_f(hidden_states) - return hidden_states - class GPT2LMHeadModel(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: GPT2Config): @@ -288,7 +230,7 @@ def batch_forward( ): op_ext.configure() - hidden_states = self.transformer.batch_forward(input_embeds, paged_kv_cache) + hidden_states = self.transformer(input_embeds, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.lm_head(hidden_states) @@ -297,6 +239,8 @@ def batch_forward( return logits def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.transformer.wte(input_ids) def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): diff --git a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py index babe901b55..5557ca1614 100644 --- a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py @@ -34,6 +34,7 @@ class GPTBigCodeConfig(ConfigBase): # pylint: disable=too-many-instance-attribu context_window_size: int = 0 prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -108,9 +109,6 @@ def __init__(self, config: GPTBigCodeConfig): bias=True, ) - self.k_cache = nn.KVCache(config.context_window_size, [self.num_kv_heads, self.head_dim]) - self.v_cache = nn.KVCache(config.context_window_size, [self.num_kv_heads, self.head_dim]) - def forward( self, hidden_states: Tensor, @@ -129,19 +127,6 @@ def forward( ) return self.c_proj(output) - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads - b, s, _ = hidden_states.shape - - # QKV Projection - qkv = self.c_attn(hidden_states) - qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) - # Attention - output = op.reshape( - paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, h_q), (b, s, h_q * d) - ) - return self.c_proj(output) - class GPTBigCodeBlock(nn.Module): def __init__(self, config: GPTBigCodeConfig): @@ -173,13 +158,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: hidden_states = out + hidden_states return hidden_states - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - out = self.attn.batch_forward(self.ln_1(hidden_states), paged_kv_cache, layer_id) - hidden_states = out + hidden_states - out = self.mlp(self.ln_2(hidden_states)) - hidden_states = out + hidden_states - return hidden_states - class GPTBigCodeModel(nn.Module): def __init__(self, config: GPTBigCodeConfig): @@ -188,12 +166,8 @@ def __init__(self, config: GPTBigCodeConfig): self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.h = nn.ModuleList([GPTBigCodeBlock(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.tensor_parallel_shards = config.tensor_parallel_shards def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): - if self.tensor_parallel_shards > 1: - input_embed = op.ccl_broadcast_from_worker0(input_embed) - # Position Embeddings # shape[1] indicates the total query length in the batch input_positions = paged_kv_cache.get_query_positions(input_embed.shape[1]) @@ -207,23 +181,6 @@ def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): return hidden_states - def batch_forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): - if self.tensor_parallel_shards > 1: - input_embed = op.ccl_broadcast_from_worker0(input_embed) - - # Position Embeddings - # shape[1] indicates the total query length in the batch - input_positions = paged_kv_cache.get_query_positions(input_embed.shape[1]) - pos_embd = self.wpe(input_positions) - - # apply position embeddings - hidden_states = input_embed + pos_embd - for layer_id, layer in enumerate(self.h): - hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) - hidden_states = self.ln_f(hidden_states) - - return hidden_states - class GPTBigCodeForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: GPTBigCodeConfig): @@ -234,6 +191,7 @@ def __init__(self, config: GPTBigCodeConfig): self.num_q_heads = config.n_head // config.tensor_parallel_shards self.num_kv_heads = 1 self.head_dim = config.n_embd // config.n_head + self.tensor_parallel_shards = config.tensor_parallel_shards self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -249,7 +207,7 @@ def batch_forward( ): op_ext.configure() - hidden_states = self.transformer.batch_forward(input_embed, paged_kv_cache) + hidden_states = self.transformer(input_embed, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.lm_head(hidden_states) @@ -258,6 +216,8 @@ def batch_forward( return logits def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.transformer.wte(input_ids) def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): @@ -313,8 +273,8 @@ def create_paged_kv_cache( prefill_chunk_size=prefill_chunk_size, page_size=page_size, num_hidden_layers=self.n_layer, - num_attention_heads=self.num_q_heads, - num_key_value_heads=self.num_kv_heads, + num_attention_heads=self.num_q_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_kv_heads // self.tensor_parallel_shards, head_dim=self.head_dim, rope_mode=RopeMode.NONE, rope_scale=-1, diff --git a/python/mlc_chat/model/gpt_neox/gpt_neox_model.py b/python/mlc_chat/model/gpt_neox/gpt_neox_model.py index 130d8246b3..b5bd89e9a6 100644 --- a/python/mlc_chat/model/gpt_neox/gpt_neox_model.py +++ b/python/mlc_chat/model/gpt_neox/gpt_neox_model.py @@ -38,6 +38,7 @@ class GPTNeoXConfig(ConfigBase): # pylint: disable=too-many-instance-attributes prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 ffn_out_dtype: str = "float32" + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -122,22 +123,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: attn_output = self.dense(output) return attn_output - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - # hidden_states: [batch_size, seq_len, hidden_size] - batch_size, seq_len, _ = hidden_states.shape - - # q/k/v states: [batch_size, seq_len, hidden_size] - qkv = self.query_key_value(hidden_states) - qkv = op.reshape(qkv, (batch_size, seq_len, 3 * self.num_attention_heads, self.head_dim)) - - # Attention - output = op.reshape( - paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_attention_heads), - (batch_size, seq_len, self.head_dim * self.num_attention_heads), - ) - attn_output = self.dense(output) - return attn_output - class GPTNeoXMLP(nn.Module): def __init__(self, config: GPTNeoXConfig): @@ -223,27 +208,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: hidden_states = self._apply_residual(mlp_output.astype(dtype), attn_output) return hidden_states - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - dtype = hidden_states.dtype - attn_input = self.input_layernorm(hidden_states) - with tp.shard_bias(self.attention.dense, self.tensor_parallel_shards): - attn_output = self.attention.batch_forward( - attn_input, - paged_kv_cache, - layer_id, - ) - if self.use_parallel_residual: - mlp_input = self.post_attention_layernorm(hidden_states) - mlp_output = self.mlp(mlp_input) - hidden_states = mlp_output + attn_output + hidden_states - else: - attn_output = self._apply_residual(attn_output, hidden_states) - mlp_input = self.post_attention_layernorm(attn_output) - with tp.shard_bias(self.mlp.dense_4h_to_h, self.tensor_parallel_shards): - mlp_output = self.mlp(mlp_input) - hidden_states = self._apply_residual(mlp_output.astype(dtype), attn_output) - return hidden_states - def _apply_residual(self, out, residual): if self.tensor_parallel_shards > 1: return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum") @@ -255,11 +219,8 @@ def __init__(self, config: GPTNeoXConfig): self.embed_in = nn.Embedding(num="vocab_size", dim=config.hidden_size) self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.tensor_parallel_shards = config.tensor_parallel_shards def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): - if self.tensor_parallel_shards > 1: - inputs = op.ccl_broadcast_from_worker0(inputs) hidden_states = inputs for layer_id, layer in enumerate(self.layers): @@ -267,16 +228,6 @@ def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): hidden_states = self.final_layer_norm(hidden_states) return hidden_states - def batch_forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): - if self.tensor_parallel_shards > 1: - inputs = op.ccl_broadcast_from_worker0(inputs) - hidden_states = inputs - - for layer_id, layer in enumerate(self.layers): - hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) - hidden_states = self.final_layer_norm(hidden_states) - return hidden_states - class GPTNeoXForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: GPTNeoXConfig): @@ -310,7 +261,7 @@ def batch_forward( ): op_ext.configure() - hidden_states = self.gpt_neox.batch_forward(input_embeds, paged_kv_cache) + hidden_states = self.gpt_neox(input_embeds, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.embed_out(hidden_states) @@ -319,6 +270,8 @@ def batch_forward( return logits def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.gpt_neox.embed_in(input_ids) def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): diff --git a/python/mlc_chat/model/internlm/internlm_model.py b/python/mlc_chat/model/internlm/internlm_model.py index 0f6b92a76f..2c88ccaa71 100644 --- a/python/mlc_chat/model/internlm/internlm_model.py +++ b/python/mlc_chat/model/internlm/internlm_model.py @@ -37,6 +37,7 @@ class InternLMConfig(ConfigBase): # pylint: disable=too-many-instance-attribute context_window_size: int = 0 prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -102,17 +103,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: attn_output = self.o_proj(output) return attn_output - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - d, h = self.head_dim, self.num_heads - b, s, _ = hidden_states.shape - qkv = self.wqkv_pack(hidden_states) - qkv = op.reshape(qkv, (b, s, 3 * h, d)) - output = op.reshape( - paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), (b, s, h * d) - ) - attn_output = self.o_proj(output) - return attn_output - class InternLMMLP(nn.Module): def __init__(self, config: InternLMConfig): @@ -145,15 +135,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: hidden_states = out + hidden_states return hidden_states - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - out = self.self_attn.batch_forward( - self.input_layernorm(hidden_states), paged_kv_cache, layer_id - ) - hidden_states = out + hidden_states - out = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = out + hidden_states - return hidden_states - class InternLMModel(nn.Module): def __init__(self, config: InternLMConfig): @@ -170,13 +151,6 @@ def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): hidden_states = self.norm(hidden_states) return hidden_states - def batch_forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): - hidden_states = inputs - for layer_id, layer in enumerate(self.layers): - hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) - hidden_states = self.norm(hidden_states) - return hidden_states - class InternLMForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: InternLMConfig): @@ -205,7 +179,7 @@ def batch_forward( ): op_ext.configure() - hidden_states = self.model.batch_forward(input_embeds, paged_kv_cache) + hidden_states = self.model(input_embeds, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.lm_head(hidden_states) diff --git a/python/mlc_chat/model/llama/llama_model.py b/python/mlc_chat/model/llama/llama_model.py index 6da1d420ea..8d54829dc0 100644 --- a/python/mlc_chat/model/llama/llama_model.py +++ b/python/mlc_chat/model/llama/llama_model.py @@ -138,19 +138,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: ) return self.o_proj(output) - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads - b, s, _ = hidden_states.shape - # QKV Projection - qkv = self.qkv_proj(hidden_states) - qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) - # Attention - output = op.reshape( - paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), - (b, s, h_q * d), - ) - return self.o_proj(output) - class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): @@ -184,15 +171,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: hidden_states = self._apply_residual(out, residual=hidden_states) return hidden_states - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - out = self.self_attn.batch_forward( - self.input_layernorm(hidden_states), paged_kv_cache, layer_id - ) - hidden_states = self._apply_residual(out, residual=hidden_states) - out = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = self._apply_residual(out, residual=hidden_states) - return hidden_states - def _apply_residual(self, out, residual): if self.tensor_parallel_shards > 1: return op.ccl_allreduce(out, "sum") + residual @@ -207,26 +185,14 @@ def __init__(self, config: LlamaConfig): [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) - self.tensor_parallel_shards = config.tensor_parallel_shards def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): - if self.tensor_parallel_shards > 1: - input_embed = op.ccl_broadcast_from_worker0(input_embed) hidden_states = input_embed for layer_id, layer in enumerate(self.layers): hidden_states = layer(hidden_states, paged_kv_cache, layer_id) hidden_states = self.norm(hidden_states) return hidden_states - def batch_forward(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): - if self.tensor_parallel_shards > 1: - input_embeds = op.ccl_broadcast_from_worker0(input_embeds) - hidden_states = input_embeds - for layer_id, layer in enumerate(self.layers): - hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) - hidden_states = self.norm(hidden_states) - return hidden_states - class LlamaForCasualLM(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: LlamaConfig): @@ -255,7 +221,7 @@ def batch_forward( ): op_ext.configure() - hidden_states = self.model.batch_forward(input_embeds, paged_kv_cache) + hidden_states = self.model(input_embeds, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.lm_head(hidden_states) @@ -264,6 +230,8 @@ def batch_forward( return logits def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.model.embed_tokens(input_ids) def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): diff --git a/python/mlc_chat/model/mixtral/mixtral_model.py b/python/mlc_chat/model/mixtral/mixtral_model.py index a2740f1b5e..2a707b0a77 100644 --- a/python/mlc_chat/model/mixtral/mixtral_model.py +++ b/python/mlc_chat/model/mixtral/mixtral_model.py @@ -1,4 +1,5 @@ """Implementation for Mistral architecture.""" + import dataclasses from tvm import tir @@ -144,9 +145,7 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: return hidden_states def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - out = self.self_attn.batch_forward( - self.input_layernorm(hidden_states), paged_kv_cache, layer_id - ) + out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) hidden_states = self._apply_residual(out, residual=hidden_states) out = self.moe(self.post_attention_layernorm(hidden_states)) hidden_states = self._apply_residual(out, residual=hidden_states) diff --git a/python/mlc_chat/model/phi/phi_model.py b/python/mlc_chat/model/phi/phi_model.py index 04360efbcd..863ecd7298 100644 --- a/python/mlc_chat/model/phi/phi_model.py +++ b/python/mlc_chat/model/phi/phi_model.py @@ -37,6 +37,7 @@ class Phi1Config(ConfigBase): # pylint: disable=too-many-instance-attributes prefill_chunk_size: int = 0 head_dim: int = 0 tensor_parallel_shards: int = 1 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -206,19 +207,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: ) return self.out_proj(output) - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - d, h_q, h_kv = self.head_dim, self.num_q_heads, self.n_head_kv - b, s, _ = hidden_states.shape - # QKV Projection - qkv = self.Wqkv(hidden_states) - qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) - # Attention - output = op.reshape( - paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), - (b, s, h_q * d), - ) - return self.out_proj(output) - class PhiParallelBlock(nn.Module): def __init__(self, config: PhiConfig): @@ -268,22 +256,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: return hidden_states - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - residual = hidden_states - hidden_states = self.ln(hidden_states) - - with tp.shard_bias(self.mixer.out_proj, self.tensor_parallel_shards), tp.shard_bias( - self.mlp.fc2, self.tensor_parallel_shards - ): - attn_outputs = self.mixer.batch_forward(hidden_states, paged_kv_cache, layer_id) - feed_forward_hidden_states = self.mlp(hidden_states) - - hidden_states = self._apply_parallel_residual( - attn_outputs, feed_forward_hidden_states, residual - ) - - return hidden_states - def _apply_parallel_residual(self, attn_out, mlp_out, residual): if self.tensor_parallel_shards > 1: return op.ccl_allreduce( @@ -313,26 +285,14 @@ def __init__(self, config: PhiConfig) -> None: super().__init__() self.embd = nn.Embedding(config.vocab_size, config.n_embd) self.h = nn.ModuleList([PhiParallelBlock(config) for _ in range(config.n_layer)]) - self.tensor_parallel_shards = config.tensor_parallel_shards def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): - if self.tensor_parallel_shards > 1: - input_embed = op.ccl_broadcast_from_worker0(input_embed) hidden_states = input_embed for layer_id, layer in enumerate(self.h): hidden_states = layer(hidden_states, paged_kv_cache, layer_id) return hidden_states - def batch_forward(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): - if self.tensor_parallel_shards > 1: - input_embeds = op.ccl_broadcast_from_worker0(input_embeds) - hidden_states = input_embeds - for layer_id, layer in enumerate(self.h): - hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) - - return hidden_states - class PhiForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes @@ -368,7 +328,7 @@ def batch_forward( ): op_ext.configure() - hidden_states = self.transformer.batch_forward(input_embeds, paged_kv_cache) + hidden_states = self.transformer(input_embeds, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) lm_logits = self.lm_head(hidden_states) @@ -419,6 +379,8 @@ def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) embeds = self.transformer.embd(input_ids) return embeds diff --git a/python/mlc_chat/model/qwen/qwen_model.py b/python/mlc_chat/model/qwen/qwen_model.py index 48c66525fb..b301ff13fe 100644 --- a/python/mlc_chat/model/qwen/qwen_model.py +++ b/python/mlc_chat/model/qwen/qwen_model.py @@ -2,6 +2,7 @@ Implementation for QWEN architecture. TODO: add docstring """ + import dataclasses from typing import Any, Dict, Optional @@ -34,6 +35,7 @@ class QWenConfig(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = 0 prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -91,10 +93,6 @@ def __init__(self, config: QWenConfig): ) self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=False) - # KV cache for single sequence - self.k_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim]) - self.v_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim]) - def forward( # pylint: disable=too-many-locals self, hidden_states: Tensor, @@ -111,17 +109,6 @@ def forward( # pylint: disable=too-many-locals ) return self.c_proj(output) - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - b, s, _ = hidden_states.shape - qkv = self.c_attn(hidden_states) - qkv = op.reshape(qkv, (b, s, 3 * self.head_dim, self.num_heads)) - # try batch forward - output = op.reshape( - paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), - (b, s, self.head_dim * self.num_heads), - ) - return self.c_proj(output) - class QWenMLP(nn.Module): def __init__(self, config: QWenConfig): @@ -154,13 +141,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: hidden_states = out + hidden_states return hidden_states - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - out = self.attn.batch_forward(self.ln_1(hidden_states), paged_kv_cache, layer_id) - hidden_states = out + hidden_states - out = self.mlp(self.ln_2(hidden_states)) - hidden_states = out + hidden_states - return hidden_states - class QWenModel(nn.Module): def __init__(self, config: QWenConfig): @@ -170,21 +150,12 @@ def __init__(self, config: QWenConfig): self.ln_f = nn.RMSNorm(config.hidden_size, -1, config.layer_norm_epsilon, bias=False) def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): - # hidden_states = self.wte(input_ids) hidden_states = inputs for layer_id, layer in enumerate(self.h): hidden_states = layer(hidden_states, paged_kv_cache, layer_id) hidden_states = self.ln_f(hidden_states) return hidden_states - def batch_forward(self, inputs, paged_kv_cache: PagedKVCache): - # hidden_states = self.wte(input_ids) - hidden_states = inputs - for layer_id, layer in enumerate(self.h): - hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) - hidden_states = self.ln_f(hidden_states) - return hidden_states - class QWenLMHeadModel(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: QWenConfig): @@ -211,7 +182,7 @@ def batch_forward( logit_positions: Optional[Tensor] = None, ): op_ext.configure() - hidden_states = self.transformer.batch_forward(inputs, paged_kv_cache) + hidden_states = self.transformer(inputs, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.lm_head(hidden_states) diff --git a/python/mlc_chat/model/stable_lm/stablelm_model.py b/python/mlc_chat/model/stable_lm/stablelm_model.py index 7f5e56e819..edb4885123 100644 --- a/python/mlc_chat/model/stable_lm/stablelm_model.py +++ b/python/mlc_chat/model/stable_lm/stablelm_model.py @@ -36,6 +36,7 @@ class StableLmConfig(ConfigBase): # pylint: disable=too-many-instance-attribute context_window_size: int = 0 prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -95,13 +96,6 @@ def __init__(self, config: StableLmConfig): bias=config.use_qkv_bias, ) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - # KV cache for single sequence - self.k_cache = nn.KVCache( - config.context_window_size, [self.num_key_value_heads, self.head_dim] - ) - self.v_cache = nn.KVCache( - config.context_window_size, [self.num_key_value_heads, self.head_dim] - ) def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads @@ -115,18 +109,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: attn_output = self.o_proj(output) return attn_output - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads - b, s, _ = hidden_states.shape - qkv = self.qkv_proj(hidden_states) - qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) - output = op.reshape( - paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), - (b, s, h_q * d), - ) - attn_output = self.o_proj(output) - return attn_output - class StableLmMLP(nn.Module): def __init__(self, config: StableLmConfig): @@ -159,15 +141,6 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: hidden_states = out + hidden_states return hidden_states - def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): - out = self.self_attn.batch_forward( - self.input_layernorm(hidden_states), paged_kv_cache, layer_id - ) - hidden_states = out + hidden_states - out = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = out + hidden_states - return hidden_states - class StableLmModel(nn.Module): def __init__(self, config: StableLmConfig): @@ -185,13 +158,6 @@ def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): hidden_states = self.norm(hidden_states) return hidden_states - def batch_forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): - hidden_states = inputs - for layer_id, layer in enumerate(self.layers): - hidden_states = layer.batch_forward(hidden_states, paged_kv_cache, layer_id) - hidden_states = self.norm(hidden_states) - return hidden_states - class StableLmForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: StableLmConfig): @@ -222,7 +188,7 @@ def batch_forward( ): op_ext.configure() - hidden_states = self.model.batch_forward(input_embeds, paged_kv_cache) + hidden_states = self.model(input_embeds, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.lm_head(hidden_states) From 1eaef7c23b03f639f1adf0b44a6d3d813a379d3f Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Fri, 8 Mar 2024 02:19:09 +0800 Subject: [PATCH 041/531] [KVCache]Migrate Qwen2 model to PagedKVCache (#1903) --- python/mlc_chat/model/qwen2/qwen2_model.py | 214 ++++++++++++++------- 1 file changed, 145 insertions(+), 69 deletions(-) diff --git a/python/mlc_chat/model/qwen2/qwen2_model.py b/python/mlc_chat/model/qwen2/qwen2_model.py index f09cceedb2..8fac47fa3e 100644 --- a/python/mlc_chat/model/qwen2/qwen2_model.py +++ b/python/mlc_chat/model/qwen2/qwen2_model.py @@ -11,6 +11,7 @@ from tvm.relax.frontend.nn import Tensor, op from mlc_chat import op as op_ext +from mlc_chat.nn import PagedKVCache, RopeMode from mlc_chat.support import logging from mlc_chat.support.config import ConfigBase from mlc_chat.support.style import bold @@ -31,7 +32,6 @@ class QWen2Config(ConfigBase): # pylint: disable=too-many-instance-attributes rms_norm_eps: float rope_theta: int vocab_size: int - context_window_size: int = 0 prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 @@ -73,7 +73,6 @@ def __post_init__(self): bold("context_window_size"), ) self.prefill_chunk_size = self.context_window_size - assert self.tensor_parallel_shards == 1, "QWEN currently does not support sharding." # pylint: disable=invalid-name,missing-docstring,too-many-locals @@ -105,26 +104,17 @@ def __init__(self, config: QWen2Config): self.num_key_value_heads = config.num_key_value_heads self.rope_theta = config.rope_theta - def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): - bsz, sl, _ = hidden_states.shape - assert bsz == 1, "Only support batch size 1 at this moment." - # Step 1. QKV Projection + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_attention_heads, self.num_key_value_heads + b, s, _ = hidden_states.shape qkv = self.c_attn(hidden_states) - num_heads = 2 * self.num_key_value_heads + self.num_attention_heads - qkv = op.reshape(qkv, (bsz, sl, num_heads, self.head_dim)) - # Step 2. Apply QK rotary embedding - q, k, v = op_ext.llama_rope( - qkv, total_seq_len, self.rope_theta, self.num_attention_heads, self.num_key_value_heads + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_attention_heads), + (b, s, h_q * d), ) - # Step 3. Query and update KVCache - self.k_cache.append(op.squeeze(k, axis=0)) - self.v_cache.append(op.squeeze(v, axis=0)) - k = self.k_cache.view(total_seq_len) - v = self.v_cache.view(total_seq_len) - # Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V - output = op_ext.attention(q, k, v, casual_mask=attention_mask) - # Step 5. Apply output projection - return self.o_proj(output) + attn_output = self.o_proj(output) + return attn_output ACT2FN = { @@ -157,11 +147,10 @@ def __init__(self, config: QWen2Config): config.hidden_size, -1, config.rms_norm_eps, bias=False ) - def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): out = self.input_layernorm(hidden_states) - out = self.self_attn(out, attention_mask, total_seq_len) + out = self.self_attn(out, paged_kv_cache, layer_id) hidden_states = out + hidden_states - out = self.post_attention_layernorm(hidden_states) out = self.mlp(out) hidden_states = out + hidden_states @@ -176,92 +165,179 @@ def __init__(self, config: QWen2Config): ) self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) - def forward(self, input_ids: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): - hidden_states = self.embed_tokens(input_ids) - for layer in self.layers: - hidden_states = layer(hidden_states, attention_mask, total_seq_len) + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) hidden_states = self.norm(hidden_states) return hidden_states -class QWen2LMHeadModel(nn.Module): +class QWen2LMHeadModel(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: QWen2Config): self.model = QWen2Model(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.dtype = config.dtype + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.intermediate_size = config.intermediate_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.rms_norm_eps = config.rms_norm_eps + self.rope_theta = config.rope_theta + self.vocab_size = config.vocab_size + self.tensor_parallel_shards = config.tensor_parallel_shards + self.head_dim = config.hidden_size // config.num_attention_heads def to(self, dtype: Optional[str] = None): super().to(dtype=dtype) if dtype is not None: self.dtype = dtype - def forward(self, inputs: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.model(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + return self.model.embed_tokens(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + def _index(x: te.Tensor): # x[:-1,:] b, s, d = x.shape return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") - hidden_states = self.model(inputs, attention_mask, total_seq_len) + hidden_states = self.model(input_embed, paged_kv_cache) hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") - return logits - - def prefill(self, inputs: Tensor, total_seq_len: tir.Var): - def _attention_mask(batch_size, seq_len, total_seq_len): - return te.compute( - (batch_size, 1, seq_len, total_seq_len), - lambda b, _, i, j: tir.if_then_else( - i < j - (total_seq_len - seq_len), - tir.min_value(self.dtype), - tir.max_value(self.dtype), - ), - name="attention_mask_prefill", - ) + return logits, paged_kv_cache - batch_size, seq_len = inputs.shape - attention_mask = op.tensor_expr_op( - _attention_mask, - name_hint="attention_mask_prefill", - args=[batch_size, seq_len, total_seq_len], - ) - return self.forward(inputs, attention_mask, total_seq_len) + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() - def decode(self, inputs: Tensor, total_seq_len: tir.Var): - batch_size, seq_len = inputs.shape - attention_mask = op.full( - shape=[batch_size, 1, seq_len, total_seq_len], - fill_value=tir.max_value(self.dtype), + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, dtype=self.dtype, ) - return self.forward(inputs, attention_mask, total_seq_len) - - @staticmethod - def softmax_with_temperature(logits: Tensor, temperature: Tensor): - return op.softmax(logits / temperature, axis=-1) def get_default_spec(self): - batch_size = 1 mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "prefill": { - "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), - "total_seq_len": int, + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", }, }, "decode": { - "inputs": nn.spec.Tensor([batch_size, 1], "int32"), - "total_seq_len": int, + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, "$": { "param_mode": "none", "effect_mode": "none", From 068d5ea9ca556f2f7a9603537b4f966da12b11f6 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 7 Mar 2024 16:12:59 -0500 Subject: [PATCH 042/531] [CI] Skip not supported quantization in model compilation test (#1904) This PR updates the model compilation test so that it will now skip a quantization when the model does not support. --- tests/python/integration/test_model_compile.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/python/integration/test_model_compile.py b/tests/python/integration/test_model_compile.py index 7dbdbf8109..c70b1b5b20 100644 --- a/tests/python/integration/test_model_compile.py +++ b/tests/python/integration/test_model_compile.py @@ -10,6 +10,8 @@ import tvm from mlc_chat.model import MODEL_PRESETS +from mlc_chat.model import MODELS as SUPPORTED_MODELS +from mlc_chat.quantization import QUANTIZATION as SUPPORTED_QUANTS from mlc_chat.support.constants import MLC_TEMP_DIR OPT_LEVEL = "O2" @@ -103,6 +105,11 @@ def test_model_compile(): # pylint: disable=too-many-locals TENSOR_PARALLEL_SHARDS, ) ): + if ( + SUPPORTED_QUANTS[quant].kind + not in SUPPORTED_MODELS[MODEL_PRESETS[model]["model_type"]].quantize + ): + continue if not target.startswith("cuda") and quant == "q4f16_ft": # FasterTransformer only works with cuda continue From 655ae5c188fd800aaf471cb0453a31ea986b8993 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 7 Mar 2024 19:33:30 -0500 Subject: [PATCH 043/531] [Serving] Add missing header for `std::iota` (#1905) The header `` was missed, which may have caused build failure on Windows. This PR adds the header. --- cpp/serve/engine_actions/batch_decode.cc | 2 ++ cpp/serve/engine_actions/batch_draft.cc | 2 ++ 2 files changed, 4 insertions(+) diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index 23b2e6bca4..2af5d86404 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -3,6 +3,8 @@ * \file serve/engine_actions/batch_decode.cc */ +#include + #include "../../random.h" #include "../config.h" #include "../model.h" diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index 617d826296..cef66443db 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -3,6 +3,8 @@ * \file serve/engine_actions/batch_draft.cc */ +#include + #include "../config.h" #include "../model.h" #include "../sampler.h" From 068091c7800803231dabb7ba609b488de3694eb3 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 8 Mar 2024 07:07:40 -0500 Subject: [PATCH 044/531] [Serving] Fix Model TokenEmbed function with TP (#1906) This PR fixes a severe bug introduced by #1899. Since #1899, we no longer copy the embedding back from worker 0 when using tensor parallelism. However, we did not synchronize with the worker 0. This will cause the following issue: in batch prefill, we will continuously call TokenEmbed for multiple times. Each time, we will copy the token ids to the `token_ids` NDArray on worker 0. If we do not synchronize with worker 0, then it is possible that the local token ids have been updated for multiple times, before the first `CopyToWorker0` really starts to execute on the worker 0 side. As a result, at the time of executing the token ids copy to worker 0, the local token ids might be wrong (by "wrong", say we are executing the copying of seq 0's token ids, then the actual local token ids array might have already been seq 3's token ids). As a result, the issue will cause the batch prefill behave completely wrong. This PR adds a synchronization with worker 0 explicitly. --- cpp/serve/model.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 113648b3a9..d7ee205ac0 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -109,6 +109,9 @@ class ModelImpl : public ModelObj { if (dst != nullptr) { CHECK(dst->defined()); ft_.nd_copy_embedding_to_offset_func_(embeddings, *dst, offset); + if (ft_.use_disco) { + ft_.sess->SyncWorker(0); + } return *dst; } else { CHECK_EQ(offset, 0); From 73fa4a27149e8e2874fd4f625c0094824f37f097 Mon Sep 17 00:00:00 2001 From: Ricardo Lu <37237570+gesanqiu@users.noreply.github.com> Date: Fri, 8 Mar 2024 23:23:00 +0800 Subject: [PATCH 045/531] [SLM] Add support for Orion architecture. (#1883) This is a PR for supporting [OrionStarAI/Orion-14B-Chat](https://huggingface.co/OrionStarAI/Orion-14B-Chat). --- cpp/conv_templates.cc | 18 + python/mlc_chat/interface/gen_config.py | 1 + python/mlc_chat/model/model.py | 14 + python/mlc_chat/model/model_preset.py | 30 ++ python/mlc_chat/model/orion/__init__.py | 0 python/mlc_chat/model/orion/orion_loader.py | 88 +++++ python/mlc_chat/model/orion/orion_model.py | 369 ++++++++++++++++++ .../model/orion/orion_quantization.py | 37 ++ 8 files changed, 557 insertions(+) create mode 100644 python/mlc_chat/model/orion/__init__.py create mode 100644 python/mlc_chat/model/orion/orion_loader.py create mode 100644 python/mlc_chat/model/orion/orion_model.py create mode 100644 python/mlc_chat/model/orion/orion_quantization.py diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index b0928b7457..729e6f3b38 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -719,6 +719,23 @@ Conversation GemmaInstruction() { return conv; } +Conversation Orion() { + Conversation conv; + conv.name = "orion"; + conv.system = ""; + conv.roles = {"Human: ", "Assitant: "}; + conv.messages = {}; + conv.offset = 0; + conv.separator_style = SeparatorStyle::kSepRoleMsg; + conv.seps = {"\n\n", ""}; + conv.role_msg_sep = ""; + conv.role_empty_sep = ""; + conv.stop_tokens = {2}; + conv.stop_str = ""; + conv.add_bos = true; + return conv; +} + } // namespace using ConvFactory = Conversation (*)(); @@ -760,6 +777,7 @@ Conversation Conversation::FromTemplate(const std::string& name) { {"baichuan", ChatML}, {"gemma_instruction", GemmaInstruction}, {"internlm", ChatML}, + {"orion", Orion}, }; auto it = factory.find(name); if (it == factory.end()) { diff --git a/python/mlc_chat/interface/gen_config.py b/python/mlc_chat/interface/gen_config.py index 444c200915..d45e1daff0 100644 --- a/python/mlc_chat/interface/gen_config.py +++ b/python/mlc_chat/interface/gen_config.py @@ -230,4 +230,5 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b "phi-2", "stablelm-2", "gemma_instruction", + "orion", } diff --git a/python/mlc_chat/model/model.py b/python/mlc_chat/model/model.py index e03d89762a..ef67c8e5ab 100644 --- a/python/mlc_chat/model/model.py +++ b/python/mlc_chat/model/model.py @@ -17,6 +17,7 @@ from .llama import llama_loader, llama_model, llama_quantization from .mistral import mistral_loader, mistral_model, mistral_quantization from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization +from .orion import orion_loader, orion_model, orion_quantization from .phi import phi_loader, phi_model, phi_quantization from .qwen import qwen_loader, qwen_model, qwen_quantization from .qwen2 import qwen2_loader, qwen2_model, qwen2_quantization @@ -278,4 +279,17 @@ class Model: "ft-quant": rwkv5_quantization.ft_quant, }, ), + "orion": Model( + name="orion", + model=orion_model.OrionForCasualLM, + config=orion_model.OrionConfig, + source={ + "huggingface-torch": orion_loader.huggingface, + "huggingface-safetensor": orion_loader.huggingface, + }, + quantize={ + "no-quant": orion_quantization.no_quant, + "group-quant": orion_quantization.group_quant, + }, + ), } diff --git a/python/mlc_chat/model/model_preset.py b/python/mlc_chat/model/model_preset.py index 9314b1143b..561109b77e 100644 --- a/python/mlc_chat/model/model_preset.py +++ b/python/mlc_chat/model/model_preset.py @@ -559,4 +559,34 @@ "use_cache": True, "vocab_size": 65536, }, + "orion": { + "architectures": ["OrionForCausalLM"], + "auto_map": { + "AutoConfig": "configuration_orion.OrionConfig", + "AutoModelForCausalLM": "modeling_orion.OrionForCausalLM", + }, + "tokenizer_class": "OrionTokenizer", + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 5120, + "model_type": "orion", + "initializer_range": 0.02, + "intermediate_size": 15360, + "max_position_embeddings": 4096, + "max_sequence_length": 4096, + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 40, + "pad_token_id": 0, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 10000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.34.0", + "use_cache": True, + "vocab_size": 84608, + }, } diff --git a/python/mlc_chat/model/orion/__init__.py b/python/mlc_chat/model/orion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_chat/model/orion/orion_loader.py b/python/mlc_chat/model/orion/orion_loader.py new file mode 100644 index 0000000000..61c8138634 --- /dev/null +++ b/python/mlc_chat/model/orion/orion_loader.py @@ -0,0 +1,88 @@ +""" +This file specifies how MLC's Orion parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +import numpy as np + +from mlc_chat.loader import ExternMapping +from mlc_chat.quantization import Quantization + +from .orion_model import OrionConfig, OrionForCasualLM + + +def huggingface(model_config: OrionConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : OrionConfig + The configuration of the Orion model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = OrionForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_chat/model/orion/orion_model.py b/python/mlc_chat/model/orion/orion_model.py new file mode 100644 index 0000000000..4692c67907 --- /dev/null +++ b/python/mlc_chat/model/orion/orion_model.py @@ -0,0 +1,369 @@ +""" +Implementation for Orion-14B architecture. +TODO: add docstring +""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_chat import op as op_ext +from mlc_chat.nn import PagedKVCache, RopeMode +from mlc_chat.support import logging +from mlc_chat.support import tensor_parallel as tp +from mlc_chat.support.config import ConfigBase +from mlc_chat.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class OrionConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Orion model.""" + + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_hidden_layers: int + rms_norm_eps: float + vocab_size: int + position_embedding_base: int = 0 + context_window_size: int = 0 + prefill_chunk_size: int = 0 + num_key_value_heads: int = 0 + head_dim: int = 0 + tensor_parallel_shards: int = 1 + max_batch_size: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.num_key_value_heads == 0: + self.num_key_value_heads = self.num_attention_heads + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size + assert self.num_attention_heads % self.num_key_value_heads == 0 + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + + +# pylint: disable=invalid-name,missing-docstring + + +class OrionFFN(nn.Module): + def __init__(self, config: OrionConfig): + super().__init__() + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.gate_up_proj = nn.Linear( + in_features=config.hidden_size, + out_features=2 * self.intermediate_size, + bias=False, + ) + self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x: Tensor): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.down_proj(op.silu(x1) * x2) + + +class OrionAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: OrionConfig): + self.head_dim = config.head_dim + self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards + assert ( + config.num_key_value_heads % config.tensor_parallel_shards == 0 + ), f"num_kv_heads({config.num_key_value_heads}) must be divisible by tensor_parallel_shards" + assert ( + config.num_key_value_heads >= config.tensor_parallel_shards + ), f"Too large tensor_parallel_shards, must be smaller than {config.num_key_value_heads}" + self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards + self.qkv_proj = nn.Linear( + in_features=config.hidden_size, + out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim, + bias=False, + ) + self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads + b, s, _ = hidden_states.shape + # QKV Projection + qkv = self.qkv_proj(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + (b, s, h_q * d), + ) + return self.o_proj(output) + + +class OrionDecoderLayer(nn.Module): + def __init__(self, config: OrionConfig): + rms_norm_eps = config.rms_norm_eps + self.self_attn = OrionAttention(config) + self.mlp = OrionFFN(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, rms_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, rms_norm_eps) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_q_heads * hd + k = self.self_attn.num_kv_heads * hd + v = self.self_attn.num_kv_heads * hd + i = self.mlp.intermediate_size + _set(self.self_attn.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0)) + _set(self.self_attn.o_proj, tp.ShardSingleDim("_shard_o", dim=1)) + _set(self.mlp.gate_up_proj, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0)) + _set(self.mlp.down_proj, tp.ShardSingleDim("_shard_mlp_down", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) + hidden_states = self._apply_residual(out, residual=hidden_states) + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = self._apply_residual(out, residual=hidden_states) + return hidden_states + + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + + +class OrionModel(nn.Module): + def __init__(self, config: OrionConfig): + assert config.hidden_size % config.num_attention_heads == 0 + self.embed_tokens = nn.Embedding("vocab_size", config.hidden_size) + self.layers = nn.ModuleList( + [OrionDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.LayerNorm(config.hidden_size, config.rms_norm_eps) + self.tensor_parallel_shards = config.tensor_parallel_shards + + def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = input_embed + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class OrionForCasualLM(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: OrionConfig): + self.model = OrionModel(config) + self.lm_head = nn.Linear(config.hidden_size, "vocab_size", bias=False) + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.vocab_size = config.vocab_size + self.rope_theta = config.position_embedding_base + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.model(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + return self.model.embed_tokens(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_chat/model/orion/orion_quantization.py b/python/mlc_chat/model/orion/orion_quantization.py new file mode 100644 index 0000000000..d34f59b2dd --- /dev/null +++ b/python/mlc_chat/model/orion/orion_quantization.py @@ -0,0 +1,37 @@ +"""This file specifies how MLC's Orion parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_chat.loader import QuantizeMapping +from mlc_chat.quantization import GroupQuantize, NoQuantize + +from .orion_model import OrionConfig, OrionForCasualLM + + +def group_quant( + model_config: OrionConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Orion-architecture model using group quantization.""" + model: nn.Module = OrionForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: OrionConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Orion2 model without quantization.""" + model: nn.Module = OrionForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map From 3f3e3fdad467c7eb61904089e2c29d4e81edeee2 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 8 Mar 2024 12:53:29 -0500 Subject: [PATCH 046/531] [Model] Eliminate the reshape in embedding func (#1908) Prior to this PR, there is a trailing reshape kernel at the end of the embedding func. The reshape is not necessarily needed to be as a kernel, which consumes extra time during execution. This PR eliminates the reshape in the embedding function by updating the signature of the embedding func, so that now it only takes the plain 1D token ids as input. --- cpp/llm_chat.cc | 39 +++++++++++++++- cpp/serve/model.cc | 45 ++++++++----------- .../compiler_pass/attach_to_ir_module.py | 2 +- .../mlc_chat/model/baichuan/baichuan_model.py | 2 +- python/mlc_chat/model/gemma/gemma_model.py | 2 +- python/mlc_chat/model/gpt2/gpt2_model.py | 2 +- .../model/gpt_bigcode/gpt_bigcode_model.py | 2 +- .../mlc_chat/model/gpt_neox/gpt_neox_model.py | 2 +- .../mlc_chat/model/internlm/internlm_model.py | 2 +- python/mlc_chat/model/llama/llama_model.py | 2 +- python/mlc_chat/model/orion/orion_model.py | 2 +- python/mlc_chat/model/phi/phi_model.py | 2 +- python/mlc_chat/model/qwen/qwen_model.py | 2 +- python/mlc_chat/model/qwen2/qwen2_model.py | 2 +- python/mlc_chat/model/rwkv5/rwkv5_model.py | 2 +- .../model/stable_lm/stablelm_model.py | 2 +- 16 files changed, 71 insertions(+), 41 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index b7a426a17f..cfb08082f5 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -292,6 +292,9 @@ struct FunctionTable { } this->fkvcache_array_popn_ = get_global_func("vm.builtin.attention_kv_cache_array_popn"); } + + this->nd_view_func_ = get_global_func("vm.builtin.reshape"); + this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of"); } ObjectRef Empty(ShapeTuple shape, DataType dtype, Device device) const { @@ -348,6 +351,9 @@ struct FunctionTable { bool support_backtracking_kv_; PackedFunc fkvcache_array_popn_; ModelMetadata model_metadata_; + + PackedFunc nd_view_func_; + PackedFunc nd_get_shape_func_; }; } // namespace @@ -1358,10 +1364,14 @@ class LLMChat { ObjectRef input_data = ft_.CopyToWorker0(this->GetInputTokenNDArray(input_tokens)); if (sliding_window_size_ == -1) { if (ft_.use_kv_state) { + int input_len = input_tokens.size(); IntTuple seq_ids_tuple({0}); - ShapeTuple input_len_shape = ShapeTuple({static_cast(input_tokens.size())}); + ShapeTuple input_len_shape{input_len}; ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, input_len_shape); + input_data = ft_.nd_view_func_(input_data, input_len_shape); auto embed = ft_.embed_func_(input_data, params_); + ShapeTuple embedding_shape = {1, input_len, GetHiddenSizeFromEmbedding(embed)}; + embed = ft_.nd_view_func_(embed, embedding_shape); ret = ft_.prefill_func_(embed, kv_cache_, params_); ft_.kv_cache_end_forward_func_(kv_cache_); } else { @@ -1397,7 +1407,10 @@ class LLMChat { IntTuple seq_ids_tuple({0}); IntTuple append_length({1}); ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, append_length); + input_data = ft_.nd_view_func_(input_data, append_length); auto embed = ft_.embed_func_(input_data, params_); + ShapeTuple embedding_shape = {1, 1, GetHiddenSizeFromEmbedding(embed)}; + embed = ft_.nd_view_func_(embed, embedding_shape); ret = ft_.decode_func_(embed, kv_cache_, params_); ft_.kv_cache_end_forward_func_(kv_cache_); } else { @@ -1424,6 +1437,26 @@ class LLMChat { } } + int GetHiddenSizeFromEmbedding(ObjectRef embedding) { + if (this->hidden_size_ != -1) { + return this->hidden_size_; + } + // Get the shape of the embedding tensor for hidden size. + ShapeTuple embedding_shape; + if (ft_.use_disco) { + ICHECK(embedding->IsInstance()); + ObjectRef shape_ref = ft_.nd_get_shape_func_(embedding); + embedding_shape = Downcast(shape_ref)->DebugGetFromRemote(0); + } else { + NDArray embedding_nd = Downcast(embedding); + embedding_shape = embedding_nd.Shape(); + } + ICHECK_EQ(embedding_shape.size(), 2); + ICHECK_GT(embedding_shape[0], 1); + this->hidden_size_ = embedding_shape[1]; + return this->hidden_size_; + } + // run forward compute with embeddings NDArray ForwardEmbeddings(NDArray embeddings, int64_t cur_pos) { if (ft_.use_disco) { @@ -1586,6 +1619,10 @@ class LLMChat { // sliding window cache offset int64_t sliding_window_cache_offset_{0}; //---------------------------- + // Model configurations + //---------------------------- + int hidden_size_ = -1; + //---------------------------- // Tokenizer //---------------------------- // internal tokenizer diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index d7ee205ac0..68bb6f171f 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -102,7 +102,6 @@ class ModelImpl : public ModelObj { CopyArrayToDevice(vec_token_ids, &input_token_ids_, dtype, prefill_chunk_size_, device_); ICHECK_EQ(token_ids_nd->ndim, 1); ICHECK_EQ(token_ids_nd->shape[0], num_tokens); - token_ids_nd = token_ids_nd.CreateView({1, num_tokens}, dtype); auto token_ids_dref_or_nd = ft_.CopyToWorker0(token_ids_nd, "token_ids", {prefill_chunk_size_}); ObjectRef embeddings = ft_.embed_func_(token_ids_dref_or_nd, params_); @@ -152,10 +151,9 @@ class ModelImpl : public ModelObj { // embeddings: (1, n, h) NDArray embeddings_nd = Downcast(embeddings); ICHECK_NE(hidden_size_, -1); - ICHECK_EQ(embeddings_nd->ndim, 3); - ICHECK_EQ(embeddings_nd->shape[0], 1); - ICHECK_GE(embeddings_nd->shape[1], total_length); - ICHECK_EQ(embeddings_nd->shape[2], hidden_size_); + ICHECK_EQ(embeddings_nd->ndim, 2); + ICHECK_GE(embeddings_nd->shape[0], total_length); + ICHECK_EQ(embeddings_nd->shape[1], hidden_size_); ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); embeddings_dref_or_nd = @@ -211,10 +209,9 @@ class ModelImpl : public ModelObj { // embeddings: (1, b, h) NDArray embeddings_nd = Downcast(embeddings); ICHECK_NE(hidden_size_, -1); - ICHECK_EQ(embeddings_nd->ndim, 3); - ICHECK_EQ(embeddings_nd->shape[0], 1); - ICHECK_GE(embeddings_nd->shape[1], num_sequence); - ICHECK_EQ(embeddings_nd->shape[2], hidden_size_); + ICHECK_EQ(embeddings_nd->ndim, 2); + ICHECK_GE(embeddings_nd->shape[0], num_sequence); + ICHECK_EQ(embeddings_nd->shape[1], hidden_size_); ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); embeddings_dref_or_nd = @@ -275,10 +272,9 @@ class ModelImpl : public ModelObj { // embeddings: (1, n, h) NDArray embeddings_nd = Downcast(embeddings); ICHECK_NE(hidden_size_, -1); - ICHECK_EQ(embeddings_nd->ndim, 3); - ICHECK_EQ(embeddings_nd->shape[0], 1); - ICHECK_GE(embeddings_nd->shape[1], total_length); - ICHECK_EQ(embeddings_nd->shape[2], hidden_size_); + ICHECK_EQ(embeddings_nd->ndim, 2); + ICHECK_GE(embeddings_nd->shape[0], total_length); + ICHECK_EQ(embeddings_nd->shape[1], hidden_size_); ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); embeddings_dref_or_nd = @@ -374,10 +370,9 @@ class ModelImpl : public ModelObj { NDArray embedding_nd = Downcast(embedding); embedding_shape = embedding_nd.Shape(); } - ICHECK_EQ(embedding_shape.size(), 3); - ICHECK_EQ(embedding_shape[0], 1); - ICHECK_EQ(embedding_shape[1], prefill_chunk_size_); - this->hidden_size_ = embedding_shape[2]; + ICHECK_EQ(embedding_shape.size(), 2); + ICHECK_EQ(embedding_shape[0], prefill_chunk_size_); + this->hidden_size_ = embedding_shape[1]; return embedding; } @@ -453,20 +448,18 @@ class ModelImpl : public ModelObj { TVM_REGISTER_GLOBAL("mlc.copy_embedding_to_offset") .set_body_typed([](NDArray embedding, NDArray dst, int offset) { - // embedding: (1, m, hidden_size) - // dst: (1, prefill_chunk_size, hidden_size) - ICHECK_EQ(embedding->ndim, 3); - ICHECK_EQ(embedding->shape[0], 1); - ICHECK_EQ(dst->ndim, 3); - ICHECK_EQ(dst->shape[0], 1); - ICHECK_LE(embedding->shape[1] + offset, dst->shape[1]); - ICHECK_EQ(embedding->shape[2], dst->shape[2]); + // embedding: (m, hidden_size) + // dst: (prefill_chunk_size, hidden_size) + ICHECK_EQ(embedding->ndim, 2); + ICHECK_EQ(dst->ndim, 2); + ICHECK_LE(embedding->shape[0] + offset, dst->shape[0]); + ICHECK_EQ(embedding->shape[1], dst->shape[1]); const DLTensor& copy_src = *(embedding.operator->()); const DLTensor* p_copy_dst = dst.operator->(); DLTensor copy_dst = *p_copy_dst; copy_dst.shape = embedding->shape; copy_dst.byte_offset = - offset * embedding->shape[2] * ((embedding->dtype.bits * embedding->dtype.lanes + 7) / 8); + offset * embedding->shape[1] * ((embedding->dtype.bits * embedding->dtype.lanes + 7) / 8); NDArray::CopyFromTo(©_src, ©_dst); }); diff --git a/python/mlc_chat/compiler_pass/attach_to_ir_module.py b/python/mlc_chat/compiler_pass/attach_to_ir_module.py index 47baacd755..9f1271dcf6 100644 --- a/python/mlc_chat/compiler_pass/attach_to_ir_module.py +++ b/python/mlc_chat/compiler_pass/attach_to_ir_module.py @@ -86,7 +86,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR bb.emit_func_output( bb.emit( relax.op.builtin.alloc_tensor( - relax.ShapeExpr([1, self.metadata["prefill_chunk_size"], hidden_size]), + relax.ShapeExpr([self.metadata["prefill_chunk_size"], hidden_size]), dtype, runtime_device_index=0, ) diff --git a/python/mlc_chat/model/baichuan/baichuan_model.py b/python/mlc_chat/model/baichuan/baichuan_model.py index 6119afc10f..266d9678c3 100644 --- a/python/mlc_chat/model/baichuan/baichuan_model.py +++ b/python/mlc_chat/model/baichuan/baichuan_model.py @@ -254,7 +254,7 @@ def create_paged_kv_cache( def get_default_spec(self): mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_chat/model/gemma/gemma_model.py b/python/mlc_chat/model/gemma/gemma_model.py index 080147d393..94768a0d89 100644 --- a/python/mlc_chat/model/gemma/gemma_model.py +++ b/python/mlc_chat/model/gemma/gemma_model.py @@ -316,7 +316,7 @@ def create_paged_kv_cache( def get_default_spec(self): mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_chat/model/gpt2/gpt2_model.py b/python/mlc_chat/model/gpt2/gpt2_model.py index 1d930ba43d..83f65502f8 100644 --- a/python/mlc_chat/model/gpt2/gpt2_model.py +++ b/python/mlc_chat/model/gpt2/gpt2_model.py @@ -308,7 +308,7 @@ def create_paged_kv_cache( def get_default_spec(self): mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py index 5557ca1614..302b093125 100644 --- a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py @@ -285,7 +285,7 @@ def create_paged_kv_cache( def get_default_spec(self): mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_chat/model/gpt_neox/gpt_neox_model.py b/python/mlc_chat/model/gpt_neox/gpt_neox_model.py index b5bd89e9a6..895655d60b 100644 --- a/python/mlc_chat/model/gpt_neox/gpt_neox_model.py +++ b/python/mlc_chat/model/gpt_neox/gpt_neox_model.py @@ -340,7 +340,7 @@ def create_paged_kv_cache( def get_default_spec(self): mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_chat/model/internlm/internlm_model.py b/python/mlc_chat/model/internlm/internlm_model.py index 2c88ccaa71..153905f55e 100644 --- a/python/mlc_chat/model/internlm/internlm_model.py +++ b/python/mlc_chat/model/internlm/internlm_model.py @@ -255,7 +255,7 @@ def create_paged_kv_cache( def get_default_spec(self): mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_chat/model/llama/llama_model.py b/python/mlc_chat/model/llama/llama_model.py index 8d54829dc0..69884e8492 100644 --- a/python/mlc_chat/model/llama/llama_model.py +++ b/python/mlc_chat/model/llama/llama_model.py @@ -299,7 +299,7 @@ def create_paged_kv_cache( def get_default_spec(self): mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_chat/model/orion/orion_model.py b/python/mlc_chat/model/orion/orion_model.py index 4692c67907..5894a5ab61 100644 --- a/python/mlc_chat/model/orion/orion_model.py +++ b/python/mlc_chat/model/orion/orion_model.py @@ -300,7 +300,7 @@ def create_paged_kv_cache( def get_default_spec(self): mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_chat/model/phi/phi_model.py b/python/mlc_chat/model/phi/phi_model.py index 863ecd7298..372598d5ae 100644 --- a/python/mlc_chat/model/phi/phi_model.py +++ b/python/mlc_chat/model/phi/phi_model.py @@ -410,7 +410,7 @@ def create_paged_kv_cache( def get_default_spec(self): mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_chat/model/qwen/qwen_model.py b/python/mlc_chat/model/qwen/qwen_model.py index b301ff13fe..b5879a92a2 100644 --- a/python/mlc_chat/model/qwen/qwen_model.py +++ b/python/mlc_chat/model/qwen/qwen_model.py @@ -260,7 +260,7 @@ def create_paged_kv_cache( def get_default_spec(self): mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_chat/model/qwen2/qwen2_model.py b/python/mlc_chat/model/qwen2/qwen2_model.py index 8fac47fa3e..a5dc351a9e 100644 --- a/python/mlc_chat/model/qwen2/qwen2_model.py +++ b/python/mlc_chat/model/qwen2/qwen2_model.py @@ -278,7 +278,7 @@ def create_paged_kv_cache( def get_default_spec(self): mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_chat/model/rwkv5/rwkv5_model.py b/python/mlc_chat/model/rwkv5/rwkv5_model.py index 066ff7d9f4..e88efa4aec 100644 --- a/python/mlc_chat/model/rwkv5/rwkv5_model.py +++ b/python/mlc_chat/model/rwkv5/rwkv5_model.py @@ -389,7 +389,7 @@ def get_default_spec(self): batch_size = 1 mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_chat/model/stable_lm/stablelm_model.py b/python/mlc_chat/model/stable_lm/stablelm_model.py index edb4885123..8193c15ccc 100644 --- a/python/mlc_chat/model/stable_lm/stablelm_model.py +++ b/python/mlc_chat/model/stable_lm/stablelm_model.py @@ -265,7 +265,7 @@ def create_paged_kv_cache( def get_default_spec(self): mod_spec = { "embed": { - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", From 3f05a1f587871cffabcf5884a515e6eff0e38a53 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Fri, 8 Mar 2024 17:12:43 -0500 Subject: [PATCH 047/531] [Pass] Low batch GEMM using GEMV-like schedule (#1769) When batch size is small, GEMM in MLP of decode stage can be dispatched into a specialized GEMV-like schedule to improve efficiency. GEMM with a dynamic var in spatial axis will now be lowered into ```python if dyn_var <= 8: low_batch_gemv() else: normal_gemm() ``` --- .../compiler_pass/low_batch_specialization.py | 63 +++++++++++++++++++ python/mlc_chat/compiler_pass/pipeline.py | 2 + 2 files changed, 65 insertions(+) create mode 100644 python/mlc_chat/compiler_pass/low_batch_specialization.py diff --git a/python/mlc_chat/compiler_pass/low_batch_specialization.py b/python/mlc_chat/compiler_pass/low_batch_specialization.py new file mode 100644 index 0000000000..63b29fb2ec --- /dev/null +++ b/python/mlc_chat/compiler_pass/low_batch_specialization.py @@ -0,0 +1,63 @@ +"""A compiler pass that dispatch low-batch-gemm to gemv schedule.""" +import tvm +from tvm import dlight as dl +from tvm import tir +from tvm.ir.module import IRModule + +# pylint: disable=too-many-locals,not-callable + + +@tvm.transform.module_pass(opt_level=0, name="LowBatchGemvSpecialize") +class LowBatchGemvSpecialize: # pylint: disable=too-few-public-methods + """A compiler pass that dispatch low-batch-gemm to gemv schedule.""" + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + for g_var, func in mod.functions_items(): + if isinstance(func, tir.PrimFunc): + low_batch_range = [2, 8] + buckets = [2, 4] + low_batch_funcs = [] + for bucket in buckets: + low_batch_mod = IRModule({}) + low_batch_mod["main"] = func + low_batch_mod = dl.ApplyDefaultSchedule( + dl.gpu.LowBatchGEMV(bucket), + )(low_batch_mod) + low_batch_funcs.append(low_batch_mod["main"]) + if any( + tvm.ir.structural_equal(low_batch_func, func) + for low_batch_func in low_batch_funcs + ): + continue + buffers = func.buffer_map.values() + shapes = [buffer.shape for buffer in buffers] + symbolic_vars = set( + expr for shape in shapes for expr in shape if isinstance(expr, tir.Var) + ) + assert len(symbolic_vars) == 1, symbolic_vars + gemm_mod = IRModule({}) + gemm_mod["main"] = func + gemm_mod = dl.ApplyDefaultSchedule( + dl.gpu.Matmul(), + )(gemm_mod) + gemm_func = gemm_mod["main"] + sym_var = list(symbolic_vars)[0] + body = gemm_func.body + for i, range_limit in reversed(list(enumerate(low_batch_range))): + body = tir.IfThenElse( + tir.op.tvm_thread_invariant(sym_var <= range_limit), + low_batch_funcs[i].body, + body, + ) + body = tir.Block([], [], [], "root", body) + body = tir.BlockRealize([], True, body) + new_func = func.with_body(body) + new_func = new_func.with_attr("tir.is_scheduled", 1) + new_func = new_func.with_attr("tir.HoistIfThenElseExprWithBlock", 1) + mod.update_func(g_var, new_func) + return mod diff --git a/python/mlc_chat/compiler_pass/pipeline.py b/python/mlc_chat/compiler_pass/pipeline.py index 00d0d3c4f8..e13ff2a404 100644 --- a/python/mlc_chat/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler_pass/pipeline.py @@ -29,6 +29,7 @@ from .fuse_ft_dequantize_matmul_epilogue import FuseFTDequantizeEpilogue from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc +from .low_batch_specialization import LowBatchGemvSpecialize from .scatter_tuple_get_item import ScatterTupleGetItem logger = logging.getLogger(__name__) @@ -122,6 +123,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I _DebugDump("debug-phase3.py", debug_dump, show_meta=False), # Phase 4. Low-level Optimizations _LogProgress("Running TVM Dlight low-level optimizations"), + LowBatchGemvSpecialize(), dl.ApplyDefaultSchedule( dl.gpu.Matmul(), dl.gpu.GEMV(), From c2258aef97e6cacea2d111f3a6a9bd72e7c765f5 Mon Sep 17 00:00:00 2001 From: Git bot Date: Fri, 8 Mar 2024 23:48:45 +0000 Subject: [PATCH 048/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 2c1ce3ab46..f06d486b4a 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 2c1ce3ab467f9367c14afd9579ed1388aaae0b90 +Subproject commit f06d486b4a1a27f0bbb072688a5fc41e7b15323c From 1b3cfd599e0493db66168c1ed13c3fa3d00de46e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 8 Mar 2024 20:15:37 -0500 Subject: [PATCH 049/531] [Serving] Avoid unnecessary worker sync in Model (#1909) Following up #1906, this PR removes the synchronization given it is avoidable. We use another approach to avoid the write-after-write issue. The key to address the issue is to make sure the addresses to be copied to worker 0 is not rewritten before the copy actually happens. So we pre-allocate a large host array to hold all the token ids, and for each sequence, we copy its token ids to the offset given when calling TokenEmbed, so that we can make sure an address will not be written twice before copy happens. --- cpp/serve/function_table.cc | 28 ++++++++++++---- cpp/serve/function_table.h | 5 +-- cpp/serve/model.cc | 64 +++++++++++-------------------------- 3 files changed, 43 insertions(+), 54 deletions(-) diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 46855221d1..bbeb23ec89 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -42,6 +42,7 @@ PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func, } void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object model_config) { + local_gpu_device = device; Device null_device{DLDeviceType(0), 0}; int num_shards; { @@ -53,6 +54,7 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object } } this->model_config = model_config; + this->cached_buffers = Map(); if (num_shards > 1) { String lib_path{nullptr}; @@ -87,7 +89,6 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), lib_path, null_device); - this->disco_buffers = Map(); this->mod_get_func = [this, fmodule_get_function = sess->GetGlobalFunc("runtime.ModuleGetFunction")]( const std::string& name) -> PackedFunc { @@ -236,23 +237,36 @@ ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) } } -ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String tensor_name, +ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String buffer_cache_key, ShapeTuple max_reserved_shape) { - Device null_device{DLDeviceType(0), 0}; + ICHECK(host_array->device.device_type == DLDeviceType::kDLCPU); if (this->use_disco) { + Device null_device{DLDeviceType(0), 0}; DRef buffer(nullptr); - if (this->disco_buffers.count(tensor_name)) { - buffer = this->disco_buffers[tensor_name]; + auto it = this->cached_buffers.find(buffer_cache_key); + if (it != this->cached_buffers.end()) { + buffer = Downcast((*it).second); } else { buffer = Downcast(this->Empty(max_reserved_shape, host_array.DataType(), null_device)); - this->disco_buffers.Set(tensor_name, buffer); + this->cached_buffers.Set(buffer_cache_key, buffer); } ShapeTuple real_shape = host_array.Shape(); DRef buffer_view = nd_view_func_(buffer, real_shape); sess->CopyToWorker0(host_array, buffer_view); return buffer_view; } else { - return host_array; + auto it = this->cached_buffers.find(buffer_cache_key); + NDArray buffer{nullptr}; + if (it != this->cached_buffers.end()) { + buffer = Downcast((*it).second); + } else { + buffer = NDArray::Empty(max_reserved_shape, host_array->dtype, local_gpu_device); + this->cached_buffers.Set(buffer_cache_key, buffer); + } + buffer = buffer.CreateView(host_array.Shape(), host_array->dtype); + DLTensor copy_dst = *(buffer.operator->()); + NDArray::CopyFromTo(host_array.operator->(), ©_dst); + return buffer; } } diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index 9f8d8daed6..9cc0ecb8e2 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -49,13 +49,14 @@ struct FunctionTable { ObjectRef Empty(ShapeTuple shape, DataType dtype, Device device) const; - ObjectRef CopyToWorker0(const NDArray& host_array, String tensor_name, + ObjectRef CopyToWorker0(const NDArray& host_array, String buffer_cache_key, ShapeTuple max_reserved_shape); bool use_disco = false; + Device local_gpu_device; Session sess{nullptr}; DRef disco_mod{nullptr}; - Map disco_buffers{nullptr}; + Map cached_buffers{nullptr}; tvm::runtime::Module local_vm{nullptr}; picojson::object model_config; diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 68bb6f171f..b5cb5c6b5a 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -6,6 +6,7 @@ #include "model.h" #include +#include #include #include @@ -17,38 +18,6 @@ namespace mlc { namespace llm { namespace serve { -/*********************** Utils ***********************/ - -/*! \brief Utility function that copies input array to the device. */ -template -NDArray CopyArrayToDevice(const std::vector& array, NDArray* dst, DLDataType dtype, - int default_init_size, Device device) { - ICHECK(!array.empty()); - ICHECK(dst != nullptr); - ICHECK(!dst->defined() || (*dst)->ndim == 1); - int64_t init_size = dst->defined() ? (*dst)->shape[0] : default_init_size; - while (init_size < static_cast(array.size())) { - init_size *= 2; - } - if (!dst->defined() || init_size != (*dst)->shape[0]) { - (*dst) = NDArray::Empty({init_size}, dtype, device); - } - ICHECK_LE(static_cast(array.size()), (*dst)->shape[0]); - NDArray view = dst->CreateView(ShapeTuple({static_cast(array.size())}), dtype); - - DLTensor copy_dst = *(view.operator->()); - DLTensor copy_src; - copy_src.data = const_cast(array.data()); - copy_src.device = Device{kDLCPU, 0}; - copy_src.ndim = 1; - copy_src.dtype = view->dtype; - copy_src.shape = view->shape; - copy_src.strides = nullptr; - copy_src.byte_offset = 0; - NDArray::CopyFromTo(©_src, ©_dst); - return view; -} - /*********************** Model Implementation ***********************/ class ModelImpl; @@ -89,17 +58,27 @@ class ModelImpl : public ModelObj { this->max_num_sequence_ = max_num_sequence; // Step 5. Reset this->Reset(); + // Step 6. Initialize the shared NDArray. + Device device_host{DLDeviceType::kDLCPU, 0}; + memory::Allocator* allocator = + memory::MemoryManager::GetOrCreateAllocator(device_host, memory::AllocatorType::kNaive); + ICHECK_NOTNULL(allocator); + token_ids_storage_ = + memory::Storage(allocator->Alloc({prefill_chunk_size_}, DataType::Int(32))); + this->logit_pos_arr_ = NDArray::Empty({max_num_sequence}, DataType::Int(32), device_host); } /*********************** Model Computation ***********************/ ObjectRef TokenEmbed(IntTuple token_ids, ObjectRef* dst, int offset) final { int num_tokens = token_ids.size(); - std::vector vec_token_ids(token_ids->data, token_ids->data + num_tokens); // Copy input token ids to device. DLDataType dtype(DataType::Int(32)); - NDArray token_ids_nd = - CopyArrayToDevice(vec_token_ids, &input_token_ids_, dtype, prefill_chunk_size_, device_); + NDArray token_ids_nd = token_ids_storage_->AllocNDArray(offset * 4, {num_tokens}, dtype); + int* p_token_ids = static_cast(token_ids_nd->data) + (token_ids_nd->byte_offset) / 4; + for (int i = 0; i < num_tokens; ++i) { + p_token_ids[i] = token_ids[i]; + } ICHECK_EQ(token_ids_nd->ndim, 1); ICHECK_EQ(token_ids_nd->shape[0], num_tokens); auto token_ids_dref_or_nd = ft_.CopyToWorker0(token_ids_nd, "token_ids", {prefill_chunk_size_}); @@ -108,9 +87,6 @@ class ModelImpl : public ModelObj { if (dst != nullptr) { CHECK(dst->defined()); ft_.nd_copy_embedding_to_offset_func_(embeddings, *dst, offset); - if (ft_.use_disco) { - ft_.sess->SyncWorker(0); - } return *dst; } else { CHECK_EQ(offset, 0); @@ -124,15 +100,13 @@ class ModelImpl : public ModelObj { CHECK_EQ(seq_ids.size(), lengths.size()); int num_sequences = seq_ids.size(); int total_length = 0; - std::vector logit_pos; - logit_pos.reserve(num_sequences); + + int* p_logit_pos = static_cast(logit_pos_arr_->data); for (int i = 0; i < num_sequences; ++i) { total_length += lengths[i]; - logit_pos.push_back(total_length - 1); + p_logit_pos[i] = total_length - 1; } - - NDArray logit_pos_nd = - CopyArrayToDevice(logit_pos, &logit_pos_arr_, DataType::Int(32), 32, device_); + NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); CHECK(ft_.prefill_func_.defined()) << "`prefill_with_embed` function is not found in the model. Please make sure the model is " @@ -442,7 +416,7 @@ class ModelImpl : public ModelObj { // Model parameters ObjectRef params_; // Shared NDArray - NDArray input_token_ids_{nullptr}; + memory::Storage token_ids_storage_{nullptr}; NDArray logit_pos_arr_{nullptr}; }; From 448c5c408659e45cbd1d351a6e4ec2a2ab3bed2e Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Sun, 10 Mar 2024 04:55:59 +0800 Subject: [PATCH 050/531] [Serving][Grammar] Enhance GrammarStateMatcher to support general grammar (#1917) --- cpp/serve/grammar/grammar.cc | 2 +- cpp/serve/grammar/grammar.h | 2 +- cpp/serve/grammar/grammar_builder.h | 4 +- cpp/serve/grammar/grammar_parser.cc | 46 ++-- cpp/serve/grammar/grammar_serializer.cc | 6 +- cpp/serve/grammar/grammar_serializer.h | 2 +- cpp/serve/grammar/grammar_simplifier.cc | 32 ++- cpp/serve/grammar/grammar_simplifier.h | 8 +- cpp/serve/grammar/grammar_state_matcher.cc | 7 + .../grammar/grammar_state_matcher_base.h | 229 ++++++++++++------ .../grammar/grammar_state_matcher_preproc.h | 10 +- .../grammar/grammar_state_matcher_state.h | 54 ++--- python/mlc_chat/serve/grammar.py | 6 +- tests/python/serve/test_grammar_parser.py | 34 ++- .../test_grammar_state_matcher_custom.py | 214 ++++++++++++++++ ....py => test_grammar_state_matcher_json.py} | 22 +- 16 files changed, 507 insertions(+), 171 deletions(-) create mode 100644 tests/python/serve/test_grammar_state_matcher_custom.py rename tests/python/serve/{test_grammar_state_matcher.py => test_grammar_state_matcher_json.py} (96%) diff --git a/cpp/serve/grammar/grammar.cc b/cpp/serve/grammar/grammar.cc index 697fb29d60..e10e6e7e45 100644 --- a/cpp/serve/grammar/grammar.cc +++ b/cpp/serve/grammar/grammar.cc @@ -103,7 +103,7 @@ elements_rest ::= ( "\t" ws "," ws elements ) characters ::= "" | [^"\\\r\n] characters | "\\" escape characters -escape ::= "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] digits ::= [0-9] | [0-9] digits fraction ::= "" | "." digits exponent ::= "" | "e" sign digits | "E" sign digits diff --git a/cpp/serve/grammar/grammar.h b/cpp/serve/grammar/grammar.h index 22e674527d..93d8f0e3c1 100644 --- a/cpp/serve/grammar/grammar.h +++ b/cpp/serve/grammar/grammar.h @@ -100,7 +100,7 @@ class BNFGrammarNode : public Object { // data format: [rule_expr_id0, rule_expr_id1, ...] kChoices, // data format: [rule_expr_id] - kStarQuantifier, + kCharacterClassStar, }; /*! \brief The object representing a rule expr. */ diff --git a/cpp/serve/grammar/grammar_builder.h b/cpp/serve/grammar/grammar_builder.h index eaa8af04f9..6044a76bd9 100644 --- a/cpp/serve/grammar/grammar_builder.h +++ b/cpp/serve/grammar/grammar_builder.h @@ -106,11 +106,11 @@ class BNFGrammarBuilder { return AddRuleExpr({RuleExprType::kChoices, data.data(), static_cast(data.size())}); } - int32_t AddStarQuantifier(int32_t element) { + int32_t AddCharacterClassStar(int32_t element) { std::vector data; data.push_back(element); return AddRuleExpr( - {RuleExprType::kStarQuantifier, data.data(), static_cast(data.size())}); + {RuleExprType::kCharacterClassStar, data.data(), static_cast(data.size())}); } size_t NumRuleExprs() const { return grammar_->NumRuleExprs(); } diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index b5f6be1849..6e9de834a5 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -263,29 +263,35 @@ int32_t EBNFParserImpl::ParseElement() { } int32_t EBNFParserImpl::HandleStarQuantifier(int32_t rule_expr_id) { - // rule ::= a* - // We have special support for star quantifier in BNFGrammar AST - auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_); - auto new_rule_expr_id = builder_.AddStarQuantifier(rule_expr_id); - return builder_.AddRule({new_rule_name, new_rule_expr_id}); + if (builder_.GetRuleExpr(rule_expr_id).type == BNFGrammarBuilder::RuleExprType::kCharacterClass) { + // We have special handling for character class star, e.g. [a-z]* + return builder_.AddCharacterClassStar(rule_expr_id); + } else { + // For other star quantifiers, we transform it into a rule: + // a* --> rule ::= a rule | "" + auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_); + auto new_rule_id = builder_.AddEmptyRule(new_rule_name); + auto ref_to_new_rule = builder_.AddRuleRef(new_rule_id); + auto new_rule_expr_id = builder_.AddChoices( + {builder_.AddSequence({rule_expr_id, ref_to_new_rule}), builder_.AddEmptyStr()}); + builder_.UpdateRuleBody(new_rule_id, new_rule_expr_id); + + // Return the reference to the new rule + return builder_.AddRuleRef(new_rule_id); + } } int32_t EBNFParserImpl::HandlePlusQuantifier(int32_t rule_expr_id) { // a+ --> rule ::= a rule | a - // We will use rule_expr a for two times in this case - // So first we create a rule for rule_expr a - auto a_rule_name = builder_.GetNewRuleName(cur_rule_name_); - auto a_rule_id = builder_.AddRule({a_rule_name, rule_expr_id}); - - // Then create the new rule_expr. auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_); auto new_rule_id = builder_.AddEmptyRule(new_rule_name); - auto a_plus_ref = builder_.AddRuleRef(new_rule_id); - auto a_ref1 = builder_.AddRuleRef(a_rule_id); - auto a_ref2 = builder_.AddRuleRef(a_rule_id); - auto new_rule_expr_id = builder_.AddChoices({builder_.AddSequence({a_ref1, a_plus_ref}), a_ref2}); + auto ref_to_new_rule = builder_.AddRuleRef(new_rule_id); + auto new_rule_expr_id = + builder_.AddChoices({builder_.AddSequence({rule_expr_id, ref_to_new_rule}), rule_expr_id}); builder_.UpdateRuleBody(new_rule_id, new_rule_expr_id); - return new_rule_id; + + // Return the reference to the new rule + return builder_.AddRuleRef(new_rule_id); } int32_t EBNFParserImpl::HandleQuestionQuantifier(int32_t rule_expr_id) { @@ -293,7 +299,7 @@ int32_t EBNFParserImpl::HandleQuestionQuantifier(int32_t rule_expr_id) { auto new_rule_name = builder_.GetNewRuleName(cur_rule_name_); auto new_rule_expr_id = builder_.AddChoices({rule_expr_id, builder_.AddEmptyStr()}); auto new_rule_id = builder_.AddRule({new_rule_name, new_rule_expr_id}); - return new_rule_id; + return builder_.AddRuleRef(new_rule_id); } int32_t EBNFParserImpl::ParseQuantifier() { @@ -308,11 +314,11 @@ int32_t EBNFParserImpl::ParseQuantifier() { switch (Peek(-1)) { case '*': // We assume that the star quantifier should be the body of some rule now - return builder_.AddStarQuantifier(rule_expr_id); + return HandleStarQuantifier(rule_expr_id); case '+': - return builder_.AddRuleRef(HandlePlusQuantifier(rule_expr_id)); + return HandlePlusQuantifier(rule_expr_id); case '?': - return builder_.AddRuleRef(HandleQuestionQuantifier(rule_expr_id)); + return HandleQuestionQuantifier(rule_expr_id); default: LOG(FATAL) << "Unreachable"; } diff --git a/cpp/serve/grammar/grammar_serializer.cc b/cpp/serve/grammar/grammar_serializer.cc index b77e194199..a057921f61 100644 --- a/cpp/serve/grammar/grammar_serializer.cc +++ b/cpp/serve/grammar/grammar_serializer.cc @@ -40,8 +40,8 @@ std::string BNFGrammarPrinter::PrintRuleExpr(const RuleExpr& rule_expr) { return PrintSequence(rule_expr); case RuleExprType::kChoices: return PrintChoices(rule_expr); - case RuleExprType::kStarQuantifier: - return PrintStarQuantifier(rule_expr); + case RuleExprType::kCharacterClassStar: + return PrintCharacterClassStar(rule_expr); default: LOG(FATAL) << "Unexpected RuleExpr type: " << static_cast(rule_expr.type); } @@ -103,7 +103,7 @@ std::string BNFGrammarPrinter::PrintChoices(const RuleExpr& rule_expr) { return result; } -std::string BNFGrammarPrinter::PrintStarQuantifier(const RuleExpr& rule_expr) { +std::string BNFGrammarPrinter::PrintCharacterClassStar(const RuleExpr& rule_expr) { return PrintRuleExpr(rule_expr[0]) + "*"; } diff --git a/cpp/serve/grammar/grammar_serializer.h b/cpp/serve/grammar/grammar_serializer.h index 2bf47392bc..5837ce2bf6 100644 --- a/cpp/serve/grammar/grammar_serializer.h +++ b/cpp/serve/grammar/grammar_serializer.h @@ -73,7 +73,7 @@ class BNFGrammarPrinter : public BNFGrammarSerializer { /*! \brief Print a RuleExpr for rule_expr choices. */ std::string PrintChoices(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for star quantifier. */ - std::string PrintStarQuantifier(const RuleExpr& rule_expr); + std::string PrintCharacterClassStar(const RuleExpr& rule_expr); }; /*! diff --git a/cpp/serve/grammar/grammar_simplifier.cc b/cpp/serve/grammar/grammar_simplifier.cc index ccbfe971f2..234f9d7057 100644 --- a/cpp/serve/grammar/grammar_simplifier.cc +++ b/cpp/serve/grammar/grammar_simplifier.cc @@ -65,7 +65,7 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { } private: - /*! \brief Visit a RuleExpr as the rule body. */ + /*! \brief Visit a RuleExpr as a rule body. */ int32_t VisitRuleBody(const RuleExpr& rule_expr) { switch (rule_expr.type) { case RuleExprType::kSequence: @@ -78,8 +78,8 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { case RuleExprType::kNegCharacterClass: case RuleExprType::kRuleRef: return builder_.AddChoices({builder_.AddSequence({builder_.AddRuleExpr(rule_expr)})}); - case RuleExprType::kStarQuantifier: - return builder_.AddStarQuantifier(VisitExpr(grammar_->GetRuleExpr(rule_expr[0]))); + case RuleExprType::kCharacterClassStar: + return builder_.AddCharacterClassStar(VisitExpr(grammar_->GetRuleExpr(rule_expr[0]))); default: LOG(FATAL) << "Unexpected sequence type: " << static_cast(rule_expr.type); } @@ -109,6 +109,9 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { case RuleExprType::kRuleRef: VisitElementInChoices(choice_expr, &new_choice_ids); break; + case RuleExprType::kCharacterClassStar: + VisitCharacterClassStarInChoices(choice_expr, &new_choice_ids); + break; default: LOG(FATAL) << "Unexpected choice type: " << static_cast(choice_expr.type); } @@ -151,6 +154,16 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { new_choice_ids->push_back(builder_.AddSequence({sub_expr_id})); } + /*! \brief Visit a character class star RuleExpr that is one of a list of choices. */ + void VisitCharacterClassStarInChoices(const RuleExpr& rule_expr, + std::vector* new_choice_ids) { + auto sub_expr_id = builder_.AddRuleExpr(grammar_->GetRuleExpr(rule_expr[0])); + auto new_star_id = builder_.AddCharacterClassStar(sub_expr_id); + auto new_rule_id = builder_.AddRuleWithHint(cur_rule_name_ + "_star", new_star_id); + auto new_rule_ref_id = builder_.AddRuleRef(new_rule_id); + new_choice_ids->push_back(builder_.AddSequence({new_rule_ref_id})); + } + /*! * \brief Visit a RuleExpr containing a sequence. * \returns A list of new sequence RuleExpr ids. @@ -173,6 +186,9 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { case RuleExprType::kRuleRef: VisitElementInSequence(seq_expr, &new_sequence_ids); break; + case RuleExprType::kCharacterClassStar: + VisitCharacterClassStarInSequence(seq_expr, &new_sequence_ids); + break; default: LOG(FATAL) << "Unexpected sequence type: " << static_cast(seq_expr.type); } @@ -208,6 +224,16 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { new_sequence_ids->push_back(builder_.AddRuleExpr(rule_expr)); } + /*! \brief Visit a character class star RuleExpr that is in a sequence. */ + void VisitCharacterClassStarInSequence(const RuleExpr& rule_expr, + std::vector* new_sequence_ids) { + auto sub_expr_id = builder_.AddRuleExpr(grammar_->GetRuleExpr(rule_expr[0])); + auto new_star_id = builder_.AddCharacterClassStar(sub_expr_id); + auto new_rule_id = builder_.AddRuleWithHint(cur_rule_name_ + "_star", new_star_id); + auto new_rule_ref_id = builder_.AddRuleRef(new_rule_id); + new_sequence_ids->push_back(new_rule_ref_id); + } + /*! \brief The name of the current rule being visited. */ std::string cur_rule_name_; }; diff --git a/cpp/serve/grammar/grammar_simplifier.h b/cpp/serve/grammar/grammar_simplifier.h index 4ccc0b55e7..b9accf09bc 100644 --- a/cpp/serve/grammar/grammar_simplifier.h +++ b/cpp/serve/grammar/grammar_simplifier.h @@ -73,8 +73,8 @@ class BNFGrammarMutator { return VisitCharacterClass(rule_expr); case RuleExprType::kRuleRef: return VisitRuleRef(rule_expr); - case RuleExprType::kStarQuantifier: - return VisitStarQuantifier(rule_expr); + case RuleExprType::kCharacterClassStar: + return VisitCharacterClassStar(rule_expr); default: LOG(FATAL) << "Unexpected sequence type: " << static_cast(rule_expr.type); } @@ -135,11 +135,11 @@ class BNFGrammarMutator { virtual T VisitRuleRef(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } /*! \brief Visit a star quantifier RuleExpr. */ - virtual T VisitStarQuantifier(const RuleExpr& rule_expr) { + virtual T VisitCharacterClassStar(const RuleExpr& rule_expr) { if constexpr (std::is_same::value) { VisitExpr(grammar_->GetRuleExpr(rule_expr[0])); } else if constexpr (std::is_same::value) { - return builder_.AddStarQuantifier(VisitExpr(grammar_->GetRuleExpr(rule_expr[0]))); + return builder_.AddCharacterClassStar(VisitExpr(grammar_->GetRuleExpr(rule_expr[0]))); } else { return T(); } diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index 3087a3d665..671b0879e3 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -450,8 +450,15 @@ GrammarStateMatcher::GrammarStateMatcher(std::shared_ptr tokenizer, int max_rollback_steps) { + auto preproc_start = std::chrono::high_resolution_clock::now(); auto init_ctx = GrammarStateMatcher::CreateInitContext( grammar, tokenizer ? tokenizer.value()->TokenTable() : std::vector()); + auto preproc_end = std::chrono::high_resolution_clock::now(); + std::cerr << "Preprocess takes " + << std::chrono::duration_cast(preproc_end - + preproc_start) + .count() + << "us"; return GrammarStateMatcher(init_ctx, max_rollback_steps); }); diff --git a/cpp/serve/grammar/grammar_state_matcher_base.h b/cpp/serve/grammar/grammar_state_matcher_base.h index 0028994b3c..4c543a2e69 100644 --- a/cpp/serve/grammar/grammar_state_matcher_base.h +++ b/cpp/serve/grammar/grammar_state_matcher_base.h @@ -58,8 +58,44 @@ class GrammarStateMatcherBase { // If init_rule_position is {}, init the stack with the main rule. void InitStackState(RulePosition init_rule_position = {}); - // Update the old stack top to the next position, and push the new stack tops to new_stack_tops. - void UpdateNewStackTops(int32_t old_node_id, std::vector* new_stack_tops); + // Update the char_class_star_id field of the given rule_position, if it refers to a character + // class star rule. + void UpdateCharClassStarId(RulePosition* rule_position) const; + + /*! + * \brief Find the next position in the rule. If the next position is at the end of the rule, + * the result depends on the consider_parent parameter: + * - false: kInvalidRulePosition will be returned. + * - true: the next position of the parent rule will be returned. If the current rule is the root + * rule, the RulePosition will be returned as is to indicate the end of the grammar. + * \param rule_position The current position. + * \param consider_parent Whether to consider the parent position if the current position is at + * the end of the rule. + */ + RulePosition IterateToNextPosition(const RulePosition& rule_position, bool consider_parent) const; + + /*! + * \brief Expand the given rule position (may be a RuleRef element) s.t. every new position is a + * CharacterClass or refers to a CharacterClassStar rule. Push all new positions into + * new_stack_tops. + * \details This method will start from cur_rule_position and continuously iterate to the next + * position as long as the current position can be empty (e.g. the current position is a + * reference to an rule that can be empty, or to a character class star rule). If the current + * position can not be empty, stop expanding. All positions collected will be pushed into + * new_stack_tops. + * + * If the end of the current rule is reached: + * - If is_outmost_level is true, we can go to the next position in the parent rule. + * - Otherwise, stop iteration. + * \param cur_rule_position The current rule position. + * \param new_stack_tops The vector to store the new stack tops. + * \param is_outmost_level Whether the current position is the outmost level of the rule. + * \param first_id_if_inserted Being not -1 means the first node is already inserted. This is the + * id of the first node. This is used to avoid inserting the same node twice. + * \return Whether the end of the rule can be reached. Used as the condition of recursion. + */ + bool ExpandRulePosition(RulePosition cur_rule_position, std::vector* new_stack_tops, + bool is_outmost_level, int32_t first_id_if_inserted = -1); BNFGrammar grammar_; RulePositionTree tree_; @@ -89,28 +125,34 @@ inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool const auto& prev_stack_tops = stack_tops_history_.GetLatest(); tmp_new_stack_tops_.clear(); - for (auto old_top : prev_stack_tops) { - const auto& rule_position = tree_[old_top]; - auto current_sequence = grammar_->GetRuleExpr(rule_position.sequence_id); - if (rule_position.parent_id == RulePosition::kNoParent && - rule_position.element_id == current_sequence.size()) { + for (auto prev_top : prev_stack_tops) { + const auto& cur_rule_position = tree_[prev_top]; + auto current_sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); + if (cur_rule_position.parent_id == RulePosition::kNoParent && + cur_rule_position.element_id == current_sequence.size()) { // This RulePosition means previous elements has matched the complete rule. // But we are still need to accept a new character, so this stack will become invalid. continue; } - auto current_char_class = grammar_->GetRuleExpr(current_sequence[rule_position.element_id]); - // Special support for star quantifiers of character classes. - if (current_char_class.type == RuleExprType::kRuleRef) { - DCHECK(rule_position.char_class_id != -1); - current_char_class = grammar_->GetRuleExpr(rule_position.char_class_id); - } + + auto current_char_class = + cur_rule_position.char_class_star_id != -1 + ? grammar_->GetRuleExpr(cur_rule_position.char_class_star_id) + : grammar_->GetRuleExpr(current_sequence[cur_rule_position.element_id]); DCHECK(current_char_class.type == RuleExprType::kCharacterClass || current_char_class.type == RuleExprType::kNegCharacterClass); auto ok = CharacterClassContains(current_char_class, codepoint); if (!ok) { continue; } - UpdateNewStackTops(old_top, &tmp_new_stack_tops_); + + if (cur_rule_position.char_class_star_id == -1) { + auto next_rule_position = IterateToNextPosition(cur_rule_position, true); + DCHECK(next_rule_position != kInvalidRulePosition); + ExpandRulePosition(next_rule_position, &tmp_new_stack_tops_, true); + } else { + ExpandRulePosition(cur_rule_position, &tmp_new_stack_tops_, true, prev_top); + } } if (tmp_new_stack_tops_.empty()) { if (verbose) { @@ -125,6 +167,9 @@ inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool << "\" Accepted" << std::endl; std::cout << "Stack after accepting: " << PrintStackState() << std::endl; } +#if TVM_LOG_DEBUG + stack_tops_history_.CheckWellFormed(); +#endif return true; } @@ -150,12 +195,12 @@ inline void GrammarStateMatcherBase::InitStackState(RulePosition init_rule_posit if (init_rule_position == kInvalidRulePosition) { // Initialize the stack with the main rule. auto main_rule = grammar_->GetRule(0); - auto main_rule_expr = grammar_->GetRuleExpr(main_rule.body_expr_id); + auto main_rule_body = grammar_->GetRuleExpr(main_rule.body_expr_id); std::vector new_stack_tops; - for (auto i : main_rule_expr) { - DCHECK(grammar_->GetRuleExpr(i).type == RuleExprType::kSequence || - grammar_->GetRuleExpr(i).type == RuleExprType::kEmptyStr); - new_stack_tops.push_back(tree_.NewNode(RulePosition(0, i, 0, RulePosition::kNoParent))); + for (auto i : main_rule_body) { + auto init_rule_position = RulePosition(0, i, 0, RulePosition::kNoParent); + UpdateCharClassStarId(&init_rule_position); + ExpandRulePosition(init_rule_position, &new_stack_tops, true); } stack_tops_history_.PushHistory(new_stack_tops); } else { @@ -163,70 +208,110 @@ inline void GrammarStateMatcherBase::InitStackState(RulePosition init_rule_posit } } -inline void GrammarStateMatcherBase::UpdateNewStackTops(int32_t old_node_id, - std::vector* new_stack_tops) { - const auto& old_rule_position = tree_[old_node_id]; - // For char_class*, the old rule position itself is also the next position - if (old_rule_position.char_class_id != -1) { - new_stack_tops->push_back(tree_.NewNode(old_rule_position)); +inline void GrammarStateMatcherBase::UpdateCharClassStarId(RulePosition* rule_position) const { + auto rule_expr = grammar_->GetRuleExpr(rule_position->sequence_id); + auto element = grammar_->GetRuleExpr(rule_expr[rule_position->element_id]); + if (element.type == RuleExprType::kRuleRef) { + auto sub_rule_body = grammar_->GetRuleExpr(grammar_->GetRule(element[0]).body_expr_id); + if (sub_rule_body.type == RuleExprType::kCharacterClassStar) { + rule_position->char_class_star_id = sub_rule_body[0]; + } + } +} + +inline RulePosition GrammarStateMatcherBase::IterateToNextPosition( + const RulePosition& rule_position, bool consider_parent) const { + auto next_position = RulePosition(rule_position.rule_id, rule_position.sequence_id, + rule_position.element_id + 1, rule_position.parent_id); + auto rule_expr = grammar_->GetRuleExpr(rule_position.sequence_id); + auto current_sequence_length = rule_expr.size(); + DCHECK(next_position.element_id <= current_sequence_length); + + if (next_position.element_id < current_sequence_length) { + // Update char_class_star_id if the position refers to a character class star rule. + UpdateCharClassStarId(&next_position); + return next_position; + } + + if (!consider_parent) { + return kInvalidRulePosition; + } + + if (next_position.parent_id == RulePosition::kNoParent) { + return next_position; + } else { + auto parent_rule_position = tree_[next_position.parent_id]; + return IterateToNextPosition(parent_rule_position, true); } +} + +inline bool GrammarStateMatcherBase::ExpandRulePosition(RulePosition cur_rule_position, + std::vector* new_stack_tops, + bool is_outmost_level, + int32_t first_id_if_inserted) { + bool is_first = false; + + for (; cur_rule_position != kInvalidRulePosition; + cur_rule_position = IterateToNextPosition(cur_rule_position, is_outmost_level)) { + // Insert the node to the tree, if not inserted before. + int32_t new_node_id; + if (is_first && first_id_if_inserted != -1) { + new_node_id = first_id_if_inserted; + } else { + new_node_id = tree_.NewNode(cur_rule_position); + } + is_first = false; - auto cur_rule_position = tree_.GetNextPosition(tree_[old_node_id]); + // Case 1. The current position points to the end of the grammar. + if (is_outmost_level) { + if (tree_.IsEndPosition(cur_rule_position)) { + new_stack_tops->push_back(new_node_id); + return true; + } + } else { + DCHECK(!tree_.IsEndPosition(cur_rule_position)); + } - // Continuously iterate to the next position (if reachs the end of the current rule, go to the - // next position of the parent rule). Push it into new_stack_tops. If this position can not - // be empty, exit the loop. - // Positions that can be empty: reference to a rule that can be empty, or a star quantifier - // rule. - for (; !tree_.IsEndPosition(cur_rule_position); - cur_rule_position = tree_.GetNextPosition(cur_rule_position)) { + // Case 2. The current position refers to a character class star rule. It can be empty. + if (cur_rule_position.char_class_star_id != -1) { + new_stack_tops->push_back(new_node_id); + continue; + } + + // Case 3. Character class: cannot be empty. auto sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); auto element = grammar_->GetRuleExpr(sequence[cur_rule_position.element_id]); if (element.type == RuleExprType::kCharacterClass || element.type == RuleExprType::kNegCharacterClass) { - // Character class: cannot be empty. Break the loop. - new_stack_tops->push_back(tree_.NewNode(cur_rule_position)); - break; - } else { - // RuleRef - DCHECK(element.type == RuleExprType::kRuleRef); - auto new_rule_id = element[0]; - auto new_rule = grammar_->GetRule(new_rule_id); - auto new_rule_expr = grammar_->GetRuleExpr(new_rule.body_expr_id); - if (new_rule_expr.type == RuleExprType::kStarQuantifier) { - cur_rule_position.char_class_id = new_rule_expr[0]; - new_stack_tops->push_back(tree_.NewNode(cur_rule_position)); - } else { - DCHECK(new_rule_expr.type == RuleExprType::kChoices); - - bool contain_empty = false; - - // For rule containing choices, expand the rule and push all positions into new_stack_tops - for (auto j : new_rule_expr) { - auto sequence = grammar_->GetRuleExpr(j); - if (sequence.type == RuleExprType::kEmptyStr) { - contain_empty = true; - continue; - } - DCHECK(sequence.type == RuleExprType::kSequence); - DCHECK(grammar_->GetRuleExpr(sequence[0]).type == RuleExprType::kCharacterClass || - grammar_->GetRuleExpr(sequence[0]).type == RuleExprType::kNegCharacterClass); - // Note: rule_position is not inserted to the tree yet, so it need to be inserted first - auto parent_id = tree_.NewNode(cur_rule_position); - new_stack_tops->push_back(tree_.NewNode(RulePosition(new_rule_id, j, 0, parent_id))); - } - - if (!contain_empty) { - break; - } + new_stack_tops->push_back(new_node_id); + return false; + } + + // Case 4. The current position refers to a normal rule, i.e. a rule of choices of sequences. + DCHECK(element.type == RuleExprType::kRuleRef); + auto sub_rule_id = element[0]; + auto sub_rule = grammar_->GetRule(sub_rule_id); + auto sub_rule_body = grammar_->GetRuleExpr(sub_rule.body_expr_id); + DCHECK(sub_rule_body.type == RuleExprType::kChoices); + + bool contain_empty = false; + + for (auto sequence_id : sub_rule_body) { + auto sequence = grammar_->GetRuleExpr(sequence_id); + if (sequence.type == RuleExprType::kEmptyStr) { + contain_empty = true; + continue; } + auto sub_rule_position = RulePosition(sub_rule_id, sequence_id, 0, new_node_id); + UpdateCharClassStarId(&sub_rule_position); + contain_empty |= ExpandRulePosition(sub_rule_position, new_stack_tops, false); } - } - // Reaches the end of the main rule. Insert a special node to indicate the end. - if (tree_.IsEndPosition(cur_rule_position)) { - new_stack_tops->push_back(tree_.NewNode(cur_rule_position)); + if (!contain_empty) { + return false; + } } + return true; } } // namespace serve diff --git a/cpp/serve/grammar/grammar_state_matcher_preproc.h b/cpp/serve/grammar/grammar_state_matcher_preproc.h index 3d1ffeb754..dbb59f886b 100644 --- a/cpp/serve/grammar/grammar_state_matcher_preproc.h +++ b/cpp/serve/grammar/grammar_state_matcher_preproc.h @@ -277,12 +277,12 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC // Find the corresponding catagorized tokens for: // 1. All character elements in the grammar - // 2. All RuleRef elements that refers to a rule of a StarQuantifier of a character class + // 2. All RuleRef elements that refers to a rule containing a CharacterClassStar RuleExpr. for (int i = 0; i < static_cast(grammar->NumRules()); ++i) { auto rule = grammar->GetRule(i); auto rule_expr = grammar->GetRuleExpr(rule.body_expr_id); - // Skip StarQuantifier since we just handle it at the reference element during matching. - if (rule_expr.type == RuleExprType::kStarQuantifier) { + // Skip CharacterClassStar since we just handle it at the reference element during matching. + if (rule_expr.type == RuleExprType::kCharacterClassStar) { continue; } DCHECK(rule_expr.type == RuleExprType::kChoices); @@ -301,8 +301,8 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC if (ref_rule_expr.type == RuleExprType::kChoices) { continue; } else { - // Reference to a StarQuantifier of a character class. - cur_rule_position.char_class_id = ref_rule_expr[0]; + // Reference to a CharacterClassStar of a character class. + cur_rule_position.char_class_star_id = ref_rule_expr[0]; } } diff --git a/cpp/serve/grammar/grammar_state_matcher_state.h b/cpp/serve/grammar/grammar_state_matcher_state.h index d8f2185f98..fad3365ed9 100644 --- a/cpp/serve/grammar/grammar_state_matcher_state.h +++ b/cpp/serve/grammar/grammar_state_matcher_state.h @@ -27,11 +27,11 @@ struct RulePosition { /*! \brief Which element of the choice sequence is being visited. */ int32_t element_id = -1; /*! - * \brief If the element refers to another rule, and another rule is a star quantifier of - * a character class, this field will be set to the id of the character class. - * This is part of the special support of star quantifiers of character classes. + * \brief If the element refers to another rule, and the body of another rule is a + * CharacterClassStar RuleExpr, this field will be set to the id of the character class. + * This is for the special support of CharacterClassStar. */ - int32_t char_class_id = -1; + int32_t char_class_star_id = -1; /*! \brief The id of the parent node in the RulePositionTree. */ int32_t parent_id = -1; /*! \brief The reference count of this RulePosition. If reduces to zero, the node will be @@ -43,16 +43,16 @@ struct RulePosition { constexpr RulePosition() = default; constexpr RulePosition(int32_t rule_id, int32_t sequence_id, int32_t element_id, - int32_t parent_id = kNoParent, int32_t char_class_id = -1) + int32_t parent_id = kNoParent, int32_t char_class_star_id = -1) : rule_id(rule_id), sequence_id(sequence_id), element_id(element_id), - char_class_id(char_class_id), + char_class_star_id(char_class_star_id), parent_id(parent_id) {} bool operator==(const RulePosition& other) const { return rule_id == other.rule_id && sequence_id == other.sequence_id && - element_id == other.element_id && char_class_id == other.char_class_id && + element_id == other.element_id && char_class_star_id == other.char_class_star_id && parent_id == other.parent_id; } @@ -146,13 +146,10 @@ class RulePositionTree { } /*! - * \brief Update a node in the stack to the next position. Next position means either the next - * element in the current rule, or if the current element is the last element in the rule, the - * next element in the parent rule. If the current node is the last element in the main rule, it - * is at the end position. + * \brief Check if the given RulePosition points to the end of the grammar. We use + * (main_rule_id, sequence_id, length_of_sequence) to represent the end position. Here the + * element_id is the length of the sequence. */ - RulePosition GetNextPosition(RulePosition rule_position) const; - bool IsEndPosition(const RulePosition& rule_position) const; /*! \brief Attach an additional reference to the node with the given id. */ @@ -180,6 +177,7 @@ class RulePositionTree { /*! \brief Get the RulePosition with the given id. */ const RulePosition& operator[](int32_t id) const { DCHECK(id != RulePosition::kNoParent); + DCHECK(node_buffer_[id] != kInvalidRulePosition); return node_buffer_[id]; } @@ -313,34 +311,11 @@ class StackTopsHistory { std::deque> stack_tops_history_; }; -/*! \brief See GetNextPosition. */ inline bool RulePositionTree::IsEndPosition(const RulePosition& rule_position) const { return rule_position.parent_id == RulePosition::kNoParent && grammar_->GetRuleExpr(rule_position.sequence_id).size() == rule_position.element_id; } -/*! - * \brief Update a node in the stack to the next position. Next position means either the next - * element in the current rule, or if the current element is the last element in the rule, the - * next element in the parent rule. If the current node is the last element in the main rule, it - * is at the end position. - */ -inline RulePosition RulePositionTree::GetNextPosition(RulePosition rule_position) const { - if (IsEndPosition(rule_position)) { - return kInvalidRulePosition; - } - rule_position = RulePosition(rule_position.rule_id, rule_position.sequence_id, - rule_position.element_id + 1, rule_position.parent_id); - while (rule_position.parent_id != RulePosition::kNoParent && - grammar_->GetRuleExpr(rule_position.sequence_id).size() == rule_position.element_id) { - auto parent_rule_position = node_buffer_[rule_position.parent_id]; - rule_position = - RulePosition(parent_rule_position.rule_id, parent_rule_position.sequence_id, - parent_rule_position.element_id + 1, parent_rule_position.parent_id); - } - return rule_position; -} - inline std::string RulePositionTree::PrintNode(int32_t id) const { std::stringstream ss; const auto& rule_position = node_buffer_[id]; @@ -348,7 +323,12 @@ inline std::string RulePositionTree::PrintNode(int32_t id) const { ss << ", rule " << rule_position.rule_id << ": " << grammar_->GetRule(rule_position.rule_id).name; ss << ", sequence " << rule_position.sequence_id << ": " << BNFGrammarPrinter(grammar_).PrintRuleExpr(rule_position.sequence_id); - ss << ", element id: " << rule_position.element_id << ", parent id: " << rule_position.parent_id + ss << ", element id: " << rule_position.element_id; + if (rule_position.char_class_star_id != -1) { + ss << ", char class " << rule_position.char_class_star_id << ": " + << BNFGrammarPrinter(grammar_).PrintRuleExpr(rule_position.char_class_star_id) << "*"; + } + ss << ", parent id: " << rule_position.parent_id << ", ref count: " << rule_position.reference_count; return ss.str(); } diff --git a/python/mlc_chat/serve/grammar.py b/python/mlc_chat/serve/grammar.py index f6122c5e8a..b8f4126c1c 100644 --- a/python/mlc_chat/serve/grammar.py +++ b/python/mlc_chat/serve/grammar.py @@ -239,7 +239,7 @@ def is_terminated(self) -> bool: return _ffi_api.GrammarStateMatcherIsTerminated(self) # type: ignore # pylint: disable=no-member def debug_accept_char(self, codepoint: int) -> bool: - """Accept one unicode codepoint to the current state. + """Accept one unicode codepoint to the current state. For test purposes. Parameters ---------- @@ -251,8 +251,8 @@ def debug_accept_char(self, codepoint: int) -> bool: ) def debug_match_complete_string(self, string: str) -> bool: - """Check if a matcher can accept the complete string, and then reach the end of the - grammar. + """Check if the matcher can accept the complete string, and then reach the end of the + grammar. For test purposes. Parameters ---------- diff --git a/tests/python/serve/test_grammar_parser.py b/tests/python/serve/test_grammar_parser.py index ceffd5805d..87228b1c18 100644 --- a/tests/python/serve/test_grammar_parser.py +++ b/tests/python/serve/test_grammar_parser.py @@ -24,16 +24,16 @@ def test_bnf_simple(): def test_ebnf(): before = """main ::= b c | b main -b ::= "b"* +b ::= "ab"* c ::= [acep-z]+ d ::= "d"? """ expected = """main ::= ((b c) | (b main)) -b ::= [b]* -c ::= ((c_2)) +b ::= ((b_1)) +c ::= ((c_1)) d ::= ((d_1)) -c_1 ::= (([acep-z])) -c_2 ::= ((c_1 c_2) | (c_1)) +b_1 ::= ("" | ([a] [b] b_1)) +c_1 ::= (([acep-z] c_1) | ([acep-z])) d_1 ::= ("" | ([d])) """ bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) @@ -41,6 +41,30 @@ def test_ebnf(): assert after == expected +def test_star_quantifier(): + before = """main ::= b c d +b ::= [b]* +c ::= "b"* +d ::= ([b] [c] [d] | ([p] [q]))* +e ::= [e]* [f]* | [g]* +""" + expected = """main ::= ((b c d)) +b ::= [b]* +c ::= ((c_1)) +d ::= ((d_1)) +e ::= ((e_star e_star_1) | (e_star_2)) +c_1 ::= ("" | ([b] c_1)) +d_1 ::= ("" | (d_1_choice d_1)) +e_star ::= [e]* +e_star_1 ::= [f]* +e_star_2 ::= [g]* +d_1_choice ::= (([b] [c] [d]) | ([p] [q])) +""" + bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + after = bnf_grammar.to_string() + assert after == expected + + def test_char(): before = r"""main ::= [a-z] [A-z] "\u0234" "\U00000345\xff" [-A-Z] [--] [^a] rest rest ::= [a-zA-Z0-9-] [\u0234-\U00000345] [测-试] [\--\]] rest1 diff --git a/tests/python/serve/test_grammar_state_matcher_custom.py b/tests/python/serve/test_grammar_state_matcher_custom.py new file mode 100644 index 0000000000..d9a9a09bab --- /dev/null +++ b/tests/python/serve/test_grammar_state_matcher_custom.py @@ -0,0 +1,214 @@ +# pylint: disable=missing-module-docstring,missing-function-docstring +# pylint: disable=redefined-outer-name,unbalanced-tuple-unpacking +"""This test is adopted from test_grammar_state_matcher_json.py, but the grammar is parsed from +a unoptimized, non-simplified EBNF string. This is to test the robustness of the grammar state +matcher.""" +import sys +from typing import List, Optional + +import pytest +import tvm +import tvm.testing + +from mlc_chat.serve import BNFGrammar, GrammarStateMatcher +from mlc_chat.tokenizer import Tokenizer + + +def get_json_grammar(): + json_grammar_ebnf = r""" +main ::= basic_array | basic_object +basic_any ::= basic_integer | basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= (([\"] basic_string_1 [\"])) +basic_string_1 ::= "" | [^"\\\r\n] basic_string_1 | "\\" escape basic_string_1 +escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" +basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" +ws ::= [ \n\t]* +""" + grammar = BNFGrammar.from_ebnf_string(json_grammar_ebnf) + print(grammar) + return grammar + + +@pytest.fixture(scope="function") +def json_grammar(): + return get_json_grammar() + + +(json_input_accepted,) = tvm.testing.parameters( + ('{"name": "John"}',), + ('{ "name" : "John" }',), + ("{}",), + ("[]",), + ('{"name": "Alice", "age": 30, "city": "New York"}',), + ('{"name": "Mike", "hobbies": ["reading", "cycling", "hiking"]}',), + ('{"name": "Emma", "address": {"street": "Maple Street", "city": "Boston"}}',), + ('[{"name": "David"}, {"name": "Sophia"}]',), + ( + '{"name": "William", "age": null, "married": true, "children": ["Liam", "Olivia"],' + ' "hasPets": false}', + ), + ( + '{"name": "Olivia", "contact": {"email": "olivia@example.com", "address": ' + '{"city": "Chicago", "zipcode": "60601"}}}', + ), + ( + '{"name": "Liam", "skills": ["Java", "Python"], "experience": ' + '[{"company": "CompanyA", "years": 5}, {"company": "CompanyB", "years": 3}]}', + ), + ( + '{"person": {"name": "Ethan", "age": 40}, "education": {"degree": "Masters", ' + '"university": "XYZ University"}, "work": [{"company": "ABC Corp", "position": ' + '"Manager"}, {"company": "DEF Corp", "position": "Senior Manager"}]}', + ), + ( + '{"name": "Charlotte", "details": {"personal": {"age": 35, "hobbies": ["gardening", ' + '"painting"]}, "professional": {"occupation": "Engineer", "skills": ' + '["CAD", "Project Management"], "projects": [{"name": "Project A", ' + '"status": "Completed"}, {"name": "Project B", "status": "In Progress"}]}}}', + ), +) + + +def test_json_accept(json_grammar: BNFGrammar, json_input_accepted: str): + assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_accepted) + + +(json_input_refused,) = tvm.testing.parameters( + (r'{ name: "John" }',), + (r'{ "name": "John" } ',), # trailing space is not accepted + (r'{ "name": "John", "age": 30, }',), + (r'{ "name": "John", "address": { "street": "123 Main St", "city": "New York" }',), + (r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling",], }',), + (r'{ "name": "John", "age": 30.5.7 }',), + (r'{ "name": "John, "age": 30, "hobbies": ["reading", "traveling"] }',), + ( + r'{ "name": "John", "age": 30, "hobbies": ["reading", { "type": "outdoor", "list": ' + r'["hiking", "swimming",]}] }', + ), + (r'{ "name": "John", "age": 30, "status": "\P\J" }',), + ( + r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling"], "address": ' + r'{ "street": "123 Main St", "city": "New York", "coordinates": { "latitude": 40.7128, ' + r'"longitude": -74.0060 }}}, "work": { "company": "Acme", "position": "developer" }}', + ), +) + + +def test_json_refuse(json_grammar: BNFGrammar, json_input_refused): + assert not GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_refused) + + +(input_find_rejected_tokens, expected_rejected_sizes) = tvm.testing.parameters( + ( + # short test + '{"id": 1,"name": "Example"}', + [ + # fmt: off + 31989, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 299, 299, 299, 299, + 299, 31973, 31846, 31846, 292, 292, 292, 292, 292, 292, 292, 292, 31974, 31999 + # fmt: on + ], + ), + ( + # long test + """{ +"id": 1, +"na": "ex", +"ac": true, +"t": ["t1", "t2"], +"ne": {"lv2": {"val": "dp"}, "arr": [1, 2, 3]}, +"res": "res" +}""", + [ + # fmt: off + 31989, 31912, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 31915, 299, 299, + 299, 31973, 31846, 31846, 292, 292, 292, 31974, 31915, 31915, 299, 299, 299, 31973, + 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 299, 299, 31973, 31846, 31846, + 31840, 291, 291, 291, 31969, 31846, 31846, 291, 291, 291, 31969, 31974, 31915, 31915, + 299, 299, 299, 31973, 31846, 31846, 31908, 299, 299, 299, 299, 31973, 31846, 31846, + 31906, 299, 299, 299, 299, 31973, 31846, 31846, 291, 291, 291, 31968, 31970, 31915, + 31915, 299, 299, 299, 299, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, + 31846, 31846, 31943, 31970, 31974, 31915, 31915, 299, 299, 299, 299, 31973, 31846, + 31846, 292, 292, 292, 292, 31974, 31974, 31999 + # fmt: on + ], + ), +) + + +def test_find_next_rejected_tokens( + json_grammar: BNFGrammar, + input_find_rejected_tokens: str, + expected_rejected_sizes: Optional[List[int]] = None, +): + tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + tokenizer = Tokenizer(tokenizer_path) + grammar_state_matcher = GrammarStateMatcher(json_grammar, tokenizer) + + real_sizes = [] + for c in input_find_rejected_tokens: + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() + real_sizes.append(len(rejected_token_ids)) + print("Accepting char:", c, file=sys.stderr) + assert grammar_state_matcher.debug_accept_char(ord(c)) + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() + real_sizes.append(len(rejected_token_ids)) + + if expected_rejected_sizes is not None: + assert real_sizes == expected_rejected_sizes + + +def test_token_based_operations(json_grammar: BNFGrammar): + """Test accepting token and finding the next token mask.""" + token_table = [ + # fmt: off + "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', + # fmt: on + ] + input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}"] + input_ids = [token_table.index(t) for t in input_splitted] + + grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) + + expected = [ + ["{"], + ['"', "}", "\n", " ", '"a":true'], + ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", " "], + ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", " "], + [":", "\n", " ", ':"'], + ['"', "{", "6", "\n", " "], + ["}", ", ", "6", "\n", " "], + [" ", "\n", '"', '"a":true'], + [" ", "\n", '"', '"a":true'], + ["}", ", ", "\n", " "], + [""], + ] + + result = [] + + for id in input_ids: + rejected = grammar_state_matcher.find_next_rejected_tokens() + accepted = list(set(range(len(token_table))) - set(rejected)) + accepted_tokens = [token_table[i] for i in accepted] + result.append(accepted_tokens) + assert id in accepted + assert grammar_state_matcher.accept_token(id) + + rejected = grammar_state_matcher.find_next_rejected_tokens() + accepted = list(set(range(len(token_table))) - set(rejected)) + accepted_tokens = [token_table[i] for i in accepted] + result.append(accepted_tokens) + + assert result == expected + + +if __name__ == "__main__": + # Run a benchmark to show the performance before running tests + test_find_next_rejected_tokens(get_json_grammar(), '{"id": 1,"name": "Example"}') + + tvm.testing.main() diff --git a/tests/python/serve/test_grammar_state_matcher.py b/tests/python/serve/test_grammar_state_matcher_json.py similarity index 96% rename from tests/python/serve/test_grammar_state_matcher.py rename to tests/python/serve/test_grammar_state_matcher_json.py index c03a414931..a38a0edefe 100644 --- a/tests/python/serve/test_grammar_state_matcher.py +++ b/tests/python/serve/test_grammar_state_matcher_json.py @@ -1,7 +1,8 @@ # pylint: disable=missing-module-docstring,missing-function-docstring # pylint: disable=redefined-outer-name,unbalanced-tuple-unpacking +"""This test uses the optimized JSON grammar provided by the grammar library.""" import sys -from typing import List +from typing import List, Optional import pytest import tvm @@ -251,7 +252,9 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): def test_find_next_rejected_tokens( - json_grammar: BNFGrammar, input_find_rejected_tokens: str, expected_rejected_sizes: List[int] + json_grammar: BNFGrammar, + input_find_rejected_tokens: str, + expected_rejected_sizes: Optional[List[int]] = None, ): tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" tokenizer = Tokenizer(tokenizer_path) @@ -265,8 +268,8 @@ def test_find_next_rejected_tokens( assert grammar_state_matcher.debug_accept_char(ord(c)) rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() real_sizes.append(len(rejected_token_ids)) - print(real_sizes) - assert real_sizes == expected_rejected_sizes + if expected_rejected_sizes is not None: + assert real_sizes == expected_rejected_sizes def test_token_based_operations(json_grammar: BNFGrammar): @@ -404,15 +407,6 @@ def test_termination(json_grammar: BNFGrammar): if __name__ == "__main__": # Run a benchmark to show the performance before running tests - test_find_next_rejected_tokens( - BNFGrammar.get_grammar_of_json(), - '{"id": 1,"name": "Example"}', - [ - # fmt: off - 31989, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 299, 299, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 292, 292, 292, 292, 292, 31974, 31999 - # fmt: on - ], - ) + test_find_next_rejected_tokens(BNFGrammar.get_grammar_of_json(), '{"id": 1,"name": "Example"}') tvm.testing.main() From b44cdc53381bd804ef000775bb280de1c7ca6439 Mon Sep 17 00:00:00 2001 From: Bohan Hou Date: Sun, 10 Mar 2024 11:13:57 -0400 Subject: [PATCH 051/531] [Android] Improve perf of TIR PagedAttn kernel on Android (#1915) * android perf * Update kv_cache.py --- python/mlc_chat/nn/kv_cache.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/mlc_chat/nn/kv_cache.py b/python/mlc_chat/nn/kv_cache.py index 4f14774338..f63e74d855 100644 --- a/python/mlc_chat/nn/kv_cache.py +++ b/python/mlc_chat/nn/kv_cache.py @@ -835,8 +835,13 @@ def _attention_decode( H_kv = num_kv_heads D = head_dim + THREAD_LIMIT = 512 + TILE_SIZE_PER_BDX = 2 + if target.kind.name == "opencl" and "android" in str(target.host): + THREAD_LIMIT = 64 + TILE_SIZE_PER_BDX = 1 max_num_threads_per_block = get_max_num_threads_per_block(target) - thread_limit = min(max_num_threads_per_block, 512) + thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) GROUP_SIZE = H_qo // H_kv VEC_SIZE = min(max(8 // qkv_dtype_bytes, D // 32), 4) @@ -847,7 +852,7 @@ def _attention_decode( gdz = GROUP_SIZE // bdy threads_per_CTA = max(thread_limit, bdx * bdy) bdz = threads_per_CTA // (bdx * bdy) - tile_size_per_bdx = 2 if GROUP_SIZE == 1 else 1 + tile_size_per_bdx = TILE_SIZE_PER_BDX if GROUP_SIZE == 1 else 1 log2e = math.log2(math.exp(1)) check_thread_limits(target, bdx=bdx, bdy=bdy, bdz=bdz, gdz=1) @@ -994,10 +999,9 @@ def batch_decode_paged_kv( ) T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], True, t0[0], tx, dtype="handle") + S_local[j] = -5e4 if (iterator * bdz + tz) * bdy * tile_size_per_bdx + j < kv_chunk_len[0]: S_local[j] = t0[0] - else: - S_local[j] = -5e4 # update st_m st_m[0] = T.max(st_m[0], S_local[j]) From 20efccb7628562974794a1d9d96763bea2cd2f90 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 11 Mar 2024 15:07:26 -0400 Subject: [PATCH 052/531] Deprecate old flow (#1928) * Deprecate old flow This PR deprecates the old flow. As of today most of the efforts are centralized around the new flow with SLM compilation. Additionally, we are bringing model definitions through unified kv interface so we can have a single model across all backends, server and local setting. We kept the old flow around for a while, but it is a good time to do the transition. All the documents are updated to point to the new flow. We also created a backup branch https://github.com/mlc-ai/mlc-llm/tree/backup-before-old-flow-deprecation for people who would like to checkout some of the old flow references. * Remove deprecated prebuilts --- docs/prebuilt_models_deprecated.rst | 845 -- mlc_llm/__init__.py | 7 - mlc_llm/build.py | 47 - mlc_llm/core.py | 996 -- mlc_llm/dispatch/__init__.py | 2 - mlc_llm/dispatch/dispatch_tir_operator.py | 53 - .../dispatch/dispatch_tir_operator_adreno.py | 8356 ----------------- mlc_llm/dispatch/gpt_neox/__init__.py | 13 - mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py | 1034 -- mlc_llm/dispatch/gpt_neox/dolly_v2_3b_mod.py | 511 - .../gpt_neox/redpajama_incite_chat_3b_v1.py | 972 -- .../redpajama_incite_chat_3b_v1_mod.py | 722 -- .../redpajama_incite_chat_3b_v1_tune.py | 1010 -- mlc_llm/dispatch/gpt_neox/redpajama_q4f32.py | 840 -- .../dispatch/gpt_neox/redpajama_q4f32_mod.py | 577 -- .../dispatch/gpt_neox/redpajama_q4f32_tune.py | 743 -- mlc_llm/dispatch/llama/__init__.py | 1 - mlc_llm/dispatch/llama/main.py | 6712 ------------- mlc_llm/quantization/__init__.py | 232 - mlc_llm/quantization/autogptq_quantization.py | 193 - mlc_llm/quantization/ft_quantization.py | 219 - mlc_llm/quantization/group_quantization.py | 214 - mlc_llm/quantization/quantization.py | 217 - mlc_llm/quantization/tir_utils.py | 106 - mlc_llm/relax_model/__init__.py | 1 - mlc_llm/relax_model/chatglm.py | 807 -- mlc_llm/relax_model/commons.py | 363 - mlc_llm/relax_model/gpt_bigcode.py | 667 -- mlc_llm/relax_model/gpt_neox.py | 739 -- mlc_llm/relax_model/gptj.py | 692 -- mlc_llm/relax_model/llama.py | 1505 --- mlc_llm/relax_model/llama_batched_vllm.py | 662 -- mlc_llm/relax_model/minigpt.py | 627 -- mlc_llm/relax_model/mistral.py | 1126 --- mlc_llm/relax_model/modules.py | 280 - mlc_llm/relax_model/param_manager.py | 1259 --- mlc_llm/relax_model/rwkv.py | 613 -- mlc_llm/relax_model/stablelm_3b.py | 919 -- mlc_llm/transform/__init__.py | 10 - mlc_llm/transform/clean_up_tir_attrs.py | 25 - mlc_llm/transform/decode_matmul_ewise.py | 84 - mlc_llm/transform/decode_take.py | 71 - mlc_llm/transform/decode_transpose.py | 113 - .../transform/fuse_split_rotary_embedding.py | 284 - .../transform/lift_tir_global_buffer_alloc.py | 197 - mlc_llm/transform/reorder_transform_func.py | 281 - mlc_llm/transform/rewrite_attention.py | 46 - mlc_llm/transform/set_entry_funcs.py | 70 - mlc_llm/transform/transpose_matmul.py | 349 - mlc_llm/utils.py | 738 -- setup.py | 47 - 51 files changed, 37197 deletions(-) delete mode 100644 docs/prebuilt_models_deprecated.rst delete mode 100644 mlc_llm/__init__.py delete mode 100644 mlc_llm/build.py delete mode 100644 mlc_llm/core.py delete mode 100644 mlc_llm/dispatch/__init__.py delete mode 100644 mlc_llm/dispatch/dispatch_tir_operator.py delete mode 100644 mlc_llm/dispatch/dispatch_tir_operator_adreno.py delete mode 100644 mlc_llm/dispatch/gpt_neox/__init__.py delete mode 100644 mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py delete mode 100644 mlc_llm/dispatch/gpt_neox/dolly_v2_3b_mod.py delete mode 100644 mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1.py delete mode 100644 mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_mod.py delete mode 100644 mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_tune.py delete mode 100644 mlc_llm/dispatch/gpt_neox/redpajama_q4f32.py delete mode 100644 mlc_llm/dispatch/gpt_neox/redpajama_q4f32_mod.py delete mode 100644 mlc_llm/dispatch/gpt_neox/redpajama_q4f32_tune.py delete mode 100644 mlc_llm/dispatch/llama/__init__.py delete mode 100644 mlc_llm/dispatch/llama/main.py delete mode 100644 mlc_llm/quantization/__init__.py delete mode 100644 mlc_llm/quantization/autogptq_quantization.py delete mode 100644 mlc_llm/quantization/ft_quantization.py delete mode 100644 mlc_llm/quantization/group_quantization.py delete mode 100644 mlc_llm/quantization/quantization.py delete mode 100644 mlc_llm/quantization/tir_utils.py delete mode 100644 mlc_llm/relax_model/__init__.py delete mode 100644 mlc_llm/relax_model/chatglm.py delete mode 100644 mlc_llm/relax_model/commons.py delete mode 100644 mlc_llm/relax_model/gpt_bigcode.py delete mode 100644 mlc_llm/relax_model/gpt_neox.py delete mode 100644 mlc_llm/relax_model/gptj.py delete mode 100644 mlc_llm/relax_model/llama.py delete mode 100644 mlc_llm/relax_model/llama_batched_vllm.py delete mode 100644 mlc_llm/relax_model/minigpt.py delete mode 100644 mlc_llm/relax_model/mistral.py delete mode 100644 mlc_llm/relax_model/modules.py delete mode 100644 mlc_llm/relax_model/param_manager.py delete mode 100644 mlc_llm/relax_model/rwkv.py delete mode 100644 mlc_llm/relax_model/stablelm_3b.py delete mode 100644 mlc_llm/transform/__init__.py delete mode 100644 mlc_llm/transform/clean_up_tir_attrs.py delete mode 100644 mlc_llm/transform/decode_matmul_ewise.py delete mode 100644 mlc_llm/transform/decode_take.py delete mode 100644 mlc_llm/transform/decode_transpose.py delete mode 100644 mlc_llm/transform/fuse_split_rotary_embedding.py delete mode 100644 mlc_llm/transform/lift_tir_global_buffer_alloc.py delete mode 100644 mlc_llm/transform/reorder_transform_func.py delete mode 100644 mlc_llm/transform/rewrite_attention.py delete mode 100644 mlc_llm/transform/set_entry_funcs.py delete mode 100644 mlc_llm/transform/transpose_matmul.py delete mode 100644 mlc_llm/utils.py delete mode 100644 setup.py diff --git a/docs/prebuilt_models_deprecated.rst b/docs/prebuilt_models_deprecated.rst deleted file mode 100644 index c18f3f3b44..0000000000 --- a/docs/prebuilt_models_deprecated.rst +++ /dev/null @@ -1,845 +0,0 @@ -Model Prebuilts from Old Flow (Deprecated) -========================================== - -**This page records the model libraries weights compiled under the old workflow (non-SLM).** - -**We will remove this page soon.** - -.. contents:: Table of Contents - :depth: 3 - :local: - -Overview --------- - -MLC-LLM is a universal solution for deploying different language models. Any models that can be described in `TVM Relax `__ -(a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the -help of :doc:`TVM Unity `. - -There are two ways to run a model on MLC-LLM: - -1. Compile your own models following :doc:`the model compilation page `. -2. Use off-the-shelf prebuilts models following this current page. - -This page focuses on the second option: - -- Documenting :ref:`how to use prebuilts ` for various platforms, and -- Tracking what current :ref:`prebuilt models we provide `. - -Prerequisite: Model Libraries and Compiled Weights -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In order to run a specific model on MLC-LLM, you need: - -**1. A model library:** a binary file containing the end-to-end functionality to inference a model (e.g. ``Llama-2-7b-chat-hf-q4f16_1-cuda.so``). See the full list of all precompiled model libraries `here `__. - -**2. Compiled weights:** a folder containing multiple files that store the compiled and quantized weights of a model (e.g. https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_1). See the full list of all precompiled weights `here `__. - -.. _deprecated-using-model-prebuilts: - -Using Prebuilt Models for Different Platforms ---------------------------------------------- - -We quickly go over how to use prebuilt models for each platform. You can find detailed instruction on each platform's corresponding page. - -.. _deprecated-using-prebuilt-models-cli: - - -Prebuilt Models on CLI / Python -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -For more, please see :doc:`the CLI page `, and the :doc:`the Python page `. - -.. collapse:: Click to show details - - First create the conda environment if you have not done so. - - .. code:: shell - - conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly - conda activate mlc-chat-venv - conda install git git-lfs - git lfs install - - Download the prebuilt model libraries from github. - - .. code:: shell - - mkdir -p dist/prebuilt - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib - - Download the prebuilt model weights from hugging face for the model variant you want. - - .. code:: shell - - # Say we want to run rwkv-raven-7b-q8f16_0 - cd dist/prebuilt - git clone https://huggingface.co/mlc-ai/mlc-chat-rwkv-raven-7b-q8f16_0 - cd ../.. - - # The format being: - # cd dist/prebuilt - # git clone https://huggingface.co/mlc-ai/mlc-chat-[model-code] - # cd ../.. - # mlc_chat_cli --model [model-code] - - Run the model with CLI: - - .. code:: shell - - # For CLI - mlc_chat_cli --model rwkv-raven-7b-q8f16_0 - - To run the model with Python API, see :doc:`the Python page ` (all other downloading steps are the same as CLI). - - -.. for a blank line - -| - -.. _deprecated-using-prebuilt-models-ios: - -Prebuilt Models on iOS -^^^^^^^^^^^^^^^^^^^^^^ - -For more, please see :doc:`the iOS page `. - -.. collapse:: Click to show details - - The `iOS app `_ has builtin RedPajama-3B and Llama-2-7b support. - - All prebuilt models with an entry in ``iOS`` in the :ref:`model library table ` are supported by iOS. Namely, we have: - - .. list-table:: Prebuilt model libraries integrated in the iOS app - :widths: 15 15 15 - :header-rows: 1 - - * - Model library name - - Model Family - - Quantization Mode - * - `Llama-2-7b-chat-hf-q3f16_1` - - LLaMA - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - * - `vicuna-v1-7b-q3f16_0` - - LLaMA - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - GPT-NeoX - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - As for prebuilt model weights, the ones we have integrated into app are listed below: - - .. list-table:: Tested prebuilt model weights for iOS - :widths: 15 15 15 15 - :header-rows: 1 - - * - Model code - - Model Series - - Quantization Mode - - Hugging Face repo - * - `Llama-2-7b-q3f16_1` - - `Llama `__ - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `vicuna-v1-7b-q3f16_0` - - `Vicuna `__ - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - `RedPajama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - - To run a model variant you compiled on your own, you can directly reuse the above - integrated prebuilt model libraries, as long as the model shares the - architecture and is compiled with the same quantization mode. - For example, if you compile `OpenLLaMA-7B `_ - with quantization mode ``q3f16_0``, then you can run the compiled OpenLLaMA model on iPhone - without rebuilding the iOS app by reusing the `vicuna-v1-7b-q3f16_0` model library. - Then you can upload the compiled weights to hugging face so that you can download - the weights in the app as shown below (for more on uploading to hugging face, - please check :ref:`distribute-compiled-models`). - - To add a model to the iOS app, follow the steps below: - - .. tabs:: - - .. tab:: Step 1 - - Open "MLCChat" app, click "Add model variant". - - .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/iPhone-custom-1.png - :align: center - :width: 30% - - .. tab:: Step 2 - - Paste the repository URL of the model built on your own, and click "Add". - - You can refer to the link in the image as an example. - - .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/iPhone-custom-2.png - :align: center - :width: 30% - - .. tab:: Step 3 - - After adding the model, you can download your model from the URL by clicking the download button. - - .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/iPhone-custom-3.png - :align: center - :width: 30% - - .. tab:: Step 4 - - When the download is finished, click into the model and enjoy. - - .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/iPhone-custom-4.png - :align: center - :width: 30% - -.. for a blank line - -| - -.. _deprecated-prebuilt-models-android: - -Prebuilt Models on Android -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -For more, please see :doc:`the Android page `. - -.. collapse:: Click to show details - - The apk for demo Android app includes the following models. To add more, check out the Android page. - - .. list-table:: Prebuilt Models for Android - :widths: 15 15 15 15 - :header-rows: 1 - - * - Model code - - Model Series - - Quantization Mode - - Hugging Face repo - * - `Llama-2-7b-q4f16_1` - - `Llama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - `RedPajama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ -.. for a blank line - -| - -.. _deprecated-supported-model-architectures: - -Level 1: Supported Model Architectures (The All-In-One Table) -------------------------------------------------------------- - -For each model architecture (e.g. Llama), there are multiple variants (e.g. CodeLlama, WizardLM). The variants share the same code for inference and only differ in their weights. In other words, running CodeLlama and WizardLM can use the same model library file (specified in Level 2 tables), but different precompiled weights (specified in Level 3 tables). Note that we have not provided prebuilt weights for all model variants. - -Each entry below hyperlinks to the corresponding level 2 and level 3 tables. - -MLC-LLM supports the following model architectures: - -.. list-table:: Supported Model Architectures - :widths: 10 10 15 15 - :header-rows: 1 - - * - Model Architecture - - Support - - Available MLC Prebuilts - - Unavailable in MLC Prebuilts - * - `LLaMA `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`Llama-2 ` - * :ref:`Code Llama ` - * :ref:`Vicuna ` - * :ref:`WizardLM ` - * :ref:`WizardMath ` - * :ref:`OpenOrca Platypus2 ` - * :ref:`FlagAlpha Llama-2 Chinese ` - * :ref:`georgesung Llama-2 Uncensored ` - - * `Alpaca `__ - * `Guanaco `__ - * `OpenLLaMA `__ - * `Gorilla `__ - * `YuLan-Chat `__ - * `WizardCoder (new) `__ - * - `GPT-NeoX `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`RedPajama ` - - * `Dolly `__ - * `Pythia `__ - * `StableCode `__ - * - `GPT-J `__ - - * Prebuilt not compiled yet - * `MLC Implementation `__ - - - - * `MOSS `__ - * - `RWKV `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`RWKV-raven ` - - - * - `MiniGPT `__ - - * Prebuilt not compiled yet - * `MLC Implementation `__ - - - - * `MiniGPT-4 `__ - * - `GPTBigCode `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`WizardCoder (old) ` - - * `StarCoder `__ - * `SantaCoder `__ - * - `ChatGLM `__ - - * Prebuilt not compiled yet - * `MLC Implementation `__ - - - - * `ChatGLM2 `__ - * `CodeGeeX2 `__ - * - `StableLM `__ - - * Prebuilt not compiled yet - * `MLC Implementation `__ - - - - * `StableLM `__ - -If the model variant you are interested in uses one of these model architectures we support, -(but we have not provided the prebuilt weights yet), you can check out -:doc:`/compilation/convert_weights` and :doc:`/compilation/compile_models` on how to compile your own models. -Afterwards, you may follow :ref:`distribute-compiled-models` to upload your prebuilt -weights to hugging face, and submit a PR that adds an entry to this page, -contributing to the community. - -For models structured in an architecture we have not supported yet, you could: - -- Either `create a [Model Request] issue `__ which automatically shows up on our `Model Request Tracking Board `__. - -- Or follow our tutorial :doc:`Define New Models `, which introduces how to bring a new model architecture to MLC-LLM. - - -.. _deprecated-model-library-tables: - -Level 2: Model Library Tables (Precompiled Binary Files) --------------------------------------------------------- - -As mentioned earlier, each model architecture corresponds to a different model library file. That is, you cannot use the same model library file to run ``RedPajama`` and ``Llama-2``. However, you can use the same ``Llama`` model library file to run ``Llama-2``, ``WizardLM``, ``CodeLlama``, etc, but just with different weight files (from tables in Level 3). - -Each table below demonstrates the pre-compiled model library files for each model architecture. This is categorized by: - -- **Size**: each size of model has its own distinct model library file (e.g. 7B or 13B number of parameters) - -- **Platform**: the backend that the model library is intended to be run on (e.g. CUDA, ROCm, iphone, etc.) - -- **Quantization scheme**: the model library file also differs due to the quantization scheme used. For more on this, please see the :doc:`model compilation page ` (e.g. ``q3f16_1`` vs. ``q4f16_1``) - -Each entry links to the specific model library file found in `this github repo `__. - -.. _deprecated-llama_library_table: - -Llama -^^^^^ -.. list-table:: Llama - :widths: 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M1/M2) - - Metal - - (Intel) - - iOS - - webgpu - - mali - * - 7B - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q3f16_1 `__ - - `q4f16_1 `__ - - `q4f32_1 `__ - - `q4f16_1 `__ - * - 13B - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - `q4f16_1 `__ - * - 34B - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q4f16_1 `__ - - - - - - - - - * - 70B - - - - - - - - - - `q3f16_1 `__ - - `q4f16_1 `__ - - - - - - `q4f16_1 `__ - - - -.. _deprecated-gpt_neox_library_table: - -GPT-NeoX (RedPajama-INCITE) -^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. list-table:: GPT-NeoX (RedPajama-INCITE) - :widths: 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M1/M2) - - Metal - - (Intel) - - iOS - - webgpu - - mali - * - 3B - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q4f16_0 `__ - - `q4f16_1 `__ - - `q4f16_0 `__ - - `q4f16_1 `__ - - `q4f16_0 `__ - - `q4f16_1 `__ - - `q4f16_0 `__ - - `q4f16_1 `__ - - `q4f16_0 `__ - - `q4f16_1 `__ - - `q4f16_0 `__ - - `q4f16_1 `__ - - `q4f32_0 `__ - - `q4f32_1 `__ - - `q4f16_1 `__ - -.. _deprecated-rwkv_library_table: - -RWKV -^^^^ -.. list-table:: RWKV - :widths: 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M1/M2) - - Metal - - (Intel) - - iOS - - webgpu - - mali - * - 1B5 - - - - - - `q8f16_0 `__ - - `q8f16_0 `__ - - `q8f16_0 `__ - - `q8f16_0 `__ - - - - - - - * - 3B - - - - - - `q8f16_0 `__ - - `q8f16_0 `__ - - `q8f16_0 `__ - - `q8f16_0 `__ - - - - - - - * - 7B - - - - - - `q8f16_0 `__ - - `q8f16_0 `__ - - `q8f16_0 `__ - - `q8f16_0 `__ - - - - - - - -.. _deprecated-gpt_big_code_library_table: - -GPTBigCode -^^^^^^^^^^ -Note that these all links to model libraries for WizardCoder (the older version released in Jun. 2023). -However, any GPTBigCode model variants should be able to reuse these (e.g. StarCoder, SantaCoder). - -.. list-table:: GPTBigCode - :widths: 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M1/M2) - - Metal - - (Intel) - - iOS - - webgpu - - mali - * - 15B - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - `q4f16_1 `__ - - `q4f32_1 `__ - - `q4f16_1 `__ - - - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - - -.. _deprecated-model-variant-tables: - -Level 3: Model Variant Tables (Precompiled Weights) ---------------------------------------------------- - -Finally, for each model variant, we provide the precompiled weights we uploaded to hugging face. - -Each precompiled weight is categorized by its model size (e.g. 7B vs. 13B) and the quantization scheme (e.g. ``q3f16_1`` vs. ``q4f16_1``). We note that the weights are **platform-agnostic**. - -Each model variant also loads its conversation configuration from a pre-defined :ref:`conversation template`. Note that multiple model variants can share a common conversation template. - -Some of these files are uploaded by our community contributors--thank you! - -.. _deprecated-llama2_variant_table: - -`Llama-2 `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``llama-2`` - -.. list-table:: Llama-2 - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q3f16_1 `__ - * `q4f16_1 `__ - * `q4f32_1 `__ - - * - 13B - - * `q4f16_1 `__ - * `q4f32_1 `__ - - * - 70B - - * `q3f16_1 `__ - * `q4f16_1 `__ - -.. _deprecated-code_llama_variant_table: - -`Code Llama `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``codellama_completion`` - -.. list-table:: Code Llama - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q4f16_1 (Base) `__ - * `q4f16_1 (Instruct) `__ - * `q4f16_1 (Python) `__ - - * - 13B - - * `q4f16_1 (Base) `__ - * `q4f16_1 (Instruct) `__ - * `q4f16_1 (Python) `__ - - * - 34B - - * `q4f16_1 (Base) `__ - * `q4f16_1 (Instruct) `__ - * `q4f16_1 (Python) `__ - - -.. _deprecated-vicuna_variant_table: - -`Vicuna `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``vicuna_v1.1`` - -.. list-table:: Vicuna - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q3f16_0 `__ - * `q4f32_0 `__ - * `int3 (demo) `__ - * `int4 (demo) `__ - - -.. _deprecated-WizardLM_variant_table: - -`WizardLM `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``vicuna_v1.1`` - -.. list-table:: WizardLM - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 13B - - * `q4f16_1 (V1.2) `__ - * `q4f32_1 (V1.2) `__ - - * - 70B - - * `q3f16_1 (V1.0) `__ - * `q4f16_1 (V1.0) `__ - - -.. _deprecated-wizard_math_variant_table: - -`WizardMath `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``wizard_coder_or_math`` - -.. list-table:: WizardMath - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q4f16_1 `__ - * `q4f32_1 `__ - * - 13B - - `q4f16_1 `__ - * - 70B - - `q4f16_1 `__ - - -.. _deprecated-open_orca_variant_table: - -`OpenOrca Platypus2 `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``llama-2`` - -.. list-table:: OpenOrca Platypus2 - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 13B - - `q4f16_1 `__ - - -.. _deprecated-flag_alpha_llama2_variant_table: - -`FlagAlpha Llama-2 Chinese `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``llama-2`` - -.. list-table:: FlagAlpha Llama-2 Chinese - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q4f16_1 `__ - * `q4f32_1 `__ - - -.. _deprecated-llama2_uncensored_variant_table: - -`Llama2 uncensored (georgesung) `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``llama-default`` - -.. list-table:: Llama2 uncensored - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q4f16_1 `__ - * `q4f32_1 `__ - -.. _deprecated-red_pajama_variant_table: - -`RedPajama `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``LM`` - -.. list-table:: Red Pajama - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 3B - - * `q4f16_0 (Instruct) `__ - * `q4f16_0 (Chat) `__ - * `q4f16_1 (Chat) `__ - * `q4f32_0 (Chat) `__ - - -.. _deprecated-rwkv_raven_variant_table: - -`RWKV-raven `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``rwkv`` - -.. list-table:: RWKV-raven - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 1B5 - - `q8f16_0 `__ - - * - 3B - - `q8f16_0 `__ - - * - 7B - - `q8f16_0 `__ - - -.. _deprecated-wizard_coder_variant_table: - -`WizardCoder `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``wizard_coder_or_math`` - -.. list-table:: WizardCoder - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 15B - - `q4f16_1 `__ - ------------------- - - -.. _deprecated-contribute-models-to-mlc-llm: - -Contribute Models to MLC-LLM ----------------------------- - -Ready to contribute your compiled models/new model architectures? Awesome! Please check :ref:`contribute-new-models` on how to contribute new models to MLC-LLM. diff --git a/mlc_llm/__init__.py b/mlc_llm/__init__.py deleted file mode 100644 index b74f00797d..0000000000 --- a/mlc_llm/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from . import dispatch -from . import quantization -from . import relax_model -from . import transform -from . import utils -from . import core -from .core import build_model, BuildArgs diff --git a/mlc_llm/build.py b/mlc_llm/build.py deleted file mode 100644 index b7619aa963..0000000000 --- a/mlc_llm/build.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Script for building/compiling models.""" -import contextlib -import sys - -from mlc_llm import core - - -@contextlib.contextmanager -def debug_on_except(): - try: - yield - finally: - raised_exception = sys.exc_info()[1] - if not isinstance(raised_exception, Exception): - return - - import traceback - - try: - import ipdb as pdb - except ImportError: - import pdb - - traceback.print_exc() - pdb.post_mortem() - - -def main(): - """Main method for building model from command line.""" - empty_args = core.convert_build_args_to_argparser() # Create new ArgumentParser - parsed_args = empty_args.parse_args() # Parse through command line - - with contextlib.ExitStack() as stack: - # Enter an exception-catching context before post-processing - # the arguments, in case the post-processing itself raises an - # exception. - if parsed_args.pdb: - stack.enter_context(debug_on_except()) - - # Post processing of arguments - parsed_args = core._parse_args(parsed_args) # pylint: disable=protected-access - - core.build_model_from_args(parsed_args) - - -if __name__ == "__main__": - main() diff --git a/mlc_llm/core.py b/mlc_llm/core.py deleted file mode 100644 index d4855582e6..0000000000 --- a/mlc_llm/core.py +++ /dev/null @@ -1,996 +0,0 @@ -# pylint: disable=missing-docstring, redefined-outer-name, not-callable -import argparse -import functools -import json -import os -import pickle -from dataclasses import asdict, dataclass, field, fields -from typing import Any, Dict, Optional - -import mlc_llm -import tvm -import tvm.relax.backend.contrib.cublas as _ -from mlc_llm import utils -from mlc_llm.relax_model import ( - chatglm, - gpt_bigcode, - gpt_neox, - gptj, - llama, - llama_batched_vllm, - minigpt, - mistral, - param_manager, - rwkv, - stablelm_3b, -) -from mlc_llm.relax_model.commons import ( - create_shard_info_func, - create_shard_transformation_func, -) -from mlc_llm.relax_model.param_manager import ( - chain_parameter_transforms, - transform_params_for_each_rank, -) -from mlc_llm.transform import fuse_split_rotary_embedding, rewrite_attention -from tvm import dlight as dl -from tvm import relax -from tvm.contrib.nvcc import parse_compute_version -from tvm.relax.backend import get_patterns_with_prefix -from tvm.relax.backend.contrib.cutlass import annotate_workspace - - -@dataclass -class BuildArgs: - r"""BuildArgs is the dataclass that organizes the arguments we use in - building a model. - - To use :meth:`mlc_llm.build_model`, users pass in an instance of :class:`BuildArgs`; for - CLI entry points, an equivalent :class:`ArgumentParser` instance is generated based - on the definition of this class using :meth:`mlc_llm.convert_build_args_to_argparser`. - - Parameters - ---------- - model: str - The name of the model to build. If it is ``auto``, we will automatically - set the model name according to ``--model-path``, ``hf-path``, or the model - folders under ``--artifact-path/models``. - - hf_path: str - Hugging Face path from which to download params, tokenizer, and config. - - quantization: str - The quantization mode we use to compile. - - max_seq_len: int - The maximum allowed sequence length for the model. - - target: str - The target platform to compile the model for. - - db_path: str - Path to log database for all models. Default: ``./log_db/``. - - reuse_lib: str - Whether to reuse a previously generated lib. - - artifact_path: str - Where to store the output. - - use_cache: int - Whether to use previously pickled IRModule and skip trace. - - convert_weights_only: bool - Whether to only convert model weights and not build the model. If both - ``convert_weight_only`` and ``build_model_only`` are set, the behavior is undefined. - - build_model_only: bool - Whether to only build model and do not convert model weights. - - debug_dump: bool - Whether to dump debugging files during compilation. - - debug_load_script: bool - Whether to load the script for debugging. - - llvm_mingw: str - ``/path/to/llvm-mingw-root``, use llvm-mingw to cross compile to windows. - - system_lib: bool - A parameter to ``relax.build``. - - sep_embed: bool - Build with separated embedding layer, only applicable to LlaMa. This - feature is in testing stage, and will be formally replaced after massive - overhaul of embedding feature for all models and use cases. - - sliding_window: int - The sliding window size in sliding window attention (SWA). This optional field - overrides the `sliding_window` in config.json for those models that use SWA. - Currently only useful when compiling Mistral. - - prefill_chunk_size: int - The chunk size during prefilling. By default, the chunk size is the same as - max sequence length. Currently only useful when compiling Mistral. - - attention_sink_size: int - Number of attention sinks (https://arxiv.org/abs/2309.17453). - Only supported on mistral yet. - - cc_path: str - ``/path/to/cross_compiler_path``; currently only used for cross-compile - for nvidia/jetson device. - - use_safetensors: bool - Specifies whether to use ``.safetensors`` instead of the default ``.bin`` - when loading in model weights. - - enable_batching: bool - Build the model for batched inference. - This is a temporary flag used to control the model execution flow in single- - sequence and batching settings for now. We will eventually merge two flows - in the future and remove this flag then. - - no_cutlass_attn: bool - Disable offloading attention operations to CUTLASS. - - no_cutlass_norm: bool - Disable offloading layer and RMS norm operations to CUTLASS. - - no_cublas: bool - Disable the step that offloads matmul to cuBLAS. Without this flag, - matmul will be offloaded to cuBLAS if quantization mode is ``q0f16`` or - ``q0f32``, target is CUDA and TVM has been built with cuBLAS enabled. - - use_cuda_graph: bool - Specifies whether to enable CUDA Graph for the decoder. MLP and QKV - projection between two attention layers are put into a graph. - - num_shards: int - Number of shards to split the model into in tensor parallelism multi-gpu - inference. Only useful when ``build_model_only`` is set. - - use_flash_attn_mqa: bool - Offload multi-query attention workload to Flash Attention. - - pdb: bool - If set, drop into a pdb debugger on error. - - use_vllm_attention: bool - Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True. - """ - model: str = field( - default="auto", - metadata={ - "help": ( - 'The name of the model to build. If it is "auto", we will ' - 'automatically set the model name according to "--model-path", ' - '"hf-path" or the model folders under "--artifact-path/models"' - ) - }, - ) - hf_path: str = field( - default=None, - metadata={"help": "Hugging Face path from which to download params, tokenizer, and config"}, - ) - quantization: str = field( - default="q4f16_1", - metadata={ - "help": "The quantization mode we use to compile.", - "choices": [*utils.quantization_schemes.keys()], - }, - ) - max_seq_len: int = field( - default=-1, - metadata={"help": "The maximum allowed sequence length for the model."}, - ) - max_vocab_size: int = field( - default=40000, - metadata={"help": "The maximum allowed vocabulary size for the model."}, - ) - target: str = field( - default="auto", - metadata={"help": "The target platform to compile the model for."}, - ) - reuse_lib: str = field( - default=None, metadata={"help": "Whether to reuse a previously generated lib."} - ) - artifact_path: str = field(default="dist", metadata={"help": "Where to store the output."}) - use_cache: int = field( - default=1, - metadata={"help": "Whether to use previously pickled IRModule and skip trace."}, - ) - convert_weights_only: bool = field( - default=False, - metadata={ - "dest": "convert_weights_only", - "action": "store_true", - "help": "Whether to only convert model weights and not build the model.", - }, - ) - build_model_only: bool = field( - default=False, - metadata={ - "help": "Whether to only build model and do not convert model weights.", - "action": "store_true", - }, - ) - debug_dump: bool = field( - default=False, - metadata={ - "help": "Whether to dump debugging files during compilation.", - "action": "store_true", - }, - ) - debug_load_script: bool = field( - default=False, - metadata={ - "help": "Whether to load the script for debugging.", - "action": "store_true", - }, - ) - llvm_mingw: str = field( - default="", - metadata={"help": "/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows."}, - ) - cc_path: str = field( - default="", - metadata={ - "help": ( - "/path/to/cross_compiler_path, Currently only used for " - "cross-compile for nvidia/jetson device." - ) - }, - ) - system_lib: bool = field( - default=False, - metadata={"help": "A parameter to `relax.build`.", "action": "store_true"}, - ) - sep_embed: bool = field( - default=False, - metadata={ - "help": ( - "Build with separated embedding layer, only applicable to LlaMa. " - "This feature is in testing stage, and will be formally replaced after " - "massive overhaul of embedding feature for all models and use cases" - ), - "action": "store_true", - }, - ) - use_safetensors: bool = field( - default=False, - metadata={ - "help": ( - "Specifies whether to use ``.safetensors`` instead of the default " - "``.bin`` when loading in model weights." - ), - "action": "store_true", - }, - ) - enable_batching: bool = field( - default=False, - metadata={ - "help": ( - "Build the model for batched inference." - "This is a temporary flag used to control the model execution flow in single-" - "sequence and batching settings for now. We will eventually merge two flows" - "in the future and remove this flag then." - ), - "action": "store_true", - }, - ) - max_batch_size: int = field( - default=80, - metadata={ - "help": ( - "The maximum batch size for build. It has effect only when batching is enabled." - ), - }, - ) - no_cutlass_attn: bool = field( - default=False, - metadata={ - "help": ("Disable offloading attention operations to CUTLASS."), - "action": "store_true", - }, - ) - no_cutlass_norm: bool = field( - default=False, - metadata={ - "help": ("Disable offloading layer and RMS norm operations to CUTLASS."), - "action": "store_true", - }, - ) - no_cublas: bool = field( - default=False, - metadata={ - "help": ( - "Disable the step that offloads matmul to cuBLAS. Without this flag, " - "matmul will be offloaded to cuBLAS if quantization mode is q0f16 or q0f32, " - "target is CUDA and TVM has been built with cuBLAS enabled." - ), - "action": "store_true", - }, - ) - use_cuda_graph: bool = field( - default=False, - metadata={ - "help": ( - "Specifies whether to enable CUDA Graph for the decoder. MLP and QKV " - "projection between two attention layers are put into a graph." - ), - "action": "store_true", - }, - ) - num_shards: int = field( - default=1, - metadata={ - "help": ( - "Number of shards to split the model into in tensor parallelism multi-gpu " - "inference. Only useful when --build-model-only is set." - ), - }, - ) - use_presharded_weights: bool = field( - default=False, - metadata={ - "action": "store_true", - "help": "Produce separate weight sets for each shard.", - }, - ) - use_flash_attn_mqa: bool = field( - default=False, - metadata={ - "help": ("Offload multi-query attention workload to Flash Attention."), - "action": "store_true", - }, - ) - sliding_window: int = field( - default=-1, - metadata={ - "help": ( - "The sliding window size in sliding window attention (SWA). " - "This optional field overrides the `sliding_window` in config.json for " - "those models that use SWA. Currently only useful when compiling Mistral." - ), - }, - ) - prefill_chunk_size: int = field( - default=-1, - metadata={ - "help": ( - "The chunk size during prefilling. By default, the chunk size is " - "the same as the sliding window size or the max sequence length. " - "Currently only useful when compiling Mistral." - ), - }, - ) - attention_sink_size: int = field( - default=0, - metadata={ - "help": ( - "The number of attention sinks to keep in cache." - "Only supported on mistral yet." - ), - }, - ) - pdb: bool = field( - default=False, - metadata={ - "help": ("If set, drop into a pdb debugger on error"), - "action": "store_true", - }, - ) - use_vllm_attention: bool = field( - default=False, - metadata={ - "help": ( - "Use vLLM paged KV cache and attention kernel, only relevant when " - "enable_batching=True." - ), - "action": "store_true", - }, - ) - - @property - def convert_weight_only(self): - """A backwards-compatibility helper""" - return self.convert_weights_only - - -def convert_build_args_to_argparser() -> argparse.ArgumentParser: - """Convert from BuildArgs to an equivalent ArgumentParser.""" - args = argparse.ArgumentParser() - for field in fields(BuildArgs): - name = field.name.replace("_", "-") - field_name = f"--{name}" - # `kwargs` contains `help`, `choices`, and `action` - kwargs = field.metadata.copy() - if field.type == bool: - # boolean arguments do not need to specify `type` - args.add_argument(field_name, default=field.default, **kwargs) - else: - args.add_argument(field_name, type=field.type, default=field.default, **kwargs) - - # Most models contain more than a single parameter (citation - # needed), so "weights" should be plural. The initial use of - # "--convert-weight-only" caused enough typos that it is worth - # fixing. The old argument spelling is retained for backwards - # compatibility. - args.add_argument( - "--convert-weight-only", - default=False, - dest="convert_weights_only", - action="store_true", - help="Equivalent to --convert-weights-only, retained for backwards compatibility.", - ) - - return args - - -def _parse_args(parsed) -> argparse.Namespace: - assert parsed.max_seq_len == -1 or parsed.max_seq_len > 0 - if parsed.use_safetensors: - try: - import safetensors # pylint: disable=import-outside-toplevel, unused-import - except ImportError as error: - raise ImportError( - "`use_safetensors` option is toggled, please install safetensors package." - ) from error - - parsed.export_kwargs = {} - parsed.lib_format = "so" - parsed.system_lib_prefix = None - parsed = _setup_model_path(parsed) - - utils.parse_target(parsed) - utils.argparse_postproc_common(parsed) - - if parsed.use_vllm_attention: - assert parsed.enable_batching, "--enable_batching is required for using vLLM attention." - assert parsed.target_kind == "cuda", "vLLM attention is only supported for CUDA." - assert tvm.get_global_func( - "tvm.contrib.vllm.single_query_cached_kv_attention", True - ), "TVM needs to be built with -DUSE_VLLM=ON." - - model_name = [ - parsed.model, - parsed.quantization.name, - ] - if parsed.use_presharded_weights: - model_name.append(f"presharded-{parsed.num_shards}gpu") - - parsed.artifact_path = os.path.join(parsed.artifact_path, "-".join(model_name)) - - return parsed - - -def _setup_model_path(args: argparse.Namespace): # pylint: disable=too-many-branches - if args.hf_path: - if args.model != "auto": - assert args.model == os.path.basename(args.hf_path), ( - 'When both "--model" and "--hf-path" is specified, the ' - 'value of "--model" is required to match the basename of "--hf-path". ' - f'Got "--model {args.model}" and "--hf-path {args.hf_path}"' - ) - else: - args.model = os.path.basename(args.hf_path) - args.model_path = os.path.join(args.artifact_path, "models", args.model) - if os.path.exists(args.model_path): - print(f"Weights exist at {args.model_path}, skipping download.") - else: - os.makedirs(args.model_path, exist_ok=True) - os.system("git lfs install") - os.system(f"git clone https://huggingface.co/{args.hf_path} {args.model_path}") - print(f"Downloaded weights to {args.model_path}") - validate_config(args.model_path) - elif args.model != "auto": - if os.path.isdir(args.model): - args.model = os.path.normpath(args.model) # Remove potential trailing `/` - args.model_path = args.model - args.model = os.path.basename(args.model) - else: - args.model_path = os.path.join(args.artifact_path, "models", args.model) - validate_config(args.model_path) - else: - lookup_path = os.path.join(args.artifact_path, "models") - print(f'"--model" is set to "auto". Searching in {lookup_path} for existing models.') - for dirname in os.listdir(lookup_path): - if os.path.isdir(os.path.join(lookup_path, dirname)) and os.path.isfile( - os.path.join(lookup_path, dirname, "config.json") - ): - try: - validate_config(os.path.join(lookup_path, dirname)) - except: # pylint: disable=bare-except - pass - else: - args.model_path = os.path.join(lookup_path, dirname) - args.model = dirname - break - if args.model == "auto": - raise ValueError("Please specify either the model_path or the hf_path.") - - print(f'Using path "{args.model_path}" for model "{args.model}"') - return args - - -def validate_config(model_path: str): - if os.path.exists(os.path.join(model_path, "mlc-chat-config.json")): - raise KeyError( - f"The model located in the directory {model_path} has already been compiled " - "by MLC-LLM. There is no need to compile it again. If you wish to compile " - "a new model, please provide a directory (or hf-path) that contains the " - "pre-compiled model in raw HuggingFace format instead." - ) - if model_path.split("/")[-1].startswith("minigpt"): - # minigpt does not contain a config.json file so we skip the check - return - config_path = os.path.join(model_path, "config.json") - assert os.path.exists( - config_path - ), f"Expecting HuggingFace config, but file not found: {config_path}." - with open(config_path, encoding="utf-8") as i_f: - config = json.load(i_f) - assert ( - "model_type" in config - ), f"Invalid config format. Expecting HuggingFace config format in: {config_path}" - assert ( - config["model_type"] in utils.supported_model_types - ), f"Model type {config['model_type']} not supported." - - -def get_cuda_sm_version(): - major, minor = parse_compute_version(tvm.cuda(0).compute_version) - - if major == 8: - sm = 80 - else: - sm = 10 * major + minor - - return sm - - -def optimize_mod_pipeline( - args: argparse.Namespace, - config: Dict, -) -> tvm.ir.transform.Pass: - """First-stage: Legalize ops and trace""" - seq = [] - - use_ft_quant = args.quantization.name in [ - "q4f16_ft", - "q8f16_ft", - "q4f16_ft_group", - "q8f16_ft_group", - ] - seq.append(mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)) - - if ( - not args.enable_batching - and hasattr(config, "num_attention_heads") - and hasattr(config, "hidden_size") - and hasattr(config, "position_embedding_base") - and getattr(config, "dtype", "float16") == "float16" - ): - max_seq_len = None - if args.max_seq_len > 0: - max_seq_len = args.max_seq_len - elif hasattr(config, "max_sequence_length"): - max_seq_len = config.max_sequence_length - - if max_seq_len: - num_key_value_heads = config.get_num_key_value_heads() - # pylint: disable=no-value-for-parameter - seq.append( - fuse_split_rotary_embedding( - config.num_attention_heads // args.num_shards, - num_key_value_heads // args.num_shards, - config.hidden_size // args.num_shards, - config.position_embedding_base, - ) - ) - - if args.target_kind == "cuda": - patterns = [] - - has_cutlass = tvm.get_global_func("relax.ext.cutlass", True) - - if has_cutlass and not args.no_cutlass_attn: - # pylint: disable=no-value-for-parameter - if args.use_flash_attn_mqa: - seq.append(rewrite_attention(use_flash_mqa=True)) - seq.append(rewrite_attention(use_flash_mqa=False)) - patterns += get_patterns_with_prefix("cutlass.attention") - - if has_cutlass and not args.no_cutlass_norm: - patterns += get_patterns_with_prefix("cutlass.layer_norm") - patterns += get_patterns_with_prefix("cutlass.rms_norm") - - if has_cutlass and use_ft_quant: - patterns += get_patterns_with_prefix("cutlass.decode_matmul") - - has_cublas = tvm.get_global_func("relax.ext.cublas", True) - - if has_cublas and args.quantization.name in ("q0f16", "q0f32") and not args.no_cublas: - patterns += get_patterns_with_prefix("cublas") - - if len(patterns) > 0: - os.makedirs("./tmp", exist_ok=True) - - sm = get_cuda_sm_version() - options = {"cutlass": {"sm": sm, "find_first_valid": False}} - - if hasattr(config, "rms_norm_eps"): - options["cutlass"]["rms_eps"] = config.rms_norm_eps - - seq.extend( - [ - relax.transform.FuseOpsByPattern( - patterns, bind_constants=False, annotate_codegen=True - ), - annotate_workspace, - relax.transform.AllocateWorkspace(), - relax.transform.RunCodegen(options), - ] - ) - - if args.target_kind == "android": - seq.extend( - [ - mlc_llm.transform.FuseTranspose1Matmul(), - mlc_llm.transform.FuseTranspose2Matmul(), - ] - ) - seq.extend( - [ - mlc_llm.transform.FuseTransposeMatmul(), - relax.pipeline.get_pipeline(), - mlc_llm.transform.FuseDecodeMatmulEwise(), - mlc_llm.transform.FuseDecodeTake(), - relax.transform.DeadCodeElimination(), - mlc_llm.transform.CleanUpTIRAttrs(), - ] - ) - - return tvm.ir.transform.Sequential(seq, name="mlc_llm.core.optimize_mod_pipeline") - - - -def dump_mlc_chat_config( - args: argparse.Namespace, - vocab_size: int, - max_window_size: int, - temperature: float = 0.7, - repetition_penalty: float = 1.0, - top_p: float = 0.95, - mean_gen_len: int = 128, - max_gen_len: int = 512, - shift_fill_factor: float = 0.3, - rwkv_world=False, -): - args.params_path = os.path.join(args.artifact_path, "params") - config: Dict[str, Any] = {} - - if args.reuse_lib: - config["model_lib"] = f"{args.reuse_lib}" - if not args.reuse_lib.endswith(args.quantization.name): - raise RuntimeError(f"Trying to reuse lib without suffix {args.quantization.name}") - else: - config["model_lib"] = f"{args.model}-{args.quantization.name}" - - config["local_id"] = f"{args.model}-{args.quantization.name}" - config["conv_template"] = args.conv_template - config["temperature"] = temperature - config["repetition_penalty"] = repetition_penalty - config["top_p"] = top_p - config["mean_gen_len"] = mean_gen_len - config["max_gen_len"] = max_gen_len - config["num_shards"] = args.num_shards - config["use_presharded_weights"] = args.use_presharded_weights - config["shift_fill_factor"] = shift_fill_factor - if rwkv_world: - config["tokenizer_files"] = ["tokenizer_model"] - else: - config["tokenizer_files"] = utils.get_tokenizer_files(args.params_path) - config["model_category"] = args.model_category - config["model_name"] = args.model - config["vocab_size"] = vocab_size - config["prefill_chunk_size"] = args.prefill_chunk_size - if args.sliding_window != -1: - # Do not add max window size if use sliding window - config["sliding_window"] = args.sliding_window - - # only use sinks if sliding window enabled - if args.attention_sink_size > 0: - config["attention_sink_size"] = args.attention_sink_size - else: - config["max_window_size"] = max_window_size - - args.chat_config_path = os.path.join(args.params_path, "mlc-chat-config.json") - with open(args.chat_config_path, "w", encoding="utf-8") as outfile: - json.dump(config, outfile, indent=4) - print(f"Finish exporting chat config to {args.chat_config_path}") - - -def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None: - target_kind = args.target_kind - if args.system_lib_prefix: - mod_deploy = mod_deploy.with_attrs({"system_lib_prefix": args.system_lib_prefix}) - - utils.debug_dump_script(mod_deploy, "mod_before_build.py", args) - utils.debug_dump_benchmark_script( - mod_deploy, f"{args.model}_{args.quantization.name}".replace("-", "_"), args - ) - - if target_kind != "cpu": - dispatch_target = ( - args.target - if args.target_kind != "webgpu" - else tvm.target.Target("apple/m1-gpu-restricted") - ) - with dispatch_target: - if args.target_kind == "android": - mod_deploy = mlc_llm.dispatch.DispatchTIROperatorAdreno()( # pylint: disable=not-callable - mod_deploy - ) - mod_deploy = dl.ApplyDefaultSchedule( # pylint: disable=not-callable - dl.gpu.Matmul(), - dl.gpu.GEMV(), - dl.gpu.Reduction(), - dl.gpu.GeneralReduction(), - dl.gpu.Fallback(), - )(mod_deploy) - mod_deploy = ( - mlc_llm.transform.LiftTIRGlobalBufferAlloc()( # pylint: disable=not-callable - mod_deploy - ) - ) - if not args.enable_batching: - mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy) - - if args.debug_load_script: - mod_deploy = utils.debug_load_script("mod_build_stage_debug.py", args) - - utils.debug_dump_script(mod_deploy, "mod_build_stage.py", args) - - use_cuda_graph = args.use_cuda_graph and target_kind == "cuda" - - with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": use_cuda_graph}): - # The num_input attribute is needed to capture transformed weights passed as input - # into a cuda graph. - # NOTE: CUDA graph for batching is not enabled and is left as a TODO item. - if not args.enable_batching: - mod_deploy["decode"] = mod_deploy["decode"].with_attr({"num_input": 3}) - ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib) - - output_filename = f"{os.path.split(args.model)[1]}-{args.quantization.name}-{target_kind}.{args.lib_format}" - - utils.debug_dump_shader(ex, f"{args.model}_{args.quantization.name}_{target_kind}", args) - args.lib_path = os.path.join(args.artifact_path, output_filename) - ex.export_library(args.lib_path, **args.export_kwargs) - print(f"Finish exporting to {args.lib_path}") - - -def build_model_from_args(args: argparse.Namespace): - if args.quantization == "q4f16_0": - print( - "WARNING: q4f16_1 is preferred to q4f16_0, " - "and it is highly recommended to use q4f16_1 instead" - ) - - use_ft_quant = args.quantization.name in [ - "q4f16_ft", - "q8f16_ft", - "q4f16_ft_group", - "q8f16_ft_group", - ] - - if args.num_shards > 1: - if (not args.build_model_only) and (not args.convert_weights_only): - raise ValueError( - "`num_shards` should be used together with " - "`--build-model-only` and `--convert-weight-only`" - ) - - if use_ft_quant and not args.use_presharded_weights: - print( - "WARNING: FT quantization with multi-gpus requires presharding weights." - "Forcing --use-presharded-weights." - ) - args.use_presharded_weights = True - - os.makedirs(args.artifact_path, exist_ok=True) - if args.debug_dump: - os.makedirs(os.path.join(args.artifact_path, "debug"), exist_ok=True) - cache_path = os.path.join(args.artifact_path, "mod_cache_before_build.pkl") - args.raw_params_path = os.path.join(args.artifact_path, "raw_params") - use_cache = args.use_cache and os.path.isfile(cache_path) - if args.sep_embed and args.model_category != "llama": - raise ValueError(f"separate embedding not supported on {args.model}") - - if args.model_category == "minigpt": - # Special case for minigpt, which neither provides nor requires a configuration. - config = {} - else: - with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: - config = json.load(i_f) - - if not use_cache or args.convert_weights_only: - model_generators = { - "llama": llama, - "mistral": mistral, - "stablelm_epoch": stablelm_3b, - "gpt_neox": gpt_neox, - "gpt_bigcode": gpt_bigcode, - "minigpt": minigpt, - "gptj": gptj, - "rwkv": rwkv, - "rwkv_world": rwkv, - "chatglm": chatglm, - } - - if args.use_vllm_attention: - model_generators["llama"] = llama_batched_vllm - model_generators["mistral"] = llama_batched_vllm - - assert args.model_category in model_generators, f"Model {args.model} not supported" - - mod, param_manager, params, model_config = model_generators[args.model_category].get_model( - args, config - ) - - if args.model_category == "mistral": - args.sliding_window = model_config.sliding_window - args.attention_sink_size = model_config.attention_sink_size - - for qspec_updater_class in param_manager.qspec_updater_classes: - qspec_updater = qspec_updater_class(param_manager) - qspec_updater.visit_module(mod) - mod = param_manager.transform_dequantize()(mod) - mod = relax.transform.BundleModelParams()(mod) - - if not args.build_model_only: - parameter_transforms = [] - - # Run pre-quantization if provided. - args.model_path = param_manager.run_pre_quantize(args.model_path) - param_manager.init_torch_pname_to_bin_name(args.use_safetensors) - parameter_transforms.append(param_manager.create_parameter_transformation()) - - # Run pre-sharding if required - if args.num_shards > 1 and args.use_presharded_weights: - mod_shard = create_shard_transformation_func(param_manager, args, model_config) - mod_shard = transform_params_for_each_rank(num_shards=args.num_shards)(mod_shard) - parameter_transforms.append(mod_shard) - - # Chain all parameter transforms together. This allows - # ReorderTransformFunc to be applied to the single - # resulting parameter transformation function. - mod_transform = functools.reduce(chain_parameter_transforms, parameter_transforms) - - seq = tvm.ir.transform.Sequential( - [ - relax.transform.CanonicalizeBindings(), - relax.transform.EliminateCommonSubexpr(), - relax.transform.DeadCodeElimination(), - # TODO(Lunderberg): Implement - # relax.transform.Simplify() that applies - # canonicalization, CSE, and DCE until - # convergence. - relax.transform.CanonicalizeBindings(), - relax.transform.EliminateCommonSubexpr(), - relax.transform.DeadCodeElimination(), - param_manager.optimize_transform_param_order(), - ], - name="SimplifyModTransform", - ) - - mod_transform = seq(mod_transform) - - params = utils.convert_weights(mod_transform, param_manager, params, args) - - if args.num_shards > 1 and use_ft_quant: - preprocessed = [] - weight_preprocess_func = tvm.get_global_func("cutlass.ft_preprocess_weight") - is_int4 = args.quantization.name in ["q4f16_ft", "q4f16_ft_group"] - sm = get_cuda_sm_version() - - for p in params: - if p.dtype == "int8": - preprocessed.append(weight_preprocess_func(p, sm, is_int4)) - else: - preprocessed.append(p) - - params = preprocessed - - utils.save_params( - params, args.artifact_path, args.num_shards if args.use_presharded_weights else 1 - ) - - if args.model_category != "minigpt": - utils.copy_tokenizer(args) - if args.model_category == "rwkv" or args.model_category == "rwkv_world": - # TODO: refactor config into model definition - dump_mlc_chat_config( - args, - vocab_size=config["vocab_size"], - max_window_size=model_config.max_sequence_length, - max_gen_len=model_config.max_sequence_length, - top_p=0.6, - temperature=1.2, - repetition_penalty=0.996, - rwkv_world=True, - ) - elif args.model_category == "chatglm": - dump_mlc_chat_config( - args, - vocab_size=config["padded_vocab_size"], - max_window_size=model_config.max_sequence_length, - max_gen_len=model_config.max_sequence_length, - ) - else: - dump_mlc_chat_config( - args, - vocab_size=config["vocab_size"], - max_window_size=model_config.max_sequence_length, - max_gen_len=model_config.max_sequence_length, - ) - - if args.convert_weights_only: - exit(0) - - mod = optimize_mod_pipeline(args, model_config)(mod) - if args.num_shards > 1: - # We require a "create_sharding_info" function for all - # multi-GPU models, even if they are using pre-sharded - # weights. When using pre-sharded weights, the list of - # initialization-time transforms to apply is empty. - sharding_module = create_shard_info_func(param_manager, args, model_config) - mod.update(sharding_module) - - with open(cache_path, "wb") as outfile: - pickle.dump(mod, outfile) - print(f"Save a cached module to {cache_path}.") - else: - print( - f"Load cached module from {cache_path} and skip tracing. " - "You can use --use-cache=0 to retrace" - ) - with open(cache_path, "rb") as pkl: - mod = pickle.load(pkl) - if not args.reuse_lib: - build(mod, args) - else: - print(f"Reuse existing prebuilt lib {args.reuse_lib}...") - - -def build_model(args: BuildArgs) -> (Optional[str], Optional[str], Optional[str]): - r"""Builds/compiles a model. - - Parameters - ---------- - args : :class:`BuildArgs` - A dataclass of arguments for building models.mlc_llm/core.py - - Returns - ---------- - lib_path: Optional[str] - The path to the model library file. Return ``None`` if not applicable. - model_path: Optional[str] - The path to the folder of the model's parameters. Return ``None`` if not applicable. - chat_config_path: Optional[str] - The path to the chat config `.json` file. Return ``None`` if not applicable. - """ - # Convert BuildArgs to argparse.Namespace so that we can share the rest - # of the code with the command line workflow - build_args_as_dict = asdict(args) - build_args_namespace = argparse.Namespace(**build_args_as_dict) - args = _parse_args(build_args_namespace) - build_model_from_args(args) - - # Prepare output; some workflows may or may not have the paths to return - lib_path = args.lib_path if hasattr(args, "lib_path") else None - model_path = args.params_path if hasattr(args, "params_path") else None - chat_config_path = args.chat_config_path if hasattr(args, "chat_config_path") else None - - return lib_path, model_path, chat_config_path diff --git a/mlc_llm/dispatch/__init__.py b/mlc_llm/dispatch/__init__.py deleted file mode 100644 index 234b60a8ad..0000000000 --- a/mlc_llm/dispatch/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .dispatch_tir_operator import DispatchTIROperator -from .dispatch_tir_operator_adreno import DispatchTIROperatorAdreno diff --git a/mlc_llm/dispatch/dispatch_tir_operator.py b/mlc_llm/dispatch/dispatch_tir_operator.py deleted file mode 100644 index 21a7d27218..0000000000 --- a/mlc_llm/dispatch/dispatch_tir_operator.py +++ /dev/null @@ -1,53 +0,0 @@ -# pylint: disable=missing-docstring -import tvm -from tvm import IRModule - - -@tvm.transform.module_pass(opt_level=0, name="DispatchTIROperator") -class DispatchTIROperator: # pylint: disable=too-few-public-methods - def __init__(self, model: str): - # pylint: disable=import-outside-toplevel - if model == "llama": - from .llama import lookup - - elif model == "gpt_neox": - from .gpt_neox import lookup - - elif model == "gpt_bigcode": - lookup = None - - elif model == "minigpt": - lookup = None - - elif model == "rwkv": - lookup = None - - elif model == "rwkv_world": - lookup = None - - elif model == "gptj": - lookup = None - - elif model == "chatglm": - lookup = None - - else: - raise ValueError(f"Model {model} not supported") - self.lookup = lookup - - # pylint: enable=import-outside-toplevel - - def transform_module( - self, - mod: IRModule, - ctx: tvm.transform.PassContext, - ) -> IRModule: - if self.lookup is None: - return mod - for gv in mod.functions: - scheduled_func = self.lookup(mod[gv]) - if scheduled_func is not None: - mod[gv] = scheduled_func - print("- Dispatch to pre-scheduled op:", gv.name_hint) - - return mod diff --git a/mlc_llm/dispatch/dispatch_tir_operator_adreno.py b/mlc_llm/dispatch/dispatch_tir_operator_adreno.py deleted file mode 100644 index 937a158b09..0000000000 --- a/mlc_llm/dispatch/dispatch_tir_operator_adreno.py +++ /dev/null @@ -1,8356 +0,0 @@ -import tvm -from tvm import IRModule -from tvm.script import tir as T - - -@T.prim_func(private=True) -def fused_decode4_matmul3( - lv1587: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), - lv1588: T.Buffer((T.int64(128), T.int64(4096)), "float16"), - lv1583: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - var_matmul_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1587[v_i // T.int64(8), v_j], lv1588[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1587[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1588[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1583[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv1583[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - ) - - -@T.prim_func(private=True) -def fused_decode4_matmul3_after( - lv1587: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), - lv1588: T.Buffer((T.int64(128), T.int64(4096)), "float16"), - lv1583: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - var_matmul_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ), -): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_matmul_intermediate_local = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(32768)), "float16", scope="local" - ) - var_matmul_intermediate_local_batch = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(32768)), "float16", scope="local" - ) - lv1587_local = T.alloc_buffer( - (T.int64(512), T.int64(4096)), "uint32", scope="local" - ) - lv1588_local = T.alloc_buffer( - (T.int64(128), T.int64(4096)), "float16", scope="local" - ) - lv1583_shared = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(2048)), "float16", scope="shared" - ) - for i0_i1_i2_fused_0 in T.thread_binding(T.int64(32), thread="blockIdx.x"): - for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(32768), - i0_i1_i2_fused_0 * T.int64(1024) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2_init, - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_0 in range(T.int64(2)): - for ax2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2_2 in T.vectorized(T.int64(8)): - with T.block("lv1583_shared"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(4096), - k_0 * T.int64(2048) - + ax2_1 * T.int64(64) - + (ax2_y * T.int64(8) + ax2_2), - ) - v2k = T.axis.spatial( - T.int64(2048), - ( - ax2_1 * T.int64(64) - + ax2_y * T.int64(8) - + ax2_2 - ), - ) - T.reads(lv1583[v0, v1, v2]) - T.writes(lv1583_shared[v0, v1, v2k]) - lv1583_shared[v0, v1, v2k] = lv1583[v0, v1, v2] - for k_1 in range(T.int64(8)): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax1 in T.vectorized(T.int64(4)): - with T.block("matmul_init_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(32768), - i0_i1_i2_fused_0 * T.int64(1024) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + ax1, - ) - T.reads() - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = T.float16(0) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv1588_local"): - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(64) - + (k_1 * T.int64(8) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv1588[v0, v1]) - T.writes(lv1588_local[v0, v1]) - lv1588_local[v0, v1] = lv1588[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv1587_local"): - v0 = T.axis.spatial( - T.int64(512), - k_0 * T.int64(256) - + (k_1 * T.int64(8) + ax2_y) * T.int64(4) - + k_2 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv1587[v0, v1]) - T.writes(lv1587_local[v0, v1]) - lv1587_local[v0, v1] = lv1587[v0, v1] - for k_3 in range(T.int64(8)): - for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_i2k = T.axis.spatial( - T.int64(32768), - i0_i1_i2_fused_0 * T.int64(1024) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_k = T.axis.reduce( - T.int64(4096), - k_0 * T.int64(2048) - + (k_1 * T.int64(8) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - v_ki = T.axis.reduce( - T.int64(2048), - (k_1 * T.int64(8) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - T.reads( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - lv1583_shared[v_i0, v_i1, v_ki], - lv1587_local[v_k // T.int64(8), v_i2], - ) - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] + lv1583_shared[ - v_i0, v_i1, v_ki - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1587_local[ - v_k // T.int64(8), v_i2 - ], - T.Cast( - "uint32", - v_k % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - ) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("multiple_scale"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(32768), - i0_i1_i2_fused_0 * T.int64(1024) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + ax1, - ) - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(64) - + (k_1 * T.int64(8) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads( - lv1588_local[v0, v1], - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - ) - T.writes( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - ) - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - + var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - * lv1588_local[v0, v1] - ) - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_update"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(1024), - i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(32768), - i0_i1_i2_fused_0 * T.int64(1024) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + ax2, - ) - T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) - T.writes(lv1583_shared[v0, v1, v2]) - lv1583_shared[v0, v1, v2] = var_matmul_intermediate_local[ - v0, v1, v_i2k - ] - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("reduction_sum"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(1024), - i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + ax2, - ) - T.where(ax2_y < T.int64(4)) - T.reads(lv1583_shared[v0, v1, v2]) - T.writes(lv1583_shared[v0, v1, v2]) - lv1583_shared[v0, v1, v2] = ( - lv1583_shared[v0, v1, v2] - + lv1583_shared[v0, v1, v2 + T.int64(16)] - ) - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(1024), - i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + ax2, - ) - T.where(ax2_y < T.int64(1)) - T.reads(lv1583_shared[v0, v1, v_i2k]) - T.writes(var_matmul_intermediate[v0, v1, v2]) - var_matmul_intermediate[v0, v1, v2] = ( - lv1583_shared[v0, v1, v_i2k] - + lv1583_shared[v0, v1, v_i2k + T.int64(4)] - + lv1583_shared[v0, v1, v_i2k + T.int64(8)] - + lv1583_shared[v0, v1, v_i2k + T.int64(12)] - ) - - -def sch_fused_decode4_matmul3(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[32, 64, 2] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[128, 4, 8] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") - b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") - b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) - v23 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 - ) - sch.annotate( - block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 - ) - sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) - sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) - l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) - sch.vectorize(loop=l29) - l30, l31, l32, l33, l34 = sch.get_loops(block=b21) - sch.vectorize(loop=l34) - l35, l36, l37, l38, l39 = sch.get_loops(block=b19) - sch.vectorize(loop=l39) - sch.vectorize(loop=l12) - b40 = sch.decompose_reduction(block=b1, loop=l16) - sch.enter_postproc() - sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") - l41, l42, l43, l44, l45 = sch.get_loops(block=b22) - l46, l47, l48 = sch.split(loop=l45, factors=[None, 64, 8], preserve_unit_iters=True) - sch.vectorize(loop=l48) - sch.bind(loop=l47, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func(private=True) -def fused_decode6_fused_matmul7_add1( - lv1623: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), - lv1624: T.Buffer((T.int64(344), T.int64(4096)), "float16"), - lv200: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), - lv198: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") - var_matmul_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ) - for i, j in T.grid(T.int64(11008), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1623[v_i // T.int64(8), v_j], lv1624[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1623[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1624[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv200[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv200[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - lv198[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - ) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - lv198[v_ax0, v_ax1, v_ax2] - + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - ) - - -@T.prim_func(private=True) -def fused_decode6_fused_matmul7_add1_after( - lv1623: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), - lv1624: T.Buffer((T.int64(344), T.int64(4096)), "float16"), - lv200: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), - lv198: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ), -): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_matmul_intermediate_local = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(16384)), "float16", scope="local" - ) - var_matmul_intermediate_local_batch = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(16384)), "float16", scope="local" - ) - lv1623_local = T.alloc_buffer( - (T.int64(1376), T.int64(4096)), "uint32", scope="local" - ) - lv1624_local = T.alloc_buffer( - (T.int64(344), T.int64(4096)), "float16", scope="local" - ) - lv200_shared = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(2752)), "float16", scope="shared" - ) - for i0_i1_i2_fused_0 in T.thread_binding(T.int64(8), thread="blockIdx.x"): - for i0_i1_i2_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(4), thread="threadIdx.y"): - for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(16384), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(16) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2_init, - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_0 in range(T.int64(4)): - for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(1), T.int64(3)): - for ax2_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(4), thread="threadIdx.y"): - for ax2_2 in T.vectorized(T.int64(2)): - with T.block("lv200_shared"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(11008), - k_0 * T.int64(2752) - + ( - ax2_0 * T.int64(1024) - + ax2_1 * T.int64(8) - + (ax2_y * T.int64(2) + ax2_2) - ), - ) - v2k = T.axis.spatial( - T.int64(2752), - ( - ax2_0 * T.int64(1024) - + ax2_1 * T.int64(8) - + (ax2_y * T.int64(2) + ax2_2) - ), - ) - T.where( - (ax2_0 * T.int64(128) + ax2_1) < T.int64(344) - ) - T.reads(lv200[v0, v1, v2]) - T.writes(lv200_shared[v0, v1, v2k]) - lv200_shared[v0, v1, v2k] = lv200[v0, v1, v2] - for k_1 in range(T.int64(22)): - for ax2_y in T.thread_binding(T.int64(4), thread="threadIdx.y"): - with T.block("lv1624_check"): - T.where((k_1 * T.int64(4) + ax2_y) < T.int64(86)) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("matmul_init_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(16384), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(16) - + ax2_y * T.int64(4) - + ax1, - ) - T.reads() - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = T.float16(0) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv1624_local"): - v0 = T.axis.spatial( - T.int64(344), - k_0 * T.int64(86) - + (k_1 * T.int64(4) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv1624[v0, v1]) - T.writes(lv1624_local[v0, v1]) - lv1624_local[v0, v1] = lv1624[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv1623_local"): - v0 = T.axis.spatial( - T.int64(1376), - k_0 * T.int64(344) - + (k_1 * T.int64(4) + ax2_y) - * T.int64(4) - + k_2 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv1623[v0, v1]) - T.writes(lv1623_local[v0, v1]) - lv1623_local[v0, v1] = lv1623[v0, v1] - for k_3 in range(T.int64(8)): - for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial( - T.int64(1), T.int64(0) - ) - v_i1 = T.axis.spatial( - T.int64(1), T.int64(0) - ) - v_i2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_i2k = T.axis.spatial( - T.int64(16384), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(16) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_k = T.axis.reduce( - T.int64(11008), - k_0 * T.int64(2752) - + (k_1 * T.int64(4) + ax2_y) - * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - v_ki = T.axis.reduce( - T.int64(2752), - (k_1 * T.int64(4) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - T.reads( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - lv200_shared[v_i0, v_i1, v_ki], - lv1623_local[v_k // T.int64(8), v_i2], - ) - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] + lv200_shared[ - v_i0, v_i1, v_ki - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1623_local[ - v_k // T.int64(8), - v_i2, - ], - T.Cast( - "uint32", - v_k % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - ) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("multiple_scale"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(16384), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(16) - + ax2_y * T.int64(4) - + ax1, - ) - v0 = T.axis.spatial( - T.int64(344), - k_0 * T.int64(86) - + (k_1 * T.int64(4) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads( - lv1624_local[v0, v1], - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - ) - T.writes( - var_matmul_intermediate_local[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local[ - v_i0, v_i1, v_i2k - ] = ( - var_matmul_intermediate_local[ - v_i0, v_i1, v_i2k - ] - + var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - * lv1624_local[v0, v1] - ) - for ax2_y in T.thread_binding(T.int64(4), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_update"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(2048), - i0_i1_i2_fused_1 * T.int64(16) - + ax2_y * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(16384), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(16) - + ax2_y * T.int64(4) - + ax2, - ) - T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) - T.writes(lv200_shared[v0, v1, v2]) - lv200_shared[v0, v1, v2] = var_matmul_intermediate_local[ - v0, v1, v_i2k - ] - for ax2_y in T.thread_binding(T.int64(4), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(2048), - i0_i1_i2_fused_1 * T.int64(16) - + ax2_y * T.int64(4) - + ax2, - ) - T.where(ax2_y < T.int64(1)) - T.reads(lv200_shared[v0, v1, v_i2k]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = ( - lv198[v0, v1, v2] - + lv200_shared[v0, v1, v_i2k] - + lv200_shared[v0, v1, v_i2k + T.int64(4)] - + lv200_shared[v0, v1, v_i2k + T.int64(8)] - + lv200_shared[v0, v1, v_i2k + T.int64(12)] - ) - - -def sch_fused_decode6_fused_matmul7_add1(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[8, 256, 2] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[344, 4, 8] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) - v21 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 - ) - sch.annotate( - block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 - ) - l22, l23, l24, l25, l26 = sch.get_loops(block=b19) - sch.vectorize(loop=l26) - sch.vectorize(loop=l12) - b27 = sch.decompose_reduction(block=b1, loop=l16) - b28 = sch.get_block(name="T_add", func_name="main") - sch.reverse_compute_inline(block=b28) - sch.enter_postproc() - sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") - l29, l30, l31, l32, l33 = sch.get_loops(block=b20) - l34, l35, l36 = sch.split( - loop=l33, factors=[None, 256, 8], preserve_unit_iters=True - ) - sch.vectorize(loop=l36) - sch.bind(loop=l35, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func(private=True) -def fused_decode5_fused_matmul6_multiply1( - lv1617: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), - lv1618: T.Buffer((T.int64(128), T.int64(11008)), "float16"), - lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(11008)), "float16" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - var_matmul_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(11008)), "float16" - ) - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1617[v_i // T.int64(8), v_j], lv1618[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1617[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1618[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - lv4[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - ) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - lv4[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - ) - - -@T.prim_func(private=True) -def fused_decode5_fused_matmul6_multiply1_after( - lv1617: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), - lv1618: T.Buffer((T.int64(128), T.int64(11008)), "float16"), - lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(11008)), "float16" - ), -): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_matmul_intermediate_local = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(22016)), "float16", scope="local" - ) - var_matmul_intermediate_local_batch = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(22016)), "float16", scope="local" - ) - lv1617_local = T.alloc_buffer( - (T.int64(512), T.int64(11008)), "uint32", scope="local" - ) - lv1618_local = T.alloc_buffer( - (T.int64(128), T.int64(11008)), "float16", scope="local" - ) - lv1622_shared = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(1024)), "float16", scope="shared" - ) - for i0_i1_i2_fused_0 in T.thread_binding(T.int64(43), thread="blockIdx.x"): - for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2_init, - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_0 in range(T.int64(4)): - for ax2_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2_2 in T.vectorized(T.int64(8)): - with T.block("lv1622_shared"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(4096), - k_0 * T.int64(1024) - + ax2_y * T.int64(512) - + ax2_1 * T.int64(8) - + ax2_2, - ) - v2k = T.axis.spatial( - T.int64(1024), - ( - ax2_y * T.int64(512) - + ax2_1 * T.int64(8) - + ax2_2 - ), - ) - T.reads(lv1622[v0, v1, v2]) - T.writes(lv1622_shared[v0, v1, v2k]) - lv1622_shared[v0, v1, v2k] = lv1622[v0, v1, v2] - for k_1 in range(T.int64(16)): - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for ax1 in T.vectorized(T.int64(4)): - with T.block("matmul_init_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax1, - ) - T.reads() - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = T.float16(0) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv1618_local"): - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(32) - + (k_1 * T.int64(2) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(11008), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv1618[v0, v1]) - T.writes(lv1618_local[v0, v1]) - lv1618_local[v0, v1] = lv1618[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv1617_local"): - v0 = T.axis.spatial( - T.int64(512), - k_0 * T.int64(128) - + (k_1 * T.int64(2) + ax2_y) * T.int64(4) - + k_2 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(11008), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv1617[v0, v1]) - T.writes(lv1617_local[v0, v1]) - lv1617_local[v0, v1] = lv1617[v0, v1] - for k_3 in range(T.int64(8)): - for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(11008), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_i2k = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_k = T.axis.reduce( - T.int64(4096), - k_0 * T.int64(1024) - + (k_1 * T.int64(2) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - v_ki = T.axis.reduce( - T.int64(1024), - (k_1 * T.int64(2) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - T.reads( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - lv1622_shared[v_i0, v_i1, v_ki], - lv1617_local[v_k // T.int64(8), v_i2], - ) - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] + lv1622_shared[ - v_i0, v_i1, v_ki - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1617_local[ - v_k // T.int64(8), v_i2 - ], - T.Cast( - "uint32", - v_k % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - ) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("multiple_scale"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax1, - ) - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(32) - + (k_1 * T.int64(2) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(11008), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads( - lv1618_local[v0, v1], - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - ) - T.writes( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - ) - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - + var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - * lv1618_local[v0, v1] - ) - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_update"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(512), - i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax2, - ) - T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) - T.writes(lv1622_shared[v0, v1, v2]) - lv1622_shared[v0, v1, v2] = var_matmul_intermediate_local[ - v0, v1, v_i2k - ] - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(11008), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(512), - i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax2, - ) - T.where(ax2_y < T.int64(1)) - T.reads(lv1622_shared[v0, v1, v_i2k], lv4[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] * ( - lv1622_shared[v0, v1, v_i2k] - + lv1622_shared[v0, v1, v_i2k + T.int64(4)] - ) - - -def sch_fused_decode5_fused_matmul6_multiply1(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[43, 64, 4] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[128, 4, 8] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") - b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") - b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) - v23 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 - ) - sch.annotate( - block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 - ) - sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) - sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) - l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) - sch.vectorize(loop=l29) - l30, l31, l32, l33, l34 = sch.get_loops(block=b21) - sch.vectorize(loop=l34) - l35, l36, l37, l38, l39 = sch.get_loops(block=b19) - sch.vectorize(loop=l39) - sch.vectorize(loop=l12) - b40 = sch.decompose_reduction(block=b1, loop=l16) - b41 = sch.get_block(name="T_multiply", func_name="main") - sch.reverse_compute_inline(block=b41) - sch.enter_postproc() - sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") - l42, l43, l44, l45, l46 = sch.get_loops(block=b22) - l47, l48, l49 = sch.split(loop=l46, factors=[None, 64, 8], preserve_unit_iters=True) - sch.vectorize(loop=l49) - sch.bind(loop=l48, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func(private=True) -def fused_fused_decode9_matmul7( - lv19: T.Buffer((T.int64(512), T.int64(22016)), "uint32"), - lv20: T.Buffer((T.int64(128), T.int64(22016)), "float16"), - lv1654: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - var_matmul_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(22016)), "float16" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - p_output0_intermediate = T.alloc_buffer((T.int64(4096), T.int64(22016)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(22016)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv19[v_i // T.int64(8), v_j], lv20[v_i // T.int64(32), v_j]) - T.writes(p_output0_intermediate[v_i, v_j]) - p_output0_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv19[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv20[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(22016), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv1654[v_i0, v_i1, v_k] * p_output0_intermediate[v_k, v_i2] - ) - - -@T.prim_func(private=True) -def fused_fused_decode9_matmul7_after( - lv19: T.Buffer((T.int64(512), T.int64(22016)), "uint32"), - lv20: T.Buffer((T.int64(128), T.int64(22016)), "float16"), - lv1654: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - var_matmul_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(22016)), "float16" - ), -): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_matmul_intermediate_local = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(352256)), "float16", scope="local" - ) - var_matmul_intermediate_local_batch = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(352256)), "float16", scope="local" - ) - lv19_local = T.alloc_buffer((T.int64(512), T.int64(22016)), "uint32", scope="local") - lv20_local = T.alloc_buffer( - (T.int64(128), T.int64(22016)), "float16", scope="local" - ) - lv1654_shared = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared" - ) - for i0_i1_i2_fused_0 in T.thread_binding(T.int64(172), thread="blockIdx.x"): - for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): - for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(352256), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(64) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2_init - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_0 in range(T.int64(1)): - for ax2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2_2 in T.vectorized(T.int64(8)): - with T.block("lv1654_shared"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(4096), - k_0 * T.int64(4096) - + ax2_y * T.int64(256) - + ax2_1 * T.int64(8) - + ax2_2, - ) - v2k = T.axis.spatial( - T.int64(4096), - ( - ax2_y * T.int64(256) - + ax2_1 * T.int64(8) - + ax2_2 - ), - ) - T.reads(lv1654[v0, v1, v2]) - T.writes(lv1654_shared[v0, v1, v2k]) - lv1654_shared[v0, v1, v2k] = lv1654[v0, v1, v2] - for k_1 in range(T.int64(8)): - for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): - for ax1 in T.vectorized(T.int64(4)): - with T.block("matmul_init_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(352256), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(64) - + ax2_y * T.int64(4) - + ax1, - ) - T.reads() - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = T.float16(0) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv20_local"): - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(128) - + (k_1 * T.int64(16) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv20[v0, v1]) - T.writes(lv20_local[v0, v1]) - lv20_local[v0, v1] = lv20[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv19_local"): - v0 = T.axis.spatial( - T.int64(512), - k_0 * T.int64(512) - + (k_1 * T.int64(16) + ax2_y) * T.int64(4) - + k_2 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv19[v0, v1]) - T.writes(lv19_local[v0, v1]) - lv19_local[v0, v1] = lv19[v0, v1] - for k_3 in range(T.int64(8)): - for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_i2k = T.axis.spatial( - T.int64(352256), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(64) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_k = T.axis.reduce( - T.int64(4096), - k_0 * T.int64(4096) - + (k_1 * T.int64(16) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - v_ki = T.axis.reduce( - T.int64(4096), - (k_1 * T.int64(16) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - T.reads( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - lv1654_shared[v_i0, v_i1, v_ki], - lv19_local[v_k // T.int64(8), v_i2], - ) - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] + lv1654_shared[ - v_i0, v_i1, v_ki - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv19_local[ - v_k // T.int64(8), v_i2 - ], - T.Cast( - "uint32", - v_k % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - ) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("multiple_scale"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(352256), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(64) - + ax2_y * T.int64(4) - + ax1, - ) - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(128) - + (k_1 * T.int64(16) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads( - lv20_local[v0, v1], - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - ) - T.writes( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - ) - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - + var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - * lv20_local[v0, v1] - ) - for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_update"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(2048), - i0_i1_i2_fused_1 * T.int64(64) - + ax2_y * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(352256), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(64) - + ax2_y * T.int64(4) - + ax2, - ) - T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) - T.writes(lv1654_shared[v0, v1, v2]) - lv1654_shared[v0, v1, v2] = var_matmul_intermediate_local[ - v0, v1, v_i2k - ] - for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("reduction_1"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v_i2k = T.axis.spatial( - T.int64(2048), - i0_i1_i2_fused_1 * T.int64(64) - + ax2_y * T.int64(4) - + ax2, - ) - T.where(ax2_y < T.int64(8)) - T.reads(lv1654_shared[v0, v1, v_i2k]) - T.writes(lv1654_shared[v0, v1, v_i2k]) - lv1654_shared[v0, v1, v_i2k] = ( - lv1654_shared[v0, v1, v_i2k] - + lv1654_shared[v0, v1, v_i2k + T.int64(32)] - ) - for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("reduction_2"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v_i2k = T.axis.spatial( - T.int64(2048), - i0_i1_i2_fused_1 * T.int64(64) - + ax2_y * T.int64(4) - + ax2, - ) - T.where(ax2_y < T.int64(4)) - T.reads(lv1654_shared[v0, v1, v_i2k]) - T.writes(lv1654_shared[v0, v1, v_i2k]) - lv1654_shared[v0, v1, v_i2k] = ( - lv1654_shared[v0, v1, v_i2k] - + lv1654_shared[v0, v1, v_i2k + T.int64(16)] - ) - for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(2048), - i0_i1_i2_fused_1 * T.int64(64) - + ax2_y * T.int64(4) - + ax2, - ) - T.where(ax2_y < T.int64(1)) - T.reads(lv1654_shared[v0, v1, v_i2k]) - T.writes(var_matmul_intermediate[v0, v1, v2]) - var_matmul_intermediate[v0, v1, v2] = ( - lv1654_shared[v0, v1, v_i2k] - + lv1654_shared[v0, v1, v_i2k + T.int64(4)] - + lv1654_shared[v0, v1, v_i2k + T.int64(8)] - + lv1654_shared[v0, v1, v_i2k + T.int64(12)] - ) - - -@T.prim_func(private=True) -def fused_fused_decode7_matmul4( - lv3: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), - lv4: T.Buffer((T.int64(128), T.int64(12288)), "float16"), - lv1615: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - var_matmul_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(12288)), "float16" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - p_output0_intermediate = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(12288)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv3[v_i // T.int64(8), v_j], lv4[v_i // T.int64(32), v_j]) - T.writes(p_output0_intermediate[v_i, v_j]) - p_output0_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv3[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv4[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(12288), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1615[v_i0, v_i1, v_k], p_output0_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv1615[v_i0, v_i1, v_k] * p_output0_intermediate[v_k, v_i2] - ) - - -@T.prim_func(private=True) -def fused_fused_decode7_matmul4_after( - lv3: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), - lv4: T.Buffer((T.int64(128), T.int64(12288)), "float16"), - lv1615: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - var_matmul_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(12288)), "float16" - ), -): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_matmul_intermediate_local = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(24576)), "float16", scope="local" - ) - var_matmul_intermediate_local_batch = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(24576)), "float16", scope="local" - ) - lv3_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local") - lv4_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local") - lv1615_shared = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(1024)), "float16", scope="shared" - ) - for i0_i1_i2_fused_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"): - for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(24576), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2_init, - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_0 in range(T.int64(4)): - for ax2_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2_2 in T.vectorized(T.int64(8)): - with T.block("lv1615_shared"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(4096), - k_0 * T.int64(1024) - + ax2_y * T.int64(512) - + ax2_1 * T.int64(8) - + ax2_2, - ) - v2k = T.axis.spatial( - T.int64(1024), - ( - ax2_y * T.int64(512) - + ax2_1 * T.int64(8) - + ax2_2 - ), - ) - T.reads(lv1615[v0, v1, v2]) - T.writes(lv1615_shared[v0, v1, v2k]) - lv1615_shared[v0, v1, v2k] = lv1615[v0, v1, v2] - for k_1 in range(T.int64(16)): - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for ax1 in T.vectorized(T.int64(4)): - with T.block("matmul_init_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(24576), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax1, - ) - T.reads() - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = T.float16(0) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv4_local"): - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(32) - + (k_1 * T.int64(2) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv4[v0, v1]) - T.writes(lv4_local[v0, v1]) - lv4_local[v0, v1] = lv4[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv3_local"): - v0 = T.axis.spatial( - T.int64(512), - k_0 * T.int64(128) - + (k_1 * T.int64(2) + ax2_y) * T.int64(4) - + k_2 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv3[v0, v1]) - T.writes(lv3_local[v0, v1]) - lv3_local[v0, v1] = lv3[v0, v1] - for k_3 in range(T.int64(8)): - for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_i2k = T.axis.spatial( - T.int64(24576), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_k = T.axis.reduce( - T.int64(4096), - k_0 * T.int64(1024) - + (k_1 * T.int64(2) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - v_ki = T.axis.reduce( - T.int64(1024), - (k_1 * T.int64(2) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - T.reads( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - lv1615_shared[v_i0, v_i1, v_ki], - lv3_local[v_k // T.int64(8), v_i2], - ) - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] + lv1615_shared[ - v_i0, v_i1, v_ki - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv3_local[ - v_k // T.int64(8), v_i2 - ], - T.Cast( - "uint32", - v_k % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - ) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("multiple_scale"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(24576), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax1, - ) - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(32) - + (k_1 * T.int64(2) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads( - lv4_local[v0, v1], - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - ) - T.writes( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - ) - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - + var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - * lv4_local[v0, v1] - ) - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_update"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(512), - i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(24576), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax2, - ) - T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) - T.writes(lv1615_shared[v0, v1, v2]) - lv1615_shared[v0, v1, v2] = var_matmul_intermediate_local[ - v0, v1, v_i2k - ] - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(512), - i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax2, - ) - T.where(ax2_y < T.int64(1)) - T.reads(lv1615_shared[v0, v1, v_i2k]) - T.writes(var_matmul_intermediate[v0, v1, v2]) - var_matmul_intermediate[v0, v1, v2] = ( - lv1615_shared[v0, v1, v_i2k] - + lv1615_shared[v0, v1, v_i2k + T.int64(4)] - ) - - -@T.prim_func(private=True) -def fused_decode5_fused_matmul6_silu1( - lv1611: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), - lv1612: T.Buffer((T.int64(128), T.int64(11008)), "float16"), - lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(11008)), "float16" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - var_matmul_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(11008)), "float16" - ) - compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1611[v_i // T.int64(8), v_j], lv1612[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1611[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1612[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - ) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.sigmoid( - var_matmul_intermediate[v_i0, v_i1, v_i2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_matmul_intermediate[v_ax0, v_ax1, v_ax2], - compute[v_ax0, v_ax1, v_ax2], - ) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - * compute[v_ax0, v_ax1, v_ax2] - ) - - -@T.prim_func(private=True) -def fused_decode5_fused_matmul6_silu1_after( - lv1611: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), - lv1612: T.Buffer((T.int64(128), T.int64(11008)), "float16"), - lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(11008)), "float16" - ), -): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_matmul_intermediate_local = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(22016)), "float16", scope="local" - ) - var_matmul_intermediate_local_batch = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(22016)), "float16", scope="local" - ) - lv1611_local = T.alloc_buffer( - (T.int64(512), T.int64(11008)), "uint32", scope="local" - ) - lv1612_local = T.alloc_buffer( - (T.int64(128), T.int64(11008)), "float16", scope="local" - ) - lv1622_shared = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(1024)), "float16", scope="shared" - ) - for i0_i1_i2_fused_0 in T.thread_binding(T.int64(43), thread="blockIdx.x"): - for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2_init, - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_0 in range(T.int64(4)): - for ax2_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2_2 in T.vectorized(T.int64(8)): - with T.block("lv1622_shared"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(4096), - k_0 * T.int64(1024) - + ax2_y * T.int64(512) - + ax2_1 * T.int64(8) - + ax2_2, - ) - v2k = T.axis.spatial( - T.int64(1024), - ( - ax2_y * T.int64(512) - + ax2_1 * T.int64(8) - + ax2_2 - ), - ) - T.reads(lv1622[v0, v1, v2]) - T.writes(lv1622_shared[v0, v1, v2k]) - lv1622_shared[v0, v1, v2k] = lv1622[v0, v1, v2] - for k_1 in range(T.int64(16)): - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for ax1 in T.vectorized(T.int64(4)): - with T.block("matmul_init_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax1, - ) - T.reads() - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = T.float16(0) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv1612_local"): - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(32) - + (k_1 * T.int64(2) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(11008), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv1612[v0, v1]) - T.writes(lv1612_local[v0, v1]) - lv1612_local[v0, v1] = lv1612[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv1611_local"): - v0 = T.axis.spatial( - T.int64(512), - k_0 * T.int64(128) - + (k_1 * T.int64(2) + ax2_y) * T.int64(4) - + k_2 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(11008), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv1611[v0, v1]) - T.writes(lv1611_local[v0, v1]) - lv1611_local[v0, v1] = lv1611[v0, v1] - for k_3 in range(T.int64(8)): - for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(11008), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_i2k = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_k = T.axis.reduce( - T.int64(4096), - k_0 * T.int64(1024) - + (k_1 * T.int64(2) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - v_ki = T.axis.reduce( - T.int64(1024), - (k_1 * T.int64(2) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - T.reads( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - lv1622_shared[v_i0, v_i1, v_ki], - lv1611_local[v_k // T.int64(8), v_i2], - ) - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] + lv1622_shared[ - v_i0, v_i1, v_ki - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1611_local[ - v_k // T.int64(8), v_i2 - ], - T.Cast( - "uint32", - v_k % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - ) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("multiple_scale"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax1, - ) - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(32) - + (k_1 * T.int64(2) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(11008), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads( - lv1612_local[v0, v1], - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - ) - T.writes( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - ) - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - + var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - * lv1612_local[v0, v1] - ) - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_update"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(512), - i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(512) - + i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax2, - ) - T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) - T.writes(lv1622_shared[v0, v1, v2]) - lv1622_shared[v0, v1, v2] = var_matmul_intermediate_local[ - v0, v1, v_i2k - ] - for ax2_y in T.thread_binding(T.int64(2), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("reduction"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(512), - i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax2, - ) - T.where(ax2_y < T.int64(1)) - T.reads(lv1622_shared[v0, v1, v2]) - T.writes(lv1622_shared[v0, v1, v2]) - lv1622_shared[v0, v1, v2] = ( - lv1622_shared[v0, v1, v2] - + lv1622_shared[v0, v1, v2 + T.int64(4)] - ) - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(11008), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(512), - i0_i1_i2_fused_1 * T.int64(8) - + ax2_y * T.int64(4) - + ax2, - ) - T.where(ax2_y < T.int64(1)) - T.reads(lv1622_shared[v0, v1, v_i2k]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv1622_shared[ - v0, v1, v_i2k - ] * T.sigmoid(lv1622_shared[v0, v1, v_i2k]) - - -def sch_fused_decode5_fused_matmul6_silu1(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[43, 64, 4] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[128, 4, 8] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") - b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") - b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) - v23 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 - ) - sch.annotate( - block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 - ) - sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) - sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) - l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) - sch.vectorize(loop=l29) - l30, l31, l32, l33, l34 = sch.get_loops(block=b21) - sch.vectorize(loop=l34) - l35, l36, l37, l38, l39 = sch.get_loops(block=b19) - sch.vectorize(loop=l39) - sch.vectorize(loop=l12) - b40 = sch.decompose_reduction(block=b1, loop=l16) - b41 = sch.get_block(name="compute", func_name="main") - sch.compute_inline(block=b41) - b42 = sch.get_block(name="T_multiply", func_name="main") - sch.reverse_compute_inline(block=b42) - sch.enter_postproc() - sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") - l43, l44, l45, l46, l47 = sch.get_loops(block=b22) - l48, l49, l50 = sch.split(loop=l47, factors=[None, 64, 8], preserve_unit_iters=True) - sch.vectorize(loop=l50) - sch.bind(loop=l49, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - -@T.prim_func(private=True) -def fused_decode81_fused_matmul1_cast2( - lv1576: T.Buffer((T.int64(512), T.int64(64000)), "uint32"), - lv1577: T.Buffer((T.int64(128), T.int64(64000)), "float16"), - lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(64000)), "float32" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(64000)), "float16") - var_matmul_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(64000)), "float16" - ) - for i, j in T.grid(T.int64(4096), T.int64(64000)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1576[v_i // T.int64(8), v_j], lv1577[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1576[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1577[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(64000), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1575[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv1575[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - ) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(64000)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) - p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( - "float32", var_matmul_intermediate[v_i0, v_i1, v_i2] - ) - -def sch_fused_decode81_fused_matmul1_cast2(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[160, 100, 4] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[512, 8, 1] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") - b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") - b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) - v23 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1 - ) - sch.annotate( - block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 - ) - sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) - sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) - l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) - sch.vectorize(loop=l29) - l30, l31, l32, l33, l34 = sch.get_loops(block=b21) - sch.vectorize(loop=l34) - l35, l36, l37, l38, l39 = sch.get_loops(block=b19) - sch.vectorize(loop=l39) - sch.vectorize(loop=l12) - b40 = sch.decompose_reduction(block=b1, loop=l16) - b41 = sch.get_block(name="compute", func_name="main") - sch.reverse_compute_inline(block=b41) - sch.enter_postproc() - sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") - l42, l43, l44, l45, l46 = sch.get_loops(block=b22) - l47, l48, l49 = sch.split( - loop=l46, factors=[None, 100, 2], preserve_unit_iters=True - ) - sch.vectorize(loop=l49) - sch.bind(loop=l48, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - - - -@T.prim_func(private=True) -def fused_decode4_fused_matmul4_add1( - lv1605: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), - lv1606: T.Buffer((T.int64(128), T.int64(4096)), "float16"), - lv197: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - lv1581: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - var_matmul_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ) - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1605[v_i // T.int64(8), v_j], lv1606[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1605[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1606[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv197[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv197[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - lv1581[v_ax0, v_ax1, v_ax2], - var_matmul_intermediate[v_ax0, v_ax1, v_ax2], - ) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - lv1581[v_ax0, v_ax1, v_ax2] - + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - ) - - -@T.prim_func(private=True) -def fused_decode4_fused_matmul4_add1_after( - lv1605: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), - lv1606: T.Buffer((T.int64(128), T.int64(4096)), "float16"), - lv197: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - lv1581: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(4096)), "float16" - ), -): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_matmul_intermediate_local = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(32768)), "float16", scope="local" - ) - var_matmul_intermediate_local_batch = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(32768)), "float16", scope="local" - ) - lv1605_local = T.alloc_buffer( - (T.int64(512), T.int64(4096)), "uint32", scope="local" - ) - lv1606_local = T.alloc_buffer( - (T.int64(128), T.int64(4096)), "float16", scope="local" - ) - lv197_shared = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(2048)), "float16", scope="shared" - ) - for i0_i1_i2_fused_0 in T.thread_binding(T.int64(32), thread="blockIdx.x"): - for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(32768), - i0_i1_i2_fused_0 * T.int64(1024) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2_init, - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_0 in range(T.int64(2)): - for ax2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2_2 in T.vectorized(T.int64(8)): - with T.block("lv197_shared"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(4096), - k_0 * T.int64(2048) - + ax2_1 * T.int64(64) - + (ax2_y * T.int64(8) + ax2_2), - ) - v2k = T.axis.spatial( - T.int64(2048), - ( - ax2_1 * T.int64(64) - + ax2_y * T.int64(8) - + ax2_2 - ), - ) - T.reads(lv197[v0, v1, v2]) - T.writes(lv197_shared[v0, v1, v2k]) - lv197_shared[v0, v1, v2k] = lv197[v0, v1, v2] - for k_1 in range(T.int64(8)): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax1 in T.vectorized(T.int64(4)): - with T.block("matmul_init_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(32768), - i0_i1_i2_fused_0 * T.int64(1024) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + ax1, - ) - T.reads() - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = T.float16(0) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv1606_local"): - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(64) - + (k_1 * T.int64(8) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv1606[v0, v1]) - T.writes(lv1606_local[v0, v1]) - lv1606_local[v0, v1] = lv1606[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv1605_local"): - v0 = T.axis.spatial( - T.int64(512), - k_0 * T.int64(256) - + (k_1 * T.int64(8) + ax2_y) * T.int64(4) - + k_2 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv1605[v0, v1]) - T.writes(lv1605_local[v0, v1]) - lv1605_local[v0, v1] = lv1605[v0, v1] - for k_3 in range(T.int64(8)): - for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_i2k = T.axis.spatial( - T.int64(32768), - i0_i1_i2_fused_0 * T.int64(1024) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_k = T.axis.reduce( - T.int64(4096), - k_0 * T.int64(2048) - + (k_1 * T.int64(8) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - v_ki = T.axis.reduce( - T.int64(2048), - (k_1 * T.int64(8) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - T.reads( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - lv197_shared[v_i0, v_i1, v_ki], - lv1605_local[v_k // T.int64(8), v_i2], - ) - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] + lv197_shared[ - v_i0, v_i1, v_ki - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1605_local[ - v_k // T.int64(8), v_i2 - ], - T.Cast( - "uint32", - v_k % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - ) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("multiple_scale"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(32768), - i0_i1_i2_fused_0 * T.int64(1024) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + ax1, - ) - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(64) - + (k_1 * T.int64(8) + ax2_y) - + ax0, - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads( - lv1606_local[v0, v1], - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ], - ) - T.writes( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - ) - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - + var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - * lv1606_local[v0, v1] - ) - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_update"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(1024), - i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(32768), - i0_i1_i2_fused_0 * T.int64(1024) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + ax2, - ) - T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) - T.writes(lv197_shared[v0, v1, v2]) - lv197_shared[v0, v1, v2] = var_matmul_intermediate_local[ - v0, v1, v_i2k - ] - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("reduction_sum"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(1024), - i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + ax2, - ) - T.where(ax2_y < T.int64(4)) - T.reads(lv197_shared[v0, v1, v2]) - T.writes(lv197_shared[v0, v1, v2]) - lv197_shared[v0, v1, v2] = ( - lv197_shared[v0, v1, v2] - + lv197_shared[v0, v1, v2 + T.int64(16)] - ) - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax2, - ) - v_i2k = T.axis.spatial( - T.int64(1024), - i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + ax2, - ) - T.where(ax2_y < T.int64(1)) - T.reads(lv197_shared[v0, v1, v_i2k], lv1581[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = ( - lv1581[v0, v1, v2] - + lv197_shared[v0, v1, v_i2k] - + lv197_shared[v0, v1, v_i2k + T.int64(4)] - + lv197_shared[v0, v1, v_i2k + T.int64(8)] - + lv197_shared[v0, v1, v_i2k + T.int64(12)] - ) - -@T.prim_func(private=True) -def fused_decode82_fused_matmul1_cast2( - lv1576: T.Buffer((T.int64(512), T.int64(64000)), "uint32"), - lv1577: T.Buffer((T.int64(128), T.int64(64000)), "float16"), - lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(64000)), "float32" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2048), T.int64(64000)), "float16") - var_matmul_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(64000)), "float16" - ) - for i, j in T.grid(T.int64(2048), T.int64(64000)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1576[v_i // T.int64(8), v_j], lv1577[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1576[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1577[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(64000), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1575[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv1575[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - ) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(64000)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) - p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( - "float32", var_matmul_intermediate[v_i0, v_i1, v_i2] - ) - -def sch_fused_decode82_fused_matmul1_cast2(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[160, 100, 4] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[512, 8, 1] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") - b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") - b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) - v23 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1 - ) - sch.annotate( - block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 - ) - sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) - sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) - l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) - sch.vectorize(loop=l29) - l30, l31, l32, l33, l34 = sch.get_loops(block=b21) - sch.vectorize(loop=l34) - l35, l36, l37, l38, l39 = sch.get_loops(block=b19) - sch.vectorize(loop=l39) - sch.vectorize(loop=l12) - b40 = sch.decompose_reduction(block=b1, loop=l16) - b41 = sch.get_block(name="compute", func_name="main") - sch.reverse_compute_inline(block=b41) - sch.enter_postproc() - sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") - l42, l43, l44, l45, l46 = sch.get_loops(block=b22) - l47, l48, l49 = sch.split( - loop=l46, factors=[None, 100, 2], preserve_unit_iters=True - ) - sch.vectorize(loop=l49) - sch.bind(loop=l48, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - -def sch_fused_decode4_fused_matmul4_add1(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[32, 64, 2] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[128, 4, 8] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") - b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") - b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) - v23 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 - ) - sch.annotate( - block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 - ) - sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) - sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) - l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) - sch.vectorize(loop=l29) - l30, l31, l32, l33, l34 = sch.get_loops(block=b21) - sch.vectorize(loop=l34) - l35, l36, l37, l38, l39 = sch.get_loops(block=b19) - sch.vectorize(loop=l39) - sch.vectorize(loop=l12) - b40 = sch.decompose_reduction(block=b1, loop=l16) - b41 = sch.get_block(name="T_add", func_name="main") - sch.reverse_compute_inline(block=b41) - sch.enter_postproc() - sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") - l42, l43, l44, l45, l46 = sch.get_loops(block=b22) - l47, l48, l49 = sch.split(loop=l46, factors=[None, 64, 8], preserve_unit_iters=True) - sch.vectorize(loop=l49) - sch.bind(loop=l48, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - -@T.prim_func(private=True) -def fused_decode3_fused_matmul1_cast2( - lv1576: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), - lv1577: T.Buffer((T.int64(128), T.int64(32000)), "float16"), - lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(32000)), "float32" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16") - var_matmul_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(32000)), "float16" - ) - for i, j in T.grid(T.int64(4096), T.int64(32000)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1576[v_i // T.int64(8), v_j], lv1577[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1576[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1577[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1575[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv1575[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - ) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) - p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( - "float32", var_matmul_intermediate[v_i0, v_i1, v_i2] - ) - -@T.prim_func(private=True) -def fused_decode3_fused_matmul1_cast2_after( - lv1576: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), - lv1577: T.Buffer((T.int64(128), T.int64(32000)), "float16"), - lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(32000)), "float32" - ), -): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_matmul_intermediate_local = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(512000)), "float16", scope="local" - ) - var_matmul_intermediate_local_batch = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(512000)), "float16", scope="local" - ) - lv1576_local = T.alloc_buffer( - (T.int64(512), T.int64(32000)), "uint32", scope="local" - ) - lv1577_local = T.alloc_buffer( - (T.int64(128), T.int64(32000)), "float16", scope="local" - ) - lv1575_shared = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared" - ) - for i0_i1_i2_fused_0 in T.thread_binding(T.int64(125), thread="blockIdx.x"): - for i0_i1_i2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(512000), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2_init - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_0 in range(T.int64(1)): - for ax2_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2_2 in T.vectorized(T.int64(8)): - with T.block("lv1575_shared"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(4096), - k_0 * T.int64(4096) - + ax2_y * T.int64(512) - + ax2_1 * T.int64(8) + ax2_2 - ) - v2k = T.axis.spatial( - T.int64(4096), - (ax2_y * T.int64(512) - + ax2_1 * T.int64(8) + ax2_2) - ) - T.reads(lv1575[v0, v1, v2]) - T.writes(lv1575_shared[v0, v1, v2k]) - lv1575_shared[v0, v1, v2k] = lv1575[v0, v1, v2] - for k_1 in range(T.int64(16)): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax1 in T.vectorized(T.int64(4)): - with T.block("matmul_init_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(512000), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) + ax1 - ) - T.reads() - T.writes(var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k]) - var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k] = T.float16(0) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv1577_local"): - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(128) - + (k_1 * T.int64(8) + ax2_y) + ax0 - ) - v1 = T.axis.spatial( - T.int64(32000), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) + ax1 - ) - T.reads(lv1577[v0, v1]) - T.writes(lv1577_local[v0, v1]) - lv1577_local[v0, v1] = lv1577[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv1576_local"): - v0 = T.axis.spatial( - T.int64(512), - k_0 * T.int64(512) - + (k_1 * T.int64(8) + ax2_y) * T.int64(4) - + k_2 + ax0 - ) - v1 = T.axis.spatial( - T.int64(32000), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1 - ) - T.reads(lv1576[v0, v1]) - T.writes(lv1576_local[v0, v1]) - lv1576_local[v0, v1] = lv1576[v0, v1] - for k_3 in range(T.int64(8)): - for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial( - T.int64(32000), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2 - ) - v_i2k = T.axis.spatial( - T.int64(512000), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) - + i0_i1_i2_fused_2 - ) - v_k = T.axis.reduce( - T.int64(4096), - k_0 * T.int64(4096) - + (k_1 * T.int64(8) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) + k_3 - ) - v_ki = T.axis.reduce( - T.int64(4096), - (k_1 * T.int64(8) + ax2_y) * T.int64(32) - + k_2 * T.int64(8) + k_3 - ) - T.reads( - var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k], - lv1575_shared[v_i0, v_i1, v_ki], lv1576_local[v_k // T.int64(8), v_i2] - ) - T.writes(var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k]) - var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k] = ( - var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k] - + lv1575_shared[v_i0, v_i1, v_ki] - * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv1576_local[v_k // T.int64(8), v_i2], - T.Cast("uint32", v_k % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7))) - ) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("multiple_scale"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2k = T.axis.spatial( - T.int64(512000), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) + ax1 - ) - v0 = T.axis.spatial( - T.int64(128), - k_0 * T.int64(128) - + (k_1 * T.int64(8) + ax2_y) + ax0 - ) - v1 = T.axis.spatial( - T.int64(32000), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) + ax1 - ) - T.reads( - lv1577_local[v0, v1], - var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k] - ) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2k]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] = ( - var_matmul_intermediate_local[v_i0, v_i1, v_i2k] - + var_matmul_intermediate_local_batch[v_i0, v_i1, v_i2k] * lv1577_local[v0, v1] - ) - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_update"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(2048), - ax2_y * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) + ax2 - ) - v_i2k = T.axis.spatial( - T.int64(512000), - i0_i1_i2_fused_0 * T.int64(2048) - + i0_i1_i2_fused_1 * T.int64(32) - + ax2_y * T.int64(4) + ax2 - ) - T.reads(var_matmul_intermediate_local[v0, v1, v_i2k]) - T.writes(lv1575_shared[v0, v1, v2]) - lv1575_shared[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v_i2k] - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("reduction_2"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v_i2k = T.axis.spatial( - T.int64(2048), - ax2_y * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) + ax2 - ) - T.where(ax2_y < T.int64(4)) - T.reads(lv1575_shared[v0, v1, v_i2k]) - T.writes(lv1575_shared[v0, v1, v_i2k]) - lv1575_shared[v0, v1, v_i2k] = ( - lv1575_shared[v0, v1, v_i2k] + lv1575_shared[v0, v1, v_i2k + T.int64(1024)] - ) - for ax2_y in T.thread_binding(T.int64(16), thread="threadIdx.y"): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial( - T.int64(32000), - i0_i1_i2_fused_0 * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) + ax2 - ) - v_i2k = T.axis.spatial( - T.int64(2048), - ax2_y * T.int64(256) - + i0_i1_i2_fused_1 * T.int64(4) + ax2 - ) - T.where(ax2_y < T.int64(1)) - T.reads(lv1575_shared[v0, v1, v_i2k]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = T.Cast( - "float32", lv1575_shared[v0, v1, v_i2k] - + lv1575_shared[v0, v1, v_i2k + T.int64(256)] - + lv1575_shared[v0, v1, v_i2k + T.int64(512)] - + lv1575_shared[v0, v1, v_i2k + T.int64(768)] - ) - - -def sch_fused_decode3_fused_matmul1_cast2(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[80, 100, 4] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[512, 8, 1] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local") - b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local") - b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1) - v23 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1 - ) - sch.annotate( - block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23 - ) - sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1) - sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1) - l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20) - sch.vectorize(loop=l29) - l30, l31, l32, l33, l34 = sch.get_loops(block=b21) - sch.vectorize(loop=l34) - l35, l36, l37, l38, l39 = sch.get_loops(block=b19) - sch.vectorize(loop=l39) - sch.vectorize(loop=l12) - b40 = sch.decompose_reduction(block=b1, loop=l16) - b41 = sch.get_block(name="compute", func_name="main") - sch.reverse_compute_inline(block=b41) - sch.enter_postproc() - sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch") - l42, l43, l44, l45, l46 = sch.get_loops(block=b22) - l47, l48, l49 = sch.split( - loop=l46, factors=[None, 100, 2], preserve_unit_iters=True - ) - sch.vectorize(loop=l49) - sch.bind(loop=l48, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func(private=True) -def fused_decode2_fused_NT_matmul3_add( - lv50: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), - lv51: T.Buffer((T.int64(344), T.int64(4096)), "float16"), - p_lv5: T.handle, - p_lv3: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv5 = T.match_buffer(p_lv5, (T.int64(1), n, T.int64(11008)), "float16") - lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16") - p_output0_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(4096)), "float16" - ) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") - var_T_transpose_intermediate = T.alloc_buffer( - (T.int64(4096), T.int64(11008)), "float16" - ) - var_NT_matmul_intermediate = T.alloc_buffer( - (T.int64(1), n, T.int64(4096)), "float16" - ) - for i, j in T.grid(T.int64(11008), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv50[v_i // T.int64(8), v_j], lv51[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv50[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv51[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) - var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv5[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv5[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - lv3[v_ax0, v_ax1, v_ax2], - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], - ) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - lv3[v_ax0, v_ax1, v_ax2] - + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] - ) - - -@T.prim_func(private=True) -def fused_decode2_fused_NT_matmul3_add_after( - lv8: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), - lv9: T.Buffer((T.int64(344), T.int64(4096)), "float16"), - p_lv5: T.handle, - p_lv3: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - lv6 = T.match_buffer(p_lv5, (1, n, 11008), "float16") - lv2 = T.match_buffer(p_lv3, (1, n, 4096), "float16") - var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16") - - var_matmul_intermediate_local = T.alloc_buffer( - (T.int64(1), ((n+7)//8) * 8, T.int64(4096)), "float16", scope="local" - ) - var_matmul_intermediate_local_batch = T.alloc_buffer( - (T.int64(1), ((n+7)//8) * 8, T.int64(4096)), "float16", scope="local" - ) - lv8_local = T.alloc_buffer((T.int64(512), T.int64(4096)), "uint32", scope="local") - lv9_local = T.alloc_buffer( - (T.int64(128), T.int64(4096)), "float16", scope="local" - ) - #lv6_shared = T.alloc_buffer( - # (T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared" - #) - for i0_i1_i2_fused_n in T.thread_binding(((n+7)//8), thread="blockIdx.y"): - for i0_i1_i2_fused_0 in T.thread_binding(T.int64(32), thread="blockIdx.x"): - for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - with T.block("n_check"): - T.where((i0_i1_i2_fused_n * T.int64(8) + ax2_y) < n) - for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2_init - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_1 in range(T.int64(344)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("matmul_init_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2k = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads() - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = T.float16(0) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv9_local"): - v0 = T.axis.spatial( - T.int64(344), k_1 - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv9[v0, v1]) - T.writes(lv9_local[v0, v1]) - lv9_local[v0, v1] = lv9[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv8_local"): - v0 = T.axis.spatial( - T.int64(1376), - k_1 * T.int64(4) - + k_2 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv8[v0, v1]) - T.writes(lv8_local[v0, v1]) - lv8_local[v0, v1] = lv8[v0, v1] - for k_3 in range(T.int64(8)): - for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_k = T.axis.reduce( - T.int64(11008), - k_1 * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - T.reads( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ], - lv6[v_i0, v_i1, v_k], - lv8_local[v_k // T.int64(8), v_i2], - ) - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] = var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] + lv6[ - v_i0, v_i1, v_k - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv8_local[ - v_k // T.int64(8), v_i2 - ], - T.Cast( - "uint32", - v_k % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - ) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("multiple_scale"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - v0 = T.axis.spatial( - T.int64(344), - k_1 - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads( - lv9_local[v0, v1], - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ], - ) - T.writes( - var_matmul_intermediate_local[v_i0, v_i1, v_i2] - ) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate_local[v_i0, v_i1, v_i2] - + var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] - * lv9_local[v0, v1] - ) - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax2, - ) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2[v_i0, v_i1, v_i2]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2[v_i0, v_i1, v_i2] - - -@T.prim_func(private=True) -def fused_decode_NT_matmul( - lv8: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), - lv9: T.Buffer((T.int64(128), T.int64(4096)), "float16"), - p_lv6: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(4096)), "float16") - var_NT_matmul_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(4096)), "float16" - ) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - var_T_transpose_intermediate = T.alloc_buffer( - (T.int64(4096), T.int64(4096)), "float16" - ) - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv8[v_i // T.int64(8), v_j], lv9[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv8[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv9[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) - var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv6[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv6[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] - ) - - -@T.prim_func(private=True) -def fused_decode_NT_matmul_after( - lv8: T.Buffer((512, 4096), "uint32"), - lv9: T.Buffer((128, 4096), "float16"), - p_lv6: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int32() - lv6 = T.match_buffer(p_lv6, (1, n, 4096), "float16") - var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16") - # with T.block("root"): - decode_local = T.alloc_buffer((4096, 4096), "float16", scope="local") - lv8_local = T.alloc_buffer((512, 4096), "uint32", scope="local") - lv9_local = T.alloc_buffer((128, 4096), "float16", scope="local") - lv6_pad_local = T.alloc_buffer( - (1, (n + 31) // 32 * 32, 4096), "float16", scope="local" - ) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer( - (1, (n + 31) // 32 * 32, 4096), "float16", scope="local" - ) - for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( - (n + 31) // 32, thread="blockIdx.y" - ): - for i2_0 in T.thread_binding(32, thread="blockIdx.x"): - for i0_i1_fused_1_1 in T.thread_binding(8, thread="threadIdx.y"): - for i2_1 in T.thread_binding(16, thread="threadIdx.x"): - for i0_i1_fused_1_2_init in range(4): - for i2_2_init in T.vectorized(8): - with T.block("NT_matmul_init"): - v_i0 = T.axis.spatial(1, 0) - v_i1 = T.axis.spatial( - (n + 31) // 32 * 32, - i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 - + i0_i1_fused_1_1 * 4 - + i0_i1_fused_1_2_init, - ) - v_i2 = T.axis.spatial( - 4096, i2_0 * 128 + i2_1 * 8 + i2_2_init - ) - T.reads() - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = T.float16(0) - for k_0 in range(128): - for ax0 in range(1): - for ax1 in T.vectorized(8): - with T.block("lv9_local"): - v0 = T.axis.spatial(128, k_0 + ax0) - v1 = T.axis.spatial( - 4096, i2_0 * 128 + i2_1 * 8 + ax1 - ) - T.reads(lv9[v0, v1]) - T.writes(lv9_local[v0, v1]) - lv9_local[v0, v1] = lv9[v0, v1] - for k_1 in range(4): - for ax0 in range(1): - for ax1 in T.vectorized(8): - with T.block("lv8_local"): - v0 = T.axis.spatial(512, k_0 * 4 + k_1 + ax0) - v1 = T.axis.spatial( - 4096, i2_0 * 128 + i2_1 * 8 + ax1 - ) - T.reads(lv8[v0, v1]) - T.writes(lv8_local[v0, v1]) - lv8_local[v0, v1] = lv8[v0, v1] - for k_2 in range(8): - for ax0 in range(1): - for ax1 in T.vectorized(8): - with T.block("decode"): - v_i = T.axis.spatial( - 4096, k_0 * 32 + k_1 * 8 + k_2 + ax0 - ) - v_j = T.axis.spatial( - 4096, i2_0 * 128 + i2_1 * 8 + ax1 - ) - T.reads( - lv8_local[v_i // 8, v_j], - lv9_local[v_i // 32, v_j], - ) - T.writes(decode_local[v_i, v_j]) - decode_local[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv8_local[v_i // 8, v_j], - T.Cast("uint32", v_i % 8) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv9_local[v_i // 32, v_j] - for ax0, ax1 in T.grid(1, 4): - for ax2 in T.vectorized(1): - with T.block("lv6_pad_local"): - v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial( - (n + 31) // 32 * 32, - i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 - + i0_i1_fused_1_1 * 4 - + ax1, - ) - v2 = T.axis.spatial( - 4096, k_0 * 32 + k_1 * 8 + k_2 + ax2 - ) - T.reads(lv6[v0, v1, v2]) - T.writes(lv6_pad_local[v0, v1, v2]) - lv6_pad_local[v0, v1, v2] = T.if_then_else( - v1 < n, lv6[v0, v1, v2], T.float16(0) - ) - for i0_i1_fused_1_2 in range(4): - for i2_2 in T.vectorized(8): - with T.block("NT_matmul_update"): - v_i0 = T.axis.spatial(1, 0) - v_i1 = T.axis.spatial( - (n + 31) // 32 * 32, - i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 - + i0_i1_fused_1_1 * 4 - + i0_i1_fused_1_2, - ) - v_i2 = T.axis.spatial( - 4096, i2_0 * 128 + i2_1 * 8 + i2_2 - ) - v_k = T.axis.reduce( - 4096, k_0 * 32 + k_1 * 8 + k_2 - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ], - lv6_pad_local[v_i0, v_i1, v_k], - decode_local[v_k, v_i2], - ) - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = ( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - + lv6_pad_local[v_i0, v_i1, v_k] - * decode_local[v_k, v_i2] - ) - for ax0, ax1 in T.grid(1, 4): - for ax2 in T.vectorized(8): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial( - (n + 31) // 32 * 32, - i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 - + i0_i1_fused_1_1 * 4 - + ax1, - ) - v2 = T.axis.spatial(4096, i2_0 * 128 + i2_1 * 8 + ax2) - T.reads( - var_NT_matmul_intermediate_pad_local[v0, v1, v2] - ) - T.writes(var_NT_matmul_intermediate[v0, v1, v2]) - if v1 < n: - var_NT_matmul_intermediate[ - v0, v1, v2 - ] = var_NT_matmul_intermediate_pad_local[v0, v1, v2] - - -@T.prim_func(private=True) -def fused_decode1_fused_NT_matmul2_silu( - lv36: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), - lv37: T.Buffer((T.int64(128), T.int64(11008)), "float16"), - p_lv45: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16") - p_output0_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(11008)), "float16" - ) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - var_T_transpose_intermediate = T.alloc_buffer( - (T.int64(11008), T.int64(4096)), "float16" - ) - var_NT_matmul_intermediate = T.alloc_buffer( - (T.int64(1), n, T.int64(11008)), "float16" - ) - compute = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv36[v_i // T.int64(8), v_j], lv37[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv36[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv37[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) - var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv45[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv45[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] - ) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.sigmoid( - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], - compute[v_ax0, v_ax1, v_ax2], - ) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] - * compute[v_ax0, v_ax1, v_ax2] - ) - - -@T.prim_func(private=True) -def fused_decode1_fused_NT_matmul2_silu_after( - lv36: T.Buffer((512, 11008), "uint32"), - lv37: T.Buffer((128, 11008), "float16"), - p_lv45: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int32() - lv45 = T.match_buffer(p_lv45, (1, n, 4096), "float16") - p_output0_intermediate = T.match_buffer(p_output0, (1, n, 11008), "float16") - # with T.block("root"): - decode_local = T.alloc_buffer((4096, 11008), "float16", scope="local") - lv36_local = T.alloc_buffer((512, 11008), "uint32", scope="local") - lv37_local = T.alloc_buffer((128, 11008), "float16", scope="local") - lv45_pad_local = T.alloc_buffer( - (1, (n + 31) // 32 * 32, 4096), "float16", scope="local" - ) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer( - (1, (n + 31) // 32 * 32, 11008), "float16", scope="local" - ) - for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( - (n + 31) // 32, thread="blockIdx.y" - ): - for i2_0 in T.thread_binding(86, thread="blockIdx.x"): - for i0_i1_fused_1_1 in T.thread_binding(8, thread="threadIdx.y"): - for i2_1 in T.thread_binding(16, thread="threadIdx.x"): - for i0_i1_fused_1_2_init in range(4): - for i2_2_init in T.vectorized(8): - with T.block("NT_matmul_init"): - v_i0 = T.axis.spatial(1, 0) - v_i1 = T.axis.spatial( - (n + 31) // 32 * 32, - i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 - + i0_i1_fused_1_1 * 4 - + i0_i1_fused_1_2_init, - ) - v_i2 = T.axis.spatial( - 11008, i2_0 * 128 + i2_1 * 8 + i2_2_init - ) - T.reads() - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = T.float16(0) - for k_0 in range(128): - for ax0 in range(1): - for ax1 in T.vectorized(8): - with T.block("lv37_local"): - v0 = T.axis.spatial(128, k_0 + ax0) - v1 = T.axis.spatial( - 11008, i2_0 * 128 + i2_1 * 8 + ax1 - ) - T.reads(lv37[v0, v1]) - T.writes(lv37_local[v0, v1]) - lv37_local[v0, v1] = lv37[v0, v1] - for k_1 in range(4): - for ax0 in range(1): - for ax1 in T.vectorized(8): - with T.block("lv36_local"): - v0 = T.axis.spatial(512, k_0 * 4 + k_1 + ax0) - v1 = T.axis.spatial( - 11008, i2_0 * 128 + i2_1 * 8 + ax1 - ) - T.reads(lv36[v0, v1]) - T.writes(lv36_local[v0, v1]) - lv36_local[v0, v1] = lv36[v0, v1] - for k_2 in range(8): - for ax0 in range(1): - for ax1 in T.vectorized(8): - with T.block("decode"): - v_i = T.axis.spatial( - 4096, k_0 * 32 + k_1 * 8 + k_2 + ax0 - ) - v_j = T.axis.spatial( - 11008, i2_0 * 128 + i2_1 * 8 + ax1 - ) - T.reads( - lv36_local[v_i // 8, v_j], - lv37_local[v_i // 32, v_j], - ) - T.writes(decode_local[v_i, v_j]) - decode_local[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv36_local[v_i // 8, v_j], - T.Cast("uint32", v_i % 8) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv37_local[v_i // 32, v_j] - for ax0, ax1 in T.grid(1, 4): - for ax2 in T.vectorized(1): - with T.block("lv45_pad_local"): - v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial( - (n + 31) // 32 * 32, - i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 - + i0_i1_fused_1_1 * 4 - + ax1, - ) - v2 = T.axis.spatial( - 4096, k_0 * 32 + k_1 * 8 + k_2 + ax2 - ) - T.reads(lv45[v0, v1, v2]) - T.writes(lv45_pad_local[v0, v1, v2]) - lv45_pad_local[v0, v1, v2] = T.if_then_else( - v1 < n, lv45[v0, v1, v2], T.float16(0) - ) - for i0_i1_fused_1_2 in range(4): - for i2_2 in T.vectorized(8): - with T.block("NT_matmul_update"): - v_i0 = T.axis.spatial(1, 0) - v_i1 = T.axis.spatial( - (n + 31) // 32 * 32, - i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 - + i0_i1_fused_1_1 * 4 - + i0_i1_fused_1_2, - ) - v_i2 = T.axis.spatial( - 11008, i2_0 * 128 + i2_1 * 8 + i2_2 - ) - v_k = T.axis.reduce( - 4096, k_0 * 32 + k_1 * 8 + k_2 - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ], - lv45_pad_local[v_i0, v_i1, v_k], - decode_local[v_k, v_i2], - ) - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = ( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - + lv45_pad_local[v_i0, v_i1, v_k] - * decode_local[v_k, v_i2] - ) - for ax0, ax1 in T.grid(1, 4): - for ax2 in T.vectorized(8): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial( - (n + 31) // 32 * 32, - i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 - + i0_i1_fused_1_1 * 4 - + ax1, - ) - v2 = T.axis.spatial(11008, i2_0 * 128 + i2_1 * 8 + ax2) - T.reads( - var_NT_matmul_intermediate_pad_local[v0, v1, v2] - ) - T.writes(p_output0_intermediate[v0, v1, v2]) - if v1 < n: - p_output0_intermediate[ - v0, v1, v2 - ] = var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ] * T.sigmoid( - var_NT_matmul_intermediate_pad_local[v0, v1, v2] - ) - - -@T.prim_func(private=True) -def fused_decode1_fused_NT_matmul2_multiply( - lv43: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), - lv44: T.Buffer((T.int64(128), T.int64(11008)), "float16"), - p_lv45: T.handle, - p_lv132: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16") - lv132 = T.match_buffer(p_lv132, (T.int64(1), n, T.int64(11008)), "float16") - p_output0_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(11008)), "float16" - ) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - var_T_transpose_intermediate = T.alloc_buffer( - (T.int64(11008), T.int64(4096)), "float16" - ) - var_NT_matmul_intermediate = T.alloc_buffer( - (T.int64(1), n, T.int64(11008)), "float16" - ) - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv43[v_i // T.int64(8), v_j], lv44[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv43[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv44[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) - var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv45[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv45[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - lv132[v_ax0, v_ax1, v_ax2], - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], - ) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - lv132[v_ax0, v_ax1, v_ax2] - * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] - ) - - -@T.prim_func(private=True) -def fused_decode1_fused_NT_matmul2_multiply_after( - lv43: T.Buffer((512, 11008), "uint32"), - lv44: T.Buffer((128, 11008), "float16"), - p_lv45: T.handle, - p_lv132: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int32() - lv45 = T.match_buffer(p_lv45, (1, n, 4096), "float16") - lv132 = T.match_buffer(p_lv132, (1, n, 11008), "float16") - p_output0_intermediate = T.match_buffer(p_output0, (1, n, 11008), "float16") - # with T.block("root"): - decode_local = T.alloc_buffer((4096, 11008), "float16", scope="local") - lv43_local = T.alloc_buffer((512, 11008), "uint32", scope="local") - lv44_local = T.alloc_buffer((128, 11008), "float16", scope="local") - lv45_pad_local = T.alloc_buffer( - (1, (n + 31) // 32 * 32, 4096), "float16", scope="local" - ) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer( - (1, (n + 31) // 32 * 32, 11008), "float16", scope="local" - ) - for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( - (n + 31) // 32, thread="blockIdx.y" - ): - for i2_0 in T.thread_binding(86, thread="blockIdx.x"): - for i0_i1_fused_1_1 in T.thread_binding(8, thread="threadIdx.y"): - for i2_1 in T.thread_binding(16, thread="threadIdx.x"): - for i0_i1_fused_1_2_init in range(4): - for i2_2_init in T.vectorized(8): - with T.block("NT_matmul_init"): - v_i0 = T.axis.spatial(1, 0) - v_i1 = T.axis.spatial( - (n + 31) // 32 * 32, - i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 - + i0_i1_fused_1_1 * 4 - + i0_i1_fused_1_2_init, - ) - v_i2 = T.axis.spatial( - 11008, i2_0 * 128 + i2_1 * 8 + i2_2_init - ) - T.reads() - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = T.float16(0) - for k_0 in range(128): - for ax0 in range(1): - for ax1 in T.vectorized(8): - with T.block("lv44_local"): - v0 = T.axis.spatial(128, k_0 + ax0) - v1 = T.axis.spatial( - 11008, i2_0 * 128 + i2_1 * 8 + ax1 - ) - T.reads(lv44[v0, v1]) - T.writes(lv44_local[v0, v1]) - lv44_local[v0, v1] = lv44[v0, v1] - for k_1 in range(4): - for ax0 in range(1): - for ax1 in T.vectorized(8): - with T.block("lv43_local"): - v0 = T.axis.spatial(512, k_0 * 4 + k_1 + ax0) - v1 = T.axis.spatial( - 11008, i2_0 * 128 + i2_1 * 8 + ax1 - ) - T.reads(lv43[v0, v1]) - T.writes(lv43_local[v0, v1]) - lv43_local[v0, v1] = lv43[v0, v1] - for k_2 in range(8): - for ax0 in range(1): - for ax1 in T.vectorized(8): - with T.block("decode"): - v_i = T.axis.spatial( - 4096, k_0 * 32 + k_1 * 8 + k_2 + ax0 - ) - v_j = T.axis.spatial( - 11008, i2_0 * 128 + i2_1 * 8 + ax1 - ) - T.reads( - lv43_local[v_i // 8, v_j], - lv44_local[v_i // 32, v_j], - ) - T.writes(decode_local[v_i, v_j]) - decode_local[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv43_local[v_i // 8, v_j], - T.Cast("uint32", v_i % 8) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv44_local[v_i // 32, v_j] - for ax0, ax1 in T.grid(1, 4): - for ax2 in T.vectorized(1): - with T.block("lv45_pad_local"): - v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial( - (n + 31) // 32 * 32, - i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 - + i0_i1_fused_1_1 * 4 - + ax1, - ) - v2 = T.axis.spatial( - 4096, k_0 * 32 + k_1 * 8 + k_2 + ax2 - ) - T.reads(lv45[v0, v1, v2]) - T.writes(lv45_pad_local[v0, v1, v2]) - lv45_pad_local[v0, v1, v2] = T.if_then_else( - v1 < n, lv45[v0, v1, v2], T.float16(0) - ) - for i0_i1_fused_1_2 in range(4): - for i2_2 in T.vectorized(8): - with T.block("NT_matmul_update"): - v_i0 = T.axis.spatial(1, 0) - v_i1 = T.axis.spatial( - (n + 31) // 32 * 32, - i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 - + i0_i1_fused_1_1 * 4 - + i0_i1_fused_1_2, - ) - v_i2 = T.axis.spatial( - 11008, i2_0 * 128 + i2_1 * 8 + i2_2 - ) - v_k = T.axis.reduce( - 4096, k_0 * 32 + k_1 * 8 + k_2 - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ], - lv45_pad_local[v_i0, v_i1, v_k], - decode_local[v_k, v_i2], - ) - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = ( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - + lv45_pad_local[v_i0, v_i1, v_k] - * decode_local[v_k, v_i2] - ) - for ax0, ax1 in T.grid(1, 4): - for ax2 in T.vectorized(8): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial( - (n + 31) // 32 * 32, - i0_i1_fused_0_i0_i1_fused_1_0_fused * 32 - + i0_i1_fused_1_1 * 4 - + ax1, - ) - v2 = T.axis.spatial(11008, i2_0 * 128 + i2_1 * 8 + ax2) - T.reads( - lv132[v0, v1, v2], - var_NT_matmul_intermediate_pad_local[v0, v1, v2], - ) - T.writes(p_output0_intermediate[v0, v1, v2]) - if v1 < n: - p_output0_intermediate[v0, v1, v2] = ( - lv132[v0, v1, v2] - * var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ] - ) - - -@T.prim_func(private=True) -def fused_decode_fused_NT_matmul_add( - lv29: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), - lv30: T.Buffer((T.int64(128), T.int64(4096)), "float16"), - p_lv41: T.handle, - p_lv2: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv41 = T.match_buffer(p_lv41, (T.int64(1), n, T.int64(4096)), "float16") - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096)), "float16") - p_output0_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(4096)), "float16" - ) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - var_T_transpose_intermediate = T.alloc_buffer( - (T.int64(4096), T.int64(4096)), "float16" - ) - var_NT_matmul_intermediate = T.alloc_buffer( - (T.int64(1), n, T.int64(4096)), "float16" - ) - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv29[v_i // T.int64(8), v_j], lv30[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv29[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv30[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) - var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv41[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv41[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - lv2[v_ax0, v_ax1, v_ax2], - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], - ) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - lv2[v_ax0, v_ax1, v_ax2] - + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] - ) - - -@T.prim_func(private=True) -def fused_decode_fused_NT_matmul_add_after( - lv8: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), - lv9: T.Buffer((T.int64(128), T.int64(4096)), "float16"), - p_lv41: T.handle, - p_lv2: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - lv6 = T.match_buffer(p_lv41, (1, n, 4096), "float16") - lv2 = T.match_buffer(p_lv2, (1, n, 4096), "float16") - var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16") - - var_matmul_intermediate_local = T.alloc_buffer( - (T.int64(1), ((n+7)//8) * 8, T.int64(4096)), "float16", scope="local" - ) - var_matmul_intermediate_local_batch = T.alloc_buffer( - (T.int64(1), ((n+7)//8) * 8, T.int64(4096)), "float16", scope="local" - ) - lv8_local = T.alloc_buffer((T.int64(512), T.int64(4096)), "uint32", scope="local") - lv9_local = T.alloc_buffer( - (T.int64(128), T.int64(4096)), "float16", scope="local" - ) - #lv6_shared = T.alloc_buffer( - # (T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared" - #) - for i0_i1_i2_fused_n in T.thread_binding(((n+7)//8), thread="blockIdx.y"): - for i0_i1_i2_fused_0 in T.thread_binding(T.int64(32), thread="blockIdx.x"): - for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - with T.block("n_check"): - T.where((i0_i1_i2_fused_n * T.int64(8) + ax2_y) < n) - for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2_init - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_1 in range(T.int64(128)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("matmul_init_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2k = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads() - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = T.float16(0) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv9_local"): - v0 = T.axis.spatial( - T.int64(128), k_1 - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv9[v0, v1]) - T.writes(lv9_local[v0, v1]) - lv9_local[v0, v1] = lv9[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv8_local"): - v0 = T.axis.spatial( - T.int64(512), - k_1 * T.int64(4) - + k_2 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv8[v0, v1]) - T.writes(lv8_local[v0, v1]) - lv8_local[v0, v1] = lv8[v0, v1] - for k_3 in range(T.int64(8)): - for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_k = T.axis.reduce( - T.int64(4096), - k_1 * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - T.reads( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ], - lv6[v_i0, v_i1, v_k], - lv8_local[v_k // T.int64(8), v_i2], - ) - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] = var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] + lv6[ - v_i0, v_i1, v_k - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv8_local[ - v_k // T.int64(8), v_i2 - ], - T.Cast( - "uint32", - v_k % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - ) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("multiple_scale"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - v0 = T.axis.spatial( - T.int64(128), - k_1 - ) - v1 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads( - lv9_local[v0, v1], - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ], - ) - T.writes( - var_matmul_intermediate_local[v_i0, v_i1, v_i2] - ) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate_local[v_i0, v_i1, v_i2] - + var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] - * lv9_local[v0, v1] - ) - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(4096), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax2, - ) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2[v_i0, v_i1, v_i2]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2[v_i0, v_i1, v_i2] - - -@T.prim_func(private=True) -def fused_decode4_fused_matmul6_add4( - lv1363: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), - lv1364: T.Buffer((T.int64(80), T.int64(2560)), "float16"), - lv2067: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), - linear_bias192: T.Buffer((T.int64(2560),), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(2560)), "float16" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") - var_matmul_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(2560)), "float16" - ) - for i, j in T.grid(T.int64(2560), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1363[v_i // T.int64(8), v_j], lv1364[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1363[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1364[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2067[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv2067[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias192[v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias192[v_ax2] - ) - - -def sch_fused_decode4_fused_matmul6_add4(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[10, 256, 1] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[160, 8, 2] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) - v21 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 - ) - sch.annotate( - block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 - ) - l22, l23, l24, l25, l26 = sch.get_loops(block=b19) - sch.vectorize(loop=l26) - sch.vectorize(loop=l12) - b27 = sch.decompose_reduction(block=b1, loop=l16) - b28 = sch.get_block(name="T_add", func_name="main") - sch.reverse_compute_inline(block=b28) - sch.enter_postproc() - sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") - l29, l30, l31, l32, l33 = sch.get_loops(block=b20) - l34, l35, l36 = sch.split( - loop=l33, factors=[None, 256, 8], preserve_unit_iters=True - ) - sch.vectorize(loop=l36) - sch.bind(loop=l35, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func(private=True) -def fused_decode6_fused_matmul9_add7_cast8_cast12_add5( - lv1393: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), - lv1394: T.Buffer((T.int64(320), T.int64(2560)), "float16"), - lv2121: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16"), - linear_bias197: T.Buffer((T.int64(2560),), "float32"), - lv329: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(2560)), "float16" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - var_compute_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(2560)), "float16" - ) - var_compute_intermediate_1 = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(2560)), "float16" - ) - for i, j in T.grid(T.int64(10240), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1393[v_i // T.int64(8), v_j], lv1394[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1393[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1394[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(10240)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2121[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[ - v_i0, v_i1, v_i2 - ] + T.Cast("float32", lv2121[v_i0, v_i1, v_k]) * T.Cast( - "float32", var_decode_intermediate[v_k, v_i2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias197[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias197[v_ax2] - ) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast( - "float16", var_T_add_intermediate[v_i0, v_i1, v_i2] - ) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) - var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[ - v_i0, v_i1, v_i2 - ] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], - lv329[v_ax0, v_ax1, v_ax2], - ) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] - + lv329[v_ax0, v_ax1, v_ax2] - ) - - -def sch_fused_decode6_fused_matmul9_add7_cast8_cast12_add5(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[10, 256, 1] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[640, 2, 8] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) - v21 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 - ) - sch.annotate( - block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 - ) - l22, l23, l24, l25, l26 = sch.get_loops(block=b19) - sch.vectorize(loop=l26) - sch.vectorize(loop=l12) - b27 = sch.decompose_reduction(block=b1, loop=l16) - b28 = sch.get_block(name="T_add", func_name="main") - bb1 = sch.get_block(name="compute", func_name="main") - bb2 = sch.get_block(name="compute_1", func_name="main") - bb3 = sch.get_block(name="T_add_1", func_name="main") - sch.compute_inline(block=b28) - sch.compute_inline(block=bb1) - sch.compute_inline(block=bb2) - sch.reverse_compute_inline(block=bb3) - sch.enter_postproc() - sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") - l29, l30, l31, l32, l33 = sch.get_loops(block=b20) - l34, l35, l36 = sch.split( - loop=l33, factors=[None, 256, 8], preserve_unit_iters=True - ) - sch.vectorize(loop=l36) - sch.bind(loop=l35, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func(private=True) -def fused_decode5_fused_matmul8_add6_gelu1_cast11( - lv1387: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), - lv1388: T.Buffer((T.int64(80), T.int64(10240)), "float16"), - lv2115: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), - linear_bias196: T.Buffer((T.int64(10240),), "float32"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(10240)), "float16" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(10240)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - T_multiply = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - T_multiply_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - T_add = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - var_T_multiply_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(10240)) - ) - for i, j in T.grid(T.int64(2560), T.int64(10240)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1387[v_i // T.int64(8), v_j], lv1388[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1387[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1388[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(10240), T.int64(2560)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2115[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[ - v_i0, v_i1, v_i2 - ] + T.Cast("float32", lv2115[v_i0, v_i1, v_k]) * T.Cast( - "float32", var_decode_intermediate[v_k, v_i2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias196[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias196[v_ax2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) - T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[ - v_ax0, v_ax1, v_ax2 - ] * T.float32(0.70710678118654757) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(T_multiply[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(compute[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[ - v_ax0, v_ax1, v_ax2 - ] * T.float32(0.5) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T.writes(T_add[v_ax0, v_ax1, v_ax2]) - T_add[v_ax0, v_ax1, v_ax2] = ( - T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_multiply_2"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2] - ) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] - ) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_multiply_intermediate[v_i0, v_i1, v_i2]) - T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) - p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( - "float16", var_T_multiply_intermediate[v_i0, v_i1, v_i2] - ) - - -def sch_fused_decode5_fused_matmul8_add6_gelu1_cast11(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[10, 256, 4] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[80, 4, 8] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) - v21 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 - ) - sch.annotate( - block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 - ) - l22, l23, l24, l25, l26 = sch.get_loops(block=b19) - sch.vectorize(loop=l26) - sch.vectorize(loop=l12) - b27 = sch.decompose_reduction(block=b1, loop=l16) - b28 = sch.get_block(name="T_add", func_name="main") - bb1 = sch.get_block(name="T_multiply", func_name="main") - bb2 = sch.get_block(name="compute", func_name="main") - bb3 = sch.get_block(name="T_multiply_1", func_name="main") - bb4 = sch.get_block(name="T_add_1", func_name="main") - bb5 = sch.get_block(name="T_multiply_2", func_name="main") - bb6 = sch.get_block(name="compute_1", func_name="main") - sch.compute_inline(block=b28) - sch.compute_inline(block=bb1) - sch.compute_inline(block=bb2) - sch.compute_inline(block=bb3) - sch.compute_inline(block=bb4) - sch.compute_inline(block=bb5) - sch.reverse_compute_inline(block=bb6) - sch.enter_postproc() - sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") - l29, l30, l31, l32, l33 = sch.get_loops(block=b20) - l34, l35, l36 = sch.split( - loop=l33, factors=[None, 256, 8], preserve_unit_iters=True - ) - sch.vectorize(loop=l36) - sch.bind(loop=l35, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func(private=True) -def fused_decode4_fused_matmul6_add4_add5( - lv1381: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), - lv1382: T.Buffer((T.int64(80), T.int64(2560)), "float16"), - lv328: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), - linear_bias195: T.Buffer((T.int64(2560),), "float16"), - lv2062: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(2560)), "float16" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") - var_matmul_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(2560)), "float16" - ) - var_T_add_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(2560)), "float16" - ) - for i, j in T.grid(T.int64(2560), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1381[v_i // T.int64(8), v_j], lv1382[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1381[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1382[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv328[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv328[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias195[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias195[v_ax2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_T_add_intermediate[v_ax0, v_ax1, v_ax2], lv2062[v_ax0, v_ax1, v_ax2] - ) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] - + lv2062[v_ax0, v_ax1, v_ax2] - ) - - -def sch_fused_decode4_fused_matmul6_add4_add5(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[10, 256, 1] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[160, 8, 2] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) - v21 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 - ) - sch.annotate( - block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 - ) - l22, l23, l24, l25, l26 = sch.get_loops(block=b19) - sch.vectorize(loop=l26) - sch.vectorize(loop=l12) - b27 = sch.decompose_reduction(block=b1, loop=l16) - b28 = sch.get_block(name="T_add", func_name="main") - bb4 = sch.get_block(name="T_add_1", func_name="main") - sch.compute_inline(block=b28) - sch.reverse_compute_inline(block=bb4) - sch.enter_postproc() - sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") - l29, l30, l31, l32, l33 = sch.get_loops(block=b20) - l34, l35, l36 = sch.split( - loop=l33, factors=[None, 256, 8], preserve_unit_iters=True - ) - sch.vectorize(loop=l36) - sch.bind(loop=l35, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func(private=True) -def fused_decode3_matmul3( - lv2515: T.Buffer((T.int64(320), T.int64(50432)), "uint32"), - lv2516: T.Buffer((T.int64(80), T.int64(50432)), "float32"), - lv705: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), - var_matmul_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(50432)), "float32" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(50432))) - for i, j in T.grid(T.int64(2560), T.int64(50432)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv2515[v_i // T.int64(8), v_j], lv2516[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float32", - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv2515[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7), - ) - * lv2516[v_i // T.int64(32), v_j] - ) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(50432), T.int64(2560)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv705[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate[v_i0, v_i1, v_i2] - + lv705[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - ) - - -def sch_fused_decode3_matmul3(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[197, 128, 2] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[80, 4, 8] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) - v21 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 - ) - sch.annotate( - block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 - ) - l22, l23, l24, l25, l26 = sch.get_loops(block=b19) - sch.vectorize(loop=l26) - sch.vectorize(loop=l12) - b27 = sch.decompose_reduction(block=b1, loop=l16) - sch.enter_postproc() - sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") - l29, l30, l31, l32, l33 = sch.get_loops(block=b20) - l34, l35, l36 = sch.split( - loop=l33, factors=[None, 128, 8], preserve_unit_iters=True - ) - sch.vectorize(loop=l36) - sch.bind(loop=l35, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func(private=True) -def fused_decode6_fused_matmul9_add7_cast8_cast12_add5_cast7( - lv2509: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), - lv2510: T.Buffer((T.int64(320), T.int64(2560)), "float16"), - lv4105: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16"), - linear_bias383: T.Buffer((T.int64(2560),), "float32"), - lv701: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), - p_output0_intermediate: T.Buffer( - (T.int64(1), T.int64(1), T.int64(2560)), "float32" - ), -): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - var_compute_intermediate = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(2560)), "float16" - ) - var_compute_intermediate_1 = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(2560)), "float16" - ) - var_T_add_intermediate_1 = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(2560)), "float16" - ) - for i, j in T.grid(T.int64(10240), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv2509[v_i // T.int64(8), v_j], lv2510[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv2509[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv2510[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(10240)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv4105[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[ - v_i0, v_i1, v_i2 - ] + T.Cast("float32", lv4105[v_i0, v_i1, v_k]) * T.Cast( - "float32", var_decode_intermediate[v_k, v_i2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias383[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias383[v_ax2] - ) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast( - "float16", var_T_add_intermediate[v_i0, v_i1, v_i2] - ) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) - var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[ - v_i0, v_i1, v_i2 - ] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], - lv701[v_ax0, v_ax1, v_ax2], - ) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = ( - var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] - + lv701[v_ax0, v_ax1, v_ax2] - ) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute_2"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) - p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( - "float32", var_T_add_intermediate_1[v_i0, v_i1, v_i2] - ) - - -def sch_fused_decode6_fused_matmul9_add7_cast8_cast12_add5_cast7(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block(name="decode", func_name="main") - b1 = sch.get_block(name="matmul", func_name="main") - l2, l3, l4, l5 = sch.get_loops(block=b1) - l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True) - v7, v8, v9 = sch.sample_perfect_tile( - loop=l6, n=3, max_innermost_factor=4, decision=[5, 256, 2] - ) - l10, l11, l12 = sch.split(loop=l6, factors=[v7, v8, v9], preserve_unit_iters=True) - v13, v14, v15 = sch.sample_perfect_tile( - loop=l5, n=3, max_innermost_factor=8, decision=[320, 4, 8] - ) - l16, l17, l18 = sch.split( - loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True - ) - sch.reorder(l10, l11, l16, l17, l18, l12) - sch.bind(loop=l10, thread_axis="blockIdx.x") - sch.bind(loop=l11, thread_axis="threadIdx.x") - sch.compute_inline(block=b0) - b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1) - b20 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared") - sch.compute_at(block=b20, loop=l11, preserve_unit_loops=True, index=-1) - v21 = sch.sample_categorical( - candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3 - ) - sch.annotate( - block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch", ann_val=v21 - ) - l22, l23, l24, l25, l26 = sch.get_loops(block=b19) - sch.vectorize(loop=l26) - sch.vectorize(loop=l12) - b27 = sch.decompose_reduction(block=b1, loop=l16) - b28 = sch.get_block(name="T_add", func_name="main") - bb1 = sch.get_block(name="compute", func_name="main") - bb2 = sch.get_block(name="compute_1", func_name="main") - bb3 = sch.get_block(name="T_add_1", func_name="main") - bb4 = sch.get_block(name="compute_2", func_name="main") - sch.compute_inline(block=b28) - sch.compute_inline(block=bb1) - sch.compute_inline(block=bb2) - sch.compute_inline(block=bb3) - sch.reverse_compute_inline(block=bb4) - sch.enter_postproc() - sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.cooperative_fetch") - l29, l30, l31, l32, l33 = sch.get_loops(block=b20) - l34, l35, l36 = sch.split( - loop=l33, factors=[None, 256, 8], preserve_unit_iters=True - ) - sch.vectorize(loop=l36) - sch.bind(loop=l35, thread_axis="threadIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func(private=True) -def fused_decode2_fused_NT_matmul3_add6_gelu1_cast11( - lv36: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), - lv37: T.Buffer((T.int64(80), T.int64(10240)), "float16"), - p_lv57: T.handle, - linear_bias4: T.Buffer((T.int64(10240),), "float32"), - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv57 = T.match_buffer(p_lv57, (T.int64(1), n, T.int64(2560)), "float16") - p_output0_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(10240)), "float16" - ) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(2560), T.int64(10240)), "float16") - var_T_transpose_intermediate = T.alloc_buffer( - (T.int64(10240), T.int64(2560)), "float16" - ) - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_multiply = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - compute = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_multiply_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_add = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - var_T_multiply_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - for i, j in T.grid(T.int64(2560), T.int64(10240)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv36[v_i // T.int64(8), v_j], lv37[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv36[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv37[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(10240), T.int64(2560)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) - var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(10240), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv57[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[ - v_i0, v_i1, v_i2 - ] + T.Cast("float32", lv57[v_i0, v_i1, v_k]) * T.Cast( - "float32", var_T_transpose_intermediate[v_i2, v_k] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias4[v_ax2] - ) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias4[v_ax2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) - T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[ - v_ax0, v_ax1, v_ax2 - ] * T.float32(0.70710678118654757) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(T_multiply[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(compute[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[ - v_ax0, v_ax1, v_ax2 - ] * T.float32(0.5) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T.writes(T_add[v_ax0, v_ax1, v_ax2]) - T_add[v_ax0, v_ax1, v_ax2] = ( - T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply_2"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2] - ) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] - ) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_multiply_intermediate[v_i0, v_i1, v_i2]) - T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) - p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( - "float16", var_T_multiply_intermediate[v_i0, v_i1, v_i2] - ) - - -@T.prim_func(private=True) -def fused_decode2_fused_NT_matmul3_add6_gelu1_cast11_after( - lv36: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), - lv37: T.Buffer((T.int64(80), T.int64(10240)), "float16"), - p_lv57: T.handle, - linear_bias4: T.Buffer((T.int64(10240),), "float32"), - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True), "tir.noalias": T.bool(True)}) - n = T.int64() - lv57 = T.match_buffer(p_lv57, (T.int64(1), n, T.int64(2560)), "float16") - p_output0_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(10240)), "float16" - ) - with T.block("root"): - T.reads() - T.writes() - T.block_attr({"meta_schedule.thread_extent_low_inclusive": 32}) - decode_local = T.alloc_buffer( - (T.int64(2560), T.int64(10240)), "float16", scope="local" - ) - lv36_local = T.alloc_buffer( - (T.int64(320), T.int64(10240)), "uint32", scope="local" - ) - lv37_local = T.alloc_buffer( - (T.int64(80), T.int64(10240)), "float16", scope="local" - ) - lv57_pad_local = T.alloc_buffer( - (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), - "float16", - scope="local", - ) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer( - ( - T.int64(1), - (n + T.int64(31)) // T.int64(32) * T.int64(32), - T.int64(10240), - ), - scope="local", - ) - for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( - (n + T.int64(31)) // T.int64(32), thread="blockIdx.y" - ): - for i2_0 in T.thread_binding(T.int64(80), thread="blockIdx.x"): - for i0_i1_fused_1_1 in T.thread_binding( - T.int64(8), thread="threadIdx.y" - ): - for i2_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): - for i0_i1_fused_1_2_init in range(T.int64(4)): - for i2_2_init in T.vectorized(T.int64(8)): - with T.block("NT_matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial( - (n + T.int64(31)) // T.int64(32) * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + i0_i1_fused_1_2_init, - ) - v_i2 = T.axis.spatial( - T.int64(10240), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + i2_2_init, - ) - T.reads() - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = T.float32(0) - for k_0_0, k_0_1 in T.grid(T.int64(20), T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("lv37_local"): - v0 = T.axis.spatial( - T.int64(80), - k_0_0 * T.int64(4) + k_0_1 + ax0, - ) - v1 = T.axis.spatial( - T.int64(10240), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads(lv37[v0, v1]) - T.writes(lv37_local[v0, v1]) - lv37_local[v0, v1] = lv37[v0, v1] - for k_1 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("lv36_local"): - v0 = T.axis.spatial( - T.int64(320), - k_0_0 * T.int64(16) - + k_0_1 * T.int64(4) - + k_1 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(10240), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads(lv36[v0, v1]) - T.writes(lv36_local[v0, v1]) - lv36_local[v0, v1] = lv36[v0, v1] - for k_2 in range(T.int64(8)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("decode"): - v_i = T.axis.spatial( - T.int64(2560), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2 - + ax0, - ) - v_j = T.axis.spatial( - T.int64(10240), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads( - lv36_local[v_i // T.int64(8), v_j], - lv37_local[v_i // T.int64(32), v_j], - ) - T.writes(decode_local[v_i, v_j]) - decode_local[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv36_local[ - v_i // T.int64(8), - v_j, - ], - T.Cast( - "uint32", - v_i % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv37_local[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): - for ax2 in T.vectorized(T.int64(1)): - with T.block("lv57_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial( - (n + T.int64(31)) - // T.int64(32) - * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + ax1, - ) - v2 = T.axis.spatial( - T.int64(2560), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2 - + ax2, - ) - T.reads(lv57[v0, v1, v2]) - T.writes(lv57_pad_local[v0, v1, v2]) - lv57_pad_local[ - v0, v1, v2 - ] = T.if_then_else( - v1 < n, - lv57[v0, v1, v2], - T.float16(0), - ) - for i0_i1_fused_1_2 in range(T.int64(4)): - for i2_2 in T.vectorized(T.int64(8)): - with T.block("NT_matmul_update"): - v_i0 = T.axis.spatial( - T.int64(1), T.int64(0) - ) - v_i1 = T.axis.spatial( - (n + T.int64(31)) - // T.int64(32) - * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + i0_i1_fused_1_2, - ) - v_i2 = T.axis.spatial( - T.int64(10240), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + i2_2, - ) - v_k = T.axis.reduce( - T.int64(2560), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2, - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ], - lv57_pad_local[v_i0, v_i1, v_k], - decode_local[v_k, v_i2], - ) - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] + T.Cast( - "float32", - lv57_pad_local[v_i0, v_i1, v_k], - ) * T.Cast( - "float32", decode_local[v_k, v_i2] - ) - for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): - for ax2 in T.vectorized(T.int64(8)): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial( - (n + T.int64(31)) // T.int64(32) * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + ax1, - ) - v2 = T.axis.spatial( - T.int64(10240), - i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax2, - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ], - linear_bias4[v2], - ) - T.writes(p_output0_intermediate[v0, v1, v2]) - if v1 < n: - p_output0_intermediate[v0, v1, v2] = T.Cast( - "float16", - ( - var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ] - + linear_bias4[v2] - ) - * ( - T.float32(0.5) - + T.erf( - ( - var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ] - + linear_bias4[v2] - ) - * T.float32(0.70710678118654757) - ) - * T.float32(0.5) - ), - ) - - -@T.prim_func(private=True) -def fused_decode1_fused_NT_matmul1_add4( - lv8: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), - lv9: T.Buffer((T.int64(80), T.int64(2560)), "float16"), - p_lv9: T.handle, - linear_bias: T.Buffer((T.int64(2560),), "float16"), - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv9_1 = T.match_buffer(p_lv9, (T.int64(1), n, T.int64(2560)), "float16") - p_output0_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(2560)), "float16" - ) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") - var_T_transpose_intermediate = T.alloc_buffer( - (T.int64(2560), T.int64(2560)), "float16" - ) - var_NT_matmul_intermediate = T.alloc_buffer( - (T.int64(1), n, T.int64(2560)), "float16" - ) - for i, j in T.grid(T.int64(2560), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv8[v_i // T.int64(8), v_j], lv9[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv8[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv9[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(2560), T.int64(2560)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) - var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv9_1[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv9_1[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias[v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias[v_ax2] - ) - - -@T.prim_func(private=True) -def fused_decode1_fused_NT_matmul1_add4_after( - lv8: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), - lv9: T.Buffer((T.int64(80), T.int64(2560)), "float16"), - p_lv9: T.handle, - linear_bias: T.Buffer((T.int64(2560),), "float16"), - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True), "tir.noalias": T.bool(True)}) - n = T.int64() - lv9_1 = T.match_buffer(p_lv9, (T.int64(1), n, T.int64(2560)), "float16") - p_output0_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(2560)), "float16" - ) - with T.block("root"): - T.reads() - T.writes() - T.block_attr({"meta_schedule.thread_extent_low_inclusive": 32}) - decode_local = T.alloc_buffer( - (T.int64(2560), T.int64(2560)), "float16", scope="local" - ) - lv8_local = T.alloc_buffer( - (T.int64(320), T.int64(2560)), "uint32", scope="local" - ) - lv9_local = T.alloc_buffer( - (T.int64(80), T.int64(2560)), "float16", scope="local" - ) - lv9_1_pad_local = T.alloc_buffer( - (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), - "float16", - scope="local", - ) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer( - (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), - "float16", - scope="local", - ) - for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( - (n + T.int64(31)) // T.int64(32), thread="blockIdx.y" - ): - for i2_0 in T.thread_binding(T.int64(20), thread="blockIdx.x"): - for i0_i1_fused_1_1 in T.thread_binding( - T.int64(8), thread="threadIdx.y" - ): - for i2_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): - for i0_i1_fused_1_2_init in range(T.int64(4)): - for i2_2_init in T.vectorized(T.int64(8)): - with T.block("NT_matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial( - (n + T.int64(31)) // T.int64(32) * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + i0_i1_fused_1_2_init, - ) - v_i2 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + i2_2_init, - ) - T.reads() - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = T.float16(0) - for k_0_0, k_0_1 in T.grid(T.int64(20), T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("lv9_local"): - v0 = T.axis.spatial( - T.int64(80), - k_0_0 * T.int64(4) + k_0_1 + ax0, - ) - v1 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads(lv9[v0, v1]) - T.writes(lv9_local[v0, v1]) - lv9_local[v0, v1] = lv9[v0, v1] - for k_1 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("lv8_local"): - v0 = T.axis.spatial( - T.int64(320), - k_0_0 * T.int64(16) - + k_0_1 * T.int64(4) - + k_1 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads(lv8[v0, v1]) - T.writes(lv8_local[v0, v1]) - lv8_local[v0, v1] = lv8[v0, v1] - for k_2 in range(T.int64(8)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("decode"): - v_i = T.axis.spatial( - T.int64(2560), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2 - + ax0, - ) - v_j = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads( - lv8_local[v_i // T.int64(8), v_j], - lv9_local[v_i // T.int64(32), v_j], - ) - T.writes(decode_local[v_i, v_j]) - decode_local[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv8_local[ - v_i // T.int64(8), - v_j, - ], - T.Cast( - "uint32", - v_i % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv9_local[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): - for ax2 in T.vectorized(T.int64(1)): - with T.block("lv9_1_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial( - (n + T.int64(31)) - // T.int64(32) - * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + ax1, - ) - v2 = T.axis.spatial( - T.int64(2560), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2 - + ax2, - ) - T.reads(lv9_1[v0, v1, v2]) - T.writes(lv9_1_pad_local[v0, v1, v2]) - lv9_1_pad_local[ - v0, v1, v2 - ] = T.if_then_else( - v1 < n, - lv9_1[v0, v1, v2], - T.float16(0), - ) - for i0_i1_fused_1_2 in range(T.int64(4)): - for i2_2 in T.vectorized(T.int64(8)): - with T.block("NT_matmul_update"): - v_i0 = T.axis.spatial( - T.int64(1), T.int64(0) - ) - v_i1 = T.axis.spatial( - (n + T.int64(31)) - // T.int64(32) - * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + i0_i1_fused_1_2, - ) - v_i2 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + i2_2, - ) - v_k = T.axis.reduce( - T.int64(2560), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2, - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ], - lv9_1_pad_local[v_i0, v_i1, v_k], - decode_local[v_k, v_i2], - ) - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = ( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - + lv9_1_pad_local[v_i0, v_i1, v_k] - * decode_local[v_k, v_i2] - ) - for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): - for ax2 in T.vectorized(T.int64(8)): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial( - (n + T.int64(31)) // T.int64(32) * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + ax1, - ) - v2 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax2, - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ], - linear_bias[v2], - ) - T.writes(p_output0_intermediate[v0, v1, v2]) - if v1 < n: - p_output0_intermediate[v0, v1, v2] = ( - var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ] - + linear_bias[v2] - ) - - -@T.prim_func(private=True) -def fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5( - lv43: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), - lv44: T.Buffer((T.int64(320), T.int64(2560)), "float16"), - p_lv63: T.handle, - linear_bias5: T.Buffer((T.int64(2560),), "float32"), - p_lv7: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv63 = T.match_buffer(p_lv63, (T.int64(1), n, T.int64(10240)), "float16") - lv7 = T.match_buffer(p_lv7, (T.int64(1), n, T.int64(2560)), "float16") - p_output0_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(2560)), "float16" - ) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") - var_T_transpose_intermediate = T.alloc_buffer( - (T.int64(2560), T.int64(10240)), "float16" - ) - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_compute_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_compute_intermediate_1 = T.alloc_buffer( - (T.int64(1), n, T.int64(2560)), "float16" - ) - for i, j in T.grid(T.int64(10240), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv43[v_i // T.int64(8), v_j], lv44[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv43[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv44[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(2560), T.int64(10240)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) - var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv63[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[ - v_i0, v_i1, v_i2 - ] + T.Cast("float32", lv63[v_i0, v_i1, v_k]) * T.Cast( - "float32", var_T_transpose_intermediate[v_i2, v_k] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias5[v_ax2] - ) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias5[v_ax2] - ) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast( - "float16", var_T_add_intermediate[v_i0, v_i1, v_i2] - ) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) - var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[ - v_i0, v_i1, v_i2 - ] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], - lv7[v_ax0, v_ax1, v_ax2], - ) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] - + lv7[v_ax0, v_ax1, v_ax2] - ) - - -@T.prim_func(private=True) -def fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5_after( - lv43: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), - lv44: T.Buffer((T.int64(320), T.int64(2560)), "float16"), - p_lv63: T.handle, - linear_bias5: T.Buffer((T.int64(2560),), "float32"), - p_lv7: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True), "tir.noalias": T.bool(True)}) - n = T.int64() - lv63 = T.match_buffer(p_lv63, (T.int64(1), n, T.int64(10240)), "float16") - lv7 = T.match_buffer(p_lv7, (T.int64(1), n, T.int64(2560)), "float16") - p_output0_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(2560)), "float16" - ) - with T.block("root"): - T.reads() - T.writes() - T.block_attr({"meta_schedule.thread_extent_low_inclusive": 32}) - decode_local = T.alloc_buffer( - (T.int64(10240), T.int64(2560)), "float16", scope="local" - ) - lv43_local = T.alloc_buffer( - (T.int64(1280), T.int64(2560)), "uint32", scope="local" - ) - lv44_local = T.alloc_buffer( - (T.int64(320), T.int64(2560)), "float16", scope="local" - ) - lv63_pad_local = T.alloc_buffer( - ( - T.int64(1), - (n + T.int64(31)) // T.int64(32) * T.int64(32), - T.int64(10240), - ), - "float16", - scope="local", - ) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer( - (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), - scope="local", - ) - for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( - (n + T.int64(31)) // T.int64(32), thread="blockIdx.y" - ): - for i2_0 in T.thread_binding(T.int64(20), thread="blockIdx.x"): - for i0_i1_fused_1_1 in T.thread_binding( - T.int64(8), thread="threadIdx.y" - ): - for i2_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): - for i0_i1_fused_1_2_init in range(T.int64(4)): - for i2_2_init in T.vectorized(T.int64(8)): - with T.block("NT_matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial( - (n + T.int64(31)) // T.int64(32) * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + i0_i1_fused_1_2_init, - ) - v_i2 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + i2_2_init, - ) - T.reads() - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = T.float32(0) - for k_0_0, k_0_1 in T.grid(T.int64(80), T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("lv44_local"): - v0 = T.axis.spatial( - T.int64(320), - k_0_0 * T.int64(4) + k_0_1 + ax0, - ) - v1 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads(lv44[v0, v1]) - T.writes(lv44_local[v0, v1]) - lv44_local[v0, v1] = lv44[v0, v1] - for k_1 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("lv43_local"): - v0 = T.axis.spatial( - T.int64(1280), - k_0_0 * T.int64(16) - + k_0_1 * T.int64(4) - + k_1 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads(lv43[v0, v1]) - T.writes(lv43_local[v0, v1]) - lv43_local[v0, v1] = lv43[v0, v1] - for k_2 in range(T.int64(8)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("decode"): - v_i = T.axis.spatial( - T.int64(10240), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2 - + ax0, - ) - v_j = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads( - lv43_local[v_i // T.int64(8), v_j], - lv44_local[v_i // T.int64(32), v_j], - ) - T.writes(decode_local[v_i, v_j]) - decode_local[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv43_local[ - v_i // T.int64(8), - v_j, - ], - T.Cast( - "uint32", - v_i % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv44_local[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): - for ax2 in T.vectorized(T.int64(1)): - with T.block("lv63_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial( - (n + T.int64(31)) - // T.int64(32) - * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + ax1, - ) - v2 = T.axis.spatial( - T.int64(10240), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2 - + ax2, - ) - T.reads(lv63[v0, v1, v2]) - T.writes(lv63_pad_local[v0, v1, v2]) - lv63_pad_local[ - v0, v1, v2 - ] = T.if_then_else( - v1 < n, - lv63[v0, v1, v2], - T.float16(0), - ) - for i0_i1_fused_1_2 in range(T.int64(4)): - for i2_2 in T.vectorized(T.int64(8)): - with T.block("NT_matmul_update"): - v_i0 = T.axis.spatial( - T.int64(1), T.int64(0) - ) - v_i1 = T.axis.spatial( - (n + T.int64(31)) - // T.int64(32) - * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + i0_i1_fused_1_2, - ) - v_i2 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + i2_2, - ) - v_k = T.axis.reduce( - T.int64(10240), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2, - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ], - lv63_pad_local[v_i0, v_i1, v_k], - decode_local[v_k, v_i2], - ) - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] + T.Cast( - "float32", - lv63_pad_local[v_i0, v_i1, v_k], - ) * T.Cast( - "float32", decode_local[v_k, v_i2] - ) - for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): - for ax2 in T.vectorized(T.int64(8)): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial( - (n + T.int64(31)) // T.int64(32) * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + ax1, - ) - v2 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax2, - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ], - linear_bias5[v2], - lv7[v0, v1, v2], - ) - T.writes(p_output0_intermediate[v0, v1, v2]) - if v1 < n: - p_output0_intermediate[v0, v1, v2] = ( - T.Cast( - "float16", - var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ] - + linear_bias5[v2], - ) - + lv7[v0, v1, v2] - ) - - -@T.prim_func(private=True) -def fused_decode1_fused_NT_matmul1_add4_add5( - lv29: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), - lv30: T.Buffer((T.int64(80), T.int64(2560)), "float16"), - p_lv49: T.handle, - linear_bias3: T.Buffer((T.int64(2560),), "float16"), - p_lv2: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560)), "float16") - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560)), "float16") - p_output0_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(2560)), "float16" - ) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") - var_T_transpose_intermediate = T.alloc_buffer( - (T.int64(2560), T.int64(2560)), "float16" - ) - var_NT_matmul_intermediate = T.alloc_buffer( - (T.int64(1), n, T.int64(2560)), "float16" - ) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i, j in T.grid(T.int64(2560), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv29[v_i // T.int64(8), v_j], lv30[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv29[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv30[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(2560), T.int64(2560)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) - var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv49[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv49[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias3[v_ax2] - ) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias3[v_ax2] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_T_add_intermediate[v_ax0, v_ax1, v_ax2], lv2[v_ax0, v_ax1, v_ax2] - ) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] + lv2[v_ax0, v_ax1, v_ax2] - ) - - -@T.prim_func(private=True) -def fused_decode1_fused_NT_matmul1_add4_add5_after( - lv29: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), - lv30: T.Buffer((T.int64(80), T.int64(2560)), "float16"), - p_lv49: T.handle, - linear_bias3: T.Buffer((T.int64(2560),), "float16"), - p_lv2: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True), "tir.noalias": T.bool(True)}) - n = T.int64() - lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560)), "float16") - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560)), "float16") - p_output0_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(2560)), "float16" - ) - with T.block("root"): - T.reads() - T.writes() - T.block_attr({"meta_schedule.thread_extent_low_inclusive": 32}) - decode_local = T.alloc_buffer( - (T.int64(2560), T.int64(2560)), "float16", scope="local" - ) - lv29_local = T.alloc_buffer( - (T.int64(320), T.int64(2560)), "uint32", scope="local" - ) - lv30_local = T.alloc_buffer( - (T.int64(80), T.int64(2560)), "float16", scope="local" - ) - lv49_pad_local = T.alloc_buffer( - (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), - "float16", - scope="local", - ) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer( - (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), - "float16", - scope="local", - ) - for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( - (n + T.int64(31)) // T.int64(32), thread="blockIdx.y" - ): - for i2_0 in T.thread_binding(T.int64(20), thread="blockIdx.x"): - for i0_i1_fused_1_1 in T.thread_binding( - T.int64(8), thread="threadIdx.y" - ): - for i2_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): - for i0_i1_fused_1_2_init in range(T.int64(4)): - for i2_2_init in T.vectorized(T.int64(8)): - with T.block("NT_matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial( - (n + T.int64(31)) // T.int64(32) * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + i0_i1_fused_1_2_init, - ) - v_i2 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + i2_2_init, - ) - T.reads() - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = T.float16(0) - for k_0_0, k_0_1 in T.grid(T.int64(20), T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("lv30_local"): - v0 = T.axis.spatial( - T.int64(80), - k_0_0 * T.int64(4) + k_0_1 + ax0, - ) - v1 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads(lv30[v0, v1]) - T.writes(lv30_local[v0, v1]) - lv30_local[v0, v1] = lv30[v0, v1] - for k_1 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("lv29_local"): - v0 = T.axis.spatial( - T.int64(320), - k_0_0 * T.int64(16) - + k_0_1 * T.int64(4) - + k_1 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads(lv29[v0, v1]) - T.writes(lv29_local[v0, v1]) - lv29_local[v0, v1] = lv29[v0, v1] - for k_2 in range(T.int64(8)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("decode"): - v_i = T.axis.spatial( - T.int64(2560), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2 - + ax0, - ) - v_j = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads( - lv29_local[v_i // T.int64(8), v_j], - lv30_local[v_i // T.int64(32), v_j], - ) - T.writes(decode_local[v_i, v_j]) - decode_local[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv29_local[ - v_i // T.int64(8), - v_j, - ], - T.Cast( - "uint32", - v_i % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv30_local[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): - for ax2 in T.vectorized(T.int64(1)): - with T.block("lv49_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial( - (n + T.int64(31)) - // T.int64(32) - * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + ax1, - ) - v2 = T.axis.spatial( - T.int64(2560), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2 - + ax2, - ) - T.reads(lv49[v0, v1, v2]) - T.writes(lv49_pad_local[v0, v1, v2]) - lv49_pad_local[ - v0, v1, v2 - ] = T.if_then_else( - v1 < n, - lv49[v0, v1, v2], - T.float16(0), - ) - for i0_i1_fused_1_2 in range(T.int64(4)): - for i2_2 in T.vectorized(T.int64(8)): - with T.block("NT_matmul_update"): - v_i0 = T.axis.spatial( - T.int64(1), T.int64(0) - ) - v_i1 = T.axis.spatial( - (n + T.int64(31)) - // T.int64(32) - * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + i0_i1_fused_1_2, - ) - v_i2 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + i2_2, - ) - v_k = T.axis.reduce( - T.int64(2560), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2, - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ], - lv49_pad_local[v_i0, v_i1, v_k], - decode_local[v_k, v_i2], - ) - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = ( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - + lv49_pad_local[v_i0, v_i1, v_k] - * decode_local[v_k, v_i2] - ) - for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): - for ax2 in T.vectorized(T.int64(8)): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial( - (n + T.int64(31)) // T.int64(32) * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + ax1, - ) - v2 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax2, - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ], - linear_bias3[v2], - lv2[v0, v1, v2], - ) - T.writes(p_output0_intermediate[v0, v1, v2]) - if v1 < n: - p_output0_intermediate[v0, v1, v2] = ( - var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ] - + linear_bias3[v2] - + lv2[v0, v1, v2] - ) - - -@T.prim_func(private=True) -def fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5_cast7( - lv1345: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), - lv1346: T.Buffer((T.int64(320), T.int64(2560)), "float16"), - p_lv2047: T.handle, - linear_bias191: T.Buffer((T.int64(2560),), "float32"), - p_lv317: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv2047 = T.match_buffer(p_lv2047, (T.int64(1), n, T.int64(10240)), "float16") - lv317 = T.match_buffer(p_lv317, (T.int64(1), n, T.int64(2560)), "float16") - p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") - var_T_transpose_intermediate = T.alloc_buffer( - (T.int64(2560), T.int64(10240)), "float16" - ) - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_compute_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_compute_intermediate_1 = T.alloc_buffer( - (T.int64(1), n, T.int64(2560)), "float16" - ) - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i, j in T.grid(T.int64(10240), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1345[v_i // T.int64(8), v_j], lv1346[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1345[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1346[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(2560), T.int64(10240)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(var_T_transpose_intermediate[v_ax0, v_ax1]) - var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2047[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[ - v_i0, v_i1, v_i2 - ] + T.Cast("float32", lv2047[v_i0, v_i1, v_k]) * T.Cast( - "float32", var_T_transpose_intermediate[v_i2, v_k] - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias191[v_ax2] - ) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = ( - var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias191[v_ax2] - ) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast( - "float16", var_T_add_intermediate[v_i0, v_i1, v_i2] - ) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) - var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[ - v_i0, v_i1, v_i2 - ] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], - lv317[v_ax0, v_ax1, v_ax2], - ) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = ( - var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] - + lv317[v_ax0, v_ax1, v_ax2] - ) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute_2"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) - p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast( - "float32", var_T_add_intermediate_1[v_i0, v_i1, v_i2] - ) - - -@T.prim_func(private=True) -def fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5_cast7_after( - lv1345: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), - lv1346: T.Buffer((T.int64(320), T.int64(2560)), "float16"), - p_lv2047: T.handle, - linear_bias191: T.Buffer((T.int64(2560),), "float32"), - p_lv317: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True), "tir.noalias": T.bool(True)}) - n = T.int64() - lv2047 = T.match_buffer(p_lv2047, (T.int64(1), n, T.int64(10240)), "float16") - lv317 = T.match_buffer(p_lv317, (T.int64(1), n, T.int64(2560)), "float16") - p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) - with T.block("root"): - T.reads() - T.writes() - T.block_attr({"meta_schedule.thread_extent_low_inclusive": 32}) - decode_local = T.alloc_buffer( - (T.int64(10240), T.int64(2560)), "float16", scope="local" - ) - lv1345_local = T.alloc_buffer( - (T.int64(1280), T.int64(2560)), "uint32", scope="local" - ) - lv1346_local = T.alloc_buffer( - (T.int64(320), T.int64(2560)), "float16", scope="local" - ) - lv2047_pad_local = T.alloc_buffer( - ( - T.int64(1), - (n + T.int64(31)) // T.int64(32) * T.int64(32), - T.int64(10240), - ), - "float16", - scope="local", - ) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer( - (T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), - scope="local", - ) - for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding( - (n + T.int64(31)) // T.int64(32), thread="blockIdx.y" - ): - for i2_0 in T.thread_binding(T.int64(20), thread="blockIdx.x"): - for i0_i1_fused_1_1 in T.thread_binding( - T.int64(8), thread="threadIdx.y" - ): - for i2_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"): - for i0_i1_fused_1_2_init in range(T.int64(4)): - for i2_2_init in T.vectorized(T.int64(8)): - with T.block("NT_matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial( - (n + T.int64(31)) // T.int64(32) * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + i0_i1_fused_1_2_init, - ) - v_i2 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + i2_2_init, - ) - T.reads() - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = T.float32(0) - for k_0_0, k_0_1 in T.grid(T.int64(80), T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("lv1346_local"): - v0 = T.axis.spatial( - T.int64(320), - k_0_0 * T.int64(4) + k_0_1 + ax0, - ) - v1 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads(lv1346[v0, v1]) - T.writes(lv1346_local[v0, v1]) - lv1346_local[v0, v1] = lv1346[v0, v1] - for k_1 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("lv1345_local"): - v0 = T.axis.spatial( - T.int64(1280), - k_0_0 * T.int64(16) - + k_0_1 * T.int64(4) - + k_1 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads(lv1345[v0, v1]) - T.writes(lv1345_local[v0, v1]) - lv1345_local[v0, v1] = lv1345[v0, v1] - for k_2 in range(T.int64(8)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(8)): - with T.block("decode"): - v_i = T.axis.spatial( - T.int64(10240), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2 - + ax0, - ) - v_j = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + ax1, - ) - T.reads( - lv1345_local[ - v_i // T.int64(8), v_j - ], - lv1346_local[ - v_i // T.int64(32), v_j - ], - ) - T.writes(decode_local[v_i, v_j]) - decode_local[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv1345_local[ - v_i // T.int64(8), - v_j, - ], - T.Cast( - "uint32", - v_i % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv1346_local[ - v_i // T.int64(32), v_j - ] - for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): - for ax2 in T.vectorized(T.int64(1)): - with T.block("lv2047_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial( - (n + T.int64(31)) - // T.int64(32) - * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + ax1, - ) - v2 = T.axis.spatial( - T.int64(10240), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2 - + ax2, - ) - T.reads(lv2047[v0, v1, v2]) - T.writes(lv2047_pad_local[v0, v1, v2]) - lv2047_pad_local[ - v0, v1, v2 - ] = T.if_then_else( - v1 < n, - lv2047[v0, v1, v2], - T.float16(0), - ) - for i0_i1_fused_1_2 in range(T.int64(4)): - for i2_2 in T.vectorized(T.int64(8)): - with T.block("NT_matmul_update"): - v_i0 = T.axis.spatial( - T.int64(1), T.int64(0) - ) - v_i1 = T.axis.spatial( - (n + T.int64(31)) - // T.int64(32) - * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + i0_i1_fused_1_2, - ) - v_i2 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) - + i2_1 * T.int64(8) - + i2_2, - ) - v_k = T.axis.reduce( - T.int64(10240), - k_0_0 * T.int64(128) - + k_0_1 * T.int64(32) - + k_1 * T.int64(8) - + k_2, - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ], - lv2047_pad_local[v_i0, v_i1, v_k], - decode_local[v_k, v_i2], - ) - T.writes( - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] - ) - var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] = var_NT_matmul_intermediate_pad_local[ - v_i0, v_i1, v_i2 - ] + T.Cast( - "float32", - lv2047_pad_local[v_i0, v_i1, v_k], - ) * T.Cast( - "float32", decode_local[v_k, v_i2] - ) - for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): - for ax2 in T.vectorized(T.int64(8)): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial( - (n + T.int64(31)) // T.int64(32) * T.int64(32), - i0_i1_fused_0_i0_i1_fused_1_0_fused - * T.int64(32) - + i0_i1_fused_1_1 * T.int64(4) - + ax1, - ) - v2 = T.axis.spatial( - T.int64(2560), - i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax2, - ) - T.reads( - var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ], - linear_bias191[v2], - lv317[v0, v1, v2], - ) - T.writes(p_output0_intermediate[v0, v1, v2]) - if v1 < n: - p_output0_intermediate[v0, v1, v2] = T.Cast( - "float32", - T.Cast( - "float16", - var_NT_matmul_intermediate_pad_local[ - v0, v1, v2 - ] - + linear_bias191[v2], - ) - + lv317[v0, v1, v2], - ) - - -@T.prim_func(private=True) -def fused_decode2_NT_matmul( - lv4: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), - lv5: T.Buffer((T.int64(128), T.int64(12288)), "float16"), - p_lv6: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(4096)), "float16") - var_NT_matmul_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(12288)), "float16" - ) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16") - p_output0_intermediate = T.alloc_buffer((T.int64(12288), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(12288)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv4[v_i // T.int64(8), v_j], lv5[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv4[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv5[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(12288), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(p_output0_intermediate[v_ax0, v_ax1]) - p_output0_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(12288), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv6[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv6[v_i0, v_i1, v_k] * p_output0_intermediate[v_i2, v_k] - ) - - -@T.prim_func(private=True) -def fused_decode2_NT_matmul_after( - lv8: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), - lv9: T.Buffer((T.int64(128), T.int64(12288)), "float16"), - p_lv6: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - lv6 = T.match_buffer(p_lv6, (1, n, 4096), "float16") - var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 12288), "float16") - - var_matmul_intermediate_local = T.alloc_buffer( - (T.int64(1), ((n+7)//8) * 8, T.int64(12288)), "float16", scope="local" - ) - var_matmul_intermediate_local_batch = T.alloc_buffer( - (T.int64(1), ((n+7)//8) * 8, T.int64(12288)), "float16", scope="local" - ) - lv8_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local") - lv9_local = T.alloc_buffer( - (T.int64(128), T.int64(12288)), "float16", scope="local" - ) - #lv6_shared = T.alloc_buffer( - # (T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared" - #) - for i0_i1_i2_fused_n in T.thread_binding(((n+7)//8), thread="blockIdx.y"): - for i0_i1_i2_fused_0 in T.thread_binding(T.int64(96), thread="blockIdx.x"): - for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - with T.block("n_check"): - T.where((i0_i1_i2_fused_n * T.int64(8) + ax2_y) < n) - for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2_init - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_1 in range(T.int64(128)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("matmul_init_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2k = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads() - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = T.float16(0) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv9_local"): - v0 = T.axis.spatial( - T.int64(128), k_1 - ) - v1 = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv9[v0, v1]) - T.writes(lv9_local[v0, v1]) - lv9_local[v0, v1] = lv9[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv8_local"): - v0 = T.axis.spatial( - T.int64(512), - k_1 * T.int64(4) - + k_2 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv8[v0, v1]) - T.writes(lv8_local[v0, v1]) - lv8_local[v0, v1] = lv8[v0, v1] - for k_3 in range(T.int64(8)): - for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_k = T.axis.reduce( - T.int64(4096), - k_1 * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - T.reads( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ], - lv6[v_i0, v_i1, v_k], - lv8_local[v_k // T.int64(8), v_i2], - ) - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] = var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] + lv6[ - v_i0, v_i1, v_k - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv8_local[ - v_k // T.int64(8), v_i2 - ], - T.Cast( - "uint32", - v_k % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - ) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("multiple_scale"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - v0 = T.axis.spatial( - T.int64(128), - k_1 - ) - v1 = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads( - lv9_local[v0, v1], - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ], - ) - T.writes( - var_matmul_intermediate_local[v_i0, v_i1, v_i2] - ) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate_local[v_i0, v_i1, v_i2] - + var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] - * lv9_local[v0, v1] - ) - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(12288), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax2, - ) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] - - -@T.prim_func(private=True) -def fused_decode4_NT_matmul3( - lv13: T.Buffer((T.int64(512), T.int64(22016)), "uint32"), - lv14: T.Buffer((T.int64(128), T.int64(22016)), "float16"), - p_lv45: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16") - var_NT_matmul_intermediate = T.match_buffer( - p_output0, (T.int64(1), n, T.int64(22016)), "float16" - ) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(4096), T.int64(22016)), "float16") - p_output0_intermediate = T.alloc_buffer((T.int64(22016), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(22016)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv13[v_i // T.int64(8), v_j], lv14[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv13[v_i // T.int64(8), v_j], - T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) * lv14[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(22016), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(p_output0_intermediate[v_ax0, v_ax1]) - p_output0_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(22016), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv45[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = ( - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] - + lv45[v_i0, v_i1, v_k] * p_output0_intermediate[v_i2, v_k] - ) - - -@T.prim_func(private=True) -def fused_decode4_NT_matmul3_after( - lv8: T.Buffer((T.int64(512), T.int64(22016)), "uint32"), - lv9: T.Buffer((T.int64(128), T.int64(22016)), "float16"), - p_lv6: T.handle, - p_output0: T.handle, -): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - lv6 = T.match_buffer(p_lv6, (1, n, 4096), "float16") - var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 22016), "float16") - - var_matmul_intermediate_local = T.alloc_buffer( - (T.int64(1), ((n+7)//8) * 8, T.int64(22016)), "float16", scope="local" - ) - var_matmul_intermediate_local_batch = T.alloc_buffer( - (T.int64(1), ((n+7)//8) * 8, T.int64(22016)), "float16", scope="local" - ) - lv8_local = T.alloc_buffer((T.int64(512), T.int64(22016)), "uint32", scope="local") - lv9_local = T.alloc_buffer( - (T.int64(128), T.int64(22016)), "float16", scope="local" - ) - #lv6_shared = T.alloc_buffer( - # (T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared" - #) - for i0_i1_i2_fused_n in T.thread_binding(((n+7)//8), thread="blockIdx.y"): - for i0_i1_i2_fused_0 in T.thread_binding(T.int64(172), thread="blockIdx.x"): - for i0_i1_i2_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax2_y in T.thread_binding(T.int64(8), thread="threadIdx.y"): - with T.block("n_check"): - T.where((i0_i1_i2_fused_n * T.int64(8) + ax2_y) < n) - for i0_i1_i2_fused_2_init in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2_init - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_1 in range(T.int64(128)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("matmul_init_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2k = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads() - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2k - ] = T.float16(0) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv9_local"): - v0 = T.axis.spatial( - T.int64(128), k_1 - ) - v1 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv9[v0, v1]) - T.writes(lv9_local[v0, v1]) - lv9_local[v0, v1] = lv9[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("lv8_local"): - v0 = T.axis.spatial( - T.int64(512), - k_1 * T.int64(4) - + k_2 - + ax0, - ) - v1 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads(lv8[v0, v1]) - T.writes(lv8_local[v0, v1]) - lv8_local[v0, v1] = lv8[v0, v1] - for k_3 in range(T.int64(8)): - for i0_i1_i2_fused_2 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + i0_i1_i2_fused_2, - ) - v_k = T.axis.reduce( - T.int64(4096), - k_1 * T.int64(32) - + k_2 * T.int64(8) - + k_3, - ) - T.reads( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ], - lv6[v_i0, v_i1, v_k], - lv8_local[v_k // T.int64(8), v_i2], - ) - T.writes( - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] - ) - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] = var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] + lv6[ - v_i0, v_i1, v_k - ] * ( - ( - T.Cast( - "float16", - T.bitwise_and( - T.shift_right( - lv8_local[ - v_k // T.int64(8), v_i2 - ], - T.Cast( - "uint32", - v_k % T.int64(8), - ) - * T.uint32(4), - ), - T.uint32(15), - ), - ) - - T.float16(7) - ) - ) - for ax0 in range(T.int64(1)): - for ax1 in T.vectorized(T.int64(4)): - with T.block("multiple_scale"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - v0 = T.axis.spatial( - T.int64(128), - k_1 - ) - v1 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax1, - ) - T.reads( - lv9_local[v0, v1], - var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ], - ) - T.writes( - var_matmul_intermediate_local[v_i0, v_i1, v_i2] - ) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = ( - var_matmul_intermediate_local[v_i0, v_i1, v_i2] - + var_matmul_intermediate_local_batch[ - v_i0, v_i1, v_i2 - ] - * lv9_local[v0, v1] - ) - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - for ax2 in T.vectorized(T.int64(4)): - with T.block("var_matmul_intermediate_local"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+7)//8) * 8, i0_i1_i2_fused_n * T.int64(8) + ax2_y) - v_i2 = T.axis.spatial( - T.int64(22016), - i0_i1_i2_fused_0 * T.int64(128) - + i0_i1_i2_fused_1 * T.int64(4) - + ax2, - ) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] - - - -@T.prim_func(private=True) -def fused_NT_matmul1_divide2_maximum1_minimum1_cast3(lv1593: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), p_lv1603: T.handle, p_lv1582: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv1603 = T.match_buffer(p_lv1603, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") - lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv1593[v_i0, v_i2, v_i1, v_k], lv1603[v_i0, v_i3, v_i1, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1593[v_i0, v_i2, v_i1, v_k] * lv1603[v_i0, v_i3, v_i1, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - -@T.prim_func(private=True) -def fused_NT_matmul1_divide2_maximum1_minimum1_cast3_after( - lv1593: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), - p_lv1603: T.handle, - p_lv1582: T.handle, - p_output0: T.handle -): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - n = T.int64() - lv1603 = T.match_buffer(p_lv1603, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") - lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) - var_matmul_intermediate_local = T.alloc_buffer( - (1, ((n + 7) // 8) * 8, 4096), "float16", scope="local" - ) - lv1593_shared = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(1024)), "float16", scope="shared" - ) - for i_by in T.thread_binding(T.int64((n + 7) // 8), thread="blockIdx.y"): - for i_bx in T.thread_binding(T.int64(32), thread="blockIdx.x"): - for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i_v8 in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(n), i_by * T.int64(8) + i_ty) - v_i2 = T.axis.spatial( - T.int64(4096), - i_bx * T.int64(128) - + i_tx * T.int64(4) - + i_v8, - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - with T.block("lv1593_shared"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(32), i_bx) - v_i3 = T.axis.spatial(T.int64(128), i_tx * T.int64(4) + i_v8) - T.reads(lv1593[v_i0, v_i1, v_i2, v_i3]) - T.writes(lv1593_shared[v_i0, v_i1, v_i3]) - lv1593_shared[v_i0, v_i1, v_i3] = lv1593[v_i0, v_i1, v_i2, v_i3] - with T.block("matmul_compute"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1_1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(n), i_by * T.int64(8) + i_ty) - v_i2 = T.axis.spatial(T.int64(32), i_bx) - v_i3 = T.axis.spatial(T.int64(128), i_tx * T.int64(4) + i_v8) - v_ik = T.axis.spatial(T.int64(4096), i_bx * T.int64(128) + i_tx * T.int64(4) + i_v8) - T.where(i_by * T.int64(8) + i_ty < n) - T.reads(lv1593_shared[v_i0, v_i1_1, v_i3], lv1603[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_ik]) - var_matmul_intermediate_local[v_i0, v_i1, v_ik] = var_matmul_intermediate_local[v_i0, v_i1, v_ik] + lv1603[v_i0, v_i1, v_i2, v_i3] * lv1593_shared[v_i0, v_i1_1, v_i3] - for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i_v8 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1_1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(n), i_by * T.int64(8) + i_ty) - v_ik = T.axis.spatial(T.int64(4096), i_bx * T.int64(128) + i_tx * T.int64(4) + i_v8) - v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_ik]) - T.writes(lv1593_shared[v_i0, v_i1_1, v_i2]) - lv1593_shared[v_i0, v_i1_1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_ik] - for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i_v8 in T.vectorized(T.int64(4)): - with T.block("reduction_1"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) - T.where(i_tx < T.int64(16)) - T.reads(lv1593_shared[v_i0, v_i1, v_i2]) - T.writes(lv1593_shared[v_i0, v_i1, v_i2]) - lv1593_shared[v_i0, v_i1, v_i2] = lv1593_shared[v_i0, v_i1, v_i2] + lv1593_shared[v_i0, v_i1, v_i2 + T.int64(64)] - for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i_v8 in T.vectorized(T.int64(4)): - with T.block("reduction_2"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) - T.where(i_tx < T.int64(8)) - T.reads(lv1593_shared[v_i0, v_i1, v_i2]) - T.writes(lv1593_shared[v_i0, v_i1, v_i2]) - lv1593_shared[v_i0, v_i1, v_i2] = lv1593_shared[v_i0, v_i1, v_i2] + lv1593_shared[v_i0, v_i1, v_i2 + T.int64(32)] - for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i_v8 in T.vectorized(T.int64(4)): - with T.block("reduction_3"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) - T.where(i_tx < T.int64(4)) - T.reads(lv1593_shared[v_i0, v_i1, v_i2]) - T.writes(lv1593_shared[v_i0, v_i1, v_i2]) - lv1593_shared[v_i0, v_i1, v_i2] = lv1593_shared[v_i0, v_i1, v_i2] + lv1593_shared[v_i0, v_i1, v_i2 + T.int64(16)] - for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i_v8 in T.vectorized(T.int64(4)): - with T.block("reduction_4"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) - T.where(i_tx < T.int64(2)) - T.reads(lv1593_shared[v_i0, v_i1, v_i2]) - T.writes(lv1593_shared[v_i0, v_i1, v_i2]) - lv1593_shared[v_i0, v_i1, v_i2] = lv1593_shared[v_i0, v_i1, v_i2] + lv1593_shared[v_i0, v_i1, v_i2 + T.int64(8)] - for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i_v8 in T.vectorized(T.int64(4)): - with T.block("reduction_4"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) - T.where(i_tx < T.int64(1)) - T.reads(lv1593_shared[v_i0, v_i1, v_i2]) - T.writes(lv1593_shared[v_i0, v_i1, v_i2]) - lv1593_shared[v_i0, v_i1, v_i2] = lv1593_shared[v_i0, v_i1, v_i2] + lv1593_shared[v_i0, v_i1, v_i2 + T.int64(4)] - for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for ax0 in range(T.int64(1)): - with T.block("Output_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(32), i_bx) - v_i2 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i3 = T.axis.spatial(T.int64(n), i_by * T.int64(8) + i_ty) - v_ik = T.axis.spatial(T.int64(1024), i_ty * T.int64(128)) - T.where(i_by * T.int64(8) + i_ty < n) - T.reads(lv1593_shared[v_i0, v_i2, v_ik]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", T.min(T.max((lv1593_shared[v_i0, v_i2, v_ik] + lv1593_shared[v_i0, v_i2, v_ik + T.int64(1)] - + lv1593_shared[v_i0, v_i2, v_ik + T.int64(2)] + lv1593_shared[v_i0, v_i2, v_ik + T.int64(3)]) - * T.float16(0.088397790055248615), T.float16(-65504)), lv1582[v_i0, T.int64(0), v_i2, v_i3])) - - - -# [gx,gy, gz] [lx, ly, lz] - -@T.prim_func(private=True) -def NT_matmul3(var_A: T.handle, var_B: T.handle, NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), T.int64(1), n), "float16") - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128), n): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_k, v_i2, v_i3], B[v_i0, v_i2, v_i1, v_k]) - T.writes(NT_matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - NT_matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - NT_matmul[v_i0, v_i1, v_i2, v_i3] = NT_matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_k, v_i2, v_i3] * B[v_i0, v_i2, v_i1, v_k] - -@T.prim_func(private=True) -def NT_matmul3_after( - var_A: T.handle, - var_B: T.handle, - NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16") -): - - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_matmul_intermediate_local = T.alloc_buffer( - (1, 8, 4096), "float16", scope="local" - ) - B_shared = T.alloc_buffer( - (T.int64(1), T.int64(1), T.int64(1024)), "float16", scope="shared" - ) - for i_bx in T.thread_binding(T.int64(32), thread="blockIdx.x"): - for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i_v8 in T.vectorized(T.int64(4)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(8), i_ty) - v_i2 = T.axis.spatial( - T.int64(4096), - i_bx * T.int64(128) + i_tx * T.int64(4) - + i_v8, - ) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float16(0) - for ax0 in range((n+255)//256): - with T.block("B_shared"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(32), i_bx) - v_i2 = T.axis.spatial(((n+255)//256) * 256, ax0 * T.int64(256) + i_ty * T.int64(32) + i_tx) - v_i2k = T.axis.spatial(T.int64(256), i_ty * T.int64(32) + i_tx) - #T.where(ax0 * T.int64(256) + i_ty * T.int64(32) + i_tx < n) - T.reads(B[v_i0, v_i1, T.int64(0), v_i2]) - T.writes(B_shared[v_i0, v_i1, v_i2k]) - B_shared[v_i0, T.int64(0), v_i2k] = T.if_then_else(v_i2 < n, B[v_i0, v_i1, T.int64(0), v_i2], T.float16(0)) - for ax1 in range(32): - #with T.block("n_check"): - # T.where(ax0 * T.int64(256) + ax1 * T.int64(8) + i_ty < n) - for i_v8 in T.vectorized(T.int64(4)): - with T.block("matmul_compute"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(((n+255)//256) * 256, ax0 * T.int64(256) + ax1 * T.int64(8) + i_ty) - v_i1_1 = T.axis.spatial(T.int64(8), i_ty) - v_i2 = T.axis.spatial(T.int64(32), i_bx) - v_i3 = T.axis.spatial(T.int64(128), i_tx * T.int64(4) + i_v8) - v_ik = T.axis.spatial(T.int64(256), ax1 * T.int64(8) + i_ty) - v_ik1 = T.axis.spatial(T.int64(4096), i_bx * T.int64(128) + i_tx * T.int64(4) + i_v8) - T.reads(B_shared[v_i0, T.int64(0), v_ik], A[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1_1, v_ik1]) - var_matmul_intermediate_local[v_i0, v_i1_1, v_ik1] = var_matmul_intermediate_local[v_i0, v_i1_1, v_ik1] + T.if_then_else(v_i1 < n, A[v_i0, v_i1, v_i2, v_i3], T.float16(0)) * B_shared[v_i0, T.int64(0), v_ik] - - for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i_v8 in T.vectorized(T.int64(4)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(8), i_ty) - v_i2 = T.axis.spatial(T.int64(4096), i_bx * T.int64(128) + i_tx * T.int64(4) + i_v8) - v_ik = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - T.writes(B_shared[v_i0, T.int64(0), v_ik]) - B_shared[v_i0, T.int64(0), v_ik] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] - for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i_v8 in T.vectorized(T.int64(4)): - with T.block("reduction_1"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) - T.where(i_ty < T.int64(4)) - T.reads(B_shared[v_i0, v_i1, v_i2]) - T.writes(B_shared[v_i0, v_i1, v_i2]) - B_shared[v_i0, v_i1, v_i2] = B_shared[v_i0, v_i1, v_i2] + B_shared[v_i0, v_i1, v_i2 + T.int64(512)] - for i_tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i_ty in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i_v8 in T.vectorized(T.int64(4)): - with T.block("Output_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(32), i_bx) - v_i3 = T.axis.spatial(T.int64(128), i_tx * T.int64(4) + i_v8) - v_ik = T.axis.spatial(T.int64(1024), i_ty * T.int64(128) + i_tx * T.int64(4) + i_v8) - T.where(i_ty < 1) - T.reads(B_shared[v_i0, v_i1, v_ik]) - T.writes(NT_matmul[v_i0, v_i1, v_i2, v_i3]) - NT_matmul[v_i0, v_i1, v_i2, v_i3] = B_shared[v_i0, v_i1, v_ik] + B_shared[v_i0, v_i1, v_ik + T.int64(128)] + B_shared[v_i0, v_i1, v_ik + T.int64(256)] + B_shared[v_i0, v_i1, v_ik + T.int64(384)] - -@T.prim_func(private=True) -def rms_norm(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") - rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): - Ared_temp = T.alloc_buffer((T.int64(1), n)) - for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("Ared_temp"): - v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) - T.reads(A[v_bsz, v_i, v_k]) - T.writes(Ared_temp[v_bsz, v_i]) - with T.init(): - Ared_temp[v_bsz, v_i] = T.float32(0) - Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k]) - for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("rms_norm"): - v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) - T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) - T.writes(rms_norm_1[v_bsz, v_i, v_k]) - rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) - -@T.prim_func(private=True) -def rms_norm_after(var_A: T.handle, B: T.Buffer((4096,), "float16"), var_rms_norm: T.handle): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - n = T.int32() - A = T.match_buffer(var_A, (1, n, 4096), "float16") - rms_norm_1 = T.match_buffer(var_rms_norm, (1, n, 4096), "float16") - # with T.block("root"): - Ared_temp_shared = T.alloc_buffer((1, n), scope="shared") - Ared_temp_rf_local = T.alloc_buffer((64, 1, n), scope="local") - for ax0_fused in T.thread_binding(n, thread="blockIdx.x"): - for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - with T.block("Ared_temp_rf_init"): - vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused]) - T.reads() - T.writes(Ared_temp_rf_local[vax1_fused_1, 0, v0]) - Ared_temp_rf_local[vax1_fused_1, 0, v0] = T.float32(0) - for ax1_fused_0, u in T.grid(64, 1): - with T.block("Ared_temp_rf_update"): - vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0]) - T.reads(Ared_temp_rf_local[vax1_fused_1, 0, v0], A[0, v0, vax1_fused_0 * 64 + vax1_fused_1]) - T.writes(Ared_temp_rf_local[vax1_fused_1, 0, v0]) - Ared_temp_rf_local[vax1_fused_1, 0, v0] = Ared_temp_rf_local[vax1_fused_1, 0, v0] + T.Cast("float32", A[0, v0, vax1_fused_0 * 64 + vax1_fused_1]) * T.Cast("float32", A[0, v0, vax1_fused_0 * 64 + vax1_fused_1]) - for ax1_fused in range(1): - for ax0 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("Ared_temp"): - vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused]) - T.reads(Ared_temp_rf_local[vax1_fused_1, 0, v0]) - T.writes(Ared_temp_shared[0, v0]) - with T.init(): - Ared_temp_shared[0, v0] = T.float32(0) - Ared_temp_shared[0, v0] = Ared_temp_shared[0, v0] + Ared_temp_rf_local[vax1_fused_1, 0, v0] - for ax0_fused_0 in range(64): - for ax0_fused_1 in T.thread_binding(64, thread="threadIdx.x"): - with T.block("rms_norm"): - v0 = T.axis.spatial(n, ax0_fused) - v1 = T.axis.spatial(4096, ax0_fused_0 * 64 + ax0_fused_1) - T.reads(B[v1], A[0, v0, v1], Ared_temp_shared[0, v0]) - T.writes(rms_norm_1[0, v0, v1]) - rms_norm_1[0, v0, v1] = T.Cast("float16", T.Cast("float32", B[v1]) * (T.Cast("float32", A[0, v0, v1]) / T.sqrt(Ared_temp_shared[0, v0] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) - -@T.prim_func(private=True) -def slice(var_A: T.handle, slice_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): - for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("slice"): - v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k]) - T.reads(A[v_i, n - T.int64(1), v_k]) - T.writes(slice_1[v_i, v_j, v_k]) - slice_1[v_i, v_j, v_k] = A[v_i, n - T.int64(1), v_k] - -@T.prim_func(private=True) -def slice_after(var_A: T.handle, slice_1: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - n = T.int32() - A = T.match_buffer(var_A, (1, n, 4096), "float16") - # with T.block("root"): - for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): - for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): - with T.block("slice"): - v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1) - T.reads(A[0, n - 1, v0]) - T.writes(slice_1[0, 0, v0]) - slice_1[0, 0, v0] = A[0, n - 1, v0] - -@T.prim_func(private=True) -def NT_matmul2(var_A: T.handle, var_B: T.handle, var_NT_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - m = T.int64() - A = T.match_buffer(var_A, (T.int64(1), m, T.int64(32), T.int64(128)), "float16") - n = T.int64() - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, m), "float16") - NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), n, T.int64(32), T.int64(128), m): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_k, v_i2, v_i3], B[v_i0, v_i2, v_i1, v_k]) - T.writes(NT_matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - NT_matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - NT_matmul[v_i0, v_i1, v_i2, v_i3] = NT_matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_k, v_i2, v_i3] * B[v_i0, v_i2, v_i1, v_k] - -@T.prim_func(private=True) -def NT_matmul2_after(var_A: T.handle, var_B: T.handle, var_NT_matmul: T.handle): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - m = T.int32() - A = T.match_buffer(var_A, (1, m, 32, 128), "float16") - n = T.int32() - B = T.match_buffer(var_B, (1, 32, n, m), "float16") - NT_matmul = T.match_buffer(var_NT_matmul, (1, n, 32, 128), "float16") - # with T.block("root"): - NT_matmul_reindex_pad_local = T.alloc_buffer((32, 128, (n + 63) // 64 * 64), "float16", scope="local") - A_reindex_pad_shared = T.alloc_buffer((32, 128, (m + 15) // 16 * 16), "float16", scope="shared") - B_reindex_pad_shared = T.alloc_buffer((32, (n + 63) // 64 * 64, (m + 15) // 16 * 16), "float16", scope="shared") - for ax0_ax2_0_fused in T.thread_binding((n + 63) // 64 * 32, thread="blockIdx.y"): - for ax1_0 in T.thread_binding(4, thread="blockIdx.x"): - for ax2_1 in T.thread_binding(1, thread="vthread.y"): - for ax1_1 in T.thread_binding(1, thread="vthread.x"): - for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): - for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax2_3_init, ax1_3_init in T.grid(4, 4): - with T.block("NT_matmul_init"): - v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((n + 63) // 64)) - v1 = T.axis.spatial(128, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) - v2 = T.axis.spatial((n + 63) // 64 * 64, ax0_ax2_0_fused % ((n + 63) // 64) * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init) - T.reads() - T.writes(NT_matmul_reindex_pad_local[v0, v1, v2]) - NT_matmul_reindex_pad_local[v0, v1, v2] = T.float16(0) - for ax3_0 in range((m + 15) // 16): - for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): - for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): - for ax0_ax1_ax2_fused_2 in range(2): - for ax0_ax1_ax2_fused_3 in T.vectorized(2): - with T.block("A_reindex_pad_shared"): - v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((n + 63) // 64)) - v1 = T.axis.spatial(128, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) - v2 = T.axis.spatial((m + 15) // 16 * 16, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) - T.reads(A[0, v2, v0, v1]) - T.writes(A_reindex_pad_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) - A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < m, A[0, v2, v0, v1], T.float16(0)) - for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): - for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): - for ax0_ax1_ax2_fused_2 in range(4): - for ax0_ax1_ax2_fused_3 in T.vectorized(2): - with T.block("B_reindex_pad_shared"): - v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((n + 63) // 64)) - v1 = T.axis.spatial((n + 63) // 64 * 64, ax0_ax2_0_fused % ((n + 63) // 64) * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) - v2 = T.axis.spatial((m + 15) // 16 * 16, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) - T.reads(B[0, v0, v1, v2]) - T.writes(B_reindex_pad_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) - B_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n and v2 < m, B[0, v0, v1, v2], T.float16(0)) - for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4): - with T.block("NT_matmul_update"): - v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((n + 63) // 64)) - v1 = T.axis.spatial(128, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) - v2 = T.axis.spatial((n + 63) // 64 * 64, ax0_ax2_0_fused % ((n + 63) // 64) * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3) - v3 = T.axis.reduce((m + 15) // 16 * 16, ax3_0 * 16 + ax3_1) - T.reads(NT_matmul_reindex_pad_local[v0, v1, v2], A_reindex_pad_shared[v0, v1, v3], B_reindex_pad_shared[v0, v2, v3]) - T.writes(NT_matmul_reindex_pad_local[v0, v1, v2]) - NT_matmul_reindex_pad_local[v0, v1, v2] = NT_matmul_reindex_pad_local[v0, v1, v2] + A_reindex_pad_shared[v0, v1, v3] * B_reindex_pad_shared[v0, v2, v3] - for ax0, ax1, ax2_0 in T.grid(1, 4, 2): - for ax2_1_1 in T.vectorized(2): - with T.block("NT_matmul_reindex_pad_local"): - v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((n + 63) // 64) + ax0) - v1 = T.axis.spatial(128, ax1_0 * 32 + ax1_2 * 4 + ax1) - v2 = T.axis.spatial((n + 63) // 64 * 64, ax0_ax2_0_fused % ((n + 63) // 64) * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) - T.reads(NT_matmul_reindex_pad_local[v0, v1, v2]) - T.writes(NT_matmul[0, v2, v0, v1]) - if v2 < n: - NT_matmul[0, v2, v0, v1] = NT_matmul_reindex_pad_local[v0, v1, v2] - - -def get_dict_key(func): - return tvm.ir.structural_hash(func), func - - -tir_dispatch_dict = { - get_dict_key(fused_decode4_matmul3): fused_decode4_matmul3_after, - get_dict_key( - fused_decode6_fused_matmul7_add1 - ): fused_decode6_fused_matmul7_add1_after, - get_dict_key( - fused_decode5_fused_matmul6_multiply1 - ): fused_decode5_fused_matmul6_multiply1_after, - get_dict_key( - fused_decode5_fused_matmul6_silu1 - ): fused_decode5_fused_matmul6_silu1_after, - get_dict_key( - fused_decode4_fused_matmul4_add1 - ): fused_decode4_fused_matmul4_add1_after, - get_dict_key( - fused_decode3_fused_matmul1_cast2 - ): fused_decode3_fused_matmul1_cast2_after, - get_dict_key( - fused_decode2_fused_NT_matmul3_add - ): fused_decode2_fused_NT_matmul3_add_after, - get_dict_key(fused_decode_NT_matmul): fused_decode_NT_matmul_after, - get_dict_key(fused_decode2_NT_matmul): fused_decode2_NT_matmul_after, - get_dict_key(fused_decode4_NT_matmul3): fused_decode4_NT_matmul3_after, - get_dict_key( - fused_decode1_fused_NT_matmul2_silu - ): fused_decode1_fused_NT_matmul2_silu_after, - get_dict_key( - fused_decode1_fused_NT_matmul2_multiply - ): fused_decode1_fused_NT_matmul2_multiply_after, - get_dict_key( - fused_decode_fused_NT_matmul_add - ): fused_decode_fused_NT_matmul_add_after, - get_dict_key( - fused_decode4_fused_matmul6_add4 - ): sch_fused_decode4_fused_matmul6_add4(fused_decode4_fused_matmul6_add4), - get_dict_key( - fused_decode6_fused_matmul9_add7_cast8_cast12_add5 - ): sch_fused_decode6_fused_matmul9_add7_cast8_cast12_add5( - fused_decode6_fused_matmul9_add7_cast8_cast12_add5 - ), - get_dict_key( - fused_decode5_fused_matmul8_add6_gelu1_cast11 - ): sch_fused_decode5_fused_matmul8_add6_gelu1_cast11( - fused_decode5_fused_matmul8_add6_gelu1_cast11 - ), - get_dict_key(fused_decode81_fused_matmul1_cast2 - ): sch_fused_decode81_fused_matmul1_cast2(fused_decode81_fused_matmul1_cast2 - ), - get_dict_key( - fused_decode4_fused_matmul6_add4_add5 - ): sch_fused_decode4_fused_matmul6_add4_add5(fused_decode4_fused_matmul6_add4_add5), - get_dict_key(fused_decode3_matmul3): sch_fused_decode3_matmul3( - fused_decode3_matmul3 - ), - get_dict_key( - fused_decode6_fused_matmul9_add7_cast8_cast12_add5_cast7 - ): sch_fused_decode6_fused_matmul9_add7_cast8_cast12_add5_cast7( - fused_decode6_fused_matmul9_add7_cast8_cast12_add5_cast7 - ), - get_dict_key( - fused_decode2_fused_NT_matmul3_add6_gelu1_cast11 - ): fused_decode2_fused_NT_matmul3_add6_gelu1_cast11_after, - get_dict_key( - fused_decode1_fused_NT_matmul1_add4 - ): fused_decode1_fused_NT_matmul1_add4_after, - get_dict_key( - fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5 - ): fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5_after, - get_dict_key( - fused_decode1_fused_NT_matmul1_add4_add5 - ): fused_decode1_fused_NT_matmul1_add4_add5_after, - get_dict_key( - fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5_cast7 - ): fused_decode3_fused_NT_matmul4_add7_cast8_cast12_add5_cast7_after, - get_dict_key(fused_fused_decode9_matmul7): fused_fused_decode9_matmul7_after, - get_dict_key(fused_fused_decode7_matmul4): fused_fused_decode7_matmul4_after, - get_dict_key(fused_NT_matmul1_divide2_maximum1_minimum1_cast3): fused_NT_matmul1_divide2_maximum1_minimum1_cast3_after, - get_dict_key(NT_matmul3): NT_matmul3_after, - get_dict_key(slice): slice_after, - get_dict_key(rms_norm): rms_norm_after, - get_dict_key(NT_matmul2): NT_matmul2_after, -} - - -def lookup_func(func): - for (hash_value, func_before), f_after in tir_dispatch_dict.items(): - if tvm.ir.structural_hash(func) == hash_value and tvm.ir.structural_equal( - func, func_before - ): - return f_after - return None - - -@tvm.transform.module_pass(opt_level=0, name="DispatchTIROperatorAdreno") -class DispatchTIROperatorAdreno: - def transform_module( - self, mod: IRModule, ctx: tvm.transform.PassContext - ) -> IRModule: - for gv in mod.functions: - scheduled_func = lookup_func(mod[gv]) - if scheduled_func is not None: - mod[gv] = scheduled_func - - return mod diff --git a/mlc_llm/dispatch/gpt_neox/__init__.py b/mlc_llm/dispatch/gpt_neox/__init__.py deleted file mode 100644 index cdf7c94f46..0000000000 --- a/mlc_llm/dispatch/gpt_neox/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -def lookup(func): - from . import dolly_v2_3b, redpajama_incite_chat_3b_v1, redpajama_q4f32 - - ret = dolly_v2_3b.lookup(func) - if ret is not None: - return ret - ret = redpajama_incite_chat_3b_v1.lookup(func) - if ret is not None: - return ret - ret = redpajama_q4f32.lookup(func) - if ret is not None: - return ret - return None diff --git a/mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py b/mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py deleted file mode 100644 index 274f08131f..0000000000 --- a/mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py +++ /dev/null @@ -1,1034 +0,0 @@ -# pylint: disable=missing-docstring,line-too-long,invalid-name,too-many-statements,too-many-locals -import tvm -from tvm import tir -from tvm.script import tir as T - -from .dolly_v2_3b_mod import Module as MOD - - -# fmt: off -def fused_NT_matmul1_add3(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l3, l4, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[2, 4, 8, 1, 2]) - l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[40, 2, 16, 2, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[320, 8, 1]) - l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="blockIdx.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="vthread.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) - b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) - b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) - l52, l53, l54 = sch.get_loops(block=b47)[-3:] - sch.fuse(l52, l53, l54, preserve_unit_iters=True) - v56 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) - b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) - l62, l63 = sch.get_loops(block=b57)[-2:] - sch.fuse(l62, l63, preserve_unit_iters=True) - v65 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) - sch.reverse_compute_inline(block=b1) - v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=2) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) - sch.enter_postproc() - sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") - l71 = sch.get_loops(block=b47)[-1] - _, l73, l74 = sch.split(loop=l71, factors=[None, 128, 4], preserve_unit_iters=True) - sch.vectorize(loop=l74) - sch.bind(loop=l73, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - l79 = sch.get_loops(block=b57)[-1] - _, l81, l82 = sch.split(loop=l79, factors=[None, 128, 4], preserve_unit_iters=True) - sch.vectorize(loop=l82) - sch.bind(loop=l81, thread_axis="threadIdx.x") - b83 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") - b120 = sch.get_block(name="NT_matmul", func_name="main") - l124 = sch.get_loops(block=b120)[4] - sch.decompose_reduction(block=b120, loop=l124) - - b1 = sch.get_block("lv10_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, b84, b85, b86, b87 = sch.get_child_blocks(b83) - l88 = sch.get_loops(block=b84)[0] - sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) - l95 = sch.get_loops(block=b85)[0] - sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) - l102 = sch.get_loops(block=b86)[0] - sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) - l114 = sch.get_loops(block=b87)[0] - sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) - - -def fused_NT_matmul1_add3_add5_add5(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="T_add_1", func_name="main") - b3 = sch.get_block(name="T_add_2", func_name="main") - b4 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l5, l6, l7, l8 = sch.get_loops(block=b0) - v9, v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l14, l15, l16, l17, l18 = sch.split(loop=l5, factors=[v9, v10, v11, v12, v13], preserve_unit_iters=True) - v19, v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[2, 8, 4, 2, 1]) - l24, l25, l26, l27, l28 = sch.split(loop=l6, factors=[v19, v20, v21, v22, v23], preserve_unit_iters=True) - v29, v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[20, 1, 64, 2, 1]) - l34, l35, l36, l37, l38 = sch.split(loop=l7, factors=[v29, v30, v31, v32, v33], preserve_unit_iters=True) - v39, v40, v41 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64, decision=[320, 1, 8]) - l42, l43, l44 = sch.split(loop=l8, factors=[v39, v40, v41], preserve_unit_iters=True) - sch.reorder(l14, l24, l34, l15, l25, l35, l16, l26, l36, l42, l43, l17, l27, l37, l44, l18, l28, l38) - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="blockIdx.x") - l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) - sch.bind(loop=l46, thread_axis="vthread.x") - l47 = sch.fuse(l16, l26, l36, preserve_unit_iters=True) - sch.bind(loop=l47, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) - b48 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b48, loop=l47, preserve_unit_loops=True, index=-1) - b49 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b49, loop=l42, preserve_unit_loops=True, index=-1) - l54, l55, l56 = sch.get_loops(block=b49)[-3:] - sch.fuse(l54, l55, l56, preserve_unit_iters=True) - v58 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_val=v58) - b59 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b59, loop=l42, preserve_unit_loops=True, index=-1) - l64, l65 = sch.get_loops(block=b59)[-2:] - sch.fuse(l64, l65, preserve_unit_iters=True) - v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b59, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) - sch.reverse_compute_inline(block=b3) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v68 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=4) - sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v68) - sch.enter_postproc() - sch.unannotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch") - l73 = sch.get_loops(block=b49)[-1] - _, l75, l76 = sch.split(loop=l73, factors=[None, 256, 4], preserve_unit_iters=True) - sch.vectorize(loop=l76) - sch.bind(loop=l75, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b59, ann_key="meta_schedule.cooperative_fetch") - l81 = sch.get_loops(block=b59)[-1] - _, l83, l84 = sch.split(loop=l81, factors=[None, 256, 4], preserve_unit_iters=True) - sch.vectorize(loop=l84) - sch.bind(loop=l83, thread_axis="threadIdx.x") - b85 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b85, ann_key="meta_schedule.unroll_explicit") - b122 = sch.get_block(name="NT_matmul", func_name="main") - l126 = sch.get_loops(block=b122)[4] - sch.decompose_reduction(block=b122, loop=l126) - - b1 = sch.get_block("lv48_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, b86, b87, b88, b89 = sch.get_child_blocks(b85) - l90 = sch.get_loops(block=b86)[0] - sch.annotate(block_or_loop=l90, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l90, ann_key="pragma_unroll_explicit", ann_val=1) - l97 = sch.get_loops(block=b87)[0] - sch.annotate(block_or_loop=l97, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l97, ann_key="pragma_unroll_explicit", ann_val=1) - l104 = sch.get_loops(block=b88)[0] - sch.annotate(block_or_loop=l104, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l104, ann_key="pragma_unroll_explicit", ann_val=1) - l116 = sch.get_loops(block=b89)[0] - sch.annotate(block_or_loop=l116, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l116, ann_key="pragma_unroll_explicit", ann_val=1) - - -def fused_NT_matmul1_add3_add5_add5_cast5(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="T_add_1", func_name="main") - b3 = sch.get_block(name="T_add_2", func_name="main") - b4 = sch.get_block(name="compute", func_name="main") - b5 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l6, l7, l8, l9 = sch.get_loops(block=b0) - v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l15, l16, l17, l18, l19 = sch.split(loop=l6, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) - v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[2, 2, 16, 2, 1]) - l25, l26, l27, l28, l29 = sch.split(loop=l7, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) - v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[64, 2, 10, 1, 2]) - l35, l36, l37, l38, l39 = sch.split(loop=l8, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) - v40, v41, v42 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[64, 20, 2]) - l43, l44, l45 = sch.split(loop=l9, factors=[v40, v41, v42], preserve_unit_iters=True) - sch.reorder(l15, l25, l35, l16, l26, l36, l17, l27, l37, l43, l44, l18, l28, l38, l45, l19, l29, l39) - l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) - sch.bind(loop=l46, thread_axis="blockIdx.x") - l47 = sch.fuse(l16, l26, l36, preserve_unit_iters=True) - sch.bind(loop=l47, thread_axis="vthread.x") - l48 = sch.fuse(l17, l27, l37, preserve_unit_iters=True) - sch.bind(loop=l48, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) - b49 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b49, loop=l48, preserve_unit_loops=True, index=-1) - b50 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b50, loop=l43, preserve_unit_loops=True, index=-1) - l55, l56, l57 = sch.get_loops(block=b50)[-3:] - sch.fuse(l55, l56, l57, preserve_unit_iters=True) - v59 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b50, ann_key="meta_schedule.cooperative_fetch", ann_val=v59) - b60 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b60, loop=l43, preserve_unit_loops=True, index=-1) - l65, l66 = sch.get_loops(block=b60)[-2:] - sch.fuse(l65, l66, preserve_unit_iters=True) - v68 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v68) - sch.reverse_compute_inline(block=b4) - sch.reverse_compute_inline(block=b3) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v69 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=3) - sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v69) - sch.enter_postproc() - sch.unannotate(block_or_loop=b50, ann_key="meta_schedule.cooperative_fetch") - l74 = sch.get_loops(block=b50)[-1] - _, l76, l77 = sch.split(loop=l74, factors=[None, 160, 4], preserve_unit_iters=True) - sch.vectorize(loop=l77) - sch.bind(loop=l76, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") - l82 = sch.get_loops(block=b60)[-1] - _, l84, l85 = sch.split(loop=l82, factors=[None, 160, 2], preserve_unit_iters=True) - sch.vectorize(loop=l85) - sch.bind(loop=l84, thread_axis="threadIdx.x") - b86 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b86, ann_key="meta_schedule.unroll_explicit") - b123 = sch.get_block(name="NT_matmul", func_name="main") - l127 = sch.get_loops(block=b123)[4] - sch.decompose_reduction(block=b123, loop=l127) - - b1 = sch.get_block("lv1815_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, b87, b88, b89, b90 = sch.get_child_blocks(b86) - l91 = sch.get_loops(block=b87)[0] - sch.annotate(block_or_loop=l91, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l91, ann_key="pragma_unroll_explicit", ann_val=1) - l98 = sch.get_loops(block=b88)[0] - sch.annotate(block_or_loop=l98, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l98, ann_key="pragma_unroll_explicit", ann_val=1) - l105 = sch.get_loops(block=b89)[0] - sch.annotate(block_or_loop=l105, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l105, ann_key="pragma_unroll_explicit", ann_val=1) - l117 = sch.get_loops(block=b90)[0] - sch.annotate(block_or_loop=l117, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l117, ann_key="pragma_unroll_explicit", ann_val=1) - - -def fused_NT_matmul3_add4_gelu1(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="T_multiply", func_name="main") - b3 = sch.get_block(name="compute", func_name="main") - b4 = sch.get_block(name="compute_1", func_name="main") - b5 = sch.get_block(name="compute_2", func_name="main") - b6 = sch.get_block(name="T_multiply_1", func_name="main") - b7 = sch.get_block(name="T_add_1", func_name="main") - b8 = sch.get_block(name="T_multiply_2", func_name="main") - b9 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l10, l11, l12, l13 = sch.get_loops(block=b0) - v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l10, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l19, l20, l21, l22, l23 = sch.split(loop=l10, factors=[v14, v15, v16, v17, v18], preserve_unit_iters=True) - v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l11, n=5, max_innermost_factor=64, decision=[2, 4, 8, 1, 2]) - l29, l30, l31, l32, l33 = sch.split(loop=l11, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True) - v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l12, n=5, max_innermost_factor=64, decision=[160, 4, 16, 1, 1]) - l39, l40, l41, l42, l43 = sch.split(loop=l12, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True) - v44, v45, v46 = sch.sample_perfect_tile(loop=l13, n=3, max_innermost_factor=64, decision=[64, 20, 2]) - l47, l48, l49 = sch.split(loop=l13, factors=[v44, v45, v46], preserve_unit_iters=True) - sch.reorder(l19, l29, l39, l20, l30, l40, l21, l31, l41, l47, l48, l22, l32, l42, l49, l23, l33, l43) - l50 = sch.fuse(l19, l29, l39, preserve_unit_iters=True) - sch.bind(loop=l50, thread_axis="blockIdx.x") - l51 = sch.fuse(l20, l30, l40, preserve_unit_iters=True) - sch.bind(loop=l51, thread_axis="vthread.x") - l52 = sch.fuse(l21, l31, l41, preserve_unit_iters=True) - sch.bind(loop=l52, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) - b53 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b53, loop=l52, preserve_unit_loops=True, index=-1) - b54 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b54, loop=l47, preserve_unit_loops=True, index=-1) - l59, l60, l61 = sch.get_loops(block=b54)[-3:] - sch.fuse(l59, l60, l61, preserve_unit_iters=True) - v63 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b54, ann_key="meta_schedule.cooperative_fetch", ann_val=v63) - b64 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b64, loop=l47, preserve_unit_loops=True, index=-1) - l69, l70 = sch.get_loops(block=b64)[-2:] - sch.fuse(l69, l70, preserve_unit_iters=True) - v72 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b64, ann_key="meta_schedule.cooperative_fetch", ann_val=v72) - sch.compute_inline(block=b7) - sch.compute_inline(block=b6) - sch.compute_inline(block=b5) - sch.compute_inline(block=b4) - sch.compute_inline(block=b3) - sch.compute_inline(block=b2) - sch.compute_inline(block=b1) - sch.reverse_compute_inline(block=b8) - v73 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=3) - sch.annotate(block_or_loop=b9, ann_key="meta_schedule.unroll_explicit", ann_val=v73) - sch.enter_postproc() - sch.unannotate(block_or_loop=b54, ann_key="meta_schedule.cooperative_fetch") - l85 = sch.get_loops(block=b54)[-1] - _, l87, l88 = sch.split(loop=l85, factors=[None, 128, 4], preserve_unit_iters=True) - sch.vectorize(loop=l88) - sch.bind(loop=l87, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b64, ann_key="meta_schedule.cooperative_fetch") - l93 = sch.get_loops(block=b64)[-1] - _, l95, l96 = sch.split(loop=l93, factors=[None, 128, 4], preserve_unit_iters=True) - sch.vectorize(loop=l96) - sch.bind(loop=l95, thread_axis="threadIdx.x") - b97 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b97, ann_key="meta_schedule.unroll_explicit") - b138 = sch.get_block(name="NT_matmul", func_name="main") - l142 = sch.get_loops(block=b138)[4] - sch.decompose_reduction(block=b138, loop=l142) - - b1 = sch.get_block("lv52_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - b98, b99, b100, b101, b102 = sch.get_child_blocks(b97) - l103 = sch.get_loops(block=b98)[0] - sch.annotate(block_or_loop=l103, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l103, ann_key="pragma_unroll_explicit", ann_val=1) - l110 = sch.get_loops(block=b99)[0] - sch.annotate(block_or_loop=l110, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l110, ann_key="pragma_unroll_explicit", ann_val=1) - l117 = sch.get_loops(block=b100)[0] - sch.annotate(block_or_loop=l117, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l117, ann_key="pragma_unroll_explicit", ann_val=1) - l129 = sch.get_loops(block=b101)[0] - sch.annotate(block_or_loop=l129, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l129, ann_key="pragma_unroll_explicit", ann_val=1) - l135 = sch.get_loops(block=b102)[0] - sch.annotate(block_or_loop=l135, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l135, ann_key="pragma_unroll_explicit", ann_val=1) - - -def fused_NT_matmul4_add3(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l3, l4, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 16, 1, 2]) - l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[128, 1, 5, 2, 2]) - l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[256, 20, 2]) - l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="blockIdx.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="vthread.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) - b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) - b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) - l52, l53, l54 = sch.get_loops(block=b47)[-3:] - sch.fuse(l52, l53, l54, preserve_unit_iters=True) - v56 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) - b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) - l62, l63 = sch.get_loops(block=b57)[-2:] - sch.fuse(l62, l63, preserve_unit_iters=True) - v65 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) - sch.reverse_compute_inline(block=b1) - v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=2) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) - sch.enter_postproc() - sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") - l71 = sch.get_loops(block=b47)[-1] - _, l73, l74 = sch.split(loop=l71, factors=[None, 80, 2], preserve_unit_iters=True) - sch.vectorize(loop=l74) - sch.bind(loop=l73, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - l79 = sch.get_loops(block=b57)[-1] - _, l81 = sch.split(loop=l79, factors=[None, 80], preserve_unit_iters=True) - sch.bind(loop=l81, thread_axis="threadIdx.x") - b82 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") - b118 = sch.get_block(name="NT_matmul", func_name="main") - l122 = sch.get_loops(block=b118)[4] - sch.decompose_reduction(block=b118, loop=l122) - - b1 = sch.get_block("lv56_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, b83, b84, b85, b86 = sch.get_child_blocks(b82) - l87 = sch.get_loops(block=b83)[0] - sch.annotate(block_or_loop=l87, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l87, ann_key="pragma_unroll_explicit", ann_val=1) - l94 = sch.get_loops(block=b84)[0] - sch.annotate(block_or_loop=l94, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l94, ann_key="pragma_unroll_explicit", ann_val=1) - l100 = sch.get_loops(block=b85)[0] - sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) - l112 = sch.get_loops(block=b86)[0] - sch.annotate(block_or_loop=l112, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l112, ann_key="pragma_unroll_explicit", ann_val=1) - - -def fused_NT_matmul_divide_maximum_minimum_cast2(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - - sch.pad_einsum(b0, [1, 1, 1, 32, 1]) - l1, l2, l3, l4, l5 = sch.get_loops(b0) - l6, l7 = sch.split(l4, [None, 32]) - sch.reorder(l6, l1, l2, l3, l7, l5) - - b1 = sch.get_block(name="T_divide", func_name="main") - b2 = sch.get_block(name="T_maximum", func_name="main") - b3 = sch.get_block(name="T_minimum", func_name="main") - b4 = sch.get_block(name="compute", func_name="main") - b5 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) - v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l16, l17, l18, l19, l20 = sch.split(loop=l6, factors=[v11, v12, v13, v14, v15], preserve_unit_iters=True) - v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[8, 2, 2, 1, 1]) - l26, l27, l28, l29, l30 = sch.split(loop=l7, factors=[v21, v22, v23, v24, v25], preserve_unit_iters=True) - v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l36, l37, l38, l39, l40 = sch.split(loop=l8, factors=[v31, v32, v33, v34, v35], preserve_unit_iters=True) - v41, v42, v43, v44, v45 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[4, 1, 32, 1, 1]) - l46, l47, l48, l49, l50 = sch.split(loop=l9, factors=[v41, v42, v43, v44, v45], preserve_unit_iters=True) - v51, v52, v53 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64, decision=[2, 1, 40]) - l54, l55, l56 = sch.split(loop=l10, factors=[v51, v52, v53], preserve_unit_iters=True) - sch.reorder(l16, l26, l36, l46, l17, l27, l37, l47, l18, l28, l38, l48, l54, l55, l19, l29, l39, l49, l56, l20, l30, l40, l50) - l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) - sch.bind(loop=l57, thread_axis="blockIdx.x") - l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) - sch.bind(loop=l58, thread_axis="vthread.x") - l59 = sch.fuse(l18, l28, l38, l48, preserve_unit_iters=True) - sch.bind(loop=l59, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) - b60 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b60, loop=l59, preserve_unit_loops=True, index=-1) - b61 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b61, loop=l54, preserve_unit_loops=True, index=-1) - l66, l67, l68, l69 = sch.get_loops(block=b61)[-4:] - sch.fuse(l66, l67, l68, l69, preserve_unit_iters=True) - v71 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v71) - b72 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b72, loop=l54, preserve_unit_loops=True, index=-1) - l77, l78, l79, l80 = sch.get_loops(block=b72)[-4:] - sch.fuse(l77, l78, l79, l80, preserve_unit_iters=True) - v82 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v82) - sch.reverse_compute_inline(block=b4) - sch.reverse_compute_inline(block=b3) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v83 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=3) - sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v83) - sch.enter_postproc() - sch.unannotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch") - l88 = sch.get_loops(block=b61)[-1] - _, l90 = sch.split(loop=l88, factors=[None, 64], preserve_unit_iters=True) - sch.bind(loop=l90, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch") - l95 = sch.get_loops(block=b72)[-1] - _, l97 = sch.split(loop=l95, factors=[None, 64], preserve_unit_iters=True) - sch.bind(loop=l97, thread_axis="threadIdx.x") - b98 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b98, ann_key="meta_schedule.unroll_explicit") - - b136 = sch.get_block(name="NT_matmul", func_name="main") - l140 = sch.get_loops(block=b136)[4] - sch.decompose_reduction(block=b136, loop=l140) - - b1 = sch.get_block("lv1870_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, b99, b100, b101, b102 = sch.get_child_blocks(b98) - l103 = sch.get_loops(block=b99)[0] - sch.annotate(block_or_loop=l103, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l103, ann_key="pragma_unroll_explicit", ann_val=1) - l109 = sch.get_loops(block=b100)[0] - sch.annotate(block_or_loop=l109, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l109, ann_key="pragma_unroll_explicit", ann_val=1) - l115 = sch.get_loops(block=b101)[0] - sch.annotate(block_or_loop=l115, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l115, ann_key="pragma_unroll_explicit", ann_val=1) - l129 = sch.get_loops(block=b102)[0] - sch.annotate(block_or_loop=l129, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l129, ann_key="pragma_unroll_explicit", ann_val=1) - - -def fused_NT_matmul2_divide1_maximum1_minimum1_cast7(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 1, 32, 32, 1]) - l1, l2, l3, l4, l5 = sch.get_loops(b0) - l6, l7 = sch.split(l3, [None, 32]) - l8, l9 = sch.split(l4, [None, 32]) - sch.reorder(l6, l8, l1, l2, l7, l9, l5) - - b1 = sch.get_block(name="T_divide", func_name="main") - b2 = sch.get_block(name="T_maximum", func_name="main") - b3 = sch.get_block(name="T_minimum", func_name="main") - b4 = sch.get_block(name="compute", func_name="main") - b5 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, _, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) - v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l16, l17, l18, l19, l20 = sch.split(loop=l6, factors=[v11, v12, v13, v14, v15], preserve_unit_iters=True) - v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[32, 1, 1, 1, 1]) - l26, l27, l28, l29, l30 = sch.split(loop=l7, factors=[v21, v22, v23, v24, v25], preserve_unit_iters=True) - v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 4, 1, 16]) - l36, l37, l38, l39, l40 = sch.split(loop=l8, factors=[v31, v32, v33, v34, v35], preserve_unit_iters=True) - v41, v42, v43, v44, v45 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[4, 2, 16, 1, 1]) - l46, l47, l48, l49, l50 = sch.split(loop=l9, factors=[v41, v42, v43, v44, v45], preserve_unit_iters=True) - v51, v52, v53 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64, decision=[10, 1, 8]) - l54, l55, l56 = sch.split(loop=l10, factors=[v51, v52, v53], preserve_unit_iters=True) - sch.reorder(l16, l26, l36, l46, l17, l27, l37, l47, l18, l28, l38, l48, l54, l55, l19, l29, l39, l49, l56, l20, l30, l40, l50) - l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) - sch.bind(loop=l57, thread_axis="blockIdx.x") - l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) - sch.bind(loop=l58, thread_axis="vthread.x") - l59 = sch.fuse(l18, l28, l38, l48, preserve_unit_iters=True) - sch.bind(loop=l59, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) - b60 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b60, loop=l59, preserve_unit_loops=True, index=-1) - b61 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b61, loop=l54, preserve_unit_loops=True, index=-1) - l66, l67, l68, l69 = sch.get_loops(block=b61)[-4:] - sch.fuse(l66, l67, l68, l69, preserve_unit_iters=True) - v71 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v71) - b72 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b72, loop=l54, preserve_unit_loops=True, index=-1) - l77, l78, l79, l80 = sch.get_loops(block=b72)[-4:] - sch.fuse(l77, l78, l79, l80, preserve_unit_iters=True) - v82 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v82) - sch.reverse_compute_inline(block=b4) - sch.reverse_compute_inline(block=b3) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v83 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=4) - sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v83) - sch.enter_postproc() - sch.unannotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch") - l88 = sch.get_loops(block=b61)[-1] - _, l90, l91 = sch.split(loop=l88, factors=[None, 32, 4], preserve_unit_iters=True) - sch.vectorize(loop=l91) - sch.bind(loop=l90, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch") - l96 = sch.get_loops(block=b72)[-1] - _, l98, l99 = sch.split(loop=l96, factors=[None, 32, 2], preserve_unit_iters=True) - sch.vectorize(loop=l99) - sch.bind(loop=l98, thread_axis="threadIdx.x") - b100 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b100, ann_key="meta_schedule.unroll_explicit") - b140 = sch.get_block(name="NT_matmul", func_name="main") - l144 = sch.get_loops(block=b140)[5] - sch.decompose_reduction(block=b140, loop=l144) - - b1 = sch.get_block("lv35_pad") - sch.compute_inline(b1) - b1 = sch.get_block("lv36_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, b101, b102, b103, b104 = sch.get_child_blocks(b100) - l105 = sch.get_loops(block=b101)[0] - sch.annotate(block_or_loop=l105, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l105, ann_key="pragma_unroll_explicit", ann_val=1) - l112 = sch.get_loops(block=b102)[0] - sch.annotate(block_or_loop=l112, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l112, ann_key="pragma_unroll_explicit", ann_val=1) - l119 = sch.get_loops(block=b103)[0] - sch.annotate(block_or_loop=l119, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l119, ann_key="pragma_unroll_explicit", ann_val=1) - l133 = sch.get_loops(block=b104)[0] - sch.annotate(block_or_loop=l133, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l133, ann_key="pragma_unroll_explicit", ann_val=1) - - -def matmul1(sch: tir.Schedule): - b0 = sch.get_block(name="matmul", func_name="main") - sch.pad_einsum(b0, [1, 1, 1, 1, 32]) - l1, l2, l3, l4, k = sch.get_loops(b0) - k0, k1 = sch.split(k, [None, 32]) - sch.reorder(l1, l2, l3, k0, l4, k1) - - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - l2, l3, l4, _, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[8, 2, 2, 1, 1]) - l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[5, 1, 16, 1, 1]) - l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) - v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[4, 8, 1]) - l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, k0, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) - l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) - sch.bind(loop=l53, thread_axis="blockIdx.x") - l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) - sch.bind(loop=l54, thread_axis="vthread.x") - l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) - sch.bind(loop=l55, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) - b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) - b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) - l62, l63, l64, l65 = sch.get_loops(block=b57)[-4:] - sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) - v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) - b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) - l73, l74, l75, l76 = sch.get_loops(block=b68)[-4:] - sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) - v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) - v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=4) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) - sch.enter_postproc() - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - l84 = sch.get_loops(block=b57)[-1] - _, l86 = sch.split(loop=l84, factors=[None, 32], preserve_unit_iters=True) - sch.bind(loop=l86, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") - l91 = sch.get_loops(block=b68)[-1] - _, l93, l94 = sch.split(loop=l91, factors=[None, 32, 4], preserve_unit_iters=True) - sch.vectorize(loop=l94) - sch.bind(loop=l93, thread_axis="threadIdx.x") - b95 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b95, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("A_pad") - sch.compute_inline(b1) - b1 = sch.get_block("B_pad") - sch.compute_inline(b1) - - b96, b97, b98, b99 = sch.get_child_blocks(b95) - l100 = sch.get_loops(block=b96)[0] - sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) - l106 = sch.get_loops(block=b97)[0] - sch.annotate(block_or_loop=l106, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l106, ann_key="pragma_unroll_explicit", ann_val=1) - l113 = sch.get_loops(block=b98)[0] - sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) - l127 = sch.get_loops(block=b99)[0] - sch.annotate(block_or_loop=l127, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l127, ann_key="pragma_unroll_explicit", ann_val=1) - b134 = sch.get_block(name="matmul", func_name="main") - l138 = sch.get_loops(block=b134)[3] - sch.decompose_reduction(block=b134, loop=l138) - - -def matmul8(sch: tir.Schedule): - b0 = sch.get_block(name="matmul", func_name="main") - sch.pad_einsum(b0, [1, 1, 32, 1, 32]) - l1, l2, l3, l4, k = sch.get_loops(b0) - s0, s1 = sch.split(l3, [None, 32]) - k0, k1 = sch.split(k, [None, 32]) - sch.reorder(s0, l1, l2, s1, k0, l4, k1) - - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l2, l3, l4, _, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 1, 1, 2]) - l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 32, 4, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[2, 2, 5, 1, 4]) - l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) - v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[2, 2, 8]) - l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, k0, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) - l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) - sch.bind(loop=l53, thread_axis="blockIdx.x") - l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) - sch.bind(loop=l54, thread_axis="vthread.x") - l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) - sch.bind(loop=l55, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) - b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) - b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) - l62, l63, l64, l65 = sch.get_loops(block=b57)[-4:] - sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) - v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) - b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) - l73, l74, l75, l76 = sch.get_loops(block=b68)[-4:] - sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) - v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) - v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=3) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) - sch.enter_postproc() - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - l84 = sch.get_loops(block=b57)[-1] - _, l86, l87 = sch.split(loop=l84, factors=[None, 40, 2], preserve_unit_iters=True) - sch.vectorize(loop=l87) - sch.bind(loop=l86, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") - l92 = sch.get_loops(block=b68)[-1] - _, l94 = sch.split(loop=l92, factors=[None, 40], preserve_unit_iters=True) - sch.bind(loop=l94, thread_axis="threadIdx.x") - b95 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b95, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("A_pad") - sch.compute_inline(b1) - b1 = sch.get_block("B_pad") - sch.compute_inline(b1) - b1 = sch.get_block("matmul_pad") - sch.reverse_compute_inline(b1) - - b96, b97, b98, b99 = sch.get_child_blocks(b95) - l100 = sch.get_loops(block=b96)[0] - sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) - l107 = sch.get_loops(block=b97)[0] - sch.annotate(block_or_loop=l107, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l107, ann_key="pragma_unroll_explicit", ann_val=1) - l113 = sch.get_loops(block=b98)[0] - sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) - l127 = sch.get_loops(block=b99)[0] - sch.annotate(block_or_loop=l127, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l127, ann_key="pragma_unroll_explicit", ann_val=1) - b134 = sch.get_block(name="matmul", func_name="main") - l138= sch.get_loops(block=b134)[4] - sch.decompose_reduction(block=b134, loop=l138) - - -def fused_layer_norm1_cast6(sch: tir.Schedule): - b0 = sch.get_block(name="A_red_temp", func_name="main") - b1 = sch.get_block(name="T_layer_norm", func_name="main") - b2 = sch.get_block(name="compute", func_name="main") - b3 = sch.get_block(name="root", func_name="main") - sch.reverse_compute_inline(block=b2) - v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125], decision=5) - l5, l6, l7 = sch.get_loops(block=b0) - l8, l9 = sch.split(loop=l7, factors=[None, v4], preserve_unit_iters=True) - sch.bind(loop=l9, thread_axis="threadIdx.x") - v10 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=1) - sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v10) - l11, l12, l13 = sch.get_loops(block=b1) - l14 = sch.fuse(l11, l12, l13, preserve_unit_iters=True) - l15, l16, l17 = sch.split(loop=l14, factors=[None, 256, 256], preserve_unit_iters=True) - sch.reorder(l16, l17, l15) - sch.bind(loop=l16, thread_axis="blockIdx.x") - sch.bind(loop=l17, thread_axis="threadIdx.x") - l18, l19, l20, l21 = sch.get_loops(block=b0) - l22 = sch.fuse(l18, l19, preserve_unit_iters=True) - sch.bind(loop=l22, thread_axis="blockIdx.x") - sch.enter_postproc() - b23 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b23, ann_key="meta_schedule.unroll_explicit") - b24, b25 = sch.get_child_blocks(b23) - l26, l27, l28 = sch.get_loops(block=b24) - sch.annotate(block_or_loop=l26, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l26, ann_key="pragma_unroll_explicit", ann_val=1) - l29, l30, l31 = sch.get_loops(block=b25) - sch.annotate(block_or_loop=l29, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l29, ann_key="pragma_unroll_explicit", ann_val=1) - - -def layer_norm1(sch: tir.Schedule): - b0 = sch.get_block(name="A_red_temp", func_name="main") - b1 = sch.get_block(name="T_layer_norm", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125], decision=4) - l4, l5, l6 = sch.get_loops(block=b0) - l7, l8 = sch.split(loop=l6, factors=[None, v3], preserve_unit_iters=True) - sch.bind(loop=l8, thread_axis="threadIdx.x") - v9 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=3) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v9) - l10, l11, l12 = sch.get_loops(block=b1) - l13 = sch.fuse(l10, l11, l12, preserve_unit_iters=True) - l14, l15, l16 = sch.split(loop=l13, factors=[None, 256, 256], preserve_unit_iters=True) - sch.reorder(l15, l16, l14) - sch.bind(loop=l15, thread_axis="blockIdx.x") - sch.bind(loop=l16, thread_axis="threadIdx.x") - l17, l18, l19, l20 = sch.get_loops(block=b0) - l21 = sch.fuse(l17, l18, preserve_unit_iters=True) - sch.bind(loop=l21, thread_axis="blockIdx.x") - sch.enter_postproc() - b22 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.unroll_explicit") - b23, b24 = sch.get_child_blocks(b22) - l25, l26, l27 = sch.get_loops(block=b23) - sch.annotate(block_or_loop=l25, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l25, ann_key="pragma_unroll_explicit", ann_val=1) - l28, l29, l30 = sch.get_loops(block=b24) - sch.annotate(block_or_loop=l28, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l28, ann_key="pragma_unroll_explicit", ann_val=1) - - -def sch_softmax_cast(cast_to_fp16: bool): - def f(sch: tir.Schedule): - if cast_to_fp16: - b_cast = sch.get_block("compute") - sch.reverse_compute_inline(b_cast) - b0 = sch.get_block("T_softmax_exp") - sch.compute_inline(b0) - b1 = sch.get_block("T_softmax_norm") - l2, l3, l4, l5 = sch.get_loops(b1) - _, l7 = sch.split(l5, [None, 128]) - sch.bind(l7, "threadIdx.x") - b8 = sch.get_block("T_softmax_expsum") - sch.compute_at(b8, l4) - sch.set_scope(b8, 0, "shared") - _, _, _, l12 = sch.get_loops(b8) - _, l14 = sch.split(l12, [None, 128]) - sch.bind(l14, "threadIdx.x") - b15 = sch.get_block("T_softmax_maxelem") - sch.compute_at(b15, l4) - sch.set_scope(b15, 0, "shared") - _, _, _, l19 = sch.get_loops(b15) - _, l21 = sch.split(l19, [None, 128]) - sch.bind(l21, "threadIdx.x") - l22 = sch.fuse(l2, l3, l4) - sch.bind(l22, "blockIdx.x") - return f - - -@T.prim_func -def softmax_cast_mxn_before(p_lv37: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n, m = T.int64(), T.int64() - lv37 = T.match_buffer(p_lv37, (T.int64(1), T.int64(32), n, m)) - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv37[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv37[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(lv37[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv37[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - - -@T.prim_func -def softmax_cast_mxn_after(var_A: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - m = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m), dtype="float16") - # with T.block("root"): - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_maxelem_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)]) - T.writes(T_softmax_norm[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):m]) - T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") - T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) - T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float32(-3.4028234663852886e+38))) - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(0) - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]), T.float32(0)) - for i0_i1_i2_1_i3_fused_0 in range((T.int64(32) * T.int64(32) * m) // T.int64(128)): - for i0_i1_i2_1_i3_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_norm"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // T.int64(32) // m) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // m % T.int64(32)) - v_i3 = T.axis.spatial(m, (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) % m) - T.where(i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1 < T.int64(32) * T.int64(32) * m) - T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i], A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3], T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3]) - if v_i2_o * T.int64(32) + v_i2_i < n: - T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] = T.Cast("float16", T.exp(A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] - T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) / T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i]) - - -# fmt: on - - -def _get_dict(): - tvm.ir.assert_structural_equal(MOD["fused_softmax1_cast8"], softmax_cast_mxn_before) - func_dict = { - softmax_cast_mxn_before: softmax_cast_mxn_after, - } - for name, func in [ - ("fused_NT_matmul1_add3", fused_NT_matmul1_add3), - ("fused_NT_matmul1_add3_add5_add5", fused_NT_matmul1_add3_add5_add5), - ( - "fused_NT_matmul1_add3_add5_add5_cast5", - fused_NT_matmul1_add3_add5_add5_cast5, - ), - ("fused_NT_matmul3_add4_gelu1", fused_NT_matmul3_add4_gelu1), - ("fused_NT_matmul4_add3", fused_NT_matmul4_add3), - ( - "fused_NT_matmul_divide_maximum_minimum_cast2", - fused_NT_matmul_divide_maximum_minimum_cast2, - ), - ( - "fused_NT_matmul2_divide1_maximum1_minimum1_cast7", - fused_NT_matmul2_divide1_maximum1_minimum1_cast7, - ), - ("matmul1", matmul1), - ("matmul8", matmul8), - ("fused_softmax_cast3", sch_softmax_cast(True)), - ("fused_layer_norm1_cast6", fused_layer_norm1_cast6), - ("layer_norm1", layer_norm1), - ]: - sch = tir.Schedule(MOD[name]) - func(sch) - func_dict[MOD[name]] = sch.mod["main"] - return { - (tvm.ir.structural_hash(k), k): v.with_attr("tir.is_scheduled", True) - for k, v in func_dict.items() - } - - -DICT = _get_dict() - - -def lookup(func): - for (hash_value, func_before), f_after in DICT.items(): - if tvm.ir.structural_hash(func) == hash_value and tvm.ir.structural_equal( - func, func_before - ): - return f_after - return None diff --git a/mlc_llm/dispatch/gpt_neox/dolly_v2_3b_mod.py b/mlc_llm/dispatch/gpt_neox/dolly_v2_3b_mod.py deleted file mode 100644 index e3ff44ba59..0000000000 --- a/mlc_llm/dispatch/gpt_neox/dolly_v2_3b_mod.py +++ /dev/null @@ -1,511 +0,0 @@ -# pylint: disable=pointless-string-statement,invalid-name,missing-docstring,line-too-long,too-many-locals,too-many-arguments,too-many-statements -from tvm.script import ir as I -from tvm.script import tir as T - -""" -Operators: -- fused_NT_matmul1_add3 -- fused_NT_matmul1_add3_add5_add5 -- fused_NT_matmul1_add3_add5_add5_cast5 -- fused_NT_matmul2_divide1_maximum1_minimum1_cast7 -- fused_NT_matmul3_add4_gelu1 -- fused_NT_matmul4_add3 -- fused_NT_matmul_divide_maximum_minimum_cast2 -- matmul1 -- matmul8 -- fused_softmax1_cast8 -- fused_softmax_cast3 -- fused_layer_norm1_cast6 -- layer_norm1 -""" - -# fmt: off - -@I.ir_module -class Module: - @T.prim_func - def fused_NT_matmul1_add3(p_lv10: T.handle, lv1173: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv10 = T.match_buffer(p_lv10, (T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv10[v_i0, v_i1, v_k], lv1173[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv10[v_i0, v_i1, v_k] * lv1173[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias[v_ax2] - - @T.prim_func - def fused_NT_matmul1_add3_add5_add5(p_lv48: T.handle, lv1194: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias3: T.Buffer((T.int64(2560),), "float16"), p_lv60: T.handle, p_lv2: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(2560)), "float16") - lv60 = T.match_buffer(p_lv60, (T.int64(1), n, T.int64(2560)), "float16") - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate_2 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv48[v_i0, v_i1, v_k], lv1194[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv48[v_i0, v_i1, v_k] * lv1194[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias3[v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias3[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv60[v_ax0, v_ax1, v_ax2], var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate_2[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_2[v_ax0, v_ax1, v_ax2] = lv60[v_ax0, v_ax1, v_ax2] + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_2"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate_2[v_ax0, v_ax1, v_ax2], lv2[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate_2[v_ax0, v_ax1, v_ax2] + lv2[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_NT_matmul1_add3_add5_add5_cast5(p_lv1815: T.handle, lv2496: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias189: T.Buffer((T.int64(2560),), "float16"), p_lv1827: T.handle, p_lv1772: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv1815 = T.match_buffer(p_lv1815, (T.int64(1), n, T.int64(2560)), "float16") - lv1827 = T.match_buffer(p_lv1827, (T.int64(1), n, T.int64(2560)), "float16") - lv1772 = T.match_buffer(p_lv1772, (T.int64(1), n, T.int64(2560)), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate_2 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1815[v_i0, v_i1, v_k], lv2496[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv1815[v_i0, v_i1, v_k] * lv2496[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias189[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias189[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv1827[v_ax0, v_ax1, v_ax2], var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = lv1827[v_ax0, v_ax1, v_ax2] + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_2"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2], lv1772[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate_2[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_2[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + lv1772[v_ax0, v_ax1, v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate_2[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_T_add_intermediate_2[v_i0, v_i1, v_i2]) - - @T.prim_func - def fused_NT_matmul2_divide1_maximum1_minimum1_cast7(p_lv35: T.handle, p_lv36: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv35 = T.match_buffer(p_lv35, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - m = T.int64() - lv36 = T.match_buffer(p_lv36, (T.int64(1), T.int64(32), m, T.int64(80)), "float16") - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(80)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv35[v_i0, v_i1, v_i2, v_k], lv36[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv35[v_i0, v_i1, v_i2, v_k] * lv36[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.11179039301310044) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - - @T.prim_func - def fused_NT_matmul3_add4_gelu1(p_lv52: T.handle, lv1201: T.Buffer((T.int64(10240), T.int64(2560)), "float16"), linear_bias4: T.Buffer((T.int64(10240),), "float16"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv52 = T.match_buffer(p_lv52, (T.int64(1), n, T.int64(2560)), "float16") - var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(10240)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240)), "float16") - var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240)), "float16") - T_multiply = T.alloc_buffer((T.int64(1), n, T.int64(10240)), "float16") - compute = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - compute_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - compute_2 = T.alloc_buffer((T.int64(1), n, T.int64(10240)), "float16") - T_multiply_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240)), "float16") - T_add = T.alloc_buffer((T.int64(1), n, T.int64(10240)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(10240), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv52[v_i0, v_i1, v_k], lv1201[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv52[v_i0, v_i1, v_k] * lv1201[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias4[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias4[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) - T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float16(0.70710678118654757) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(T_multiply[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.Cast("float32", T_multiply[v_i0, v_i1, v_i2]) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(compute[v_i0, v_i1, v_i2]) - T.writes(compute_1[v_i0, v_i1, v_i2]) - compute_1[v_i0, v_i1, v_i2] = T.erf(compute[v_i0, v_i1, v_i2]) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("compute_2"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(compute_1[v_i0, v_i1, v_i2]) - T.writes(compute_2[v_i0, v_i1, v_i2]) - compute_2[v_i0, v_i1, v_i2] = T.Cast("float16", compute_1[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(compute_2[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T_multiply_1[v_ax0, v_ax1, v_ax2] = compute_2[v_ax0, v_ax1, v_ax2] * T.float16(0.5) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T.writes(T_add[v_ax0, v_ax1, v_ax2]) - T_add[v_ax0, v_ax1, v_ax2] = T.float16(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply_2"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_NT_matmul4_add3(p_lv56: T.handle, lv1208: T.Buffer((T.int64(2560), T.int64(10240)), "float16"), linear_bias5: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv56 = T.match_buffer(p_lv56, (T.int64(1), n, T.int64(10240)), "float16") - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv56[v_i0, v_i1, v_k], lv1208[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv56[v_i0, v_i1, v_k] * lv1208[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias5[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias5[v_ax2] - - @T.prim_func - def fused_NT_matmul_divide_maximum_minimum_cast2(lv1869: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16"), p_lv1870: T.handle, p_lv1839: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv1870 = T.match_buffer(p_lv1870, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - lv1839 = T.match_buffer(p_lv1839, (T.int64(1), T.int64(1), T.int64(1), n), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(80)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv1869[v_i0, v_i1, v_i2, v_k], lv1870[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1869[v_i0, v_i1, v_i2, v_k] * lv1870[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.11179039301310044) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1839[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1839[v_ax0, T.int64(0), v_ax2, v_ax3]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - - @T.prim_func - def fused_softmax1_cast8(p_lv43: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n, m = T.int64(), T.int64() - lv43 = T.match_buffer(p_lv43, (T.int64(1), T.int64(32), n, m)) - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv43[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv43[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(lv43[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv43[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - - @T.prim_func - def fused_softmax_cast3(p_lv1877: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv1877 = T.match_buffer(p_lv1877, (T.int64(1), T.int64(32), T.int64(1), n)) - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), "float16") - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) - var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1877[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv1877[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(lv1877[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv1877[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - - @T.prim_func - def matmul1(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80), n): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] - - @T.prim_func - def matmul8(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n, m = T.int64(), T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), "float16") - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(80)), "float16") - matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(80), m): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] - - @T.prim_func - def layer_norm1(var_A: T.handle, B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), var_T_layer_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) - T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) - A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) - for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("A_red_temp"): - v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) - T.reads(A[v_ax0, v_ax1, v_k2]) - T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) - with T.init(): - A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) - A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] - A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 - A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_layer_norm"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) - T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) - T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] - - @T.prim_func - def fused_layer_norm1_cast6(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) - A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) - var_T_layer_norm_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("A_red_temp"): - v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) - T.reads(lv6[v_ax0, v_ax1, v_k2]) - T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) - with T.init(): - A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) - A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] - A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 - A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_layer_norm"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv6[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], weight1[v_ax2], bias[v_ax2]) - T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = (lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight1[v_ax2] + bias[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) - -# fmt: on diff --git a/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1.py b/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1.py deleted file mode 100644 index 7c9d1c55fa..0000000000 --- a/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1.py +++ /dev/null @@ -1,972 +0,0 @@ -# pylint: disable=missing-docstring,line-too-long,invalid-name,too-many-statements,too-many-locals -import tvm -from tvm import tir -from tvm.script import tir as T - -from .redpajama_incite_chat_3b_v1_mod import Module as MOD - -# fmt: off - -def fused_NT_matmul1_add4(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l3, l4, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 2, 32, 1, 2]) - l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[80, 1, 4, 4, 2]) - l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[128, 5, 4]) - l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="blockIdx.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="vthread.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) - b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) - l52, l53, l54 = sch.get_loops(block=b47)[-3:] - sch.fuse(l52, l53, l54, preserve_unit_iters=True) - v56 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) - b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) - l62, l63 = sch.get_loops(block=b57)[-2:] - sch.fuse(l62, l63, preserve_unit_iters=True) - v65 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) - sch.reverse_compute_inline(block=b1) - v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) - sch.enter_postproc() - sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") - l71 = sch.get_loops(block=b47)[-1] - _, l73, l74 = sch.split(loop=l71, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l74) - sch.bind(loop=l73, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - l79 = sch.get_loops(block=b57)[-1] - _, l81 = sch.split(loop=l79, factors=[None, 64], preserve_unit_iters=True) - sch.bind(loop=l81, thread_axis="threadIdx.x") - b82 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv9_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, _, b85, _ = sch.get_child_blocks(b82) - l100 = sch.get_loops(block=b85)[0] - sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) - b118 = sch.get_block(name="NT_matmul", func_name="main") - l122 = sch.get_loops(block=b118)[4] - sch.decompose_reduction(block=b118, loop=l122) - - -def fused_NT_matmul1_add4_add5(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="T_add_1", func_name="main") - b3 = sch.get_block(name="root", func_name="main") - - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l4, l5, l6, l7 = sch.get_loops(block=b0) - v8, v9, v10, v11, v12 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l13, l14, l15, l16, l17 = sch.split(loop=l4, factors=[v8, v9, v10, v11, v12], preserve_unit_iters=True) - v18, v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[2, 8, 4, 2, 1]) - l23, l24, l25, l26, l27 = sch.split(loop=l5, factors=[v18, v19, v20, v21, v22], preserve_unit_iters=True) - v28, v29, v30, v31, v32 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[40, 2, 16, 1, 2]) - l33, l34, l35, l36, l37 = sch.split(loop=l6, factors=[v28, v29, v30, v31, v32], preserve_unit_iters=True) - v38, v39, v40 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64, decision=[160, 4, 4]) - l41, l42, l43 = sch.split(loop=l7, factors=[v38, v39, v40], preserve_unit_iters=True) - sch.reorder(l13, l23, l33, l14, l24, l34, l15, l25, l35, l41, l42, l16, l26, l36, l43, l17, l27, l37) - - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="blockIdx.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="vthread.x") - l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) - sch.bind(loop=l46, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b47 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b47, loop=l46, preserve_unit_loops=True, index=-1) - b48 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b48, loop=l41, preserve_unit_loops=True, index=-1) - l53, l54, l55 = sch.get_loops(block=b48)[-3:] - sch.fuse(l53, l54, l55, preserve_unit_iters=True) - v57 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch", ann_val=v57) - b58 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b58, loop=l41, preserve_unit_loops=True, index=-1) - l63, l64 = sch.get_loops(block=b58)[-2:] - sch.fuse(l63, l64, preserve_unit_iters=True) - v66 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch", ann_val=v66) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v67 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=3) - sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v67) - sch.enter_postproc() - sch.unannotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch") - l72 = sch.get_loops(block=b48)[-1] - _, l74, l75 = sch.split(loop=l72, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l75) - sch.bind(loop=l74, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch") - l80 = sch.get_loops(block=b58)[-1] - _, l82, l83 = sch.split(loop=l80, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l83) - sch.bind(loop=l82, thread_axis="threadIdx.x") - b84 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b84, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv49_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, _, b87, _ = sch.get_child_blocks(b84) - l103 = sch.get_loops(block=b87)[0] - sch.annotate(block_or_loop=l103, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l103, ann_key="pragma_unroll_explicit", ann_val=1) - - b121 = sch.get_block(name="NT_matmul", func_name="main") - l125 = sch.get_loops(block=b121)[4] - sch.decompose_reduction(block=b121, loop=l125) - - -def fused_NT_matmul2_divide1_maximum1_minimum1_cast9(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 1, 32, 32, 1]) - l1, l2, l3, l4, l5 = sch.get_loops(b0) - l6, l7 = sch.split(l3, [None, 32]) - l8, l9 = sch.split(l4, [None, 32]) - sch.reorder(l6, l8, l1, l2, l7, l9, l5) - - b1 = sch.get_block(name="T_divide", func_name="main") - b2 = sch.get_block(name="T_maximum", func_name="main") - b3 = sch.get_block(name="T_minimum", func_name="main") - b4 = sch.get_block(name="compute", func_name="main") - b5 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, _, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) - v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l16, l17, l18, l19, l20 = sch.split(loop=l6, factors=[v11, v12, v13, v14, v15], preserve_unit_iters=True) - v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[16, 1, 1, 1, 2]) - l26, l27, l28, l29, l30 = sch.split(loop=l7, factors=[v21, v22, v23, v24, v25], preserve_unit_iters=True) - v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[1, 1, 16, 2, 4]) - l36, l37, l38, l39, l40 = sch.split(loop=l8, factors=[v31, v32, v33, v34, v35], preserve_unit_iters=True) - v41, v42, v43, v44, v45 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[8, 1, 16, 1, 1]) - l46, l47, l48, l49, l50 = sch.split(loop=l9, factors=[v41, v42, v43, v44, v45], preserve_unit_iters=True) - v51, v52, v53 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64, decision=[4, 20, 1]) - l54, l55, l56 = sch.split(loop=l10, factors=[v51, v52, v53], preserve_unit_iters=True) - sch.reorder(l16, l26, l36, l46, l17, l27, l37, l47, l18, l28, l38, l48, l54, l55, l19, l29, l39, l49, l56, l20, l30, l40, l50) - l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) - sch.bind(loop=l57, thread_axis="blockIdx.x") - l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) - sch.bind(loop=l58, thread_axis="vthread.x") - l59 = sch.fuse(l18, l28, l38, l48, preserve_unit_iters=True) - sch.bind(loop=l59, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b60 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b60, loop=l59, preserve_unit_loops=True, index=-1) - b61 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b61, loop=l54, preserve_unit_loops=True, index=-1) - l66, l67, l68, l69 = sch.get_loops(block=b61)[-4: ] - sch.fuse(l66, l67, l68, l69, preserve_unit_iters=True) - v71 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v71) - b72 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b72, loop=l54, preserve_unit_loops=True, index=-1) - l77, l78, l79, l80 = sch.get_loops(block=b72)[-4:] - sch.fuse(l77, l78, l79, l80, preserve_unit_iters=True) - v82 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v82) - sch.reverse_compute_inline(block=b4) - sch.reverse_compute_inline(block=b3) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v83 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0) - sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v83) - sch.enter_postproc() - sch.unannotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch") - l88 = sch.get_loops(block=b61)[-1] - _, l90, l91 = sch.split(loop=l88, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l91) - sch.bind(loop=l90, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch") - l96 = sch.get_loops(block=b72)[-1] - _, l98, l99 = sch.split(loop=l96, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l99) - sch.bind(loop=l98, thread_axis="threadIdx.x") - b100 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b100, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv36_pad") - sch.compute_inline(b1) - b1 = sch.get_block("lv37_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - b140 = sch.get_block(name="NT_matmul", func_name="main") - l144 = sch.get_loops(block=b140)[5] - sch.decompose_reduction(block=b140, loop=l144) - - -def fused_NT_matmul3_add6_gelu1_cast11(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="T_multiply", func_name="main") - b3 = sch.get_block(name="compute", func_name="main") - b4 = sch.get_block(name="T_multiply_1", func_name="main") - b5 = sch.get_block(name="T_add_1", func_name="main") - b6 = sch.get_block(name="T_multiply_2", func_name="main") - b7 = sch.get_block(name="compute_1", func_name="main") - b8 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l9, l10, l11, l12 = sch.get_loops(block=b0) - v13, v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l18, l19, l20, l21, l22 = sch.split(loop=l9, factors=[v13, v14, v15, v16, v17], preserve_unit_iters=True) - v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l10, n=5, max_innermost_factor=64, decision=[1, 1, 32, 4, 1]) - l28, l29, l30, l31, l32 = sch.split(loop=l10, factors=[v23, v24, v25, v26, v27], preserve_unit_iters=True) - v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l11, n=5, max_innermost_factor=64, decision=[320, 1, 4, 8, 1]) - l38, l39, l40, l41, l42 = sch.split(loop=l11, factors=[v33, v34, v35, v36, v37], preserve_unit_iters=True) - v43, v44, v45 = sch.sample_perfect_tile(loop=l12, n=3, max_innermost_factor=64, decision=[80, 32, 1]) - l46, l47, l48 = sch.split(loop=l12, factors=[v43, v44, v45], preserve_unit_iters=True) - sch.reorder(l18, l28, l38, l19, l29, l39, l20, l30, l40, l46, l47, l21, l31, l41, l48, l22, l32, l42) - l49 = sch.fuse(l18, l28, l38, preserve_unit_iters=True) - sch.bind(loop=l49, thread_axis="blockIdx.x") - l50 = sch.fuse(l19, l29, l39, preserve_unit_iters=True) - sch.bind(loop=l50, thread_axis="vthread.x") - l51 = sch.fuse(l20, l30, l40, preserve_unit_iters=True) - sch.bind(loop=l51, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) - b52 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True, index=-1) - b53 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b53, loop=l46, preserve_unit_loops=True, index=-1) - l58, l59, l60 = sch.get_loops(block=b53)[-3:] - sch.fuse(l58, l59, l60, preserve_unit_iters=True) - v62 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v62) - b63 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b63, loop=l46, preserve_unit_loops=True, index=-1) - l68, l69 = sch.get_loops(block=b63)[-2:] - sch.fuse(l68, l69, preserve_unit_iters=True) - v71 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b63, ann_key="meta_schedule.cooperative_fetch", ann_val=v71) - sch.reverse_compute_inline(block=b7) - sch.compute_inline(block=b5) - sch.compute_inline(block=b4) - sch.compute_inline(block=b3) - sch.compute_inline(block=b2) - sch.compute_inline(block=b1) - sch.reverse_compute_inline(block=b6) - v72 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b8, ann_key="meta_schedule.unroll_explicit", ann_val=v72) - sch.enter_postproc() - sch.unannotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch") - l77 = sch.get_loops(block=b53)[-1] - _, l79, l80 = sch.split(loop=l77, factors=[None, 32, 4], preserve_unit_iters=True) - sch.vectorize(loop=l80) - sch.bind(loop=l79, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b63, ann_key="meta_schedule.cooperative_fetch") - l85 = sch.get_loops(block=b63)[-1] - _, l87, l88 = sch.split(loop=l85, factors=[None, 32, 4], preserve_unit_iters=True) - sch.vectorize(loop=l88) - sch.bind(loop=l87, thread_axis="threadIdx.x") - b89 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b89, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv57_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, _, b92, _ = sch.get_child_blocks(b89) - l108 = sch.get_loops(block=b92)[0] - sch.annotate(block_or_loop=l108, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l108, ann_key="pragma_unroll_explicit", ann_val=1) - - b126 = sch.get_block(name="NT_matmul", func_name="main") - l130 = sch.get_loops(block=b126)[4] - sch.decompose_reduction(block=b126, loop=l130) - - -def fused_NT_matmul4_add7_cast8_cast12_add5(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="compute", func_name="main") - b3 = sch.get_block(name="compute_1", func_name="main") - b4 = sch.get_block(name="T_add_1", func_name="main") - b5 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l6, l7, l8, l9 = sch.get_loops(block=b0) - v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l15, l16, l17, l18, l19 = sch.split(loop=l6, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) - v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[2, 4, 16, 1, 1]) - l25, l26, l27, l28, l29 = sch.split(loop=l7, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) - v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[40, 1, 8, 1, 8]) - l35, l36, l37, l38, l39 = sch.split(loop=l8, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) - v40, v41, v42 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[320, 32, 1]) - l43, l44, l45 = sch.split(loop=l9, factors=[v40, v41, v42], preserve_unit_iters=True) - sch.reorder(l15, l25, l35, l16, l26, l36, l17, l27, l37, l43, l44, l18, l28, l38, l45, l19, l29, l39) - l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) - sch.bind(loop=l46, thread_axis="blockIdx.x") - l47 = sch.fuse(l16, l26, l36, preserve_unit_iters=True) - sch.bind(loop=l47, thread_axis="vthread.x") - l48 = sch.fuse(l17, l27, l37, preserve_unit_iters=True) - sch.bind(loop=l48, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b49 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b49, loop=l48, preserve_unit_loops=True, index=-1) - b50 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b50, loop=l43, preserve_unit_loops=True, index=-1) - l55, l56, l57 = sch.get_loops(block=b50)[-3:] - sch.fuse(l55, l56, l57, preserve_unit_iters=True) - v59 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b50, ann_key="meta_schedule.cooperative_fetch", ann_val=v59) - b60 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b60, loop=l43, preserve_unit_loops=True, index=-1) - l65, l66 = sch.get_loops(block=b60)[-2:] - sch.fuse(l65, l66, preserve_unit_iters=True) - v68 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v68) - sch.reverse_compute_inline(block=b4) - sch.reverse_compute_inline(block=b3) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v69 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) - sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v69) - sch.enter_postproc() - sch.unannotate(block_or_loop=b50, ann_key="meta_schedule.cooperative_fetch") - l74 = sch.get_loops(block=b50)[-1] - _, l76, l77 = sch.split(loop=l74, factors=[None, 128, 4], preserve_unit_iters=True) - sch.vectorize(loop=l77) - sch.bind(loop=l76, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") - l82 = sch.get_loops(block=b60)[-1] - _, l84, l85 = sch.split(loop=l82, factors=[None, 128, 2], preserve_unit_iters=True) - sch.vectorize(loop=l85) - sch.bind(loop=l84, thread_axis="threadIdx.x") - b86 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b86, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv63_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, _, b89, _ = sch.get_child_blocks(b86) - l105 = sch.get_loops(block=b89)[0] - sch.annotate(block_or_loop=l105, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l105, ann_key="pragma_unroll_explicit", ann_val=1) - b123 = sch.get_block(name="NT_matmul", func_name="main") - l127 = sch.get_loops(block=b123)[4] - sch.decompose_reduction(block=b123, loop=l127) - - -def fused_NT_matmul4_add7_cast8_cast12_add5_cast7(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="compute", func_name="main") - b3 = sch.get_block(name="compute_1", func_name="main") - b4 = sch.get_block(name="T_add_1", func_name="main") - b5 = sch.get_block(name="compute_2", func_name="main") - b6 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l7, l8, l9, l10 = sch.get_loops(block=b0) - v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l16, l17, l18, l19, l20 = sch.split(loop=l7, factors=[v11, v12, v13, v14, v15], preserve_unit_iters=True) - v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 16, 2, 2]) - l26, l27, l28, l29, l30 = sch.split(loop=l8, factors=[v21, v22, v23, v24, v25], preserve_unit_iters=True) - v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[64, 2, 10, 1, 2]) - l36, l37, l38, l39, l40 = sch.split(loop=l9, factors=[v31, v32, v33, v34, v35], preserve_unit_iters=True) - v41, v42, v43 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64, decision=[256, 20, 2]) - l44, l45, l46 = sch.split(loop=l10, factors=[v41, v42, v43], preserve_unit_iters=True) - sch.reorder(l16, l26, l36, l17, l27, l37, l18, l28, l38, l44, l45, l19, l29, l39, l46, l20, l30, l40) - l47 = sch.fuse(l16, l26, l36, preserve_unit_iters=True) - sch.bind(loop=l47, thread_axis="blockIdx.x") - l48 = sch.fuse(l17, l27, l37, preserve_unit_iters=True) - sch.bind(loop=l48, thread_axis="vthread.x") - l49 = sch.fuse(l18, l28, l38, preserve_unit_iters=True) - sch.bind(loop=l49, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b50 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b50, loop=l49, preserve_unit_loops=True, index=-1) - b51 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b51, loop=l44, preserve_unit_loops=True, index=-1) - l56, l57, l58 = sch.get_loops(block=b51)[-3:] - sch.fuse(l56, l57, l58, preserve_unit_iters=True) - v60 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b51, ann_key="meta_schedule.cooperative_fetch", ann_val=v60) - b61 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b61, loop=l44, preserve_unit_loops=True, index=-1) - l66, l67 = sch.get_loops(block=b61)[-2:] - sch.fuse(l66, l67, preserve_unit_iters=True) - v69 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v69) - sch.reverse_compute_inline(block=b5) - sch.reverse_compute_inline(block=b4) - sch.reverse_compute_inline(block=b3) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v70 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b6, ann_key="meta_schedule.unroll_explicit", ann_val=v70) - sch.enter_postproc() - sch.unannotate(block_or_loop=b51, ann_key="meta_schedule.cooperative_fetch") - l75 = sch.get_loops(block=b51)[-1] - _, l77, l78 = sch.split(loop=l75, factors=[None, 80, 4], preserve_unit_iters=True) - sch.vectorize(loop=l78) - sch.bind(loop=l77, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch") - l83 = sch.get_loops(block=b61)[-1] - _, l85, l86 = sch.split(loop=l83, factors=[None, 80, 2], preserve_unit_iters=True) - sch.vectorize(loop=l86) - sch.bind(loop=l85, thread_axis="threadIdx.x") - b87 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b87, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv2047_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, _, b90, _ = sch.get_child_blocks(b87) - l106 = sch.get_loops(block=b90)[0] - sch.annotate(block_or_loop=l106, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l106, ann_key="pragma_unroll_explicit", ann_val=1) - b124 = sch.get_block(name="NT_matmul", func_name="main") - l128 = sch.get_loops(block=b124)[4] - sch.decompose_reduction(block=b124, loop=l128) - - -def fused_NT_matmul_divide_maximum_minimum_cast2(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 1, 1, 32, 1]) - l1, l2, l3, l4, l5 = sch.get_loops(b0) - l6, l7 = sch.split(l4, [None, 32]) - sch.reorder(l6, l1, l2, l3, l7, l5) - - b1 = sch.get_block(name="T_divide", func_name="main") - b2 = sch.get_block(name="T_maximum", func_name="main") - b3 = sch.get_block(name="T_minimum", func_name="main") - b4 = sch.get_block(name="compute", func_name="main") - b5 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) - v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l16, l17, l18, l19, l20 = sch.split(loop=l6, factors=[v11, v12, v13, v14, v15], preserve_unit_iters=True) - v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[4, 1, 8, 1, 1]) - l26, l27, l28, l29, l30 = sch.split(loop=l7, factors=[v21, v22, v23, v24, v25], preserve_unit_iters=True) - v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l36, l37, l38, l39, l40 = sch.split(loop=l8, factors=[v31, v32, v33, v34, v35], preserve_unit_iters=True) - v41, v42, v43, v44, v45 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[4, 1, 16, 2, 1]) - l46, l47, l48, l49, l50 = sch.split(loop=l9, factors=[v41, v42, v43, v44, v45], preserve_unit_iters=True) - v51, v52, v53 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64, decision=[5, 8, 2]) - l54, l55, l56 = sch.split(loop=l10, factors=[v51, v52, v53], preserve_unit_iters=True) - sch.reorder(l16, l26, l36, l46, l17, l27, l37, l47, l18, l28, l38, l48, l54, l55, l19, l29, l39, l49, l56, l20, l30, l40, l50) - l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) - sch.bind(loop=l57, thread_axis="blockIdx.x") - l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) - sch.bind(loop=l58, thread_axis="vthread.x") - l59 = sch.fuse(l18, l28, l38, l48, preserve_unit_iters=True) - sch.bind(loop=l59, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b60 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b60, loop=l59, preserve_unit_loops=True, index=-1) - b61 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b61, loop=l54, preserve_unit_loops=True, index=-1) - l66, l67, l68, l69 = sch.get_loops(block=b61)[-4:] - sch.fuse(l66, l67, l68, l69, preserve_unit_iters=True) - v71 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v71) - b72 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b72, loop=l54, preserve_unit_loops=True, index=-1) - l77, l78, l79, l80 = sch.get_loops(block=b72)[-4:] - sch.fuse(l77, l78, l79, l80, preserve_unit_iters=True) - v82 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v82) - sch.reverse_compute_inline(block=b4) - sch.reverse_compute_inline(block=b3) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v83 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v83) - sch.enter_postproc() - sch.unannotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch") - l88 = sch.get_loops(block=b61)[-1] - _, l90, l91 = sch.split(loop=l88, factors=[None, 128, 2], preserve_unit_iters=True) - sch.vectorize(loop=l91) - sch.bind(loop=l90, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch") - l96 = sch.get_loops(block=b72)[-1] - _, l98, l99 = sch.split(loop=l96, factors=[None, 128, 2], preserve_unit_iters=True) - sch.vectorize(loop=l99) - sch.bind(loop=l98, thread_axis="threadIdx.x") - b100 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b100, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv2095_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, _, b103, _ = sch.get_child_blocks(b100) - l119 = sch.get_loops(block=b103)[0] - sch.annotate(block_or_loop=l119, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l119, ann_key="pragma_unroll_explicit", ann_val=1) - - b140 = sch.get_block(name="NT_matmul", func_name="main") - l144 = sch.get_loops(block=b140)[4] - sch.decompose_reduction(block=b140, loop=l144) - - -def fused_layer_norm1_cast8(sch: tir.Schedule): - b0 = sch.get_block(name="A_red_temp", func_name="main") - b1 = sch.get_block(name="T_layer_norm", func_name="main") - b2 = sch.get_block(name="compute", func_name="main") - b3 = sch.get_block(name="root", func_name="main") - sch.reverse_compute_inline(block=b2) - v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125], decision=5) - l5, l6, l7 = sch.get_loops(block=b0) - l8, l9 = sch.split(loop=l7, factors=[None, v4], preserve_unit_iters=True) - sch.bind(loop=l9, thread_axis="threadIdx.x") - v10 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v10) - l11, l12, l13 = sch.get_loops(block=b1) - l14 = sch.fuse(l11, l12, l13, preserve_unit_iters=True) - l15, l16, l17 = sch.split(loop=l14, factors=[None, 256, 256], preserve_unit_iters=True) - sch.reorder(l16, l17, l15) - sch.bind(loop=l16, thread_axis="blockIdx.x") - sch.bind(loop=l17, thread_axis="threadIdx.x") - l18, l19, l20, l21 = sch.get_loops(block=b0) - l22 = sch.fuse(l18, l19, preserve_unit_iters=True) - sch.bind(loop=l22, thread_axis="blockIdx.x") - sch.enter_postproc() - b23 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b23, ann_key="meta_schedule.unroll_explicit") - b24, b25 = sch.get_child_blocks(b23) - l26, l27, l28 = sch.get_loops(block=b24) - sch.annotate(block_or_loop=l26, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l26, ann_key="pragma_unroll_explicit", ann_val=1) - l29, l30, l31 = sch.get_loops(block=b25) - - -def layer_norm1(sch: tir.Schedule): - b0 = sch.get_block(name="A_red_temp", func_name="main") - b1 = sch.get_block(name="T_layer_norm", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125], decision=1) - l4, l5, l6 = sch.get_loops(block=b0) - l7, l8 = sch.split(loop=l6, factors=[None, v3], preserve_unit_iters=True) - sch.bind(loop=l8, thread_axis="threadIdx.x") - v9 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v9) - l10, l11, l12 = sch.get_loops(block=b1) - l13 = sch.fuse(l10, l11, l12, preserve_unit_iters=True) - l14, l15, l16 = sch.split(loop=l13, factors=[None, 256, 256], preserve_unit_iters=True) - sch.reorder(l15, l16, l14) - sch.bind(loop=l15, thread_axis="blockIdx.x") - sch.bind(loop=l16, thread_axis="threadIdx.x") - l17, l18, l19, l20 = sch.get_loops(block=b0) - l21 = sch.fuse(l17, l18, preserve_unit_iters=True) - sch.bind(loop=l21, thread_axis="blockIdx.x") - sch.enter_postproc() - b22 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.unroll_explicit") - b23, b24 = sch.get_child_blocks(b22) - l25, l26, l27 = sch.get_loops(block=b23) - sch.annotate(block_or_loop=l25, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l25, ann_key="pragma_unroll_explicit", ann_val=1) - l28, l29, l30 = sch.get_loops(block=b24) - - -def matmul3(sch: tir.Schedule): - b0 = sch.get_block(name="matmul", func_name="main") - sch.pad_einsum(b0, [1, 1, 1, 1, 32]) - l1, l2, l3, l4, k = sch.get_loops(b0) - k0, k1 = sch.split(k, [None, 32]) - sch.reorder(l1, l2, l3, k0, l4, k1) - - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - l2, l3, l4, _, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 16, 1, 2]) - l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[8, 1, 10, 1, 1]) - l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) - v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[1, 32, 1]) - l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, k0, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) - l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) - sch.bind(loop=l53, thread_axis="blockIdx.x") - l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) - sch.bind(loop=l54, thread_axis="vthread.x") - l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) - sch.bind(loop=l55, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) - b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) - l62, l63, l64, l65 = sch.get_loops(block=b57)[-4:] - sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) - v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) - b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) - l73, l74, l75, l76 = sch.get_loops(block=b68)[-4:] - sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) - v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) - v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) - sch.enter_postproc() - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - l84 = sch.get_loops(block=b57)[-1] - _, l86, l87 = sch.split(loop=l84, factors=[None, 160, 4], preserve_unit_iters=True) - sch.vectorize(loop=l87) - sch.bind(loop=l86, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") - l92 = sch.get_loops(block=b68)[-1] - _, l94, l95 = sch.split(loop=l92, factors=[None, 160, 2], preserve_unit_iters=True) - sch.vectorize(loop=l95) - sch.bind(loop=l94, thread_axis="threadIdx.x") - b96 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b96, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("A_pad") - sch.compute_inline(b1) - b1 = sch.get_block("B_pad") - sch.compute_inline(b1) - - _, _, b99, _ = sch.get_child_blocks(b96) - l115 = sch.get_loops(block=b99)[0] - sch.annotate(block_or_loop=l115, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l115, ann_key="pragma_unroll_explicit", ann_val=1) - b136 = sch.get_block(name="matmul", func_name="main") - l140 = sch.get_loops(block=b136)[3] - sch.decompose_reduction(block=b136, loop=l140) - - -def matmul9(sch: tir.Schedule): - b0 = sch.get_block(name="matmul", func_name="main") - sch.pad_einsum(b0, [1, 1, 32, 1, 32]) - l1, l2, l3, l4, k = sch.get_loops(b0) - s0, s1 = sch.split(l3, [None, 32]) - k0, k1 = sch.split(k, [None, 32]) - sch.reorder(s0, l1, l2, s1, k0, l4, k1) - - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l2, l3, l4, _, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 1, 1, 2]) - l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[8, 1, 8, 2, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[2, 1, 5, 2, 4]) - l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) - v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[16, 1, 2]) - l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, k0, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) - l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) - sch.bind(loop=l53, thread_axis="blockIdx.x") - l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) - sch.bind(loop=l54, thread_axis="vthread.x") - l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) - sch.bind(loop=l55, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) - b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) - l62, l63, l64, l65 = sch.get_loops(block=b57)[-4:] - sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) - v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) - b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) - l73, l74, l75, l76 = sch.get_loops(block=b68)[-4:] - sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) - v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) - v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) - - b1 = sch.get_block("A_pad") - sch.compute_inline(b1) - b1 = sch.get_block("B_pad") - sch.compute_inline(b1) - b1 = sch.get_block("matmul_pad") - sch.reverse_compute_inline(b1) - - sch.enter_postproc() - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - l84 = sch.get_loops(block=b57)[-1] - _, l86, l87 = sch.split(loop=l84, factors=[None, 40, 4], preserve_unit_iters=True) - sch.vectorize(loop=l87) - sch.bind(loop=l86, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") - l92 = sch.get_loops(block=b68)[-1] - _, l94, l95 = sch.split(loop=l92, factors=[None, 40, 2], preserve_unit_iters=True) - sch.vectorize(loop=l95) - sch.bind(loop=l94, thread_axis="threadIdx.x") - b96 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b96, ann_key="meta_schedule.unroll_explicit") - b136 = sch.get_block(name="matmul", func_name="main") - l140 = sch.get_loops(block=b136)[4] - sch.decompose_reduction(block=b136, loop=l140) - - -def softmax_1xn(sch: tir.Schedule): - has_cast = True - if has_cast: - b_cast = sch.get_block("compute") - sch.reverse_compute_inline(b_cast) - - b0 = sch.get_block("T_softmax_exp") - sch.compute_inline(b0) - b1 = sch.get_block("T_softmax_norm") - l2, l3, l4, l5 = sch.get_loops(b1) - _, l7 = sch.split(l5, [None, 128]) - sch.bind(l7, "threadIdx.x") - b8 = sch.get_block("T_softmax_expsum") - sch.compute_at(b8, l4) - sch.set_scope(b8, 0, "shared") - _, _, _, l12 = sch.get_loops(b8) - _, l14 = sch.split(l12, [None, 128]) - sch.bind(l14, "threadIdx.x") - b15 = sch.get_block("T_softmax_maxelem") - sch.compute_at(b15, l4) - sch.set_scope(b15, 0, "shared") - _, _, _, l19 = sch.get_loops(b15) - _, l21 = sch.split(l19, [None, 128]) - sch.bind(l21, "threadIdx.x") - l22 = sch.fuse(l2, l3, l4) - sch.bind(l22, "blockIdx.x") - - -def fused_min_max_triu_te_broadcast_to(sch: tir.Schedule): - b0 = sch.get_block("T_broadcast_to") - sch.reverse_compute_inline(b0) - b1 = sch.get_block("make_diag_mask_te") - i, j = sch.get_loops(b1) - i = sch.fuse(i, j) - i, j = sch.split(i, [None, 128]) - sch.bind(i, "blockIdx.x") - sch.bind(j, "threadIdx.x") - - -@T.prim_func -def softmax_mxn_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - m = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - - -@T.prim_func -def softmax_mxn_after(var_A: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - m = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_maxelem_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)]) - T.writes(T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) - T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float32(-3.4028234663852886e+38))) - for i0_i1_i2_1_fused_0 in range(T.int64(8)): - for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem_cache_write"): - v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) - T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) - T.reads(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_expsum_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)], T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T.writes(T_softmax_expsum[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(0) - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]), T.float32(0)) - for i0_i1_i2_1_fused_0 in range(T.int64(8)): - for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum_cache_write"): - v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) - T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) - T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] - for i0_i1_i2_fused_i3_fused_0 in T.thread_binding((n * T.int64(32) * m + T.int64(255)) // T.int64(256), thread="blockIdx.x"): - for i0_i1_i2_fused_i3_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("T_softmax_norm"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // m // n) - v_i2 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // m % n) - v_i3 = T.axis.spatial(m, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) % m) - T.where(i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1 < n * T.int64(32) * m) - T.reads(T_softmax_expsum[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] - - - -def _get_dict(): - # tvm.ir.assert_structural_equal(MOD["softmax"], softmax_mxn_before) - func_dict = { - # softmax_mxn_before: softmax_mxn_after, - } - for name, func in [ - # fmt: off - ("fused_layer_norm1_cast8", fused_layer_norm1_cast8), - ("fused_NT_matmul1_add4_add5", fused_NT_matmul1_add4_add5), - ("fused_NT_matmul2_divide1_maximum1_minimum1_cast9", fused_NT_matmul2_divide1_maximum1_minimum1_cast9), - ("fused_NT_matmul4_add7_cast8_cast12_add5", fused_NT_matmul4_add7_cast8_cast12_add5), - ("fused_NT_matmul3_add6_gelu1_cast11", fused_NT_matmul3_add6_gelu1_cast11), - ("fused_NT_matmul_divide_maximum_minimum_cast2", fused_NT_matmul_divide_maximum_minimum_cast2), - ("matmul3", matmul3), - ("fused_NT_matmul1_add4", fused_NT_matmul1_add4), - ("matmul9", matmul9), - ("layer_norm1", layer_norm1), - ("fused_NT_matmul4_add7_cast8_cast12_add5_cast7", fused_NT_matmul4_add7_cast8_cast12_add5_cast7), - ("fused_min_max_triu_te_broadcast_to", fused_min_max_triu_te_broadcast_to), - ("fused_softmax_cast3", softmax_1xn), - # fmt: on - ]: - # print(f"############### {name} ###############") - sch = tir.Schedule(MOD[name]) - func(sch) - # sch.mod["main"].show(black_format=False) - func_dict[MOD[name]] = sch.mod["main"] - return { - (tvm.ir.structural_hash(k), k): v.with_attr("tir.is_scheduled", True) - for k, v in func_dict.items() - } - - -DICT = _get_dict() - - -def lookup(func): - for (hash_value, func_before), f_after in DICT.items(): - if tvm.ir.structural_hash(func) == hash_value and tvm.ir.structural_equal( - func, func_before - ): - return f_after - return None diff --git a/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_mod.py b/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_mod.py deleted file mode 100644 index b71567bc08..0000000000 --- a/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_mod.py +++ /dev/null @@ -1,722 +0,0 @@ -# pylint: disable=pointless-string-statement,invalid-name,missing-docstring,line-too-long,too-many-locals,too-many-arguments,too-many-statements -from tvm.script import ir as I -from tvm.script import tir as T - -# fmt: off - -@I.ir_module -class Module: - @T.prim_func - def cast7(var_A: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560)), "float16") - compute = T.match_buffer(var_compute, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(A[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.Cast("float32", A[v_i0, v_i1, v_i2]) - - @T.prim_func - def extend_te(var_A: T.handle, var_concat_te: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(1), n, n), "float16") - m = T.int64() - concat_te = T.match_buffer(var_concat_te, (T.int64(1), T.int64(1), n, m), "float16") - # with T.block("root"): - for b, _, i, j in T.grid(T.int64(1), T.int64(1), n, m): - with T.block("concat_te"): - v_b, v__, v_i, v_j = T.axis.remap("SSSS", [b, _, i, j]) - T.reads(A[v_b, v__, v_i, v_j + n - m]) - T.writes(concat_te[v_b, v__, v_i, v_j]) - concat_te[v_b, v__, v_i, v_j] = T.if_then_else(v_j < m - n, T.float16(65504), A[v_b, v__, v_i, v_j + n - m]) - - @T.prim_func - def full(var_T_full: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(1), T.int64(1), n), "float16") - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), n): - with T.block("T_full"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads() - T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3]) - T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(65504) - - @T.prim_func - def fused_NT_matmul1_add4(p_lv9: T.handle, lv1173: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv9 = T.match_buffer(p_lv9, (T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv9[v_i0, v_i1, v_k], lv1173[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv9[v_i0, v_i1, v_k] * lv1173[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias[v_ax2] - - @T.prim_func - def fused_NT_matmul1_add4_add5(p_lv49: T.handle, lv1194: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias3: T.Buffer((T.int64(2560),), "float16"), p_lv2: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560)), "float16") - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv49[v_i0, v_i1, v_k], lv1194[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv49[v_i0, v_i1, v_k] * lv1194[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias3[v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias3[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2], lv2[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + lv2[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_NT_matmul2_divide1_maximum1_minimum1_cast9(p_lv36: T.handle, p_lv37: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv36 = T.match_buffer(p_lv36, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - m = T.int64() - lv37 = T.match_buffer(p_lv37, (T.int64(1), T.int64(32), m, T.int64(80)), "float16") - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(80)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv36[v_i0, v_i1, v_i2, v_k], lv37[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv36[v_i0, v_i1, v_i2, v_k] * lv37[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.11179039301310044) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - - @T.prim_func - def fused_NT_matmul3_add6_gelu1_cast11(p_lv57: T.handle, lv1201: T.Buffer((T.int64(10240), T.int64(2560)), "float16"), linear_bias4: T.Buffer((T.int64(10240),), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv57 = T.match_buffer(p_lv57, (T.int64(1), n, T.int64(2560)), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(10240)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_multiply = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - compute = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_multiply_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_add = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - var_T_multiply_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(10240), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv57[v_i0, v_i1, v_k], lv1201[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv57[v_i0, v_i1, v_k]) * T.Cast("float32", lv1201[v_i2, v_k]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias4[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias4[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) - T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(T_multiply[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(compute[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T.writes(T_add[v_ax0, v_ax1, v_ax2]) - T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply_2"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_multiply_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_multiply_intermediate[v_i0, v_i1, v_i2]) - - @T.prim_func - def fused_NT_matmul4_add7_cast8_cast12_add5(p_lv63: T.handle, lv1208: T.Buffer((T.int64(2560), T.int64(10240)), "float16"), linear_bias5: T.Buffer((T.int64(2560),), "float32"), p_lv53: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv63 = T.match_buffer(p_lv63, (T.int64(1), n, T.int64(10240)), "float16") - lv53 = T.match_buffer(p_lv53, (T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_compute_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv63[v_i0, v_i1, v_k], lv1208[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv63[v_i0, v_i1, v_k]) * T.Cast("float32", lv1208[v_i2, v_k]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias5[v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias5[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) - var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[v_i0, v_i1, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], lv53[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv53[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_NT_matmul4_add7_cast8_cast12_add5_cast7(p_lv2047: T.handle, lv2510: T.Buffer((T.int64(2560), T.int64(10240)), "float16"), linear_bias191: T.Buffer((T.int64(2560),), "float32"), p_lv2037: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv2047 = T.match_buffer(p_lv2047, (T.int64(1), n, T.int64(10240)), "float16") - lv2037 = T.match_buffer(p_lv2037, (T.int64(1), n, T.int64(2560)), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_compute_intermediate_2 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2047[v_i0, v_i1, v_k], lv2510[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv2047[v_i0, v_i1, v_k]) * T.Cast("float32", lv2510[v_i2, v_k]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias191[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias191[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) - var_compute_intermediate_1[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2]) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_compute_intermediate_1[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate_2[v_i0, v_i1, v_i2]) - var_compute_intermediate_2[v_i0, v_i1, v_i2] = var_compute_intermediate_1[v_i0, v_i1, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_compute_intermediate_2[v_ax0, v_ax1, v_ax2], lv2037[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_2[v_ax0, v_ax1, v_ax2] + lv2037[v_ax0, v_ax1, v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute_2"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - - @T.prim_func - def fused_NT_matmul_divide_maximum_minimum_cast2(lv2094: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16"), p_lv2095: T.handle, p_lv2063: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv2095 = T.match_buffer(p_lv2095, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - lv2063 = T.match_buffer(p_lv2063, (T.int64(1), T.int64(1), T.int64(1), n), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(80)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv2094[v_i0, v_i1, v_i2, v_k], lv2095[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv2094[v_i0, v_i1, v_i2, v_k] * lv2095[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.11179039301310044) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2063[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2063[v_ax0, T.int64(0), v_ax2, v_ax3]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - - @T.prim_func - def fused_layer_norm1_cast8(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) - A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) - var_T_layer_norm_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("A_red_temp"): - v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) - T.reads(lv6[v_ax0, v_ax1, v_k2]) - T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) - with T.init(): - A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) - A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] - A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 - A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_layer_norm"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv6[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], weight1[v_ax2], bias[v_ax2]) - T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = (lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight1[v_ax2] + bias[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) - - @T.prim_func - def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int64): - T.func_attr({"tir.noalias": T.bool(True)}) - var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), n, n), "float16") - # with T.block("root"): - var_make_diag_mask_te_intermediate = T.alloc_buffer((n, n), "float16") - for i, j in T.grid(n, n): - with T.block("make_diag_mask_te"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads() - T.writes(var_make_diag_mask_te_intermediate[v_i, v_j]) - var_make_diag_mask_te_intermediate[v_i, v_j] = T.Select(v_i < v_j, T.float16(-65504), T.float16(65504)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), n, n): - with T.block("T_broadcast_to"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_make_diag_mask_te_intermediate[v_ax2, v_ax3]) - T.writes(var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_make_diag_mask_te_intermediate[v_ax2, v_ax3] - - @T.prim_func - def fused_softmax1_cast10(p_lv44: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n, m = T.int64(), T.int64() - lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv44[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(lv44[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv44[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - - @T.prim_func - def fused_softmax_cast3(p_lv2102: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv2102 = T.match_buffer(p_lv2102, (T.int64(1), T.int64(32), T.int64(1), n)) - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), "float16") - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) - var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2102[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv2102[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(lv2102[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv2102[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - - @T.prim_func - def layer_norm1(var_A: T.handle, B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), var_T_layer_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) - T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) - A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) - for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("A_red_temp"): - v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) - T.reads(A[v_ax0, v_ax1, v_k2]) - T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) - with T.init(): - A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) - A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] - A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 - A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_layer_norm"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) - T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) - T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] - - @T.prim_func - def matmul3(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80), n): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] - - @T.prim_func - def matmul9(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n, m = T.int64(), T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), "float16") - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(80)), "float16") - matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(80), m): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] - - @T.prim_func - def reshape3(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (n, T.int64(32), T.int64(80)), "float16") - T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[((v_ax3 // T.int64(80) + v_ax2) // T.int64(32) + v_ax0 * n + v_ax1) % n, (v_ax3 // T.int64(80) + v_ax2) % T.int64(32), v_ax3 % T.int64(80)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[((v_ax3 // T.int64(80) + v_ax2) // T.int64(32) + v_ax0 * n + v_ax1) % n, (v_ax3 // T.int64(80) + v_ax2) % T.int64(32), v_ax3 % T.int64(80)] - - @T.prim_func - def reshape5(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n), "int32") - T_reshape = T.match_buffer(var_T_reshape, (n,), "int32") - # with T.block("root"): - for ax0 in range(n): - with T.block("T_reshape"): - v_ax0 = T.axis.spatial(n, ax0) - T.reads(A[T.int64(0), v_ax0 % n]) - T.writes(T_reshape[v_ax0]) - T_reshape[v_ax0] = A[T.int64(0), v_ax0 % n] - - @T.prim_func - def reshape6(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (n, T.int64(2560)), "float16") - T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[(v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) - T_reshape[v_ax0, v_ax1, v_ax2] = A[(v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560)] - - @T.prim_func - def reshape7(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560)), "float16") - T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[T.int64(0), ((v_ax2 * T.int64(80) + v_ax3) // T.int64(2560) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), ((v_ax2 * T.int64(80) + v_ax3) // T.int64(2560) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)] - - @T.prim_func - def reshape8(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") - T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[T.int64(0), (v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) - T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), (v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)] - - @T.prim_func - def rotary_embedding(var_A: T.handle, B: T.Buffer((T.int64(2048), T.int64(80)), "float16"), C: T.Buffer((T.int64(2048), T.int64(80)), "float16"), var_rotary: T.handle, m: T.int64): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") - rotary = T.match_buffer(var_rotary, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") - # with T.block("root"): - for i_batch_size, i_seq_len, i_num_heads, i_head_dim in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): - with T.block("rotary"): - v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim = T.axis.remap("SSSS", [i_batch_size, i_seq_len, i_num_heads, i_head_dim]) - T.reads(B[m + v_i_seq_len - n, v_i_head_dim], A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim - T.int64(40):v_i_head_dim - T.int64(40) + T.int64(81)], C[m + v_i_seq_len - n, v_i_head_dim]) - T.writes(rotary[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim]) - rotary[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim] = T.Select(v_i_head_dim < T.int64(80), B[m + v_i_seq_len - n, v_i_head_dim] * A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim] + C[m + v_i_seq_len - n, v_i_head_dim] * T.Select(v_i_head_dim < T.int64(40), A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim + T.int64(40)] * T.float16(-1), A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim - T.int64(40)]), A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim]) - - @T.prim_func - def slice(var_A: T.handle, slice_1: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("slice"): - v_i, v__, v_k = T.axis.remap("SSS", [i, _, k]) - T.reads(A[v_i, n - T.int64(1), v_k]) - T.writes(slice_1[v_i, v__, v_k]) - slice_1[v_i, v__, v_k] = A[v_i, n - T.int64(1), v_k] - - @T.prim_func - def squeeze1(var_A: T.handle, var_T_squeeze: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") - T_squeeze = T.match_buffer(var_T_squeeze, (n, T.int64(32), T.int64(80)), "float16") - # with T.block("root"): - for ax0, ax1, ax2 in T.grid(n, T.int64(32), T.int64(80)): - with T.block("T_squeeze"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2]) - T.writes(T_squeeze[v_ax0, v_ax1, v_ax2]) - T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2] - - @T.prim_func - def take_decode1(A: T.Buffer((T.int64(50432), T.int64(320)), "uint32"), B: T.Buffer((T.int64(50432), T.int64(80)), "float16"), var_C: T.handle, var_take_decode: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - C = T.match_buffer(var_C, (n,), "int32") - take_decode = T.match_buffer(var_take_decode, (n, T.int64(2560)), "float16") - # with T.block("root"): - for i, j in T.grid(n, T.int64(2560)): - with T.block("take_decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[C[v_i], v_j // T.int64(8)], C[v_i], B[C[v_i], v_j // T.int64(32)]) - T.writes(take_decode[v_i, v_j]) - take_decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[C[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[C[v_i], v_j // T.int64(32)] - - @T.prim_func - def transpose3(var_A: T.handle, var_T_transpose: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") - T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(80)): - with T.block("T_transpose"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) - T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) - T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] - - @T.prim_func - def transpose6(var_A: T.handle, var_T_transpose: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), n, T.int64(32), T.int64(80)), "float16") - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): - with T.block("T_transpose"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) - T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) - T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] - - -# fmt: on diff --git a/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_tune.py b/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_tune.py deleted file mode 100644 index 460bec0f75..0000000000 --- a/mlc_llm/dispatch/gpt_neox/redpajama_incite_chat_3b_v1_tune.py +++ /dev/null @@ -1,1010 +0,0 @@ -from tvm.script import ir as I -from tvm.script import tir as T - -""" - ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done ------------------------------------------------------------------------------------------------------------------------------------------------------------ - 0 | cast | 1 | 1 | 0.0000 | 27.7422 | 27.7422 | 4 | Y - 1 | cast6 | 1 | 1 | 0.0000 | 27.6800 | 27.6800 | 4 | Y - 2 | decode4 | 26214400 | 1 | 167.6862 | 156.3301 | 156.3301 | 172 | Y - 3 | decode5 | 104857600 | 1 | 128.5783 | 815.5153 | 815.5153 | 172 | Y - 4 | decode6 | 104857600 | 1 | 128.6586 | 815.0066 | 815.0066 | 179 | Y - 5 | divide2 | 50432 | 1 | 1.8169 | 27.7575 | 27.7575 | 4 | Y - 6 | fused_NT_matmul1_add4 | 1678049280 | 1 | 2178.8097 | 770.1679 | 770.1679 | 1088 | Y - 7 | fused_NT_matmul1_add4_add5 | 1678376960 | 1 | 2130.5374 | 787.7717 | 787.7717 | 1215 | Y - 8 | fused_NT_matmul2_divide1_maximum1_minimum1_cast9 | 85458944 | 1 | 1211.9454 | 70.5139 | 70.5139 | 192 | Y - 9 | fused_NT_matmul3_add6_gelu1_cast11 | 6717440000 | 1 | 2129.3171 | 3154.7391 | 3154.7391 | 4416 | Y - 10 | fused_NT_matmul4_add7_cast8_cast12_add5 | 6711541760 | 1 | 2072.7296 | 3238.0208 | 3238.0208 | 4544 | Y - 11 | fused_NT_matmul4_add7_cast8_cast12_add5_cast7 | 6711541760 | 1 | 2091.5892 | 3208.8241 | 3208.8241 | 4416 | Y - 12 | fused_NT_matmul_divide_maximum_minimum_cast2 | 667648 | 1 | 23.3021 | 28.6519 | 28.6519 | 64 | Y - 13 | fused_decode1_fused_matmul4_add2_gelu_cast4 | 157337600 | 1 | 812.5380 | 193.6372 | 193.6372 | 319 | Y - 14 | fused_decode2_fused_matmul5_add3_cast1_cast5_add1 | 157291520 | 1 | 730.8166 | 215.2271 | 215.2271 | 320 | Y - 15 | fused_decode2_fused_matmul5_add3_cast1_cast5_add1_cast | 157291520 | 1 | 729.0229 | 215.7566 | 215.7566 | 319 | Y - 16 | fused_decode3_matmul6 | 774635520 | 1 | 868.1608 | 892.2719 | 892.2719 | 1331 | Y - 17 | fused_decode_fused_matmul2_add | 39324160 | 1 | 733.2646 | 53.6289 | 53.6289 | 191 | Y - 18 | fused_decode_fused_matmul2_add_add1 | 39326720 | 1 | 740.8926 | 53.0802 | 53.0802 | 192 | Y - 19 | fused_layer_norm1_cast8 | 4587520 | 1 | 76.3188 | 60.1099 | 60.1099 | 50 | Y - 20 | fused_layer_norm_cast1 | 35840 | 1 | 0.6533 | 54.8634 | 54.8634 | 159 | Y - 21 | fused_reshape2_squeeze | 1 | 1 | 0.0000 | 27.5470 | 27.5470 | 4 | Y - 22 | fused_slice1_cast6 | 1 | 1 | 0.0000 | 27.5899 | 27.5899 | 4 | Y - 23 | fused_transpose4_reshape4 | 1 | 1 | 0.0000 | 27.5157 | 27.5157 | 4 | Y - 24 | layer_norm | 35840 | 1 | 0.6506 | 55.0910 | 55.0910 | 160 | Y - 25 | layer_norm1 | 4587520 | 1 | 74.6941 | 61.4174 | 61.4174 | 50 | Y - 26 | matmul3 | 163840 | 1 | 5.8011 | 28.2428 | 28.2428 | 64 | Y - 27 | matmul9 | 20971520 | 1 | 571.2811 | 36.7096 | 36.7096 | 192 | Y - 28 | reshape | 1 | 1 | 0.0000 | 27.9399 | 27.9399 | 1 | Y - 29 | reshape1 | 1 | 1 | 0.0000 | 27.6659 | 27.6659 | 4 | Y - 30 | reshape2 | 1 | 1 | 0.0000 | 27.6446 | 27.6446 | 4 | Y - 31 | softmax2 | 201728 | 1 | 2.8631 | 70.4578 | 70.4578 | 186 | Y - 32 | squeeze | 1 | 1 | 0.0000 | 27.3156 | 27.3156 | 4 | Y - 33 | take_decode | 10240 | 1 | 0.3712 | 27.5835 | 27.5835 | 4 | Y - 34 | transpose2 | 1 | 1 | 0.0000 | 27.6975 | 27.6975 | 4 | Y ------------------------------------------------------------------------------------------------------------------------------------------------------------ -""" - -# fmt: off - -@I.ir_module -class Module: - @T.prim_func - def cast(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), compute: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(A[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.Cast("float32", A[v_i0, v_i1, v_i2]) - - @T.prim_func - def cast6(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), compute: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(A[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = A[v_i0, v_i1, v_i2] - - @T.prim_func - def decode4(A: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), B: T.Buffer((T.int64(80), T.int64(2560)), "float16"), T_transpose: T.Buffer((T.int64(2560), T.int64(2560)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") - for i, j in T.grid(T.int64(2560), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(2560), T.int64(2560)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - - @T.prim_func - def decode5(A: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), B: T.Buffer((T.int64(80), T.int64(10240)), "float16"), T_transpose: T.Buffer((T.int64(10240), T.int64(2560)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(2560), T.int64(10240)), "float16") - for i, j in T.grid(T.int64(2560), T.int64(10240)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(10240), T.int64(2560)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - - @T.prim_func - def decode6(A: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), B: T.Buffer((T.int64(320), T.int64(2560)), "float16"), T_transpose: T.Buffer((T.int64(2560), T.int64(10240)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") - for i, j in T.grid(T.int64(10240), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(2560), T.int64(10240)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - - @T.prim_func - def divide2(A: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(50432)): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[v_ax0, v_ax1, v_ax2], B[()]) - T.writes(T_divide[v_ax0, v_ax1, v_ax2]) - T_divide[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2] / B[()] - - @T.prim_func - def fused_decode1_fused_matmul4_add2_gelu_cast4(lv32: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), lv33: T.Buffer((T.int64(80), T.int64(10240)), "float16"), lv2115: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), linear_bias196: T.Buffer((T.int64(10240),), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(10240)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - T_multiply = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - T_multiply_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - T_add = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - var_T_multiply_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - for i, j in T.grid(T.int64(2560), T.int64(10240)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv32[v_i // T.int64(8), v_j], lv33[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv32[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv33[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(10240), T.int64(2560)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2115[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv2115[v_i0, v_i1, v_k]) * T.Cast("float32", var_decode_intermediate[v_k, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias196[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias196[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) - T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(T_multiply[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(compute[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T.writes(T_add[v_ax0, v_ax1, v_ax2]) - T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_multiply_2"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_multiply_intermediate[v_i0, v_i1, v_i2]) - T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) - p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_multiply_intermediate[v_i0, v_i1, v_i2]) - - @T.prim_func - def fused_decode2_fused_matmul5_add3_cast1_cast5_add1(lv38: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), lv39: T.Buffer((T.int64(320), T.int64(2560)), "float16"), lv2121: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16"), linear_bias197: T.Buffer((T.int64(2560),), "float32"), lv8: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - var_compute_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") - var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") - for i, j in T.grid(T.int64(10240), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv38[v_i // T.int64(8), v_j], lv39[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv38[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv39[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(10240)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2121[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv2121[v_i0, v_i1, v_k]) * T.Cast("float32", var_decode_intermediate[v_k, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias197[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias197[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2]) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) - var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[v_i0, v_i1, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], lv8[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv8[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_decode2_fused_matmul5_add3_cast1_cast5_add1_cast(lv1154: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), lv1155: T.Buffer((T.int64(320), T.int64(2560)), "float16"), lv4105: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16"), linear_bias383: T.Buffer((T.int64(2560),), "float32"), lv380: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - var_compute_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") - var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") - for i, j in T.grid(T.int64(10240), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1154[v_i // T.int64(8), v_j], lv1155[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1154[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1155[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(10240)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv4105[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv4105[v_i0, v_i1, v_k]) * T.Cast("float32", var_decode_intermediate[v_k, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias383[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias383[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2]) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) - var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[v_i0, v_i1, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], lv380[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv380[v_ax0, v_ax1, v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute_2"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) - p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - - @T.prim_func - def fused_decode3_matmul6(lv1160: T.Buffer((T.int64(320), T.int64(50432)), "uint32"), lv1161: T.Buffer((T.int64(80), T.int64(50432)), "float32"), lv384: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(50432))) - for i, j in T.grid(T.int64(2560), T.int64(50432)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1160[v_i // T.int64(8), v_j], lv1161[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.Cast("float16", T.bitwise_and(T.shift_right(lv1160[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1161[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(50432), T.int64(2560)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv384[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv384[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - - @T.prim_func - def fused_decode_fused_matmul2_add(lv8: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv9: T.Buffer((T.int64(80), T.int64(2560)), "float16"), lv2067: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), linear_bias192: T.Buffer((T.int64(2560),), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") - for i, j in T.grid(T.int64(2560), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv8[v_i // T.int64(8), v_j], lv9[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv8[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv9[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2067[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2067[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias192[v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias192[v_ax2] - - @T.prim_func - def fused_decode_fused_matmul2_add_add1(lv26: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv27: T.Buffer((T.int64(80), T.int64(2560)), "float16"), lv7: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), linear_bias195: T.Buffer((T.int64(2560),), "float16"), lv2062: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") - for i, j in T.grid(T.int64(2560), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv26[v_i // T.int64(8), v_j], lv27[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv26[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv27[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv7[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv7[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias195[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias195[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], lv2062[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] + lv2062[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_layer_norm_cast1(lv2064: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), weight67: T.Buffer((T.int64(2560),), "float32"), bias65: T.Buffer((T.int64(2560),), "float32"), var_compute_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - A_red_temp_v0 = T.alloc_buffer((T.int64(1), T.int64(1))) - A_red_temp_v1 = T.alloc_buffer((T.int64(1), T.int64(1))) - var_T_layer_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - for ax0, ax1, k2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("A_red_temp"): - v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) - T.reads(lv2064[v_ax0, v_ax1, v_k2]) - T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) - with T.init(): - A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) - A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + lv2064[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + lv2064[v_ax0, v_ax1, v_k2] * lv2064[v_ax0, v_ax1, v_k2] - A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 - A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_layer_norm"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv2064[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], weight67[v_ax2], bias65[v_ax2]) - T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = (lv2064[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight67[v_ax2] + bias65[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) - - @T.prim_func - def fused_reshape2_squeeze(lv2080: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), var_T_squeeze_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(80)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float16") - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(80)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(lv2080[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)]) - T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv2080[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(80)): - with T.block("T_squeeze"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2]) - T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2] = var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_slice1_cast6(lv4113: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), var_compute_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_slice_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("slice"): - v_i, v__, v_k = T.axis.remap("SSS", [i, _, k]) - T.reads(lv4113[v_i, T.int64(0), v_k]) - T.writes(var_slice_intermediate[v_i, v__, v_k]) - var_slice_intermediate[v_i, v__, v_k] = lv4113[v_i, T.int64(0), v_k] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_slice_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = var_slice_intermediate[v_i0, v_i1, v_i2] - - @T.prim_func - def fused_transpose4_reshape4(lv2105: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16"), var_T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_T_transpose_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float16") - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(80)): - with T.block("T_transpose"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(lv2105[v_ax0, v_ax2, v_ax1, v_ax3]) - T.writes(var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv2105[v_ax0, v_ax2, v_ax1, v_ax3] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)]) - T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2] = var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)] - - @T.prim_func - def layer_norm(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - A_red_temp_v0 = T.alloc_buffer((T.int64(1), T.int64(1))) - A_red_temp_v1 = T.alloc_buffer((T.int64(1), T.int64(1))) - for ax0, ax1, k2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("A_red_temp"): - v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) - T.reads(A[v_ax0, v_ax1, v_k2]) - T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) - with T.init(): - A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) - A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] - A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 - A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_layer_norm"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) - T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) - T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] - - @T.prim_func - def reshape(A: T.Buffer((T.int64(1), T.int64(1)), "int32"), T_reshape: T.Buffer((T.int64(1),), "int32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for ax0 in range(T.int64(1)): - with T.block("T_reshape"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - T.reads(A[T.int64(0), T.int64(0)]) - T.writes(T_reshape[v_ax0]) - T_reshape[v_ax0] = A[T.int64(0), T.int64(0)] - - @T.prim_func - def reshape1(A: T.Buffer((T.int64(1), T.int64(2560)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[T.int64(0), v_ax2 % T.int64(2560)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) - T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax2 % T.int64(2560)] - - @T.prim_func - def reshape2(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(80)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)] - - @T.prim_func - def softmax2(A: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(1))) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(50432))) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(1))) - for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(50432)): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) - T.reads(A[v_i0, v_i1, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], A[v_i0, v_i1, v_k]) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(50432)): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(A[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2]) - T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(A[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1]) - for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(50432)): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1]) - with T.init(): - T_softmax_expsum[v_i0, v_i1] = T.float32(0) - T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(50432)): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2]) - T.block_attr({"axis": 2}) - T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1] - - @T.prim_func - def squeeze(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float16"), T_squeeze: T.Buffer((T.int64(1), T.int64(32), T.int64(80)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(80)): - with T.block("T_squeeze"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2]) - T.writes(T_squeeze[v_ax0, v_ax1, v_ax2]) - T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2] - - @T.prim_func - def take_decode(A: T.Buffer((T.int64(50432), T.int64(320)), "uint32"), B: T.Buffer((T.int64(50432), T.int64(80)), "float16"), C: T.Buffer((T.int64(1),), "int32"), take_decode_1: T.Buffer((T.int64(1), T.int64(2560)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for i, j in T.grid(T.int64(1), T.int64(2560)): - with T.block("take_decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[C[v_i], v_j // T.int64(8)], C[v_i], B[C[v_i], v_j // T.int64(32)]) - T.writes(take_decode_1[v_i, v_j]) - take_decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[C[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[C[v_i], v_j // T.int64(32)] - - @T.prim_func - def transpose2(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80)): - with T.block("T_transpose"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) - T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) - T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] - - ####################################### Dynamic Shape ####################################### - - @T.prim_func - def fused_NT_matmul1_add4(p_lv9: T.handle, lv1173: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - lv9 = T.match_buffer(p_lv9, (T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv9[v_i0, v_i1, v_k], lv1173[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv9[v_i0, v_i1, v_k] * lv1173[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias[v_ax2] - - @T.prim_func - def fused_NT_matmul1_add4_add5(p_lv49: T.handle, lv1194: T.Buffer((T.int64(2560), T.int64(2560)), "float16"), linear_bias3: T.Buffer((T.int64(2560),), "float16"), p_lv2: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560)), "float16") - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv49[v_i0, v_i1, v_k], lv1194[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv49[v_i0, v_i1, v_k] * lv1194[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias3[v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias3[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2], lv2[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + lv2[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_NT_matmul2_divide1_maximum1_minimum1_cast9(p_lv36: T.handle, p_lv37: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - m = T.meta_var(T.int64(128)) - lv36 = T.match_buffer(p_lv36, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - lv37 = T.match_buffer(p_lv37, (T.int64(1), T.int64(32), m, T.int64(80)), "float16") - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(80)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv36[v_i0, v_i1, v_i2, v_k], lv37[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv36[v_i0, v_i1, v_i2, v_k] * lv37[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.11179039301310044) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - - @T.prim_func - def fused_NT_matmul3_add6_gelu1_cast11(p_lv57: T.handle, lv1201: T.Buffer((T.int64(10240), T.int64(2560)), "float16"), linear_bias4: T.Buffer((T.int64(10240),), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - lv57 = T.match_buffer(p_lv57, (T.int64(1), n, T.int64(2560)), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(10240)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_multiply = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - compute = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_multiply_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_add = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - var_T_multiply_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(10240), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv57[v_i0, v_i1, v_k], lv1201[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv57[v_i0, v_i1, v_k] * lv1201[v_i2, v_k]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias4[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias4[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) - T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(T_multiply[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(compute[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T.writes(T_add[v_ax0, v_ax1, v_ax2]) - T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply_2"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_multiply_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_multiply_intermediate[v_i0, v_i1, v_i2]) - - @T.prim_func - def fused_NT_matmul4_add7_cast8_cast12_add5(p_lv63: T.handle, lv1208: T.Buffer((T.int64(2560), T.int64(10240)), "float16"), linear_bias5: T.Buffer((T.int64(2560),), "float32"), p_lv53: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - lv63 = T.match_buffer(p_lv63, (T.int64(1), n, T.int64(10240)), "float16") - lv53 = T.match_buffer(p_lv53, (T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_compute_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv63[v_i0, v_i1, v_k], lv1208[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv63[v_i0, v_i1, v_k] * lv1208[v_i2, v_k]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias5[v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias5[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) - var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[v_i0, v_i1, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], lv53[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv53[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_NT_matmul4_add7_cast8_cast12_add5_cast7(p_lv2047: T.handle, lv2510: T.Buffer((T.int64(2560), T.int64(10240)), "float16"), linear_bias191: T.Buffer((T.int64(2560),), "float32"), p_lv2037: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - lv2047 = T.match_buffer(p_lv2047, (T.int64(1), n, T.int64(10240)), "float16") - lv2037 = T.match_buffer(p_lv2037, (T.int64(1), n, T.int64(2560)), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_compute_intermediate_2 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2047[v_i0, v_i1, v_k], lv2510[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv2047[v_i0, v_i1, v_k] * lv2510[v_i2, v_k]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias191[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias191[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) - var_compute_intermediate_1[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2]) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute_1"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_compute_intermediate_1[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate_2[v_i0, v_i1, v_i2]) - var_compute_intermediate_2[v_i0, v_i1, v_i2] = var_compute_intermediate_1[v_i0, v_i1, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_compute_intermediate_2[v_ax0, v_ax1, v_ax2], lv2037[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_2[v_ax0, v_ax1, v_ax2] + lv2037[v_ax0, v_ax1, v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute_2"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - - @T.prim_func - def fused_NT_matmul_divide_maximum_minimum_cast2(lv2094: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16"), p_lv2095: T.handle, p_lv2063: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - lv2095 = T.match_buffer(p_lv2095, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - lv2063 = T.match_buffer(p_lv2063, (T.int64(1), T.int64(1), T.int64(1), n), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(80)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv2094[v_i0, v_i1, v_i2, v_k], lv2095[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv2094[v_i0, v_i1, v_i2, v_k] * lv2095[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.11179039301310044) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2063[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2063[v_ax0, T.int64(0), v_ax2, v_ax3]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - - @T.prim_func - def fused_layer_norm1_cast8(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") - # with T.block("root"): - A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) - A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) - var_T_layer_norm_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("A_red_temp"): - v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) - T.reads(lv6[v_ax0, v_ax1, v_k2]) - T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) - with T.init(): - A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) - A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] - A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 - A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_layer_norm"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv6[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], weight1[v_ax2], bias[v_ax2]) - T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = (lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * weight1[v_ax2] + bias[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_layer_norm_intermediate[v_i0, v_i1, v_i2]) - - @T.prim_func - def layer_norm1(var_A: T.handle, B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), var_T_layer_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) - T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) - A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) - for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("A_red_temp"): - v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) - T.reads(A[v_ax0, v_ax1, v_k2]) - T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) - with T.init(): - A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) - A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] - A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 - A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_layer_norm"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) - T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) - T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] - - @T.prim_func - def matmul3(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(32)) - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80), n): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] - - @T.prim_func - def matmul9(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - m = T.meta_var(T.int64(32)) - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), "float16") - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(80)), "float16") - matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(80)), "float16") - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(80), m): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] - -# fmt: on diff --git a/mlc_llm/dispatch/gpt_neox/redpajama_q4f32.py b/mlc_llm/dispatch/gpt_neox/redpajama_q4f32.py deleted file mode 100644 index b6e91233b3..0000000000 --- a/mlc_llm/dispatch/gpt_neox/redpajama_q4f32.py +++ /dev/null @@ -1,840 +0,0 @@ -# pylint: disable=missing-docstring,line-too-long,invalid-name,too-many-statements,too-many-locals -import tvm -from tvm import tir -from tvm.script import tir as T - -from .redpajama_q4f32_mod import Module as MOD - -# fmt: off - -def fused_NT_matmul1_divide_maximum_minimum(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 1, 32, 32, 1]) - l1, l2, l3, l4, l5 = sch.get_loops(b0) - l6, l7 = sch.split(l3, [None, 32]) - l8, l9 = sch.split(l4, [None, 32]) - sch.reorder(l6, l8, l1, l2, l7, l9, l5) - - b1 = sch.get_block(name="T_divide", func_name="main") - b2 = sch.get_block(name="T_maximum", func_name="main") - b3 = sch.get_block(name="T_minimum", func_name="main") - b4 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) - v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) - v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[16, 1, 2, 1, 1]) - l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) - v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[8, 1, 4, 1, 4]) - l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) - v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[8, 1, 4, 2, 2]) - l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) - v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[10, 4, 2]) - l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) - sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) - l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) - sch.bind(loop=l56, thread_axis="blockIdx.x") - l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) - sch.bind(loop=l57, thread_axis="vthread.x") - l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) - sch.bind(loop=l58, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) - b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) - l65, l66, l67, l68 = sch.get_loops(block=b60)[-4:] - sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) - v70 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) - sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) - b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) - l76, l77, l78, l79 = sch.get_loops(block=b71)[-4:] - sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) - v81 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) - sch.reverse_compute_inline(block=b3) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) - sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) - sch.enter_postproc() - sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") - l87 = sch.get_loops(block=b60)[-1] - _, l89, l90 = sch.split(loop=l87, factors=[None, 32, 4], preserve_unit_iters=True) - sch.vectorize(loop=l90) - sch.bind(loop=l89, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") - l95 = sch.get_loops(block=b71)[-1] - _, l97 = sch.split(loop=l95, factors=[None, 32], preserve_unit_iters=True) - sch.bind(loop=l97, thread_axis="threadIdx.x") - b98 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b98, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv34_pad") - sch.compute_inline(b1) - b1 = sch.get_block("lv35_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - b140 = sch.get_block(name="NT_matmul", func_name="main") - l144 = sch.get_loops(block=b140)[5] - sch.decompose_reduction(block=b140, loop=l144) - - b101 = sch.get_child_blocks(b98)[2] - l116 = sch.get_loops(block=b101)[0] - sch.annotate(block_or_loop=l116, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l116, ann_key="pragma_unroll_explicit", ann_val=1) - - -def fused_NT_matmul2_add2_gelu(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="T_multiply", func_name="main") - b3 = sch.get_block(name="compute", func_name="main") - b4 = sch.get_block(name="T_multiply_1", func_name="main") - b5 = sch.get_block(name="T_add_1", func_name="main") - b6 = sch.get_block(name="T_multiply_2", func_name="main") - b7 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l8, l9, l10, l11 = sch.get_loops(block=b0) - v12, v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l17, l18, l19, l20, l21 = sch.split(loop=l8, factors=[v12, v13, v14, v15, v16], preserve_unit_iters=True) - v22, v23, v24, v25, v26 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64, decision=[1, 2, 16, 2, 2]) - l27, l28, l29, l30, l31 = sch.split(loop=l9, factors=[v22, v23, v24, v25, v26], preserve_unit_iters=True) - v32, v33, v34, v35, v36 = sch.sample_perfect_tile(loop=l10, n=5, max_innermost_factor=64, decision=[320, 1, 8, 4, 1]) - l37, l38, l39, l40, l41 = sch.split(loop=l10, factors=[v32, v33, v34, v35, v36], preserve_unit_iters=True) - v42, v43, v44 = sch.sample_perfect_tile(loop=l11, n=3, max_innermost_factor=64, decision=[160, 4, 4]) - l45, l46, l47 = sch.split(loop=l11, factors=[v42, v43, v44], preserve_unit_iters=True) - sch.reorder(l17, l27, l37, l18, l28, l38, l19, l29, l39, l45, l46, l20, l30, l40, l47, l21, l31, l41) - l48 = sch.fuse(l17, l27, l37, preserve_unit_iters=True) - sch.bind(loop=l48, thread_axis="blockIdx.x") - l49 = sch.fuse(l18, l28, l38, preserve_unit_iters=True) - sch.bind(loop=l49, thread_axis="vthread.x") - l50 = sch.fuse(l19, l29, l39, preserve_unit_iters=True) - sch.bind(loop=l50, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024) - b51 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b51, loop=l50, preserve_unit_loops=True, index=-1) - b52 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b52, loop=l45, preserve_unit_loops=True, index=-1) - l57, l58, l59 = sch.get_loops(block=b52)[-3:] - sch.fuse(l57, l58, l59, preserve_unit_iters=True) - v61 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v61) - b62 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b62, loop=l45, preserve_unit_loops=True, index=-1) - l67, l68 = sch.get_loops(block=b62)[-2:] - sch.fuse(l67, l68, preserve_unit_iters=True) - v70 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b62, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) - sch.compute_inline(block=b5) - sch.compute_inline(block=b4) - sch.compute_inline(block=b3) - sch.compute_inline(block=b2) - sch.compute_inline(block=b1) - sch.reverse_compute_inline(block=b6) - v71 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) - sch.annotate(block_or_loop=b7, ann_key="meta_schedule.unroll_explicit", ann_val=v71) - sch.enter_postproc() - sch.unannotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch") - l76 = sch.get_loops(block=b52)[-1] - _, l78, l79 = sch.split(loop=l76, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l79) - sch.bind(loop=l78, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b62, ann_key="meta_schedule.cooperative_fetch") - l84 = sch.get_loops(block=b62)[-1] - _, l86 = sch.split(loop=l84, factors=[None, 64], preserve_unit_iters=True) - sch.bind(loop=l86, thread_axis="threadIdx.x") - b87 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b87, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv51_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, _, b90, _ = sch.get_child_blocks(b87) - l105 = sch.get_loops(block=b90)[0] - sch.annotate(block_or_loop=l105, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l105, ann_key="pragma_unroll_explicit", ann_val=1) - b123 = sch.get_block(name="NT_matmul", func_name="main") - l127 = sch.get_loops(block=b123)[4] - sch.decompose_reduction(block=b123, loop=l127) - - -def fused_NT_matmul3_add_cast_add1(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="compute", func_name="main") - b3 = sch.get_block(name="T_add_1", func_name="main") - b4 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l5, l6, l7, l8 = sch.get_loops(block=b0) - v9, v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l14, l15, l16, l17, l18 = sch.split(loop=l5, factors=[v9, v10, v11, v12, v13], preserve_unit_iters=True) - v19, v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[1, 4, 32, 1, 1]) - l24, l25, l26, l27, l28 = sch.split(loop=l6, factors=[v19, v20, v21, v22, v23], preserve_unit_iters=True) - v29, v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[40, 1, 4, 16, 1]) - l34, l35, l36, l37, l38 = sch.split(loop=l7, factors=[v29, v30, v31, v32, v33], preserve_unit_iters=True) - v39, v40, v41 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64, decision=[640, 4, 4]) - l42, l43, l44 = sch.split(loop=l8, factors=[v39, v40, v41], preserve_unit_iters=True) - sch.reorder(l14, l24, l34, l15, l25, l35, l16, l26, l36, l42, l43, l17, l27, l37, l44, l18, l28, l38) - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="blockIdx.x") - l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) - sch.bind(loop=l46, thread_axis="vthread.x") - l47 = sch.fuse(l16, l26, l36, preserve_unit_iters=True) - sch.bind(loop=l47, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b48 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b48, loop=l47, preserve_unit_loops=True, index=-1) - b49 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b49, loop=l42, preserve_unit_loops=True, index=-1) - l54, l55, l56 = sch.get_loops(block=b49)[-3:] - sch.fuse(l54, l55, l56, preserve_unit_iters=True) - v58 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_val=v58) - b59 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b59, loop=l42, preserve_unit_loops=True, index=-1) - l64, l65 = sch.get_loops(block=b59)[-2:] - sch.fuse(l64, l65, preserve_unit_iters=True) - v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b59, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) - sch.reverse_compute_inline(block=b3) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v68 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v68) - sch.enter_postproc() - sch.unannotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch") - l73 = sch.get_loops(block=b49)[-1] - _, l75, l76 = sch.split(loop=l73, factors=[None, 128, 2], preserve_unit_iters=True) - sch.vectorize(loop=l76) - sch.bind(loop=l75, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b59, ann_key="meta_schedule.cooperative_fetch") - l81 = sch.get_loops(block=b59)[-1] - _, l83, l84 = sch.split(loop=l81, factors=[None, 128, 2], preserve_unit_iters=True) - sch.vectorize(loop=l84) - sch.bind(loop=l83, thread_axis="threadIdx.x") - b85 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b85, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv56_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, _, b88, _ = sch.get_child_blocks(b85) - l104 = sch.get_loops(block=b88)[0] - sch.annotate(block_or_loop=l104, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l104, ann_key="pragma_unroll_explicit", ann_val=1) - b121 = sch.get_block(name="NT_matmul", func_name="main") - l125 = sch.get_loops(block=b121)[4] - sch.decompose_reduction(block=b121, loop=l125) - - -def fused_NT_matmul4_divide2_maximum1_minimum1(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 1, 1, 32, 1]) - l1, l2, l3, l4, l5 = sch.get_loops(b0) - l6, l7 = sch.split(l4, [None, 32]) - sch.reorder(l6, l1, l2, l3, l7, l5) - - b1 = sch.get_block(name="T_divide", func_name="main") - b2 = sch.get_block(name="T_maximum", func_name="main") - b3 = sch.get_block(name="T_minimum", func_name="main") - b4 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) - v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) - v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[16, 2, 1, 1, 1]) - l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) - v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) - v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 32, 1, 2]) - l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) - v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[20, 2, 2]) - l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) - sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) - l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) - sch.bind(loop=l56, thread_axis="blockIdx.x") - l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) - sch.bind(loop=l57, thread_axis="vthread.x") - l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) - sch.bind(loop=l58, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) - b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) - l65, l66, l67, l68 = sch.get_loops(block=b60)[-4:] - sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) - v70 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) - b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) - l76, l77, l78, l79 = sch.get_loops(block=b71)[-4:] - sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) - v81 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) - sch.reverse_compute_inline(block=b3) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0) - sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) - sch.enter_postproc() - sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") - l87 = sch.get_loops(block=b60)[-1] - _, l89, l90 = sch.split(loop=l87, factors=[None, 16, 2], preserve_unit_iters=True) - sch.vectorize(loop=l90) - sch.bind(loop=l89, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") - l95 = sch.get_loops(block=b71)[-1] - _, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True) - sch.bind(loop=l97, thread_axis="threadIdx.x") - b98 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b98, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv1836_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - b140 = sch.get_block(name="NT_matmul", func_name="main") - l144 = sch.get_loops(block=b140)[4] - sch.decompose_reduction(block=b140, loop=l144) - - -def fused_NT_matmul_add(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l3, l4, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[2, 4, 8, 1, 2]) - l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[64, 5, 8, 1, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[320, 2, 4]) - l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="blockIdx.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="vthread.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) - b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) - l52, l53, l54 = sch.get_loops(block=b47)[-3:] - sch.fuse(l52, l53, l54, preserve_unit_iters=True) - v56 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) - b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) - l62, l63 = sch.get_loops(block=b57)[-2:] - sch.fuse(l62, l63, preserve_unit_iters=True) - v65 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) - sch.reverse_compute_inline(block=b1) - v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) - sch.enter_postproc() - sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") - l71 = sch.get_loops(block=b47)[-1] - _, l73, l74 = sch.split(loop=l71, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l74) - sch.bind(loop=l73, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - l79 = sch.get_loops(block=b57)[-1] - _, l81 = sch.split(loop=l79, factors=[None, 64], preserve_unit_iters=True) - sch.bind(loop=l81, thread_axis="threadIdx.x") - b82 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv7_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, _, b85, _ = sch.get_child_blocks(b82) - l100 = sch.get_loops(block=b85)[0] - sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) - b118 = sch.get_block(name="NT_matmul", func_name="main") - l122 = sch.get_loops(block=b118)[4] - sch.decompose_reduction(block=b118, loop=l122) - - -def fused_NT_matmul_add_add1(sch: tir.Schedule): - b0 = sch.get_block(name="NT_matmul", func_name="main") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="T_add_1", func_name="main") - b3 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l4, l5, l6, l7 = sch.get_loops(block=b0) - v8, v9, v10, v11, v12 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l13, l14, l15, l16, l17 = sch.split(loop=l4, factors=[v8, v9, v10, v11, v12], preserve_unit_iters=True) - v18, v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[2, 2, 32, 1, 1]) - l23, l24, l25, l26, l27 = sch.split(loop=l5, factors=[v18, v19, v20, v21, v22], preserve_unit_iters=True) - v28, v29, v30, v31, v32 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[80, 2, 1, 16, 1]) - l33, l34, l35, l36, l37 = sch.split(loop=l6, factors=[v28, v29, v30, v31, v32], preserve_unit_iters=True) - v38, v39, v40 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64, decision=[320, 1, 8]) - l41, l42, l43 = sch.split(loop=l7, factors=[v38, v39, v40], preserve_unit_iters=True) - sch.reorder(l13, l23, l33, l14, l24, l34, l15, l25, l35, l41, l42, l16, l26, l36, l43, l17, l27, l37) - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="blockIdx.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="vthread.x") - l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) - sch.bind(loop=l46, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b47 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b47, loop=l46, preserve_unit_loops=True, index=-1) - b48 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b48, loop=l41, preserve_unit_loops=True, index=-1) - l53, l54, l55 = sch.get_loops(block=b48)[-3:] - sch.fuse(l53, l54, l55, preserve_unit_iters=True) - v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) - sch.annotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch", ann_val=v57) - b58 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b58, loop=l41, preserve_unit_loops=True, index=-1) - l63, l64 = sch.get_loops(block=b58)[-2:] - sch.fuse(l63, l64, preserve_unit_iters=True) - v66 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) - sch.annotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch", ann_val=v66) - sch.reverse_compute_inline(block=b2) - sch.reverse_compute_inline(block=b1) - v67 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v67) - sch.enter_postproc() - sch.unannotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch") - l72 = sch.get_loops(block=b48)[-1] - _, l74, l75 = sch.split(loop=l72, factors=[None, 32, 4], preserve_unit_iters=True) - sch.vectorize(loop=l75) - sch.bind(loop=l74, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch") - l80 = sch.get_loops(block=b58)[-1] - _, l82, l83 = sch.split(loop=l80, factors=[None, 32, 4], preserve_unit_iters=True) - sch.vectorize(loop=l83) - sch.bind(loop=l82, thread_axis="threadIdx.x") - b84 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b84, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("lv45_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - _, _, b87, _ = sch.get_child_blocks(b84) - l103 = sch.get_loops(block=b87)[0] - sch.annotate(block_or_loop=l103, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l103, ann_key="pragma_unroll_explicit", ann_val=1) - b121 = sch.get_block(name="NT_matmul", func_name="main") - l125 = sch.get_loops(block=b121)[4] - sch.decompose_reduction(block=b121, loop=l125) - - - -def layer_norm(sch: tir.Schedule): - b0 = sch.get_block(name="A_red_temp", func_name="main") - b1 = sch.get_block(name="T_layer_norm", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125], decision=4) - _, _, l6 = sch.get_loops(block=b0) - _, l8 = sch.split(loop=l6, factors=[None, v3], preserve_unit_iters=True) - sch.bind(loop=l8, thread_axis="threadIdx.x") - v9 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v9) - l10, l11, l12 = sch.get_loops(block=b1) - l13 = sch.fuse(l10, l11, l12, preserve_unit_iters=True) - l14, l15, l16 = sch.split(loop=l13, factors=[None, 256, 256], preserve_unit_iters=True) - sch.reorder(l15, l16, l14) - sch.bind(loop=l15, thread_axis="blockIdx.x") - sch.bind(loop=l16, thread_axis="threadIdx.x") - l17, l18, _, _ = sch.get_loops(block=b0) - l21 = sch.fuse(l17, l18, preserve_unit_iters=True) - sch.bind(loop=l21, thread_axis="blockIdx.x") - sch.enter_postproc() - b22 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.unroll_explicit") - b23, _ = sch.get_child_blocks(b22) - l25, _, _ = sch.get_loops(block=b23) - sch.annotate(block_or_loop=l25, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l25, ann_key="pragma_unroll_explicit", ann_val=1) - - -def matmul(sch: tir.Schedule): - b0 = sch.get_block(name="matmul", func_name="main") - sch.pad_einsum(b0, [1, 1, 32, 1, 32]) - l1, l2, l3, l4, k = sch.get_loops(b0) - s0, s1 = sch.split(l3, [None, 32]) - k0, k1 = sch.split(k, [None, 32]) - sch.reorder(s0, l1, l2, s1, k0, l4, k1) - - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l2, l3, l4, _, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[8, 4, 1, 1, 1]) - l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[16, 4, 2, 1, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 80, 1, 1]) - l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) - v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[8, 4, 1]) - l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, k0, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) - - l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) - sch.bind(loop=l53, thread_axis="blockIdx.x") - l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) - sch.bind(loop=l54, thread_axis="vthread.x") - l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) - sch.bind(loop=l55, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) - b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) - l62, l63, l64, l65 = sch.get_loops(block=b57)[-4:] - sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) - v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) - b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) - l73, l74, l75, l76 = sch.get_loops(block=b68)[-4:] - sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) - v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) - v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) - sch.enter_postproc() - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - l84 = sch.get_loops(block=b57)[-1] - _, l86, l87 = sch.split(loop=l84, factors=[None, 160, 4], preserve_unit_iters=True) - sch.vectorize(loop=l87) - sch.bind(loop=l86, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") - l92 = sch.get_loops(block=b68)[-1] - _, l94, l95 = sch.split(loop=l92, factors=[None, 160, 2], preserve_unit_iters=True) - sch.vectorize(loop=l95) - sch.bind(loop=l94, thread_axis="threadIdx.x") - b96 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b96, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("A_pad") - sch.compute_inline(b1) - b1 = sch.get_block("B_pad") - sch.compute_inline(b1) - b1 = sch.get_block("matmul_1_pad") - sch.reverse_compute_inline(b1) - - _, _, b99, _ = sch.get_child_blocks(b96) - l115 = sch.get_loops(block=b99)[0] - sch.annotate(block_or_loop=l115, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l115, ann_key="pragma_unroll_explicit", ann_val=1) - b136 = sch.get_block(name="matmul", func_name="main") - l140 = sch.get_loops(block=b136)[4] - sch.decompose_reduction(block=b136, loop=l140) - - - -def matmul8(sch: tir.Schedule): - b0 = sch.get_block(name="matmul", func_name="main") - sch.pad_einsum(b0, [1, 1, 1, 1, 32]) - l1, l2, l3, l4, k = sch.get_loops(b0) - k0, k1 = sch.split(k, [None, 32]) - sch.reorder(l1, l2, l3, k0, l4, k1) - - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - l2, l3, l4, _, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 2, 1, 1]) - l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[2, 1, 40, 1, 1]) - l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) - v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[8, 2, 2]) - l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, k0, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) - - l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) - sch.bind(loop=l53, thread_axis="blockIdx.x") - l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) - sch.bind(loop=l54, thread_axis="vthread.x") - l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) - sch.bind(loop=l55, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) - b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) - l62, l63, l64, l65 = sch.get_loops(block=b57)[-4:] - sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) - v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) - b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) - l73, l74, l75, l76 = sch.get_loops(block=b68)[-4:] - sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) - v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) - v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) - sch.enter_postproc() - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - l84 = sch.get_loops(block=b57)[-1] - _, l86 = sch.split(loop=l84, factors=[None, 80], preserve_unit_iters=True) - sch.bind(loop=l86, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") - l91 = sch.get_loops(block=b68)[-1] - _, l93 = sch.split(loop=l91, factors=[None, 80], preserve_unit_iters=True) - sch.bind(loop=l93, thread_axis="threadIdx.x") - b94 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b94, ann_key="meta_schedule.unroll_explicit") - - b1 = sch.get_block("A_pad") - sch.compute_inline(b1) - b1 = sch.get_block("B_pad") - sch.compute_inline(b1) - - b132 = sch.get_block(name="matmul", func_name="main") - l136 = sch.get_loops(block=b132)[3] - sch.decompose_reduction(block=b132, loop=l136) - - -@T.prim_func -def softmax_mxn_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - m = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - - -@T.prim_func -def softmax_mxn_after(var_A: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - m = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_maxelem_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)]) - T.writes(T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) - T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float32(-3.4028234663852886e+38))) - for i0_i1_i2_1_fused_0 in range(T.int64(8)): - for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem_cache_write"): - v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) - T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) - T.reads(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_expsum_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)], T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T.writes(T_softmax_expsum[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(0) - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]), T.float32(0)) - for i0_i1_i2_1_fused_0 in range(T.int64(8)): - for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum_cache_write"): - v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) - T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) - T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] - for i0_i1_i2_fused_i3_fused_0 in T.thread_binding((n * T.int64(32) * m + T.int64(255)) // T.int64(256), thread="blockIdx.x"): - for i0_i1_i2_fused_i3_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("T_softmax_norm"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // m // n) - v_i2 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // m % n) - v_i3 = T.axis.spatial(m, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) % m) - T.where(i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1 < n * T.int64(32) * m) - T.reads(T_softmax_expsum[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] - - -def fused_min_max_triu_te_broadcast_to(sch: tir.Schedule): - b0 = sch.get_block("T_broadcast_to") - sch.reverse_compute_inline(b0) - b1 = sch.get_block("make_diag_mask_te") - i, j = sch.get_loops(b1) - i = sch.fuse(i, j) - i, j = sch.split(i, [None, 128]) - sch.bind(i, "blockIdx.x") - sch.bind(j, "threadIdx.x") - -def softmax_1xn(sch: tir.Schedule): - has_cast = False - if has_cast: - b_cast = sch.get_block("compute") - sch.reverse_compute_inline(b_cast) - - b0 = sch.get_block("T_softmax_exp") - sch.compute_inline(b0) - b1 = sch.get_block("T_softmax_norm") - l2, l3, l4, l5 = sch.get_loops(b1) - _, l7 = sch.split(l5, [None, 128]) - sch.bind(l7, "threadIdx.x") - b8 = sch.get_block("T_softmax_expsum") - sch.compute_at(b8, l4) - sch.set_scope(b8, 0, "shared") - _, _, _, l12 = sch.get_loops(b8) - _, l14 = sch.split(l12, [None, 128]) - sch.bind(l14, "threadIdx.x") - b15 = sch.get_block("T_softmax_maxelem") - sch.compute_at(b15, l4) - sch.set_scope(b15, 0, "shared") - _, _, _, l19 = sch.get_loops(b15) - _, l21 = sch.split(l19, [None, 128]) - sch.bind(l21, "threadIdx.x") - l22 = sch.fuse(l2, l3, l4) - sch.bind(l22, "blockIdx.x") - -def _get_dict(): - tvm.ir.assert_structural_equal(MOD["softmax"], softmax_mxn_before) - func_dict = { - softmax_mxn_before: softmax_mxn_after, - } - for name, func in [ - # fmt: off - ("fused_NT_matmul1_divide_maximum_minimum", fused_NT_matmul1_divide_maximum_minimum), - ("fused_NT_matmul2_add2_gelu", fused_NT_matmul2_add2_gelu), - ("fused_NT_matmul3_add_cast_add1", fused_NT_matmul3_add_cast_add1), - ("fused_NT_matmul4_divide2_maximum1_minimum1", fused_NT_matmul4_divide2_maximum1_minimum1), - ("fused_NT_matmul_add", fused_NT_matmul_add), - ("fused_NT_matmul_add_add1", fused_NT_matmul_add_add1), - ("layer_norm", layer_norm), - ("matmul", matmul), - ("matmul8", matmul8), - ("softmax2", softmax_1xn), - ("fused_min_max_triu_te_broadcast_to", fused_min_max_triu_te_broadcast_to), - # fmt: on - ]: - # print(f"############### {name} ###############") - sch = tir.Schedule(MOD[name]) - func(sch) - # sch.mod["main"].show(black_format=False) - func_dict[MOD[name]] = sch.mod["main"] - return { - (tvm.ir.structural_hash(k), k): v.with_attr("tir.is_scheduled", True) - for k, v in func_dict.items() - } - - -DICT = _get_dict() - - -def lookup(func): - for (hash_value, func_before), f_after in DICT.items(): - if tvm.ir.structural_hash(func) == hash_value and tvm.ir.structural_equal( - func, func_before - ): - return f_after - return None diff --git a/mlc_llm/dispatch/gpt_neox/redpajama_q4f32_mod.py b/mlc_llm/dispatch/gpt_neox/redpajama_q4f32_mod.py deleted file mode 100644 index b6c4cbc33d..0000000000 --- a/mlc_llm/dispatch/gpt_neox/redpajama_q4f32_mod.py +++ /dev/null @@ -1,577 +0,0 @@ -# pylint: disable=pointless-string-statement,invalid-name,missing-docstring,line-too-long,too-many-locals,too-many-arguments,too-many-statements -from tvm.script import ir as I -from tvm.script import tir as T - -# fmt: off - -@I.ir_module -class Module: - @T.prim_func - def extend_te(var_A: T.handle, var_concat_te: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(1), n, n)) - m = T.int64() - concat_te = T.match_buffer(var_concat_te, (T.int64(1), T.int64(1), n, m)) - # with T.block("root"): - for b, _, i, j in T.grid(T.int64(1), T.int64(1), n, m): - with T.block("concat_te"): - v_b, v__, v_i, v_j = T.axis.remap("SSSS", [b, _, i, j]) - T.reads(A[v_b, v__, v_i, v_j + n - m]) - T.writes(concat_te[v_b, v__, v_i, v_j]) - concat_te[v_b, v__, v_i, v_j] = T.if_then_else(v_j < m - n, T.float32(3.4028234663852886e+38), A[v_b, v__, v_i, v_j + n - m]) - - @T.prim_func - def full(var_T_full: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(1), T.int64(1), n)) - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), n): - with T.block("T_full"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads() - T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3]) - T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(3.4028234663852886e+38) - - @T.prim_func - def fused_NT_matmul1_divide_maximum_minimum(p_lv34: T.handle, p_lv35: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv34 = T.match_buffer(p_lv34, (T.int64(1), T.int64(32), n, T.int64(80))) - m = T.int64() - lv35 = T.match_buffer(p_lv35, (T.int64(1), T.int64(32), m, T.int64(80))) - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m)) - var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(80)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv34[v_i0, v_i1, v_i2, v_k], lv35[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv34[v_i0, v_i1, v_i2, v_k] * lv35[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.11180339723346898) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - - @T.prim_func - def fused_NT_matmul2_add2_gelu(p_lv51: T.handle, lv38: T.Buffer((T.int64(10240), T.int64(2560)), "float32"), linear_bias4: T.Buffer((T.int64(10240),), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv51 = T.match_buffer(p_lv51, (T.int64(1), n, T.int64(2560))) - var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(10240))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_multiply = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - compute = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_multiply_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_add = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(10240), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv51[v_i0, v_i1, v_k], lv38[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv51[v_i0, v_i1, v_k] * lv38[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias4[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias4[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) - T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(T_multiply[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(compute[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T.writes(T_add[v_ax0, v_ax1, v_ax2]) - T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply_2"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_NT_matmul3_add_cast_add1(p_lv56: T.handle, lv45: T.Buffer((T.int64(2560), T.int64(10240)), "float32"), linear_bias5: T.Buffer((T.int64(2560),), "float32"), p_lv49: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv56 = T.match_buffer(p_lv56, (T.int64(1), n, T.int64(10240))) - lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560))) - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_compute_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv56[v_i0, v_i1, v_k], lv45[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv56[v_i0, v_i1, v_k] * lv45[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias5[v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias5[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = var_T_add_intermediate_1[v_i0, v_i1, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_compute_intermediate[v_ax0, v_ax1, v_ax2], lv49[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate[v_ax0, v_ax1, v_ax2] + lv49[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_NT_matmul4_divide2_maximum1_minimum1(lv1835: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float32"), p_lv1836: T.handle, p_lv1806: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv1836 = T.match_buffer(p_lv1836, (T.int64(1), T.int64(32), n, T.int64(80))) - lv1806 = T.match_buffer(p_lv1806, (T.int64(1), T.int64(1), T.int64(1), n)) - var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(80)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv1835[v_i0, v_i1, v_i2, v_k], lv1836[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1835[v_i0, v_i1, v_i2, v_k] * lv1836[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.11180339723346898) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1806[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1806[v_ax0, T.int64(0), v_ax2, v_ax3]) - - @T.prim_func - def fused_NT_matmul_add(p_lv7: T.handle, lv10: T.Buffer((T.int64(2560), T.int64(2560)), "float32"), linear_bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv7 = T.match_buffer(p_lv7, (T.int64(1), n, T.int64(2560))) - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv7[v_i0, v_i1, v_k], lv10[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv7[v_i0, v_i1, v_k] * lv10[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias[v_ax2] - - @T.prim_func - def fused_NT_matmul_add_add1(p_lv45: T.handle, lv31: T.Buffer((T.int64(2560), T.int64(2560)), "float32"), linear_bias3: T.Buffer((T.int64(2560),), "float32"), p_lv2: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(2560))) - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560))) - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv45[v_i0, v_i1, v_k], lv31[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * lv31[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias3[v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias3[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2], lv2[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + lv2[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int64): - T.func_attr({"tir.noalias": T.bool(True)}) - var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), n, n)) - # with T.block("root"): - var_make_diag_mask_te_intermediate = T.alloc_buffer((n, n)) - for i, j in T.grid(n, n): - with T.block("make_diag_mask_te"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads() - T.writes(var_make_diag_mask_te_intermediate[v_i, v_j]) - var_make_diag_mask_te_intermediate[v_i, v_j] = T.Select(v_i < v_j, T.float32(-3.4028234663852886e+38), T.float32(3.4028234663852886e+38)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), n, n): - with T.block("T_broadcast_to"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_make_diag_mask_te_intermediate[v_ax2, v_ax3]) - T.writes(var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_make_diag_mask_te_intermediate[v_ax2, v_ax3] - - @T.prim_func - def layer_norm(var_A: T.handle, B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), var_T_layer_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) - T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) - A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) - for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("A_red_temp"): - v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) - T.reads(A[v_ax0, v_ax1, v_k2]) - T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) - with T.init(): - A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) - A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] - A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 - A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_layer_norm"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) - T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) - T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] - - @T.prim_func - def matmul(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n, m = T.int64(), T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(80))) - matmul_1 = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(80))) - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(80), m): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] - - @T.prim_func - def matmul8(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n)) - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(80))) - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80), n): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] - - @T.prim_func - def reshape(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n), "int32") - T_reshape = T.match_buffer(var_T_reshape, (n,), "int32") - # with T.block("root"): - for ax0 in range(n): - with T.block("T_reshape"): - v_ax0 = T.axis.spatial(n, ax0) - T.reads(A[T.int64(0), v_ax0 % n]) - T.writes(T_reshape[v_ax0]) - T_reshape[v_ax0] = A[T.int64(0), v_ax0 % n] - - @T.prim_func - def reshape1(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (n, T.int64(2560))) - T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[(v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) - T_reshape[v_ax0, v_ax1, v_ax2] = A[(v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560)] - - @T.prim_func - def reshape2(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) - T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(80))) - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[T.int64(0), ((v_ax2 * T.int64(80) + v_ax3) // T.int64(2560) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), ((v_ax2 * T.int64(80) + v_ax3) // T.int64(2560) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)] - - @T.prim_func - def reshape3(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - m = T.int64() - A = T.match_buffer(var_A, (m, T.int64(32), T.int64(80))) - T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), m, T.int64(32), T.int64(80))) - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), m, T.int64(32), T.int64(80)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[((v_ax3 // T.int64(80) + v_ax2) // T.int64(32) + v_ax0 * m + v_ax1) % m, (v_ax3 // T.int64(80) + v_ax2) % T.int64(32), v_ax3 % T.int64(80)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[((v_ax3 // T.int64(80) + v_ax2) // T.int64(32) + v_ax0 * m + v_ax1) % m, (v_ax3 // T.int64(80) + v_ax2) % T.int64(32), v_ax3 % T.int64(80)] - - @T.prim_func - def reshape4(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80))) - T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[T.int64(0), (v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) - T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), (v_ax2 // T.int64(2560) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)] - - @T.prim_func - def rotary_embedding(var_A: T.handle, B: T.Buffer((T.int64(2048), T.int64(80)), "float32"), C: T.Buffer((T.int64(2048), T.int64(80)), "float32"), var_rotary: T.handle, m: T.int64): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80))) - rotary = T.match_buffer(var_rotary, (T.int64(1), n, T.int64(32), T.int64(80))) - # with T.block("root"): - for i_batch_size, i_seq_len, i_num_heads, i_head_dim in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): - with T.block("rotary"): - v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim = T.axis.remap("SSSS", [i_batch_size, i_seq_len, i_num_heads, i_head_dim]) - T.reads(B[m + v_i_seq_len - n, v_i_head_dim], A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim - T.int64(40):v_i_head_dim - T.int64(40) + T.int64(81)], C[m + v_i_seq_len - n, v_i_head_dim]) - T.writes(rotary[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim]) - rotary[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim] = T.Select(v_i_head_dim < T.int64(80), B[m + v_i_seq_len - n, v_i_head_dim] * A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim] + C[m + v_i_seq_len - n, v_i_head_dim] * T.Select(v_i_head_dim < T.int64(40), A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim + T.int64(40)] * T.float32(-1), A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim - T.int64(40)]), A[v_i_batch_size, v_i_seq_len, v_i_num_heads, v_i_head_dim]) - - @T.prim_func - def slice(var_A: T.handle, slice_1: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("slice"): - v_i, v__, v_k = T.axis.remap("SSS", [i, _, k]) - T.reads(A[v_i, n - T.int64(1), v_k]) - T.writes(slice_1[v_i, v__, v_k]) - slice_1[v_i, v__, v_k] = A[v_i, n - T.int64(1), v_k] - - @T.prim_func - def softmax(var_A: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n, m = T.int64(), T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - - @T.prim_func - def softmax2(var_A: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), T.int64(1), n)) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - - @T.prim_func - def squeeze(var_A: T.handle, var_T_squeeze: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80))) - T_squeeze = T.match_buffer(var_T_squeeze, (n, T.int64(32), T.int64(80))) - # with T.block("root"): - for ax0, ax1, ax2 in T.grid(n, T.int64(32), T.int64(80)): - with T.block("T_squeeze"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2]) - T.writes(T_squeeze[v_ax0, v_ax1, v_ax2]) - T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2] - - @T.prim_func - def take_decode(A: T.Buffer((T.int64(50432), T.int64(320)), "uint32"), B: T.Buffer((T.int64(50432), T.int64(80)), "uint32"), var_C: T.handle, var_take_decode: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - C = T.match_buffer(var_C, (n,), "int32") - take_decode_1 = T.match_buffer(var_take_decode, (n, T.int64(2560))) - # with T.block("root"): - for i, j in T.grid(n, T.int64(2560)): - with T.block("take_decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[C[v_i], v_j // T.int64(8)], C[v_i], B[C[v_i], v_j // T.int64(32)]) - T.writes(take_decode_1[v_i, v_j]) - take_decode_1[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(A[C[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(B[C[v_i], v_j // T.int64(32)], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(B[C[v_i], v_j // T.int64(32)], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - - @T.prim_func - def transpose(var_A: T.handle, var_T_transpose: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(80))) - T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), T.int64(32), n, T.int64(80))) - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(80)): - with T.block("T_transpose"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) - T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) - T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] - - @T.prim_func - def transpose1(var_A: T.handle, var_T_transpose: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, T.int64(80))) - T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), n, T.int64(32), T.int64(80))) - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(80)): - with T.block("T_transpose"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) - T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) - T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] -# fmt: on diff --git a/mlc_llm/dispatch/gpt_neox/redpajama_q4f32_tune.py b/mlc_llm/dispatch/gpt_neox/redpajama_q4f32_tune.py deleted file mode 100644 index 1b1169ea00..0000000000 --- a/mlc_llm/dispatch/gpt_neox/redpajama_q4f32_tune.py +++ /dev/null @@ -1,743 +0,0 @@ -# pylint: disable=pointless-string-statement,invalid-name,missing-docstring,line-too-long,too-many-locals,too-many-arguments,too-many-statements -from tvm.script import ir as I -from tvm.script import tir as T - -# fmt: off - -@I.ir_module -class Module: - @T.prim_func - def cast1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), compute: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(A[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = A[v_i0, v_i1, v_i2] - - @T.prim_func - def decode(A: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), B: T.Buffer((T.int64(80), T.int64(2560)), "uint32"), T_transpose: T.Buffer((T.int64(2560), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode_1 = T.alloc_buffer((T.int64(2560), T.int64(2560))) - for i, j in T.grid(T.int64(2560), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) - T.writes(decode_1[v_i, v_j]) - decode_1[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(B[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(B[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for ax0, ax1 in T.grid(T.int64(2560), T.int64(2560)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode_1[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode_1[v_ax1, v_ax0] - - @T.prim_func - def decode1(A: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), B: T.Buffer((T.int64(80), T.int64(10240)), "uint32"), T_transpose: T.Buffer((T.int64(10240), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(2560), T.int64(10240))) - for i, j in T.grid(T.int64(2560), T.int64(10240)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(B[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(B[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for ax0, ax1 in T.grid(T.int64(10240), T.int64(2560)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - - @T.prim_func - def decode2(A: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), B: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), T_transpose: T.Buffer((T.int64(2560), T.int64(10240)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(10240), T.int64(2560))) - for i, j in T.grid(T.int64(10240), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(B[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(B[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for ax0, ax1 in T.grid(T.int64(2560), T.int64(10240)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - - @T.prim_func - def divide1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(50432)): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[v_ax0, v_ax1, v_ax2], B[()]) - T.writes(T_divide[v_ax0, v_ax1, v_ax2]) - T_divide[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2] / B[()] - - @T.prim_func - def fused_decode3_matmul1(lv1352: T.Buffer((T.int64(320), T.int64(50432)), "uint32"), lv1353: T.Buffer((T.int64(80), T.int64(50432)), "uint32"), lv1800: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(50432))) - for i, j in T.grid(T.int64(2560), T.int64(50432)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1352[v_i // T.int64(8), v_j], lv1353[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1352[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1353[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1353[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(50432), T.int64(2560)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1800[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1800[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - - @T.prim_func - def fused_decode4_fused_matmul7_add3(lv1363: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv1364: T.Buffer((T.int64(80), T.int64(2560)), "uint32"), lv1808: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), linear_bias192: T.Buffer((T.int64(2560),), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560))) - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - for i, j in T.grid(T.int64(2560), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1363[v_i // T.int64(8), v_j], lv1364[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1363[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1364[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1364[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1808[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1808[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias192[v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias192[v_ax2] - - @T.prim_func - def fused_decode4_fused_matmul7_add3_add4(lv1381: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv1382: T.Buffer((T.int64(80), T.int64(2560)), "uint32"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), linear_bias195: T.Buffer((T.int64(2560),), "float32"), lv1805: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560))) - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - for i, j in T.grid(T.int64(2560), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1381[v_i // T.int64(8), v_j], lv1382[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1381[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1382[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1382[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv5[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv5[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias195[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias195[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], lv1805[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] + lv1805[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_decode5_fused_matmul9_add5_gelu1(lv1387: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), lv1388: T.Buffer((T.int64(80), T.int64(10240)), "uint32"), lv1852: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), linear_bias196: T.Buffer((T.int64(10240),), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(10240))) - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - T_multiply = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - T_multiply_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - T_add = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) - for i, j in T.grid(T.int64(2560), T.int64(10240)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1387[v_i // T.int64(8), v_j], lv1388[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1387[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1388[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1388[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(10240), T.int64(2560)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1852[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1852[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias196[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias196[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) - T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(T_multiply[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(compute[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T.writes(T_add[v_ax0, v_ax1, v_ax2]) - T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): - with T.block("T_multiply_2"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_decode6_fused_matmul10_add3_cast1_add4(lv1393: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), lv1394: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv1857: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float32"), linear_bias197: T.Buffer((T.int64(2560),), "float32"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(10240), T.int64(2560))) - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - var_compute_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - for i, j in T.grid(T.int64(10240), T.int64(2560)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1393[v_i // T.int64(8), v_j], lv1394[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1393[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1394[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1394[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(10240)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1857[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1857[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias197[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias197[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = var_T_add_intermediate[v_i0, v_i1, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_compute_intermediate[v_ax0, v_ax1, v_ax2], lv6[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate[v_ax0, v_ax1, v_ax2] + lv6[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_reshape7_squeeze1(lv1821: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), var_T_squeeze_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(80)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80))) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(80)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(lv1821[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)]) - T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1821[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(80)): - with T.block("T_squeeze"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2]) - T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2] = var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_slice1_cast1(lv3599: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), var_compute_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_slice_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) - for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("slice"): - v_i, v__, v_k = T.axis.remap("SSS", [i, _, k]) - T.reads(lv3599[v_i, T.int64(0), v_k]) - T.writes(var_slice_intermediate[v_i, v__, v_k]) - var_slice_intermediate[v_i, v__, v_k] = lv3599[v_i, T.int64(0), v_k] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_slice_intermediate[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = var_slice_intermediate[v_i0, v_i1, v_i2] - - @T.prim_func - def fused_transpose7_reshape8(lv1844: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float32"), var_T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_T_transpose_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80))) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(80)): - with T.block("T_transpose"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(lv1844[v_ax0, v_ax2, v_ax1, v_ax3]) - T.writes(var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1844[v_ax0, v_ax2, v_ax1, v_ax3] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)]) - T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2] = var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2560) // T.int64(80), v_ax2 % T.int64(80)] - - @T.prim_func - def layer_norm1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), T_layer_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - A_red_temp_v0 = T.alloc_buffer((T.int64(1), T.int64(1))) - A_red_temp_v1 = T.alloc_buffer((T.int64(1), T.int64(1))) - for ax0, ax1, k2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("A_red_temp"): - v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) - T.reads(A[v_ax0, v_ax1, v_k2]) - T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) - with T.init(): - A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) - A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] - A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 - A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_layer_norm"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) - T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) - T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] - - @T.prim_func - def reshape5(A: T.Buffer((T.int64(1), T.int64(1)), "int32"), T_reshape: T.Buffer((T.int64(1),), "int32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for ax0 in range(T.int64(1)): - with T.block("T_reshape"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - T.reads(A[T.int64(0), T.int64(0)]) - T.writes(T_reshape[v_ax0]) - T_reshape[v_ax0] = A[T.int64(0), T.int64(0)] - - @T.prim_func - def reshape6(A: T.Buffer((T.int64(1), T.int64(2560)), "float32"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[T.int64(0), v_ax2 % T.int64(2560)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) - T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax2 % T.int64(2560)] - - @T.prim_func - def reshape7(A: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(80)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), T.int64(0), (v_ax2 * T.int64(80) + v_ax3) % T.int64(2560)] - - @T.prim_func - def softmax1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(1))) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(50432))) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(1))) - for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(50432)): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) - T.reads(A[v_i0, v_i1, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], A[v_i0, v_i1, v_k]) - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(50432)): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(A[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2]) - T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(A[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1]) - for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(50432)): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1]) - with T.init(): - T_softmax_expsum[v_i0, v_i1] = T.float32(0) - T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(50432)): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2]) - T.block_attr({"axis": 2}) - T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1] - - @T.prim_func - def squeeze1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float32"), T_squeeze: T.Buffer((T.int64(1), T.int64(32), T.int64(80)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(80)): - with T.block("T_squeeze"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2]) - T.writes(T_squeeze[v_ax0, v_ax1, v_ax2]) - T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2] - - @T.prim_func - def take_decode1(A: T.Buffer((T.int64(50432), T.int64(320)), "uint32"), B: T.Buffer((T.int64(50432), T.int64(80)), "uint32"), C: T.Buffer((T.int64(1),), "int32"), take_decode: T.Buffer((T.int64(1), T.int64(2560)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for i, j in T.grid(T.int64(1), T.int64(2560)): - with T.block("take_decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[C[v_i], v_j // T.int64(8)], C[v_i], B[C[v_i], v_j // T.int64(32)]) - T.writes(take_decode[v_i, v_j]) - take_decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(A[C[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(B[C[v_i], v_j // T.int64(32)], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(B[C[v_i], v_j // T.int64(32)], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - - @T.prim_func - def transpose6(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(80)), "float32"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80)): - with T.block("T_transpose"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) - T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) - T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] - - ########## Dynamic shape ########## - - @T.prim_func - def fused_NT_matmul1_divide_maximum_minimum(p_lv34: T.handle, p_lv35: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - m = T.meta_var(T.int64(128)) - lv34 = T.match_buffer(p_lv34, (T.int64(1), T.int64(32), n, T.int64(80))) - lv35 = T.match_buffer(p_lv35, (T.int64(1), T.int64(32), m, T.int64(80))) - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m)) - var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(80)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv34[v_i0, v_i1, v_i2, v_k], lv35[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv34[v_i0, v_i1, v_i2, v_k] * lv35[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.11180339723346898) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - - @T.prim_func - def fused_NT_matmul2_add2_gelu(p_lv51: T.handle, lv38: T.Buffer((T.int64(10240), T.int64(2560)), "float32"), linear_bias4: T.Buffer((T.int64(10240),), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - lv51 = T.match_buffer(p_lv51, (T.int64(1), n, T.int64(2560))) - var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(10240))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_multiply = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - compute = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_multiply_1 = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - T_add = T.alloc_buffer((T.int64(1), n, T.int64(10240))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(10240), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv51[v_i0, v_i1, v_k], lv38[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv51[v_i0, v_i1, v_k] * lv38[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias4[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias4[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) - T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(T_multiply[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(compute[v_ax0, v_ax1, v_ax2]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) - T.writes(T_add[v_ax0, v_ax1, v_ax2]) - T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(10240)): - with T.block("T_multiply_2"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_NT_matmul3_add_cast_add1(p_lv56: T.handle, lv45: T.Buffer((T.int64(2560), T.int64(10240)), "float32"), linear_bias5: T.Buffer((T.int64(2560),), "float32"), p_lv49: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - lv56 = T.match_buffer(p_lv56, (T.int64(1), n, T.int64(10240))) - lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560))) - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_compute_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(10240)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv56[v_i0, v_i1, v_k], lv45[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv56[v_i0, v_i1, v_k] * lv45[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias5[v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias5[v_ax2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_T_add_intermediate_1[v_i0, v_i1, v_i2]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) - var_compute_intermediate[v_i0, v_i1, v_i2] = var_T_add_intermediate_1[v_i0, v_i1, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_compute_intermediate[v_ax0, v_ax1, v_ax2], lv49[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate[v_ax0, v_ax1, v_ax2] + lv49[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def fused_NT_matmul4_divide2_maximum1_minimum1(lv1835: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float32"), p_lv1836: T.handle, p_lv1806: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - lv1836 = T.match_buffer(p_lv1836, (T.int64(1), T.int64(32), n, T.int64(80))) - lv1806 = T.match_buffer(p_lv1806, (T.int64(1), T.int64(1), T.int64(1), n)) - var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(80)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv1835[v_i0, v_i1, v_i2, v_k], lv1836[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1835[v_i0, v_i1, v_i2, v_k] * lv1836[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.11180339723346898) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1806[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1806[v_ax0, T.int64(0), v_ax2, v_ax3]) - - @T.prim_func - def fused_NT_matmul_add(p_lv7: T.handle, lv10: T.Buffer((T.int64(2560), T.int64(2560)), "float32"), linear_bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - lv7 = T.match_buffer(p_lv7, (T.int64(1), n, T.int64(2560))) - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv7[v_i0, v_i1, v_k], lv10[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv7[v_i0, v_i1, v_k] * lv10[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias[v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias[v_ax2] - - @T.prim_func - def fused_NT_matmul_add_add1(p_lv45: T.handle, lv31: T.Buffer((T.int64(2560), T.int64(2560)), "float32"), linear_bias3: T.Buffer((T.int64(2560),), "float32"), p_lv2: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(2560))) - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560))) - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), n, T.int64(2560))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(2560), T.int64(2560)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv45[v_i0, v_i1, v_k], lv31[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * lv31[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias3[v_ax2]) - T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias3[v_ax2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2], lv2[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + lv2[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def layer_norm(var_A: T.handle, B: T.Buffer((T.int64(2560),), "float32"), C: T.Buffer((T.int64(2560),), "float32"), var_T_layer_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(2560))) - T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), n, T.int64(2560))) - # with T.block("root"): - A_red_temp_v0 = T.alloc_buffer((T.int64(1), n)) - A_red_temp_v1 = T.alloc_buffer((T.int64(1), n)) - for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("A_red_temp"): - v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) - T.reads(A[v_ax0, v_ax1, v_k2]) - T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1]) - with T.init(): - A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) - A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + A[v_ax0, v_ax1, v_k2] * A[v_ax0, v_ax1, v_k2] - A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 - A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): - with T.block("T_layer_norm"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(A[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, v_ax1], B[v_ax2], C[v_ax2]) - T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) - T_layer_norm[v_ax0, v_ax1, v_ax2] = (A[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * B[v_ax2] + C[v_ax2] - - @T.prim_func - def matmul(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(128)) - m = T.meta_var(T.int64(32)) - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(80))) - matmul_1 = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(80))) - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(80), m): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] - - @T.prim_func - def matmul8(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(80)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.meta_var(T.int64(32)) - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n)) - B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(80))) - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(80), n): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] - -# fmt: on diff --git a/mlc_llm/dispatch/llama/__init__.py b/mlc_llm/dispatch/llama/__init__.py deleted file mode 100644 index 2374080799..0000000000 --- a/mlc_llm/dispatch/llama/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .main import lookup_func as lookup diff --git a/mlc_llm/dispatch/llama/main.py b/mlc_llm/dispatch/llama/main.py deleted file mode 100644 index 166739b85a..0000000000 --- a/mlc_llm/dispatch/llama/main.py +++ /dev/null @@ -1,6712 +0,0 @@ -import tvm -from tvm import IRModule -from tvm.script import tir as T - - -# fmt: off -@T.prim_func -def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int64): - T.func_attr({"tir.noalias": T.bool(True)}) - var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), n, n), "float16") - # with T.block("root"): - var_make_diag_mask_te_intermediate = T.alloc_buffer((n, n), "float16") - for i, j in T.grid(n, n): - with T.block("make_diag_mask_te"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads() - T.writes(var_make_diag_mask_te_intermediate[v_i, v_j]) - var_make_diag_mask_te_intermediate[v_i, v_j] = T.Select(v_i < v_j, T.float16(-65504), T.float16(65504)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), n, n): - with T.block("T_broadcast_to"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_make_diag_mask_te_intermediate[v_ax2, v_ax3]) - T.writes(var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_make_diag_mask_te_intermediate[v_ax2, v_ax3] - - -def fused_min_max_triu_te_broadcast_to_sch_func(): - sch = tvm.tir.Schedule(fused_min_max_triu_te_broadcast_to) - b0 = sch.get_block("T_broadcast_to") - sch.reverse_compute_inline(b0) - return sch.mod["main"] - - -@T.prim_func -def rms_norm_before(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((T.int64(4096),), "float32"), var_rms_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (T.int64(1), n, T.int64(4096))) - rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096))) - # with T.block("root"): - rxplaceholderred_temp = T.alloc_buffer((T.int64(1), n)) - for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("rxplaceholderred_temp"): - v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) - T.reads(rxplaceholder_1[v_bsz, v_i, v_k]) - T.writes(rxplaceholderred_temp[v_bsz, v_i]) - with T.init(): - rxplaceholderred_temp[v_bsz, v_i] = T.float32(0) - rxplaceholderred_temp[v_bsz, v_i] = rxplaceholderred_temp[v_bsz, v_i] + rxplaceholder_1[v_bsz, v_i, v_k] * rxplaceholder_1[v_bsz, v_i, v_k] - for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("rms_norm"): - v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) - T.reads(rxplaceholder[v_k], rxplaceholder_1[v_bsz, v_i, v_k], rxplaceholderred_temp[v_bsz, v_i]) - T.writes(rms_norm_1[v_bsz, v_i, v_k]) - rms_norm_1[v_bsz, v_i, v_k] = rxplaceholder[v_k] * (rxplaceholder_1[v_bsz, v_i, v_k] / T.sqrt(rxplaceholderred_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))) - - -@T.prim_func -def rms_norm_after(var_A: T.handle, var_weight: T.Buffer((T.int64(4096),), "float32"), var_rms_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096))) - rms_norm = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096))) - # with T.block("root"): - for i_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("compute_o"): - v_bsz = T.axis.spatial(T.int64(1), T.int64(0)) - v_i_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i_0) - T.reads(A[v_bsz, v_i_o * T.int64(32):v_i_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) - T.writes(rms_norm[v_bsz, T.int64(0) : T.int64(n), T.int64(0):T.int64(4096)]) - sq_sum_pad_local = T.alloc_buffer((T.int64(32),), scope="shared") - for bsz, i_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(16)): - for k_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("compute"): - v_i_i = T.axis.spatial(T.int64(32), i_1) - v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1) - T.reads(A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i]) - T.writes(sq_sum_pad_local[v_i_i]) - with T.init(): - sq_sum_pad_local[v_i_i] = T.float32(0) - sq_sum_pad_local[v_i_i] = sq_sum_pad_local[v_i_i] + T.if_then_else(v_i_o * T.int64(32) + v_i_i < n, A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i], T.float32(0)) * T.if_then_else(v_i_o * T.int64(32) + v_i_i < n, A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i], T.float32(0)) - for bsz_i_fused_1, k_0 in T.grid(T.int64(32), T.int64(16)): - for k_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("compute_cache_write"): - v_bsz = T.axis.spatial(T.int64(1), T.int64(0)) - v_i_i = T.axis.spatial(n, bsz_i_fused_1) - v_k = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1) - T.reads(A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k], var_weight[v_k], sq_sum_pad_local[v_i_i]) - T.writes(rms_norm[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k]) - if v_i_i < n: - rms_norm[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k] = var_weight[v_k] * (A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k] / T.sqrt(sq_sum_pad_local[v_i_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))) - - -@T.prim_func -def rms_norm_fp16_before(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (T.int64(1), n, T.int64(4096)), "float16") - rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): - rxplaceholderred_temp = T.alloc_buffer((T.int64(1), n)) - for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("rxplaceholderred_temp"): - v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) - T.reads(rxplaceholder_1[v_bsz, v_i, v_k]) - T.writes(rxplaceholderred_temp[v_bsz, v_i]) - with T.init(): - rxplaceholderred_temp[v_bsz, v_i] = T.float32(0) - rxplaceholderred_temp[v_bsz, v_i] = rxplaceholderred_temp[v_bsz, v_i] + T.Cast("float32", rxplaceholder_1[v_bsz, v_i, v_k]) * T.Cast("float32", rxplaceholder_1[v_bsz, v_i, v_k]) - for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("rms_norm"): - v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) - T.reads(rxplaceholder[v_k], rxplaceholder_1[v_bsz, v_i, v_k], rxplaceholderred_temp[v_bsz, v_i]) - T.writes(rms_norm_1[v_bsz, v_i, v_k]) - rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", rxplaceholder[v_k]) * (T.Cast("float32", rxplaceholder_1[v_bsz, v_i, v_k]) / T.sqrt(rxplaceholderred_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) - - -@T.prim_func -def rms_norm_fp16_after(var_A: T.handle, var_weight: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), dtype="float16") - rms_norm = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), dtype="float16") - # with T.block("root"): - for i_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("compute_o"): - v_bsz = T.axis.spatial(T.int64(1), T.int64(0)) - v_i_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i_0) - T.reads(A[v_bsz, v_i_o * T.int64(32):v_i_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) - T.writes(rms_norm[v_bsz, T.int64(0) : T.int64(n), T.int64(0):T.int64(4096)]) - sq_sum_pad_local = T.alloc_buffer((T.int64(32),), scope="shared") - for bsz, i_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(16)): - for k_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("compute"): - v_i_i = T.axis.spatial(T.int64(32), i_1) - v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1) - T.reads(A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i]) - T.writes(sq_sum_pad_local[v_i_i]) - with T.init(): - sq_sum_pad_local[v_i_i] = T.float32(0) - sq_sum_pad_local[v_i_i] = sq_sum_pad_local[v_i_i] + T.if_then_else(v_i_o * T.int64(32) + v_i_i < n, T.Cast("float32", A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i]) * T.Cast("float32", A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k_i]), T.float32(0)) - for bsz_i_fused_1, k_0 in T.grid(T.int64(32), T.int64(16)): - for k_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("compute_cache_write"): - v_bsz = T.axis.spatial(T.int64(1), T.int64(0)) - v_i_i = T.axis.spatial(n, bsz_i_fused_1) - v_k = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1) - T.reads(A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k], var_weight[v_k], sq_sum_pad_local[v_i_i]) - T.writes(rms_norm[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k]) - if v_i_i < n: - rms_norm[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k] = T.Cast("float16", T.Cast("float32", var_weight[v_k]) * (T.Cast("float32", A[v_bsz, v_i_o * T.int64(32) + v_i_i, v_k]) / T.sqrt(sq_sum_pad_local[v_i_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) - - -@T.prim_func -def softmax_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, n)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, n)) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, n)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - - -@T.prim_func -def softmax_after(var_A: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, n)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, n)) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_maxelem_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(n + T.int64(127)) // T.int64(128) * T.int64(128)]) - T.writes(T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (n + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((n + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) - T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < n, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float32(-3.4028234663852886e+38))) - for i0_i1_i2_1_fused_0 in range(T.int64(8)): - for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem_cache_write"): - v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) - T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) - T.reads(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_expsum_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(n + T.int64(127)) // T.int64(128) * T.int64(128)], T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T.writes(T_softmax_expsum[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (n + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((n + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(0) - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < n, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]), T.float32(0)) - for i0_i1_i2_1_fused_0 in range(T.int64(8)): - for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum_cache_write"): - v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) - T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) - T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] - for i0_i1_i2_fused_i3_fused_0 in T.thread_binding((n * T.int64(32) * n + T.int64(255)) // T.int64(256), thread="blockIdx.x"): - for i0_i1_i2_fused_i3_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("T_softmax_norm"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // n // n) - v_i2 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // n % n) - v_i3 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) % n) - T.where(i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1 < n * T.int64(32) * n) - T.reads(T_softmax_expsum[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] - - -@T.prim_func -def softmax_mxn_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - m = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - - -@T.prim_func -def softmax_mxn_after(var_A: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - m = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_maxelem_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)]) - T.writes(T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) - T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float32(-3.4028234663852886e+38))) - for i0_i1_i2_1_fused_0 in range(T.int64(8)): - for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem_cache_write"): - v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) - T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) - T.reads(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_expsum_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)], T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T.writes(T_softmax_expsum[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(0) - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]), T.float32(0)) - for i0_i1_i2_1_fused_0 in range(T.int64(8)): - for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum_cache_write"): - v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) - T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) - T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] - for i0_i1_i2_fused_i3_fused_0 in T.thread_binding((n * T.int64(32) * m + T.int64(255)) // T.int64(256), thread="blockIdx.x"): - for i0_i1_i2_fused_i3_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("T_softmax_norm"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // m // n) - v_i2 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // m % n) - v_i3 = T.axis.spatial(m, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) % m) - T.where(i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1 < n * T.int64(32) * m) - T.reads(T_softmax_expsum[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] - -@T.prim_func -def softmax_cast_mxn_before(p_lv37: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n, m = T.int64(), T.int64() - lv37 = T.match_buffer(p_lv37, (T.int64(1), T.int64(32), n, m)) - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) - var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv37[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv37[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(lv37[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv37[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - - -@T.prim_func -def softmax_cast_mxn_after(var_A: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - m = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m), dtype="float16") - # with T.block("root"): - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_maxelem_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)]) - T.writes(T_softmax_norm[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):m]) - T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") - T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) - T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float32(-3.4028234663852886e+38))) - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float32(0) - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]), T.float32(0)) - for i0_i1_i2_1_i3_fused_0 in range((T.int64(32) * T.int64(32) * m) // T.int64(128)): - for i0_i1_i2_1_i3_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_norm"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // T.int64(32) // m) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // m % T.int64(32)) - v_i3 = T.axis.spatial(m, (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) % m) - T.where(i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1 < T.int64(32) * T.int64(32) * m) - T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i], A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3], T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3]) - if v_i2_o * T.int64(32) + v_i2_i < n: - T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] = T.Cast("float16", T.exp(A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] - T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) / T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i]) - - -@T.prim_func -def softmax_mxn_fp16_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - m = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m), "float16") - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m), "float16") - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n), "float16") - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float16(-65504) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float16(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - -@T.prim_func -def softmax_mxn_fp16_after(var_A: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - m = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), dtype="float16") - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, m), dtype="float16") - # with T.block("root"): - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_maxelem_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(m + T.int64(127)) // T.int64(128) * T.int64(128)]) - T.writes(T_softmax_norm[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):m]) - T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared", dtype="float16") - T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared", dtype="float16") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) - T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float16(-65504) - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float16(-65504))) - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (m + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((m + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float16(0) - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < m, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]), T.float16(0)) - for i0_i1_i2_1_i3_fused_0 in range((T.int64(32) * T.int64(32) * m) // T.int64(128)): - for i0_i1_i2_1_i3_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_norm"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // T.int64(32) // m) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) // m % T.int64(32)) - v_i3 = T.axis.spatial(m, (i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1) % m) - T.where(i0_i1_i2_1_i3_fused_0 * T.int64(128) + i0_i1_i2_1_i3_fused_1 < T.int64(32) * T.int64(32) * m) - T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i], A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3], T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3]) - if v_i2_o * T.int64(32) + v_i2_i < n: - T_softmax_norm[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] = T.exp(A[v_i0, v_i1, v_i2_o * T.int64(32) + v_i2_i, v_i3] - T_softmax_maxelem_pad_0_local[v_i0, v_i1, v_i2_i]) / T_softmax_expsum_pad_0_local[v_i0, v_i1, v_i2_i] - - -@T.prim_func -def softmax_fp16_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, n), "float16") - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, n), "float16") - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n), "float16") - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, n), "float16") - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float16(-65504) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float16(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - - -@T.prim_func -def softmax_fp16_after(var_A: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, n), dtype="float16") - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), n, n), dtype="float16") - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n), dtype="float16") - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n), dtype="float16") - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_maxelem_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(n + T.int64(127)) // T.int64(128) * T.int64(128)]) - T.writes(T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T_softmax_maxelem_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared", dtype="float16") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (n + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((n + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i]) - T.writes(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float16(-65504) - T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.max(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i], T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < n, A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T.float16(-65504))) - for i0_i1_i2_1_fused_0 in range(T.int64(8)): - for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_maxelem_cache_write"): - v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) - T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) - T.reads(T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_maxelem_pad_0_local[v_i0, v_i1_i, v_i2_i] - for i2_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.x"): - with T.block("T_softmax_expsum_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0) - T.reads(A[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):(n + T.int64(127)) // T.int64(128) * T.int64(128)], T_softmax_maxelem[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T.writes(T_softmax_expsum[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32)]) - T_softmax_expsum_pad_0_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32)), scope="shared", dtype="float16") - for i0, i1, i2_1, k_0 in T.grid(T.int64(1), T.int64(32), T.int64(32), (n + T.int64(127)) // T.int64(128)): - for k_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum"): - v_i1_i, v_i2_i = T.axis.remap("SS", [i1, i2_1]) - v_k_i = T.axis.reduce(T.int64(32) * ((n + T.int64(127)) // T.int64(128)), k_0 * T.int64(128) + k_1) - T.reads(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i], T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T.writes(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - with T.init(): - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T.float16(0) - T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] + T.if_then_else(v_i2_o * T.int64(32) + v_i2_i < n and v_k_i < n, T.exp(A[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i, v_k_i] - T_softmax_maxelem[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]), T.float16(0)) - for i0_i1_i2_1_fused_0 in range(T.int64(8)): - for i0_i1_i2_1_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - with T.block("T_softmax_expsum_cache_write"): - v_i1_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) // T.int64(32)) - v_i2_i = T.axis.spatial(T.int64(32), (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32)) - T.where(v_i2_o * T.int64(32) + (i0_i1_i2_1_fused_0 * T.int64(128) + i0_i1_i2_1_fused_1) % T.int64(32) < n) - T.reads(T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i]) - T.writes(T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i]) - T_softmax_expsum[v_i0, v_i1_i, v_i2_o * T.int64(32) + v_i2_i] = T_softmax_expsum_pad_0_local[v_i0, v_i1_i, v_i2_i] - for i0_i1_i2_fused_i3_fused_0 in T.thread_binding((n * T.int64(32) * n + T.int64(255)) // T.int64(256), thread="blockIdx.x"): - for i0_i1_i2_fused_i3_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("T_softmax_norm"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(32), (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // n // n) - v_i2 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) // n % n) - v_i3 = T.axis.spatial(n, (i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1) % n) - T.where(i0_i1_i2_fused_i3_fused_0 * T.int64(256) + i0_i1_i2_fused_i3_fused_1 < n * T.int64(32) * n) - T.reads(T_softmax_expsum[v_i0, v_i1, v_i2], A[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T.exp(A[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) / T_softmax_expsum[v_i0, v_i1, v_i2] - - -@T.prim_func -def softmax_1xn_before(var_inp0: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - inp0 = T.match_buffer(var_inp0, (T.int64(1), T.int64(32), T.int64(1), n)) - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), T.int64(1), n)) - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(inp0[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], inp0[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(inp0[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(inp0[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - - -@T.prim_func -def softmax_cast_1xn_before(p_lv1614: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv1614 = T.match_buffer(p_lv1614, (T.int64(1), T.int64(32), T.int64(1), n)) - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), "float16") - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) - var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1614[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv1614[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(lv1614[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv1614[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) - - -@T.prim_func -def softmax_1xn_fp16_before(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), T.int64(1), n), "float16") - T_softmax_norm = T.match_buffer(var_T_softmax_norm, (T.int64(1), T.int64(32), T.int64(1), n), "float16") - # with T.block("root"): - T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)), "float16") - T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_maxelem"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float16(-65504) - T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], rxplaceholder[v_i0, v_i1, v_i2, v_k]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_exp"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) - T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) - T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(rxplaceholder[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_expsum"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) - T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) - with T.init(): - T_softmax_expsum[v_i0, v_i1, v_i2] = T.float16(0) - T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_softmax_norm"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) - T.writes(T_softmax_norm[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"axis": 3}) - T_softmax_norm[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] - - -def softmax_1xn_sch_func(f_softmax, cast_to_fp16: bool = False): - sch = tvm.tir.Schedule(f_softmax) - if cast_to_fp16: - b_cast = sch.get_block("compute") - sch.reverse_compute_inline(b_cast) - - b0 = sch.get_block("T_softmax_exp") - sch.compute_inline(b0) - b1 = sch.get_block("T_softmax_norm") - l2, l3, l4, l5 = sch.get_loops(b1) - l6, l7 = sch.split(l5, [None, 128]) - sch.bind(l7, "threadIdx.x") - b8 = sch.get_block("T_softmax_expsum") - sch.compute_at(b8, l4) - sch.set_scope(b8, 0, "shared") - l9, l10, l11, l12 = sch.get_loops(b8) - l13, l14 = sch.split(l12, [None, 128]) - sch.bind(l14, "threadIdx.x") - b15 = sch.get_block("T_softmax_maxelem") - sch.compute_at(b15, l4) - sch.set_scope(b15, 0, "shared") - l16, l17, l18, l19 = sch.get_loops(b15) - l20, l21 = sch.split(l19, [None, 128]) - sch.bind(l21, "threadIdx.x") - l22 = sch.fuse(l2, l3, l4) - sch.bind(l22, "blockIdx.x") - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def matmul1_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), T.int64(1), n)) - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128))) - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k], rxplaceholder_1[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1[v_i0, v_i1, v_k, v_i3] - - -@T.prim_func -def matmul1_after(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32")): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), T.int64(1), n)) - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128))) - # with T.block("root"): - matmul_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), scope="local") - rxplaceholder_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), (n + T.int64(127)) // T.int64(128) * T.int64(128)), scope="shared") - rxplaceholder_1_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(128)), scope="shared") - for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(T.int64(128), thread="threadIdx.x"): - for i0_3_init, i1_3_init, i2_3_init, i3_3_init, i0_4_init, i1_4_init, i2_4_init, i3_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), i0_3_init + i0_4_init) - v_i1 = T.axis.spatial(T.int64(32), i0_2_i1_2_i2_2_i3_2_fused // T.int64(8) * T.int64(2) + i1_3_init * T.int64(2) + i1_4_init) - v_i2 = T.axis.spatial(T.int64(1), i2_3_init + i2_4_init) - v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_0_i3_0_fused * T.int64(8) + i0_2_i1_2_i2_2_i3_2_fused % T.int64(8) + i3_3_init + i3_4_init) - T.reads() - T.writes(matmul_local[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - matmul_local[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - for k_0, k_1_0 in T.grid((n + T.int64(127)) // T.int64(128), T.int64(8)): - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(4)): - with T.block("rxplaceholder_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_ax3_fused_0 * T.int64(512) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) // T.int64(16)) - v2 = T.axis.spatial(T.int64(1), T.int64(0)) - v3 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(512) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(16)) - T.reads(rxplaceholder[v0, v1, v2, v3]) - T.writes(rxplaceholder_pad_shared[v0, v1, v2, v3]) - rxplaceholder_pad_shared[v0, v1, v2, v3] = T.if_then_else(v3 < n, rxplaceholder[v0, v1, v2, v3], T.float32(0)) - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(8)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(4)): - with T.block("rxplaceholder_1_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_ax3_fused_0 * T.int64(512) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) // T.int64(128)) - v2 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(512) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(128) // T.int64(8)) - v3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_0_i3_0_fused * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(512) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) - T.reads(rxplaceholder_1[v0, v1, v2, v3]) - T.writes(rxplaceholder_1_pad_shared[v0, v1, v2, v3]) - rxplaceholder_1_pad_shared[v0, v1, v2, v3] = T.if_then_else(v2 < n, rxplaceholder_1[v0, v1, v2, v3], T.float32(0)) - for k_1_1, i0_3, i1_3, i2_3, i3_3, k_1_2, i0_4, i1_4, i2_4, i3_4 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(8), T.int64(1), T.int64(2), T.int64(1), T.int64(1)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), i0_3 + i0_4) - v_i1 = T.axis.spatial(T.int64(32), i0_2_i1_2_i2_2_i3_2_fused // T.int64(8) * T.int64(2) + i1_3 * T.int64(2) + i1_4) - v_i2 = T.axis.spatial(T.int64(1), i2_3 + i2_4) - v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_0_i3_0_fused * T.int64(8) + i0_2_i1_2_i2_2_i3_2_fused % T.int64(8) + i3_3 + i3_4) - v_k = T.axis.reduce((n + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(16) + k_1_1 * T.int64(8) + k_1_2) - T.reads(matmul_local[v_i0, v_i1, v_i2, v_i3], rxplaceholder_pad_shared[v_i0, v_i1, v_i2, v_k], rxplaceholder_1_pad_shared[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul_local[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - matmul_local[v_i0, v_i1, v_i2, v_i3] = matmul_local[v_i0, v_i1, v_i2, v_i3] + rxplaceholder_pad_shared[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1_pad_shared[v_i0, v_i1, v_k, v_i3] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(2), T.int64(1), T.int64(1)): - with T.block("matmul_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_2_i1_2_i2_2_i3_2_fused // T.int64(8) * T.int64(2) + ax1) - v2 = T.axis.spatial(T.int64(1), ax2) - v3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_0_i3_0_fused * T.int64(8) + i0_2_i1_2_i2_2_i3_2_fused % T.int64(8) + ax3) - T.reads(matmul_local[v0, v1, v2, v3]) - T.writes(matmul[v0, v1, v2, v3]) - matmul[v0, v1, v2, v3] = matmul_local[v0, v1, v2, v3] - - -@T.prim_func -def matmul2_before(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - inp0 = T.match_buffer(var_inp0, (T.int64(1), n, T.int64(4096))) - matmul = T.match_buffer(var_matmul, (T.int64(1), n, T.int64(4096))) - # with T.block("root"): - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(inp0[v_i0, v_i1, v_k], inp1[v_k, v_i2]) - T.writes(matmul[v_i0, v_i1, v_i2]) - with T.init(): - matmul[v_i0, v_i1, v_i2] = T.float32(0) - matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + inp0[v_i0, v_i1, v_k] * inp1[v_k, v_i2] - -def matmul2_sch_func(): - sch = tvm.tir.Schedule(matmul2_before) - b0 = sch.get_block("matmul") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - b0 = sch.get_block(name="matmul", func_name="main") - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l2, l3, l4, l5 = sch.get_loops(block=b0) - v6, v7, v8, v9, v10 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l11, l12, l13, l14, l15 = sch.split(loop=l2, factors=[v6, v7, v8, v9, v10], preserve_unit_iters=True) - v16, v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[2, 2, 2, 4, 1]) - l21, l22, l23, l24, l25 = sch.split(loop=l3, factors=[v16, v17, v18, v19, v20], preserve_unit_iters=True) - v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[128, 2, 16, 1, 1]) - l31, l32, l33, l34, l35 = sch.split(loop=l4, factors=[v26, v27, v28, v29, v30], preserve_unit_iters=True) - v36, v37, v38 = sch.sample_perfect_tile(loop=l5, n=3, max_innermost_factor=64, decision=[512, 4, 2]) - l39, l40, l41 = sch.split(loop=l5, factors=[v36, v37, v38], preserve_unit_iters=True) - sch.reorder(l11, l21, l31, l12, l22, l32, l13, l23, l33, l39, l40, l14, l24, l34, l41, l15, l25, l35) - l42 = sch.fuse(l11, l21, l31, preserve_unit_iters=True) - sch.bind(loop=l42, thread_axis="blockIdx.x") - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="vthread.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b45 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b45, loop=l44, preserve_unit_loops=True, index=-1) - b46 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b46, loop=l39, preserve_unit_loops=True, index=-1) - _, l47, l48, l49, l50, l51, l52, l53 = sch.get_loops(block=b46) - l54 = sch.fuse(l51, l52, l53, preserve_unit_iters=True) - v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) - b56 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b56, loop=l39, preserve_unit_loops=True, index=-1) - _, l57, l58, l59, l60, l61, l62 = sch.get_loops(block=b56) - l63 = sch.fuse(l61, l62, preserve_unit_iters=True) - v64 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch", ann_val=v64) - v65 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v65) - sch.enter_postproc() - sch.unannotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch") - _, l66, l67, l68, l69, l70 = sch.get_loops(block=b46) - l71, l72, l73 = sch.split(loop=l70, factors=[None, 32, 2], preserve_unit_iters=True) - sch.vectorize(loop=l73) - sch.bind(loop=l72, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch") - _, l74, l75, l76, l77, l78 = sch.get_loops(block=b56) - l79, l80, l81 = sch.split(loop=l78, factors=[None, 32, 2], preserve_unit_iters=True) - sch.vectorize(loop=l81) - sch.bind(loop=l80, thread_axis="threadIdx.x") - b82 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") - _, b83, b84, b85, b86, _ = sch.get_child_blocks(b82) - _, l87, l88, l89, l90, l91, l92, l93 = sch.get_loops(block=b83) - sch.annotate(block_or_loop=l87, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l87, ann_key="pragma_unroll_explicit", ann_val=1) - _, l94, l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b84) - sch.annotate(block_or_loop=l94, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l94, ann_key="pragma_unroll_explicit", ann_val=1) - _, l101, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112 = sch.get_loops(block=b85) - sch.annotate(block_or_loop=l101, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l101, ann_key="pragma_unroll_explicit", ann_val=1) - _, l113, l114, l115, l116, l117, l118 = sch.get_loops(block=b86) - sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) - b119 = sch.get_block(name="matmul", func_name="main") - _, l120, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b119) - b132 = sch.decompose_reduction(block=b119, loop=l123) - b1 = sch.get_block("inp0_pad") - sch.compute_inline(b1) - b2 = sch.get_block("matmul_pad") - sch.reverse_compute_inline(b2) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def matmul5_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, n)) - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128))) - matmul_1 = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128))) - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), n): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(rxplaceholder[T.int64(0), v_i1, v_i2, v_k], rxplaceholder_1[T.int64(0), v_i1, v_k, v_i3]) - T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[T.int64(0), v_i1, v_i2, v_k] * rxplaceholder_1[T.int64(0), v_i1, v_k, v_i3] - - -@T.prim_func -def matmul5_after(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, n)) - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128))) - matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128))) - # with T.block("root"): - C_pad = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), T.int64(128))) - C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(128)), scope="local") - A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), (n + T.int64(127)) // T.int64(128) * T.int64(128)), scope="shared") - B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(128)), scope="shared") - for i2_0 in range((n + T.int64(127)) // T.int64(128)): - for i0_0_i1_0_i2_1_0_i3_0_fused in T.thread_binding(T.int64(256), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 1024, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_i2_1_1_i3_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): - for i0_2_i1_2_i2_1_2_i3_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for i0_3_init, i1_3_init, i2_1_3_init, i3_3_init, i0_4_init, i1_4_init, i2_1_4_init, i3_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(4), T.int64(1)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), i0_3_init + i0_4_init) - v_i1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8) + i1_3_init + i1_4_init) - v_i2 = T.axis.spatial((n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + i0_1_i1_1_i2_1_1_i3_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_2_i2_1_2_i3_2_fused // T.int64(16) * T.int64(4) + i2_1_3_init * T.int64(4) + i2_1_4_init) - v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + i0_1_i1_1_i2_1_1_i3_1_fused % T.int64(2) * T.int64(32) + i0_2_i1_2_i2_1_2_i3_2_fused % T.int64(16) * T.int64(2) + i3_3_init + i3_4_init) - T.reads() - T.writes(C_pad_local[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - for k_0, k_1_0 in T.grid((n + T.int64(127)) // T.int64(128), T.int64(16)): - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(4)): - with T.block("A_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8)) - v2 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(256) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) - v3 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(256) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) - T.reads(rxplaceholder[v0, v1, v2, v3]) - T.writes(A_pad_shared[v0, v1, v2, v3]) - A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v2 < n and v3 < n, rxplaceholder[v0, v1, v2, v3], T.float32(0)) - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(4)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): - with T.block("B_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8)) - v2 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(64)) - v3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(64)) - T.reads(rxplaceholder_1[v0, v1, v2, v3]) - T.writes(B_pad_shared[v0, v1, v2, v3]) - B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v2 < n, rxplaceholder_1[v0, v1, v2, v3], T.float32(0)) - for k_1_1, i0_3, i1_3, i2_1_3, i3_3, k_1_2, i0_4, i1_4, i2_1_4, i3_4 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(4), T.int64(1)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), i0_3 + i0_4) - v_i1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8) + i1_3 + i1_4) - v_i2 = T.axis.spatial((n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + i0_1_i1_1_i2_1_1_i3_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_2_i2_1_2_i3_2_fused // T.int64(16) * T.int64(4) + i2_1_3 * T.int64(4) + i2_1_4) - v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + i0_1_i1_1_i2_1_1_i3_1_fused % T.int64(2) * T.int64(32) + i0_2_i1_2_i2_1_2_i3_2_fused % T.int64(16) * T.int64(2) + i3_3 + i3_4) - v_k = T.axis.reduce((n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(8) + k_1_1 * T.int64(4) + k_1_2) - T.reads(C_pad_local[v_i0, v_i1, v_i2, v_i3], A_pad_shared[T.int64(0), v_i1, v_i2, v_k], B_pad_shared[T.int64(0), v_i1, v_k, v_i3]) - T.writes(C_pad_local[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[v_i0, v_i1, v_i2, v_i3] = C_pad_local[v_i0, v_i1, v_i2, v_i3] + A_pad_shared[T.int64(0), v_i1, v_i2, v_k] * B_pad_shared[T.int64(0), v_i1, v_k, v_i3] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(4), T.int64(2)): - with T.block("C_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8) + ax1) - v2 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + i0_1_i1_1_i2_1_1_i3_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_2_i2_1_2_i3_2_fused // T.int64(16) * T.int64(4) + ax2) - v3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + i0_1_i1_1_i2_1_1_i3_1_fused % T.int64(2) * T.int64(32) + i0_2_i1_2_i2_1_2_i3_2_fused % T.int64(16) * T.int64(2) + ax3) - T.reads(C_pad_local[v0, v1, v2, v3]) - T.writes(C_pad[v0, v1, v2, v3]) - C_pad[v0, v1, v2, v3] = C_pad_local[v0, v1, v2, v3] - for i0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): - for i1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i2, i3 in T.grid(n, T.int64(128)): - with T.block("C_pad"): - vi0, vi1, vi2, vi3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(C_pad[vi0, vi1, vi2, vi3]) - T.writes(matmul[vi0, vi1, vi2, vi3]) - matmul[vi0, vi1, vi2, vi3] = C_pad[vi0, vi1, vi2, vi3] - -@T.prim_func -def matmul5_with_m_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n, m = T.int64(), T.int64() - A = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m)) - B = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), m, T.int64(128))) - matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128))) - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), m): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] - - -@T.prim_func -def matmul5_with_m_after(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - m = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m)) - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), m, T.int64(128))) - matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128))) - # with T.block("root"): - C_pad = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), T.int64(128))) - C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(128)), scope="local") - A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(127)) // T.int64(128) * T.int64(128), (m + T.int64(127)) // T.int64(128) * T.int64(128)), scope="shared") - B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (m + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(128)), scope="shared") - for i2_0 in range((n + T.int64(127)) // T.int64(128)): - for i0_0_i1_0_i2_1_0_i3_0_fused in T.thread_binding(T.int64(256), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 1024, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_i2_1_1_i3_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): - for i0_2_i1_2_i2_1_2_i3_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for i0_3_init, i1_3_init, i2_1_3_init, i3_3_init, i0_4_init, i1_4_init, i2_1_4_init, i3_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(4), T.int64(1)): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), i0_3_init + i0_4_init) - v_i1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8) + i1_3_init + i1_4_init) - v_i2 = T.axis.spatial((n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + i0_1_i1_1_i2_1_1_i3_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_2_i2_1_2_i3_2_fused // T.int64(16) * T.int64(4) + i2_1_3_init * T.int64(4) + i2_1_4_init) - v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + i0_1_i1_1_i2_1_1_i3_1_fused % T.int64(2) * T.int64(32) + i0_2_i1_2_i2_1_2_i3_2_fused % T.int64(16) * T.int64(2) + i3_3_init + i3_4_init) - T.reads() - T.writes(C_pad_local[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - for k_0, k_1_0 in T.grid((m + T.int64(127)) // T.int64(128), T.int64(16)): - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(4)): - with T.block("A_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8)) - v2 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(256) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) - v3 = T.axis.spatial((m + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(256) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) - T.reads(rxplaceholder[v0, v1, v2, v3]) - T.writes(A_pad_shared[v0, v1, v2, v3]) - A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v2 < n and v3 < m, rxplaceholder[v0, v1, v2, v3], T.float32(0)) - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(4)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): - with T.block("B_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8)) - v2 = T.axis.spatial((m + T.int64(127)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(64)) - v3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(64)) - T.reads(rxplaceholder_1[v0, v1, v2, v3]) - T.writes(B_pad_shared[v0, v1, v2, v3]) - B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v2 < m, rxplaceholder_1[v0, v1, v2, v3], T.float32(0)) - for k_1_1, i0_3, i1_3, i2_1_3, i3_3, k_1_2, i0_4, i1_4, i2_1_4, i3_4 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(4), T.int64(1)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), i0_3 + i0_4) - v_i1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8) + i1_3 + i1_4) - v_i2 = T.axis.spatial((n + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + i0_1_i1_1_i2_1_1_i3_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_2_i2_1_2_i3_2_fused // T.int64(16) * T.int64(4) + i2_1_3 * T.int64(4) + i2_1_4) - v_i3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + i0_1_i1_1_i2_1_1_i3_1_fused % T.int64(2) * T.int64(32) + i0_2_i1_2_i2_1_2_i3_2_fused % T.int64(16) * T.int64(2) + i3_3 + i3_4) - v_k = T.axis.reduce((m + T.int64(128) - T.int64(1)) // T.int64(128) * T.int64(128), k_0 * T.int64(128) + k_1_0 * T.int64(8) + k_1_1 * T.int64(4) + k_1_2) - T.reads(C_pad_local[v_i0, v_i1, v_i2, v_i3], A_pad_shared[T.int64(0), v_i1, v_i2, v_k], B_pad_shared[T.int64(0), v_i1, v_k, v_i3]) - T.writes(C_pad_local[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[v_i0, v_i1, v_i2, v_i3] = C_pad_local[v_i0, v_i1, v_i2, v_i3] + A_pad_shared[T.int64(0), v_i1, v_i2, v_k] * B_pad_shared[T.int64(0), v_i1, v_k, v_i3] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(4), T.int64(2)): - with T.block("C_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_0_fused // T.int64(8) + ax1) - v2 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), i2_0 * T.int64(128) + i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(8) // T.int64(2) * T.int64(32) + i0_1_i1_1_i2_1_1_i3_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_2_i2_1_2_i3_2_fused // T.int64(16) * T.int64(4) + ax2) - v3 = T.axis.spatial(T.int64(128), i0_0_i1_0_i2_1_0_i3_0_fused % T.int64(2) * T.int64(64) + i0_1_i1_1_i2_1_1_i3_1_fused % T.int64(2) * T.int64(32) + i0_2_i1_2_i2_1_2_i3_2_fused % T.int64(16) * T.int64(2) + ax3) - T.reads(C_pad_local[v0, v1, v2, v3]) - T.writes(C_pad[v0, v1, v2, v3]) - C_pad[v0, v1, v2, v3] = C_pad_local[v0, v1, v2, v3] - for i0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): - for i1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i2, i3 in T.grid(n, T.int64(128)): - with T.block("C_pad"): - vi0, vi1, vi2, vi3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(C_pad[vi0, vi1, vi2, vi3]) - T.writes(matmul[vi0, vi1, vi2, vi3]) - matmul[vi0, vi1, vi2, vi3] = C_pad[vi0, vi1, vi2, vi3] - - -@T.prim_func -def NT_matmul_before(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_NT_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (T.int64(1), n, T.int64(4096))) - NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(4096))) - # with T.block("root"): - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rxplaceholder_1[v_i0, v_i1, v_k], rxplaceholder[v_i2, v_k]) - T.writes(NT_matmul[v_i0, v_i1, v_i2]) - with T.init(): - NT_matmul[v_i0, v_i1, v_i2] = T.float32(0) - NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rxplaceholder_1[v_i0, v_i1, v_k] * rxplaceholder[v_i2, v_k] - - -@T.prim_func -def NT_matmul_after(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_NT_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (T.int64(1), n, T.int64(4096))) - NT_matmul_1 = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(4096))) - # with T.block("root"): - for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.y"): - with T.block("NT_matmul_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0) - T.reads(rxplaceholder_1[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)], rxplaceholder[T.int64(0):T.int64(4096), T.int64(0):T.int64(4096)]) - T.writes(NT_matmul_1[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) - C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local") - A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared") - rxplaceholder_shared = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="shared") - for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"): - for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(1), T.int64(2), T.int64(4), T.int64(2)): - with T.block("NT_matmul_init"): - v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3_init * T.int64(4) + i1_1_4_init) - v_i2_i = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3_init * T.int64(2) + i2_4_init) - T.reads() - T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0) - for k_0 in range(T.int64(128)): - for ax0_ax1_ax2_fused_0 in range(T.int64(8)): - for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(2)): - with T.block("A_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(128) + ax0_ax1_ax2_fused_1 * T.int64(2) + ax0_ax1_ax2_fused_2) // T.int64(32)) - v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(128) + ax0_ax1_ax2_fused_1 * T.int64(2) + ax0_ax1_ax2_fused_2) % T.int64(32)) - T.reads(rxplaceholder_1[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) - T.writes(A_pad_shared[v0, v1, v2]) - A_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, rxplaceholder_1[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0)) - for ax0_ax1_fused_0 in range(T.int64(8)): - for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_fused_2 in T.vectorized(T.int64(2)): - with T.block("rxplaceholder_shared"): - v0 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32)) - v1 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32)) - T.reads(rxplaceholder[v0, v1]) - T.writes(rxplaceholder_shared[v0, v1]) - rxplaceholder_shared[v0, v1] = rxplaceholder[v0, v1] - for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(1), T.int64(2), T.int64(4), T.int64(1), T.int64(4), T.int64(2)): - with T.block("NT_matmul_update"): - v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3 * T.int64(4) + i1_1_4) - v_i2_i = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3 * T.int64(2) + i2_4) - v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) - T.reads(C_pad_local[T.int64(0), v_i1_i, v_i2_i], A_pad_shared[T.int64(0), v_i1_i, v_k_i], rxplaceholder_shared[v_i2_i, v_k_i]) - T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[T.int64(0), v_i1_i, v_i2_i] = C_pad_local[T.int64(0), v_i1_i, v_i2_i] + A_pad_shared[T.int64(0), v_i1_i, v_k_i] * rxplaceholder_shared[v_i2_i, v_k_i] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(4), T.int64(4)): - with T.block("C_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + ax1) - v2 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + ax2) - T.reads(C_pad_local[v0, v1, v2]) - T.writes(NT_matmul_1[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) - # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n: - if v_i1_o * T.int64(32) + v1 < n: - NT_matmul_1[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = C_pad_local[v0, v1, v2] - - -@T.prim_func -def NT_matmul4_before(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((T.int64(32000), T.int64(4096)), "float32"), var_NT_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (T.int64(1), n, T.int64(4096))) - NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(32000))) - # with T.block("root"): - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(32000), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rxplaceholder_1[v_i0, v_i1, v_k], rxplaceholder[v_i2, v_k]) - T.writes(NT_matmul[v_i0, v_i1, v_i2]) - with T.init(): - NT_matmul[v_i0, v_i1, v_i2] = T.float32(0) - NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rxplaceholder_1[v_i0, v_i1, v_k] * rxplaceholder[v_i2, v_k] - - -def NT_matmul4_sch_func(): - sch = tvm.tir.Schedule(NT_matmul4_before) - b0 = sch.get_block("NT_matmul") - sch.pad_einsum(b0, [1, 32, 256, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - b0 = sch.get_block(name="NT_matmul", func_name="main") - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l2, l3, l4, l5 = sch.get_loops(block=b0) - v6, v7, v8, v9, v10 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l11, l12, l13, l14, l15 = sch.split(loop=l2, factors=[v6, v7, v8, v9, v10], preserve_unit_iters=True) - v16, v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 8, 4, 1]) - l21, l22, l23, l24, l25 = sch.split(loop=l3, factors=[v16, v17, v18, v19, v20], preserve_unit_iters=True) - v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[668, 1, 8, 1, 6]) - l31, l32, l33, l34, l35 = sch.split(loop=l4, factors=[v26, v27, v28, v29, v30], preserve_unit_iters=True) - v36, v37, v38 = sch.sample_perfect_tile(loop=l5, n=3, max_innermost_factor=64, decision=[128, 4, 8]) - l39, l40, l41 = sch.split(loop=l5, factors=[v36, v37, v38], preserve_unit_iters=True) - sch.reorder(l11, l21, l31, l12, l22, l32, l13, l23, l33, l39, l40, l14, l24, l34, l41, l15, l25, l35) - l42 = sch.fuse(l11, l21, l31, preserve_unit_iters=True) - sch.bind(loop=l42, thread_axis="blockIdx.x") - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="vthread.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b45 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b45, loop=l44, preserve_unit_loops=True, index=-1) - b46 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b46, loop=l39, preserve_unit_loops=True, index=-1) - _, l47, l48, l49, l50, l51, l52, l53 = sch.get_loops(block=b46) - l54 = sch.fuse(l51, l52, l53, preserve_unit_iters=True) - v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) - sch.annotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) - b56 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b56, loop=l39, preserve_unit_loops=True, index=-1) - _, l57, l58, l59, l60, l61, l62 = sch.get_loops(block=b56) - l63 = sch.fuse(l61, l62, preserve_unit_iters=True) - v64 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) - sch.annotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch", ann_val=v64) - v65 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v65) - sch.enter_postproc() - sch.unannotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch") - _, l66, l67, l68, l69, l70 = sch.get_loops(block=b46) - l71, l72, l73 = sch.split(loop=l70, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l73) - sch.bind(loop=l72, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch") - _, l74, l75, l76, l77, l78 = sch.get_loops(block=b56) - l79, l80, l81 = sch.split(loop=l78, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l81) - sch.bind(loop=l80, thread_axis="threadIdx.x") - b82 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") - _, b83, b84, b85, b86, _ = sch.get_child_blocks(b82) - _, l87, l88, l89, l90, l91, l92, l93 = sch.get_loops(block=b83) - sch.annotate(block_or_loop=l87, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l87, ann_key="pragma_unroll_explicit", ann_val=1) - _, l94, l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b84) - sch.annotate(block_or_loop=l94, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l94, ann_key="pragma_unroll_explicit", ann_val=1) - _, l101, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112 = sch.get_loops(block=b85) - sch.annotate(block_or_loop=l101, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l101, ann_key="pragma_unroll_explicit", ann_val=1) - _, l113, l114, l115, l116, l117, l118 = sch.get_loops(block=b86) - sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) - b119 = sch.get_block(name="NT_matmul", func_name="main") - _, l120, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b119) - b132 = sch.decompose_reduction(block=b119, loop=l123) - b1 = sch.get_block("rxplaceholder_1_pad") - sch.compute_inline(b1) - b3 = sch.get_block("NT_matmul_pad") - sch.reverse_compute_inline(b3) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def NT_matmul9_before(rxplaceholder: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), rxplaceholder_1: T.Buffer((T.int64(32000), T.int64(4096)), "float32"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rxplaceholder[v_i0, v_i1, v_k], rxplaceholder_1[v_i2, v_k]) - T.writes(NT_matmul[v_i0, v_i1, v_i2]) - with T.init(): - NT_matmul[v_i0, v_i1, v_i2] = T.float32(0) - NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rxplaceholder[v_i0, v_i1, v_k] * rxplaceholder_1[v_i2, v_k] - - -def NT_matmul9_sch_func(): - sch = tvm.tir.Schedule(NT_matmul9_before) - b0 = sch.get_block(name="NT_matmul", func_name="main") - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - l2, l3, l4, l5 = sch.get_loops(block=b0) - v6, v7, v8, v9, v10 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l11, l12, l13, l14, l15 = sch.split(loop=l2, factors=[v6, v7, v8, v9, v10], preserve_unit_iters=True) - v16, v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l21, l22, l23, l24, l25 = sch.split(loop=l3, factors=[v16, v17, v18, v19, v20], preserve_unit_iters=True) - v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[668, 1, 48, 1, 1]) - l31, l32, l33, l34, l35 = sch.split(loop=l4, factors=[v26, v27, v28, v29, v30], preserve_unit_iters=True) - v36, v37, v38 = sch.sample_perfect_tile(loop=l5, n=3, max_innermost_factor=64, decision=[64, 64, 1]) - l39, l40, l41 = sch.split(loop=l5, factors=[v36, v37, v38], preserve_unit_iters=True) - sch.reorder(l11, l21, l31, l12, l22, l32, l13, l23, l33, l39, l40, l14, l24, l34, l41, l15, l25, l35) - l42 = sch.fuse(l11, l21, l31, preserve_unit_iters=True) - sch.bind(loop=l42, thread_axis="blockIdx.x") - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="vthread.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b45 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b45, loop=l44, preserve_unit_loops=True, index=-1) - b46 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b46, loop=l39, preserve_unit_loops=True, index=-1) - l47, l48, l49, l50, l51, l52, l53 = sch.get_loops(block=b46) - l54 = sch.fuse(l51, l52, l53, preserve_unit_iters=True) - v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) - b56 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b56, loop=l39, preserve_unit_loops=True, index=-1) - l57, l58, l59, l60, l61, l62 = sch.get_loops(block=b56) - l63 = sch.fuse(l61, l62, preserve_unit_iters=True) - v64 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch", ann_val=v64) - v65 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v65) - sch.enter_postproc() - sch.unannotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch") - l66, l67, l68, l69, l70 = sch.get_loops(block=b46) - l71, l72, l73 = sch.split(loop=l70, factors=[None, 48, 2], preserve_unit_iters=True) - sch.vectorize(loop=l73) - sch.bind(loop=l72, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch") - l74, l75, l76, l77, l78 = sch.get_loops(block=b56) - l79, l80, l81 = sch.split(loop=l78, factors=[None, 48, 2], preserve_unit_iters=True) - sch.vectorize(loop=l81) - sch.bind(loop=l80, thread_axis="threadIdx.x") - b82 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") - b83, b84, b85, b86 = sch.get_child_blocks(b82) - l87, l88, l89, l90, l91, l92, l93 = sch.get_loops(block=b83) - sch.annotate(block_or_loop=l87, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l87, ann_key="pragma_unroll_explicit", ann_val=1) - l94, l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b84) - sch.annotate(block_or_loop=l94, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l94, ann_key="pragma_unroll_explicit", ann_val=1) - l101, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112 = sch.get_loops(block=b85) - sch.annotate(block_or_loop=l101, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l101, ann_key="pragma_unroll_explicit", ann_val=1) - l113, l114, l115, l116, l117, l118 = sch.get_loops(block=b86) - sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) - b119 = sch.get_block(name="NT_matmul", func_name="main") - l120, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b119) - b132 = sch.decompose_reduction(block=b119, loop=l123) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - - -@T.prim_func -def fused_matmul1_add1(p_lv39: T.handle, lv40: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), p_lv2: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv39 = T.match_buffer(p_lv39, (T.int64(1), n, T.int64(4096))) - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096))) - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096))) - # with T.block("root"): - var_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv39[v_i0, v_i1, v_k], lv40[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv39[v_i0, v_i1, v_k] * lv40[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv2[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -def fused_matmul1_add1_sch_func(): - sch = tvm.tir.Schedule(fused_matmul1_add1) - b0 = sch.get_block("matmul") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - b0 = sch.get_block(name="matmul", func_name="main") - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l3, l4, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 8, 4, 1, 1]) - l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[128, 2, 16, 1, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[512, 4, 2]) - l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="blockIdx.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="vthread.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) - b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) - _, l48, l49, l50, l51, l52, l53, l54 = sch.get_loops(block=b47) - l55 = sch.fuse(l52, l53, l54, preserve_unit_iters=True) - v56 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) - b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) - _, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b57) - l64 = sch.fuse(l62, l63, preserve_unit_iters=True) - v65 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) - sch.reverse_compute_inline(block=b1) - v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) - sch.enter_postproc() - sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") - _, l67, l68, l69, l70, l71 = sch.get_loops(block=b47) - l72, l73, l74 = sch.split(loop=l71, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l74) - sch.bind(loop=l73, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - _, l75, l76, l77, l78, l79 = sch.get_loops(block=b57) - l80, l81, l82 = sch.split(loop=l79, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l82) - sch.bind(loop=l81, thread_axis="threadIdx.x") - b83 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") - _, b84, b85, b86, b87, _ = sch.get_child_blocks(b83) - _, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84) - sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) - _, l95, l96, l97, l98, l99, l100, l101 = sch.get_loops(block=b85) - sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) - _, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b86) - sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) - _, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b87) - sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) - b120 = sch.get_block(name="matmul", func_name="main") - _, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b120) - b133 = sch.decompose_reduction(block=b120, loop=l124) - b1 = sch.get_block("lv39_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def fused_matmul3_multiply(p_lv43: T.handle, lv46: T.Buffer((T.int64(4096), T.int64(11008)), "float32"), p_lv48: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096))) - lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(11008))) - var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008))) - # with T.block("root"): - var_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv43[v_i0, v_i1, v_k], lv46[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * lv46[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv48[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = lv48[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -def fused_matmul3_multiply_sch_func(): - sch = tvm.tir.Schedule(fused_matmul3_multiply) - b0 = sch.get_block("matmul") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - b0 = sch.get_block(name="matmul", func_name="main") - b1 = sch.get_block(name="T_multiply", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l3, l4, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 4, 2, 4, 1]) - l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[344, 2, 16, 1, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[512, 1, 8]) - l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="blockIdx.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="vthread.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) - b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) - _, l48, l49, l50, l51, l52, l53, l54 = sch.get_loops(block=b47) - l55 = sch.fuse(l52, l53, l54, preserve_unit_iters=True) - v56 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) - b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) - _, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b57) - l64 = sch.fuse(l62, l63, preserve_unit_iters=True) - v65 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) - sch.reverse_compute_inline(block=b1) - v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) - sch.enter_postproc() - sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") - _, l67, l68, l69, l70, l71 = sch.get_loops(block=b47) - l72, l73, l74 = sch.split(loop=l71, factors=[None, 32, 2], preserve_unit_iters=True) - sch.vectorize(loop=l74) - sch.bind(loop=l73, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - _, l75, l76, l77, l78, l79 = sch.get_loops(block=b57) - l80, l81, l82 = sch.split(loop=l79, factors=[None, 32, 2], preserve_unit_iters=True) - sch.vectorize(loop=l82) - sch.bind(loop=l81, thread_axis="threadIdx.x") - b83 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") - _, b84, b85, b86, b87, _ = sch.get_child_blocks(b83) - _, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84) - sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) - _, l95, l96, l97, l98, l99, l100, l101 = sch.get_loops(block=b85) - sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) - _, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b86) - sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) - _, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b87) - sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) - b120 = sch.get_block(name="matmul", func_name="main") - _, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b120) - b133 = sch.decompose_reduction(block=b120, loop=l124) - b1 = sch.get_block("lv43_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def fused_matmul3_silu(p_lv43: T.handle, lv44: T.Buffer((T.int64(4096), T.int64(11008)), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096))) - var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008))) - # with T.block("root"): - var_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008))) - compute = T.alloc_buffer((T.int64(1), n, T.int64(11008))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv43[v_i0, v_i1, v_k], lv44[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * lv44[v_k, v_i2] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] - - -def fused_matmul3_silu_sch_func(): - sch = tvm.tir.Schedule(fused_matmul3_silu) - b0 = sch.get_block("matmul") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - b0 = sch.get_block(name="matmul", func_name="main") - b1 = sch.get_block(name="compute", func_name="main") - b2 = sch.get_block(name="T_multiply", func_name="main") - b3 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l4, l5, l6, l7 = sch.get_loops(block=b0) - v8, v9, v10, v11, v12 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l13, l14, l15, l16, l17 = sch.split(loop=l4, factors=[v8, v9, v10, v11, v12], preserve_unit_iters=True) - v18, v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 2, 2, 8, 1]) - l23, l24, l25, l26, l27 = sch.split(loop=l5, factors=[v18, v19, v20, v21, v22], preserve_unit_iters=True) - v28, v29, v30, v31, v32 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[344, 2, 16, 1, 1]) - l33, l34, l35, l36, l37 = sch.split(loop=l6, factors=[v28, v29, v30, v31, v32], preserve_unit_iters=True) - v38, v39, v40 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64, decision=[512, 1, 8]) - l41, l42, l43 = sch.split(loop=l7, factors=[v38, v39, v40], preserve_unit_iters=True) - sch.reorder(l13, l23, l33, l14, l24, l34, l15, l25, l35, l41, l42, l16, l26, l36, l43, l17, l27, l37) - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="blockIdx.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="vthread.x") - l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) - sch.bind(loop=l46, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b47 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b47, loop=l46, preserve_unit_loops=True, index=-1) - b48 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b48, loop=l41, preserve_unit_loops=True, index=-1) - _, l49, l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b48) - l56 = sch.fuse(l53, l54, l55, preserve_unit_iters=True) - v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch", ann_val=v57) - b58 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b58, loop=l41, preserve_unit_loops=True, index=-1) - _, l59, l60, l61, l62, l63, l64 = sch.get_loops(block=b58) - l65 = sch.fuse(l63, l64, preserve_unit_iters=True) - v66 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=3) - sch.annotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch", ann_val=v66) - sch.compute_inline(block=b1) - v67 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4) - sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v67) - l68, l69, l70 = sch.get_loops(block=b2) - l71 = sch.fuse(l68, l69, l70, preserve_unit_iters=True) - l72, l73, l74 = sch.split(loop=l71, factors=[None, 256, 256], preserve_unit_iters=True) - sch.reorder(l73, l74, l72) - sch.bind(loop=l73, thread_axis="blockIdx.x") - sch.bind(loop=l74, thread_axis="threadIdx.x") - sch.enter_postproc() - sch.unannotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch") - _, l75, l76, l77, l78, l79 = sch.get_loops(block=b48) - l80, l81, l82 = sch.split(loop=l79, factors=[None, 32, 2], preserve_unit_iters=True) - sch.vectorize(loop=l82) - sch.bind(loop=l81, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch") - _, l83, l84, l85, l86, l87 = sch.get_loops(block=b58) - l88, l89, l90 = sch.split(loop=l87, factors=[None, 32, 4], preserve_unit_iters=True) - sch.vectorize(loop=l90) - sch.bind(loop=l89, thread_axis="threadIdx.x") - b91 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b91, ann_key="meta_schedule.unroll_explicit") - _, b92, b93, b94, b95, _, b96 = sch.get_child_blocks(b91) - _, l97, l98, l99, l100, l101, l102, l103 = sch.get_loops(block=b92) - sch.annotate(block_or_loop=l97, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l97, ann_key="pragma_unroll_explicit", ann_val=1) - _, l104, l105, l106, l107, l108, l109, l110 = sch.get_loops(block=b93) - sch.annotate(block_or_loop=l104, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l104, ann_key="pragma_unroll_explicit", ann_val=1) - _, l111, l112, l113, l114, l115, l116, l117, l118, l119, l120, l121, l122 = sch.get_loops(block=b94) - sch.annotate(block_or_loop=l111, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l111, ann_key="pragma_unroll_explicit", ann_val=1) - _, l123, l124, l125, l126, l127, l128 = sch.get_loops(block=b95) - sch.annotate(block_or_loop=l123, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l123, ann_key="pragma_unroll_explicit", ann_val=1) - l129, l130, l131 = sch.get_loops(block=b96) - sch.annotate(block_or_loop=l129, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l129, ann_key="pragma_unroll_explicit", ann_val=1) - b132 = sch.get_block(name="matmul", func_name="main") - _, l133, l134, l135, l136, l137, l138, l139, l140, l141, l142, l143, l144 = sch.get_loops(block=b132) - b145 = sch.decompose_reduction(block=b132, loop=l136) - b1 = sch.get_block("lv43_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def fused_matmul4_add1(p_lv49: T.handle, lv50: T.Buffer((T.int64(11008), T.int64(4096)), "float32"), p_lv42: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(11008))) - lv42 = T.match_buffer(p_lv42, (T.int64(1), n, T.int64(4096))) - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096))) - # with T.block("root"): - var_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv49[v_i0, v_i1, v_k], lv50[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv49[v_i0, v_i1, v_k] * lv50[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv42[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv42[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -def fused_matmul4_add1_sch_func(): - sch = tvm.tir.Schedule(fused_matmul4_add1) - b0 = sch.get_block("matmul") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - b0 = sch.get_block(name="matmul", func_name="main") - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l3, l4, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 4, 8, 1]) - l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[128, 2, 16, 1, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[1376, 2, 4]) - l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="blockIdx.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="vthread.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) - b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) - _, l48, l49, l50, l51, l52, l53, l54 = sch.get_loops(block=b47) - l55 = sch.fuse(l52, l53, l54, preserve_unit_iters=True) - v56 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) - b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) - _, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b57) - l64 = sch.fuse(l62, l63, preserve_unit_iters=True) - v65 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) - sch.reverse_compute_inline(block=b1) - v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) - sch.enter_postproc() - sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") - _, l67, l68, l69, l70, l71 = sch.get_loops(block=b47) - l72, l73, l74 = sch.split(loop=l71, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l74) - sch.bind(loop=l73, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - _, l75, l76, l77, l78, l79 = sch.get_loops(block=b57) - l80, l81, l82 = sch.split(loop=l79, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l82) - sch.bind(loop=l81, thread_axis="threadIdx.x") - b83 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") - _, b84, b85, b86, b87, _ = sch.get_child_blocks(b83) - _, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84) - sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) - _, l95, l96, l97, l98, l99, l100, l101 = sch.get_loops(block=b85) - sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) - _, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b86) - sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) - _, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b87) - sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) - b120 = sch.get_block(name="matmul", func_name="main") - _, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b120) - b133 = sch.decompose_reduction(block=b120, loop=l124) - b1 = sch.get_block("lv49_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def fused_NT_matmul_add1_before(p_lv39: T.handle, linear_weight3: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), p_lv2: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv39 = T.match_buffer(p_lv39, (T.int64(1), n, T.int64(4096))) - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096))) - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv39[v_i0, v_i1, v_k], linear_weight3[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv39[v_i0, v_i1, v_k] * linear_weight3[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv2[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_NT_matmul_add1_after(p_lv33: T.handle, linear_weight3: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), p_lv2: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - lv33 = T.match_buffer(p_lv33, (T.int64(1), n, T.int64(4096))) - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096))) - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096))) - # with T.block("root"): - for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.y"): - with T.block("NT_matmul_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0) - T.reads(lv33[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)], linear_weight3[T.int64(0):T.int64(4096), T.int64(0):T.int64(4096)], lv2[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) - T.writes(var_T_add_intermediate[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local") - lv33_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared") - linear_weight3_shared = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="shared") - for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): - for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(1), T.int64(4), T.int64(1), T.int64(1)): - with T.block("NT_matmul_init"): - v_i1_i = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused * T.int64(8) + i0_2_i1_1_2_i2_2_fused // T.int64(8) + i1_1_3_init + i1_1_4_init) - v_i2_i = T.axis.spatial(T.int64(4096), i2_4_init + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3_init) - T.reads() - T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0) - for k_0 in range(T.int64(128)): - for ax0_ax1_ax2_fused_0 in range(T.int64(8)): - for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(2)): - with T.block("lv33_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(128) + ax0_ax1_ax2_fused_1 * T.int64(2) + ax0_ax1_ax2_fused_2) // T.int64(32)) - v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(128) + ax0_ax1_ax2_fused_1 * T.int64(2) + ax0_ax1_ax2_fused_2) % T.int64(32)) - T.reads(lv33[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) - T.writes(lv33_pad_shared[v0, v1, v2]) - lv33_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, lv33[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0)) - for ax0_ax1_fused_0 in range(T.int64(8)): - for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_fused_2 in T.vectorized(T.int64(2)): - with T.block("linear_weight3_shared"): - v0 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32)) - v1 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32)) - T.reads(linear_weight3[v0, v1]) - T.writes(linear_weight3_shared[v0, v1]) - linear_weight3_shared[v0, v1] = linear_weight3[v0, v1] - for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(1), T.int64(4), T.int64(4), T.int64(1), T.int64(1), T.int64(1)): - with T.block("NT_matmul_update"): - v_i1_i = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused * T.int64(8) + i0_2_i1_1_2_i2_2_fused // T.int64(8) + i1_1_3 + i1_1_4) - v_i2_i = T.axis.spatial(T.int64(4096), i2_4 + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3) - v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) - T.reads(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i], lv33_pad_shared[T.int64(0), v_i1_i, v_k_i], linear_weight3_shared[v_i2_i, v_k_i]) - T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] + lv33_pad_shared[T.int64(0), v_i1_i, v_k_i] * linear_weight3_shared[v_i2_i, v_k_i] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4)): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused * T.int64(8) + i0_2_i1_1_2_i2_2_fused // T.int64(8) + ax1) - v2 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + ax2) - T.reads(lv2[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], var_NT_matmul_intermediate_pad_local[v0, v1, v2]) - T.writes(var_T_add_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) - # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n: - if v_i1_o * T.int64(32) + v1 < n: - var_T_add_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = lv2[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] + var_NT_matmul_intermediate_pad_local[v0, v1, v2] - - -@T.prim_func -def fused_NT_matmul1_divide_add_maximum_before(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128))) - lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), n, T.int64(128))) - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, n)) - var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, n)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, n)) - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, n)) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, n)) - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, n, T.int64(128)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv28[T.int64(0), v_i1, v_i2, v_k], lv29[T.int64(0), v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[T.int64(0), v_i1, v_i2, v_k] * lv29[T.int64(0), v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] + lv5[v_ax0, T.int64(0), v_ax2, v_ax3] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) - - -@T.prim_func -def fused_NT_matmul1_divide_add_maximum_after(p_lv22: T.handle, p_lv23: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - lv22 = T.match_buffer(p_lv22, (T.int64(1), T.int64(32), n, T.int64(128))) - lv23 = T.match_buffer(p_lv23, (T.int64(1), T.int64(32), n, T.int64(128))) - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, n)) - var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, n)) - # with T.block("root"): - for i2_0_i3_0_fused in T.thread_binding((n + T.int64(31)) // T.int64(32) * ((n + T.int64(31)) // T.int64(32)), thread="blockIdx.y"): - with T.block("NT_matmul_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0_i3_0_fused // ((n + T.int64(31)) // T.int64(32))) - v_i3_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0_i3_0_fused % ((n + T.int64(31)) // T.int64(32))) - T.reads(lv22[T.int64(0), T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv23[T.int64(0), T.int64(0):T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv5[v_i0, T.int64(0), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) - T.writes(var_T_maximum_intermediate[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) - C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), scope="local") - A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), scope="shared") - B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), scope="shared") - for i0_0_i1_0_i2_1_0_i3_1_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_i2_1_1_i3_1_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): - for i0_2_i1_2_i2_1_2_i3_1_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for i1_3_init, i2_1_3_init, i3_1_3_init, i1_4_init, i2_1_4_init, i3_1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("NT_matmul_init"): - v_i1_i = T.axis.spatial(T.int64(32), i1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3_init) - v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3_init + i2_1_4_init) - v_i3_i = T.axis.spatial(T.int64(32), i3_1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3_init) - T.reads() - T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = T.float32(0) - for k_0 in range(T.int64(16)): - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): - with T.block("A_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) - v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) - v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) - T.reads(lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3]) - T.writes(A_pad_shared[v0, v1, v2, v3]) - A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i2_o * T.int64(32) + v2 < n, lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3], T.float32(0)) - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): - with T.block("B_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) - v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) - v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) - T.reads(lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3]) - T.writes(B_pad_shared[v0, v1, v2, v3]) - B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i3_o * T.int64(32) + v2 < n, lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3], T.float32(0)) - for k_1, i0_3, i1_3, i2_1_3, i3_1_3, k_2, i0_4, i1_4, i2_1_4, i3_1_4 in T.grid(T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("NT_matmul_update"): - v_i1_i = T.axis.spatial(T.int64(32), i1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3) - v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3 + i2_1_4) - v_i3_i = T.axis.spatial(T.int64(32), i3_1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3) - v_k_i = T.axis.reduce(T.int64(128), k_0 * T.int64(8) + k_1 * T.int64(2) + k_2) - T.reads(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i], A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i], B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i]) - T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] + A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i] * B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("C_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + ax1) - v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + ax2) - v3 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + ax3) - T.reads(C_pad_local[v0, v1, v2, v3], lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) - T.writes(var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) - # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i2_o * T.int64(32) + v2 and v_i2_o * T.int64(32) + v2 < n and T.int64(0) <= v_i3_o * T.int64(32) + v3 and v_i3_o * T.int64(32) + v3 < n: - if v_i2_o * T.int64(32) + v2 < n and v_i3_o * T.int64(32) + v3 < n: - var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3] = T.max(C_pad_local[v0, v1, v2, v3] * T.float32(0.088388349161020605) + lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3], T.float32(-3.4028234663852886e+38)) - -@T.prim_func -def fused_NT_matmul1_divide_add_maximum_with_m_before(p_lv30: T.handle, p_lv31: T.handle, p_lv7: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv30 = T.match_buffer(p_lv30, (T.int64(1), T.int64(32), n, T.int64(128))) - m = T.int64() - lv31 = T.match_buffer(p_lv31, (T.int64(1), T.int64(32), m, T.int64(128))) - lv7 = T.match_buffer(p_lv7, (T.int64(1), T.int64(1), n, m)) - var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv30[v_i0, v_i1, v_i2, v_k], lv31[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv30[v_i0, v_i1, v_i2, v_k] * lv31[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv7[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] + lv7[v_ax0, T.int64(0), v_ax2, v_ax3] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) - -@T.prim_func -def fused_NT_matmul1_divide_add_maximum_with_m_after(p_lv22: T.handle, p_lv23: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - m = T.int64() - lv22 = T.match_buffer(p_lv22, (T.int64(1), T.int64(32), n, T.int64(128))) - lv23 = T.match_buffer(p_lv23, (T.int64(1), T.int64(32), m, T.int64(128))) - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m)) - var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - for i2_0_i3_0_fused in T.thread_binding((n + T.int64(31)) // T.int64(32) * ((m + T.int64(31)) // T.int64(32)), thread="blockIdx.y"): - with T.block("NT_matmul_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0_i3_0_fused // ((m + T.int64(31)) // T.int64(32))) - v_i3_o = T.axis.spatial((m + T.int64(31)) // T.int64(32), i2_0_i3_0_fused % ((m + T.int64(31)) // T.int64(32))) - T.reads(lv22[T.int64(0), T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv23[T.int64(0), T.int64(0):T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv5[v_i0, T.int64(0), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) - T.writes(var_T_maximum_intermediate[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) - C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), scope="local") - A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), scope="shared") - B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), scope="shared") - for i0_0_i1_0_i2_1_0_i3_1_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_i2_1_1_i3_1_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): - for i0_2_i1_2_i2_1_2_i3_1_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for i1_3_init, i2_1_3_init, i3_1_3_init, i1_4_init, i2_1_4_init, i3_1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("NT_matmul_init"): - v_i1_i = T.axis.spatial(T.int64(32), i1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3_init) - v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3_init + i2_1_4_init) - v_i3_i = T.axis.spatial(T.int64(32), i3_1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3_init) - T.reads() - T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = T.float32(0) - for k_0 in range(T.int64(16)): - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): - with T.block("A_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) - v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) - v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) - T.reads(lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3]) - T.writes(A_pad_shared[v0, v1, v2, v3]) - A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i2_o * T.int64(32) + v2 < n, lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3], T.float32(0)) - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): - with T.block("B_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) - v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) - v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) - T.reads(lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3]) - T.writes(B_pad_shared[v0, v1, v2, v3]) - B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i3_o * T.int64(32) + v2 < m, lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3], T.float32(0)) - for k_1, i0_3, i1_3, i2_1_3, i3_1_3, k_2, i0_4, i1_4, i2_1_4, i3_1_4 in T.grid(T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("NT_matmul_update"): - v_i1_i = T.axis.spatial(T.int64(32), i1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3) - v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3 + i2_1_4) - v_i3_i = T.axis.spatial(T.int64(32), i3_1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3) - v_k_i = T.axis.reduce(T.int64(128), k_0 * T.int64(8) + k_1 * T.int64(2) + k_2) - T.reads(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i], A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i], B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i]) - T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] + A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i] * B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("C_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + ax1) - v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + ax2) - v3 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + ax3) - T.reads(C_pad_local[v0, v1, v2, v3], lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) - T.writes(var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) - # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i2_o * T.int64(32) + v2 and v_i2_o * T.int64(32) + v2 < n and T.int64(0) <= v_i3_o * T.int64(32) + v3 and v_i3_o * T.int64(32) + v3 < n: - if v_i2_o * T.int64(32) + v2 < n and v_i3_o * T.int64(32) + v3 < m: - var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3] = T.max(C_pad_local[v0, v1, v2, v3] * T.float32(0.088388349161020605) + lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3], T.float32(-3.4028234663852886e+38)) - - -@T.prim_func -def fused_NT_matmul6_divide1_add2_maximum1_before(lv2732: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32"), p_lv2733: T.handle, p_lv2709: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv2733 = T.match_buffer(p_lv2733, (T.int64(1), T.int64(32), n, T.int64(128))) - lv2709 = T.match_buffer(p_lv2709, (T.int64(1), T.int64(1), T.int64(1), n)) - var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv2732[T.int64(0), v_i1, v_i2, v_k], lv2733[T.int64(0), v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv2732[T.int64(0), v_i1, v_i2, v_k] * lv2733[T.int64(0), v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2709[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] + lv2709[v_ax0, T.int64(0), v_ax2, v_ax3] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) - - -@T.prim_func -def fused_NT_matmul6_divide1_add2_maximum1_after(lv2732: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32"), p_lv2733: T.handle, p_lv2709: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - lv2733 = T.match_buffer(p_lv2733, (T.int64(1), T.int64(32), n, T.int64(128))) - lv2709 = T.match_buffer(p_lv2709, (T.int64(1), T.int64(1), T.int64(1), n)) - var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32)), scope="local") - lv2732_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), scope="shared") - lv2733_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(128)), scope="shared") - for i3_0 in range((n + T.int64(31)) // T.int64(32)): - for i0_0_i1_0_i2_0_i3_1_0_fused in T.thread_binding(T.int64(32), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_i2_1_i3_1_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"): - for i0_2_i1_2_i2_2_i3_1_2_fused in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i0_3_init, i1_3_init, i2_3_init, i3_1_3_init, i0_4_init, i1_4_init, i2_4_init, i3_1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("NT_matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), i0_3_init + i0_4_init) - v_i1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_0_i3_1_0_fused // T.int64(4) * T.int64(4) + i0_2_i1_2_i2_2_i3_1_2_fused // T.int64(8) + i1_3_init + i1_4_init) - v_i2 = T.axis.spatial(T.int64(1), i2_3_init + i2_4_init) - v_i3 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i3_0 * T.int64(32) + i0_0_i1_0_i2_0_i3_1_0_fused % T.int64(4) * T.int64(8) + i0_2_i1_2_i2_2_i3_1_2_fused % T.int64(8) + i3_1_3_init + i3_1_4_init) - T.reads() - T.writes(var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - for k_0 in range(T.int64(8)): - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): - with T.block("lv2732_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_0_i3_1_0_fused // T.int64(4) * T.int64(4) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(64) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(16)) - v2 = T.axis.spatial(T.int64(1), T.int64(0)) - v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(64) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(16)) - T.reads(lv2732[v0, v1, v2, v3]) - T.writes(lv2732_shared[v0, v1, v2, v3]) - lv2732_shared[v0, v1, v2, v3] = lv2732[v0, v1, v2, v3] - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(4)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv2733_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_0_i3_1_0_fused // T.int64(4) * T.int64(4) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) // T.int64(128)) - v2 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i3_0 * T.int64(32) + i0_0_i1_0_i2_0_i3_1_0_fused % T.int64(4) * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(128) // T.int64(16)) - v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(4) + ax0_ax1_ax2_ax3_fused_2) % T.int64(16)) - T.reads(lv2733[v0, v1, v2, v3]) - T.writes(lv2733_pad_shared[v0, v1, v2, v3]) - lv2733_pad_shared[v0, v1, v2, v3] = T.if_then_else(v2 < n, lv2733[v0, v1, v2, v3], T.float32(0)) - for k_1, i0_3, i1_3, i2_3, i3_1_3, k_2, i0_4, i1_4, i2_4, i3_1_4 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(16), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("NT_matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), i0_3 + i0_4) - v_i1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_0_i3_1_0_fused // T.int64(4) * T.int64(4) + i0_2_i1_2_i2_2_i3_1_2_fused // T.int64(8) + i1_3 + i1_4) - v_i2 = T.axis.spatial(T.int64(1), i2_3 + i2_4) - v_i3 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i3_0 * T.int64(32) + i0_0_i1_0_i2_0_i3_1_0_fused % T.int64(4) * T.int64(8) + i0_2_i1_2_i2_2_i3_1_2_fused % T.int64(8) + i3_1_3 + i3_1_4) - v_k = T.axis.reduce(T.int64(128), k_0 * T.int64(16) + k_1 * T.int64(16) + k_2) - T.reads(var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2, v_i3], lv2732_shared[v_i0, v_i1, v_i2, v_k], lv2733_pad_shared[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2, v_i3]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2, v_i3] + lv2732_shared[v_i0, v_i1, v_i2, v_k] * lv2733_pad_shared[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_0_i3_1_0_fused // T.int64(4) * T.int64(4) + i0_2_i1_2_i2_2_i3_1_2_fused // T.int64(8) + ax1) - v2 = T.axis.spatial(T.int64(1), ax2) - v3 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i3_0 * T.int64(32) + i0_0_i1_0_i2_0_i3_1_0_fused % T.int64(4) * T.int64(8) + i0_2_i1_2_i2_2_i3_1_2_fused % T.int64(8) + ax3) - T.reads(var_NT_matmul_intermediate_pad_local[v0, v1, v2, v3]) - T.writes(var_NT_matmul_intermediate[v0, v1, v2, v3]) - if v3 < n: - var_NT_matmul_intermediate[v0, v1, v2, v3] = var_NT_matmul_intermediate_pad_local[v0, v1, v2, v3] - for ax0_ax1_ax2_ax3_fused_0 in T.thread_binding(n, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - with T.block("T_add"): - v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_ax1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_ax3_fused_0 * T.int64(32) + ax0_ax1_ax2_ax3_fused_1) // n) - v_ax2 = T.axis.spatial(T.int64(1), T.int64(0)) - v_ax3 = T.axis.spatial(n, (ax0_ax1_ax2_ax3_fused_0 * T.int64(32) + ax0_ax1_ax2_ax3_fused_1) % n) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv2709[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) + lv2709[v_ax0, T.int64(0), v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) - - -@T.prim_func -def fused_NT_matmul2_multiply_before(p_lv43: T.handle, linear_weight6: T.Buffer((T.int64(11008), T.int64(4096)), "float32"), p_lv48: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096))) - lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(11008))) - var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv43[v_i0, v_i1, v_k], linear_weight6[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * linear_weight6[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv48[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = lv48[v_ax0, v_ax1, v_ax2] * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_NT_matmul2_multiply_after(p_lv37: T.handle, linear_weight6: T.Buffer((T.int64(11008), T.int64(4096)), "float32"), p_lv42: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - lv37 = T.match_buffer(p_lv37, (T.int64(1), n, T.int64(4096))) - lv42 = T.match_buffer(p_lv42, (T.int64(1), n, T.int64(11008))) - var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008))) - # with T.block("root"): - for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.y"): - with T.block("NT_matmul_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0) - T.reads(lv37[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)], linear_weight6[T.int64(0):T.int64(11008), T.int64(0):T.int64(4096)], lv42[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(11008)]) - T.writes(var_T_multiply_intermediate[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(11008)]) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(11008)), scope="local") - lv37_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared") - linear_weight6_shared = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope="shared") - for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(344), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"): - for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(2)): - with T.block("NT_matmul_init"): - v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3_init * T.int64(2) + i1_1_4_init) - v_i2_i = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3_init * T.int64(2) + i2_4_init) - T.reads() - T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0) - for k_0 in range(T.int64(128)): - for ax0_ax1_ax2_fused_0 in range(T.int64(4)): - for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv37_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) // T.int64(32)) - v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) % T.int64(32)) - T.reads(lv37[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) - T.writes(lv37_pad_shared[v0, v1, v2]) - lv37_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, lv37[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0)) - for ax0_ax1_fused_0 in range(T.int64(8)): - for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_fused_2 in T.vectorized(T.int64(2)): - with T.block("linear_weight6_shared"): - v0 = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32)) - v1 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32)) - T.reads(linear_weight6[v0, v1]) - T.writes(linear_weight6_shared[v0, v1]) - linear_weight6_shared[v0, v1] = linear_weight6[v0, v1] - for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(2), T.int64(2), T.int64(4), T.int64(1), T.int64(2), T.int64(2)): - with T.block("NT_matmul_update"): - v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3 * T.int64(2) + i1_1_4) - v_i2_i = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3 * T.int64(2) + i2_4) - v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) - T.reads(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i], lv37_pad_shared[T.int64(0), v_i1_i, v_k_i], linear_weight6_shared[v_i2_i, v_k_i]) - T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] + lv37_pad_shared[T.int64(0), v_i1_i, v_k_i] * linear_weight6_shared[v_i2_i, v_k_i] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(4), T.int64(4)): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + ax1) - v2 = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + ax2) - T.reads(lv42[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], var_NT_matmul_intermediate_pad_local[v0, v1, v2]) - T.writes(var_T_multiply_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) - # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n: - if v_i1_o * T.int64(32) + v1 < n: - var_T_multiply_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = lv42[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] * var_NT_matmul_intermediate_pad_local[v0, v1, v2] - - -@T.prim_func -def fused_NT_matmul2_silu_before(p_lv43: T.handle, linear_weight4: T.Buffer((T.int64(11008), T.int64(4096)), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096))) - var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008))) - compute = T.alloc_buffer((T.int64(1), n, T.int64(11008))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv43[v_i0, v_i1, v_k], linear_weight4[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * linear_weight4[v_i2, v_k] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_NT_matmul2_silu_after(p_lv37: T.handle, linear_weight4: T.Buffer((T.int64(11008), T.int64(4096)), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - lv37 = T.match_buffer(p_lv37, (T.int64(1), n, T.int64(4096))) - var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008))) - for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.y"): - with T.block("NT_matmul_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0) - T.reads(lv37[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)], linear_weight4[T.int64(0):T.int64(11008), T.int64(0):T.int64(4096)]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(11008)]) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(11008)), scope="local") - lv37_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared") - linear_weight4_shared = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope="shared") - for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(344), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"): - for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(2), T.int64(4), T.int64(2), T.int64(1)): - with T.block("NT_matmul_init"): - v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3_init * T.int64(2) + i1_1_4_init) - v_i2_i = T.axis.spatial(T.int64(11008), i2_4_init + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3_init) - T.reads() - T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0) - for k_0 in range(T.int64(128)): - for ax0_ax1_ax2_fused_0 in range(T.int64(4)): - for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv37_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) // T.int64(32)) - v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) % T.int64(32)) - T.reads(lv37[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) - T.writes(lv37_pad_shared[v0, v1, v2]) - lv37_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, lv37[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0)) - for ax0_ax1_fused_0 in range(T.int64(8)): - for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_fused_2 in T.vectorized(T.int64(2)): - with T.block("linear_weight4_shared"): - v0 = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32)) - v1 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32)) - T.reads(linear_weight4[v0, v1]) - T.writes(linear_weight4_shared[v0, v1]) - linear_weight4_shared[v0, v1] = linear_weight4[v0, v1] - for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(2), T.int64(4), T.int64(4), T.int64(1), T.int64(2), T.int64(1)): - with T.block("NT_matmul_update"): - v_i1_i = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + i1_1_3 * T.int64(2) + i1_1_4) - v_i2_i = T.axis.spatial(T.int64(11008), i2_4 + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + i2_3) - v_k_i = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) - T.reads(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i], lv37_pad_shared[T.int64(0), v_i1_i, v_k_i], linear_weight4_shared[v_i2_i, v_k_i]) - T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] + lv37_pad_shared[T.int64(0), v_i1_i, v_k_i] * linear_weight4_shared[v_i2_i, v_k_i] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(4), T.int64(4)): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(4) + ax1) - v2 = T.axis.spatial(T.int64(11008), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(4) + ax2) - T.reads(var_NT_matmul_intermediate_pad_local[v0, v1, v2]) - T.writes(var_NT_matmul_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) - # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n: - if v_i1_o * T.int64(32) + v1 < n: - var_NT_matmul_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = var_NT_matmul_intermediate_pad_local[v0, v1, v2] - for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for ax0_ax1_ax2_fused_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0_ax1_ax2_fused_0 in range((n * T.int64(11008) + T.int64(65535)) // T.int64(65536)): - with T.block("T_multiply"): - v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_ax1 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * T.int64(65536) + ax0_ax1_ax2_fused_1 * T.int64(256) + ax0_ax1_ax2_fused_2) // T.int64(11008)) - v_ax2 = T.axis.spatial(T.int64(11008), (ax0_ax1_ax2_fused_0 * T.int64(65536) + ax0_ax1_ax2_fused_1 * T.int64(256) + ax0_ax1_ax2_fused_2) % T.int64(11008)) - T.where((ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1) * T.int64(256) + ax0_ax1_ax2_fused_2 < n * T.int64(11008)) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * T.sigmoid(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - - -@T.prim_func -def fused_NT_matmul3_add1_before(p_lv49: T.handle, linear_weight5: T.Buffer((T.int64(4096), T.int64(11008)), "float32"), p_lv42: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(11008))) - lv42 = T.match_buffer(p_lv42, (T.int64(1), n, T.int64(4096))) - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096))) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096))) - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv49[v_i0, v_i1, v_k], linear_weight5[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv49[v_i0, v_i1, v_k] * linear_weight5[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv42[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv42[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_NT_matmul3_add1_after(p_lv43: T.handle, linear_weight5: T.Buffer((T.int64(4096), T.int64(11008)), "float32"), p_lv36: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(11008))) - lv36 = T.match_buffer(p_lv36, (T.int64(1), n, T.int64(4096))) - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096))) - # with T.block("root"): - for i1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.y"): - with T.block("NT_matmul_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i1_0) - T.reads(lv43[T.Add(v_i0, T.int64(0)), v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(11008)], linear_weight5[T.int64(0):T.int64(4096), T.int64(0):T.int64(11008)], lv36[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) - T.writes(var_T_add_intermediate[v_i0, v_i1_o * T.int64(32):v_i1_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(4096)]) - var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local") - lv43_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(11008)), scope="shared") - linear_weight5_shared = T.alloc_buffer((T.int64(4096), T.int64(11008)), scope="shared") - for i0_0_i1_1_0_i2_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_1_i2_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): - for i0_2_i1_1_2_i2_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for i1_1_3_init, i2_3_init, i1_1_4_init, i2_4_init in T.grid(T.int64(2), T.int64(2), T.int64(1), T.int64(1)): - with T.block("NT_matmul_init"): - v_i1_i = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(2) + i1_1_3_init + i1_1_4_init) - v_i2_i = T.axis.spatial(T.int64(4096), i2_4_init + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_1_i1_1_1_i2_1_fused % T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(2) + i2_3_init) - T.reads() - T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = T.float32(0) - for k_0 in range(T.int64(344)): - for ax0_ax1_ax2_fused_0 in range(T.int64(4)): - for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv43_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) // T.int64(32)) - v2 = T.axis.spatial(T.int64(11008), k_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2) % T.int64(32)) - T.reads(lv43[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) - T.writes(lv43_pad_shared[v0, v1, v2]) - lv43_pad_shared[v0, v1, v2] = T.if_then_else(v_i1_o * T.int64(32) + v1 < n, lv43[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], T.float32(0)) - for ax0_ax1_fused_0 in range(T.int64(8)): - for ax0_ax1_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_fused_2 in T.vectorized(T.int64(2)): - with T.block("linear_weight5_shared"): - v0 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) // T.int64(32)) - v1 = T.axis.spatial(T.int64(11008), k_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * T.int64(2) + ax0_ax1_fused_2) % T.int64(32)) - T.reads(linear_weight5[v0, v1]) - T.writes(linear_weight5_shared[v0, v1]) - linear_weight5_shared[v0, v1] = linear_weight5[v0, v1] - for k_1, i0_3, i1_1_3, i2_3, k_2, i0_4, i1_1_4, i2_4 in T.grid(T.int64(8), T.int64(1), T.int64(2), T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(1)): - with T.block("NT_matmul_update"): - v_i1_i = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(2) + i1_1_3 + i1_1_4) - v_i2_i = T.axis.spatial(T.int64(4096), i2_4 + i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_1_i1_1_1_i2_1_fused % T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(2) + i2_3) - v_k_i = T.axis.reduce(T.int64(11008), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) - T.reads(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i], lv43_pad_shared[T.int64(0), v_i1_i, v_k_i], linear_weight5_shared[v_i2_i, v_k_i]) - T.writes(var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] = var_NT_matmul_intermediate_pad_local[T.int64(0), v_i1_i, v_i2_i] + lv43_pad_shared[T.int64(0), v_i1_i, v_k_i] * linear_weight5_shared[v_i2_i, v_k_i] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(2), T.int64(2)): - with T.block("var_NT_matmul_intermediate_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_1_i1_1_1_i2_1_fused // T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused // T.int64(8) * T.int64(2) + ax1) - v2 = T.axis.spatial(T.int64(4096), i0_0_i1_1_0_i2_0_fused * T.int64(32) + i0_1_i1_1_1_i2_1_fused % T.int64(2) * T.int64(16) + i0_2_i1_1_2_i2_2_fused % T.int64(8) * T.int64(2) + ax2) - T.reads(lv36[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2], var_NT_matmul_intermediate_pad_local[v0, v1, v2]) - T.writes(var_T_add_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2]) - # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i1_o * T.int64(32) + v1 and v_i1_o * T.int64(32) + v1 < n: - if v_i1_o * T.int64(32) + v1 < n: - var_T_add_intermediate[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] = lv36[v_i0 + v0, v_i1_o * T.int64(32) + v1, v2] + var_NT_matmul_intermediate_pad_local[v0, v1, v2] - - - -@T.prim_func -def fused_NT_matmul_divide_maximum_minimum_cast_before(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") - lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv1605[v_i0, v_i1, v_i2, v_k], lv1606[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1605[v_i0, v_i1, v_i2, v_k] * lv1606[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - -def fused_NT_matmul_divide_maximum_minimum_cast_sch_func(): - sch = tvm.tir.Schedule(fused_NT_matmul_divide_maximum_minimum_cast_before) - b_cast = sch.get_block("compute") - sch.reverse_compute_inline(b_cast) - b0 = sch.get_block("NT_matmul") - sch.pad_einsum(b0, [1, 1, 1, 32, 1]) - l1, l2, l3, l4, l5 = sch.get_loops(b0) - l6, l7 = sch.split(l4, [None, 32]) - sch.reorder(l6, l1, l2, l3, l7, l5) - - b0 = sch.get_block(name="NT_matmul", func_name="main") - b1 = sch.get_block(name="T_divide", func_name="main") - b2 = sch.get_block(name="T_maximum", func_name="main") - b3 = sch.get_block(name="T_minimum", func_name="main") - b4 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) - v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) - v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[8, 1, 4, 1, 1]) - l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) - v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) - v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 16, 1, 1]) - l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) - v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[4, 4, 8]) - l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) - sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) - l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) - sch.bind(loop=l56, thread_axis="blockIdx.x") - l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) - sch.bind(loop=l57, thread_axis="vthread.x") - l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) - sch.bind(loop=l58, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) - b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) - _, l61, l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b60) - l69 = sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) - v70 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) - b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) - _, l72, l73, l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b71) - l80 = sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) - v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) - sch.reverse_compute_inline(block=b3) - sch.compute_inline(block=b1) - v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) - sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) - - # inline ewise - sch.reverse_compute_inline(b2) - # l83, l84, l85, l86 = sch.get_loops(block=b2) - # l87 = sch.fuse(l83, l84, l85, l86, preserve_unit_iters=True) - # v88 = sch.sample_categorical(candidates=[32, 64, 128, 256], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - # l89, l90 = sch.split(loop=l87, factors=[None, v88], preserve_unit_iters=True) - # sch.bind(loop=l89, thread_axis="blockIdx.x") - # sch.bind(loop=l90, thread_axis="threadIdx.x") - - sch.enter_postproc() - sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") - _, l91, l92, l93, l94, l95 = sch.get_loops(block=b60) - l96, l97 = sch.split(loop=l95, factors=[None, 64], preserve_unit_iters=True) - sch.bind(loop=l97, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") - _, l98, l99, l100, l101, l102 = sch.get_loops(block=b71) - l103, l104 = sch.split(loop=l102, factors=[None, 64], preserve_unit_iters=True) - sch.bind(loop=l104, thread_axis="threadIdx.x") - b105 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b105, ann_key="meta_schedule.unroll_explicit") - _, b106, b107, b108, b109, _ = sch.get_child_blocks(b105) - _, l111, l112, l113, l114, l115, l116 = sch.get_loops(block=b106) - sch.annotate(block_or_loop=l111, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l111, ann_key="pragma_unroll_explicit", ann_val=1) - _, l117, l118, l119, l120, l121, l122 = sch.get_loops(block=b107) - sch.annotate(block_or_loop=l117, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l117, ann_key="pragma_unroll_explicit", ann_val=1) - _, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132, l133, l134, l135, l136 = sch.get_loops(block=b108) - sch.annotate(block_or_loop=l123, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l123, ann_key="pragma_unroll_explicit", ann_val=1) - _, l137, l138, l139, l140, l141, l142, l143 = sch.get_loops(block=b109) - sch.annotate(block_or_loop=l137, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l137, ann_key="pragma_unroll_explicit", ann_val=1) - - b146 = sch.get_block(name="NT_matmul", func_name="main") - l0, l147, l148, l149, l150, l151, l152, l153, l154, l155, l156, l157, l158, l159, l160 = sch.get_loops(block=b146) - sch.bind(l0, "blockIdx.y") - b161 = sch.decompose_reduction(block=b146, loop=l150) - - b1 = sch.get_block("lv1606_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - -@T.prim_func -def fused_NT_matmul_divide_maximum_minimum_before(lv1540: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32"), p_lv1541: T.handle, p_lv1517: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv1541 = T.match_buffer(p_lv1541, (T.int64(1), T.int64(32), n, T.int64(128))) - lv1517 = T.match_buffer(p_lv1517, (T.int64(1), T.int64(1), T.int64(1), n)) - var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv1540[v_i0, v_i1, v_i2, v_k], lv1541[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1540[v_i0, v_i1, v_i2, v_k] * lv1541[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1517[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1517[v_ax0, T.int64(0), v_ax2, v_ax3]) - -def fused_NT_matmul_divide_maximum_minimum_sch_func(): - sch = tvm.tir.Schedule(fused_NT_matmul_divide_maximum_minimum_before) - b0 = sch.get_block("NT_matmul") - sch.pad_einsum(b0, [1, 1, 1, 32, 1]) - l1, l2, l3, l4, l5 = sch.get_loops(b0) - l6, l7 = sch.split(l4, [None, 32]) - sch.reorder(l6, l1, l2, l3, l7, l5) - - b0 = sch.get_block(name="NT_matmul", func_name="main") - b1 = sch.get_block(name="T_divide", func_name="main") - b2 = sch.get_block(name="T_maximum", func_name="main") - b3 = sch.get_block(name="T_minimum", func_name="main") - b4 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) - v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) - v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[8, 1, 4, 1, 1]) - l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) - v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) - v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 16, 1, 1]) - l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) - v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[4, 4, 8]) - l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) - sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) - l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) - sch.bind(loop=l56, thread_axis="blockIdx.x") - l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) - sch.bind(loop=l57, thread_axis="vthread.x") - l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) - sch.bind(loop=l58, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) - b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) - _, l61, l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b60) - l69 = sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) - v70 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) - b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) - _, l72, l73, l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b71) - l80 = sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) - v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) - sch.reverse_compute_inline(block=b3) - sch.compute_inline(block=b1) - v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) - sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) - - # inline ewise - sch.reverse_compute_inline(b2) - # l83, l84, l85, l86 = sch.get_loops(block=b2) - # l87 = sch.fuse(l83, l84, l85, l86, preserve_unit_iters=True) - # v88 = sch.sample_categorical(candidates=[32, 64, 128, 256], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - # l89, l90 = sch.split(loop=l87, factors=[None, v88], preserve_unit_iters=True) - # sch.bind(loop=l89, thread_axis="blockIdx.x") - # sch.bind(loop=l90, thread_axis="threadIdx.x") - - sch.enter_postproc() - sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") - _, l91, l92, l93, l94, l95 = sch.get_loops(block=b60) - l96, l97 = sch.split(loop=l95, factors=[None, 64], preserve_unit_iters=True) - sch.bind(loop=l97, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") - _, l98, l99, l100, l101, l102 = sch.get_loops(block=b71) - l103, l104 = sch.split(loop=l102, factors=[None, 64], preserve_unit_iters=True) - sch.bind(loop=l104, thread_axis="threadIdx.x") - b105 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b105, ann_key="meta_schedule.unroll_explicit") - _, b106, b107, b108, b109, _ = sch.get_child_blocks(b105) - _, l111, l112, l113, l114, l115, l116 = sch.get_loops(block=b106) - sch.annotate(block_or_loop=l111, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l111, ann_key="pragma_unroll_explicit", ann_val=1) - _, l117, l118, l119, l120, l121, l122 = sch.get_loops(block=b107) - sch.annotate(block_or_loop=l117, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l117, ann_key="pragma_unroll_explicit", ann_val=1) - _, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132, l133, l134, l135, l136 = sch.get_loops(block=b108) - sch.annotate(block_or_loop=l123, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l123, ann_key="pragma_unroll_explicit", ann_val=1) - _, l137, l138, l139, l140, l141, l142, l143 = sch.get_loops(block=b109) - sch.annotate(block_or_loop=l137, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l137, ann_key="pragma_unroll_explicit", ann_val=1) - - b146 = sch.get_block(name="NT_matmul", func_name="main") - l0, l147, l148, l149, l150, l151, l152, l153, l154, l155, l156, l157, l158, l159, l160 = sch.get_loops(block=b146) - sch.bind(l0, "blockIdx.y") - b161 = sch.decompose_reduction(block=b146, loop=l150) - - b1 = sch.get_block("lv1541_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - -@T.prim_func -def fused_NT_matmul1_add3_before(p_lv39: T.handle, lv1848: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), p_lv2: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv39 = T.match_buffer(p_lv39, (T.int64(1), n, T.int64(4096)), "float16") - lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096)), "float16") - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv39[v_i0, v_i1, v_k], lv1848[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv39[v_i0, v_i1, v_k] * lv1848[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv2[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -def fused_NT_matmul1_add3_sch_func(): - sch = tvm.tir.Schedule(fused_NT_matmul1_add3_before) - b0 = sch.get_block("NT_matmul") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b0 = sch.get_block(name="NT_matmul", func_name="main") - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l3, l4, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 2, 8, 1, 2]) - l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[256, 1, 4, 4, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[256, 1, 16]) - l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="blockIdx.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="vthread.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) - b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) - _, l48, l49, l50, l51, l52, l53, l54 = sch.get_loops(block=b47) - l55 = sch.fuse(l52, l53, l54, preserve_unit_iters=True) - v56 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) - b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) - _, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b57) - l64 = sch.fuse(l62, l63, preserve_unit_iters=True) - v65 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) - sch.reverse_compute_inline(block=b1) - v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) - sch.enter_postproc() - sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") - _, l67, l68, l69, l70, l71 = sch.get_loops(block=b47) - l72, l73, l74 = sch.split(loop=l71, factors=[None, 32, 2], preserve_unit_iters=True) - sch.vectorize(loop=l74) - sch.bind(loop=l73, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - _, l75, l76, l77, l78, l79 = sch.get_loops(block=b57) - l80, l81, l82 = sch.split(loop=l79, factors=[None, 32, 4], preserve_unit_iters=True) - sch.vectorize(loop=l82) - sch.bind(loop=l81, thread_axis="threadIdx.x") - b83 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") - _, b84, b85, b86, b87, _ = sch.get_child_blocks(b83) - _, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84) - sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) - _, l95, l96, l97, l98, l99, l100, l101 = sch.get_loops(block=b85) - sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) - _, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b86) - sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) - _, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b87) - sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) - b120 = sch.get_block(name="NT_matmul", func_name="main") - l0, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b120) - sch.bind(l0, "blockIdx.y") - b133 = sch.decompose_reduction(block=b120, loop=l124) - - b1 = sch.get_block("lv39_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def fused_NT_matmul2_divide1_add2_maximum1_before(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") - lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, n), "float16") - var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, n), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, n), "float16") - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, n), "float16") - var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, n), "float16") - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, n, T.int64(128)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] + lv5[v_ax0, T.int64(0), v_ax2, v_ax3] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, n): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) - - -def fused_NT_matmul2_divide1_add2_maximum1_sch_func(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block("NT_matmul") - sch.pad_einsum(b0, [1, 1, 32, 32, 1]) - l1, l2, l3, l4, l5 = sch.get_loops(b0) - l6, l7 = sch.split(l3, [None, 32]) - l8, l9 = sch.split(l4, [None, 32]) - sch.reorder(l6, l8, l1, l2, l7, l9, l5) - - b0 = sch.get_block(name="NT_matmul", func_name="main") - b1 = sch.get_block(name="T_divide", func_name="main") - b2 = sch.get_block(name="T_add", func_name="main") - b3 = sch.get_block(name="T_maximum", func_name="main") - b4 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) - v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) - v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[32, 1, 1, 1, 1]) - l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) - v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[2, 1, 8, 1, 2]) - l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) - v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 8, 1, 2]) - l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) - v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[8, 16, 1]) - l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) - sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) - l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) - sch.bind(loop=l56, thread_axis="blockIdx.x") - l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) - sch.bind(loop=l57, thread_axis="vthread.x") - l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) - sch.bind(loop=l58, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) - b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) - _, _, l61, l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b60) - l69 = sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) - v70 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) - b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) - _, _, l72, l73, l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b71) - l80 = sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) - v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) - sch.reverse_compute_inline(block=b3) - sch.compute_inline(block=b1) - v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) - l83, l84, l85, l86 = sch.get_loops(block=b2) - l87 = sch.fuse(l83, l84, l85, l86, preserve_unit_iters=True) - v88 = sch.sample_categorical(candidates=[32, 64, 128, 256], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - l89, l90 = sch.split(loop=l87, factors=[None, v88], preserve_unit_iters=True) - sch.bind(loop=l89, thread_axis="blockIdx.x") - sch.bind(loop=l90, thread_axis="threadIdx.x") - sch.enter_postproc() - sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") - _, _, l91, l92, l93, l94, l95 = sch.get_loops(block=b60) - l96, l97, l98 = sch.split(loop=l95, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l98) - sch.bind(loop=l97, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") - _, _, l99, l100, l101, l102, l103 = sch.get_loops(block=b71) - l104, l105, l106 = sch.split(loop=l103, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l106) - sch.bind(loop=l105, thread_axis="threadIdx.x") - b107 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b107, ann_key="meta_schedule.unroll_explicit") - _, _, b108, b109, b110, b111, _, b112 = sch.get_child_blocks(b107) - _, _, l113, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b108) - sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) - _, _, l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b109) - sch.annotate(block_or_loop=l120, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l120, ann_key="pragma_unroll_explicit", ann_val=1) - _, _, l127, l128, l129, l130, l131, l132, l133, l134, l135, l136, l137, l138, l139, l140 = sch.get_loops(block=b110) - sch.annotate(block_or_loop=l127, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l127, ann_key="pragma_unroll_explicit", ann_val=1) - _, _, l141, l142, l143, l144, l145, l146, l147 = sch.get_loops(block=b111) - sch.annotate(block_or_loop=l141, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l141, ann_key="pragma_unroll_explicit", ann_val=1) - l148, l149 = sch.get_loops(block=b112) - sch.annotate(block_or_loop=l148, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l148, ann_key="pragma_unroll_explicit", ann_val=1) - b150 = sch.get_block(name="NT_matmul", func_name="main") - l0, l1, l151, l152, l153, l154, l155, l156, l157, l158, l159, l160, l161, l162, l163, l164 = sch.get_loops(block=b150) - l2 = sch.fuse(l0, l1) - sch.bind(l2, "blockIdx.y") - b165 = sch.decompose_reduction(block=b150, loop=l154) - - b1 = sch.get_block("lv28_pad") - sch.compute_inline(b1) - b2 = sch.get_block("lv29_pad") - sch.compute_inline(b2) - b3 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b3) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def fused_NT_matmul2_divide1_maximum1_minimum1_cast3_before(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") - m = T.int64() - lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), m, T.int64(128)), "float16") - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16") - var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) - var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) - -@T.prim_func -def fused_NT_matmul2_divide1_maximum1_minimum1_cast3_after(p_lv22: T.handle, p_lv23: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - m = T.int64() - lv22 = T.match_buffer(p_lv22, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") - lv23 = T.match_buffer(p_lv23, (T.int64(1), T.int64(32), m, T.int64(128)), "float16") - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16") - var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - for i2_0_i3_0_fused in T.thread_binding((n + T.int64(31)) // T.int64(32) * ((m + T.int64(31)) // T.int64(32)), thread="blockIdx.y"): - with T.block("NT_matmul_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0_i3_0_fused // ((m + T.int64(31)) // T.int64(32))) - v_i3_o = T.axis.spatial((m + T.int64(31)) // T.int64(32), i2_0_i3_0_fused % ((m + T.int64(31)) // T.int64(32))) - T.reads(lv22[T.int64(0), T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv23[T.int64(0), T.int64(0):T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv5[v_i0, T.int64(0), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) - T.writes(var_T_maximum_intermediate[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) - C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), "float16", scope="local") - A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), "float16", scope="shared") - B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), "float16", scope="shared") - for i0_0_i1_0_i2_1_0_i3_1_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_i2_1_1_i3_1_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): - for i0_2_i1_2_i2_1_2_i3_1_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for i1_3_init, i2_1_3_init, i3_1_3_init, i1_4_init, i2_1_4_init, i3_1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("NT_matmul_init"): - v_i1_i = T.axis.spatial(T.int64(32), i1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3_init) - v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3_init + i2_1_4_init) - v_i3_i = T.axis.spatial(T.int64(32), i3_1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3_init) - T.reads() - T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = T.float32(0) - for k_0 in range(T.int64(16)): - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): - with T.block("A_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) - v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) - v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) - T.reads(lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3]) - T.writes(A_pad_shared[v0, v1, v2, v3]) - A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i2_o * T.int64(32) + v2 < n, lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3], T.float32(0)) - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): - with T.block("B_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) - v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) - v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) - T.reads(lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3]) - T.writes(B_pad_shared[v0, v1, v2, v3]) - B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i3_o * T.int64(32) + v2 < m, lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3], T.float32(0)) - for k_1, i0_3, i1_3, i2_1_3, i3_1_3, k_2, i0_4, i1_4, i2_1_4, i3_1_4 in T.grid(T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("NT_matmul_update"): - v_i1_i = T.axis.spatial(T.int64(32), i1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3) - v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3 + i2_1_4) - v_i3_i = T.axis.spatial(T.int64(32), i3_1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3) - v_k_i = T.axis.reduce(T.int64(128), k_0 * T.int64(8) + k_1 * T.int64(2) + k_2) - T.reads(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i], A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i], B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i]) - T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] + A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i] * B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("C_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + ax1) - v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + ax2) - v3 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + ax3) - T.reads(C_pad_local[v0, v1, v2, v3], lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) - T.writes(var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) - # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i2_o * T.int64(32) + v2 and v_i2_o * T.int64(32) + v2 < n and T.int64(0) <= v_i3_o * T.int64(32) + v3 and v_i3_o * T.int64(32) + v3 < n: - if v_i2_o * T.int64(32) + v2 < n and v_i3_o * T.int64(32) + v3 < m: - var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3] = T.Cast("float32", T.min(T.max(C_pad_local[v0, v1, v2, v3] * T.float32(0.088397790055248615), T.float16(-65504)), lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3])) - -@T.prim_func -def fused_NT_matmul2_divide1_maximum1_minimum1_before(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128))) - m = T.int64() - lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), m, T.int64(128))) - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m)) - var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_maximum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): - with T.block("T_minimum"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) - var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) - -@T.prim_func -def fused_NT_matmul2_divide1_maximum1_minimum1_after(p_lv22: T.handle, p_lv23: T.handle, p_lv5: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - n = T.int64() - m = T.int64() - lv22 = T.match_buffer(p_lv22, (T.int64(1), T.int64(32), n, T.int64(128)), "float32") - lv23 = T.match_buffer(p_lv23, (T.int64(1), T.int64(32), m, T.int64(128)), "float32") - lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float32") - var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) - # with T.block("root"): - for i2_0_i3_0_fused in T.thread_binding((n + T.int64(31)) // T.int64(32) * ((m + T.int64(31)) // T.int64(32)), thread="blockIdx.y"): - with T.block("NT_matmul_o"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0_i3_0_fused // ((m + T.int64(31)) // T.int64(32))) - v_i3_o = T.axis.spatial((m + T.int64(31)) // T.int64(32), i2_0_i3_0_fused % ((m + T.int64(31)) // T.int64(32))) - T.reads(lv22[T.int64(0), T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv23[T.int64(0), T.int64(0):T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv5[v_i0, T.int64(0), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) - T.writes(var_T_maximum_intermediate[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) - C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), "float32", scope="local") - A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), "float32", scope="shared") - B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), "float32", scope="shared") - for i0_0_i1_0_i2_1_0_i3_1_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): - for i0_1_i1_1_i2_1_1_i3_1_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): - for i0_2_i1_2_i2_1_2_i3_1_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for i1_3_init, i2_1_3_init, i3_1_3_init, i1_4_init, i2_1_4_init, i3_1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("NT_matmul_init"): - v_i1_i = T.axis.spatial(T.int64(32), i1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3_init) - v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3_init + i2_1_4_init) - v_i3_i = T.axis.spatial(T.int64(32), i3_1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3_init) - T.reads() - T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = T.float32(0) - for k_0 in range(T.int64(16)): - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): - with T.block("A_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) - v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) - v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) - T.reads(lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3]) - T.writes(A_pad_shared[v0, v1, v2, v3]) - A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i2_o * T.int64(32) + v2 < n, lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3], T.float32(0)) - for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): - with T.block("B_pad_shared"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) - v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) - v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) - T.reads(lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3]) - T.writes(B_pad_shared[v0, v1, v2, v3]) - B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i3_o * T.int64(32) + v2 < m, lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3], T.float32(0)) - for k_1, i0_3, i1_3, i2_1_3, i3_1_3, k_2, i0_4, i1_4, i2_1_4, i3_1_4 in T.grid(T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("NT_matmul_update"): - v_i1_i = T.axis.spatial(T.int64(32), i1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3) - v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3 + i2_1_4) - v_i3_i = T.axis.spatial(T.int64(32), i3_1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3) - v_k_i = T.axis.reduce(T.int64(128), k_0 * T.int64(8) + k_1 * T.int64(2) + k_2) - T.reads(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i], A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i], B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i]) - T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) - T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) - C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] + A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i] * B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): - with T.block("C_pad_local"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + ax1) - v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + ax2) - v3 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + ax3) - T.reads(C_pad_local[v0, v1, v2, v3], lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) - T.writes(var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) - # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i2_o * T.int64(32) + v2 and v_i2_o * T.int64(32) + v2 < n and T.int64(0) <= v_i3_o * T.int64(32) + v3 and v_i3_o * T.int64(32) + v3 < n: - if v_i2_o * T.int64(32) + v2 < n and v_i3_o * T.int64(32) + v3 < m: - var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3] = T.min(T.max(C_pad_local[v0, v1, v2, v3] * T.float32(0.088397790055248615), T.float16(-65504)), lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) - -def fused_NT_matmul2_divide1_add2_maximum1_sch_func(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block("NT_matmul") - sch.pad_einsum(b0, [1, 1, 32, 32, 1]) - l1, l2, l3, l4, l5 = sch.get_loops(b0) - l6, l7 = sch.split(l3, [None, 32]) - l8, l9 = sch.split(l4, [None, 32]) - sch.reorder(l6, l8, l1, l2, l7, l9, l5) - - b0 = sch.get_block(name="NT_matmul", func_name="main") - b1 = sch.get_block(name="T_divide", func_name="main") - b2 = sch.get_block(name="T_add", func_name="main") - b3 = sch.get_block(name="T_maximum", func_name="main") - b4 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) - v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) - v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[32, 1, 1, 1, 1]) - l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) - v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[2, 1, 8, 1, 2]) - l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) - v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 8, 1, 2]) - l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) - v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[8, 16, 1]) - l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) - sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) - l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) - sch.bind(loop=l56, thread_axis="blockIdx.x") - l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) - sch.bind(loop=l57, thread_axis="vthread.x") - l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) - sch.bind(loop=l58, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) - b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) - _, _, l61, l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b60) - l69 = sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) - v70 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) - b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) - _, _, l72, l73, l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b71) - l80 = sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) - v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) - sch.reverse_compute_inline(block=b3) - sch.compute_inline(block=b1) - v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) - l83, l84, l85, l86 = sch.get_loops(block=b2) - l87 = sch.fuse(l83, l84, l85, l86, preserve_unit_iters=True) - v88 = sch.sample_categorical(candidates=[32, 64, 128, 256], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - l89, l90 = sch.split(loop=l87, factors=[None, v88], preserve_unit_iters=True) - sch.bind(loop=l89, thread_axis="blockIdx.x") - sch.bind(loop=l90, thread_axis="threadIdx.x") - sch.enter_postproc() - sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") - _, _, l91, l92, l93, l94, l95 = sch.get_loops(block=b60) - l96, l97, l98 = sch.split(loop=l95, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l98) - sch.bind(loop=l97, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") - _, _, l99, l100, l101, l102, l103 = sch.get_loops(block=b71) - l104, l105, l106 = sch.split(loop=l103, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l106) - sch.bind(loop=l105, thread_axis="threadIdx.x") - b107 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b107, ann_key="meta_schedule.unroll_explicit") - _, _, b108, b109, b110, b111, _, b112 = sch.get_child_blocks(b107) - _, _, l113, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b108) - sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) - _, _, l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b109) - sch.annotate(block_or_loop=l120, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l120, ann_key="pragma_unroll_explicit", ann_val=1) - _, _, l127, l128, l129, l130, l131, l132, l133, l134, l135, l136, l137, l138, l139, l140 = sch.get_loops(block=b110) - sch.annotate(block_or_loop=l127, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l127, ann_key="pragma_unroll_explicit", ann_val=1) - _, _, l141, l142, l143, l144, l145, l146, l147 = sch.get_loops(block=b111) - sch.annotate(block_or_loop=l141, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l141, ann_key="pragma_unroll_explicit", ann_val=1) - l148, l149 = sch.get_loops(block=b112) - sch.annotate(block_or_loop=l148, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l148, ann_key="pragma_unroll_explicit", ann_val=1) - b150 = sch.get_block(name="NT_matmul", func_name="main") - l0, l1, l151, l152, l153, l154, l155, l156, l157, l158, l159, l160, l161, l162, l163, l164 = sch.get_loops(block=b150) - l2 = sch.fuse(l0, l1) - sch.bind(l2, "blockIdx.y") - b165 = sch.decompose_reduction(block=b150, loop=l154) - - b1 = sch.get_block("lv28_pad") - sch.compute_inline(b1) - b2 = sch.get_block("lv29_pad") - sch.compute_inline(b2) - b3 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b3) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def fused_NT_matmul3_multiply1_before(p_lv43: T.handle, lv1866: T.Buffer((T.int64(11008), T.int64(4096)), "float16"), p_lv48: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096)), "float16") - lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(11008)), "float16") - var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv43[v_i0, v_i1, v_k], lv1866[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * lv1866[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv48[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = lv48[v_ax0, v_ax1, v_ax2] * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -def fused_NT_matmul3_multiply1_sch_func(): - sch = tvm.tir.Schedule(fused_NT_matmul3_multiply1_before) - b0 = sch.get_block("NT_matmul") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b0 = sch.get_block(name="NT_matmul", func_name="main") - b1 = sch.get_block(name="T_multiply", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l3, l4, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 8, 2, 2]) - l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[344, 4, 8, 1, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[128, 16, 2]) - l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="blockIdx.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="vthread.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) - b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) - _, l48, l49, l50, l51, l52, l53, l54 = sch.get_loops(block=b47) - l55 = sch.fuse(l52, l53, l54, preserve_unit_iters=True) - v56 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) - b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) - _, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b57) - l64 = sch.fuse(l62, l63, preserve_unit_iters=True) - v65 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) - sch.reverse_compute_inline(block=b1) - v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) - sch.enter_postproc() - sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") - _, l67, l68, l69, l70, l71 = sch.get_loops(block=b47) - l72, l73, l74 = sch.split(loop=l71, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l74) - sch.bind(loop=l73, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - _, l75, l76, l77, l78, l79 = sch.get_loops(block=b57) - l80, l81, l82 = sch.split(loop=l79, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l82) - sch.bind(loop=l81, thread_axis="threadIdx.x") - b83 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") - _, b84, b85, b86, b87, _ = sch.get_child_blocks(b83) - _, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84) - sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) - _, l95, l96, l97, l98, l99, l100, l101 = sch.get_loops(block=b85) - sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) - _, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b86) - sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) - _, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b87) - sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) - b120 = sch.get_block(name="NT_matmul", func_name="main") - l0, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b120) - sch.bind(l0, "blockIdx.y") - b133 = sch.decompose_reduction(block=b120, loop=l124) - - b1 = sch.get_block("lv43_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def fused_NT_matmul3_silu1_before(p_lv43: T.handle, lv1857: T.Buffer((T.int64(11008), T.int64(4096)), "float16"), p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv43 = T.match_buffer(p_lv43, (T.int64(1), n, T.int64(4096)), "float16") - var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16") - compute = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv43[v_i0, v_i1, v_k], lv1857[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * lv1857[v_i2, v_k] - for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] - - -def fused_NT_matmul3_silu1_sch_func(): - sch = tvm.tir.Schedule(fused_NT_matmul3_silu1_before) - b0 = sch.get_block("NT_matmul") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b0 = sch.get_block(name="NT_matmul", func_name="main") - b1 = sch.get_block(name="compute", func_name="main") - b2 = sch.get_block(name="T_multiply", func_name="main") - b3 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l4, l5, l6, l7 = sch.get_loops(block=b0) - v8, v9, v10, v11, v12 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l13, l14, l15, l16, l17 = sch.split(loop=l4, factors=[v8, v9, v10, v11, v12], preserve_unit_iters=True) - v18, v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 8, 4, 1]) - l23, l24, l25, l26, l27 = sch.split(loop=l5, factors=[v18, v19, v20, v21, v22], preserve_unit_iters=True) - v28, v29, v30, v31, v32 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[344, 4, 8, 1, 1]) - l33, l34, l35, l36, l37 = sch.split(loop=l6, factors=[v28, v29, v30, v31, v32], preserve_unit_iters=True) - v38, v39, v40 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64, decision=[128, 16, 2]) - l41, l42, l43 = sch.split(loop=l7, factors=[v38, v39, v40], preserve_unit_iters=True) - sch.reorder(l13, l23, l33, l14, l24, l34, l15, l25, l35, l41, l42, l16, l26, l36, l43, l17, l27, l37) - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="blockIdx.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="vthread.x") - l46 = sch.fuse(l15, l25, l35, preserve_unit_iters=True) - sch.bind(loop=l46, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b47 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b47, loop=l46, preserve_unit_loops=True, index=-1) - b48 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b48, loop=l41, preserve_unit_loops=True, index=-1) - _, l49, l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b48) - l56 = sch.fuse(l53, l54, l55, preserve_unit_iters=True) - v57 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch", ann_val=v57) - b58 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b58, loop=l41, preserve_unit_loops=True, index=-1) - _, l59, l60, l61, l62, l63, l64 = sch.get_loops(block=b58) - l65 = sch.fuse(l63, l64, preserve_unit_iters=True) - v66 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch", ann_val=v66) - sch.compute_inline(block=b1) - v67 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) - sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v67) - - # reverse compute inline the silu part - sch.reverse_compute_inline(b2) - # l68, l69, l70 = sch.get_loops(block=b2) - # l71 = sch.fuse(l68, l69, l70, preserve_unit_iters=True) - # l72, l73, l74 = sch.split(loop=l71, factors=[None, 256, 256], preserve_unit_iters=True) - #sch.reorder(l73, l74, l72) - # sch.bind(loop=l73, thread_axis="blockIdx.x") - # sch.bind(loop=l74, thread_axis="threadIdx.x") - sch.enter_postproc() - sch.unannotate(block_or_loop=b48, ann_key="meta_schedule.cooperative_fetch") - _, l75, l76, l77, l78, l79 = sch.get_loops(block=b48) - l80, l81, l82 = sch.split(loop=l79, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l82) - sch.bind(loop=l81, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b58, ann_key="meta_schedule.cooperative_fetch") - _, l83, l84, l85, l86, l87 = sch.get_loops(block=b58) - l88, l89, l90 = sch.split(loop=l87, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l90) - sch.bind(loop=l89, thread_axis="threadIdx.x") - b91 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b91, ann_key="meta_schedule.unroll_explicit") - _, b92, b93, b94, b95, _ = sch.get_child_blocks(b91) - _, l97, l98, l99, l100, l101, l102, l103 = sch.get_loops(block=b92) - sch.annotate(block_or_loop=l97, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l97, ann_key="pragma_unroll_explicit", ann_val=1) - _, l104, l105, l106, l107, l108, l109, l110 = sch.get_loops(block=b93) - sch.annotate(block_or_loop=l104, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l104, ann_key="pragma_unroll_explicit", ann_val=1) - _, l111, l112, l113, l114, l115, l116, l117, l118, l119, l120, l121, l122 = sch.get_loops(block=b94) - sch.annotate(block_or_loop=l111, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l111, ann_key="pragma_unroll_explicit", ann_val=1) - _, l123, l124, l125, l126, l127, l128 = sch.get_loops(block=b95) - sch.annotate(block_or_loop=l123, ann_key="pragma_auto_unroll_max_step", ann_val=16) - sch.annotate(block_or_loop=l123, ann_key="pragma_unroll_explicit", ann_val=1) - # l129, l130, l131 = sch.get_loops(block=b96) - # sch.annotate(block_or_loop=l129, ann_key="pragma_auto_unroll_max_step", ann_val=16) - # sch.annotate(block_or_loop=l129, ann_key="pragma_unroll_explicit", ann_val=1) - b132 = sch.get_block(name="NT_matmul", func_name="main") - l0, l133, l134, l135, l136, l137, l138, l139, l140, l141, l142, l143, l144 = sch.get_loops(block=b132) - sch.bind(l0, "blockIdx.y") - b145 = sch.decompose_reduction(block=b132, loop=l136) - - b1 = sch.get_block("lv43_pad") - sch.compute_inline(b1) - - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def fused_NT_matmul4_add3_before(p_lv49: T.handle, lv1875: T.Buffer((T.int64(4096), T.int64(11008)), "float16"), p_lv42: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(11008)), "float16") - lv42 = T.match_buffer(p_lv42, (T.int64(1), n, T.int64(4096)), "float16") - var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): - var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16") - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv49[v_i0, v_i1, v_k], lv1875[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv49[v_i0, v_i1, v_k] * lv1875[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv42[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv42[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -def fused_NT_matmul4_add3_sch_func(): - sch = tvm.tir.Schedule(fused_NT_matmul4_add3_before) - b0 = sch.get_block("NT_matmul") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b0 = sch.get_block(name="NT_matmul", func_name="main") - b1 = sch.get_block(name="T_add", func_name="main") - b2 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l3, l4, l5, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l3, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 8, 1, 4]) - l22, l23, l24, l25, l26 = sch.split(loop=l4, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[128, 2, 8, 2, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l5, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[688, 16, 1]) - l40, l41, l42 = sch.split(loop=l6, factors=[v37, v38, v39], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l13, l23, l33, l14, l24, l34, l40, l41, l15, l25, l35, l42, l16, l26, l36) - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="blockIdx.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="vthread.x") - l45 = sch.fuse(l14, l24, l34, preserve_unit_iters=True) - sch.bind(loop=l45, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b46 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b46, loop=l45, preserve_unit_loops=True, index=-1) - b47 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b47, loop=l40, preserve_unit_loops=True, index=-1) - _, l48, l49, l50, l51, l52, l53, l54 = sch.get_loops(block=b47) - l55 = sch.fuse(l52, l53, l54, preserve_unit_iters=True) - v56 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch", ann_val=v56) - b57 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l40, preserve_unit_loops=True, index=-1) - _, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b57) - l64 = sch.fuse(l62, l63, preserve_unit_iters=True) - v65 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) - sch.reverse_compute_inline(block=b1) - v66 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v66) - sch.enter_postproc() - sch.unannotate(block_or_loop=b47, ann_key="meta_schedule.cooperative_fetch") - _, l67, l68, l69, l70, l71 = sch.get_loops(block=b47) - l72, l73, l74 = sch.split(loop=l71, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l74) - sch.bind(loop=l73, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - _, l75, l76, l77, l78, l79 = sch.get_loops(block=b57) - l80, l81, l82 = sch.split(loop=l79, factors=[None, 64, 2], preserve_unit_iters=True) - sch.vectorize(loop=l82) - sch.bind(loop=l81, thread_axis="threadIdx.x") - b83 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b83, ann_key="meta_schedule.unroll_explicit") - _, b84, b85, b86, b87, _ = sch.get_child_blocks(b83) - _, l88, l89, l90, l91, l92, l93, l94 = sch.get_loops(block=b84) - sch.annotate(block_or_loop=l88, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l88, ann_key="pragma_unroll_explicit", ann_val=1) - _, l95, l96, l97, l98, l99, l100, l101 = sch.get_loops(block=b85) - sch.annotate(block_or_loop=l95, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l95, ann_key="pragma_unroll_explicit", ann_val=1) - _, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112, l113 = sch.get_loops(block=b86) - sch.annotate(block_or_loop=l102, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l102, ann_key="pragma_unroll_explicit", ann_val=1) - _, l114, l115, l116, l117, l118, l119 = sch.get_loops(block=b87) - sch.annotate(block_or_loop=l114, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l114, ann_key="pragma_unroll_explicit", ann_val=1) - b120 = sch.get_block(name="NT_matmul", func_name="main") - l0, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b120) - sch.bind(l0, "blockIdx.y") - b133 = sch.decompose_reduction(block=b120, loop=l124) - - b1 = sch.get_block("lv49_pad") - sch.compute_inline(b1) - b2 = sch.get_block("var_NT_matmul_intermediate_pad") - sch.reverse_compute_inline(b2) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def matmul1_fp16_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), T.int64(1), n), "float16") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k], rxplaceholder_1[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1[v_i0, v_i1, v_k, v_i3] - - -def matmul1_fp16_sch_func(): - sch = tvm.tir.Schedule(matmul1_fp16_before) - b0 = sch.get_block("matmul") - sch.pad_einsum(b0, [1, 1, 1, 1, 128]) - l1, l2, l3, l4, l5 = sch.get_loops(b0) - sch.split(l5, [None, 128]) - - b0 = sch.get_block(name="matmul", func_name="main") - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - l2, l3, l4, l5, ko, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[2, 1, 16, 1, 1]) - l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[8, 1, 16, 1, 1]) - l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) - v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[4, 16, 2]) - l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, ko, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) - l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) - sch.bind(loop=l53, thread_axis="blockIdx.x") - l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) - sch.bind(loop=l54, thread_axis="vthread.x") - l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) - sch.bind(loop=l55, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) - b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) - l58, l59, l60, _, l61, l62, l63, l64, l65 = sch.get_loops(block=b57) - l66 = sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) - v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) - b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) - l69, l70, l71, _, l72, l73, l74, l75, l76 = sch.get_loops(block=b68) - l77 = sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) - v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=1) - sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) - v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=2) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) - sch.enter_postproc() - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - l80, l81, l82, _, l83, l84 = sch.get_loops(block=b57) - l85, l86 = sch.split(loop=l84, factors=[None, 256], preserve_unit_iters=True) - sch.bind(loop=l86, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") - l87, l88, l89, _, l90, l91 = sch.get_loops(block=b68) - l92, l93, l94 = sch.split(loop=l91, factors=[None, 256, 2], preserve_unit_iters=True) - sch.vectorize(loop=l94) - sch.bind(loop=l93, thread_axis="threadIdx.x") - b95 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b95, ann_key="meta_schedule.unroll_explicit") - _, _, b96, b97, b98, b99 = sch.get_child_blocks(b95) - l100, l101, l102, _, l103, l104, l105 = sch.get_loops(block=b96) - sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) - l106, l107, l108, _, l109, l110, l111, l112 = sch.get_loops(block=b97) - sch.annotate(block_or_loop=l106, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l106, ann_key="pragma_unroll_explicit", ann_val=1) - l113, l114, l115, _, l116, l117, l118, l119, l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b98) - sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) - l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b99) - sch.annotate(block_or_loop=l127, ann_key="pragma_auto_unroll_max_step", ann_val=64) - sch.annotate(block_or_loop=l127, ann_key="pragma_unroll_explicit", ann_val=1) - b134 = sch.get_block(name="matmul", func_name="main") - l135, l136, l137, ko, l138, l139, l140, l141, l142, l143, l144, l145, l146, l147, l148 = sch.get_loops(block=b134) - b149 = sch.decompose_reduction(block=b134, loop=ko) - - b1 = sch.get_block("rxplaceholder_pad") - sch.compute_inline(b1) - b2 = sch.get_block("rxplaceholder_1_pad") - sch.compute_inline(b2) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def matmul8_fp16_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, n), "float16") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") - matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), n): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k], rxplaceholder_1[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1[v_i0, v_i1, v_k, v_i3] - -@T.prim_func -def matmul8_with_m_fp16_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - m = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), T.int64(32), n, m), "float16") - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(1), T.int64(32), m, T.int64(128)), "float16") - matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") - # with T.block("root"): - for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), m): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k], rxplaceholder_1[v_i0, v_i1, v_k, v_i3]) - T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) - with T.init(): - matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1[v_i0, v_i1, v_k, v_i3] - -def matmul8_fp16_sch_func(func): - sch = tvm.tir.Schedule(func) - b0 = sch.get_block("matmul") - sch.pad_einsum(b0, [1, 1, 32, 1, 128]) - l1, l2, l3, l4, l5 = sch.get_loops(b0) - l6, l7 = sch.split(l3, [None, 32]) - l8, l9 = sch.split(l5, [None, 128]) - sch.reorder(l6, l1, l2, l7, l4, l8, l9) - - b0 = sch.get_block(name="matmul", func_name="main") - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l2, l3, l4, l5, ko, l6 = sch.get_loops(block=b0) - v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l12, l13, l14, l15, l16 = sch.split(loop=l2, factors=[v7, v8, v9, v10, v11], preserve_unit_iters=True) - v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[32, 1, 1, 1, 1]) - l22, l23, l24, l25, l26 = sch.split(loop=l3, factors=[v17, v18, v19, v20, v21], preserve_unit_iters=True) - v27, v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 4, 2, 4]) - l32, l33, l34, l35, l36 = sch.split(loop=l4, factors=[v27, v28, v29, v30, v31], preserve_unit_iters=True) - v37, v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[4, 1, 16, 2, 1]) - l42, l43, l44, l45, l46 = sch.split(loop=l5, factors=[v37, v38, v39, v40, v41], preserve_unit_iters=True) - v47, v48, v49 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64, decision=[16, 1, 8]) - l50, l51, l52 = sch.split(loop=l6, factors=[v47, v48, v49], preserve_unit_iters=True) - sch.reorder(l12, l22, l32, l42, l13, l23, l33, l43, l14, l24, l34, l44, ko, l50, l51, l15, l25, l35, l45, l52, l16, l26, l36, l46) - l53 = sch.fuse(l12, l22, l32, l42, preserve_unit_iters=True) - sch.bind(loop=l53, thread_axis="blockIdx.x") - l54 = sch.fuse(l13, l23, l33, l43, preserve_unit_iters=True) - sch.bind(loop=l54, thread_axis="vthread.x") - l55 = sch.fuse(l14, l24, l34, l44, preserve_unit_iters=True) - sch.bind(loop=l55, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b56 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b56, loop=l55, preserve_unit_loops=True, index=-1) - b57 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b57, loop=l50, preserve_unit_loops=True, index=-1) - _, l58, l59, l60, _, l61, l62, l63, l64, l65 = sch.get_loops(block=b57) - l66 = sch.fuse(l62, l63, l64, l65, preserve_unit_iters=True) - v67 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) - b68 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b68, loop=l50, preserve_unit_loops=True, index=-1) - _, l69, l70, l71, _, l72, l73, l74, l75, l76 = sch.get_loops(block=b68) - l77 = sch.fuse(l73, l74, l75, l76, preserve_unit_iters=True) - v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) - sch.annotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) - v79 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=3) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v79) - sch.enter_postproc() - sch.unannotate(block_or_loop=b57, ann_key="meta_schedule.cooperative_fetch") - _, l80, l81, l82, _, l83, l84 = sch.get_loops(block=b57) - l85, l86, l87 = sch.split(loop=l84, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l87) - sch.bind(loop=l86, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b68, ann_key="meta_schedule.cooperative_fetch") - _, l88, l89, l90, _, l91, l92 = sch.get_loops(block=b68) - l93, l94 = sch.split(loop=l92, factors=[None, 64], preserve_unit_iters=True) - sch.bind(loop=l94, thread_axis="threadIdx.x") - b95 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b95, ann_key="meta_schedule.unroll_explicit") - _, _, b96, b97, b98, b99, _ = sch.get_child_blocks(b95) - _, l100, l101, l102, _, l103, l104, l105, l106 = sch.get_loops(block=b96) - sch.annotate(block_or_loop=l100, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l100, ann_key="pragma_unroll_explicit", ann_val=1) - _, l107, l108, l109, _, l110, l111, l112 = sch.get_loops(block=b97) - sch.annotate(block_or_loop=l107, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l107, ann_key="pragma_unroll_explicit", ann_val=1) - _, l113, l114, l115, _, l116, l117, l118, l119, l120, l121, l122, l123, l124, l125, l126 = sch.get_loops(block=b98) - sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) - _, l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b99) - sch.annotate(block_or_loop=l127, ann_key="pragma_auto_unroll_max_step", ann_val=512) - sch.annotate(block_or_loop=l127, ann_key="pragma_unroll_explicit", ann_val=1) - b134 = sch.get_block(name="matmul", func_name="main") - l0, l135, l136, l137, ko, l138, l139, l140, l141, l142, l143, l144, l145, l146, l147, l148 = sch.get_loops(block=b134) - sch.bind(l0, "blockIdx.y") - b149 = sch.decompose_reduction(block=b134, loop=ko) - - b1 = sch.get_block("rxplaceholder_pad") - sch.compute_inline(b1) - b2 = sch.get_block("rxplaceholder_1_pad") - sch.compute_inline(b2) - b3 = sch.get_block("matmul_pad") - sch.reverse_compute_inline(b3) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def NT_matmul1_fp16_before(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), var_NT_matmul: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int64() - rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (T.int64(1), n, T.int64(4096)), "float16") - NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(4096)), "float16") - # with T.block("root"): - for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(rxplaceholder_1[v_i0, v_i1, v_k], rxplaceholder[v_i2, v_k]) - T.writes(NT_matmul[v_i0, v_i1, v_i2]) - with T.init(): - NT_matmul[v_i0, v_i1, v_i2] = T.float16(0) - NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + rxplaceholder_1[v_i0, v_i1, v_k] * rxplaceholder[v_i2, v_k] - - -def NT_matmul1_fp16_sch_func(): - sch = tvm.tir.Schedule(NT_matmul1_fp16_before) - b0 = sch.get_block("NT_matmul") - sch.pad_einsum(b0, [1, 32, 1, 1]) - l1, l2, l3, l4 = sch.get_loops(b0) - l5, l6 = sch.split(l2, [None, 32]) - sch.reorder(l5, l1, l6, l3, l4) - - b0 = sch.get_block(name="NT_matmul", func_name="main") - b1 = sch.get_block(name="root", func_name="main") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") - _, l2, l3, l4, l5 = sch.get_loops(block=b0) - v6, v7, v8, v9, v10 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) - l11, l12, l13, l14, l15 = sch.split(loop=l2, factors=[v6, v7, v8, v9, v10], preserve_unit_iters=True) - v16, v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[1, 1, 4, 2, 4]) - l21, l22, l23, l24, l25 = sch.split(loop=l3, factors=[v16, v17, v18, v19, v20], preserve_unit_iters=True) - v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[128, 1, 16, 1, 2]) - l31, l32, l33, l34, l35 = sch.split(loop=l4, factors=[v26, v27, v28, v29, v30], preserve_unit_iters=True) - v36, v37, v38 = sch.sample_perfect_tile(loop=l5, n=3, max_innermost_factor=64, decision=[512, 2, 4]) - l39, l40, l41 = sch.split(loop=l5, factors=[v36, v37, v38], preserve_unit_iters=True) - sch.reorder(l11, l21, l31, l12, l22, l32, l13, l23, l33, l39, l40, l14, l24, l34, l41, l15, l25, l35) - l42 = sch.fuse(l11, l21, l31, preserve_unit_iters=True) - sch.bind(loop=l42, thread_axis="blockIdx.x") - l43 = sch.fuse(l12, l22, l32, preserve_unit_iters=True) - sch.bind(loop=l43, thread_axis="vthread.x") - l44 = sch.fuse(l13, l23, l33, preserve_unit_iters=True) - sch.bind(loop=l44, thread_axis="threadIdx.x") - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) - sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) - b45 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") - sch.reverse_compute_at(block=b45, loop=l44, preserve_unit_loops=True, index=-1) - b46 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b46, loop=l39, preserve_unit_loops=True, index=-1) - _, l47, l48, l49, l50, l51, l52, l53 = sch.get_loops(block=b46) - l54 = sch.fuse(l51, l52, l53, preserve_unit_iters=True) - v55 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) - b56 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) - sch.compute_at(block=b56, loop=l39, preserve_unit_loops=True, index=-1) - _, l57, l58, l59, l60, l61, l62 = sch.get_loops(block=b56) - l63 = sch.fuse(l61, l62, preserve_unit_iters=True) - v64 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=2) - sch.annotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch", ann_val=v64) - v65 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4) - sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v65) - sch.enter_postproc() - sch.unannotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch") - _, l66, l67, l68, l69, l70 = sch.get_loops(block=b46) - l71, l72, l73 = sch.split(loop=l70, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l73) - sch.bind(loop=l72, thread_axis="threadIdx.x") - sch.unannotate(block_or_loop=b56, ann_key="meta_schedule.cooperative_fetch") - _, l74, l75, l76, l77, l78 = sch.get_loops(block=b56) - l79, l80, l81 = sch.split(loop=l78, factors=[None, 64, 4], preserve_unit_iters=True) - sch.vectorize(loop=l81) - sch.bind(loop=l80, thread_axis="threadIdx.x") - b82 = sch.get_block(name="root", func_name="main") - sch.unannotate(block_or_loop=b82, ann_key="meta_schedule.unroll_explicit") - _, b83, b84, b85, b86, _ = sch.get_child_blocks(b82) - _, l87, l88, l89, l90, l91, l92, l93 = sch.get_loops(block=b83) - sch.annotate(block_or_loop=l87, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l87, ann_key="pragma_unroll_explicit", ann_val=1) - _, l94, l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b84) - sch.annotate(block_or_loop=l94, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l94, ann_key="pragma_unroll_explicit", ann_val=1) - _, l101, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111, l112 = sch.get_loops(block=b85) - sch.annotate(block_or_loop=l101, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l101, ann_key="pragma_unroll_explicit", ann_val=1) - _, l113, l114, l115, l116, l117, l118 = sch.get_loops(block=b86) - sch.annotate(block_or_loop=l113, ann_key="pragma_auto_unroll_max_step", ann_val=1024) - sch.annotate(block_or_loop=l113, ann_key="pragma_unroll_explicit", ann_val=1) - b119 = sch.get_block(name="NT_matmul", func_name="main") - l0, l120, l121, l122, l123, l124, l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b119) - sch.bind(l0, "blockIdx.y") - b132 = sch.decompose_reduction(block=b119, loop=l123) - - b1 = sch.get_block("rxplaceholder_1_pad") - sch.compute_inline(b1) - b2 = sch.get_block("NT_matmul_pad") - sch.reverse_compute_inline(b2) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def decode6(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(4096), T.int64(4096))) - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - - -@T.prim_func -def decode7(rxplaceholder: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(11008)), "uint32"), T_transpose: T.Buffer((T.int64(11008), T.int64(4096)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(4096), T.int64(11008))) - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - - -@T.prim_func -def decode8(rxplaceholder: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(344), T.int64(4096)), "uint32"), T_transpose: T.Buffer((T.int64(4096), T.int64(11008)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(11008), T.int64(4096))) - for i, j in T.grid(T.int64(11008), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - - -@T.prim_func -def decode4_fp16(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(4096)), "float16"), rxplaceholder_2: T.Buffer((T.int64(128), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j], rxplaceholder_2[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * rxplaceholder_1[v_i // T.int64(32), v_j] + rxplaceholder_2[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - -@T.prim_func -def decode5_fp16(rxplaceholder: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(11008)), "float16"), rxplaceholder_2: T.Buffer((T.int64(128), T.int64(11008)), "float16"), T_transpose: T.Buffer((T.int64(11008), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j], rxplaceholder_2[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * rxplaceholder_1[v_i // T.int64(32), v_j] + rxplaceholder_2[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - -@T.prim_func -def decode6_fp16(rxplaceholder: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(344), T.int64(4096)), "float16"), rxplaceholder_2: T.Buffer((T.int64(344), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(11008)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(11008), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(rxplaceholder[v_i // T.int64(8), v_j], rxplaceholder_1[v_i // T.int64(32), v_j], rxplaceholder_2[v_i // T.int64(32), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * rxplaceholder_1[v_i // T.int64(32), v_j] + rxplaceholder_2[v_i // T.int64(32), v_j] - for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - - -@T.prim_func -def decode_int3_fp16(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B: T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode_1 = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[v_i // T.int64(10), v_j], B[v_i // T.int64(40), v_j]) - T.writes(decode_1[v_i, v_j]) - decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] - for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode_1[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode_1[v_ax1, v_ax0] - -@T.prim_func -def decode1_int3_fp16(A: T.Buffer((T.int64(412), T.int64(11008)), "uint32"), B: T.Buffer((T.int64(103), T.int64(11008)), "float16"), T_transpose: T.Buffer((T.int64(11008), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[v_i // T.int64(10), v_j], B[v_i // T.int64(40), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] - for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - -@T.prim_func -def decode2_int3_fp16(A: T.Buffer((T.int64(1104), T.int64(4096)), "uint32"), B: T.Buffer((T.int64(276), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(11008)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(11008), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[v_i // T.int64(10), v_j], B[v_i // T.int64(40), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] - for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - - -@T.prim_func -def decode_int3_int16_fp16(A: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), B: T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode_1 = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j]) - T.writes(decode_1[v_i, v_j]) - decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] - for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode_1[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode_1[v_ax1, v_ax0] - -@T.prim_func -def decode1_int3_int16_fp16(A: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), B: T.Buffer((T.int64(103), T.int64(11008)), "float16"), T_transpose: T.Buffer((T.int64(11008), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] - for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - -@T.prim_func -def decode2_int3_int16_fp16(A: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), B: T.Buffer((T.int64(276), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(11008)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(11008), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j]) - T.writes(decode[v_i, v_j]) - decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] - for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(decode[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] - - -def decode_sch_func(orig_func): - sch = tvm.tir.Schedule(orig_func) - b0 = sch.get_block(name="decode", func_name="main") - l1, l2 = sch.get_loops(block=b0) - l3, l4 = sch.split(loop=l1, factors=[None, 8], preserve_unit_iters=True) - v5, v6, v7 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=4, decision=[32, 8, 2]) - l8, l9, l10 = sch.split(loop=l3, factors=[v5, v6, v7], preserve_unit_iters=True) - v11, v12 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[256, 16]) - l13, l14 = sch.split(loop=l2, factors=[v11, v12], preserve_unit_iters=True) - sch.reorder(l8, l13, l9, l14, l10, l4) - sch.bind(loop=l8, thread_axis="blockIdx.y") - sch.bind(loop=l13, thread_axis="blockIdx.x") - sch.bind(loop=l9, thread_axis="threadIdx.y") - sch.bind(loop=l14, thread_axis="threadIdx.x") - sch.unroll(loop=l4) - b15 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="shared") - sch.compute_inline(block=b15) - b16 = sch.get_block(name="T_transpose", func_name="main") - sch.reverse_compute_at(block=b16, loop=l13, preserve_unit_loops=True, index=-1) - b17 = sch.get_block(name="T_transpose", func_name="main") - l18, l19, l20, l21 = sch.get_loops(block=b17) - l22 = sch.fuse(l20, l21, preserve_unit_iters=True) - l23, l24, l25 = sch.split(loop=l22, factors=[None, v12, 4], preserve_unit_iters=True) - sch.bind(loop=l24, thread_axis="threadIdx.x") - sch.vectorize(loop=l25) - sch.storage_align(block=b0, buffer_index=0, axis=0, factor=32, offset=1) - return sch.mod["main"].with_attr("tir.is_scheduled", 1) - - -@T.prim_func -def fused_decode3_matmul1_before(lv2931: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv2932: T.Buffer((T.int64(128), T.int64(32000)), "uint32"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000))) - for i, j in T.grid(T.int64(4096), T.int64(32000)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv2931[v_i // T.int64(8), v_j], lv2932[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv2931[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv2932[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv2932[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1511[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1511[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - - -@T.prim_func -def fused_decode3_matmul1_after(lv1123: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv1124: T.Buffer((T.int64(128), T.int64(32000)), "uint32"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - # with T.block("root"): - var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4096), T.int64(32000)), scope="local") - var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope="local") - lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv1511_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) - T.reads(lv1511[v0, v1, v2]) - T.writes(lv1511_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv1511_shared[v0, v1, v2] = lv1511[v0, v1, v2] - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(64)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("var_decode_intermediate_pad"): - v0 = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1123[v0 // T.int64(8), v1], lv1124[v0 // T.int64(32), v1]) - T.writes(var_decode_intermediate_pad_local[v0, v1]) - var_decode_intermediate_pad_local[v0, v1] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1123[v0 // T.int64(8), v1], T.Cast("uint32", v0 % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1124[v0 // T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1124[v0 // T.int64(32), v1], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for k_0_1_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_pad_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_pad_local[v0, v1, v2]) - T.writes(var_matmul_intermediate[v0, v1, v2]) - var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_pad_local[v0, v1, v2] - - -@T.prim_func -def fused_decode4_fused_matmul5_add3_before(lv3184: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv3185: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), lv452: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096))) - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096))) - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv3184[v_i // T.int64(8), v_j], lv3185[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv3184[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv3185[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv3185[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv452[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv452[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv2710[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv2710[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode4_fused_matmul5_add3_after(lv1143: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv1144: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="local") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local") - lv3_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv3_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) - T.reads(lv3[v0, v1, v2]) - T.writes(lv3_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv3_shared[v0, v1, v2] = lv3[v0, v1, v2] - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(64)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1143[v_j // T.int64(8), v_i], lv1144[v_j // T.int64(32), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1143[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1144[v_j // T.int64(32), v_i], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1144[v_j // T.int64(32), v_i], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for k_0_1_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv3_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv3_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(lv2710[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv2710[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode4_matmul5_before(lv3166: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv3167: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096))) - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv3166[v_i // T.int64(8), v_j], lv3167[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv3166[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv3167[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv3167[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2712[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2712[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - - -@T.prim_func -def fused_decode4_matmul5_after(lv1128: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv1129: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="local") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local") - lv2712_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv2712_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) - T.reads(lv2712[v0, v1, v2]) - T.writes(lv2712_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv2712_shared[v0, v1, v2] = lv2712[v0, v1, v2] - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(64)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1128[v_j // T.int64(8), v_i], lv1129[v_j // T.int64(32), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1128[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1129[v_j // T.int64(32), v_i], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1129[v_j // T.int64(32), v_i], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for k_0_1_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2712_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2712_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_local[v0, v1, v2]) - T.writes(var_matmul_intermediate[v0, v1, v2]) - var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode5_fused_matmul8_multiply1_before(lv1617: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv1618: T.Buffer((T.int64(128), T.int64(11008)), "uint32"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008))) - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008))) - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1617[v_i // T.int64(8), v_j], lv1618[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1617[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1618[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1618[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2749[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2749[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv4[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv4[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode5_fused_matmul8_multiply1_after(lv1153: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv1154: T.Buffer((T.int64(128), T.int64(11008)), "uint32"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(11008)), scope="local") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local") - lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv2749_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) - T.reads(lv2749[v0, v1, v2]) - T.writes(lv2749_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv2749_shared[v0, v1, v2] = lv2749[v0, v1, v2] - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(64)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1153[v_j // T.int64(8), v_i], lv1154[v_j // T.int64(32), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1153[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1154[v_j // T.int64(32), v_i], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1154[v_j // T.int64(32), v_i], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for k_0_1_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(lv5[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv5[v0, v1, v2] * var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode5_fused_matmul8_silu1_before(lv1611: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv1612: T.Buffer((T.int64(128), T.int64(11008)), "uint32"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008))) - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008))) - compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008))) - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1611[v_i // T.int64(8), v_j], lv1612[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1611[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1612[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1612[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2749[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2749[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode5_fused_matmul8_silu1_after(lv1148: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv1149: T.Buffer((T.int64(128), T.int64(11008)), "uint32"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(11008)), scope="local") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local") - lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv2749_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) - T.reads(lv2749[v0, v1, v2]) - T.writes(lv2749_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv2749_shared[v0, v1, v2] = lv2749[v0, v1, v2] - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(64)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1148[v_j // T.int64(8), v_i], lv1149[v_j // T.int64(32), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1148[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1149[v_j // T.int64(32), v_i], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1149[v_j // T.int64(32), v_i], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for k_0_1_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] * T.sigmoid(var_matmul_intermediate_local[v0, v1, v2]) - - -@T.prim_func -def fused_decode6_fused_matmul9_add3_before(lv1623: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), lv1624: T.Buffer((T.int64(344), T.int64(4096)), "uint32"), lv230: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32"), lv228: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096))) - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096))) - for i, j in T.grid(T.int64(11008), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1623[v_i // T.int64(8), v_j], lv1624[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1623[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1624[v_i // T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1624[v_i // T.int64(32), v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv230[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv230[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv228[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv228[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode6_fused_matmul9_add3_after(lv1158: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), lv1159: T.Buffer((T.int64(344), T.int64(4096)), "uint32"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope="local") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local") - lv6_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="shared") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(2)): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(22)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv6_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + (ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(5504)) - T.reads(lv6[v0, v1, v2]) - T.writes(lv6_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv6_shared[v0, v1, v2] = lv6[v0, v1, v2] - for k_0_1 in range(T.int64(86)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1158[v_j // T.int64(8), v_i], lv1159[v_j // T.int64(32), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float32", T.bitwise_and(T.shift_right(lv1158[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32", T.shift_left(T.bitwise_and(lv1159[v_j // T.int64(32), v_i], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", T.shift_left(T.bitwise_and(T.shift_right(lv1159[v_j // T.int64(32), v_i], T.uint32(16)), T.uint32(65535)), T.uint32(16))) - for k_0_2_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + k_0_2_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv6_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv6_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode3_matmul1_fp16_before(lv5865: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv5866: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv5867: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv2705: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(32000)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv5865[v_i // T.int64(8), v_j], lv5866[v_i // T.int64(32), v_j], lv5867[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv5865[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv5866[v_i // T.int64(32), v_j] + lv5867[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2705[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2705[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - - -@T.prim_func -def fused_decode3_matmul1_fp16_after(lv1123: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv5866: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv5867: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16")): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - # with T.block("root"): - var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4096), T.int64(32000)), scope="local", dtype="float16") - var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope="local", dtype="float16") - lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv1511_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) - T.reads(lv1511[v0, v1, v2]) - T.writes(lv1511_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv1511_shared[v0, v1, v2] = lv1511[v0, v1, v2] - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(64)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("var_decode_intermediate_pad"): - v0 = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1123[v0 // T.int64(8), v1], lv5866[v0 // T.int64(32), v1], lv5867[v0 // T.int64(32), v1]) - T.writes(var_decode_intermediate_pad_local[v0, v1]) - var_decode_intermediate_pad_local[v0, v1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1123[v0 // T.int64(8), v1], T.Cast("uint32", v0 % T.int64(8) * T.int64(4))), T.uint32(15))) * lv5866[v0 // T.int64(32), v1] + lv5867[v0 // T.int64(32), v1] - for k_0_1_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_pad_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_pad_local[v0, v1, v2]) - T.writes(var_matmul_intermediate[v0, v1, v2]) - var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_pad_local[v0, v1, v2] - - -@T.prim_func -def fused_decode3_matmul1_cast_fp16_before(lv1803: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv1804: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv1805: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv3025: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(32000)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1803[v_i // T.int64(8), v_j], lv1804[v_i // T.int64(32), v_j], lv1805[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1803[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv1804[v_i // T.int64(32), v_j] + lv1805[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv3025[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3025[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) - p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2]) - - -@T.prim_func -def fused_decode3_matmul1_cast_fp16_after(lv1123: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv5866: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv5867: T.Buffer((T.int64(128), T.int64(32000)), "float16"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - # with T.block("root"): - var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4096), T.int64(32000)), scope="local", dtype="float16") - var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope="local", dtype="float16") - lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv1511_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) - T.reads(lv1511[v0, v1, v2]) - T.writes(lv1511_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv1511_shared[v0, v1, v2] = lv1511[v0, v1, v2] - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(64)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("var_decode_intermediate_pad"): - v0 = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1123[v0 // T.int64(8), v1], lv5866[v0 // T.int64(32), v1], lv5867[v0 // T.int64(32), v1]) - T.writes(var_decode_intermediate_pad_local[v0, v1]) - var_decode_intermediate_pad_local[v0, v1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1123[v0 // T.int64(8), v1], T.Cast("uint32", v0 % T.int64(8) * T.int64(4))), T.uint32(15))) * lv5866[v0 // T.int64(32), v1] + lv5867[v0 // T.int64(32), v1] - for k_0_1_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_pad_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_pad_local[v0, v1, v2]) - T.writes(var_matmul_intermediate[v0, v1, v2]) - var_matmul_intermediate[v0, v1, v2] = T.Cast("float32", var_matmul_intermediate_pad_local[v0, v1, v2]) - - -@T.prim_func -def fused_decode4_fused_matmul5_add3_fp16_before(lv35: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv36: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv37: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv2: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv35[v_i // T.int64(8), v_j], lv36[v_i // T.int64(32), v_j], lv37[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv35[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv36[v_i // T.int64(32), v_j] + lv37[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv2710[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv2710[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode4_fused_matmul5_add3_fp16_after(lv1143: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv36: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv37: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") - lv3_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv3_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) - T.reads(lv3[v0, v1, v2]) - T.writes(lv3_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv3_shared[v0, v1, v2] = lv3[v0, v1, v2] - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(64)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1143[v_j // T.int64(8), v_i], lv36[v_j // T.int64(32), v_i], lv37[v_j // T.int64(32), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1143[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv36[v_j // T.int64(32), v_i] + lv37[v_j // T.int64(32), v_i] - for k_0_1_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv3_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv3_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(lv2710[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv2710[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode4_matmul5_fp16_before(lv11: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv12: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv13: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv11[v_i // T.int64(8), v_j], lv12[v_i // T.int64(32), v_j], lv13[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv11[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv12[v_i // T.int64(32), v_j] + lv13[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2712[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2712[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - - -@T.prim_func -def fused_decode4_matmul5_fp16_after(lv1128: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv12: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv13: T.Buffer((T.int64(128), T.int64(4096)), "float16"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(4096)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") - lv2712_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv2712_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) - T.reads(lv2712[v0, v1, v2]) - T.writes(lv2712_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv2712_shared[v0, v1, v2] = lv2712[v0, v1, v2] - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(64)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1128[v_j // T.int64(8), v_i], lv12[v_j // T.int64(32), v_i], lv13[v_j // T.int64(32), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1128[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv12[v_j // T.int64(32), v_i] + lv13[v_j // T.int64(32), v_i] - for k_0_1_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2712_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2712_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_local[v0, v1, v2]) - T.writes(var_matmul_intermediate[v0, v1, v2]) - var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode5_fused_matmul8_multiply1_fp16_before(lv51: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv52: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv53: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv51[v_i // T.int64(8), v_j], lv52[v_i // T.int64(32), v_j], lv53[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv51[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv52[v_i // T.int64(32), v_j] + lv53[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2749[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2749[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv5[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv5[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode5_fused_matmul8_multiply1_fp16_after(lv1153: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv52: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv53: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(11008)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local", dtype="float16") - lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv2749_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) - T.reads(lv2749[v0, v1, v2]) - T.writes(lv2749_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv2749_shared[v0, v1, v2] = lv2749[v0, v1, v2] - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(64)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1153[v_j // T.int64(8), v_i], lv52[v_j // T.int64(32), v_i], lv53[v_j // T.int64(32), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1153[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv52[v_j // T.int64(32), v_i] + lv53[v_j // T.int64(32), v_i] - for k_0_1_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(lv5[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv5[v0, v1, v2] * var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode5_fused_matmul8_silu1_fp16_before(lv43: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv44: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv45: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") - compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv43[v_i // T.int64(8), v_j], lv44[v_i // T.int64(32), v_j], lv45[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv43[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv44[v_i // T.int64(32), v_j] + lv45[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv2749[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2749[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode5_fused_matmul8_silu1_fp16_after(lv1148: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), lv44: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv45: T.Buffer((T.int64(128), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(11008)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local", dtype="float16") - lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(4)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax1_ax2_fused_2 in T.vectorized(T.int64(4)): - with T.block("lv2749_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4096), ax1_ax2_fused_0 * T.int64(1024) + ax1_ax2_fused_1 * T.int64(4) + ax1_ax2_fused_2) - T.reads(lv2749[v0, v1, v2]) - T.writes(lv2749_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv2749_shared[v0, v1, v2] = lv2749[v0, v1, v2] - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(64)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4096), k_0_0 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1148[v_j // T.int64(8), v_i], lv44[v_j // T.int64(32), v_i], lv45[v_j // T.int64(32), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1148[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv44[v_j // T.int64(32), v_i] + lv45[v_j // T.int64(32), v_i] - for k_0_1_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0_0 * T.int64(64) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] * T.sigmoid(var_matmul_intermediate_local[v0, v1, v2]) - - -@T.prim_func -def fused_decode6_fused_matmul9_add3_fp16_before(lv59: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), lv60: T.Buffer((T.int64(344), T.int64(4096)), "float16"), lv61: T.Buffer((T.int64(344), T.int64(4096)), "float16"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(11008), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv59[v_i // T.int64(8), v_j], lv60[v_i // T.int64(32), v_j], lv61[v_i // T.int64(32), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = T.Cast("float16", T.bitwise_and(T.shift_right(lv59[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * lv60[v_i // T.int64(32), v_j] + lv61[v_i // T.int64(32), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv5[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv5[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv3[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode6_fused_matmul9_add3_fp16_after(lv1158: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), lv60: T.Buffer((T.int64(344), T.int64(4096)), "float16"), lv61: T.Buffer((T.int64(344), T.int64(4096)), "float16"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") - lv6_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(2)): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(22)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv6_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + (ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(5504)) - T.reads(lv6[v0, v1, v2]) - T.writes(lv6_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv6_shared[v0, v1, v2] = lv6[v0, v1, v2] - for k_0_1 in range(T.int64(86)): - for ax0_0 in range(T.int64(8)): - for ax0_1 in T.unroll(T.int64(8)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + ax0_0 * T.int64(8) + ax0_1) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1158[v_j // T.int64(8), v_i], lv60[v_j // T.int64(32), v_i], lv61[v_j // T.int64(32), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1158[v_j // T.int64(8), v_i], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) * lv60[v_j // T.int64(32), v_i] + lv61[v_j // T.int64(32), v_i] - for k_0_2_k_1_fused in range(T.int64(64)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(11008), k_0_0 * T.int64(5504) + k_0_1 * T.int64(64) + k_0_2_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv6_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv6_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode3_matmul1_cast_int3_fp16_before(lv2931: T.Buffer((T.int64(412), T.int64(32000)), "uint32"), lv2932: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv3025: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(32000)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv2931[v_i // T.int64(10), v_j], lv2932[v_i // T.int64(40), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv2931[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv2932[v_i // T.int64(40), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv3025[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3025[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) - p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2]) - - -@T.prim_func -def fused_decode3_matmul1_cast_int3_fp16_after(lv1123: T.Buffer((T.int64(412), T.int64(32000)), "uint32"), lv5866: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - # with T.block("root"): - var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4120), T.int64(32000)), scope="local", dtype="float16") - var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope="local", dtype="float16") - lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv1511_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) - T.reads(lv1511[v0, v1, v2]) - T.writes(lv1511_shared[v0, v1, v2]) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv1511_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv1511[v0, v1, v2], T.float16(0)) - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(103)): - for ax0_0 in T.unroll(T.int64(40)): - for ax1 in range(T.int64(1)): - with T.block("var_decode_intermediate_pad"): - v0 = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) - v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1123[v0 // T.int64(10), v1], lv5866[v0 // T.int64(40), v1]) - T.writes(var_decode_intermediate_pad_local[v0, v1]) - var_decode_intermediate_pad_local[v0, v1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1123[v0 // T.int64(10), v1], T.Cast("uint32", v0 % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3) - for k_0_1_k_1_fused in range(T.int64(40)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2] * lv5866[v_k // T.int64(40), v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_pad_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_pad_local[v0, v1, v2]) - T.writes(var_matmul_intermediate[v0, v1, v2]) - var_matmul_intermediate[v0, v1, v2] = T.Cast("float32", var_matmul_intermediate_pad_local[v0, v1, v2]) - - -@T.prim_func -def fused_decode4_fused_matmul5_add3_int3_fp16_before(lv1605: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), lv1606: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv164: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1518: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1605[v_i // T.int64(10), v_j], lv1606[v_i // T.int64(40), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1605[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1606[v_i // T.int64(40), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv164[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv164[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv1518[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1518[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode4_fused_matmul5_add3_int3_fp16_after(lv1143: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), lv36: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(4096)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") - lv3_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv3_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) - T.reads(lv3[v0, v1, v2]) - T.writes(lv3_shared[v0, v1, v2]) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv3_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv3[v0, v1, v2], T.float16(0)) - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(103)): - for ax0_0 in T.unroll(T.int64(40)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1143[v_j // T.int64(10), v_i], lv36[v_j // T.int64(40), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1143[v_j // T.int64(10), v_i], T.Cast("uint32", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3) - for k_0_1_k_1_fused in range(T.int64(40)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv3_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv3_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv36[v_k // T.int64(40), v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(lv2710[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv2710[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode4_matmul5_int3_fp16_before(lv1587: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), lv1588: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv1520: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1587[v_i // T.int64(10), v_j], lv1588[v_i // T.int64(40), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1587[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1588[v_i // T.int64(40), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1520[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1520[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - - -@T.prim_func -def fused_decode4_matmul5_int3_fp16_after(lv1128: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), lv12: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(4096)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") - lv2712_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv2712_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) - T.reads(lv2712[v0, v1, v2]) - T.writes(lv2712_shared[v0, v1, v2]) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv2712_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2712[v0, v1, v2], T.float16(0)) - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(103)): - for ax0_0 in T.unroll(T.int64(40)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1128[v_j // T.int64(10), v_i], lv12[v_j // T.int64(40), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1128[v_j // T.int64(10), v_i], T.Cast("uint32", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3) - for k_0_1_k_1_fused in range(T.int64(40)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2712_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2712_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv12[v_k // T.int64(40), v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_local[v0, v1, v2]) - T.writes(var_matmul_intermediate[v0, v1, v2]) - var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode5_fused_matmul8_multiply1_int3_fp16_before(lv1617: T.Buffer((T.int64(412), T.int64(11008)), "uint32"), lv1618: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1557: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1617[v_i // T.int64(10), v_j], lv1618[v_i // T.int64(40), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1617[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1618[v_i // T.int64(40), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1557[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1557[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv3[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode5_fused_matmul8_multiply1_int3_fp16_after(lv1153: T.Buffer((T.int64(412), T.int64(11008)), "uint32"), lv52: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(11008)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local", dtype="float16") - lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv2749_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) - T.reads(lv2749[v0, v1, v2]) - T.writes(lv2749_shared[v0, v1, v2]) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv2749_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2749[v0, v1, v2], T.float16(0)) - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(103)): - for ax0_0 in T.unroll(T.int64(40)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) - v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1153[v_j // T.int64(10), v_i], lv52[v_j // T.int64(40), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1153[v_j // T.int64(10), v_i], T.Cast("uint32", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3) - for k_0_1_k_1_fused in range(T.int64(40)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv52[v_k // T.int64(40), v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(lv5[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv5[v0, v1, v2] * var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode5_fused_matmul8_silu1_int3_fp16_before(lv1611: T.Buffer((T.int64(412), T.int64(11008)), "uint32"), lv1612: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1557: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") - compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1611[v_i // T.int64(10), v_j], lv1612[v_i // T.int64(40), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1611[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1612[v_i // T.int64(40), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1557[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1557[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode5_fused_matmul8_silu1_int3_fp16_after(lv1148: T.Buffer((T.int64(412), T.int64(11008)), "uint32"), lv44: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(11008)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local", dtype="float16") - lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv2749_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) - T.reads(lv2749[v0, v1, v2]) - T.writes(lv2749_shared[v0, v1, v2]) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv2749_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2749[v0, v1, v2], T.float16(0)) - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(103)): - for ax0_0 in T.unroll(T.int64(40)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) - v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1148[v_j // T.int64(10), v_i], lv44[v_j // T.int64(40), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1148[v_j // T.int64(10), v_i], T.Cast("uint32", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3) - for k_0_1_k_1_fused in range(T.int64(40)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv44[v_k // T.int64(40), v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] * T.sigmoid(var_matmul_intermediate_local[v0, v1, v2]) - - -@T.prim_func -def fused_decode6_fused_matmul9_add3_int3_fp16_before(lv1623: T.Buffer((T.int64(1104), T.int64(4096)), "uint32"), lv1624: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv167: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv165: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(11008), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1623[v_i // T.int64(10), v_j], lv1624[v_i // T.int64(40), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1623[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1624[v_i // T.int64(40), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv167[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv167[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv165[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv165[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode6_fused_matmul9_add3_int3_fp16_after(lv1158: T.Buffer((T.int64(1104), T.int64(4096)), "uint32"), lv60: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(11040), T.int64(4096)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") - lv6_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11040)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(2)): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(22)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv6_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(11040), k_0_0 * T.int64(5520) + (ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1)) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(5520)) - T.reads(lv6[v0, v1, v2]) - T.writes(lv6_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv6_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(11008), lv6[v0, v1, v2], T.float16(0)) - for k_0_1 in range(T.int64(69)): - for ax0_0 in T.unroll(T.int64(80)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(11040), k_0_0 * T.int64(5520) + k_0_1 * T.int64(80) + ax0_0) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1158[v_j // T.int64(10), v_i], lv60[v_j // T.int64(40), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.bitwise_and(T.shift_right(lv1158[v_j // T.int64(10), v_i], T.Cast("uint32", v_j % T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3) - for k_0_2_k_1_fused in range(T.int64(80)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(11040), k_0_0 * T.int64(5520) + k_0_1 * T.int64(80) + k_0_2_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv6_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv6_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * lv60[v_k // T.int64(40), v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode3_matmul1_cast_int3_int16_fp16_before(lv2931: T.Buffer((T.int64(824), T.int64(32000)), "uint16"), lv2932: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv3025: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(32000)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv2931[v_i // T.int64(5), v_j], lv2932[v_i // T.int64(40), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv2931[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv2932[v_i // T.int64(40), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv3025[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3025[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) - p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2]) - - -@T.prim_func -def fused_decode3_matmul1_cast_int3_int16_fp16_after(lv1123: T.Buffer((T.int64(824), T.int64(32000)), "uint16"), lv5866: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - # with T.block("root"): - var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4120), T.int64(32000)), scope="local", dtype="float16") - var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(4096)), scope="local", dtype="float16") - var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope="local", dtype="float16") - lv1511_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(125), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv1511_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) - T.reads(lv1511[v0, v1, v2]) - T.writes(lv1511_shared[v0, v1, v2]) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv1511_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv1511[v0, v1, v2], T.float16(0)) - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(103)): - for ax0_0 in T.unroll(T.int64(40)): - for ax1 in range(T.int64(1)): - with T.block("var_decode_intermediate_pad"): - v0 = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) - v1 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1123[v0 // T.int64(5), v1]) - T.writes(var_decode_intermediate_pad_local[v0, v1]) - var_decode_intermediate_pad_local[v0, v1] = T.Cast("float16", T.Cast("int16", T.bitwise_and(T.shift_right(T.Cast("uint16", lv1123[v0 // T.int64(5), v1]), T.Cast("uint16", v0 % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3)) - for ax0_0 in range(T.int64(1)): - for ax1 in range(T.int64(1)): - with T.block("scale"): - v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv5866[v_j, v_i]) - T.writes(var_scale_intermediate_local[v_j, v_i]) - var_scale_intermediate_local[v_j, v_i] = lv5866[v_j, v_i] - for k_0_1_k_1_fused in range(T.int64(40)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv1511_shared[v_i0, v_i1, v_k], var_decode_intermediate_pad_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2]) - T.writes(var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv1511_shared[v_i0, v_i1, v_k] * var_decode_intermediate_pad_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_pad_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(32000), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_pad_local[v0, v1, v2]) - T.writes(var_matmul_intermediate[v0, v1, v2]) - var_matmul_intermediate[v0, v1, v2] = T.Cast("float32", var_matmul_intermediate_pad_local[v0, v1, v2]) - - -@T.prim_func -def fused_decode4_fused_matmul5_add3_int3_int16_fp16_before(lv1605: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv1606: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv164: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1518: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1605[v_i // T.int64(5), v_j], lv1606[v_i // T.int64(40), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1605[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1606[v_i // T.int64(40), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv164[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv164[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv1518[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1518[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode4_fused_matmul5_add3_int3_int16_fp16_after(lv1143: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv36: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv2710: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(4096)), scope="local", dtype="float16") - var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(4096)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") - lv3_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv3_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) - T.reads(lv3[v0, v1, v2]) - T.writes(lv3_shared[v0, v1, v2]) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv3_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv3[v0, v1, v2], T.float16(0)) - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(103)): - for ax0_0 in T.unroll(T.int64(40)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1143[v_j // T.int64(5), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.Cast("int16", T.bitwise_and(T.shift_right(T.Cast("uint16", lv1143[v_j // T.int64(5), v_i]), T.Cast("uint16", v_j % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3)) - for ax0_0 in range(T.int64(1)): - for ax1 in range(T.int64(1)): - with T.block("scale"): - v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv36[v_j, v_i]) - T.writes(var_scale_intermediate_local[v_j, v_i]) - var_scale_intermediate_local[v_j, v_i] = lv36[v_j, v_i] - for k_0_1_k_1_fused in range(T.int64(40)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv3_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv3_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(lv2710[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv2710[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode4_matmul5_int3_int16_fp16_before(lv1587: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv1588: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv1520: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1587[v_i // T.int64(5), v_j], lv1588[v_i // T.int64(40), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1587[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1588[v_i // T.int64(40), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1520[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1520[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - - -@T.prim_func -def fused_decode4_matmul5_int3_int16_fp16_after(lv1128: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv12: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv2712: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(4096)), scope="local", dtype="float16") - var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(4096)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") - lv2712_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv2712_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) - T.reads(lv2712[v0, v1, v2]) - T.writes(lv2712_shared[v0, v1, v2]) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv2712_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2712[v0, v1, v2], T.float16(0)) - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(103)): - for ax0_0 in T.unroll(T.int64(40)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1128[v_j // T.int64(5), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.Cast("int16", T.bitwise_and(T.shift_right(T.Cast("uint16", lv1128[v_j // T.int64(5), v_i]), T.Cast("uint16", v_j % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3)) - for ax0_0 in range(T.int64(1)): - for ax1 in range(T.int64(1)): - with T.block("scale"): - v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv12[v_j, v_i]) - T.writes(var_scale_intermediate_local[v_j, v_i]) - var_scale_intermediate_local[v_j, v_i] = lv12[v_j, v_i] - for k_0_1_k_1_fused in range(T.int64(40)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2712_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2712_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_local[v0, v1, v2]) - T.writes(var_matmul_intermediate[v0, v1, v2]) - var_matmul_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode5_fused_matmul8_multiply1_int3_int16_fp16_before(lv1617: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv1618: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1557: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1617[v_i // T.int64(5), v_j], lv1618[v_i // T.int64(40), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1617[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1618[v_i // T.int64(40), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1557[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1557[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv3[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode5_fused_matmul8_multiply1_int3_int16_fp16_after(lv1153: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv52: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv5: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(11008)), scope="local", dtype="float16") - var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(4096)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local", dtype="float16") - lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv2749_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) - T.reads(lv2749[v0, v1, v2]) - T.writes(lv2749_shared[v0, v1, v2]) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv2749_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2749[v0, v1, v2], T.float16(0)) - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(103)): - for ax0_0 in T.unroll(T.int64(40)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) - v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1153[v_j // T.int64(5), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.Cast("int16", T.bitwise_and(T.shift_right(T.Cast("uint16", lv1153[v_j // T.int64(5), v_i]), T.Cast("uint16", v_j % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3)) - for ax0_0 in range(T.int64(1)): - for ax1 in range(T.int64(1)): - with T.block("scale"): - v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv52[v_j, v_i]) - T.writes(var_scale_intermediate_local[v_j, v_i]) - var_scale_intermediate_local[v_j, v_i] = lv52[v_j, v_i] - for k_0_1_k_1_fused in range(T.int64(40)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(lv5[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv5[v0, v1, v2] * var_matmul_intermediate_local[v0, v1, v2] - - -@T.prim_func -def fused_decode5_fused_matmul8_silu1_int3_int16_fp16_before(lv1611: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv1612: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1557: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") - compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") - for i, j in T.grid(T.int64(4096), T.int64(11008)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1611[v_i // T.int64(5), v_j], lv1612[v_i // T.int64(40), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1611[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1612[v_i // T.int64(40), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv1557[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1557[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("compute"): - v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) - T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) - T.writes(compute[v_i0, v_i1, v_i2]) - compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2]) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode5_fused_matmul8_silu1_int3_int16_fp16_after(lv1148: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv44: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv2749: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): - T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(4120), T.int64(11008)), scope="local", dtype="float16") - var_scale_intermediate_local = T.alloc_buffer((T.int64(103), T.int64(11008)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), scope="local", dtype="float16") - lv2749_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4120)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(43), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(17)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv2749_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(4120), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) - T.reads(lv2749[v0, v1, v2]) - T.writes(lv2749_shared[v0, v1, v2]) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(4120)) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv2749_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(4096), lv2749[v0, v1, v2], T.float16(0)) - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(103)): - for ax0_0 in T.unroll(T.int64(40)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(4120), k_0_0 * T.int64(40) + ax0_0) - v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1148[v_j // T.int64(5), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.Cast("int16", T.bitwise_and(T.shift_right(T.Cast("uint16", lv1148[v_j // T.int64(5), v_i]), T.Cast("uint16", v_j % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3)) - for ax0_0 in range(T.int64(1)): - for ax1 in range(T.int64(1)): - with T.block("scale"): - v_j = T.axis.spatial(T.int64(103), k_0_0 + ax0_0) - v_i = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv44[v_j, v_i]) - T.writes(var_scale_intermediate_local[v_j, v_i]) - var_scale_intermediate_local[v_j, v_i] = lv44[v_j, v_i] - for k_0_1_k_1_fused in range(T.int64(40)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(4120), k_0_0 * T.int64(40) + k_0_1_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv2749_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv2749_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(11008), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = var_matmul_intermediate_local[v0, v1, v2] * T.sigmoid(var_matmul_intermediate_local[v0, v1, v2]) - - -@T.prim_func -def fused_decode6_fused_matmul9_add3_int3_int16_fp16_before(lv1623: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), lv1624: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv167: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv165: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - # with T.block("root"): - var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") - var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") - for i, j in T.grid(T.int64(11008), T.int64(4096)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv1623[v_i // T.int64(5), v_j], lv1624[v_i // T.int64(40), v_j]) - T.writes(var_decode_intermediate[v_i, v_j]) - var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1623[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1624[v_i // T.int64(40), v_j] - for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv167[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) - T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv167[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv165[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv165[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - -@T.prim_func -def fused_decode6_fused_matmul9_add3_int3_int16_fp16_after(lv1158: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), lv60: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - # with T.block("root"): - var_decode_intermediate_local = T.alloc_buffer((T.int64(11040), T.int64(4096)), scope="local", dtype="float16") - var_scale_intermediate_local = T.alloc_buffer((T.int64(276), T.int64(4096)), scope="local", dtype="float16") - var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local", dtype="float16") - lv6_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11040)), scope="shared", dtype="float16") - for i0_i1_i2_0_fused in T.thread_binding(T.int64(16), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 16, "pragma_unroll_explicit": 1}): - for i2_1 in T.thread_binding(T.int64(1), thread="vthread.x"): - for i2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - for ax0, ax1_ax2_fused_0 in T.grid(T.int64(1), T.int64(44)): - for ax1_ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): - with T.block("lv2749_shared"): - v0 = T.axis.spatial(T.int64(1), ax0) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial(T.int64(11040), ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1) - T.reads(lv6[v0, v1, v2]) - T.writes(lv6_shared[v0, v1, v2]) - T.where(ax1_ax2_fused_0 * T.int64(256) + ax1_ax2_fused_1 < T.int64(11040)) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 8]]}) - lv6_shared[v0, v1, v2] = T.if_then_else(v2 < T.int64(11008), lv6[v0, v1, v2], T.float16(0)) - with T.block("matmul_init"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - T.reads() - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0_0 in range(T.int64(138)): - for ax0_0 in T.unroll(T.int64(80)): - for ax1 in range(T.int64(1)): - with T.block("decode"): - v_j = T.axis.spatial(T.int64(11040), k_0_0 * T.int64(80) + ax0_0) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv1158[v_j // T.int64(5), v_i]) - T.writes(var_decode_intermediate_local[v_j, v_i]) - var_decode_intermediate_local[v_j, v_i] = T.Cast("float16", T.Cast("int16", T.bitwise_and(T.shift_right(T.Cast("uint16", lv1158[v_j // T.int64(5), v_i]), T.Cast("uint16", v_j % T.int64(5)) * T.uint16(3)), T.uint16(7))) - T.int16(3)) - for ax0_0 in T.unroll(T.int64(2)): - for ax1 in range(T.int64(1)): - with T.block("scale"): - v_j = T.axis.spatial(T.int64(276), k_0_0 * T.int64(2) + ax0_0) - v_i = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax1) - T.reads(lv60[v_j, v_i]) - T.writes(var_scale_intermediate_local[v_j, v_i]) - var_scale_intermediate_local[v_j, v_i] = lv60[v_j, v_i] - for k_0_2_k_1_fused in range(T.int64(80)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_1 * T.int64(256) + i2_2) - v_k = T.axis.reduce(T.int64(11040), k_0_0 * T.int64(80) + k_0_2_k_1_fused) - T.reads(var_matmul_intermediate_local[v_i0, v_i1, v_i2], lv6_shared[v_i0, v_i1, v_k], var_decode_intermediate_local[v_k, v_i2], var_scale_intermediate_local[v_k // T.int64(40), v_i2]) - T.writes(var_matmul_intermediate_local[v_i0, v_i1, v_i2]) - var_matmul_intermediate_local[v_i0, v_i1, v_i2] = var_matmul_intermediate_local[v_i0, v_i1, v_i2] + lv6_shared[v_i0, v_i1, v_k] * var_decode_intermediate_local[v_k, v_i2] * var_scale_intermediate_local[v_k // T.int64(40), v_i2] - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): - with T.block("var_matmul_intermediate_local"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(T.int64(4096), i0_i1_i2_0_fused * T.int64(256) + i2_2 + ax2) - T.reads(lv4[v0, v1, v2], var_matmul_intermediate_local[v0, v1, v2]) - T.writes(p_output0_intermediate[v0, v1, v2]) - p_output0_intermediate[v0, v1, v2] = lv4[v0, v1, v2] + var_matmul_intermediate_local[v0, v1, v2] -################################################ - -def get_dict_key(func): - return tvm.ir.structural_hash(func), func - - -tir_dispatch_dict = { - get_dict_key(fused_min_max_triu_te_broadcast_to): fused_min_max_triu_te_broadcast_to_sch_func(), - get_dict_key(rms_norm_before): rms_norm_after, - get_dict_key(rms_norm_fp16_before): rms_norm_fp16_after, - get_dict_key(softmax_before): softmax_after, - get_dict_key(softmax_mxn_before): softmax_mxn_after, - get_dict_key(softmax_cast_mxn_before): softmax_cast_mxn_after, - get_dict_key(softmax_fp16_before): softmax_fp16_after, - get_dict_key(softmax_mxn_fp16_before): softmax_mxn_fp16_after, - get_dict_key(softmax_1xn_before): softmax_1xn_sch_func(softmax_1xn_before), - get_dict_key(softmax_cast_1xn_before): softmax_1xn_sch_func(softmax_cast_1xn_before, cast_to_fp16=True), - get_dict_key(softmax_1xn_fp16_before): softmax_1xn_sch_func(softmax_1xn_fp16_before), - get_dict_key(matmul1_before): matmul1_after, - get_dict_key(matmul2_before): matmul2_sch_func(), - get_dict_key(matmul5_before): matmul5_after, - get_dict_key(matmul5_with_m_before): matmul5_with_m_after, - get_dict_key(NT_matmul_before): NT_matmul_after, - get_dict_key(NT_matmul4_before): NT_matmul4_sch_func(), - get_dict_key(NT_matmul9_before): NT_matmul9_sch_func(), - get_dict_key(fused_matmul1_add1): fused_matmul1_add1_sch_func(), - get_dict_key(fused_matmul3_multiply): fused_matmul3_multiply_sch_func(), - get_dict_key(fused_matmul3_silu): fused_matmul3_silu_sch_func(), - get_dict_key(fused_matmul4_add1): fused_matmul4_add1_sch_func(), - get_dict_key(fused_NT_matmul_add1_before): fused_NT_matmul_add1_after, - get_dict_key(fused_NT_matmul1_divide_add_maximum_before): fused_NT_matmul1_divide_add_maximum_after, - get_dict_key(fused_NT_matmul1_divide_add_maximum_with_m_before): fused_NT_matmul1_divide_add_maximum_with_m_after, - get_dict_key(fused_NT_matmul6_divide1_add2_maximum1_before): fused_NT_matmul6_divide1_add2_maximum1_after, - get_dict_key(fused_NT_matmul2_multiply_before): fused_NT_matmul2_multiply_after, - get_dict_key(fused_NT_matmul2_silu_before): fused_NT_matmul2_silu_after, - get_dict_key(fused_NT_matmul3_add1_before): fused_NT_matmul3_add1_after, - get_dict_key(fused_NT_matmul_divide_maximum_minimum_cast_before): fused_NT_matmul_divide_maximum_minimum_cast_sch_func(), - get_dict_key(fused_NT_matmul_divide_maximum_minimum_before): fused_NT_matmul_divide_maximum_minimum_sch_func(), - get_dict_key(fused_NT_matmul1_add3_before): fused_NT_matmul1_add3_sch_func(), - get_dict_key(fused_NT_matmul2_divide1_add2_maximum1_before): fused_NT_matmul2_divide1_add2_maximum1_sch_func(fused_NT_matmul2_divide1_add2_maximum1_before), - get_dict_key(fused_NT_matmul2_divide1_maximum1_minimum1_cast3_before): fused_NT_matmul2_divide1_maximum1_minimum1_cast3_after, - get_dict_key(fused_NT_matmul2_divide1_maximum1_minimum1_before): fused_NT_matmul2_divide1_maximum1_minimum1_after, - get_dict_key(fused_NT_matmul3_multiply1_before): fused_NT_matmul3_multiply1_sch_func(), - get_dict_key(fused_NT_matmul3_silu1_before): fused_NT_matmul3_silu1_sch_func(), - get_dict_key(fused_NT_matmul4_add3_before): fused_NT_matmul4_add3_sch_func(), - get_dict_key(matmul1_fp16_before): matmul1_fp16_sch_func(), - get_dict_key(matmul8_fp16_before): matmul8_fp16_sch_func(matmul8_fp16_before), - get_dict_key(matmul8_with_m_fp16_before): matmul8_fp16_sch_func(matmul8_with_m_fp16_before), - get_dict_key(NT_matmul1_fp16_before): NT_matmul1_fp16_sch_func(), - get_dict_key(decode6): decode_sch_func(decode6), - get_dict_key(decode7): decode_sch_func(decode7), - get_dict_key(decode8): decode_sch_func(decode8), - get_dict_key(decode4_fp16): decode_sch_func(decode4_fp16), - get_dict_key(decode5_fp16): decode_sch_func(decode5_fp16), - get_dict_key(decode6_fp16): decode_sch_func(decode6_fp16), - get_dict_key(decode_int3_fp16): decode_sch_func(decode_int3_fp16), - get_dict_key(decode1_int3_fp16): decode_sch_func(decode1_int3_fp16), - get_dict_key(decode2_int3_fp16): decode_sch_func(decode2_int3_fp16), - get_dict_key(decode_int3_int16_fp16): decode_sch_func(decode_int3_int16_fp16), - get_dict_key(decode1_int3_int16_fp16): decode_sch_func(decode1_int3_int16_fp16), - get_dict_key(decode2_int3_int16_fp16): decode_sch_func(decode2_int3_int16_fp16), - get_dict_key(fused_decode3_matmul1_before): fused_decode3_matmul1_after, - get_dict_key(fused_decode4_fused_matmul5_add3_before): fused_decode4_fused_matmul5_add3_after, - get_dict_key(fused_decode4_matmul5_before): fused_decode4_matmul5_after, - get_dict_key(fused_decode5_fused_matmul8_multiply1_before): fused_decode5_fused_matmul8_multiply1_after, - get_dict_key(fused_decode5_fused_matmul8_silu1_before): fused_decode5_fused_matmul8_silu1_after, - get_dict_key(fused_decode6_fused_matmul9_add3_before): fused_decode6_fused_matmul9_add3_after, - get_dict_key(fused_decode3_matmul1_fp16_before): fused_decode3_matmul1_fp16_after, - get_dict_key(fused_decode3_matmul1_cast_fp16_before): fused_decode3_matmul1_cast_fp16_after, - get_dict_key(fused_decode4_fused_matmul5_add3_fp16_before): fused_decode4_fused_matmul5_add3_fp16_after, - get_dict_key(fused_decode4_matmul5_fp16_before): fused_decode4_matmul5_fp16_after, - get_dict_key(fused_decode5_fused_matmul8_multiply1_fp16_before): fused_decode5_fused_matmul8_multiply1_fp16_after, - get_dict_key(fused_decode5_fused_matmul8_silu1_fp16_before): fused_decode5_fused_matmul8_silu1_fp16_after, - get_dict_key(fused_decode6_fused_matmul9_add3_fp16_before): fused_decode6_fused_matmul9_add3_fp16_after, - get_dict_key(fused_decode3_matmul1_cast_int3_fp16_before): fused_decode3_matmul1_cast_int3_fp16_after, - get_dict_key(fused_decode4_fused_matmul5_add3_int3_fp16_before): fused_decode4_fused_matmul5_add3_int3_fp16_after, - get_dict_key(fused_decode4_matmul5_int3_fp16_before): fused_decode4_matmul5_int3_fp16_after, - get_dict_key(fused_decode5_fused_matmul8_multiply1_int3_fp16_before): fused_decode5_fused_matmul8_multiply1_int3_fp16_after, - get_dict_key(fused_decode5_fused_matmul8_silu1_int3_fp16_before): fused_decode5_fused_matmul8_silu1_int3_fp16_after, - get_dict_key(fused_decode6_fused_matmul9_add3_int3_fp16_before): fused_decode6_fused_matmul9_add3_int3_fp16_after, - get_dict_key(fused_decode3_matmul1_cast_int3_int16_fp16_before): fused_decode3_matmul1_cast_int3_int16_fp16_after, - get_dict_key(fused_decode4_fused_matmul5_add3_int3_int16_fp16_before): fused_decode4_fused_matmul5_add3_int3_int16_fp16_after, - get_dict_key(fused_decode4_matmul5_int3_int16_fp16_before): fused_decode4_matmul5_int3_int16_fp16_after, - get_dict_key(fused_decode5_fused_matmul8_multiply1_int3_int16_fp16_before): fused_decode5_fused_matmul8_multiply1_int3_int16_fp16_after, - get_dict_key(fused_decode5_fused_matmul8_silu1_int3_int16_fp16_before): fused_decode5_fused_matmul8_silu1_int3_int16_fp16_after, - get_dict_key(fused_decode6_fused_matmul9_add3_int3_int16_fp16_before): fused_decode6_fused_matmul9_add3_int3_int16_fp16_after, -} -# fmt: on - - -def lookup_func(func): - for (hash_value, func_before), f_after in tir_dispatch_dict.items(): - if tvm.ir.structural_hash(func) == hash_value and tvm.ir.structural_equal( - func, func_before - ): - return f_after - return None diff --git a/mlc_llm/quantization/__init__.py b/mlc_llm/quantization/__init__.py deleted file mode 100644 index 6284df6fa8..0000000000 --- a/mlc_llm/quantization/__init__.py +++ /dev/null @@ -1,232 +0,0 @@ -from .quantization import FQuantize -from .quantization import QuantizationScheme -from .quantization import QuantizationSpec, NoQuantizationSpec, ParamQuantKind -from .quantization import QuantSpecUpdater -from .group_quantization import GroupQuantizationSpec -from .autogptq_quantization import AutogptqQuantizationSpec -from .ft_quantization import FTQuantizationSpec, FTQuantizeUpdater - - -# The predefined quantization schemes. -quantization_schemes = { - "autogptq_llama_q4f16_0": QuantizationScheme( - name="autogptq_llama_q4f16_0", - linear_weight=AutogptqQuantizationSpec( - dtype="float16", - mode="int4", - sym=False, - group_size=128, - ), - embedding_table=NoQuantizationSpec("float16"), - final_fc_weight=NoQuantizationSpec("float16"), - ), - "autogptq_llama_q4f16_1": QuantizationScheme( - name="autogptq_llama_q4f16_1", - linear_weight=AutogptqQuantizationSpec( - dtype="float16", - mode="int4", - sym=False, - group_size=-1, - ), - embedding_table=NoQuantizationSpec("float16"), - final_fc_weight=NoQuantizationSpec("float16"), - ), - "q0f16": QuantizationScheme("q0f16", NoQuantizationSpec("float16")), - "q0f32": QuantizationScheme("q0f32", NoQuantizationSpec("float32")), - "q3f16_0": QuantizationScheme( - name="q3f16_0", - linear_weight=GroupQuantizationSpec( - dtype="float16", - mode="int3", - sym=True, - storage_nbit=16, - group_size=40, - transpose=True, - ), - embedding_table=GroupQuantizationSpec( - dtype="float16", - mode="int3", - sym=True, - storage_nbit=16, - group_size=40, - transpose=False, - ), - final_fc_weight="same_as_linear_weight", - ), - "q3f16_1": QuantizationScheme( - name="q3f16_1", - linear_weight=GroupQuantizationSpec( - dtype="float16", - mode="int3", - sym=True, - storage_nbit=16, - group_size=40, - transpose=False, - ), - embedding_table="same_as_linear_weight", - final_fc_weight="same_as_linear_weight", - ), - "q4f16_0": QuantizationScheme( - name="q4f16_0", - linear_weight=GroupQuantizationSpec( - dtype="float16", - mode="int4", - sym=True, - storage_nbit=32, - group_size=32, - transpose=True, - ), - embedding_table=GroupQuantizationSpec( - dtype="float16", - mode="int4", - sym=True, - storage_nbit=32, - group_size=32, - transpose=False, - ), - final_fc_weight="same_as_linear_weight", - ), - "q4f16_1": QuantizationScheme( - name="q4f16_1", - linear_weight=GroupQuantizationSpec( - dtype="float16", - mode="int4", - sym=True, - storage_nbit=32, - group_size=32, - transpose=False, - ), - embedding_table="same_as_linear_weight", - final_fc_weight="same_as_linear_weight", - ), - "q4f16_2": QuantizationScheme( - name="q4f16_2", - linear_weight=GroupQuantizationSpec( - dtype="float16", - mode="int4", - sym=True, - storage_nbit=32, - group_size=32, - transpose=False, - ), - embedding_table=NoQuantizationSpec("float16"), - final_fc_weight=NoQuantizationSpec("float16"), - ), - "q4f16_ft": QuantizationScheme( - name="q4f16_ft", - linear_weight=FTQuantizationSpec( - dtype="float16", - nbit=4, - group_size=-1, - ), - embedding_table=GroupQuantizationSpec( - dtype="float16", - mode="int4", - sym=True, - storage_nbit=32, - group_size=32, - transpose=False, - ), - final_fc_weight="same_as_linear_weight", - qspec_updater_class=FTQuantizeUpdater, - ), - "q4f16_ft_group": QuantizationScheme( - name="q4f16_ft_group", - linear_weight=FTQuantizationSpec( - dtype="float16", - nbit=4, - group_size=64, - ), - embedding_table=GroupQuantizationSpec( - dtype="float16", - mode="int4", - sym=True, - storage_nbit=32, - group_size=32, - transpose=False, - ), - final_fc_weight="same_as_linear_weight", - qspec_updater_class=FTQuantizeUpdater, - ), - "q4f32_0": QuantizationScheme( - name="q4f32_0", - linear_weight=GroupQuantizationSpec( - dtype="float32", - mode="int4", - sym=False, - storage_nbit=32, - group_size=32, - transpose=True, - ), - embedding_table=GroupQuantizationSpec( - dtype="float32", - mode="int4", - sym=False, - storage_nbit=32, - group_size=32, - transpose=False, - ), - final_fc_weight="same_as_linear_weight", - ), - "q4f32_1": QuantizationScheme( - name="q4f32_1", - linear_weight=GroupQuantizationSpec( - dtype="float32", - mode="int4", - sym=False, - storage_nbit=32, - group_size=32, - transpose=False, - ), - embedding_table="same_as_linear_weight", - final_fc_weight="same_as_linear_weight", - ), - "q8f16_ft": QuantizationScheme( - name="q8f16_ft", - linear_weight=FTQuantizationSpec( - dtype="float16", - nbit=8, - ), - embedding_table=GroupQuantizationSpec( - dtype="float16", - mode="int8", - sym=True, - storage_nbit=32, - group_size=32, - transpose=False, - ), - final_fc_weight="same_as_linear_weight", - qspec_updater_class=FTQuantizeUpdater, - ), - "q8f16_ft_group": QuantizationScheme( - name="q8f16_ft_group", - linear_weight=FTQuantizationSpec( - dtype="float16", - nbit=8, - group_size=64, - ), - embedding_table=GroupQuantizationSpec( - dtype="float16", - mode="int8", - sym=True, - storage_nbit=32, - group_size=32, - transpose=False, - ), - final_fc_weight="same_as_linear_weight", - qspec_updater_class=FTQuantizeUpdater, - ), - "q8f16_1": QuantizationScheme( - name="q8f16_1", - linear_weight=GroupQuantizationSpec( - dtype="float16", - mode="int8", - sym=True, - storage_nbit=32, - group_size=32, - transpose=False, - ), - embedding_table="same_as_linear_weight", - final_fc_weight="same_as_linear_weight", - ), -} diff --git a/mlc_llm/quantization/autogptq_quantization.py b/mlc_llm/quantization/autogptq_quantization.py deleted file mode 100644 index 2cdc186dbc..0000000000 --- a/mlc_llm/quantization/autogptq_quantization.py +++ /dev/null @@ -1,193 +0,0 @@ -from dataclasses import dataclass -from typing import Any, List, Literal, Optional, Tuple -from tvm import relax, te, tir, topi -from . import tir_utils -from .quantization import QuantizationSpec -from .quantization import FQuantize, FTEDequantize, convert_TE_func - - -@dataclass -class AutogptqQuantizationSpec(QuantizationSpec): - """The quantization specification for group quantization algorithm.""" - - mode: Literal["int2", "int3", "int4", "int8"] - sym: bool - group_size: int - storage_nbit: int = 32 - - quantized_suffix = ["qweight", "qzeros", "scales", "g_idx"] - - def get_loaded_tensor_info( - self, pname: str, param_info: relax.TensorStructInfo - ) -> Tuple[List[str], List[relax.TensorStructInfo]]: - assert self.storage_nbit == 32, "Only support 32bit storage currently" - - quantized_pnames = self.quant_convert_pname_fwd(pname) - if len(quantized_pnames) == 1: - return quantized_pnames, [param_info] - else: - assert len(quantized_pnames) == 4 - assert param_info.ndim == 2 - nbit = int(self.mode[-1]) - tensor_info = [] - outfeatures, infeatures = param_info.shape.values - group_size = self.group_size if self.group_size != -1 else infeatures - - def get_quantized_shape_dtype(quantized_pname: str): - if quantized_pname.endswith("qweight"): - return (infeatures // self.storage_nbit * nbit, outfeatures), "uint32" - elif quantized_pname.endswith("qzeros"): - return ( - infeatures // group_size, - outfeatures // self.storage_nbit * nbit, - ), "uint32" - elif quantized_pname.endswith("scales"): - return (infeatures // group_size, outfeatures), "float16" - elif quantized_pname.endswith("g_idx"): - return (infeatures,), "uint32" - else: - raise ValueError(f"Unrecognized quantized parameter name {quantized_pname}") - - for quantized_pname in quantized_pnames: - shape, dtype = get_quantized_shape_dtype(quantized_pname) - tensor_info.append(relax.TensorStructInfo(shape, dtype)) - - return quantized_pnames, tensor_info - - def quant_convert_pname_fwd(self, torch_pname: str) -> List[str]: - # For Llama: - if "_proj.weight" in torch_pname: - return [torch_pname.replace("weight", suffix) for suffix in self.quantized_suffix] - return [torch_pname] - - def run_prequantize(self, model_path: str) -> str: - # with auto-gptq >= 0.2.0 - try: - import auto_gptq # pylint: disable=import-outside-toplevel - import transformers # pylint: disable=import-outside-toplevel - except ImportError: - raise ImportError( - "Please install auto_gptq package (version >= 0.2.0) and " - "transformers package to use AutoGPTQ quantization." - ) - import os - from transformers import AutoTokenizer - from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig - - quantized_model_path = ( - model_path - + f"-gptq-i{self.mode[-1]}" - + ("-sym" if self.sym else "") - + f"-g{self.group_size}" - ) - if os.path.isdir(quantized_model_path): - return quantized_model_path - - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) - examples = [ - tokenizer( - "MLC LLM is a universal solution that allows any language models " - "to be deployed natively on a diverse set of hardware backends and " - "native applications, plus a productive framework for everyone to " - "further optimize model performance for their own use cases." - ) - ] - quantize_config = BaseQuantizeConfig( - bits=int(self.mode[-1]), # quantize bits - desc_act=False, # disable activation description - group_size=self.group_size, # disable group quantization - ) - - model = AutoGPTQForCausalLM.from_pretrained(model_path, quantize_config) - model.quantize(examples) - - # save quantized model - model.save_quantized(quantized_model_path) - tokenizer.save_pretrained(quantized_model_path) - return quantized_model_path - - def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]: - return None - - def get_dequantize_func( - self, - param_info: relax.TensorStructInfo, - qparam_info: List[relax.TensorStructInfo], - ) -> Optional[FQuantize]: - return convert_TE_func( - decoding_func( - sym=self.sym, - nbit=int(self.mode[-1]), - storage_nbit=self.storage_nbit, - dim_length=param_info.shape.values[-1], - dtype=self.dtype, - ), - func_name="decode", - ) - - def convert_param_bkwd(self, torch_pname: str, torch_param): - target_dtype = ( - self.dtype if "_proj." not in torch_pname or "scales" in torch_pname else "uint32" - ) - - # For Llama - combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] - if any([name in torch_pname for name in combined_layers]): - return None - return [(torch_pname, torch_param.astype(target_dtype))] - - def compute_relax_param(self, relax_pname: str, torch_params: List[Any]): - import numpy as np - - # For Llama - if "query_key_value_proj" in relax_pname: - assert len(torch_params) == 3 - elif "gate_up_proj" in relax_pname: - assert len(torch_params) == 2 - else: - raise ValueError("Unexpected param loading") - - if "g_idx" in relax_pname: - return torch_params[0].astype("uint32") - else: - target_dtype = self.dtype if "scales" in relax_pname else "uint32" - return np.concatenate(torch_params, axis=-1).astype(target_dtype) - - -def decoding_func( - sym: bool, - nbit: int, - storage_nbit: int, - dim_length: tir.PrimExpr, - dtype: str = "float16", -) -> FTEDequantize: - assert dtype in ["float16"], "Only support float16 currently" - assert sym == False, "Only support sym=False currently" - assert storage_nbit == 32, "Only support storage_nbit=32 currently" - - def te_decode_asym(qweight, qzeros, scales, g_idx): - n_float_per_u32 = 32 // nbit - - def f_decode_asym(i, j): - zeros = tir_utils._tir_u32_to_int_to_float( - nbit, - qzeros[g_idx[i], j // n_float_per_u32], - j % n_float_per_u32, - dtype=dtype, - ) - data_float = tir_utils._tir_u32_to_int_to_float( - nbit, - qweight[i // n_float_per_u32, j], - i % n_float_per_u32, - dtype=dtype, - ) - scale_float, bias_float = scales[g_idx[i], j], zeros + 1 - w = (data_float - bias_float) * scale_float - return w - - shape = (dim_length, qweight.shape[1]) - w = te.compute(shape=shape, fcompute=f_decode_asym, name="decode") - w = topi.transpose(w) - return w - - return te_decode_asym diff --git a/mlc_llm/quantization/ft_quantization.py b/mlc_llm/quantization/ft_quantization.py deleted file mode 100644 index 286ca9a28c..0000000000 --- a/mlc_llm/quantization/ft_quantization.py +++ /dev/null @@ -1,219 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional - -import tvm -from tvm.contrib.nvcc import parse_compute_version -from tvm import relax, te, tir, topi -from tvm.script import tir as T -from tvm.relax.expr_functor import visitor - -from . import tir_utils -from .quantization import QuantizationSpec, QuantSpecUpdater -from .quantization import FQuantize, convert_TE_func -from .group_quantization import GroupQuantizationSpec - - -@dataclass -class FTQuantizationSpec(QuantizationSpec): - """The quantization specification for the FasterTransformer kernel.""" - - def __init__(self, dtype, nbit, group_size=-1): - super().__init__(dtype) - self.nbit = nbit - assert group_size in [-1, 64, 128], f"Group size {group_size} is not supported." - self.group_size = group_size - - if tvm.cuda(0).exist: - major, minor = parse_compute_version(tvm.cuda(0).compute_version) - if major == 8: - self.sm = 80 - else: - self.sm = 10 * major + minor - else: - self.sm = None - - self.do_preprocess = True - - def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]: - assert self.sm is not None - - def f_quantize(bb: relax.BlockBuilder, inputs: List[relax.Expr]): - encoded_data = bb.emit_te( - encoding_func( - self.nbit, - 8, - group_size=self.group_size, - dtype=self.dtype, - ), - inputs[0], - primfunc_name_hint="encode", - ) - - packed_weight = bb.normalize(encoded_data[0]) - - if self.do_preprocess: - encoded_weight = bb.emit( - relax.call_pure_packed( - "cutlass.ft_preprocess_weight", - packed_weight, - self.sm, - self.nbit == 4, - sinfo_args=packed_weight.struct_info, - ) - ) - else: - encoded_weight = packed_weight - - return bb.emit(relax.Tuple([encoded_weight, encoded_data[1]])) - - return f_quantize - - def get_dequantize_func( - self, - param_info: relax.TensorStructInfo, - qparam_info: List[relax.TensorStructInfo], - ) -> Optional[FQuantize]: - return convert_TE_func( - decoding_func( - self.nbit, - storage_nbit=8, - group_size=self.group_size, - ), - func_name="decode", - ) - - -def encoding_func(nbit: int, storage_nbit: int, group_size: int, dtype: str = "float32"): - def te_encode_sym(weight: te.Tensor): - """Encode the weight tensor of shape [N, K] into a quantized weight tensor of shape - [K, N // float_per_int] and a scale tensor of shape [K // group_size, N] - """ - n_float_per_int = storage_nbit // nbit - max_int_value = (1 << (nbit - 1)) - 1 - - cur_group_size = weight.shape[1] if group_size == -1 else group_size - scale_min_shape = (tir.ceildiv(weight.shape[1], cur_group_size), weight.shape[0]) - k = te.reduce_axis((0, cur_group_size), name="k") - max_abs_value = te.compute( - shape=scale_min_shape, - fcompute=lambda group, i: te.max( - te.abs( - tir.if_then_else( - group * cur_group_size + k < weight.shape[1], - weight[i, group * cur_group_size + k], - tir.const(0, dtype=weight.dtype), - ) - ), - axis=k, - ), - name="max_abs_value", - ) - - def f_compute_scale(*idx): - max_value = tir.max(tir.Cast(dtype, max_abs_value(*idx)), tir.const(1e-4, dtype)) - return max_value / tir.const(max_int_value, dtype) - - scale = te.compute(shape=scale_min_shape, fcompute=f_compute_scale, name="scale") - storage_dtype = "int" + str(storage_nbit) - - def f_scale_weight(i, j): - w_scaled = tir.round(tir.Cast(dtype, weight[i, j]) / scale[j // cur_group_size, i]) - w_scaled = T.min( - T.max(w_scaled, tir.const(-max_int_value - 1, dtype)), - tir.const(max_int_value, dtype), - ).astype(storage_dtype) - if n_float_per_int == 1: - return w_scaled - return w_scaled & tir.const((1 << nbit) - 1, storage_dtype) - - n_i32 = tir.ceildiv(weight.shape[0], n_float_per_int) - - if n_float_per_int == 1: - w_gathered = te.compute( - shape=(weight.shape[1], n_i32), - fcompute=lambda j, i: f_scale_weight(i, j), - name="w_gathered", - ) - else: - k = te.reduce_axis((0, n_float_per_int), name="k") - reducer = te.comm_reducer( - fcombine=lambda x, y: tir.bitwise_or(x, y), - fidentity=lambda dtype: tir.const(0, storage_dtype), - name="bitwise_or", - ) - w_gathered = te.compute( - shape=(weight.shape[1], n_i32), - fcompute=lambda j, i: reducer( - tir.if_then_else( - i * n_float_per_int + k < weight.shape[0], - f_scale_weight(i * n_float_per_int + k, j) - << (k.astype(storage_dtype) * tir.const(nbit, storage_dtype)), - tir.const(0, storage_dtype), - ), - axis=k, - ), - name="w_gathered", - ) - - return w_gathered, topi.cast(scale, "float16") - - return te_encode_sym - - -def decoding_func(nbit: int, storage_nbit: int, group_size: int): - def te_decode_sym(data, scale): - n_float_per_int = storage_nbit // nbit - cur_group_size = data.shape[0] if group_size == -1 else group_size - - def f_decode_sym(i, j): - if n_float_per_int == 1: - data_float = tir.Cast("float16", data[i, j]) - else: - f_convert = tir_utils._tir_packed_int_to_int_to_float(storage_nbit) - data_float = f_convert( - nbit, data[i, j // n_float_per_int], j % n_float_per_int, dtype="float16" - ) - - scale_float = scale[i // cur_group_size, j] - return data_float * scale_float - - shape = (data.shape[0], data.shape[1] * n_float_per_int) - w = te.compute(shape=shape, fcompute=f_decode_sym, name="decode") - # Dummy transpose for FuseDecodeTranspose - return topi.transpose(w) - - return te_decode_sym - - -@visitor -class FTQuantizeUpdater(QuantSpecUpdater._cls): - def visit_call_(self, call: relax.Call): - if call.op != tvm.ir.Op.get("relax.matmul"): - return - rhs = self.lookup_binding(call.args[1]) - assert rhs is not None - if ( - rhs.op != tvm.ir.Op.get("relax.permute_dims") - or rhs.attrs.axes is not None - or rhs.args[0].struct_info.ndim != 2 - ): - return - - if rhs.args[0] not in self.param_map: - return - - param = self.param_map[rhs.args[0]] - - if call.struct_info.dtype == "float32" or rhs.struct_info.shape[-1] % 8 != 0: - # FT requires N to be a multiple of 8 - # FT does not support fp32 output dtype - # TODO(masahi): If `matmul(..., out_dtype="float32")` is immediately followed - # by `cast(..., "float16")`, `matmul -> cast` can be offloaded. - param.quant_spec = GroupQuantizationSpec( - param.param_info.dtype, - mode="int4", - sym=True, - storage_nbit=32, - group_size=32, - transpose=False, - ) diff --git a/mlc_llm/quantization/group_quantization.py b/mlc_llm/quantization/group_quantization.py deleted file mode 100644 index 7603ad29f3..0000000000 --- a/mlc_llm/quantization/group_quantization.py +++ /dev/null @@ -1,214 +0,0 @@ -from dataclasses import dataclass -from typing import List, Literal, Optional - -import tvm -from tvm import relax, te, tir, topi -from tvm.script import tir as T -from tvm.relax.expr_functor import visitor - -from . import tir_utils -from .quantization import QuantizationSpec, QuantSpecUpdater -from .quantization import NoQuantizationSpec -from .quantization import FQuantize, FTEQuantize, FTEDequantize, convert_TE_func - - -@dataclass -class GroupQuantizationSpec(QuantizationSpec): - """The quantization specification for group quantization algorithm.""" - - mode: Literal["int3", "int4"] - sym: bool - storage_nbit: int - group_size: int - transpose: bool - - def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]: - return convert_TE_func( - encoding_func( - sym=self.sym, - group_size=self.group_size, - nbit=int(self.mode[-1]), - mode=self.mode, - storage_nbit=self.storage_nbit, - transpose=self.transpose, - dtype=self.dtype, - ), - func_name="encode", - ) - - def get_dequantize_func( - self, - param_info: relax.TensorStructInfo, - qparam_info: List[relax.TensorStructInfo], - ) -> Optional[FQuantize]: - return convert_TE_func( - decoding_func( - sym=self.sym, - group_size=self.group_size, - nbit=int(self.mode[-1]), - mode=self.mode, - storage_nbit=self.storage_nbit, - dim_length=param_info.shape.values[-1], - data_transposed=self.transpose, - transpose_output=self.transpose, - dtype=self.dtype, - ), - func_name="decode", - ) - - -# fmt: off -def encoding_func(sym: bool, group_size: int, nbit: int, mode: str, storage_nbit: int, transpose: bool=True, dtype: str = "float32") -> FTEQuantize: - def te_encode_asym(weight: te.Tensor): - assert weight.shape[1] % group_size == 0 - n_group = weight.shape[1] // group_size - n_float_per_u32 = 32 // nbit - - scale_min_shape = (weight.shape[0], n_group) - k = te.reduce_axis((0, group_size), name="k") - min_value = te.compute(shape=scale_min_shape, fcompute=lambda i, j: te.min(weight[i, j * group_size + k], axis=k), name="min_value") - max_value = te.compute(shape=scale_min_shape, fcompute=lambda i, j: te.max(weight[i, j * group_size + k], axis=k), name="max_value") - scale = te.compute(shape=scale_min_shape, fcompute=lambda i, j: (max_value[i, j] - min_value[i, j]) / tir.const((1 << nbit) - 1, dtype), name="scale") - - def f_scale_weight(i, j): - group_idx = j // group_size - w_scaled = tir.round((weight[i, j] - min_value[i, group_idx]) / scale[i, group_idx]).astype("int32") - w_scaled = T.min(T.max(w_scaled, tir.const(0, "int32")), tir.const((1 << nbit) - 1, "int32")) - w_scaled = w_scaled.astype("uint32") - return w_scaled - - k = te.reduce_axis((0, n_float_per_u32), name="k") - reducer = te.comm_reducer(fcombine=lambda x, y: tir.bitwise_or(x, y), fidentity=lambda dtype: tir.const(0, dtype), name="bitwise_or") - if dtype == "float32": - if transpose: - w_gathered = te.compute(shape=(weight.shape[1] // n_float_per_u32, weight.shape[0]), fcompute=lambda j, i: reducer(f_scale_weight(i, j * n_float_per_u32 + k) << (k * nbit).astype("uint32"), axis=k), name="w_gathered") - scale_bias = te.compute(shape=(n_group, weight.shape[0]), fcompute=lambda j, i: tir_utils._tir_f32x2_to_bf16x2_to_u32(scale[i, j], min_value[i, j], round_to_even=True), name="scale_min") - else: - w_gathered = te.compute(shape=(weight.shape[0], weight.shape[1] // n_float_per_u32), fcompute=lambda i, j: reducer(f_scale_weight(i, j * n_float_per_u32 + k) << (k * nbit).astype("uint32"), axis=k), name="w_gathered") - scale_bias = te.compute(shape=(weight.shape[0], n_group), fcompute=lambda i, j: tir_utils._tir_f32x2_to_bf16x2_to_u32(scale[i, j], min_value[i, j], round_to_even=True), name="scale_min") - return w_gathered, scale_bias - else: - if transpose: - w_gathered = te.compute(shape=(weight.shape[1] // n_float_per_u32, weight.shape[0]), fcompute=lambda j, i: reducer(f_scale_weight(i, j * n_float_per_u32 + k) << (k * nbit).astype("uint32"), axis=k), name="w_gathered") - scale = te.compute(shape=(n_group, weight.shape[0]), fcompute=lambda j, i: scale[i, j], name="scale_transpose") - min_value = te.compute(shape=(n_group, weight.shape[0]), fcompute=lambda j, i: min_value[i, j], name="min_transpose") - else: - w_gathered = te.compute(shape=(weight.shape[0], weight.shape[1] // n_float_per_u32), fcompute=lambda i, j: reducer(f_scale_weight(i, j * n_float_per_u32 + k) << (k * nbit).astype("uint32"), axis=k), name="w_gathered") - return w_gathered, scale, min_value - - def te_encode_sym(weight: te.Tensor): - n_group = tir.ceildiv(weight.shape[1], group_size) - n_float_per_int = storage_nbit // nbit - max_int_value = (1 << (nbit - 1)) - 1 - assert group_size % n_float_per_int == 0 - - scale_min_shape = (weight.shape[0], n_group) - k = te.reduce_axis((0, group_size), name="k") - max_abs_value = te.compute(shape=scale_min_shape, fcompute=lambda i, j: te.max(tir.if_then_else(j * group_size + k < weight.shape[1], te.abs(weight[i, j * group_size + k]), tir.min_value(dtype)), axis=k), name="max_abs_value") - - def f_compute_scale(i, j): - max_value = tir.max(max_abs_value[i, j], tir.const(1e-4, dtype)) - return (max_value / tir.const(max_int_value, dtype)) if mode.startswith("int") else max_value - - scale = te.compute(shape=scale_min_shape, fcompute=f_compute_scale, name="scale") - storage_dtype = ("uint" + str(storage_nbit)) if mode.startswith("int") else "uint32" - - def f_scale_weight(i, j): - group_idx = j // group_size - if mode.startswith("int"): - w_scaled = tir.round(weight[i, j] / scale[i, group_idx] + tir.const(max_int_value, dtype)) - w_scaled = T.min(T.max(w_scaled, tir.const(0, dtype)), tir.const(max_int_value * 2, dtype)).astype(storage_dtype) - return w_scaled - else: - f_convert = tir_utils._tir_f32_to_uint_to_f4 if dtype == "float32" else tir_utils._tir_f16_to_uint_to_f4 - return f_convert(weight[i, j] / scale[i, group_idx]) - - k = te.reduce_axis((0, n_float_per_int), name="k") - reducer = te.comm_reducer(fcombine=lambda x, y: tir.bitwise_or(x, y), fidentity=lambda dtype: tir.const(0, dtype), name="bitwise_or") - n_i32 = tir.ceildiv(group_size, n_float_per_int) * n_group - if transpose: - w_gathered = te.compute(shape=(n_i32, weight.shape[0]), fcompute=lambda j, i: reducer(tir.if_then_else(j * n_float_per_int + k < weight.shape[1], f_scale_weight(i, j * n_float_per_int + k) << (k.astype(storage_dtype) * tir.const(nbit, storage_dtype)), tir.const(0, storage_dtype)), axis=k), name="w_gathered") - scale = te.compute(shape=(n_group, weight.shape[0]), fcompute=lambda j, i: scale[i, j]) - else: - w_gathered = te.compute(shape=(weight.shape[0], n_i32), fcompute=lambda i, j: reducer(tir.if_then_else(j * n_float_per_int + k < weight.shape[1], f_scale_weight(i, j * n_float_per_int + k) << (k.astype(storage_dtype) * tir.const(nbit, storage_dtype)), tir.const(0, storage_dtype)), axis=k), name="w_gathered") - return w_gathered, scale - - return te_encode_sym if sym else te_encode_asym - - -def decoding_func(sym: bool, group_size: int, nbit: int, mode: str, storage_nbit: int, dim_length: tir.PrimExpr, data_transposed: bool=True, transpose_output: bool=False, dtype: str = "float32") -> FTEDequantize: - def te_decode_asym(*args): - n_float_per_u32 = 32 // nbit - data = args[0] - if dtype == "float32": - scale_bias_bf16x2 = args[1] - else: - scale, min_value = args[1], args[2] - - def f_decode_asym(i, j): - if data_transposed: - data_float = tir_utils._tir_u32_to_int_to_float(nbit, data[i // n_float_per_u32, j], i % n_float_per_u32, dtype=dtype) - if dtype == "float32": - scale_float, bias_float = tir_utils._tir_u32_to_bf16x2_to_f32x2(scale_bias_bf16x2[i // group_size, j]) - else: - scale_float, bias_float = scale[i // group_size, j], min_value[i // group_size, j] - else: - data_float = tir_utils._tir_u32_to_int_to_float(nbit, data[i, j // n_float_per_u32], j % n_float_per_u32, dtype=dtype) - if dtype == "float32": - scale_float, bias_float = tir_utils._tir_u32_to_bf16x2_to_f32x2(scale_bias_bf16x2[i, j // group_size]) - else: - scale_float, bias_float = scale[i, j // group_size], min_value[i, j // group_size] - w = data_float * scale_float + bias_float - return w - - shape = (dim_length, data.shape[1]) if data_transposed else (data.shape[0], dim_length) - w = te.compute(shape=shape, fcompute=f_decode_asym, name="decode") - if transpose_output: - w = topi.transpose(w) - return w - - def te_decode_sym(data, scale): - n_float_per_int = storage_nbit // nbit - - def f_decode_sym(i, j): - f_convert = tir_utils._tir_packed_uint_to_uint_to_float(storage_nbit) if mode.startswith("int") else (tir_utils._tir_u32_to_f4_to_f32 if dtype == "float32" else tir_utils._tir_u32_to_f4_to_f16) - if data_transposed: - data_float = f_convert(nbit, data[i // n_float_per_int, j], i % n_float_per_int, dtype=dtype) - scale_float = scale[i // group_size, j] - else: - data_float = f_convert(nbit, data[i, j // n_float_per_int], j % n_float_per_int, dtype=dtype) - scale_float = scale[i, j // group_size] - return data_float * scale_float - - shape = (dim_length, data.shape[1]) if data_transposed else (data.shape[0], dim_length) - w = te.compute(shape=shape, fcompute=f_decode_sym, name="decode") - if transpose_output: - w = topi.transpose(w) - return w - - return te_decode_sym if sym else te_decode_asym -# fmt: on - - -# A simple example demo showing how QuantSpecUpdater is used. -# NOTE: This visitor is only for demo purpose and should not be put into real use. -@visitor -class GroupQuantDemoUpdater(QuantSpecUpdater._cls): - def visit_call_(self, call: relax.Call): - if call.op != tvm.ir.Op.get("relax.matmul"): - return - rhs = self.lookup_binding(call.args[1]) - assert rhs is not None - if ( - rhs.op != tvm.ir.Op.get("relax.permute_dims") - or rhs.attrs.axes is not None - or rhs.args[0].struct_info.ndim != 2 - ): - return - - if rhs.args[0] not in self.param_map: - return - param = self.param_map[rhs.args[0]] - # Update to no quantization for matmul with float32 output dtype. - if call.struct_info.dtype == "float32": - param.quant_spec = NoQuantizationSpec(param.param_info.dtype) diff --git a/mlc_llm/quantization/quantization.py b/mlc_llm/quantization/quantization.py deleted file mode 100644 index 2922c936b8..0000000000 --- a/mlc_llm/quantization/quantization.py +++ /dev/null @@ -1,217 +0,0 @@ -import enum -from dataclasses import dataclass -from typing import Any, Callable, List, Literal, Optional, Tuple, Type, Union - -import tvm -from tvm import relax, te -from tvm.relax.expr_functor import PyExprVisitor, visitor - -FQuantize = Callable[[relax.BlockBuilder, List[relax.Expr]], relax.Var] -FTEQuantize = Callable[[te.Tensor], List[te.Tensor]] -FTEDequantize = Callable[[List[te.Tensor]], te.Tensor] - - -@dataclass -class QuantizationSpec: - """The base dataclass of quantization specification. - A specification describes how a parameter is quantized and dequantized. - - A subclass of QuantizationSpec - - contains more data fields (e.g., the "group size" in group quantization) - which instruct the quantization/dequantization, - - defines the `get_quantize_func` method, which returns a function - (`Callable[[relax.BlockBuilder, List[relax.Expr]], relax.Var]`) that takes a - Relax BlockBuilder and the weight relax Var to be quantized, computes - the quantization and returns the relax Var of quantized results. - algorithm of the quantization. - - defines the `get_dequantize_func` method, which returns function - (`Callable[[relax.BlockBuilder, List[relax.Expr]], relax.Var]`) that takes - the quantized results, computes and returns the dequantization result. - - optionally overloads the `get_loaded_tensor_info` when the parameter is - pre-quantized, in which case `get_loaded_tensor_info` needs to be overloaded - so that we know how many quantized data tensors there are, and the dtype - and shape of each quantized data tensor. - """ - - dtype: str - - def get_loaded_tensor_info( - self, pname: str, param_info: relax.TensorStructInfo - ) -> Tuple[List[str], List[relax.TensorStructInfo]]: - """Returns the names and shapes and dtypes of the tensors that need to - be loaded from the disk. - - It is useful when the parameter is pre-quantized. In such cases, we need - to know how many tensors the parameter is quantized into, and together - with the dtype and shape of each tensor, so that we can load the - pre-quantized tensors in. - """ - return [pname], [param_info] - - def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]: - """Returns the function which computes quantization. - Returning `None` means the parameter does not need quantization or is - pre-quantized. - - The returned function takes a Relax BlockBuilder and a (list of) weight - relax Var to be quantized, computes the quantization and returns the - quantization result Relax Var(s). - - You can use `convert_TE_func` to convert a TE function to the function - of the desired return format. See `group_quantization.py` for examples. - """ - return NotImplementedError() - - def get_dequantize_func( - self, - param_info: relax.TensorStructInfo, - qparam_info: List[relax.TensorStructInfo], - ) -> Optional[FQuantize]: - """Returns the function which computes dequantization. - Returning `None` means the parameter does not need dequantization. - - The returned function takes a Relax BlockBuilder and a (list of) - quantized weight relax Var, computes the dequantization and returns the - result Relax Var(s). - - You can use `convert_TE_func` to convert a TE function to the function - of the desired return format. See `group_quantization.py` for examples. - """ - return NotImplementedError() - - -@dataclass -class NoQuantizationSpec(QuantizationSpec): - """The quantization specification that describes doing no quantization.""" - - def get_quantize_func(self, param_info: relax.TensorStructInfo) -> Optional[FQuantize]: - return None - - def get_dequantize_func( - self, - param_info: relax.TensorStructInfo, - qparam_info: List[relax.TensorStructInfo], - ) -> Optional[FQuantize]: - return None - - -class ParamQuantKind(enum.IntEnum): - """The parameter quantization kind class. - - We categorized all the parameters in a model into four kinds: - - the weights of the internal linear layers, which are the main targets of quantization, - - the embedding table of every token, - - the weight of the fully-connected layer at the end of the model, which is - used for computes the logits of each input token, - - other parameters (e.g., the weight of layer normalization, etc.). - """ - - linear_weight = 0 - embedding_table = 1 - final_fc_weight = 2 - others = 3 - - -class QuantizationScheme: - """The quantization scheme class describes how an entire model is quantized. - It contains the quantization specification for each parameter quantization kind. - - Besides, it has an optional field for a visitor class which has the ability to - take the constructed model (in format of IRModule) as input, go through the - model and update the QuantizationSpec for certain parameters. - """ - - name: str - linear_weight: QuantizationSpec - embedding_table: QuantizationSpec - final_fc_weight: QuantizationSpec - others: QuantizationSpec - - qspec_updater_class: Optional[Type["QuantSpecUpdater"]] - f_convert_param_bkwd: Optional[Callable[[str, Any], Optional[List[Tuple[str, Any]]]]] - f_compute_relax_param: Optional[Callable[[str, List[Any]], Any]] - f_run_prequantize: Optional[Callable[[str], str]] - - def __init__( - self, - name: str, - linear_weight: QuantizationSpec, - *, - embedding_table: Optional[Union[QuantizationSpec, Literal["same_as_linear_weight"]]] = None, - final_fc_weight: Optional[Union[QuantizationSpec, Literal["same_as_linear_weight"]]] = None, - others: Optional[QuantizationSpec] = None, - qspec_updater_class: Optional[Type["QuantSpecUpdater"]] = None, - ) -> None: - self.name = name - self.linear_weight = linear_weight - self.others = others if others is not None else NoQuantizationSpec(self.model_dtype) - - if embedding_table is None: - self.embedding_table = self.others - elif embedding_table == "same_as_linear_weight": - self.embedding_table = self.linear_weight - else: - self.embedding_table = embedding_table - - if final_fc_weight is None: - self.final_fc_weight = self.others - elif final_fc_weight == "same_as_linear_weight": - self.final_fc_weight = self.linear_weight - else: - self.final_fc_weight = final_fc_weight - - self.qspec_updater_class = qspec_updater_class - self.f_convert_param_bkwd = None - self.f_compute_relax_param = None - self.f_run_prequantize = None - - for spec in [self.linear_weight, self.embedding_table, self.final_fc_weight, self.others]: - if hasattr(spec, "convert_param_bkwd"): - self.f_convert_param_bkwd = spec.convert_param_bkwd - if hasattr(spec, "compute_relax_param"): - self.f_compute_relax_param = spec.compute_relax_param - if hasattr(spec, "run_prequantize"): - self.f_run_prequantize = spec.run_prequantize - - @property - def model_dtype(self) -> str: - """Returns the overall model dtype, which is defined as the dtype of - the linear layers. - """ - return self.linear_weight.dtype - - -def convert_TE_func(te_func: Union[FTEQuantize, FTEDequantize], func_name: str) -> FQuantize: - def func(bb: relax.BlockBuilder, inputs: List[relax.Expr]) -> relax.Var: - return bb.call_te(te_func, *inputs, primfunc_name_hint=func_name) - - return func - - -@visitor -class QuantSpecUpdater(PyExprVisitor): - def __init__(self, param_manager) -> None: - super().__init__() - self.param_manager = param_manager - self.param_map = None - self.builder = relax.BlockBuilder() - - def lookup_binding(self, var: relax.Var): - return self.builder.lookup_binding(var) - - def visit_module(self, mod: tvm.IRModule): - for gv, func in mod.functions.items(): - if not isinstance(func, relax.Function): - continue - if func.attrs is None or not "num_input" in func.attrs: - continue - - self.param_map = dict() - num_input = int(func.attrs["num_input"]) - params_in_func = self.param_manager.params_in_func[gv.name_hint] - assert len(func.params) - num_input == len(params_in_func) - for i, relax_param in enumerate(func.params[num_input:]): - self.param_map[relax_param] = params_in_func[i] - - self.builder.normalize(func) - self.visit_expr(func) diff --git a/mlc_llm/quantization/tir_utils.py b/mlc_llm/quantization/tir_utils.py deleted file mode 100644 index 02d4c72c71..0000000000 --- a/mlc_llm/quantization/tir_utils.py +++ /dev/null @@ -1,106 +0,0 @@ -"""TIR computation utilities for quantization.""" - -import tvm -from tvm import tir - -# fmt: off -def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool=True): - mask = tir.const((1 << 16) - 1, "uint32") - res = [] - for data in [v0, v1]: - u32_val = tir.reinterpret("uint32", data) - if round_to_even: - rounding_bias = ((u32_val >> tir.const(16, "uint32")) & tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32") - u32_val += rounding_bias - res.append((u32_val >> tir.const(16, "uint32")) & mask) - return res[0] | (res[1] << tir.const(16, "uint32")) - - -def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr): - mask = tir.const((1 << 16) - 1, "uint32") - x0 = x & mask - x1 = (x >> 16) & mask - return (tir.reinterpret("float32", x << tir.const(16, "uint32")) for x in [x0, x1]) - - -def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert val.dtype == "uint32" - mask = tvm.tir.const((1 << nbit) - 1, "uint32") - return tir.Cast(dtype, (val >> (pos * nbit).astype("uint32")) & mask) - - -def _tir_packed_uint_to_uint_to_float(storage_nbit: int): - storage_dtype = "uint" + str(storage_nbit) - - def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert val.dtype == storage_dtype - max_int_value = (1 << (nbit - 1)) - 1 - return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const((1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) - - return f_convert - - -def _tir_packed_int_to_int_to_float(storage_nbit: int): - storage_dtype = "int" + str(storage_nbit) - - def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert val.dtype == storage_dtype - mask = tir.const((1 << nbit) - 1, "int32") - unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask - return tir.Cast(dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) - - return f_convert - - -def _tir_f32_to_uint_to_f4(val: tir.PrimExpr): - assert val.dtype == "float32" - val_u32 = tir.reinterpret("uint32", val) - # e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7) - # e_f32 == 120 -> e_f4 = 1 - # e_f32 < 120 -> e_f4 = 0 - m_h = (val_u32 >> tir.const(22, "uint32")) & tir.const(1, "uint32") - e_f32 = (val_u32 >> tir.const(23, "uint32")) & tir.const(255, "uint32") - s = (val_u32 >> tir.const(31, "uint32")) - e_f4 = tir.Select(e_f32 > tir.const(120, "uint32"), tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")), tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) - return (s << tir.const(3, "uint32")) | e_f4 - - -def _tir_f16_to_uint_to_f4(val: tir.PrimExpr): - assert val.dtype == "float16" - val_u32 = tir.Cast("uint32", tir.reinterpret("uint16", val)) - m_h = (val_u32 >> tir.const(9, "uint32")) & tir.const(1, "uint32") - e_f16 = (val_u32 >> tir.const(10, "uint32")) & tir.const(31, "uint32") - s = (val_u32 >> tir.const(15, "uint32")) - e_f4 = tir.Select(e_f16 > tir.const(8, "uint32"), tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")), tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) - return (s << tir.const(3, "uint32")) | e_f4 - - -def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert nbit == 4 - assert dtype == "float32" - assert val.dtype == "uint32" - # e_f4 == 0 -> e_f32 = 0 - # e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2 - mask = tvm.tir.const((1 << nbit) - 1, "uint32") - f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask - s = f4 >> tir.const(3, "uint32") - e_f4 = f4 & tir.const(7, "uint32") - e_f32 = e_f4 | tir.const(120, "uint32") - val_f32 = tir.reinterpret("float32", (e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32")) - return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32) - - -def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert nbit == 4 - assert dtype == "float16" - assert val.dtype == "uint32" - # e_f4 == 0 -> e_f16 = 0 - # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2 - mask = tvm.tir.const((1 << nbit) - 1, "uint32") - f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask - s = f4 >> tir.const(3, "uint32") - e_f4 = f4 & tir.const(7, "uint32") - e_f16 = e_f4 | tir.const(8, "uint32") - val_f16 = tir.reinterpret("float16", (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32")) - return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) -# fmt: on diff --git a/mlc_llm/relax_model/__init__.py b/mlc_llm/relax_model/__init__.py deleted file mode 100644 index 9ee3d0db52..0000000000 --- a/mlc_llm/relax_model/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import llama diff --git a/mlc_llm/relax_model/chatglm.py b/mlc_llm/relax_model/chatglm.py deleted file mode 100644 index f1a5b574dc..0000000000 --- a/mlc_llm/relax_model/chatglm.py +++ /dev/null @@ -1,807 +0,0 @@ -import argparse -import math -from dataclasses import dataclass -from typing import List, Tuple - -import tvm -from tvm import relax, te, tir -from tvm.relax.op import ( - astype, - broadcast_to, - expand_dims, - matmul, - maximum, - minimum, - permute_dims, - repeat, - reshape, - split, - squeeze, -) -from tvm.relax.op.nn import silu, softmax -from tvm.relax.testing import nn -from tvm.script import relax as R - -from ..quantization import ParamQuantKind, QuantizationScheme -from .commons import create_metadata_func -from .modules import Embedding, Linear, ModuleList, RotaryEmbedding -from .param_manager import ParamManager - - -@dataclass -class ChatGLMConfig: - def __init__( - self, - add_bias_linear: bool = False, - add_qkv_bias: bool = True, - ffn_hidden_size: int = 13696, - hidden_size: int = 4096, - kv_channels: int = 128, - layernorm_epsilon: float = 1e-05, - multi_query_group_num: int = 2, - num_attention_heads: int = 32, - num_layers: int = 28, - max_sequence_length: int = 2048, - padded_vocab_size: int = 65024, - eos_token_id: int = 2, - bos_token_id: int = 0, - dtype: str = "float32", - **kwargs, - ): - self.add_bias_linear = add_bias_linear - self.add_qkv_bias = add_qkv_bias - self.ffn_hidden_size = ffn_hidden_size - self.hidden_size = hidden_size - self.kv_channels = kv_channels - self.layernorm_epsilon = layernorm_epsilon - self.multi_query_group_num = multi_query_group_num - self.num_attention_heads = num_attention_heads - self.num_layers = num_layers - self.max_sequence_length = min(2048, max_sequence_length) - self.padded_vocab_size = padded_vocab_size - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.dtype = dtype - self.kwargs = kwargs - - -def _repeat_kv(k: relax.Expr, v: relax.Expr, n_rep: int, shape: relax.Expr): - k = nn.emit(reshape(repeat(k, n_rep, 1), shape)) - v = nn.emit(reshape(repeat(v, n_rep, 1), shape)) - return k, v - - -def _reshape(x: relax.Expr, shape: Tuple[int]): - x = nn.emit(reshape(x, R.shape(shape))) - return x - - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, dtype, eps=1e-5): - self.weight = nn.Parameter((hidden_size,), dtype=dtype, name="rms_norm_weight") - self.eps = tvm.tir.const(eps, dtype) - - def forward(self, hidden_states): - def f_rms_norm(x, weight): - is_float32 = x.dtype == "float32" - - def f_square(x): - return tir.Cast("float32", x) * tir.Cast("float32", x) if not is_float32 else x * x - - k = te.reduce_axis((0, x.shape[2]), name="k") - square_sum = te.compute( - (x.shape[0], x.shape[1]), - lambda bsz, i: te.sum(f_square(x[bsz, i, k]), axis=k), - name=x.op.name + "red_temp", - ) - - def f_div_cast(bsz, i, k): - x_val = x[bsz, i, k] - if not is_float32: - x_val = tir.Cast("float32", x_val) - return x_val / tir.sqrt(square_sum[bsz, i] / x.shape[2] + self.eps) - - def f_mul_cast(x, y): - value = x * y - if not is_float32: - value = tir.Cast(x.dtype, value) - return value - - return te.compute( - x.shape, - lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast(bsz, i, k)), - name="rms_norm", - ) - - return nn.emit_te( - f_rms_norm, - hidden_states, - self.weight, - primfunc_name_hint="rms_norm", - ) - - -class CoreAttention(nn.Module): - def __init__(self, config: ChatGLMConfig): - projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - - self.dtype = config.dtype - - def forward( - self, - q: relax.Expr, - k: relax.Expr, - v: relax.Expr, - attention_mask: relax.Expr, - ) -> relax.Expr: - bsz, sl, nh, hd = q.struct_info.shape - kv_sl = k.struct_info.shape[1] - - # [bsz, nh, sl, hd] - q = nn.emit(permute_dims(q, [0, 2, 1, 3])) - - # [bsz, nh, kv_sl, hd] - k = nn.emit(permute_dims(k, [0, 2, 1, 3])) - v = nn.emit(permute_dims(v, [0, 2, 1, 3])) - - # Calculate Q.K: [bsz, nh, sl, kv_sl] - matmul_result = nn.emit( - matmul(q, permute_dims(k, [0, 1, 3, 2])) - / relax.const(self.norm_factor, q.struct_info.dtype) - ) - attention_scores = _reshape(matmul_result, (bsz, nh, sl, kv_sl)) - - # Apply attention mask: [bsz, nh, sl, kv_sl] - attention_scores = nn.emit( - maximum( - attention_scores, - relax.const( - tvm.tir.min_value(attention_scores.struct_info.dtype).value, - attention_scores.struct_info.dtype, - ), - ) - ) - attention_scores = nn.emit(minimum(attention_scores, attention_mask)) - - # Calculate Softmax(Q.K) - if attention_scores.struct_info.dtype != "float32": - attention_scores = astype(attention_scores, "float32") - attention_probs = nn.emit(softmax(attention_scores, axis=-1)) - if attention_probs.struct_info.dtype != q.struct_info.dtype: - attention_probs = astype(attention_probs, q.struct_info.dtype) - - # Calculate Softmax(Q.K).V - context = nn.emit(matmul(attention_probs, v)) - context = nn.emit(permute_dims(context, [0, 2, 1, 3])) - context = _reshape(context, (bsz, sl, nh * hd)) - - return context - - -class SelfAttention(nn.Module): - def __init__( - self, - config: ChatGLMConfig, - rotary_pos_emb: RotaryEmbedding, - ): - self.projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - # Multi-query attention config - self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = ( - self.projection_size - + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num - ) - - self.query_key_value = Linear( - config.hidden_size, - self.qkv_hidden_size, - config.dtype, - bias=config.add_bias_linear or config.add_qkv_bias, - ) - - self.rotary_pos_emb = rotary_pos_emb - - self.core_attention = CoreAttention(config) - - self.dense = Linear( - self.projection_size, - config.hidden_size, - config.dtype, - bias=config.add_bias_linear, - ) - - self.dtype = config.dtype - - def forward( - self, - hidden_states: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_value: Tuple[relax.Expr, relax.Expr], - attention_mask: relax.Expr, - ) -> Tuple[relax.Expr, Tuple[relax.Expr, relax.Expr]]: - # hidden_states: [bsz, sl, hs] - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - - bsz, sl, _ = hidden_states.struct_info.shape - kv_sl = all_seq_len_shape.struct_info.values[0] - - mixed_x_layer = nn.emit( - split( - self.query_key_value(hidden_states), - indices_or_sections=[ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - ( - self.num_attention_heads_per_partition - + self.num_multi_query_groups_per_partition - ) - * self.hidden_size_per_attention_head, - ], - axis=-1, - ) - ) - - q_shape = ( - bsz, - sl, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - kv_shape = ( - bsz, - sl, - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - ) - - # queries: [bsz, sl, nh, hd] - q = _reshape(relax.TupleGetItem(mixed_x_layer, 0), q_shape) - - # keys: [bsz, sl, ng, hd] - k = _reshape(relax.TupleGetItem(mixed_x_layer, 1), kv_shape) - - # values: [bsz, sl, ng, hd] - v = _reshape(relax.TupleGetItem(mixed_x_layer, 2), kv_shape) - - # apply rotary embeddings - q, k = self.rotary_pos_emb(q, k, kv_sl - sl) - - assert k.struct_info.shape[0] == 1 and v.struct_info.shape[0] == 1 - squeezed_k, squeezed_v = nn.emit(squeeze(k, axis=0)), nn.emit(squeeze(v, axis=0)) - - k_cache, v_cache = past_key_value - f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") - k_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_append, - k_cache, - squeezed_k, - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - v_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_append, - v_cache, - squeezed_v, - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - past_key_value = (k_cache, v_cache) - - kv_sl = all_seq_len_shape.struct_info.values[0] - bsz, _, n_groups, head_dim = k.struct_info.shape - kv_cache_shape = R.shape([kv_sl, n_groups, head_dim]) - f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") - k = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - k_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)], - ) - ) - v = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - v_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)], - ) - ) - - n_rep = self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition - kv_attn_shape = R.shape( - [ - bsz, - kv_sl, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ] - ) - k, v = _repeat_kv(k, v, n_rep, kv_attn_shape) - - # core attention computation - context_layer = self.core_attention(q, k, v, attention_mask) - - # apply output projection - output = self.dense(context_layer) - - return output, past_key_value - - -class MLP(nn.Module): - def __init__(self, config: ChatGLMConfig): - super().__init__() - self.dtype = config.dtype - - self.dense_h_to_4h = Linear( - config.hidden_size, - config.ffn_hidden_size * 2, - config.dtype, - bias=config.add_bias_linear, - ) - - def swiglu(x: relax.Expr): - x = nn.emit(split(x, 2, axis=-1)) - return nn.emit(silu(x[0]) * x[1]) - - self.activation_func = swiglu - - self.dense_4h_to_h = Linear( - config.ffn_hidden_size, - config.hidden_size, - config.dtype, - bias=config.add_bias_linear, - ) - - def forward(self, hidden_states): - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - - hidden_states = self.dense_h_to_4h(hidden_states) - hidden_states = self.activation_func(hidden_states) - hidden_states = self.dense_4h_to_h(hidden_states) - - return hidden_states - - -class GLMBlock(nn.Module): - def __init__(self, config: ChatGLMConfig, rotary_pos_emb: RotaryEmbedding): - self.input_layernorm = RMSNorm( - hidden_size=config.hidden_size, - dtype=config.dtype, - eps=config.layernorm_epsilon, - ) - self.post_attention_layernorm = RMSNorm( - hidden_size=config.hidden_size, - dtype=config.dtype, - eps=config.layernorm_epsilon, - ) - - self.self_attention = SelfAttention(config, rotary_pos_emb) - self.mlp = MLP(config) - - self.dtype = config.dtype - - def forward( - self, - hidden_states: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_value: Tuple[relax.Expr], - attention_mask: relax.Expr, - ): - layernorm_output = self.input_layernorm(hidden_states) - attention_output, present_key_value = self.self_attention( - layernorm_output, all_seq_len_shape, past_key_value, attention_mask - ) - - # residual connection - layernorm_input = nn.emit(attention_output + hidden_states) - - layernorm_output = self.post_attention_layernorm(layernorm_input) - mlp_output = self.mlp(layernorm_output) - - # residual connection - output = nn.emit(mlp_output + layernorm_input) - - return output, present_key_value - - -class GLMTransformer(nn.Module): - def __init__(self, config: ChatGLMConfig, rotary_pos_emb: RotaryEmbedding): - self.num_layers = config.num_layers - - self.layers = ModuleList([GLMBlock(config, rotary_pos_emb) for _ in range(self.num_layers)]) - self.final_layernorm = RMSNorm( - hidden_size=config.hidden_size, - dtype=config.dtype, - eps=config.layernorm_epsilon, - ) - - def forward( - self, - hidden_states: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_values: relax.Expr, - attention_mask: relax.Expr, - ): - present_kv_cache = [] - for i, block in enumerate(self.layers): - past_key_value = past_key_values[i * 2], past_key_values[i * 2 + 1] - hidden_states, (present_k_cache, present_v_cache) = block( - hidden_states, - all_seq_len_shape=all_seq_len_shape, - past_key_value=past_key_value, - attention_mask=attention_mask, - ) - present_kv_cache.append(present_k_cache) - present_kv_cache.append(present_v_cache) - hidden_states = self.final_layernorm(hidden_states) - return hidden_states, present_kv_cache - - -class ChatGLMModel(nn.Module): - def __init__(self, config: ChatGLMConfig): - self.num_layers = config.num_layers - - self.embedding = Embedding( - num_embeddings=config.padded_vocab_size, - embedding_dim=config.hidden_size, - dtype=config.dtype, - ) - - self.seq_length = config.max_sequence_length - rotary_dim = config.kv_channels // 2 - - self.rotary_pos_emb = RotaryEmbedding( - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - position_embedding_base=10000, - max_sequence_length=config.max_sequence_length, - rotary_dim=rotary_dim, - swizzle_style="glm", - dtype=config.dtype, - ) - self.encoder = GLMTransformer(config, self.rotary_pos_emb) - self.output_layer = Linear( - in_features=config.hidden_size, - out_features=config.padded_vocab_size, - bias=False, - dtype=config.dtype, - ) - - self.dtype = config.dtype - - def _prepare_decoder_attention_mask(self, input_shape, kv_sl, dtype): - # create causal mask - # [bsz, sl] -> [bsz, 1, sl, kv_sl] - if isinstance(input_shape[-1], tvm.tir.SizeVar) or input_shape[-1] > 1: - bsz, sl = input_shape - - def min_max_triu_te(): - return te.compute( - (sl, sl), - lambda i, j: tvm.tir.Select( - j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype) - ), - name="make_diag_mask_te", - ) - - mask = nn.emit_te(min_max_triu_te) - mask = nn.emit(expand_dims(mask, 0)) - diag_mask = nn.emit(broadcast_to(mask, (bsz, 1, sl, sl))) - if kv_sl == sl: - return diag_mask - - def extend_te(x, sl, kv_sl): - return te.compute( - (bsz, 1, sl, kv_sl), - lambda b, _, i, j: te.if_then_else( - j < kv_sl - sl, - tvm.tir.max_value(dtype), - x[b, _, i, j - (kv_sl - sl)], - ), - name="concat_te", - ) - - return nn.emit_te(extend_te, diag_mask, sl, kv_sl) - else: - # Get kv_sl from input parameters - # [bsz, sl=1] -> [bsz, 1, sl=1, kv_sl] - bsz, sl = input_shape - mask = relax.op.full( - (bsz, 1, sl, kv_sl), - relax.const(tvm.tir.max_value(dtype).value, dtype), - dtype, - ) - return nn.emit(mask) - - def forward( - self, - input_ids: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_values: relax.Expr, - ): - batch_size, seq_length = input_ids.struct_info.shape - seq_length_with_past = all_seq_len_shape.struct_info.values[0] - - # Token Embeddings - inputs_embeds = self.embedding(input_ids) - - attention_mask = self._prepare_decoder_attention_mask( - (batch_size, seq_length), - seq_length_with_past, - dtype=self.dtype, - ) - - hidden_states, present_kv_cache = self.encoder( - inputs_embeds, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - attention_mask=attention_mask, - ) - - return hidden_states, present_kv_cache - - -class ChatGLMForCausalLM(nn.Module): - def __init__(self, config: ChatGLMConfig): - self.transformer = ChatGLMModel(config) - - self.dtype = config.dtype - - def forward( - self, - input_ids: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_values: relax.Expr, - ): - hidden_states, key_value_cache = self.transformer( - input_ids=input_ids, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - - def te_slice_last(x: te.Tensor): - _, sl, hs = x.shape - return te.compute( - shape=(1, 1, hs), - fcompute=lambda i, _, k: x[i, sl - 1, k], - name="slice_last", - ) - - hidden_states = nn.emit_te( - te_slice_last, - hidden_states, - primfunc_name_hint="slice_last", - ) - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - - lm_logits = self.transformer.output_layer(hidden_states) - - if lm_logits.struct_info.dtype != "float32": - lm_logits = nn.emit(astype(lm_logits, "float32")) - - return lm_logits, key_value_cache - - -def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: - if "embedding.weight" in name: - return ParamQuantKind.embedding_table - elif "transformer.output_layer.weight" in name: - return ParamQuantKind.final_fc_weight - elif param_info.ndim == 2 and name.endswith(".weight"): - return ParamQuantKind.linear_weight - else: - return ParamQuantKind.others - - -def create_encoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: ChatGLMConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "prefill" - - bsz = tvm.tir.IntImm("int64", 1) - sl = tvm.tir.SizeVar("n", "int64") - all_seq_len = tvm.tir.SizeVar("m", "int64") - with bb.function(func_name): - model = ChatGLMForCausalLM(config) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((bsz, sl), dtype="int32", name="input_ids") - all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo([relax.ObjectStructInfo() for _ in range(config.num_layers * 2)]), - ) - - with bb.dataflow(): - logits, key_value_cache = model( - input_ids=input_ids, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - params = [ - input_ids, - all_seq_len_shape, - past_key_values, - ] + model.parameters() - - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_decoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: ChatGLMConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "decode" - - bsz = 1 - all_seq_len = tvm.tir.SizeVar("m", "int64") - - with bb.function(func_name): - model = ChatGLMForCausalLM(config) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") - all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo([relax.ObjectStructInfo() for _ in range(config.num_layers * 2)]), - ) - with bb.dataflow(): - logits, key_value_cache = model( - input_ids=input_ids, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - params = [ - input_ids, - all_seq_len_shape, - past_key_values, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_kv_cache_func(bb: relax.BlockBuilder, config: ChatGLMConfig) -> None: - init_shape = relax.ShapeExpr( - ( - config.max_sequence_length, - config.multi_query_group_num, - config.hidden_size // config.num_attention_heads, - ) - ) - with bb.function("create_kv_cache", []): - with bb.dataflow(): - zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) - caches = [] - f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") - for _ in range(config.num_layers * 2): - caches.append( - bb.emit( - relax.call_pure_packed( - f_kv_cache_create, - zeros, - init_shape, - relax.PrimValue(0), - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - ) - gv = bb.emit_output(caches) - bb.emit_func_output(gv) - - -def create_softmax_func(bb: relax.BlockBuilder, config: ChatGLMConfig) -> None: - with bb.function("softmax_with_temperature"): - logits = nn.Placeholder((1, 1, config.padded_vocab_size), dtype="float32", name="logits") - temperature = nn.Placeholder((), dtype="float32", name="temperature") - with bb.dataflow(): - div = bb.emit(relax.op.divide(logits, temperature)) - softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) - gv = bb.emit_output(softmax) - bb.emit_func_output(gv, [logits, temperature]) - - -def get_model(args: argparse.Namespace, hf_config): - model = args.model - dtype = args.quantization.model_dtype - - if ( - model.startswith("chatglm2") - or model.startswith("codegeex2") - or model.startswith("chatglm3") - ): - config = ChatGLMConfig( - **hf_config, - dtype=dtype, - ) - - param_manager = ParamManager() - bb = relax.BlockBuilder() - create_encoding_func(bb, param_manager, config, args.quantization) - create_decoding_func(bb, param_manager, config, args.quantization) - create_kv_cache_func(bb, config) - create_softmax_func(bb, config) - create_metadata_func( - bb, - model_name=model, - max_window_size=config.max_sequence_length, - stop_tokens=[0], - add_prefix_space=False, - prefill_chunk_size=args.prefill_chunk_size, - ) - - mod = bb.get() - - tir_bound_map = dict() - tir_bound_map["n"] = ( - args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length - ) - tir_bound_map["m"] = config.max_sequence_length - for gv in mod.functions: - func = mod[gv] - if isinstance(func, relax.Function): - mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) - - if args.build_model_only: - return mod, param_manager, None, config - - def f_convert_pname_fwd(pname: str) -> List[str]: - if "transformer.embedding" in pname: - return [ - pname.replace("transformer.embedding", "transformer.embedding.word_embeddings") - ] - else: - return [pname] - - def f_convert_param_bkwd(torch_pname: str, torch_param): - if "transformer.embedding.word_embeddings" in torch_pname: - return [ - ( - torch_pname.replace( - "transformer.embedding.word_embeddings", - "transformer.embedding", - ), - torch_param.astype(dtype), - ) - ] - else: - return [(torch_pname, torch_param.astype(dtype))] - - param_manager.set_param_loading_func( - args.model_path, args.use_safetensors, f_convert_pname_fwd, f_convert_param_bkwd - ) - return mod, param_manager, [None] * len(param_manager.param_names), config - - raise ValueError(f"Unsupported model {model}") diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py deleted file mode 100644 index d55c2ca5e6..0000000000 --- a/mlc_llm/relax_model/commons.py +++ /dev/null @@ -1,363 +0,0 @@ -import json -from typing import Dict, List, Optional - -import mlc_llm -import tvm -from tvm import relax, te, tir, topi - - -def create_metadata_func( - bb: relax.BlockBuilder, - model_name: str, - max_window_size: int, - stop_tokens: List[int], - add_prefix_space: bool, - prefill_chunk_size: int = -1, - sliding_window: int = -1, -): - metadata = json.dumps( - { - "model_name": model_name, - "max_window_size": max_window_size, - "stop_tokens": stop_tokens, - "add_prefix_space": add_prefix_space, - "prefill_chunk_size": prefill_chunk_size, - "sliding_window": sliding_window, - } - ) - with bb.function("get_metadata", params=[]): - bb.emit_func_output(relax.StringImm(metadata)) - - -def _get_shard_strategies( - model_config, num_shards: int, param_shape_is_already_sharded: bool -) -> Dict[str, tvm.tir.PrimFunc]: - head_dim = model_config.hidden_size // model_config.num_attention_heads - q_heads = model_config.num_attention_heads - kv_heads = model_config.get_num_key_value_heads() - - # pylint: disable=invalid-name - def shard_qkv_weight_scale(weight: relax.TensorStructInfo): - (spatial, red), dtype = weight.shape, weight.dtype - spatial, red = int(spatial), int(red) - if param_shape_is_already_sharded: - spatial *= num_shards - a = te.placeholder((spatial, red), dtype=dtype) - w = topi.reshape(a, (spatial // head_dim, head_dim, red)) - q = te.compute((q_heads, head_dim, red), lambda i, j, k: w[i, j, k]) - k = te.compute((kv_heads, head_dim, red), lambda i, j, k: w[q_heads + i, j, k]) - v = te.compute((kv_heads, head_dim, red), lambda i, j, k: w[q_heads + kv_heads + i, j, k]) - q = topi.reshape(q, (num_shards, q_heads // num_shards, head_dim, red)) - k = topi.reshape(k, (num_shards, kv_heads // num_shards, head_dim, red)) - v = topi.reshape(v, (num_shards, kv_heads // num_shards, head_dim, red)) - w = topi.concatenate((q, k, v), axis=1) - w = topi.reshape(w, (num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim, red)) - func = te.create_prim_func([a, w]) - return func - - def shard_k_weight_scale(weight: relax.TensorStructInfo): - (spatial, red), dtype = weight.shape, weight.dtype - spatial, red = int(spatial), int(red) - if param_shape_is_already_sharded: - red *= num_shards - a = te.placeholder((spatial, red), dtype=dtype) - w = topi.reshape(a, (spatial, num_shards, red // num_shards)) - w = topi.transpose(w, (1, 0, 2)) - func = te.create_prim_func([a, w]) - return func - - def shard_axis_0(weight: relax.TensorStructInfo): - (red, spatial), dtype = weight.shape, weight.dtype - red, spatial = int(red), int(spatial) - if param_shape_is_already_sharded: - red *= num_shards - a = te.placeholder((red, spatial), dtype=dtype) - w = topi.reshape(a, (num_shards, red // num_shards, spatial)) - func = te.create_prim_func([a, w]) - return func - - def shard_axis_1(weight: relax.TensorStructInfo): - (spatial, red), dtype = weight.shape, weight.dtype - spatial, red = int(spatial), int(red) - if param_shape_is_already_sharded: - red *= num_shards - a = te.placeholder((spatial, red), dtype=dtype) - w = topi.reshape(a, (spatial, num_shards, red // num_shards)) - w = topi.transpose(w, (1, 0, 2)) - func = te.create_prim_func([a, w]) - return func - - def shard_gate_up_weight_scale(weight: relax.TensorStructInfo): - (spatial, red), dtype = weight.shape, weight.dtype - spatial, red = int(spatial), int(red) - if param_shape_is_already_sharded: - spatial *= num_shards - a = te.placeholder((spatial, red), dtype=dtype) - g = te.compute((spatial // 2, red), lambda i, j: a[i, j]) - u = te.compute((spatial // 2, red), lambda i, j: a[spatial // 2 + i, j]) - g = topi.reshape(g, (num_shards, spatial // 2 // num_shards, red)) - u = topi.reshape(u, (num_shards, spatial // 2 // num_shards, red)) - w = topi.concatenate((g, u), axis=1) - w = topi.reshape(w, (num_shards, spatial // num_shards, red)) - func = te.create_prim_func([a, w]) - return func - - # pylint: enable=invalid-name - - return { - "shard_qkv": shard_qkv_weight_scale, - "shard_mlp_k": shard_k_weight_scale, - "shard_o_proj_k": shard_k_weight_scale, - "shard_gate_up": shard_gate_up_weight_scale, - "shard_axis_0": shard_axis_0, - "shard_axis_1": shard_axis_1, - } - - -def _get_shard_strategies_ft( - model_config, num_shards: int, param_shape_is_already_sharded: bool -) -> Dict[str, tvm.tir.PrimFunc]: - q_heads = model_config.num_attention_heads - kv_heads = model_config.get_num_key_value_heads() - - def shard_qkv_weight_scale(x: relax.TensorStructInfo): - (red, spatial), dtype = x.shape, x.dtype - red, spatial = int(red), int(spatial) - if param_shape_is_already_sharded: - spatial *= num_shards - head_dim = spatial // (q_heads + 2 * kv_heads) - a = te.placeholder((red, spatial), dtype=dtype) - w = topi.reshape(a, (red, spatial // head_dim, head_dim)) - q = te.compute((red, q_heads, head_dim), lambda i, j, k: w[i, j, k]) - k = te.compute((red, kv_heads, head_dim), lambda i, j, k: w[i, q_heads + j, k]) - v = te.compute((red, kv_heads, head_dim), lambda i, j, k: w[i, q_heads + kv_heads + j, k]) - q = topi.reshape(q, (red, num_shards, q_heads // num_shards, head_dim)) - k = topi.reshape(k, (red, num_shards, kv_heads // num_shards, head_dim)) - v = topi.reshape(v, (red, num_shards, kv_heads // num_shards, head_dim)) - w = topi.concatenate((q, k, v), axis=2) - w = topi.reshape(w, (red, num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim)) - w = topi.transpose(w, (1, 0, 2)) - func = te.create_prim_func([a, w]) - return func - - def shard_k_weight(weight: relax.TensorStructInfo): - (red, spatial), dtype = weight.shape, weight.dtype - red, spatial = int(red), int(spatial) - if param_shape_is_already_sharded: - red *= num_shards - a = te.placeholder((red, spatial), dtype=dtype) - w = topi.reshape(a, (num_shards, red // num_shards, spatial)) - func = te.create_prim_func([a, w]) - return func - - def shard_axis_0(weight: relax.TensorStructInfo): - (red, spatial), dtype = weight.shape, weight.dtype - red, spatial = int(red), int(spatial) - if param_shape_is_already_sharded: - red *= num_shards - a = te.placeholder((red, spatial), dtype=dtype) - w = topi.reshape(a, (num_shards, red // num_shards, spatial)) - func = te.create_prim_func([a, w]) - return func - - def shard_axis_1(weight: relax.TensorStructInfo): - (spatial, red), dtype = weight.shape, weight.dtype - spatial, red = int(spatial), int(red) - if param_shape_is_already_sharded: - red *= num_shards - a = te.placeholder((spatial, red), dtype=dtype) - w = topi.reshape(a, (spatial, num_shards, red // num_shards)) - w = topi.transpose(w, (1, 0, 2)) - func = te.create_prim_func([a, w]) - return func - - def shard_gate_up_weight_scale(x: relax.TensorStructInfo): - (red, spatial), dtype = x.shape, x.dtype - red, spatial = int(red), int(spatial) - if param_shape_is_already_sharded: - spatial *= num_shards - a = te.placeholder((red, spatial), dtype=dtype) - g = te.compute((red, spatial // 2), lambda i, j: a[i, j]) - u = te.compute((red, spatial // 2), lambda i, j: a[i, spatial // 2 + j]) - g = topi.reshape(g, (red, num_shards, spatial // 2 // num_shards)) - u = topi.reshape(u, (red, num_shards, spatial // 2 // num_shards)) - w = topi.concatenate((g, u), axis=2) - w = topi.reshape(w, (red, num_shards, spatial // num_shards)) - w = topi.transpose(w, (1, 0, 2)) - func = te.create_prim_func([a, w]) - return func - - return { - "shard_qkv": shard_qkv_weight_scale, - "shard_mlp_k": shard_k_weight, - "shard_o_proj_k": shard_k_weight, - "shard_gate_up": shard_gate_up_weight_scale, - "shard_axis_0": shard_axis_0, - "shard_axis_1": shard_axis_1, - } - - -def create_shard_info_func(param_manager, args, model_config) -> tvm.IRModule: - shard_strategy_to_func = _get_shard_strategies( - model_config, - num_shards=args.num_shards, - param_shape_is_already_sharded=args.build_model_only, - ) - - shard_info_dict = {} - shard_funcs = {} - - def add_to_shard_info(param_name: str, func_name: Optional[str]): - shard_info = [] - if func_name is not None: - func = shard_funcs[func_name] - buffer = func.buffer_map[func.params[-1]] - shape = [int(i) for i in buffer.shape] - dtype = str(buffer.dtype) - shard_info.append((func_name, [shape, dtype])) - - shard_info_dict[param_name] = shard_info - - q_params = [param.struct_info for param in param_manager.get_quantized_params("prefill")] - for _, param in param_manager.params.items(): - if param.shard_strategy is None: - pass - elif param.shard_strategy in shard_strategy_to_func: - for i, weight in enumerate(param_manager.param2qrange[param]): - if args.use_presharded_weights: - sharding_func_name = None - else: - sharding_func_name = f"{param.shard_strategy}_{i}" - if sharding_func_name not in shard_funcs: - shard_funcs[sharding_func_name] = shard_strategy_to_func[ - param.shard_strategy - ](q_params[weight]) - add_to_shard_info(f"param_{weight}", sharding_func_name) - else: - raise NotImplementedError(f"Shard strategy not implemented: {param.shard_strategy}") - - bb = relax.BlockBuilder() # pylint: disable=invalid-name - - for name, func in shard_funcs.items(): - func = func.with_attr({"global_symbol": name}) - bb.add_func(func, name) - - with bb.function("get_shard_info", params=[]): - bb.emit_func_output(relax.StringImm(json.dumps(shard_info_dict))) - - return bb.get() - - -def create_shard_transformation_func(param_manager, args, model_config) -> tvm.IRModule: - use_ft_quant = args.quantization.name in [ - "q4f16_ft", - "q8f16_ft", - "q4f16_ft_group", - "q8f16_ft_group", - ] - - if use_ft_quant: - shard_strategy_to_func = _get_shard_strategies_ft( - model_config, - num_shards=args.num_shards, - param_shape_is_already_sharded=args.build_model_only, - ) - else: - shard_strategy_to_func = _get_shard_strategies( - model_config, - num_shards=args.num_shards, - param_shape_is_already_sharded=args.build_model_only, - ) - - q_params = [param.struct_info for param in param_manager.get_quantized_params("prefill")] - - # The order of the quantized parameters must be preserved. - # Therefore, we need to loop over q_params and look up information - # as needed, rather than looping over original parameters and - # looking up the quantized parameters as needed. - orig_param_lookup = {} - for param in param_manager.params_in_func["prefill"]: - qrange = param_manager.param2qrange[param] - for i_orig_part, i_qparam in enumerate(qrange): - orig_param_lookup[i_qparam] = ( - param, - i_orig_part, - len(qrange), - ) - - bb = relax.BlockBuilder() # pylint: disable=invalid-name - with bb.function("transform_params", attrs={"num_input": 1}): - rank = tir.SizeVar("rank", "int64") - # TODO(Lunderberg): Support primitive inputs to relax - # functions. Currently, using a PrimStructInfo as the - # argument results in an error thrown during - # `vm_shape_lower.cc`, due to BindParams failing to replace - # the symbolic variable "rank" when defined in a R.PrimValue. - # - # rank_arg = relax.Var("rank", relax.PrimStructInfo(value=rank)) - rank_arg = relax.Var("rank_arg", relax.ShapeStructInfo([rank])) - - args = [rank_arg] - output = [] - - for i_qparam, qparam_sinfo in enumerate(q_params): - param, i_orig_part, num_orig_parts = orig_param_lookup[i_qparam] - - if isinstance(param.quant_spec, mlc_llm.quantization.NoQuantizationSpec): - arg_name = param.name - elif num_orig_parts == 1: - arg_name = f"{param.name}.quantized" - else: - arg_name = f"{param.name}.quantized_{i_orig_part}" - - arg = relax.Var(arg_name, qparam_sinfo) - - if param.shard_strategy is None or ( - use_ft_quant - and param.shard_strategy in ["shard_mlp_k", "shard_o_proj_k"] - and qparam_sinfo.shape[0] == 1 - ): - sharded = arg - else: - strategy_func = shard_strategy_to_func[param.shard_strategy]( - qparam_sinfo - ).without_attr("global_symbol") - strategy_gvar = bb.add_func( - strategy_func, - func_name=f"{arg_name}.sharding_func", - ) - - # TODO(Lunderberg): Write the strategies as relax - # functions, so the sharded shapes can be inferred. - reordered_buffer = strategy_func.buffer_map[strategy_func.params[-1]] - reordered_sinfo = relax.TensorStructInfo( - reordered_buffer.shape, reordered_buffer.dtype - ) - reordered = relax.op.call_tir( - strategy_gvar, relax.Tuple([arg]), out_sinfo=reordered_sinfo - ) - - # TODO(Lunderberg): Allow relax.PrimValue as the index - # in a TupleGetItem. This would allow all of the - # splits to be generated at once in the merged - # function, and could be optimized to an in-place view. - # - # split = relax.op.split(reordered, indices_or_sections=num_shards, axis=0)[rank] - split = relax.op.strided_slice( - reordered, - axes=[0], - begin=[rank], - end=[rank + 1], - assume_inbound=True, - ) - - sharded = relax.op.squeeze(split, axis=0) - - args.append(arg) - output.append(sharded) - - with bb.dataflow(): - gv = bb.emit_output(output) - bb.emit_func_output(output=gv, params=args) - - return bb.get() diff --git a/mlc_llm/relax_model/gpt_bigcode.py b/mlc_llm/relax_model/gpt_bigcode.py deleted file mode 100644 index 4f72400e3c..0000000000 --- a/mlc_llm/relax_model/gpt_bigcode.py +++ /dev/null @@ -1,667 +0,0 @@ -import argparse -import math -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import tvm -from tvm import relax, te -from tvm.relax.op import ( - astype, - broadcast_to, - expand_dims, - matmul, - maximum, - minimum, - permute_dims, - reshape, - squeeze, -) -from tvm.relax.op.nn import gelu, layer_norm, softmax -from tvm.relax.testing import nn -from tvm.script import relax as R - -from ..quantization import ParamQuantKind, QuantizationScheme -from .commons import create_metadata_func -from .modules import Embedding, Linear, ModuleList -from .param_manager import ParamManager - - -@dataclass -class GPTBigCodeConfig: - def __init__( - self, - bos_token_id: int = 0, - eos_token_id: int = 0, - initializer_range: float = 0.02, - layer_norm_epsilon: float = 1e-05, - max_sequence_length: int = 2048, - n_embd: int = 6144, - n_head: int = 48, - n_inner: int = 24576, - n_layer: int = 40, - n_positions: int = 8192, - scale_attn_weights: bool = True, - vocab_size: int = 49152, - dtype: str = "float32", - **kwargs, - ): - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.initializer_range = initializer_range - self.layer_norm_epsilon = layer_norm_epsilon - self.max_sequence_length = max_sequence_length - self.n_embd = n_embd - self.n_head = n_head - self.n_inner = n_inner - self.n_layer = n_layer - self.n_positions = n_positions - self.scale_attn_weights = scale_attn_weights - self.vocab_size = vocab_size - self.dtype = dtype - self.kwargs = kwargs - - -def _prepare_decoder_attention_mask(input_shape, src_len, dtype): - # create causal mask - # [bsz, seq_len] -> [bsz, tgt_seq_len, 1, src_seq_len] - if isinstance(input_shape[-1], tvm.tir.SizeVar) or input_shape[-1] > 1: - bsz, tgt_len = input_shape - - def min_max_triu_te(): - return te.compute( - (tgt_len, tgt_len), - lambda i, j: tvm.tir.Select( - j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype) - ), - name="make_diag_mask_te", - ) - - mask = nn.emit_te(min_max_triu_te) - mask = nn.emit(expand_dims(mask, 1)) - diag_mask = nn.emit(broadcast_to(mask, (bsz, tgt_len, 1, tgt_len))) - if src_len == tgt_len: - return diag_mask - - def extend_te(x, tgt_len, src_len): - return te.compute( - (bsz, tgt_len, 1, src_len), - lambda b, i, _, j: te.if_then_else( - j < src_len - tgt_len, - tvm.tir.max_value(dtype), - x[b, i, _, j - (src_len - tgt_len)], - ), - name="concat_te", - ) - - return nn.emit_te(extend_te, diag_mask, tgt_len, src_len) - else: - # Get src_len from input parameters - # [bsz, seq_len] -> [bsz, tgt_seq_len, 1, src_seq_len] - bsz, tgt_len = input_shape - mask = relax.op.full( - (bsz, tgt_len, 1, src_len), - relax.const(tvm.tir.max_value(dtype).value, dtype), - dtype, - ) - return nn.emit(mask) - - -def apply_position_embedding(t_embd, weight, offset: int = 0): - def f_position_embedding(tensor, weight, offset): - def position_compute(*idx): - b, s, e = idx - return weight[s + offset, e] + tensor[b, s, e] - - return tvm.te.compute(tensor.shape, position_compute, name="position") - - hidden_states = nn.emit_te( - f_position_embedding, - t_embd, - weight, - offset, - primfunc_name_hint="position_embedding", - ) - return hidden_states - - -class LayerNorm(nn.Module): - def __init__( - self, - hidden_size, - dtype, - eps=1e-5, - ): - super().__init__() - self.dtype = dtype - - self.eps = eps - self.weight = nn.Parameter((hidden_size,), dtype=dtype, name="weight") - self.bias = nn.Parameter((hidden_size,), dtype=dtype, name="bias") - - def forward(self, x: relax.Expr) -> relax.Var: - if x.struct_info.dtype != self.dtype: - x = nn.emit(relax.op.astype(x, self.dtype)) - x = nn.emit( - layer_norm( - x, - gamma=self.weight, - beta=self.bias, - axes=-1, - epsilon=self.eps, - ) - ) - return x - - -class GPTBigCodeAttention(nn.Module): - """Multi-query attention from 'Fast Transformer Decoding: One Write-Head is All You Need'""" - - def __init__(self, config: GPTBigCodeConfig): - if config.n_embd % config.n_head != 0: - raise ValueError( - f"hidden_size must be divisible by n_head (got `hidden_size`: {config.n_embd}" - f" and `n_head`: {config.n_head})." - ) - self.n_embd = config.n_embd - self.n_head = config.n_head - self.head_dim = config.n_embd // config.n_head - - self.c_attn = Linear(self.n_embd, self.n_embd + 2 * self.head_dim, config.dtype, bias=True) - self.c_proj = Linear(self.n_embd, self.n_embd, config.dtype, bias=True) - - self.dtype = config.dtype - - def forward( - self, - hidden_states: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_value: Optional[Tuple[relax.Expr, relax.Expr]] = None, - attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Union[Tuple[None, None], Tuple[relax.Expr, relax.Expr]]]: - # hidden_states: [batch_size, seq_len, n_embd] - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - - batch_size, seq_len, _ = hidden_states.struct_info.shape - kv_seq_len = all_seq_len_shape.struct_info.values[0] - - def te_slice(x: te.Tensor, start: int, end: int): - batch_size, seq_len, _ = x.shape - return te.compute( - shape=(batch_size, seq_len, end - start), - fcompute=lambda i, j, k: x[i, j, start + k], - name="slice", - ) - - query_key_value = self.c_attn(hidden_states) - # queries: [batch_size, seq_len, n_embd] - q = nn.emit_te(te_slice, query_key_value, 0, self.n_embd, primfunc_name_hint="slice") - # keys: [batch_size, seq_len, head_dim] - k = nn.emit_te( - te_slice, - query_key_value, - self.n_embd, - self.n_embd + self.head_dim, - primfunc_name_hint="slice", - ) - # values: [batch_size, seq_len, head_dim] - v = nn.emit_te( - te_slice, - query_key_value, - self.n_embd + self.head_dim, - self.n_embd + 2 * self.head_dim, - primfunc_name_hint="slice", - ) - - squeezed_k = nn.emit(squeeze(k, axis=0)) - squeezed_v = nn.emit(squeeze(v, axis=0)) - - assert k.struct_info.shape[0] == 1 and v.struct_info.shape[0] == 1 - - k_cache, v_cache = past_key_value - f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") - k_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_append, - k_cache, - squeezed_k, - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - v_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_append, - v_cache, - squeezed_v, - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - past_key_value = (k_cache, v_cache) - - batch_size, _, head_size = k.struct_info.shape - kv_cache_shape = R.shape([kv_seq_len, head_size]) - kv_states_shape = R.shape([batch_size, kv_seq_len, head_size]) - f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") - k = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - k_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)], - ) - ) - v = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - v_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)], - ) - ) - - k = nn.emit(reshape(k, kv_states_shape)) - v = nn.emit(reshape(v, kv_states_shape)) - - q_state_shape = R.shape([batch_size, seq_len * self.n_head, self.head_dim]) - q = nn.emit(reshape(q, q_state_shape)) - - # Calculate Q.K - attn_weights = nn.emit( - matmul(q, permute_dims(k, [0, 2, 1])) - / relax.const(math.sqrt(self.head_dim), q.struct_info.dtype) - ) - - # Apply attention mask - attn_weights = nn.emit( - maximum( - attn_weights, - relax.const( - tvm.tir.min_value(attn_weights.struct_info.dtype).value, - attn_weights.struct_info.dtype, - ), - ) - ) - attn_shape = R.shape([batch_size, seq_len, self.n_head, kv_seq_len]) - attn_view = R.shape([batch_size, seq_len * self.n_head, kv_seq_len]) - attn_weights = nn.emit(reshape(attn_weights, attn_shape)) - attn_weights = nn.emit(minimum(attn_weights, attention_mask)) - attn_weights = nn.emit(reshape(attn_weights, attn_view)) - - # Calculate Softmax(Q.K) - if attn_weights.struct_info.dtype != "float32": - attn_weights = astype(attn_weights, "float32") - attn_weights = nn.emit(softmax(attn_weights, axis=-1)) - if attn_weights.struct_info.dtype != q.struct_info.dtype: - attn_weights = astype(attn_weights, q.struct_info.dtype) - - # Calculate Softmax(Q.K).V - attn_output = nn.emit(matmul(attn_weights, v)) - - # Apply output projection - attn_output = self.c_proj( - reshape( - attn_output, - (batch_size, seq_len, self.n_embd), - ) - ) - - return attn_output, past_key_value - - -class GPTBigCodeMLP(nn.Module): - def __init__(self, config: GPTBigCodeConfig): - super().__init__() - self.dtype = config.dtype - - self.c_fc = Linear(config.n_embd, config.n_inner, config.dtype, bias=True) - self.c_proj = Linear(config.n_inner, config.n_embd, config.dtype, bias=True) - - def forward(self, hidden_states): - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - - hidden_states = self.c_fc(hidden_states) - hidden_states = nn.emit(gelu(hidden_states)) - hidden_states = self.c_proj(hidden_states) - - return hidden_states - - -class GPTBigCodeBlock(nn.Module): - def __init__(self, config: GPTBigCodeConfig): - self.dtype = config.dtype - - self.ln_1 = LayerNorm( - hidden_size=config.n_embd, dtype=config.dtype, eps=config.layer_norm_epsilon - ) - self.ln_2 = LayerNorm( - hidden_size=config.n_embd, dtype=config.dtype, eps=config.layer_norm_epsilon - ) - - self.attn = GPTBigCodeAttention(config) - self.mlp = GPTBigCodeMLP(config) - - def forward( - self, - hidden_states, - all_seq_len_shape: relax.Expr, - past_key_value: Tuple[relax.Expr], - attention_mask: Optional[relax.Expr] = None, - ): - attn_input = self.ln_1(hidden_states) - attn_output, present_key_value = self.attn( - attn_input, all_seq_len_shape, past_key_value, attention_mask - ) - - # residual connection - attn_output = nn.emit(attn_output + hidden_states) - - mlp_input = self.ln_2(attn_output) - mlp_output = self.mlp(mlp_input) - - # residual connection - hidden_states = nn.emit(astype(mlp_output, self.dtype) + attn_output) - - return hidden_states, present_key_value - - -class GPTBigCodeModel(nn.Module): - def __init__(self, config: GPTBigCodeConfig): - self.wte = Embedding( - num_embeddings=config.vocab_size, - embedding_dim=config.n_embd, - dtype=config.dtype, - ) - self.wpe = Embedding( - num_embeddings=config.n_positions, - embedding_dim=config.n_embd, - dtype=config.dtype, - ) - - self.h = ModuleList([GPTBigCodeBlock(config) for _ in range(config.n_layer)]) - self.ln_f = LayerNorm( - hidden_size=config.n_embd, dtype=config.dtype, eps=config.layer_norm_epsilon - ) - - def forward( - self, - input_ids: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_values: relax.Expr, - ): - batch_size, seq_length = input_ids.struct_info.shape - seq_length_with_past = all_seq_len_shape.struct_info.values[0] - - # Token Embeddings - t_embd = self.wte(input_ids) - - # Position Embeddings - offset = seq_length_with_past - seq_length - hidden_states = apply_position_embedding(t_embd, self.wpe.weight, offset=offset) - - attention_mask = _prepare_decoder_attention_mask( - (batch_size, seq_length), - seq_length_with_past, - dtype=hidden_states.struct_info.dtype, - ) - - present_kv_cache = [] - for i, block in enumerate(self.h): - past_key_value = ( - (past_key_values[i * 2], past_key_values[i * 2 + 1]) - if past_key_values is not None - else None - ) - hidden_states, (present_k_cache, present_v_cache) = block( - hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - all_seq_len_shape=all_seq_len_shape, - ) - present_kv_cache.append(present_k_cache) - present_kv_cache.append(present_v_cache) - hidden_states = self.ln_f(hidden_states) - return hidden_states, present_kv_cache - - -class GPTBigCodeForCausalLM(nn.Module): - def __init__(self, config: GPTBigCodeConfig): - self.dtype = config.dtype - - self.transformer = GPTBigCodeModel(config) - self.lm_head = Linear( - in_features=config.n_embd, - out_features=config.vocab_size, - bias=False, - dtype=config.dtype, - ) - - def forward( - self, - input_ids: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_values: relax.Expr, - ): - hidden_states, key_value_cache = self.transformer( - input_ids=input_ids, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - - def te_slice_last(x: te.Tensor): - _, seq_len, n_embd = x.shape - return te.compute( - shape=(1, 1, n_embd), - fcompute=lambda i, _, k: x[i, seq_len - 1, k], - name="slice_last", - ) - - hidden_states = nn.emit_te( - te_slice_last, - hidden_states, - primfunc_name_hint="slice_last", - ) - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - - logits = self.lm_head(hidden_states) - - if logits.struct_info.dtype != "float32": - logits = nn.emit(astype(logits, "float32")) - - return logits, key_value_cache - - -def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: - if "wte.weight" in name: - return ParamQuantKind.embedding_table - elif "lm_head.weight" in name: - return ParamQuantKind.final_fc_weight - elif "wpe" not in name and param_info.ndim == 2 and name.endswith(".weight"): - return ParamQuantKind.linear_weight - else: - return ParamQuantKind.others - - -def create_encoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: GPTBigCodeConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "prefill" - - batch_size = tvm.tir.IntImm("int64", 1) - seq_len = tvm.tir.SizeVar("n", "int64") - all_seq_len = tvm.tir.SizeVar("m", "int64") - with bb.function(func_name): - model = GPTBigCodeForCausalLM(config) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") - all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo([relax.ObjectStructInfo() for _ in range(config.n_layer * 2)]), - ) - - with bb.dataflow(): - logits, key_value_cache = model( - input_ids=input_ids, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - params = [ - input_ids, - all_seq_len_shape, - past_key_values, - ] + model.parameters() - - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_decoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: GPTBigCodeConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "decode" - - bsz = tvm.tir.IntImm("int64", 1) - seq_len = tvm.tir.IntImm("int64", 1) - all_seq_len = tvm.tir.SizeVar("m", "int64") - - with bb.function(func_name): - model = GPTBigCodeForCausalLM(config) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") - all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo([relax.ObjectStructInfo() for _ in range(config.n_layer * 2)]), - ) - with bb.dataflow(): - logits, key_value_cache = model( - input_ids=input_ids, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - params = [ - input_ids, - all_seq_len_shape, - past_key_values, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_kv_cache_func(bb: relax.BlockBuilder, config: GPTBigCodeConfig) -> None: - init_shape = relax.ShapeExpr( - ( - config.max_sequence_length, - config.n_embd // config.n_head, - ) - ) - with bb.function("create_kv_cache", []): - with bb.dataflow(): - zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) - caches = [] - f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") - for _ in range(config.n_layer * 2): - caches.append( - bb.emit( - relax.call_pure_packed( - f_kv_cache_create, - zeros, - init_shape, - relax.PrimValue(0), - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - ) - gv = bb.emit_output(caches) - bb.emit_func_output(gv) - - -def create_softmax_func(bb: relax.BlockBuilder, config: GPTBigCodeConfig) -> None: - with bb.function("softmax_with_temperature"): - logits = nn.Placeholder((1, 1, config.vocab_size), dtype="float32", name="logits") - temperature = nn.Placeholder((), dtype="float32", name="temperature") - with bb.dataflow(): - div = bb.emit(relax.op.divide(logits, temperature)) - softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) - gv = bb.emit_output(softmax) - bb.emit_func_output(gv, [logits, temperature]) - - -def get_model(args: argparse.Namespace, hf_config): - model = args.model - dtype = args.quantization.model_dtype - max_seq_len = args.max_seq_len - - if ( - model.startswith("starcoder") - or model.startswith("WizardCoder-") - or model.startswith("gpt_bigcode") - ): - config = GPTBigCodeConfig( - **hf_config, - dtype=dtype, - ) - if max_seq_len != -1: - config.max_sequence_length = max_seq_len - elif config.max_sequence_length is None: - config.max_sequence_length = 2048 - - param_manager = ParamManager() - bb = relax.BlockBuilder() - create_encoding_func(bb, param_manager, config, args.quantization) - create_decoding_func(bb, param_manager, config, args.quantization) - create_kv_cache_func(bb, config) - create_softmax_func(bb, config) - create_metadata_func( - bb, - model_name=model, - max_window_size=config.max_sequence_length, - stop_tokens=[0], - add_prefix_space=False, - prefill_chunk_size=args.prefill_chunk_size, - ) - - mod = bb.get() - - tir_bound_map = dict() - tir_bound_map["n"] = ( - args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length - ) - tir_bound_map["m"] = config.max_sequence_length - for gv in mod.functions: - func = mod[gv] - if isinstance(func, relax.Function): - mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) - - if args.build_model_only: - return mod, param_manager, None, config - - param_manager.set_param_loading_func( - args.model_path, - args.use_safetensors, - f_convert_param_bkwd=lambda torch_pname, torch_param: [ - (torch_pname, torch_param.astype(dtype)) - ], - ) - return mod, param_manager, [None] * len(param_manager.param_names), config - - raise ValueError(f"Unsupported model {model}") diff --git a/mlc_llm/relax_model/gpt_neox.py b/mlc_llm/relax_model/gpt_neox.py deleted file mode 100644 index 30f2d25ac5..0000000000 --- a/mlc_llm/relax_model/gpt_neox.py +++ /dev/null @@ -1,739 +0,0 @@ -# pylint: disable=missing-docstring,too-few-public-methods,too-many-instance-attributes,invalid-name,too-many-locals,too-many-arguments -import argparse -import math -from typing import List, Optional, Tuple, Union - -import tvm -from tvm import relax, te -from tvm.relax.op import ( - astype, - broadcast_to, - matmul, - maximum, - minimum, - permute_dims, - reshape, - squeeze, -) -from tvm.relax.op.nn import gelu, softmax -from tvm.relax.testing import nn -from tvm.script import relax as R - -from ..quantization import ParamQuantKind, QuantizationScheme -from .commons import create_metadata_func -from .modules import Embedding, LayerNorm, Linear, ModuleList, RotaryEmbedding -from .param_manager import ParamManager - - -class GPTNeoXConfig: # pylint: disable=too-many-instance-attributes - def __init__( - self, - use_parallel_residual, - hidden_size, - intermediate_size, - num_attention_heads, - num_hidden_layers, - vocab_size, - rotary_pct, - rotary_emb_base, - layer_norm_eps, - max_sequence_length, - dtype, - ffn_out_dtype, - **kwargs, - ): - self.use_parallel_residual = use_parallel_residual - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_attention_heads = num_attention_heads - self.num_hidden_layers = num_hidden_layers - self.vocab_size = vocab_size - self.rotary_pct = rotary_pct - self.rotary_emb_base = rotary_emb_base - self.layer_norm_eps = layer_norm_eps - self.max_sequence_length = max_sequence_length - self.dtype = dtype - self.ffn_out_dtype = ffn_out_dtype - self.kwargs = kwargs - - -class GPTNeoXAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - hidden_size: int, - num_heads: int, - rotary_embedding: RotaryEmbedding, - dtype: str, - ): - if hidden_size % num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {hidden_size}" - f" and `num_heads`: {num_heads})." - ) - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - self.rotary_embedding = rotary_embedding - self.query_key_value = Linear(hidden_size, hidden_size * 3, dtype, bias=True) - self.dense = Linear(hidden_size, hidden_size, dtype, bias=True) - self.dtype = dtype - - def forward( - self, - hidden_states: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_value: Optional[Tuple[relax.Expr, relax.Expr]] = None, - attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Union[Tuple[None, None], Tuple[relax.Expr, relax.Expr]]]: - # hidden_states: [batch_size, seq_len, hidden_size] - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - batch_size, seq_len, _ = hidden_states.struct_info.shape - kv_seq_len = all_seq_len_shape.struct_info.values[0] - - # qkv_states: [batch_size, seq_len, hidden_size * 3] - qkv_states = nn.emit( - relax.op.split( - reshape( - self.query_key_value(hidden_states), - (batch_size, seq_len, self.num_heads, 3 * self.head_dim), - ), - indices_or_sections=3, - axis=-1, - ) - ) - - # q/k/v states: [batch_size, seq_len, num_attention_heads, head_size] - q, k, v = [relax.TupleGetItem(qkv_states, idx) for idx in range(3)] - q, k = self.rotary_embedding(q, k, kv_seq_len - seq_len) - - if past_key_value is not None: - f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") - f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") - k_cache, v_cache = past_key_value - k_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_append, - k_cache, - squeeze(k, axis=0), - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - v_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_append, - v_cache, - squeeze(v, axis=0), - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - batch_size, _, num_heads, head_size = k.struct_info.shape - kv_cache_shape = R.shape([kv_seq_len, num_heads, head_size]) - kv_states_shape = R.shape([batch_size, kv_seq_len, num_heads, head_size]) - k = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - k_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)], - ) - ) - v = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - v_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)], - ) - ) - k = nn.emit(reshape(k, kv_states_shape)) - v = nn.emit(reshape(v, kv_states_shape)) - past_key_value = (k_cache, v_cache) - else: - past_key_value = (None, None) - - q = nn.emit(permute_dims(q, [0, 2, 1, 3])) - k = nn.emit(permute_dims(k, [0, 2, 1, 3])) - v = nn.emit(permute_dims(v, [0, 2, 1, 3])) - - # Calculate QK - attn_weights = nn.emit( - matmul(q, permute_dims(k, [0, 1, 3, 2])) - / relax.const( - math.sqrt(self.head_dim), - q.struct_info.dtype, - ) - ) - # Apply attention mask - attn_weights = nn.emit( - maximum( - attn_weights, - relax.const( - tvm.tir.min_value(attn_weights.struct_info.dtype).value, - attn_weights.struct_info.dtype, - ), - ) - ) - attn_weights = nn.emit(minimum(attn_weights, attention_mask)) - # Calculate Softmax(QK) - if attn_weights.struct_info.dtype != "float32": - attn_weights = astype(attn_weights, "float32") - attn_weights = nn.emit(softmax(attn_weights, axis=-1)) - if attn_weights.struct_info.dtype != q.struct_info.dtype: - attn_weights = astype(attn_weights, q.struct_info.dtype) - # Calculate Softmax(QK)V - attn_output = nn.emit(matmul(attn_weights, v)) - # Apply output projection - attn_output = self.dense( - reshape( - permute_dims(attn_output, [0, 2, 1, 3]), - (batch_size, seq_len, self.hidden_size), - ) - ) - return attn_output, past_key_value - - -class GPTNeoXMLP(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - dtype: str, - out_dtype: Optional[str], - ): - super().__init__() - if out_dtype is None: - out_dtype = dtype - self.dense_h_to_4h = Linear( - hidden_size, - intermediate_size, - dtype=dtype, - out_dtype=out_dtype, - ) - self.dense_4h_to_h = Linear( - intermediate_size, - hidden_size, - dtype=dtype, - out_dtype=out_dtype, - ) - self.dtype = dtype - - def forward(self, hidden_states): - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - hidden_states = self.dense_h_to_4h(hidden_states) - hidden_states = nn.emit(gelu(hidden_states)) - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - hidden_states = self.dense_4h_to_h(hidden_states) - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - return hidden_states - - -class GPTNeoXLayer(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - layer_norm_eps: float, - num_heads: int, - use_parallel_residual: bool, - rotary_embedding: RotaryEmbedding, - dtype: str, - ffn_out_dtype: Optional[str], - ): - self.input_layernorm = LayerNorm( - hidden_size, - eps=layer_norm_eps, - dtype=dtype, - ) - self.post_attention_layernorm = LayerNorm( - hidden_size, - eps=layer_norm_eps, - dtype=dtype, - ) - self.attention = GPTNeoXAttention( - hidden_size, - num_heads=num_heads, - rotary_embedding=rotary_embedding, - dtype=dtype, - ) - self.mlp = GPTNeoXMLP( - hidden_size, - intermediate_size=intermediate_size, - dtype=dtype, - out_dtype=ffn_out_dtype, - ) - self.use_parallel_residual = use_parallel_residual - self.dtype = dtype - - def forward( - self, - hidden_states, - all_seq_len_shape: relax.Expr, - past_key_value: Optional[Tuple[relax.Expr]] = None, - attention_mask: Optional[relax.Expr] = None, - ): - attn_input = self.input_layernorm(hidden_states) - attn_output, present_key_value = self.attention( - attn_input, - all_seq_len_shape, - past_key_value, - attention_mask, - ) - if self.use_parallel_residual: - mlp_input = self.post_attention_layernorm(hidden_states) - mlp_output = self.mlp(mlp_input) - hidden_states = nn.emit(mlp_output + attn_output + hidden_states) - else: - attn_output = nn.emit(attn_output + hidden_states) - mlp_input = self.post_attention_layernorm(attn_output) - mlp_output = self.mlp(mlp_input) - hidden_states = nn.emit(astype(mlp_output, self.dtype) + attn_output) - return hidden_states, present_key_value - - -def _prepare_decoder_attention_mask(input_shape, src_len, dtype): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if isinstance(input_shape[-1], tvm.tir.SizeVar) or input_shape[-1] > 1: - bsz, tgt_len = input_shape - - def min_max_triu_te(): - return te.compute( - (tgt_len, tgt_len), - lambda i, j: tvm.tir.Select( - j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype) - ), - name="make_diag_mask_te", - ) - - mask = nn.emit_te(min_max_triu_te) - diag_mask = nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, tgt_len))) - if src_len == tgt_len: - return diag_mask - - def extend_te(x, tgt_len, src_len): - return te.compute( - (bsz, 1, tgt_len, src_len), - lambda b, _, i, j: te.if_then_else( - j < src_len - tgt_len, - tvm.tir.max_value(dtype), - x[b, _, i, j - (src_len - tgt_len)], - ), - name="concat_te", - ) - - return nn.emit_te(extend_te, diag_mask, tgt_len, src_len) - else: - # Get src_len from input parameters - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - bsz, tgt_len = input_shape - mask = relax.op.full( - (bsz, 1, tgt_len, src_len), - relax.const(tvm.tir.max_value(dtype).value, dtype), - dtype, - ) - return nn.emit(mask) - - -class GPTNeoXEmbedTokens(nn.Module): - def __init__(self, config: GPTNeoXConfig): - self.embed_in = Embedding( - num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - dtype=config.dtype, - ) - - def forward(self, input_ids: relax.Expr): - return self.embed_in(input_ids) - - -class GPTNeoXEmbedTokensWrapper(nn.Module): - def __init__(self, config: GPTNeoXConfig): - # build a wrapper to ensure that the naming of the embed_in parameter is consistent - self.gpt_neox = GPTNeoXEmbedTokens(config) - - def forward(self, input_ids: relax.Expr): - return self.gpt_neox(input_ids) - - -class GPTNeoXModel(nn.Module): - def __init__( - self, - config: GPTNeoXConfig, - sep_embed: bool = False, - ): - rotary_embedding = RotaryEmbedding( - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - position_embedding_base=config.rotary_emb_base, - max_sequence_length=config.max_sequence_length, - rotary_pct=config.rotary_pct, - dtype=config.dtype, - ) - - self.embed_in = None - if not sep_embed: - self.embed_in = Embedding( - num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - dtype=config.dtype, - ) - - self.layers = ModuleList( - [ - GPTNeoXLayer( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - layer_norm_eps=config.layer_norm_eps, - num_heads=config.num_attention_heads, - rotary_embedding=rotary_embedding, - use_parallel_residual=config.use_parallel_residual, - dtype=config.dtype, - ffn_out_dtype=config.ffn_out_dtype, - ) - for _ in range(config.num_hidden_layers) - ] - ) - self.final_layer_norm = LayerNorm( - hidden_size=config.hidden_size, - eps=config.layer_norm_eps, - dtype=config.dtype, - ) - - def forward( - self, - inputs: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_values: Optional[Tuple[relax.Expr, relax.Expr]], - ): - # embed positions - hidden_states = self.embed_in(inputs) if self.embed_in else inputs - - batch_size, seq_length, _ = hidden_states.struct_info.shape - seq_length_with_past = all_seq_len_shape.struct_info.values[0] - attention_mask = _prepare_decoder_attention_mask( - (batch_size, seq_length), - seq_length_with_past, - dtype=hidden_states.struct_info.dtype, - ) - present_kv_cache = [] - for i, layer in enumerate(self.layers): - past_key_value = ( - (past_key_values[i * 2], past_key_values[i * 2 + 1]) - if past_key_values is not None - else None - ) - hidden_states, (present_k_cache, present_v_cache) = layer( - hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - all_seq_len_shape=all_seq_len_shape, - ) - present_kv_cache.append(present_k_cache) - present_kv_cache.append(present_v_cache) - hidden_states = self.final_layer_norm(hidden_states) - return hidden_states, present_kv_cache - - -class GPTNeoXForCausalLM(nn.Module): - def __init__( - self, - config: GPTNeoXConfig, - sep_embed: bool = False, - ): - self.gpt_neox = GPTNeoXModel(config, sep_embed) - self.embed_out = Linear( - in_features=config.hidden_size, - out_features=config.vocab_size, - bias=False, - dtype="float32", - ) - - def forward( - self, - inputs: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_values: Optional[List[relax.Expr]], - ): - hidden_states, key_value_cache = self.gpt_neox( - inputs=inputs, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - - def _slice(x: te.Tensor): - _, seq_len, hidden_dim = x.shape - return te.compute( - shape=(1, 1, hidden_dim), - fcompute=lambda i, _, k: x[i, seq_len - 1, k], - name="slice", - ) - - hidden_states = nn.emit_te( - _slice, - hidden_states, - primfunc_name_hint="slice", - ) - hidden_states = astype(hidden_states, "float32") - logits = self.embed_out(hidden_states) - return logits, key_value_cache - - -def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: - if "embed_in.weight" in name: - return ParamQuantKind.embedding_table - elif "embed_out.weight" in name: - return ParamQuantKind.final_fc_weight - elif param_info.ndim == 2 and name.endswith(".weight"): - return ParamQuantKind.linear_weight - else: - return ParamQuantKind.others - - -def create_embed_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: GPTNeoXConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "embed" - - bsz = 1 - seq_len = tvm.tir.SizeVar("m", "int64") - with bb.function(func_name): - model = GPTNeoXEmbedTokensWrapper(config) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") - with bb.dataflow(): - inputs_embeds = model(input_ids) - params = [input_ids] + model.parameters() - gv = bb.emit_output(inputs_embeds) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var("embed") - bb.update_func(gv, mod[gv].with_attr("num_input", 1)) - - -def create_encoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: GPTNeoXConfig, - quant_scheme: QuantizationScheme, - sep_embed: bool = False, -) -> None: - func_name = "prefill_with_embed" if sep_embed else "prefill" - - batch_size = tvm.tir.IntImm("int64", 1) - seq_len = tvm.tir.SizeVar("n", "int64") - all_seq_len = tvm.tir.SizeVar("m", "int64") - hidden_size = config.hidden_size - with bb.function(func_name): - model = GPTNeoXForCausalLM(config, sep_embed) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - inputs = ( - nn.Placeholder( - (batch_size, seq_len, hidden_size), - dtype=config.dtype, - name="input_embeds", - ) - if sep_embed - else nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") - ) - all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] - ), - ) - with bb.dataflow(): - logits, key_value_cache = model( - inputs=inputs, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - params = [ - inputs, - all_seq_len_shape, - past_key_values, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_decoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: GPTNeoXConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "decode" - - batch_size = tvm.tir.IntImm("int64", 1) - seq_len = tvm.tir.IntImm("int64", 1) - all_seq_len = tvm.tir.SizeVar("m", "int64") - with bb.function(func_name): - model = GPTNeoXForCausalLM(config) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") - all_seq_len_shape = relax.Var( - "all_seq_len", - relax.ShapeStructInfo((all_seq_len,)), - ) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] - ), - ) - with bb.dataflow(): - logits, key_value_cache = model( - inputs=input_ids, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - params = [ - input_ids, - all_seq_len_shape, - past_key_values, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_kv_cache_func( - bb: relax.BlockBuilder, - config: GPTNeoXConfig, -) -> None: - init_shape = relax.ShapeExpr( - ( - config.max_sequence_length, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ) - ) - with bb.function("create_kv_cache", []): - with bb.dataflow(): - zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) - caches = [] - f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") - for _ in range(config.num_hidden_layers * 2): - caches.append( - bb.emit( - relax.call_pure_packed( - f_kv_cache_create, - zeros, - init_shape, - relax.PrimValue(0), - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - ) - gv = bb.emit_output(caches) - bb.emit_func_output(gv) - - -def create_softmax_func(bb: relax.BlockBuilder, config: GPTNeoXConfig) -> None: - with bb.function("softmax_with_temperature"): - logits = nn.Placeholder((1, 1, config.vocab_size), dtype="float32", name="logits") - temperature = nn.Placeholder((), dtype="float32", name="temperature") - with bb.dataflow(): - div = bb.emit(relax.op.divide(logits, temperature)) - softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) - gv = bb.emit_output(softmax) - bb.emit_func_output(gv, [logits, temperature]) - - -def get_model( - args: argparse.Namespace, - hf_config, -): - model = args.model - dtype = args.quantization.model_dtype - ffn_out_dtype = "float32" - sep_embed = args.sep_embed - - if model.startswith("dolly-"): - stop_tokens = [2] - ffn_out_dtype = "float16" - elif model.startswith("stablelm-"): - stop_tokens = [50278, 50279, 50277, 1, 0] - ffn_out_dtype = "float16" - elif model.lower().startswith("stablecode-"): - stop_tokens = [0] - elif model.lower().startswith("redpajama-"): - stop_tokens = [0] - else: - raise ValueError(f"Unsupported model {model}") - - config = GPTNeoXConfig( - **hf_config, - max_sequence_length=args.max_seq_len if args.max_seq_len != -1 else 2048, - dtype=dtype, - ffn_out_dtype=ffn_out_dtype, - ) - - param_manager = ParamManager() - bb = relax.BlockBuilder() - if sep_embed: - create_embed_func(bb, param_manager, config, args.quantization) - create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) - create_decoding_func(bb, param_manager, config, args.quantization) - create_kv_cache_func(bb, config) - create_softmax_func(bb, config) - create_metadata_func( - bb, - model_name=model, - max_window_size=config.max_sequence_length, - stop_tokens=stop_tokens, - add_prefix_space=False, - prefill_chunk_size=args.prefill_chunk_size, - ) - mod = bb.get() - - tir_bound_map = dict() - tir_bound_map["n"] = ( - args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length - ) - tir_bound_map["m"] = config.max_sequence_length - for gv in mod.functions: - func = mod[gv] - if isinstance(func, relax.Function): - mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) - - if args.build_model_only: - return mod, param_manager, None, config - - def f_convert_pname_fwd(pname: str) -> List[str]: - return [pname] - - def f_convert_param_bkwd(torch_pname: str, torch_param): - # torch_param: numpy.ndarray - if "layernorm" in torch_pname or "layer_norm" in torch_pname or "embed_out" in torch_pname: - return [(torch_pname, torch_param.astype("float32"))] - elif ".dense_h_to_4h.bias" in torch_pname or ".dense_4h_to_h.bias" in torch_pname: - return [(torch_pname, torch_param.astype(ffn_out_dtype))] - else: - return [(torch_pname, torch_param.astype(dtype))] - - param_manager.set_param_loading_func( - args.model_path, args.use_safetensors, f_convert_pname_fwd, f_convert_param_bkwd - ) - return mod, param_manager, [None] * len(param_manager.param_names), config diff --git a/mlc_llm/relax_model/gptj.py b/mlc_llm/relax_model/gptj.py deleted file mode 100644 index ea755a447a..0000000000 --- a/mlc_llm/relax_model/gptj.py +++ /dev/null @@ -1,692 +0,0 @@ -import math -from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union - -import tvm -from tvm import relax, te -from tvm.relax.op import ( - astype, - broadcast_to, - full, - matmul, - maximum, - minimum, - permute_dims, - reshape, - squeeze, - triu, -) -from tvm.relax.op.nn import gelu, softmax -from tvm.relax.testing import nn -from tvm.script import relax as R - -from ..quantization import ParamQuantKind, QuantizationScheme -from .commons import create_metadata_func -from .gpt_neox import create_kv_cache_func -from .modules import Embedding, LayerNorm, Linear, ModuleList, RotaryEmbedding -from .param_manager import ParamManager - - -def _min_value(dtype) -> relax.Expr: - v = tvm.tir.min_value(dtype).value - if dtype == "float16": - v = -55504.0 - return relax.const(v, dtype) - - -def _max_value(dtype) -> relax.Expr: - v = tvm.tir.max_value(dtype).value - if dtype == "float16": - v = 55504.0 - return relax.const(v, dtype) - - -@dataclass -class GPTJConfig: # pylint: disable=too-many-instance-attributes - def __init__( - self, - vocab_size, - n_embd, - n_inner, - n_head, - n_layer, - bos_token_id, - eos_token_id, - rotary_dim, - tie_word_embeddings, - dtype="float32", - layer_norm_eps=1e-5, - max_sequence_length=2048, - rotary_emb_base=10000, - **kwargs, - ): - self.vocab_size = vocab_size - self.hidden_size = n_embd - self.intermediate_size = n_inner if n_inner is not None else 4 * n_embd - self.num_attention_heads = n_head - self.num_hidden_layers = n_layer - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.rotary_dim = rotary_dim - self.tie_word_embeddings = tie_word_embeddings - self.dtype = dtype - self.layer_norm_eps = layer_norm_eps - self.max_sequence_length = max_sequence_length - self.rotary_emb_base = rotary_emb_base - self.kwargs = kwargs - - -class GPTJMLP(nn.Module): - def __init__(self, hidden_size: int, intermediate_size: int, dtype: str): - super().__init__() - self.fc_in = Linear(hidden_size, intermediate_size, dtype, bias=True) - self.fc_out = Linear(intermediate_size, hidden_size, dtype, bias=True) - self.dtype = dtype - - def forward(self, hidden_states): - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - hidden_states = self.fc_in(hidden_states) - hidden_states = nn.emit(gelu(hidden_states)) - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - hidden_states = self.fc_out(hidden_states) - return nn.emit(hidden_states) - - -class GPTJAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - hidden_size: int, - num_heads: int, - rotary_embedding: RotaryEmbedding, - dtype: str, - ): - if hidden_size % num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {hidden_size}" - f" and `num_heads`: {num_heads})." - ) - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - self.rotary_embedding = rotary_embedding - self.q_proj = Linear(hidden_size, hidden_size, dtype, bias=False) - self.k_proj = Linear(hidden_size, hidden_size, dtype, bias=False) - self.v_proj = Linear(hidden_size, hidden_size, dtype, bias=False) - self.out_proj = Linear(hidden_size, hidden_size, dtype, bias=False) - self.dtype = dtype - - def forward( - self, - hidden_states: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_value: Optional[Tuple[relax.Expr, relax.Expr]] = None, - attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Union[Tuple[None, None], Tuple[relax.Expr, relax.Expr]]]: - # hidden_states: [batch_size, seq_len, hidden_size] - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - batch_size, seq_len, _ = hidden_states.struct_info.shape - kv_seq_len = all_seq_len_shape.struct_info.values[0] - - def _project(proj): - return nn.emit( - reshape( - proj(hidden_states), - (batch_size, seq_len, self.num_heads, self.head_dim), - ) - ) - - # q/k/v states: [batch_size, seq_len, num_attention_heads, head_size] - q, k, v = ( - _project(self.q_proj), - _project(self.k_proj), - _project(self.v_proj), - ) - q, k = self.rotary_embedding(q, k, kv_seq_len - seq_len) - - if past_key_value is not None: - f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") - f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") - k_cache, v_cache = past_key_value - k_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_append, - k_cache, - squeeze(k, axis=0), - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - v_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_append, - v_cache, - squeeze(v, axis=0), - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - batch_size, _, num_heads, head_size = k.struct_info.shape - kv_cache_shape = R.shape([kv_seq_len, num_heads, head_size]) - kv_states_shape = R.shape([batch_size, kv_seq_len, num_heads, head_size]) - k = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - k_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)], - ) - ) - v = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - v_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)], - ) - ) - k = nn.emit(reshape(k, kv_states_shape)) - v = nn.emit(reshape(v, kv_states_shape)) - past_key_value = (k_cache, v_cache) - else: - past_key_value = (None, None) - - q = nn.emit(permute_dims(q, [0, 2, 1, 3])) - k = nn.emit(permute_dims(k, [0, 2, 1, 3])) - v = nn.emit(permute_dims(v, [0, 2, 1, 3])) - - # Calculate QK - attn_weights = nn.emit( - matmul(q, permute_dims(k, [0, 1, 3, 2])) - / relax.const( - math.sqrt(self.head_dim), - q.struct_info.dtype, - ) - ) - # Apply attention mask - attn_weights = nn.emit(attn_weights + attention_mask) - attn_weights = nn.emit( - minimum( - maximum( - attn_weights, - _min_value(attn_weights.struct_info.dtype), - ), - _max_value(attn_weights.struct_info.dtype), - ) - ) - # Calculate Softmax(QK) - if attn_weights.struct_info.dtype != "float32": - attn_weights = astype(attn_weights, "float32") - attn_weights = nn.emit(softmax(attn_weights, axis=-1)) - if attn_weights.struct_info.dtype != q.struct_info.dtype: - attn_weights = astype(attn_weights, q.struct_info.dtype) - # Calculate Softmax(QK)V - attn_output = nn.emit(matmul(attn_weights, v)) - # Apply output projection - attn_output = self.out_proj( - reshape( - permute_dims(attn_output, [0, 2, 1, 3]), - (batch_size, seq_len, self.hidden_size), - ) - ) - return attn_output, past_key_value - - -class GPTJLayer(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - layer_norm_eps: float, - num_heads: int, - rotary_embedding: RotaryEmbedding, - dtype: str, - ): - self.ln_1 = LayerNorm( - hidden_size, - eps=layer_norm_eps, - dtype=dtype, - ) - self.attn = GPTJAttention( - hidden_size, - num_heads=num_heads, - rotary_embedding=rotary_embedding, - dtype=dtype, - ) - self.mlp = GPTJMLP( - hidden_size, - intermediate_size=intermediate_size, - dtype=dtype, - ) - self.dtype = dtype - - def forward( - self, - hidden_states, - all_seq_len_shape: relax.Expr, - past_key_value: Optional[Tuple[relax.Expr]] = None, - attention_mask: Optional[relax.Expr] = None, - ): - normalized_input = self.ln_1(hidden_states) - attn_output, present_key_value = self.attn( - normalized_input, - all_seq_len_shape, - past_key_value, - attention_mask, - ) - mlp_output = self.mlp(normalized_input) - hidden_states = nn.emit(mlp_output + attn_output + hidden_states) - return hidden_states, present_key_value - - -def _prepare_decoder_attention_mask(input_shape, src_len, dtype): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if isinstance(input_shape[-1], tvm.tir.SizeVar) or input_shape[-1] > 1: - bsz, tgt_len = input_shape - mask = full((tgt_len, tgt_len), _min_value(dtype)) - mask = triu(mask, k=1) - diag_mask = nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, tgt_len))) - if src_len == tgt_len: - return diag_mask - - def extend_te(x, tgt_len, src_len): - return te.compute( - (bsz, 1, tgt_len, src_len), - lambda b, _, i, j: te.if_then_else( - j < src_len - tgt_len, 0, x[b, _, i, j - (src_len - tgt_len)] - ), - name="concat_te", - ) - - return nn.emit_te(extend_te, diag_mask, tgt_len, src_len) - else: - # Get src_len from input parameters - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - bsz, tgt_len = input_shape - mask = relax.op.zeros((bsz, 1, tgt_len, src_len), dtype) - return nn.emit(mask) - - -class GPTJEmbedTokens(nn.Module): - def __init__(self, config: GPTJConfig): - self.wte = Embedding( - num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - dtype=config.dtype, - ) - - def forward(self, input_ids: relax.Expr): - return self.wte(input_ids) - - -class GPTJEmbedTokensWrapper(nn.Module): - def __init__(self, config: GPTJConfig): - # build a wrapper to ensure that the naming of the embed_in parameter is consistent - self.gptj = GPTJEmbedTokens(config) - - def forward(self, input_ids: relax.Expr): - return self.gptj(input_ids) - - -class GPTJModel(nn.Module): - def __init__( - self, - config: GPTJConfig, - sep_embed: bool = False, - ): - rotary_embedding = RotaryEmbedding( - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - position_embedding_base=config.rotary_emb_base, - max_sequence_length=config.max_sequence_length, - rotary_dim=config.rotary_dim, - swizzle_style="gptj", - dtype=config.dtype, - ) - self.wte = None - if not sep_embed: - self.wte = Embedding( - num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - dtype=config.dtype, - ) - self.h = ModuleList( - [ - GPTJLayer( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - layer_norm_eps=config.layer_norm_eps, - num_heads=config.num_attention_heads, - rotary_embedding=rotary_embedding, - dtype=config.dtype, - ) - for _ in range(config.num_hidden_layers) - ] - ) - self.ln_f = LayerNorm( - hidden_size=config.hidden_size, - eps=config.layer_norm_eps, - dtype=config.dtype, - ) - - def forward( - self, - inputs: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_values: Optional[Tuple[relax.Expr, relax.Expr]], - ): - batch_size, seq_length = inputs.struct_info.shape - seq_length_with_past = all_seq_len_shape.struct_info.values[0] - # embed positions - hidden_states = self.wte(inputs) if self.wte is not None else inputs - attention_mask = _prepare_decoder_attention_mask( - (batch_size, seq_length), - seq_length_with_past, - dtype=hidden_states.struct_info.dtype, - ) - present_kv_cache = [] - for i, layer in enumerate(self.h): - past_key_value = ( - (past_key_values[i * 2], past_key_values[i * 2 + 1]) - if past_key_values is not None - else None - ) - hidden_states, (present_k_cache, present_v_cache) = layer( - hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - all_seq_len_shape=all_seq_len_shape, - ) - present_kv_cache.append(present_k_cache) - present_kv_cache.append(present_v_cache) - hidden_states = self.ln_f(hidden_states) - return hidden_states, present_kv_cache - - -class GPTJForCausalLM(nn.Module): - def __init__( - self, - config: GPTJConfig, - sep_embed: bool = False, - ): - self.transformer = GPTJModel(config, sep_embed) - self.lm_head = Linear( - in_features=config.hidden_size, - out_features=config.vocab_size, - bias=True, - dtype=config.dtype, - ) - self.dtype = config.dtype - - def forward( - self, - inputs: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_values: Optional[List[relax.Expr]], - ): - hidden_states, key_value_cache = self.transformer( - inputs=inputs, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - if hidden_states.struct_info.dtype != self.dtype: - hidden_states = nn.emit(astype(hidden_states, self.dtype)) - - def _slice(x: te.Tensor): - _, seq_len, hidden_dim = x.shape - return te.compute( - shape=(1, 1, hidden_dim), - fcompute=lambda i, _, k: x[i, seq_len - 1, k], - name="slice", - ) - - hidden_states = nn.emit_te( - _slice, - hidden_states, - primfunc_name_hint="slice", - ) - logits = self.lm_head(hidden_states) - if logits.struct_info.dtype != "float32": - logits = nn.emit(astype(logits, "float32")) - - return logits, key_value_cache - - -def check_parameters(param_dict, param_list): - relax_shape_to_list = lambda _: [s.value for s in _.values] - shape_dict_0 = {k: relax_shape_to_list(v.struct_info.shape) for k, v in param_dict.items()} - shape_dict_1 = {k: list(v.shape) for (k, v) in param_list} - assert len(shape_dict_0) == len(shape_dict_1) - for k, v in shape_dict_0.items(): - assert k in shape_dict_1, "{}".format(k) - assert v == shape_dict_1[k], "key={}, shape_0={}, shape_1={}".format(k, v, shape_dict_1[k]) - - -def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: - if "wte.weight" in name: - return ParamQuantKind.embedding_table - elif "lm_head.weight" in name: - return ParamQuantKind.final_fc_weight - elif param_info.ndim == 2 and name.endswith(".weight"): - return ParamQuantKind.linear_weight - else: - return ParamQuantKind.others - - -def create_embed_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: GPTJConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "embed" - - bsz = 1 - seq_len = tvm.tir.SizeVar("m", "int64") - with bb.function(func_name): - model = GPTJEmbedTokensWrapper(config) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") - with bb.dataflow(): - inputs_embeds = model(input_ids) - params = [input_ids] + model.parameters() - gv = bb.emit_output(inputs_embeds) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var("embed") - bb.update_func(gv, mod[gv].with_attr("num_input", 1)) - - -def create_encoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: GPTJConfig, - quant_scheme: QuantizationScheme, - sep_embed: bool = False, -) -> None: - func_name = "prefill_with_embed" if sep_embed else "prefill" - - batch_size = tvm.tir.IntImm("int64", 1) - seq_len = tvm.tir.SizeVar("n", "int64") - all_seq_len = tvm.tir.SizeVar("m", "int64") - hidden_size = config.hidden_size - with bb.function(func_name): - model = GPTJForCausalLM(config, sep_embed) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - inputs = ( - nn.Placeholder( - (batch_size, seq_len, hidden_size), - dtype=config.dtype, - name="input_embeds", - ) - if sep_embed - else nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") - ) - all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] - ), - ) - with bb.dataflow(): - logits, key_value_cache = model( - inputs=inputs, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - params = [ - inputs, - all_seq_len_shape, - past_key_values, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_decoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: GPTJConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "decode" - - batch_size = tvm.tir.IntImm("int64", 1) - seq_len = tvm.tir.IntImm("int64", 1) - all_seq_len = tvm.tir.SizeVar("m", "int64") - with bb.function(func_name): - model = GPTJForCausalLM(config) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((batch_size, seq_len), dtype="int32", name="input_ids") - all_seq_len_shape = relax.Var( - "all_seq_len", - relax.ShapeStructInfo((all_seq_len,)), - ) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] - ), - ) - with bb.dataflow(): - logits, key_value_cache = model( - inputs=input_ids, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - params = [ - input_ids, - all_seq_len_shape, - past_key_values, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_softmax_func(bb: relax.BlockBuilder, config: GPTJConfig) -> None: - with bb.function("softmax_with_temperature"): - logits = nn.Placeholder((1, 1, config.vocab_size), dtype="float32", name="logits") - temperature = nn.Placeholder((), dtype="float32", name="temperature") - with bb.dataflow(): - div = bb.emit(relax.op.divide(logits, temperature)) - softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) - gv = bb.emit_output(softmax) - bb.emit_func_output(gv, [logits, temperature]) - - -def get_model(args, hf_config): - model_name = args.model - dtype = args.quantization.model_dtype - max_seq_len = args.max_seq_len - sep_embed = args.sep_embed - - if model_name.startswith("gpt-j-"): - stop_tokens = [50256] - elif model_name.startswith("moss-"): - stop_tokens = [106068] - - config = GPTJConfig(**hf_config, dtype=dtype) - if max_seq_len != -1: - config.max_sequence_length = max_seq_len - - param_manager = ParamManager() - bb = relax.BlockBuilder() - if sep_embed: - create_embed_func(bb, param_manager, config, args.quantization) - create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) - create_decoding_func(bb, param_manager, config, args.quantization) - create_kv_cache_func(bb, config) - create_softmax_func(bb, config) - create_metadata_func( - bb, - model_name=model_name, - max_window_size=config.max_sequence_length, - stop_tokens=stop_tokens, - add_prefix_space=True, - prefill_chunk_size=args.prefill_chunk_size, - ) - mod = bb.get() - - tir_bound_map = dict() - tir_bound_map["n"] = ( - args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length - ) - tir_bound_map["m"] = config.max_sequence_length - for gv in mod.functions: - func = mod[gv] - if isinstance(func, relax.Function): - mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) - - if args.build_model_only: - return mod, param_manager, None, config - - def f_convert_pname_fwd(pname: str) -> List[str]: - import re - - str_pattern = re.compile(r"(q|k|v)_proj") - if re.search(str_pattern, pname) is not None: - return [str_pattern.sub("qkv_proj", pname)] - else: - return [pname] - - hidden_size = config.hidden_size - - def f_convert_param_bkwd(torch_pname: str, torch_param) -> Optional[List[Tuple[str, Any]]]: - # torch_param: numpy.ndarray - if torch_pname.endswith("qkv_proj.weight"): - assert torch_param.ndim == 2 - mp_num = 4 - torch_param = torch_param.astype(dtype).reshape(mp_num, 3, -1, hidden_size) - q_weight = torch_param[:, 0, :, :].reshape(hidden_size, hidden_size) - k_weight = torch_param[:, 2, :, :].reshape(hidden_size, hidden_size) - v_weight = torch_param[:, 1, :, :].reshape(hidden_size, hidden_size) - return [ - (torch_pname.replace("qkv_proj", "q_proj"), q_weight), - (torch_pname.replace("qkv_proj", "k_proj"), k_weight), - (torch_pname.replace("qkv_proj", "v_proj"), v_weight), - ] - if "ln_1" in torch_pname or "ln_f" in torch_pname: - return [(torch_pname, torch_param.astype("float32"))] - else: - return [(torch_pname, torch_param.astype(dtype))] - - param_manager.set_param_loading_func( - args.model_path, args.use_safetensors, f_convert_pname_fwd, f_convert_param_bkwd - ) - return mod, param_manager, [None] * len(param_manager.param_names), config diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py deleted file mode 100644 index 7cad3d6fc4..0000000000 --- a/mlc_llm/relax_model/llama.py +++ /dev/null @@ -1,1505 +0,0 @@ -import math -from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union - -import numpy as np -import tvm -from tvm import relax, te, tir -from tvm.relax.op import ccl -from tvm.relax.testing import nn -from tvm.script import relax as R - -from ..quantization import ParamQuantKind, QuantizationScheme -from .commons import create_metadata_func -from .modules import ModuleList -from .param_manager import ParamManager - - -@dataclass -class LlamaConfig: - def __init__( - self, - dtype="float32", - max_sequence_length=2048, - vocab_size=32000, # some models like WizardMath can have 32001 - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - initializer_range=0.02, - rms_norm_eps=1e-6, - pad_token_id=-1, - bos_token_id=0, - eos_token_id=1, - tie_word_embeddings=False, - position_embedding_base=10000, - combine_matmul=True, - build_model_only=False, - num_shards=1, - sliding_window=None, - target_kind=None, - **kwargs, - ): - self.dtype = dtype - self.max_sequence_length = max_sequence_length - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.tie_word_embeddings = tie_word_embeddings - self.position_embedding_base = position_embedding_base - self.combine_matmul = combine_matmul - self.sliding_window = sliding_window - self.target_kind = target_kind - - if build_model_only and num_shards > 1: - self.num_shards = num_shards - else: - self.num_shards = 1 - self.kwargs = kwargs - - def get_num_key_value_heads(self): - if self.num_key_value_heads is None: - return self.num_attention_heads - - return self.num_key_value_heads - - -class Linear(nn.Module): - def __init__(self, in_features, out_features, dtype: str, bias=True): - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter((out_features, in_features), dtype=dtype, name="linear_weight") - if bias: - self.bias = nn.Parameter((out_features,), dtype=dtype, name="linear_bias") - else: - self.bias = None - - def forward(self, input: relax.Expr) -> relax.Var: - return nn.emit(relax.op.linear(input, self.weight, self.bias)) - - -class Embedding(nn.Module): - def __init__(self, num_embeddings, embedding_dim, dtype: str): - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.weight = nn.Parameter( - (num_embeddings, embedding_dim), dtype=dtype, name="embedding_weight" - ) - - def forward(self, x: relax.Expr) -> relax.Var: - from tvm.relax.op import reshape, take - - ndim = x.struct_info.ndim - if ndim == 1: - return nn.emit(take(self.weight, x, axis=0)) - else: - x_shape = x.struct_info.shape.values - emb_size = self.weight.struct_info.shape.values[-1] - x = nn.emit(reshape(x, shape=[-1])) - embedding = nn.emit(take(self.weight, x, axis=0)) - return nn.emit(reshape(embedding, [*x_shape, emb_size])) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, dtype, eps=1e-6): - self.weight = nn.Parameter((hidden_size,), dtype=dtype, name="rms_norm_weight") - self.variance_epsilon = tvm.tir.const(eps, dtype) - - def forward(self, hidden_states): - from tvm import te, tir - - def f_rms_norm(x, weight): - is_float32 = x.dtype == "float32" - - def f_square(x): - return tir.Cast("float32", x) * tir.Cast("float32", x) if not is_float32 else x * x - - def f_mul_cast(x, y): - value = x * y - if not is_float32: - value = tir.Cast(x.dtype, value) - return value - - def f_div_cast_2d(i, k): - x_val = x[i, k] - if not is_float32: - x_val = tir.Cast("float32", x_val) - return x_val / tir.sqrt(square_sum[i] / x.shape[1] + self.variance_epsilon) - - def f_div_cast_3d(bsz, i, k): - x_val = x[bsz, i, k] - if not is_float32: - x_val = tir.Cast("float32", x_val) - return x_val / tir.sqrt(square_sum[bsz, i] / x.shape[2] + self.variance_epsilon) - - k = te.reduce_axis((0, x.shape[-1]), name="k") - - if len(x.shape) == 2: - square_sum = te.compute( - (x.shape[0],), - lambda i: te.sum(f_square(x[i, k]), axis=k), - name=x.op.name + "red_temp", - ) - - return te.compute( - x.shape, - lambda i, k: f_mul_cast(weight(k), f_div_cast_2d(i, k)), - name="rms_norm", - ) - else: - square_sum = te.compute( - (x.shape[0], x.shape[1]), - lambda bsz, i: te.sum(f_square(x[bsz, i, k]), axis=k), - name=x.op.name + "red_temp", - ) - - return te.compute( - x.shape, - lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast_3d(bsz, i, k)), - name="rms_norm", - ) - - return nn.emit_te(f_rms_norm, hidden_states, self.weight, primfunc_name_hint="rms_norm") - - -class LlamaMLP(nn.Module): - def __init__(self, config: LlamaConfig): - self.combine_matmul = config.combine_matmul - self.num_shards = config.num_shards - hidden_size = config.hidden_size - intermediate_size = config.intermediate_size // self.num_shards - dtype = config.dtype - if self.combine_matmul: - self.gate_up_proj = Linear(hidden_size, 2 * intermediate_size, dtype=dtype, bias=False) - self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) - self.gate_up_proj.weight.shard_dim = 0 - self.gate_up_proj.weight.shard_strategy = "shard_gate_up" - self.down_proj.weight.shard_dim = 1 - self.down_proj.weight.shard_strategy = "shard_mlp_k" - else: - self.gate_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) - self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) - self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) - self.gate_proj.weight.shard_dim = 0 - self.gate_proj.weight.shard_strategy = "shard_axis_0" - self.down_proj.weight.shard_dim = 1 - self.down_proj.weight.shard_strategy = "shard_axis_1" - self.up_proj.weight.shard_dim = 0 - self.up_proj.weight.shard_strategy = "shard_axis_0" - - def forward(self, x): - if self.combine_matmul: - gate_up_results = nn.emit( - relax.op.split( - self.gate_up_proj(x), - indices_or_sections=2, - axis=-1, - ) - ) - gate_result = relax.TupleGetItem(gate_up_results, 0) - up_result = relax.TupleGetItem(gate_up_results, 1) - else: - gate_result = self.gate_proj(x) - up_result = self.up_proj(x) - - result = self.down_proj(relax.op.nn.silu(gate_result) * up_result) - return result - - -def rotary_modulate_by_freq(tensor, idx, pos, position_embedding_base): - head_dim = tensor.shape[-1] - dtype = tensor.dtype - n_feat_half = head_dim // 2 - feat_idx = idx[-1] - inv_freq = te.const(1, "float32") / ( - te.power( - te.const(position_embedding_base, "float32"), - ((2 * feat_idx) % head_dim).astype("float32") / head_dim.astype("float32"), - ) - ) - freq = pos * inv_freq - left_indices = idx[:-1] + (feat_idx - n_feat_half,) - right_indices = idx[:-1] + (feat_idx + n_feat_half,) - return te.cos(freq).astype(dtype) * tensor(*idx) + te.sin(freq).astype(dtype) * tvm.tir.Select( - feat_idx >= n_feat_half, - tensor[(*left_indices,)], - -tensor[(*right_indices,)], - ) - - -def apply_rotary_pos_emb(q, k, position_embedding_base, offset: int = 0): - def f_rotary_embedding(tensor, offset): - def rotary_compute(*idx): - pos = (offset + idx[-3]).astype("float32") - return rotary_modulate_by_freq( - tensor, - idx, - pos, - position_embedding_base, - ) - - return tvm.te.compute(tensor.shape, rotary_compute, name="rotary") - - q_embed = nn.emit_te(f_rotary_embedding, q, offset, primfunc_name_hint="rotary_embedding") - k_embed = nn.emit_te(f_rotary_embedding, k, offset, primfunc_name_hint="rotary_embedding") - return q_embed, k_embed - - -class LlamaAttentionBase(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - dtype = config.dtype - self.num_shards = config.num_shards - self.hidden_size = config.hidden_size - self.num_key_value_heads = config.get_num_key_value_heads() // config.num_shards - self.num_query_heads = config.num_attention_heads // self.num_shards - self.head_dim = self.hidden_size // config.num_attention_heads - self.position_embedding_base = config.position_embedding_base - - self.combine_matmul = config.combine_matmul - if self.combine_matmul: - self.query_key_value_proj = Linear( - self.hidden_size, - (self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim, - dtype=dtype, - bias=False, - ) - self.query_key_value_proj.weight.shard_dim = 0 - self.query_key_value_proj.weight.shard_strategy = "shard_qkv" - else: - self.q_proj = Linear( - self.hidden_size, - self.num_query_heads * self.head_dim, - dtype=dtype, - bias=False, - ) - self.k_proj = Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - dtype=dtype, - bias=False, - ) - self.v_proj = Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - dtype=dtype, - bias=False, - ) - self.q_proj.weight.shard_dim = 0 - self.k_proj.weight.shard_dim = 0 - self.v_proj.weight.shard_dim = 0 - self.q_proj.weight.shard_strategy = "shard_axis_0" - self.k_proj.weight.shard_strategy = "shard_axis_0" - self.v_proj.weight.shard_strategy = "shard_axis_0" - - self.o_proj = Linear( - self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False - ) - self.o_proj.weight.shard_dim = 1 - self.o_proj.weight.shard_strategy = "shard_o_proj_k" - - def project_qkv(self, hidden_states, query_output_shape, kv_output_shape): - from tvm.relax.op import reshape, split - - if self.combine_matmul: - qkv_states = nn.emit( - split( - self.query_key_value_proj(hidden_states), - indices_or_sections=[ - self.num_query_heads * self.head_dim, - (self.num_query_heads + self.num_key_value_heads) * self.head_dim, - ], - axis=-1, - ) - ) - query_states = relax.TupleGetItem(qkv_states, 0) - key_states = relax.TupleGetItem(qkv_states, 1) - value_states = relax.TupleGetItem(qkv_states, 2) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = nn.emit( - reshape(query_states, query_output_shape), - ) - key_states = nn.emit( - reshape(key_states, kv_output_shape), - ) - value_states = nn.emit( - reshape(value_states, kv_output_shape), - ) - - return query_states, key_states, value_states - - def forward( - self, - hidden_states: relax.Expr, - all_seq_len_shape: Optional[relax.Expr], - past_key_values: Union[relax.Expr, Tuple[relax.Expr]], - layer_id: int, - attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Union[relax.Expr, Tuple[relax.Expr]]]: - bsz, q_len, _ = hidden_states.struct_info.shape - - query_states, key_states, value_states = self.project_qkv( - hidden_states, - (bsz, q_len, self.num_query_heads, self.head_dim), - (bsz, q_len, self.num_key_value_heads, self.head_dim), - ) - - from tvm.relax.op import reshape - - attn_output, past_key_values = self.attention_fwd( - query_states, - key_states, - value_states, - past_key_values, - bsz, - q_len, - layer_id=layer_id, - all_seq_len_shape=all_seq_len_shape, - attention_mask=attention_mask, - ) - - attn_output = nn.emit( - reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) - ) - attn_output = self.o_proj(attn_output) - return attn_output, past_key_values - - def attention_fwd( - self, - query_states: relax.Expr, - key_states: relax.Expr, - value_states: relax.Expr, - past_key_values: relax.Expr, - batch_size: tir.PrimExpr, - q_len: tir.PrimExpr, - **kwargs, - ): - raise NotImplementedError() - - -class LlamaPagedAttention(LlamaAttentionBase): - def __init__(self, config: LlamaConfig): - super().__init__(config) - - def attention_fwd( - self, - query_states: relax.Expr, - key_states: relax.Expr, - value_states: relax.Expr, - past_key_values: relax.Expr, - batch_size: tir.PrimExpr, - q_len: tir.PrimExpr, - **kwargs, - ) -> Tuple[relax.Expr, relax.Expr]: - assert "layer_id" in kwargs and isinstance(kwargs["layer_id"], int) - layer_id = kwargs["layer_id"] - - f_kv_cache_attention = relax.extern("vm.builtin.paged_attention_kv_cache_attention") - attn_output = nn.emit( - relax.call_dps_packed( - f_kv_cache_attention, - [ - past_key_values, - relax.PrimValue(layer_id), - query_states, - key_states, - value_states, - ], - out_sinfo=relax.TensorStructInfo( - ((batch_size, q_len, self.num_query_heads, self.head_dim)), - query_states.struct_info.dtype, - ), - ) - ) - return attn_output, past_key_values - - -class LlamaAttention(LlamaAttentionBase): - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.config = config - - def attention_fwd( - self, - query_states: relax.Expr, - key_states: relax.Expr, - value_states: relax.Expr, - past_key_values: relax.Expr, - batch_size: tir.PrimExpr, - q_len: tir.PrimExpr, - **kwargs, - ) -> Tuple[relax.Expr, Tuple[relax.Expr]]: - assert "attention_mask" in kwargs - assert "all_seq_len_shape" in kwargs - attention_mask = kwargs["attention_mask"] - kv_seq_len = kwargs["all_seq_len_shape"].struct_info.values[0] - - from tvm.relax.op import astype, matmul, maximum, permute_dims, reshape, squeeze - from tvm.relax.op.nn import softmax - - offset = kv_seq_len - q_len - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - self.position_embedding_base, - offset=offset, - ) - # [bsz, t, nh, hd] - - kv_states_shape = key_states.struct_info.shape - kv_states_dtype = key_states.struct_info.dtype - assert kv_states_shape[0] == 1 # bsz - kv_states_shape = R.shape( - [kv_states_shape[0], kv_seq_len, kv_states_shape[2], kv_states_shape[3]] - ) - kv_cache_shape = R.shape([kv_seq_len, kv_states_shape[2], kv_states_shape[3]]) - - squeezed_key = nn.emit(squeeze(key_states, axis=0)) - squeezed_value = nn.emit(squeeze(value_states, axis=0)) - k_cache, v_cache = past_key_values - f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") - k_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_append, - k_cache, - squeezed_key, - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - v_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_append, - v_cache, - squeezed_value, - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - past_key_values = (k_cache, v_cache) - f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") - k_cache = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - k_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], - ) - ) - v_cache = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - v_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], - ) - ) - key_states = nn.emit(reshape(k_cache, kv_states_shape)) - value_states = nn.emit(reshape(v_cache, kv_states_shape)) - if self.num_key_value_heads != self.num_query_heads: - n_rep = self.num_query_heads // self.num_key_value_heads - key_states = nn.emit(relax.op.repeat(key_states, n_rep, axis=2)) - value_states = nn.emit(relax.op.repeat(value_states, n_rep, axis=2)) - - if self.config.target_kind == "android": - attn_weights = nn.emit( - matmul( - permute_dims(query_states, [0, 2, 1, 3]), permute_dims(key_states, [0, 2, 3, 1]) - ) - / relax.const(math.sqrt(self.head_dim), query_states.struct_info.dtype) - ) - else: - query_states = nn.emit(permute_dims(query_states, [0, 2, 1, 3])) - key_states = nn.emit(permute_dims(key_states, [0, 2, 1, 3])) - value_states = nn.emit(permute_dims(value_states, [0, 2, 1, 3])) - - attn_weights = nn.emit( - matmul(query_states, permute_dims(key_states, [0, 1, 3, 2])) - / relax.const(math.sqrt(self.head_dim), query_states.struct_info.dtype) - ) - - tvm.ir.assert_structural_equal( - attention_mask.struct_info.shape.values, - (batch_size, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), - ) - - attn_weights = nn.emit( - maximum( - attn_weights, - relax.const( - tvm.tir.min_value(attn_weights.struct_info.dtype).value, - attn_weights.struct_info.dtype, - ), - ) - ) - attn_weights = nn.emit(relax.op.minimum(attn_weights, attention_mask)) - - # upcast attention to fp32 - if attn_weights.struct_info.dtype != "float32": - attn_weights = astype(attn_weights, "float32") - attn_weights = nn.emit(softmax(attn_weights, axis=-1)) - if attn_weights.struct_info.dtype != query_states.struct_info.dtype: - attn_weights = astype(attn_weights, query_states.struct_info.dtype) - if self.config.target_kind == "android": - attn_output = nn.emit(matmul(attn_weights, permute_dims(value_states, [0, 2, 1, 3]))) - else: - attn_output = nn.emit(matmul(attn_weights, value_states)) - attn_output = nn.emit(permute_dims(attn_output, [0, 2, 1, 3])) - return attn_output, past_key_values - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, enable_batching: bool): - attn_class = LlamaPagedAttention if enable_batching else LlamaAttention - self.hidden_size = config.hidden_size - self.self_attn = attn_class(config) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm( - config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = LlamaRMSNorm( - config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps - ) - - def post_self_attn(self, hidden_states, residual): - if self.self_attn.num_shards > 1: - residual = nn.emit( - residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) - ) - hidden_states = nn.emit(residual + hidden_states) - if self.self_attn.num_shards > 1: - hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if self.mlp.num_shards > 1: - residual = nn.emit( - residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) - ) - hidden_states = nn.emit(residual + hidden_states) - if self.mlp.num_shards > 1: - hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - - return hidden_states - - def forward( - self, - hidden_states: relax.Expr, - all_seq_len_shape: Optional[relax.Expr], - past_key_values: Union[relax.Expr, Tuple[relax.Expr]], - layer_id: int, - attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_values=past_key_values, - attention_mask=attention_mask, - all_seq_len_shape=all_seq_len_shape, - layer_id=layer_id, - ) - hidden_states = self.post_self_attn(hidden_states, residual) - return hidden_states, present_key_value - - -def _make_causal_mask(input_ids_shape, dtype, src_len): - from tvm.relax.op import broadcast_to - - bsz, tgt_len = input_ids_shape - - def min_max_triu_te(): - return te.compute( - (tgt_len, tgt_len), - lambda i, j: tvm.tir.Select(j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype)), - name="make_diag_mask_te", - ) - - mask = nn.emit_te(min_max_triu_te) - diag_mask = nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, tgt_len))) - if src_len == tgt_len: - return diag_mask - - def extend_te(x, tgt_len, src_len): - return te.compute( - (bsz, 1, tgt_len, src_len), - lambda b, _, i, j: te.if_then_else( - j < src_len - tgt_len, - tvm.tir.max_value(dtype), - x[b, _, i, j - (src_len - tgt_len)], - ), - name="concat_te", - ) - - return nn.emit_te(extend_te, diag_mask, tgt_len, src_len) - - -class LlamaEmbedTokens(nn.Module): - def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.SizeVar): - self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) - - def forward(self, input_ids: relax.Expr): - inputs_embeds = self.embed_tokens(input_ids) - return inputs_embeds - - -class LlamaEmbedTokensWrapper(nn.Module): - def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.SizeVar): - # build a wrapper to ensure that the naming of the embed_tokens parameter is consistent - self.model = LlamaEmbedTokens(config, vocab_size_var) - - def forward(self, input_ids: relax.Expr): - inputs_embeds = self.model(input_ids) - return inputs_embeds - - -class LlamaModelBase(nn.Module): - def __init__( - self, - config: LlamaConfig, - vocab_size_var: tir.SizeVar, - sep_embed: bool = False, - enable_batching: bool = False, - ): - self.num_shards = config.num_shards - self.padding_idx = config.pad_token_id - self.embed_tokens = None - - if not sep_embed: - self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) - - self.layers = ModuleList( - [LlamaDecoderLayer(config, enable_batching) for _ in range(config.num_hidden_layers)] - ) - self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) - - def forward( - self, - inputs: relax.Expr, - all_seq_len_shape: Optional[relax.Expr], - past_key_values: relax.Expr, - ): - raise NotImplementedError() - - -class LlamaModelForSingleSequence(LlamaModelBase): - def __init__( - self, config: LlamaConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False - ): - super().__init__(config, vocab_size_var, sep_embed, enable_batching=False) - - def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if isinstance(input_shape[-1], tvm.tir.SizeVar) or input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask(input_shape, dtype, src_len) - else: - # Get src_len from input parameters - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - bsz, tgt_len = input_shape - combined_attention_mask = nn.emit( - relax.op.full( - (bsz, 1, tgt_len, src_len), - relax.const(tvm.tir.max_value(dtype).value, dtype), - dtype, - ) - ) - return combined_attention_mask - - def forward( - self, - inputs: relax.Expr, - all_seq_len_shape: Optional[relax.Expr], - past_key_values: relax.Expr, - ): - if self.num_shards > 1: - inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) - if self.embed_tokens: - inputs_embeds = self.embed_tokens(inputs) - else: - inputs_embeds = inputs - # retrieve input_ids - batch_size, seq_length, _ = inputs_embeds.struct_info.shape - seq_length_with_past = all_seq_len_shape.struct_info.values[0] - # embed positions - attention_mask = self._prepare_decoder_attention_mask( - (batch_size, seq_length), - seq_length_with_past, - inputs_embeds.struct_info.dtype, - ) - - hidden_states = inputs_embeds - - # decoder layers - next_decoder_cache = () - - for idx, decoder_layer in enumerate(self.layers): - assert past_key_values is not None - past_key_value = (past_key_values[idx * 2], past_key_values[idx * 2 + 1]) - - hidden_states, key_value_cache = decoder_layer( - hidden_states, - attention_mask=attention_mask, - past_key_values=past_key_value, - all_seq_len_shape=all_seq_len_shape, - layer_id=idx, - ) - next_decoder_cache += key_value_cache - - hidden_states = self.norm(hidden_states) - - assert len(next_decoder_cache) == len(self.layers) * 2 - return hidden_states, next_decoder_cache - - -class LlamaModelForBatching(LlamaModelBase): - def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool): - assert sep_embed - super().__init__(config, vocab_size_var, sep_embed=True, enable_batching=True) - - def forward( - self, - inputs: relax.Expr, - all_seq_len_shape: Optional[relax.Expr], - past_key_values: relax.Expr, - ): - assert all_seq_len_shape is None - if self.num_shards > 1: - inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) - if self.embed_tokens: - inputs_embeds = self.embed_tokens(inputs) - else: - inputs_embeds = inputs - - hidden_states = inputs_embeds - - for idx, decoder_layer in enumerate(self.layers): - assert past_key_values is not None - hidden_states, past_key_values = decoder_layer( - hidden_states, - attention_mask=None, - past_key_values=past_key_values, - all_seq_len_shape=all_seq_len_shape, - layer_id=idx, - ) - - hidden_states = self.norm(hidden_states) - return hidden_states, past_key_values - - -class LlamaForCausalLM(nn.Module): - def __init__( - self, - config: LlamaConfig, - vocab_size_var: tvm.tir.SizeVar, - sep_embed: bool = False, - enable_batching: bool = False, - output_all_logits: bool = False, - ): - model_class = LlamaModelForBatching if enable_batching else LlamaModelForSingleSequence - self.model = model_class(config, vocab_size_var, sep_embed) - self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) - - ############ Rotary embedding constants ############ - assert config.hidden_size % config.num_attention_heads == 0 - head_dim = config.hidden_size // config.num_attention_heads - - # Set the cached sin/cos to the maximum of 2048 and max seq len. - # This will be eliminated further with online rotary embedding calculation. - cache_len = te.var("cached_rotary_embedding_len", "int64") - self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached") - self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached") - - # Mark if output_all_logits is True - self.output_all_logits = output_all_logits - ############ End ############ - - def forward( - self, - inputs: relax.Expr, - all_seq_len_shape: Optional[relax.Expr], - past_key_values: relax.Expr, - logit_positions: Optional[relax.Expr] = None, - ): - hidden_states, key_value_cache = self.model( - inputs=inputs, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - - def te_slicing(x: te.Tensor): - assert x.ndim == 3 - return te.compute( - shape=(x.shape[0], 1, x.shape[2]), - fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], - name="slice", - ) - - if not self.output_all_logits and hidden_states.struct_info.shape[1] != 1: - if logit_positions is None: - hidden_states = nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice") - else: - hidden_states = relax.op.take(hidden_states, logit_positions, axis=1) - logits = self.lm_head(hidden_states) - - if logits.struct_info.dtype != "float32": - logits = nn.emit(relax.op.astype(logits, "float32")) - - return logits, key_value_cache - - -def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: - if "embed_tokens" in name: - return ParamQuantKind.embedding_table - elif "lm_head.weight" in name: - return ParamQuantKind.final_fc_weight - elif param_info.ndim == 2 and name.endswith(".weight"): - return ParamQuantKind.linear_weight - else: - return ParamQuantKind.others - - -def create_embed_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: LlamaConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "embed" - - seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") - with bb.function(func_name): - model = LlamaEmbedTokensWrapper(config, tvm.tir.SizeVar("vocab_size", "int64")) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((1, seq_len), dtype="int32", name="input_ids") - with bb.dataflow(): - inputs_embeds = model(input_ids) - params = [input_ids] + model.parameters() - gv = bb.emit_output(inputs_embeds) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 1)) - - -def create_prefill_func_for_single_seq( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: LlamaConfig, - quant_scheme: QuantizationScheme, - sep_embed: bool = False, -) -> None: - func_name = "prefill_with_embed" if sep_embed else "prefill" - - bsz = 1 - seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") - all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64") - hidden_size = config.hidden_size - with bb.function(func_name): - model = LlamaForCausalLM( - config, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed, enable_batching=False - ) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - inputs = ( - nn.Placeholder((bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds") - if sep_embed - else nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") - ) - all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] - ), - ) - with bb.dataflow(): - logits, key_value_cache = model( - inputs, all_seq_len_shape, past_key_values=past_key_values - ) - params = [ - inputs, - all_seq_len_shape, - past_key_values, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_prefill_func_for_batching( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: LlamaConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "prefill_with_embed" - - bsz = tir.SizeVar("batch_size", "int64") - total_seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") - hidden_size = config.hidden_size - with bb.function(func_name): - model = LlamaForCausalLM( - config, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed=True, enable_batching=True - ) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - inputs = nn.Placeholder( - (1, total_seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds" - ) - logit_pos = nn.Placeholder((bsz,), dtype="int32", name="logit_positions") - past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo()) - with bb.dataflow(): - logits, key_value_cache = model( - inputs, - all_seq_len_shape=None, - past_key_values=past_key_values, - logit_positions=logit_pos, - ) - params = [inputs, logit_pos, past_key_values] + model.parameters() - gv = bb.emit_output((logits, key_value_cache)) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_decoding_func_for_single_seq( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: LlamaConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "decode" - - bsz = 1 - all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64") - - with bb.function(func_name): - model = LlamaForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64")) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") - all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] - ), - ) - with bb.dataflow(): - logits, key_value_cache = model( - input_ids, all_seq_len_shape, past_key_values=past_key_values - ) - params = [ - input_ids, - all_seq_len_shape, - past_key_values, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_decoding_func_for_batching( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: LlamaConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "decode_with_embed" - - bsz = tir.SizeVar("batch_size", "int64") - hidden_size = config.hidden_size - with bb.function(func_name): - model = LlamaForCausalLM( - config, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed=True, enable_batching=True - ) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - inputs = nn.Placeholder((bsz, 1, hidden_size), dtype=config.dtype, name="inputs_embeds") - past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo()) - with bb.dataflow(): - logits, key_value_cache = model( - inputs, all_seq_len_shape=None, past_key_values=past_key_values - ) - params = [inputs, past_key_values] + model.parameters() - gv = bb.emit_output((logits, key_value_cache)) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 2)) - - -def create_verification_func_for_batching( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: LlamaConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "verify_with_embed" - - total_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64") - hidden_size = config.hidden_size - with bb.function(func_name): - model = LlamaForCausalLM( - config, - tvm.tir.SizeVar("vocab_size", "int64"), - sep_embed=True, - enable_batching=True, - output_all_logits=True, - ) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - inputs = nn.Placeholder( - (1, total_seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds" - ) - past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo()) - with bb.dataflow(): - logits, key_value_cache = model( - inputs, - all_seq_len_shape=None, - past_key_values=past_key_values, - ) - params = [inputs, past_key_values] + model.parameters() - gv = bb.emit_output((logits, key_value_cache)) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 2)) - - -def create_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: - num_key_value_heads = config.get_num_key_value_heads() // config.num_shards - init_shape = relax.ShapeExpr( - ( - config.max_sequence_length, - num_key_value_heads, - config.hidden_size // config.num_attention_heads, # head_dim - ) - ) - with bb.function("create_kv_cache", []): - with bb.dataflow(): - zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) - caches = [] - f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") - for _ in range(config.num_hidden_layers * 2): - caches.append( - bb.emit( - relax.call_pure_packed( - f_kv_cache_create, - zeros, - init_shape, - relax.PrimValue(0), - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - ) - gv = bb.emit_output(caches) - bb.emit_func_output(gv) - - -def create_paged_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: - head_dim = config.hidden_size // config.num_attention_heads - num_qo_heads = config.num_attention_heads // config.num_shards - num_kv_heads = config.get_num_key_value_heads() // config.num_shards - - page_size = tir.SizeVar("page_size", "int64") - total_seq_len = tir.SizeVar("total_seq_len", "int64") - reserved_nseq = tir.SizeVar("reserved_nseq", "int64") - cache_config = relax.Var( - "cache_config", - relax.ShapeStructInfo([reserved_nseq, total_seq_len, page_size]), - ) - - with bb.function("create_kv_cache", [cache_config]): - with bb.dataflow(): - zeros = bb.emit(relax.op.zeros((), config.dtype)) - f_kv_cache_create = relax.extern("vm.builtin.paged_attention_kv_cache_create") - cache = bb.emit_output( - relax.call_pure_packed( - f_kv_cache_create, - cache_config, - relax.PrimValue(config.num_hidden_layers), - relax.PrimValue(num_qo_heads), - relax.PrimValue(num_kv_heads), - relax.PrimValue(head_dim), - relax.PrimValue(1), - relax.PrimValue(config.position_embedding_base), - zeros, - bb.get().get_global_var("kv_cache_transpose_append"), - bb.get().get_global_var("attention_prefill"), - bb.get().get_global_var("attention_decode"), - bb.get().get_global_var("attention_prefill_ragged"), - bb.get().get_global_var("attention_prefill_ragged_begin_forward"), - bb.get().get_global_var("attention_prefill_ragged_end_forward"), - bb.get().get_global_var("attention_prefill_begin_forward"), - bb.get().get_global_var("attention_prefill_end_forward"), - bb.get().get_global_var("attention_decode_begin_forward"), - bb.get().get_global_var("attention_decode_end_forward"), - bb.get().get_global_var("attention_rope_in_place"), - bb.get().get_global_var("attention_merge_state"), - bb.get().get_global_var("kv_cache_debug_get_kv"), - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - bb.emit_func_output(cache) - - -def create_softmax_func_for_single_seq(bb: relax.BlockBuilder, config: LlamaConfig) -> None: - with bb.function("softmax_with_temperature"): - logits = nn.Placeholder( - (1, 1, tvm.tir.SizeVar("vocab_size", "int64")), dtype="float32", name="logits" - ) - temperature = nn.Placeholder((), dtype="float32", name="temperature") - with bb.dataflow(): - div = bb.emit(relax.op.divide(logits, temperature)) - softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) - gv = bb.emit_output(softmax) - bb.emit_func_output(gv, [logits, temperature]) - - -def create_softmax_func_for_batching(bb: relax.BlockBuilder, config: LlamaConfig) -> None: - with bb.function("softmax_with_temperature"): - bsz = tvm.tir.SizeVar("batch_size", "int64") - logits = nn.Placeholder( - (bsz, 1, tvm.tir.SizeVar("vocab_size", "int64")), - dtype="float32", - name="logits", - ) - temperature = nn.Placeholder((bsz,), dtype="float32", name="temperature") - with bb.dataflow(): - t_reshaped = bb.emit(relax.op.reshape(temperature, (bsz, 1, 1))) - div = bb.emit(relax.op.divide(logits, t_reshaped)) - softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) - gv = bb.emit_output(softmax) - bb.emit_func_output(gv, [logits, temperature]) - - -def emit_paged_kv_cache_op(bb: relax.BlockBuilder, config: LlamaConfig) -> None: - from tvm.script import tir as T - - num_kv_heads = config.get_num_key_value_heads() // config.num_shards - head_dim = config.hidden_size // config.num_attention_heads - - @T.prim_func - def kv_cache_transpose_append( - var_pages: T.handle, - var_k_data: T.handle, - var_v_data: T.handle, - var_position_map: T.handle, - ): - ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") - page_size = T.SizeVar("page_size", "int64") - num_pages = T.int64() - - pages = T.match_buffer( - var_pages, (num_pages, 2, num_kv_heads, page_size, head_dim), config.dtype - ) - k_data = T.match_buffer(var_k_data, (ntoken, num_kv_heads, head_dim), config.dtype) - v_data = T.match_buffer(var_v_data, (ntoken, num_kv_heads, head_dim), config.dtype) - position_map = T.match_buffer(var_position_map, (ntoken,), "int32") - - for global_pos, h, f in T.grid(ntoken, num_kv_heads, head_dim): - with T.block("k_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - position: T.int64 = T.Cast("int64", position_map[vgpos]) - pages[ - T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf - ] = k_data[vgpos, vh, vf] - with T.block("v_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - position: T.int64 = T.Cast("int64", position_map[vgpos]) - pages[ - T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf - ] = v_data[vgpos, vh, vf] - - @T.prim_func - def kv_cache_debug_get_kv( - var_pages: T.handle, - var_position_map: T.handle, - var_k_data: T.handle, - var_v_data: T.handle, - layer_id: T.int64, - ): - seqlen = T.SizeVar("seqlen", "int64") - page_size = T.SizeVar("page_size", "int64") - num_pages = T.int64() - - pages = T.match_buffer( - var_pages, (num_pages, 2, num_kv_heads, page_size, head_dim), config.dtype - ) - position_map = T.match_buffer(var_position_map, (seqlen,), "int32") - k_data = T.match_buffer( - var_k_data, (config.num_hidden_layers, seqlen, num_kv_heads, head_dim), config.dtype - ) - v_data = T.match_buffer( - var_v_data, (config.num_hidden_layers, seqlen, num_kv_heads, head_dim), config.dtype - ) - - for p, h, d in T.grid(seqlen, num_kv_heads, head_dim): - with T.block("copy0"): - vp, vh, vd = T.axis.remap("SSS", [p, h, d]) - position: T.int64 = T.Cast("int64", position_map[vp]) - k_data[layer_id, vp, vh, vd] = pages[ - T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd - ] - v_data[layer_id, vp, vh, vd] = pages[ - T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vd - ] - - bb.add_func(kv_cache_transpose_append, "kv_cache_transpose_append") - bb.add_func(kv_cache_debug_get_kv, "kv_cache_debug_get_kv") - bb.add_func(relax.extern("paged_kv_cache.attention_kernel_prefill"), "attention_prefill") - bb.add_func(relax.extern("paged_kv_cache.attention_kernel_decode"), "attention_decode") - bb.add_func( - relax.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), - "attention_prefill_ragged", - ) - bb.add_func( - relax.extern("paged_kv_cache.attention_kernel_prefill_begin_forward"), - "attention_prefill_begin_forward", - ) - bb.add_func( - relax.extern("paged_kv_cache.attention_kernel_prefill_end_forward"), - "attention_prefill_end_forward", - ) - bb.add_func( - relax.extern("paged_kv_cache.attention_kernel_decode_begin_forward"), - "attention_decode_begin_forward", - ) - bb.add_func( - relax.extern("paged_kv_cache.attention_kernel_decode_end_forward"), - "attention_decode_end_forward", - ) - bb.add_func( - relax.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), - "attention_prefill_ragged_begin_forward", - ) - bb.add_func( - relax.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), - "attention_prefill_ragged_end_forward", - ) - bb.add_func( - relax.extern("flashinfer.merge_state_in_place"), - "attention_merge_state", - ) - bb.add_func( - relax.extern("flashinfer.batch_qk_apply_rotary_in_place"), - "attention_rope_in_place", - ) - - -def setup_params(mod, param_manager, dtype, config, args): - def f_convert_pname_fwd(pname: str) -> List[str]: - if not config.combine_matmul: - return [pname] - - qkv_str = "query_key_value_proj" - gate_up_str = "gate_up_proj" - if qkv_str in pname: - return [ - pname.replace(qkv_str, "q_proj"), - pname.replace(qkv_str, "k_proj"), - pname.replace(qkv_str, "v_proj"), - ] - elif gate_up_str in pname: - return [ - pname.replace(gate_up_str, "gate_proj"), - pname.replace(gate_up_str, "up_proj"), - ] - else: - return [pname] - - def f_convert_param_bkwd(torch_pname: str, torch_param): - if not config.combine_matmul: - return [(torch_pname, torch_param.astype(dtype))] - - combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] - if any([name in torch_pname for name in combined_layers]): - return None - return [(torch_pname, torch_param.astype(dtype))] - - def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): - # Expected to enter this function only for the combined linear matmul weights. - # Other weights are supposed to be loaded in `f_convert_param_bkwd` since - # each other relax param has a unique corresponding torch param. - if not config.combine_matmul: - # When matmul combination is not turned on, each relax param has a unique - # corresponding torch param, and this function is not expected to be entered. - raise NotImplementedError( - "Matmul combination is not turned on, and the function " - "is not expected to be entered" - ) - hidden_size = config.hidden_size - head_dim = config.hidden_size // config.num_attention_heads - - if "query_key_value_proj" in relax_pname: - q_heads = config.num_attention_heads - kv_heads = config.get_num_key_value_heads() - q, k, v = torch_params - assert q.shape == (q_heads * head_dim, hidden_size) - assert k.shape == (kv_heads * head_dim, hidden_size) - assert v.shape == (kv_heads * head_dim, hidden_size) - qkv = np.concatenate([q, k, v], axis=0).astype(dtype) - return qkv - if "gate_up_proj" in relax_pname: - gate, up = torch_params - gate_up = np.concatenate([gate, up], axis=0).astype(dtype) - return gate_up - raise ValueError("Unexpected param loading") - - param_manager.set_param_loading_func( - args.model_path, - args.use_safetensors, - f_convert_pname_fwd, - f_convert_param_bkwd, - f_compute_relax_param, - ) - - device = tvm.cpu() - param_list = [None] * param_manager.nparam_to_load - - head_dim = config.hidden_size / config.num_attention_heads - inv_freq = 1.0 / ( - config.position_embedding_base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim) - ) - - # The following cos/sin values can be removed but **are kept for compatibility issues**. - t = np.arange(2048, dtype=inv_freq.dtype) - freqs = np.einsum("i,j->ij", t, inv_freq) - emb = np.concatenate((freqs, freqs), axis=-1) - param_list[-2] = tvm.nd.array(np.cos(emb).astype(config.dtype), device) - param_list[-1] = tvm.nd.array(np.sin(emb).astype(config.dtype), device) - - return mod, param_manager, param_list, config - - -def get_model(args, hf_config): - model_name = args.model - dtype = args.quantization.model_dtype - enable_batching = args.enable_batching - sep_embed = args.sep_embed - - if enable_batching and not sep_embed: - raise ValueError("`sep_embed` is required when batching is enabled.") - - position_embedding_base = 10000 - - if "rope_theta" in hf_config: - position_embedding_base = hf_config["rope_theta"] - - # Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards, - # while Llama-1 variants use `max_sequence_length`. - # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. - # If none of them is defined, throw an error. - if "max_sequence_length" in hf_config: - config = LlamaConfig( - **hf_config, - dtype=dtype, - position_embedding_base=position_embedding_base, - combine_matmul=True, - num_shards=args.num_shards, - build_model_only=args.build_model_only, - target_kind=args.target_kind, - ) - elif "max_position_embeddings" in hf_config: - config = LlamaConfig( - **hf_config, - dtype=dtype, - max_sequence_length=hf_config["max_position_embeddings"], - position_embedding_base=position_embedding_base, - combine_matmul=True, - num_shards=args.num_shards, - build_model_only=args.build_model_only, - target_kind=args.target_kind, - ) - else: - raise Exception( - "The model config should contain information about maximum sequence length." - ) - - # If there is a user-provided maximum sequence length, override hf config. - if args.max_seq_len != -1: - config.max_sequence_length = args.max_seq_len - - param_manager = ParamManager() - bb = relax.BlockBuilder() - - if sep_embed: - create_embed_func(bb, param_manager, config, args.quantization) - - if enable_batching: - emit_paged_kv_cache_op(bb, config) - create_prefill_func_for_batching(bb, param_manager, config, args.quantization) - create_decoding_func_for_batching(bb, param_manager, config, args.quantization) - create_verification_func_for_batching(bb, param_manager, config, args.quantization) - create_paged_kv_cache_func(bb, config) - create_softmax_func_for_batching(bb, config) - else: - create_prefill_func_for_single_seq(bb, param_manager, config, args.quantization, sep_embed) - create_decoding_func_for_single_seq(bb, param_manager, config, args.quantization) - create_kv_cache_func(bb, config) - create_softmax_func_for_single_seq(bb, config) - - create_metadata_func( - bb, - model_name=model_name, - max_window_size=config.max_sequence_length, - stop_tokens=[2], - add_prefix_space=False, - prefill_chunk_size=args.prefill_chunk_size, - ) - - mod = bb.get() - - tir_bound_map = dict() - tir_bound_map["num_tokens_without_cache"] = ( - args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length - ) - tir_bound_map["num_tokens_with_cache"] = config.max_sequence_length - tir_bound_map["vocab_size"] = args.max_vocab_size - if enable_batching: - tir_bound_map["nseq"] = args.max_batch_size - for gv in mod.functions: - func = mod[gv] - if isinstance(func, relax.Function): - mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) - - if args.build_model_only: - return mod, param_manager, None, config - - return setup_params(mod, param_manager, dtype, config, args) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py deleted file mode 100644 index 4ff6fb0621..0000000000 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ /dev/null @@ -1,662 +0,0 @@ -from typing import Optional, Tuple - -import numpy as np -import tvm -from tvm import relax, te -from tvm.ir import VDevice -from tvm.relax.op import ccl, concat, expand_dims, repeat, reshape, take, zeros -from tvm.relax.op.nn import attention_var_len -from tvm.relax.testing import nn -from tvm.script import relax as R -from tvm.script.ir_builder import tir as T - -from ..quantization import QuantizationScheme -from .llama import ( - Embedding, - Linear, - LlamaAttentionBase, - LlamaConfig, - LlamaDecoderLayer, - LlamaRMSNorm, - get_param_quant_kind, - rotary_modulate_by_freq, - setup_params, -) -from .modules import ModuleList -from .param_manager import ParamManager - - -def apply_rotary_pos_emb(q, k, positions, position_embedding_base): - def f_rotary_embedding(tensor, pos_tensor): - def rotary_compute(*idx): - pos = pos_tensor[idx[0]].astype("float32") - return rotary_modulate_by_freq( - tensor, - idx, - pos, - position_embedding_base, - ) - - return tvm.te.compute(tensor.shape, rotary_compute, name="rotary") - - q_embed = nn.emit_te(f_rotary_embedding, q, positions, primfunc_name_hint="rotary_embedding") - k_embed = nn.emit_te(f_rotary_embedding, k, positions, primfunc_name_hint="rotary_embedding") - return q_embed, k_embed - - -class LlamaAttentionBatched(LlamaAttentionBase): - def __init__(self, config: LlamaConfig, head_mapping: relax.Constant): - super().__init__(config) - self.head_mapping = head_mapping # (num_heads,), used by vLLM for multi-query attention - self.sliding_window = None - - if config.sliding_window: - self.sliding_window = T.IntImm("int32", config.sliding_window) - - def forward( - self, - hidden_states: relax.Expr, # (num_token, hidden_size) - positions: relax.Expr, # (num_token,), for batched RoPE - seq_lens: relax.Expr, # (num_seq,) - kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], - slot_mapping: Optional[relax.Expr], # (num_token,) - max_seqlen: Optional[relax.Expr], # (), must be on CPU - seqstart: Optional[relax.Expr], # (num_seq + 1,), for prefill - block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode - indices_within_window: Optional[ - relax.Expr - ], # (num_cached_total,), for prefill with sliding-window attention - ): - num_tokens, _ = hidden_states.struct_info.shape - - queries, keys, values = self.project_qkv( - hidden_states, - (num_tokens, self.num_query_heads, self.head_dim), - (num_tokens, self.num_key_value_heads, self.head_dim), - ) - - queries, keys = apply_rotary_pos_emb(queries, keys, positions, self.position_embedding_base) - - if kv_cache: - # Paged KV cache update - k_cache, v_cache = kv_cache - - if self.sliding_window is None or block_tables: - # For decode or prefill without sliding window, cache all keys / values. - keys_to_cache = keys - values_to_cache = values - else: - # Cache only the most recent keys and values within the window. - keys_to_cache = nn.emit(take(keys, indices_within_window, axis=0)) - values_to_cache = nn.emit(take(values, indices_within_window, axis=0)) - slot_mapping = nn.emit(take(slot_mapping, indices_within_window, axis=0)) - - # kv caches are updated inplace, takes ownership of the arguments - kv = nn.emit( - relax.op.call_inplace_packed( - "tvm.contrib.vllm.reshape_and_cache", - keys_to_cache, - values_to_cache, - k_cache, - v_cache, - slot_mapping, - inplace_indices=[2, 3], - sinfo_args=[k_cache.struct_info, v_cache.struct_info], - ) - ) - - k_cache, v_cache = kv[0], kv[1] - else: - k_cache = v_cache = None - - if seqstart: - # Prefill, batched attention over variable sequence lengths - attn_output = nn.emit( - attention_var_len( - nn.emit(expand_dims(queries, axis=0)), - nn.emit(expand_dims(keys, axis=0)), - nn.emit(expand_dims(values, axis=0)), - seqstart_q=seqstart, - max_seqlen_q=max_seqlen, - causal_mask="BottomRight", - window_size=self.sliding_window, - ) - ) - else: - # Decode, using vLLM kernel - attn_output = nn.emit( - relax.op.call_dps_packed( - "tvm.contrib.vllm.single_query_cached_kv_attention", - [ - queries, - k_cache, - v_cache, - self.head_mapping, - block_tables, - seq_lens, - 16, # block_size - max_seqlen, - ], - out_sinfo=queries.struct_info, - ) - ) - - attn_output = nn.emit( - reshape(attn_output, (num_tokens, self.num_query_heads * self.head_dim)) - ) - attn_output = self.o_proj(attn_output) - - return attn_output, (k_cache, v_cache) - - -class LlamaDecoderLayerBatched(LlamaDecoderLayer): - def __init__(self, config: LlamaConfig, head_mapping: relax.Constant): - super().__init__(config, False) - self.self_attn = LlamaAttentionBatched(config, head_mapping) - - def forward( - self, - hidden_states: relax.Expr, - positions: relax.Expr, - seq_lens: relax.Expr, - kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], - slot_mapping: Optional[relax.Expr], - max_seqlen: Optional[relax.Expr], - seqstart: Optional[relax.Expr], - block_tables: Optional[relax.Expr], - indices_within_window: Optional[relax.Expr], - ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, new_kv = self.self_attn( - hidden_states=hidden_states, - positions=positions, - seq_lens=seq_lens, - kv_cache=kv_cache, - slot_mapping=slot_mapping, - max_seqlen=max_seqlen, - seqstart=seqstart, - block_tables=block_tables, - indices_within_window=indices_within_window, - ) - - hidden_states = self.post_self_attn(hidden_states, residual) - - return hidden_states, new_kv - - -class LlamaModel(nn.Module): - def __init__( - self, - config: LlamaConfig, - cpu_device: VDevice, - vocab_size_var: tvm.tir.SizeVar, - sep_embed: bool = False, - ): - self.padding_idx = config.pad_token_id - self.embed_tokens = None - - num_query_heads = config.num_attention_heads // config.num_shards - num_key_value_heads = config.get_num_key_value_heads() // config.num_shards - num_queries_per_kv = num_query_heads // num_key_value_heads - head_mapping = relax.const( - tvm.nd.array( - np.repeat(np.arange(num_key_value_heads, dtype="int32"), num_queries_per_kv) - ) - ) - - if not sep_embed: - self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) - - self.layers = ModuleList( - [ - LlamaDecoderLayerBatched(config, head_mapping) - for _ in range(config.num_hidden_layers) - ] - ) - self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) - - self.cpu_device = cpu_device - - def forward( - self, - inputs: relax.Expr, - positions: relax.Expr, - seq_lens: relax.Expr, - kv_caches: Optional[relax.Expr], - slot_mapping: Optional[relax.Expr], - seqstart: Optional[relax.Expr], - block_tables: Optional[relax.Expr], - indices_within_window: Optional[relax.Expr], - ): - if self.embed_tokens: - inputs_embeds = self.embed_tokens(inputs) - else: - inputs_embeds = inputs - - hidden_states = inputs_embeds - - # max_seqlen needs to be on CPU, so that vLLM and Flash Attention can directly get the - # integer length by max_seqlen->data[0]. Otherwise, we need to repeatedly do cudaMemcpy - # of a single int32. - max_seqlen = R.to_vdevice(R.max(seq_lens), self.cpu_device) - - new_kvs = () - - for idx, decoder_layer in enumerate(self.layers): - if kv_caches: - cache = (kv_caches[2 * idx], kv_caches[2 * idx + 1]) - else: - cache = None - - hidden_states, new_kv = decoder_layer( - hidden_states, - positions, - seq_lens, - cache, - slot_mapping, - max_seqlen, - seqstart, - block_tables, - indices_within_window, - ) - new_kvs += new_kv - - return self.norm(hidden_states), new_kvs - - -class LlamaForCausalLM(nn.Module): - def __init__( - self, - config: LlamaConfig, - cpu_device: VDevice, - vocab_size_var: tvm.tir.SizeVar, - sep_embed: bool = False, - ): - self.num_shards = config.num_shards - self.model = LlamaModel(config, cpu_device, vocab_size_var, sep_embed) - self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) - - ############ Rotary embedding constants ############ - assert config.hidden_size % config.num_attention_heads == 0 - head_dim = config.hidden_size // config.num_attention_heads - - # Set the cached sin/cos to the maximum of 2048 and max seq len. - # This will be eliminated further with online rotary embedding calculation. - cache_len = te.var("cached_rotary_embedding_len", "int64") - self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached") - self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached") - ############ End ############ - - def forward( - self, - input_ids: relax.Expr, # (num_token,) - positions: relax.Expr, # (num_token,), for batched RoPE - seq_lens: relax.Expr, # (num_seq,) - kv_caches: Optional[relax.Expr], # For prefill and decode, not needed for evaluate - slot_mapping: Optional[ - relax.Expr - ], # (num_token,), for prefill and decode, not needed for evaluate - block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode - indices_within_window: Optional[ - relax.Expr - ], # (num_cached_total,), for prefill with sliding-window attention - ): - """ - In vLLM, the paged KV cache is simply a pair of tensors, one for keys and the other - for values. The tensor has shape (num_blocks, num_kv_heads, head_size, block_size). - (In practice, the key cache has a slightly different shape for an efficiency reason, - but that's not important.) - - The mapping between sequences / tokens to blocks is specified by two inputs. - - block_tables: A list of block IDs allocated for the sequence. - - slot_mapping: A linear index into the 2D grid (num_blocks, block_size), for each token. - - Support for sliding-window attention is realized by making a block table a circular buffer. - So the length of a block table for each sequence is at most ceil(window_size / block_size). - - With sliding window, not all past K / V values need to be cached during prefill. - The last input, indices_within_window, tells which tokens among (num_token,) need to have - their K / V values cached. - """ - if self.num_shards > 1: - input_ids = nn.emit(ccl.broadcast_from_worker0(input_ids)) - positions = nn.emit(ccl.broadcast_from_worker0(positions)) - seq_lens = nn.emit(ccl.broadcast_from_worker0(seq_lens)) - - if slot_mapping: - slot_mapping = nn.emit(ccl.broadcast_from_worker0(slot_mapping)) - - if block_tables: - block_tables = nn.emit(ccl.broadcast_from_worker0(block_tables)) - - if indices_within_window: - indices_within_window = nn.emit(ccl.broadcast_from_worker0(indices_within_window)) - - is_prompt = block_tables is None - - if is_prompt: # prefill and evaluate - # https://github.com/apache/tvm/issues/15851 for why we need to use Thrust - cumsum = nn.emit( - relax.op.call_dps_packed( - "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info - ) - ) - seqstart = nn.emit(concat([zeros((1,), "int32"), cumsum])) - else: - seqstart = None - - hidden_states, new_kvs = self.model( - input_ids, - positions, - seq_lens, - kv_caches, - slot_mapping, - seqstart, - block_tables, - indices_within_window, - ) - - if is_prompt: - # Extract logits for the last token in each sequence - - def get_logits_last_tokens(x, seq_len_tensor, seqstart): - return te.compute( - shape=(seq_len_tensor.shape[0], x.shape[-1]), - fcompute=lambda i, j: x[seqstart[i] + seq_len_tensor[i] - 1, j], - name="get_logits_last_tokens", - ) - - logits = self.lm_head( - nn.emit_te( - get_logits_last_tokens, - hidden_states, - seq_lens, - seqstart, - primfunc_name_hint="get_logits_last_tokens", - ) - ) - else: - logits = self.lm_head(hidden_states) - - if logits.struct_info.dtype != "float32": - logits = nn.emit(relax.op.astype(logits, "float32")) - - return logits, new_kvs - - -def get_inputs( - num_token, num_seq, config, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True -): - hidden_size = config.hidden_size - - inputs = ( - nn.Placeholder((num_token, hidden_size), dtype=config.dtype, name="inputs_embeds") - if sep_embed - else nn.Placeholder((num_token,), dtype="int32", name="input_ids") - ) - - seq_lens = nn.Placeholder((num_seq,), dtype="int32", name="seq_lens") - positions = nn.Placeholder((num_token,), dtype="int32", name="positions") - - if need_cache: - num_blocks = tvm.tir.SizeVar("num_blocks", "int64") - block_size = 16 - - vec_size = 8 # 128 bit, fp16 x 8 - num_key_value_heads = config.get_num_key_value_heads() // config.num_shards - head_size = hidden_size // config.num_attention_heads - - k_cache_shape = ( - num_blocks, - num_key_value_heads, - head_size // vec_size, - block_size, - vec_size, - ) - v_cache_shape = (num_blocks, num_key_value_heads, head_size, block_size) - - get_cache_sinfo = lambda i: relax.TensorStructInfo( - k_cache_shape if i % 2 == 0 else v_cache_shape, dtype="float16" - ) - - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo( - [get_cache_sinfo(i) for i in range(config.num_hidden_layers * 2)] - ), - ) - slot_mapping = nn.Placeholder((num_token,), dtype="int32", name="slot_mapping") - else: - past_key_values = None - slot_mapping = None - block_tables = None - - if max_num_blocks_per_seq is None: - block_tables = None - else: - block_tables = nn.Placeholder( - (num_seq, max_num_blocks_per_seq), dtype="int32", name="block_tables" - ) - - return inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables - - -def create_evaluate_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: LlamaConfig, - cpu_dev: VDevice, - quant_scheme: QuantizationScheme, - sep_embed: bool = False, -) -> None: - """Evaluate logits for the last token in each sequence. Same as prefill but without KV cache.""" - func_name = "evaluate" - - num_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") - num_seq = tvm.tir.SizeVar("batch_size", "int64") - - with bb.function(func_name): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - inputs, positions, seq_lens, _, _, _ = get_inputs( - num_token, num_seq, config, sep_embed=sep_embed - ) - - with bb.dataflow(): - logits, _ = model( - inputs, - positions, - seq_lens, - kv_caches=None, - slot_mapping=None, - block_tables=None, - indices_within_window=None, - ) - params = [ - inputs, - positions, - seq_lens, - ] + model.parameters() - gv = bb.emit_output(logits) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_encoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: LlamaConfig, - cpu_dev: VDevice, - quant_scheme: QuantizationScheme, - sep_embed: bool = False, -) -> None: - """Batched prefill with vLLM paged KV cache. - - The batched attention op is intended to be offloaded to CUTLASS or Flash Attention - via BYOC. - """ - func_name = "prefill_with_embed" if sep_embed else "prefill" - - num_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") - num_seq = tvm.tir.SizeVar("batch_size", "int64") - - num_inputs = 5 - - with bb.function(func_name): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( - num_token, num_seq, config, sep_embed=sep_embed - ) - - with bb.dataflow(): - params = [ - input_ids, - positions, - seq_lens, - past_key_values, - slot_mapping, - ] - - inputs = [ - input_ids, - positions, - seq_lens, - past_key_values, - slot_mapping, - None, # block_tables - ] - - if config.sliding_window: - num_inputs += 1 - # The value of num_cached_total is between - # num_token (if seq_len < sliding_window for all seq) and - # num_seq * config.sliding_window (if seq_len > sliding_window for all seq) - num_cached_total = tvm.tir.SizeVar("num_cached_total", "int64") - indices_within_window = nn.Placeholder( - (num_cached_total,), dtype="int32", name="indices_within_window" - ) - inputs.append(indices_within_window) - params.append(indices_within_window) - else: - inputs.append(None) - - logits, new_kvs = model(*inputs) - gv = bb.emit_output((logits, relax.Tuple(new_kvs))) - - bb.emit_func_output(gv, params + model.parameters()) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", num_inputs)) - - -def create_decoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: LlamaConfig, - cpu_dev: VDevice, - quant_scheme: QuantizationScheme, -) -> None: - """Batched decoding with vLLM paged KV cache.""" - func_name = "decode" - - num_seq = tvm.tir.SizeVar("batch_size", "int64") - max_num_blocks_per_seq = tvm.tir.SizeVar("max_num_blocks_per_seq", "int64") - - with bb.function(func_name): - inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables = get_inputs( - num_seq, num_seq, config, max_num_blocks_per_seq - ) - - with bb.dataflow(): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64")) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - logits, new_kvs = model( - inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables, None - ) - params = [ - inputs, - positions, - seq_lens, - past_key_values, - slot_mapping, - block_tables, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(new_kvs))) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 6)) - - -def get_model(args, hf_config): - dtype = args.quantization.model_dtype - sep_embed = False - - position_embedding_base = 10000 - - if "rope_theta" in hf_config: - position_embedding_base = hf_config["rope_theta"] - - # Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards, - # while Llama-1 variants use `max_sequence_length`. - # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. - # If none of them is defined, throw an error. - if "max_sequence_length" in hf_config: - config = LlamaConfig( - **hf_config, - dtype=dtype, - position_embedding_base=position_embedding_base, - combine_matmul=True, - num_shards=args.num_shards, - build_model_only=args.build_model_only, - ) - elif "max_position_embeddings" in hf_config: - config = LlamaConfig( - **hf_config, - dtype=dtype, - max_sequence_length=hf_config["max_position_embeddings"], - position_embedding_base=position_embedding_base, - combine_matmul=True, - num_shards=args.num_shards, - build_model_only=args.build_model_only, - ) - else: - raise Exception( - "The model config should contain information about maximum sequence length." - ) - - # If there is a user-provided maximum sequence length, override hf config. - if args.max_seq_len != -1: - config.max_sequence_length = args.max_seq_len - - param_manager = ParamManager() - bb = relax.BlockBuilder() - - # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. - cpu_dev = VDevice("llvm", 0, "global") - - create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) - create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) - create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) - - mod = bb.get() - - mod.update_global_info("vdevice", [cpu_dev]) - - if args.build_model_only: - return mod, param_manager, None, config - - return setup_params(mod, param_manager, dtype, config, args) diff --git a/mlc_llm/relax_model/minigpt.py b/mlc_llm/relax_model/minigpt.py deleted file mode 100644 index 96126bbf5b..0000000000 --- a/mlc_llm/relax_model/minigpt.py +++ /dev/null @@ -1,627 +0,0 @@ -import math -import os -from dataclasses import dataclass - -import torch -import tvm -from tvm import relax -from tvm.relax.testing import nn - - -from ..quantization import ParamQuantKind, QuantizationScheme -from .modules import ModuleList, TransformImage -from .param_manager import ParamManager - - -@dataclass -class MiniGPTConfig: - dtype: str = "float16" - in_chan: int = 4 # represent rgba - image_size: int = 224 - num_query_token: int = 32 - max_txt_len: int = 160 - vocab_size: int = 32000 - patch_size: int = 14 - word_embed: int = 768 - visual_encoder_embed_dim: int = 1408 - visual_encoder_attn_heads: int = 16 - visual_encoder_attn_hidden_dim: int = 257 - visual_encoder_fc_hidden_dim: int = 6144 - visual_encoder_num_blocks: int = 39 - bert_hidden_layers: int = 12 - bert_num_attn_heads: int = 12 - bert_attn_head_size: int = 64 - bert_interm_query: int = 3072 - llama_proj_size: int = 4096 - - -MODEL_CONFIG = { - "minigpt4-7b": {}, -} - - -class MiniGPTPatchEmbed(nn.Module): - def __init__( - self, image_size, patch_size, embed_dim, dtype: str, in_chans=3, bias=True - ): - self.strides = (patch_size, patch_size) - self.embed_dim = embed_dim - self.out_shape = image_size // patch_size - - bs = 1 - self.cls_token = nn.Parameter((bs, 1, embed_dim), dtype=dtype, name="cls_token") - self.pos_embed = nn.Parameter( - (1, self.out_shape * self.out_shape + 1, embed_dim), - dtype=dtype, - name="pos_embed", - ) - self.weight = nn.Parameter( - (embed_dim, in_chans, patch_size, patch_size), - dtype=dtype, - name="patch_embed_weight", - ) - if bias: - self.bias = nn.Parameter((embed_dim,), dtype=dtype, name="patch_embed_bias") - else: - self.bias = None - - def forward(self, input: relax.Expr) -> relax.Var: - bs = 1 - x = nn.emit(relax.op.nn.conv2d(input, self.weight, self.strides)) - if self.bias: - bias = relax.op.reshape(self.bias, [1, self.embed_dim, 1, 1]) - x = relax.op.add(x, bias) - x = relax.op.reshape(x, (bs, self.embed_dim, self.out_shape * self.out_shape)) - x = relax.op.permute_dims(x, [0, 2, 1]) - # concatenate with cls_tokens - x_concat = relax.op.concat([self.cls_token, x], axis=1) - # add with pos_embed - res = relax.op.add(x_concat, self.pos_embed) - return res - - -class MiniGPTVisualEncoderAttention(nn.Module): - def __init__(self, config: MiniGPTConfig): - self.embed_dim = config.visual_encoder_embed_dim - self.num_heads = config.visual_encoder_attn_heads - self.head_dim = self.embed_dim // self.num_heads - self.scale = self.head_dim ** (-0.5) - self.dtype = config.dtype - self.N = config.visual_encoder_attn_hidden_dim - - self.q_bias = nn.Parameter((self.embed_dim,), dtype=self.dtype, name="q_bias") - self.v_bias = nn.Parameter((self.embed_dim,), dtype=self.dtype, name="v_bias") - self.qkv_weight = nn.Parameter( - (self.embed_dim * 3, self.embed_dim), dtype=self.dtype, name="qkv_weight" - ) - self.proj_weight = nn.Parameter( - (self.embed_dim, self.embed_dim), dtype=self.dtype, name="proj_weight" - ) - self.proj_bias = nn.Parameter( - (self.embed_dim,), dtype=self.dtype, name="proj_bias" - ) - - def forward(self, input: relax.Expr): - from tvm.relax.op import ( - concat, - linear, - matmul, - permute_dims, - reshape, - squeeze, - strided_slice, - zeros, - ) - - bs = 1 - k_bias = zeros((self.embed_dim,), self.dtype) - qkv_bias = concat([self.q_bias, k_bias, self.v_bias], axis=0) - x = linear(input, self.qkv_weight, qkv_bias) - x = reshape(x, (bs, self.N, 3, self.num_heads, self.head_dim)) - x = permute_dims(x, [2, 0, 3, 1, 4]) - q = squeeze(strided_slice(x, axes=[0], begin=[0], end=[1]), [0]) - k = squeeze(strided_slice(x, axes=[0], begin=[1], end=[2]), [0]) - v = squeeze(strided_slice(x, axes=[0], begin=[2], end=[3]), [0]) - q = q * relax.const(self.scale, self.dtype) - attn = matmul(q, permute_dims(k, [0, 1, 3, 2])) - attn = relax.op.nn.softmax(attn, -1) - res = permute_dims(matmul(attn, v), [0, 2, 1, 3]) - res = reshape(res, (bs, self.N, self.embed_dim)) - res = linear(res, self.proj_weight, self.proj_bias) - return res - - -class MiniGPTMLP(nn.Module): - def __init__(self, config: MiniGPTConfig): - self.hidden_dim = config.visual_encoder_fc_hidden_dim - self.embed_dim = config.visual_encoder_embed_dim - self.dtype = config.dtype - - self.fc1_weight = nn.Parameter( - (self.hidden_dim, self.embed_dim), dtype=self.dtype, name="fc1_weight" - ) - self.fc1_bias = nn.Parameter( - (self.hidden_dim,), dtype=self.dtype, name="fc1_bias" - ) - self.fc2_weight = nn.Parameter( - (self.embed_dim, self.hidden_dim), dtype=self.dtype, name="fc2_weight" - ) - self.fc2_bias = nn.Parameter( - (self.embed_dim,), dtype=self.dtype, name="fc2_bias" - ) - - def forward(self, input: relax.Expr): - res = relax.op.linear(input, self.fc1_weight, self.fc1_bias) - res = relax.op.nn.gelu(res) - res = relax.op.linear(res, self.fc2_weight, self.fc2_bias) - return res - - -class MiniGPTVisualEncoderBlock(nn.Module): - def __init__(self, config: MiniGPTConfig): - embed_dim = config.visual_encoder_embed_dim - dtype = config.dtype - self.norm1_weight = nn.Parameter((embed_dim,), dtype=dtype, name="norm1_weight") - self.norm1_bias = nn.Parameter((embed_dim,), dtype=dtype, name="norm1_bias") - self.attn = MiniGPTVisualEncoderAttention(config) - self.norm2_weight = nn.Parameter((embed_dim,), dtype=dtype, name="norm2_weight") - self.norm2_bias = nn.Parameter((embed_dim,), dtype=dtype, name="norm2_bias") - self.mlp = MiniGPTMLP(config) - - def forward(self, input: relax.Expr): - x = relax.op.nn.layer_norm(input, self.norm1_weight, self.norm1_bias, axes=[-1]) - proj = self.attn(x) - proj = relax.op.add(input, proj) - res = relax.op.nn.layer_norm( - proj, self.norm2_weight, self.norm2_bias, axes=[-1] - ) - res = self.mlp(res) - res = relax.op.add(proj, res) - return res - - -class MiniGPTVisualEncoder(nn.Module): - def __init__(self, config: MiniGPTConfig): - self.embed_dim = config.visual_encoder_embed_dim - self.dtype = config.dtype - self.transform = TransformImage(config.dtype, config.in_chan) - self.patch_embed = MiniGPTPatchEmbed( - config.image_size, - config.patch_size, - config.visual_encoder_embed_dim, - config.dtype, - ) - self.num_blocks = config.visual_encoder_num_blocks - self.blocks = ModuleList( - [MiniGPTVisualEncoderBlock(config) for _ in range(self.num_blocks)] - ) - - self.ln_vision_weight = nn.Parameter( - (self.embed_dim,), dtype=self.dtype, name="ln_vision_weight" - ) - self.ln_vision_bias = nn.Parameter( - (self.embed_dim,), dtype=self.dtype, name="ln_vision_bias" - ) - - def forward(self, input_image: relax.Expr): - res = self.transform(input_image) - res = self.patch_embed(res) - for block in self.blocks: - res = block(res) - res = relax.op.nn.layer_norm( - res, self.ln_vision_weight, self.ln_vision_bias, axes=[-1] - ) - return res - - -class MiniGPTEmbedding(nn.Module): - def __init__(self, config: MiniGPTConfig): - self.word_embed = config.word_embed - self.dtype = config.dtype - self.eps = 1e-12 - - self.norm_weight = nn.Parameter( - (self.word_embed,), dtype=self.dtype, name="norm_weight" - ) - self.norm_bias = nn.Parameter( - (self.word_embed,), dtype=self.dtype, name="norm_bias" - ) - - def forward(self, embedding: relax.Expr): - res = relax.op.nn.layer_norm( - embedding, self.norm_weight, self.norm_bias, axes=[-1], epsilon=self.eps - ) - return res - - -class MiniGPTBertAttention(nn.Module): - def __init__(self, config: MiniGPTConfig, hidden_dim: int): - self.word_embed = config.word_embed - self.num_query_token = config.num_query_token - self.num_attn_heads = config.bert_num_attn_heads - self.attn_head_size = config.bert_attn_head_size - self.visual_encoder_attn_hidden_dim = config.visual_encoder_attn_hidden_dim - self.dtype = config.dtype - self.eps = 1e-12 - - self.query_weight = nn.Parameter( - (self.word_embed, self.word_embed), dtype=self.dtype, name="query_weight" - ) - self.query_bias = nn.Parameter( - (self.word_embed,), dtype=self.dtype, name="query_bias" - ) - self.key_weight = nn.Parameter( - (self.word_embed, hidden_dim), dtype=self.dtype, name="key_weight" - ) - self.key_bias = nn.Parameter( - (self.word_embed,), dtype=self.dtype, name="key_bias" - ) - self.value_weight = nn.Parameter( - (self.word_embed, hidden_dim), dtype=self.dtype, name="value_weight" - ) - self.value_bias = nn.Parameter( - (self.word_embed,), dtype=self.dtype, name="value_bias" - ) - self.dense_weight = nn.Parameter( - (self.word_embed, self.word_embed), dtype=self.dtype, name="dense_weight" - ) - self.dense_bias = nn.Parameter( - (self.word_embed,), dtype=self.dtype, name="dense_bias" - ) - self.norm_weight = nn.Parameter( - (self.word_embed,), dtype=self.dtype, name="norm_weight" - ) - self.norm_bias = nn.Parameter( - (self.word_embed,), dtype=self.dtype, name="norm_bias" - ) - - def forward( - self, - hidden_states: relax.Expr, - attention_mask: relax.Expr, - encoder_hidden_states=None, - encoder_extend_attention_mask=None, - ): - from tvm.relax.op import add, linear, matmul, permute_dims, reshape - - bs = 1 - states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - mask = ( - encoder_extend_attention_mask - if encoder_extend_attention_mask is not None - else attention_mask - ) - hidden_dim = ( - self.visual_encoder_attn_hidden_dim - if encoder_hidden_states is not None - else self.num_query_token - ) - key = linear(states, self.key_weight, self.key_bias) - value = linear(states, self.value_weight, self.value_bias) - key = reshape(key, [bs, hidden_dim, self.num_attn_heads, self.attn_head_size]) - key = permute_dims(key, [0, 2, 1, 3]) - value = reshape( - value, [bs, hidden_dim, self.num_attn_heads, self.attn_head_size] - ) - value = permute_dims(value, [0, 2, 1, 3]) - query = linear(hidden_states, self.query_weight, self.query_bias) - query = reshape( - query, [bs, self.num_query_token, self.num_attn_heads, self.attn_head_size] - ) - query = permute_dims(query, [0, 2, 1, 3]) - scores = matmul(query, permute_dims(key, [0, 1, 3, 2])) - scores = scores / relax.const(math.sqrt(self.attn_head_size), dtype=self.dtype) - scores = add(scores, mask) - probs = relax.op.nn.softmax(scores, axis=-1) - context = matmul(probs, value) - context = permute_dims(context, [0, 2, 1, 3]) - context = reshape(context, [bs, self.num_query_token, self.word_embed]) - # calculate the output - context = linear(context, self.dense_weight, self.dense_bias) - context = add(context, hidden_states) - res = relax.op.nn.layer_norm( - context, self.norm_weight, self.norm_bias, axes=[-1], epsilon=self.eps - ) - return res, key, value - - -class MiniGPTBertLayer(nn.Module): - def __init__(self, config: MiniGPTConfig, use_cross_attention=False): - self.word_embed = config.word_embed - self.embed_dim = config.visual_encoder_embed_dim - self.interm_query = config.bert_interm_query - self.dtype = config.dtype - self.eps = 1e-12 - - self.attention = MiniGPTBertAttention(config, self.word_embed) - if use_cross_attention: - self.cross_attention = MiniGPTBertAttention(config, self.embed_dim) - else: - self.cross_attention = None - self.interm_query_weight = nn.Parameter( - (self.interm_query, self.word_embed), - dtype=self.dtype, - name="interm_query_weight", - ) - self.interm_query_bias = nn.Parameter( - (self.interm_query,), dtype=self.dtype, name="interm_query_bias" - ) - self.output_query_weight = nn.Parameter( - (self.word_embed, self.interm_query), - dtype=self.dtype, - name="output_query_weight", - ) - self.output_query_bias = nn.Parameter( - (self.word_embed,), dtype=self.dtype, name="output_query_bias" - ) - self.norm_weight = nn.Parameter( - (self.word_embed,), dtype=self.dtype, name="norm_weight" - ) - self.norm_bias = nn.Parameter( - (self.word_embed,), dtype=self.dtype, name="norm_bias" - ) - - def forward( - self, - embedding: relax.Expr, - extend_attention_mask: relax.Expr, - encoder_hidden_states: relax.Expr, - encoder_extend_attention_mask: relax.Expr, - ): - attn_output, key, value = self.attention(embedding, extend_attention_mask) - if self.cross_attention: - attn_output, _, _ = self.cross_attention( - attn_output, - extend_attention_mask, - encoder_hidden_states, - encoder_extend_attention_mask, - ) - res = relax.op.linear( - attn_output, self.interm_query_weight, self.interm_query_bias - ) - res = relax.op.nn.gelu(res) - res = relax.op.linear(res, self.output_query_weight, self.output_query_bias) - res = relax.op.add(res, attn_output) - res = relax.op.nn.layer_norm( - res, self.norm_weight, self.norm_bias, axes=[-1], epsilon=self.eps - ) - return res, key, value - - -class MiniGPTQFormer(nn.Module): - def __init__(self, config: MiniGPTConfig): - self.N = config.visual_encoder_attn_hidden_dim - self.num_query_token = config.num_query_token - self.word_embed = config.word_embed - self.num_layers = config.bert_hidden_layers - self.dtype = config.dtype - - bs = 1 - self.query_tokens = nn.Parameter( - (bs, self.num_query_token, self.word_embed), - dtype=self.dtype, - name="query_tokens", - ) - self.embedding = MiniGPTEmbedding(config) - self.bert_layers = ModuleList( - [MiniGPTBertLayer(config, i % 2 == 0) for i in range(self.num_layers)] - ) - - def forward(self, image_embeds: relax.Expr): - from tvm.relax.op import expand_dims, ones - - bs = 1 - image_attns = ones((bs, self.N), self.dtype) - embedding = self.embedding(self.query_tokens) - attention_mask = ones((bs, self.num_query_token), self.dtype) - extend_attention_mask = expand_dims(attention_mask, [1, 2]) - extend_attention_mask = ( - relax.const(1.0, self.dtype) - extend_attention_mask - ) * relax.const(-10000.0, self.dtype) - encoder_extend_attention_mask = expand_dims(image_attns, [1, 2]) - encoder_extend_attention_mask = ( - relax.const(1.0, self.dtype) - encoder_extend_attention_mask - ) - for layer in self.bert_layers: - embedding, _, _ = layer( - embedding, - extend_attention_mask, - image_embeds, - encoder_extend_attention_mask, - ) - return embedding - - -class MiniGPTLLaMAProj(nn.Module): - def __init__(self, config: MiniGPTConfig): - self.proj_size = config.llama_proj_size - self.word_embed = config.word_embed - self.dtype = config.dtype - - self.weight = nn.Parameter( - (self.proj_size, self.word_embed), dtype=self.dtype, name="weight" - ) - self.bias = nn.Parameter((self.proj_size,), dtype=self.dtype, name="bias") - - def forward(self, embedding: relax.Expr): - return relax.op.linear(embedding, self.weight, self.bias) - - -class MiniGPTModel(nn.Module): - def __init__(self, config: MiniGPTConfig): - self.visual_encoder = MiniGPTVisualEncoder(config) - self.q_former = MiniGPTQFormer(config) - self.llama_proj = MiniGPTLLaMAProj(config) - - def forward(self, input_image: relax.Expr): - output = self.visual_encoder(input_image) - output = self.q_former(output) - output = self.llama_proj(output) - return output - - -def get_param_quant_kind( - name: str, param_info: relax.TensorStructInfo -) -> ParamQuantKind: - """No quantization for MiniGPT. Use q0f16 or q0f32 when building it.""" - return ParamQuantKind.others - - -def create_embed_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: MiniGPTConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "embed" - - bs = 1 - with bb.function(func_name): - model = MiniGPTModel(config) - param_manager.register_params( - model, func_name, quant_scheme, get_param_quant_kind - ) - - input_image = nn.Placeholder( - (bs, config.image_size, config.image_size, config.in_chan), - dtype="uint8", - name="input_image", - ) - with bb.dataflow(): - output = model(input_image) - params = [input_image] + model.parameters() - gv = bb.emit_output(output) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 1)) - - -def get_model(args, _config): - model_name = args.model - model_path = args.model_path - - if model_name.startswith("minigpt"): - config = MiniGPTConfig(**MODEL_CONFIG[model_name]) - config.dtype = args.quantization.model_dtype - # build the relax model - param_manager = ParamManager() - bb = relax.BlockBuilder() - create_embed_func(bb, param_manager, config, args.quantization) - mod = bb.get() - - if args.build_model_only: - return mod, param_manager, None, config - - param_manager.set_param_loading_func( - args.model_path, args.use_safetensors, no_lazy_param_loading=True - ) - - # load visual encoder weights - visual_encoder_url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth" - visual_encoder_cached_file = download_cached_file( - visual_encoder_url, check_hash=False, progress=True - ) - visual_encoder_state_dict = torch.load( - visual_encoder_cached_file, map_location="cpu" - ) - - # load QFormer weights - q_former_url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth" - q_former_cached_file = download_cached_file( - q_former_url, check_hash=False, progress=True - ) - q_former_state_dict = torch.load(q_former_cached_file, map_location="cpu")[ - "model" - ] - - # load llama and llama proj weights - if os.path.isdir(model_path): - raise ValueError( - "MiniGPT model path should be a single file instead of a directory." - ) - llama_state_dict = torch.load(model_path + ".pth", map_location="cpu")["model"] - - param_list = [] - device = tvm.cpu() - visual_encoder_key_list = list(visual_encoder_state_dict.keys())[ - : 4 + 13 * config.visual_encoder_num_blocks - ] - for key in visual_encoder_key_list: - param_list.append( - tvm.nd.array( - visual_encoder_state_dict[key].numpy().astype(config.dtype), device - ) - ) - q_former_key_list = ( - list(q_former_state_dict.keys())[1:3] - + [list(q_former_state_dict.keys())[0]] - + list(q_former_state_dict.keys())[ - 6 : 8 + (26 + 16) * config.bert_hidden_layers // 2 - ] - ) - for key in q_former_key_list: - param_list.append( - tvm.nd.array( - q_former_state_dict[key].numpy().astype(config.dtype), device - ) - ) - llama_key_list = list(llama_state_dict.keys())[-2:] - for key in llama_key_list: - param_list.append( - tvm.nd.array(llama_state_dict[key].numpy().astype(config.dtype), device) - ) - - return mod, param_manager, param_list, config - - raise ValueError(f"Unsupported model: {model_name}") - - -# helper functions for distributed download of model weights from URL -# source: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/common/dist_utils.py (originally credit to Salesforce) - - -def download_cached_file(url, check_hash=True, progress=False): - import timm.models.hub as timm_hub - import torch.distributed as dist - - def is_dist_avail_and_initialized(): - if not dist.is_available(): - return False - if not dist.is_initialized(): - return False - return True - - def get_rank(): - if not is_dist_avail_and_initialized(): - return 0 - return dist.get_rank() - - def is_main_process(): - return get_rank() == 0 - - """ - Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. - If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. - """ - - def get_cached_file_path(): - # a hack to sync the file path across processes - parts = torch.hub.urlparse(url) - filename = os.path.basename(parts.path) - cached_file = os.path.join(timm_hub.get_cache_dir(), filename) - - return cached_file - - if is_main_process(): - timm_hub.download_cached_file(url, check_hash, progress) - - if is_dist_avail_and_initialized(): - dist.barrier() - - return get_cached_file_path() diff --git a/mlc_llm/relax_model/mistral.py b/mlc_llm/relax_model/mistral.py deleted file mode 100644 index f9959fdb11..0000000000 --- a/mlc_llm/relax_model/mistral.py +++ /dev/null @@ -1,1126 +0,0 @@ -# pylint: disable=too-many-lines, missing-class-docstring, missing-function-docstring -"""Implements the mistal model with sliding window attention.""" - -import math -from dataclasses import dataclass -from typing import Any, List, Optional, Tuple - -import numpy as np -import tvm -from tvm import relax, te -from tvm.relax.op import ccl -from tvm.relax.testing import nn -from tvm.script import relax as R - -from ..quantization import ParamQuantKind, QuantizationScheme -from .commons import create_metadata_func -from .modules import ModuleList -from .param_manager import ParamManager - - -@dataclass -class MistralConfig: - """Configuration for mistral model.""" - - def __init__( - self, - bos_token_id=1, - eos_token_id=2, - pad_token_id=-1, - hidden_act="silu", - hidden_size=4096, - initializer_range=0.02, - intermediate_size=14336, - max_position_embeddings=32768, - num_attention_heads=32, - num_hidden_layers=32, - num_key_value_heads=8, - rms_norm_eps=1e-5, - rope_theta=10000.0, - sliding_window=4096, - attention_sink_size=0, - tie_word_embeddings=False, - vocab_size=32000, - dtype="float32", - max_sequence_length=16384, - combine_matmul=True, - build_model_only=False, - num_shards=1, - **kwargs, - ): - sliding_window = 4096 if sliding_window is None else sliding_window - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.hidden_act = hidden_act - self.hidden_size = hidden_size - self.initializer_range = initializer_range - self.intermediate_size = intermediate_size - self.max_position_embeddings = max_position_embeddings - self.num_attention_heads = num_attention_heads - self.num_hidden_layers = num_hidden_layers - self.num_key_value_heads = num_key_value_heads - self.rms_norm_eps = rms_norm_eps - self.rope_theta = rope_theta - self.sliding_window = sliding_window - self.attention_sink_size = attention_sink_size - self.tie_word_embeddings = tie_word_embeddings - self.vocab_size = vocab_size - self.dtype = dtype - self.max_sequence_length = sliding_window * 4 - self.combine_matmul = combine_matmul - if build_model_only and num_shards > 1: - self.num_shards = num_shards - else: - self.num_shards = 1 - self.kwargs = kwargs - - def get_num_key_value_heads(self): - if self.num_key_value_heads is None: - return self.num_attention_heads - - return self.num_key_value_heads - - -class Linear(nn.Module): - def __init__(self, in_features, out_features, dtype: str, bias=True): - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter((out_features, in_features), dtype=dtype, name="linear_weight") - if bias: - self.bias = nn.Parameter((out_features,), dtype=dtype, name="linear_bias") - else: - self.bias = None - - def forward(self, input: relax.Expr) -> relax.Var: - return nn.emit(relax.op.linear(input, self.weight, self.bias)) - - -class Embedding(nn.Module): - def __init__(self, num_embeddings, embedding_dim, dtype: str): - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.weight = nn.Parameter( - (num_embeddings, embedding_dim), dtype=dtype, name="embedding_weight" - ) - - def forward(self, x: relax.Expr) -> relax.Var: - from tvm.relax.op import ( # pylint: disable=import-outside-toplevel - reshape, - take, - ) - - ndim = x.struct_info.ndim - if ndim == 1: - return nn.emit(take(self.weight, x, axis=0)) - else: - x_shape = x.struct_info.shape.values - emb_size = self.weight.struct_info.shape.values[-1] - x = nn.emit(reshape(x, shape=[-1])) - embedding = nn.emit(take(self.weight, x, axis=0)) - return nn.emit(reshape(embedding, [*x_shape, emb_size])) - - -class MistralRMSNorm(nn.Module): - def __init__(self, hidden_size, dtype, eps=1e-6): - self.weight = nn.Parameter((hidden_size,), dtype=dtype, name="rms_norm_weight") - self.variance_epsilon = tvm.tir.const(eps, dtype) - - def forward(self, hidden_states): - from tvm import te, tir - - def f_rms_norm(x, weight): - is_float32 = x.dtype == "float32" - - def f_square(x): - return tir.Cast("float32", x) * tir.Cast("float32", x) if not is_float32 else x * x - - k = te.reduce_axis((0, x.shape[2]), name="k") - square_sum = te.compute( - (x.shape[0], x.shape[1]), - lambda bsz, i: te.sum(f_square(x[bsz, i, k]), axis=k), - name=x.op.name + "red_temp", - ) - - def f_div_cast(bsz, i, k): - x_val = x[bsz, i, k] - if not is_float32: - x_val = tir.Cast("float32", x_val) - return x_val / tir.sqrt(square_sum[bsz, i] / x.shape[2] + self.variance_epsilon) - - def f_mul_cast(x, y): - value = x * y - if not is_float32: - value = tir.Cast(x.dtype, value) - return value - - return te.compute( - x.shape, - lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast(bsz, i, k)), - name="rms_norm", - ) - - return nn.emit_te(f_rms_norm, hidden_states, self.weight, primfunc_name_hint="rms_norm") - - -class MistralMLP(nn.Module): - def __init__(self, config: MistralConfig): - self.combine_matmul = config.combine_matmul - self.num_shards = config.num_shards - hidden_size = config.hidden_size - intermediate_size = config.intermediate_size // self.num_shards - dtype = config.dtype - if self.combine_matmul: - self.gate_up_proj = Linear(hidden_size, 2 * intermediate_size, dtype=dtype, bias=False) - self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) - self.gate_up_proj.weight.shard_dim = 0 - self.gate_up_proj.weight.shard_strategy = "shard_gate_up" - self.down_proj.weight.shard_dim = 1 - self.down_proj.weight.shard_strategy = "shard_mlp_k" - else: - self.gate_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) - self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) - self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) - - def forward(self, x): - if self.combine_matmul: - gate_up_results = nn.emit( - relax.op.split( - self.gate_up_proj(x), - indices_or_sections=2, - axis=-1, - ) - ) - gate_result = relax.TupleGetItem(gate_up_results, 0) - up_result = relax.TupleGetItem(gate_up_results, 1) - else: - gate_result = self.gate_proj(x) - up_result = self.up_proj(x) - - result = self.down_proj(relax.op.nn.silu(gate_result) * up_result) - return result - - -def apply_rotary_pos_emb(q, k, base, q_offset): - def f_rotary_embedding(tensor, offset): - dtype = tensor.dtype - head_dim = tensor.shape[-1] - n_feat_half = tensor.shape[-1] // 2 - - def rotary_compute(*idx): - i, j = idx[-3], idx[-1] - pos = (offset + i).astype("float32") - inv_freq = te.const(1, "float32") / ( - te.power( - te.const(base, "float32"), - ((2 * j) % head_dim).astype("float32") / head_dim.astype("float32"), - ) - ) - freq = pos * inv_freq - return te.cos(freq).astype(dtype) * tensor(*idx) + te.sin(freq).astype( - dtype - ) * tvm.tir.Select( - j >= n_feat_half, - tensor[idx[0], i, idx[2], j - n_feat_half], - -tensor[idx[0], i, idx[2], j + n_feat_half], - ) - - return tvm.te.compute(tensor.shape, rotary_compute, name="rotary") - - q_embed = nn.emit_te(f_rotary_embedding, q, q_offset, primfunc_name_hint="rotary_embedding") - k_embed = nn.emit_te(f_rotary_embedding, k, 0, primfunc_name_hint="rotary_embedding") - return q_embed, k_embed - - -class MistralAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: MistralConfig): - dtype = config.dtype - self.num_shards = config.num_shards - self.hidden_size = config.hidden_size - self.num_key_value_heads = config.get_num_key_value_heads() // config.num_shards - self.num_query_heads = config.num_attention_heads // self.num_shards - self.head_dim = self.hidden_size // config.num_attention_heads - self.rope_theta = config.rope_theta - self.sliding_window = config.sliding_window - self.attention_sink_size = config.attention_sink_size - - self.combine_matmul = config.combine_matmul - if self.combine_matmul: - self.query_key_value_proj = Linear( - self.hidden_size, - (self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim, - dtype=dtype, - bias=False, - ) - self.query_key_value_proj.weight.shard_dim = 0 - self.query_key_value_proj.weight.shard_strategy = "shard_qkv" - else: - self.q_proj = Linear( - self.hidden_size, - self.num_query_heads * self.head_dim, - dtype=dtype, - bias=False, - ) - self.k_proj = Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - dtype=dtype, - bias=False, - ) - self.v_proj = Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - dtype=dtype, - bias=False, - ) - self.q_proj.weight.shard_dim = 0 - self.k_proj.weight.shard_dim = 0 - self.v_proj.weight.shard_dim = 0 - - self.o_proj = Linear( - self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False - ) - self.o_proj.weight.shard_dim = 1 - self.o_proj.weight.shard_strategy = "shard_o_proj_k" - - def interleave_kv( - self, - key_cur: relax.Expr, - value_cur: relax.Expr, - kv_seq_len: int, - rolling_cache_len: int, - cache_offset: int, - attention_sink_size: int, - past_key_value: Tuple[relax.Expr], - ): - from tvm.relax.op import reshape - - def te_cache_unrotate(x_cached, cache_offset, rolling_cache_len): - return te.compute( - (kv_cur_shape[0], rolling_cache_len, kv_cur_shape[2], kv_cur_shape[3]), - lambda b, s, h, d: te.if_then_else( - s < attention_sink_size, - x_cached[b, s, h, d], - te.if_then_else( - s < rolling_cache_len - cache_offset + attention_sink_size, - x_cached[b, s + cache_offset - attention_sink_size, h, d], - x_cached[b, s + cache_offset - rolling_cache_len, h, d], - ), - ), - name="te_cache_unrotate", - ) - - def te_cache_cur_concat(x, x_cached, kv_seq_len, rolling_cache_len): - return te.compute( - (kv_cur_shape[0], kv_seq_len, kv_cur_shape[2], kv_cur_shape[3]), - lambda b, s, h, d: te.if_then_else( - s < rolling_cache_len, - x_cached[b, s, h, d], - x[b, s - rolling_cache_len, h, d], - ), - name="te_cache_cur_concat", - ) - - def te_squeeze(x): - return te.compute( - x.shape[1:], - lambda s, h, d: x[0, s, h, d], - name="squeeze_te", - ) - - # [bsz, t, nh, hd] - kv_cur_shape = key_cur.struct_info.shape - kv_cur_dtype = key_cur.struct_info.dtype - assert kv_cur_shape[0] == 1 # bsz - kv_batched_cache_shape = R.shape( - [kv_cur_shape[0], rolling_cache_len, kv_cur_shape[2], kv_cur_shape[3]] - ) - kv_cache_shape = R.shape([rolling_cache_len, kv_cur_shape[2], kv_cur_shape[3]]) - - # fecth past keys and values from cache - k_cache, v_cache = past_key_value - - f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") - key_cached = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - k_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, kv_cur_dtype)], - ) - ) - value_cached = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - v_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, kv_cur_dtype)], - ) - ) - key_cached = nn.emit(reshape(key_cached, kv_batched_cache_shape)) - value_cached = nn.emit(reshape(value_cached, kv_batched_cache_shape)) - - key_cached = nn.emit_te( - te_cache_unrotate, - key_cached, - cache_offset, - rolling_cache_len, - primfunc_name_hint="te_cache_unrotate_key", - ) - key = nn.emit_te( - te_cache_cur_concat, - key_cur, - key_cached, - kv_seq_len, - rolling_cache_len, - primfunc_name_hint="te_cache_cur_concat_key", - ) - - value_cached = nn.emit_te( - te_cache_unrotate, - value_cached, - cache_offset, - rolling_cache_len, - primfunc_name_hint="te_cache_unrotate_value", - ) - value = nn.emit_te( - te_cache_cur_concat, - value_cur, - value_cached, - kv_seq_len, - rolling_cache_len, - primfunc_name_hint="te_cache_cur_concat_value", - ) - - # update cache - squeezed_key = nn.emit_te(te_squeeze, key_cur) - squeezed_value = nn.emit_te(te_squeeze, value_cur) - - assert attention_sink_size >= 0 - f_kv_cache_override = relax.extern( - "vm.builtin.attention_kv_cache_window_override_with_sinks" - ) - k_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_override, - k_cache, - squeezed_key, - relax.PrimValue(self.sliding_window), - relax.PrimValue(attention_sink_size), - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - v_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_override, - v_cache, - squeezed_value, - relax.PrimValue(self.sliding_window), - relax.PrimValue(attention_sink_size), - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - - return key, value, (k_cache, v_cache) - - def forward( - self, - hidden_states: relax.Expr, - cache_len_shape: relax.Expr, - kv_seq_len_shape: relax.Expr, - cache_offset_shape: relax.Expr, - past_key_value: Tuple[relax.Expr], - attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Optional[relax.Expr], Optional[Tuple[relax.Expr]]]: - # pylint: disable=import-outside-toplevel - from tvm.relax.op import astype, matmul, maximum, permute_dims, reshape, split - from tvm.relax.op.nn import softmax - - bsz, q_len, _ = hidden_states.struct_info.shape - assert bsz == 1, "Only support batch size 1 at this moment." - - if self.combine_matmul: - qkv_cur = nn.emit( - split( - self.query_key_value_proj(hidden_states), - indices_or_sections=[ - self.num_query_heads * self.head_dim, - (self.num_query_heads + self.num_key_value_heads) * self.head_dim, - ], - axis=-1, - ) - ) - query = relax.TupleGetItem(qkv_cur, 0) - key_cur = relax.TupleGetItem(qkv_cur, 1) - value_cur = relax.TupleGetItem(qkv_cur, 2) - else: - query = self.q_proj(hidden_states) - key_cur = self.k_proj(hidden_states) - value_cur = self.v_proj(hidden_states) - - query = nn.emit( - reshape( - query, - (bsz, q_len, self.num_query_heads, self.head_dim), - ), - ) - key_cur = nn.emit( - reshape( - key_cur, - (bsz, q_len, self.num_key_value_heads, self.head_dim), - ), - ) - value_cur = nn.emit( - reshape( - value_cur, - (bsz, q_len, self.num_key_value_heads, self.head_dim), - ), - ) - - # concat current kv with cached kv (unrotating the cache) - rolling_cache_len = cache_len_shape.struct_info.values[0] - kv_seq_len = kv_seq_len_shape.struct_info.values[0] - cache_offset = cache_offset_shape.struct_info.values[0] - key, value, updated_key_value = self.interleave_kv( - key_cur, - value_cur, - kv_seq_len, - rolling_cache_len, - cache_offset, - self.attention_sink_size, - past_key_value, - ) - - # cache relative position embeddings (after KV Cache) - query, key = apply_rotary_pos_emb( - query, - key, - self.rope_theta, - q_offset=rolling_cache_len, - ) - - if self.num_key_value_heads != self.num_query_heads: - n_rep = self.num_query_heads // self.num_key_value_heads - key = nn.emit(relax.op.repeat(key, n_rep, axis=2)) - value = nn.emit(relax.op.repeat(value, n_rep, axis=2)) - - query = nn.emit(permute_dims(query, [0, 2, 1, 3])) - key = nn.emit(permute_dims(key, [0, 2, 1, 3])) - value = nn.emit(permute_dims(value, [0, 2, 1, 3])) - - attn_weights = nn.emit( - matmul(query, permute_dims(key, [0, 1, 3, 2])) - / relax.const(math.sqrt(self.head_dim), query.struct_info.dtype) - ) - - tvm.ir.assert_structural_equal( - attention_mask.struct_info.shape.values, - (bsz, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), - ) - - attn_weights = nn.emit( - maximum( - attn_weights, - relax.const( - tvm.tir.min_value(attn_weights.struct_info.dtype).value, - attn_weights.struct_info.dtype, - ), - ) - ) - attn_weights = nn.emit(relax.op.minimum(attn_weights, attention_mask)) - - # upcast attention to fp32 - if attn_weights.struct_info.dtype != "float32": - attn_weights = astype(attn_weights, "float32") - attn_weights = nn.emit(softmax(attn_weights, axis=-1)) - if attn_weights.struct_info.dtype != query.struct_info.dtype: - attn_weights = astype(attn_weights, query.struct_info.dtype) - attn_output = nn.emit(matmul(attn_weights, value)) - - attn_output = nn.emit(permute_dims(attn_output, [0, 2, 1, 3])) - attn_output = nn.emit( - reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) - ) - - attn_output = self.o_proj(attn_output) - - return attn_output, ((None, None) if updated_key_value is None else updated_key_value) - - -class MistralDecoderLayer(nn.Module): - def __init__(self, config: MistralConfig): - self.hidden_size = config.hidden_size - self.self_attn = MistralAttention(config) - self.mlp = MistralMLP(config) - self.input_layernorm = MistralRMSNorm( - config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = MistralRMSNorm( - config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps - ) - - def forward( - self, - hidden_states: relax.Expr, - cache_len_shape: relax.Expr, - kv_seq_len_shape: relax.Expr, - cache_offset_shape: relax.Expr, - past_key_value: Tuple[relax.Expr], - attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - cache_len_shape=cache_len_shape, - kv_seq_len_shape=kv_seq_len_shape, - cache_offset_shape=cache_offset_shape, - ) - if self.self_attn.num_shards > 1: - residual = nn.emit( - residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) - ) - hidden_states = nn.emit(residual + hidden_states) - if self.self_attn.num_shards > 1: - hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if self.mlp.num_shards > 1: - residual = nn.emit( - residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) - ) - hidden_states = nn.emit(residual + hidden_states) - if self.mlp.num_shards > 1: - hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - return hidden_states, present_key_value - - -def _make_sliding_window_mask(input_shape, kv_seq_len, sliding_window, dtype): - # See `tests/python/test_sliding_window_mask.py` for more on its behavior. - # [bsz, tgt_len] -> [bsz, 1, tgt_len, kv_seq_len] - - bsz, tgt_len = input_shape # TODO: only support batch size of 1 for now - cache_len = kv_seq_len - tgt_len # number of elements in cache - - if isinstance(tgt_len, tvm.tir.SizeVar) or tgt_len > 1: - # Either 1. First prefill, or 2. Subsequent prefill - from tvm.relax.op import broadcast_to # pylint: disable=import-outside-toplevel - - def sliding_window_min_max_te(sliding_window): - return te.compute( - (tgt_len, kv_seq_len), - lambda i, j: tvm.tir.Select( - tvm.tir.all(i + cache_len >= j, i + cache_len - j < sliding_window), - tvm.tir.max_value(dtype), - tvm.tir.min_value(dtype), - ), - name="make_diag_mask_sliding_window_te", - ) - - mask = nn.emit_te(sliding_window_min_max_te, sliding_window) - return nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, kv_seq_len))) - - else: - # 3. Decode (equivalent to prefilling a chunk of size 1) - # Mask nothing here since WS == cache_size - bsz, tgt_len = input_shape - return nn.emit( - relax.op.full( - (bsz, 1, tgt_len, kv_seq_len), - relax.const(tvm.tir.max_value(dtype).value, dtype), - dtype, - ) - ) - - -class MistralEmbedTokens(nn.Module): - def __init__(self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar): - self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) - - def forward(self, input_ids: relax.Expr): - inputs_embeds = self.embed_tokens(input_ids) - return inputs_embeds - - -class MistralEmbedTokensWrapper(nn.Module): - def __init__(self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar): - # build a wrapper to ensure that the naming of the embed_tokens parameter is consistent - self.model = MistralEmbedTokens(config, vocab_size_var) - - def forward(self, input_ids: relax.Expr): - inputs_embeds = self.model(input_ids) - return inputs_embeds - - -class MistralModel(nn.Module): - def __init__( - self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False - ): - self.num_shards = config.num_shards - self.padding_idx = config.pad_token_id - self.embed_tokens = None - - if not sep_embed: - self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) - - self.layers = ModuleList( - [MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)] - ) - self.norm = MistralRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) - self.sliding_window = config.sliding_window - - def forward( - self, - inputs: relax.Expr, - cache_len_shape: relax.Expr, - kv_seq_len_shape: relax.Expr, - cache_offset_shape: relax.Expr, - past_key_values: relax.Expr, - ): - if self.num_shards > 1: - inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) - if self.embed_tokens: - inputs_embeds = self.embed_tokens(inputs) - else: - inputs_embeds = inputs - # retrieve input_ids - batch_size, seq_length, _ = inputs_embeds.struct_info.shape - kv_seq_len = kv_seq_len_shape.struct_info.values[0] - - # embed positions - attention_mask = _make_sliding_window_mask( - (batch_size, seq_length), - kv_seq_len, - self.sliding_window, - inputs_embeds.struct_info.dtype, - ) - - hidden_states = inputs_embeds - - # decoder layers - next_decoder_cache = () - - for idx, decoder_layer in enumerate(self.layers): - assert past_key_values is not None - past_key_value = (past_key_values[idx * 2], past_key_values[idx * 2 + 1]) - - hidden_states, key_value_cache = decoder_layer( - hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - cache_len_shape=cache_len_shape, - kv_seq_len_shape=kv_seq_len_shape, - cache_offset_shape=cache_offset_shape, - ) - next_decoder_cache += key_value_cache - - hidden_states = self.norm(hidden_states) - - assert len(next_decoder_cache) == len(self.layers) * 2 - return hidden_states, next_decoder_cache - - -class MistralForCausalLM(nn.Module): - def __init__( - self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False - ): - self.model = MistralModel(config, vocab_size_var, sep_embed) - self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) - - ############ Rotary embedding constants ############ - assert config.hidden_size % config.num_attention_heads == 0 - head_dim = config.hidden_size // config.num_attention_heads - - # Set the cached sin/cos to the maximum of 2048 and max seq len. - # This will be eliminated further with online rotary embedding calculation. - rope_cache_len = te.var("rope_cache_len", "int64") - self.cos_cached = nn.Parameter( - (rope_cache_len, head_dim), dtype=config.dtype, name="cos_cached" - ) - self.sin_cached = nn.Parameter( - (rope_cache_len, head_dim), dtype=config.dtype, name="sin_cached" - ) - ############ End ############ - - def forward( - self, - inputs: relax.Expr, - cache_len_shape: relax.Expr, - kv_seq_len_shape: relax.Expr, - cache_offset_shape: relax.Expr, - past_key_values: relax.Expr, - ): - hidden_states, key_value_cache = self.model( - inputs=inputs, - cache_len_shape=cache_len_shape, - kv_seq_len_shape=kv_seq_len_shape, - cache_offset_shape=cache_offset_shape, - past_key_values=past_key_values, - ) - - def te_slicing(x: te.Tensor): - return te.compute( - shape=(1, 1, x.shape[-1]), - fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], - name="slice", - ) - - logits = self.lm_head(nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice")) - if logits.struct_info.dtype != "float32": - logits = nn.emit(relax.op.astype(logits, "float32")) - - return logits, key_value_cache - - -def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: - if "embed_tokens" in name: - return ParamQuantKind.embedding_table - elif "lm_head.weight" in name: - return ParamQuantKind.final_fc_weight - elif param_info.ndim == 2 and name.endswith(".weight"): - return ParamQuantKind.linear_weight - else: - return ParamQuantKind.others - - -def create_embed_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: MistralConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "embed" - - bsz = 1 - seq_len = tvm.tir.SizeVar("n", "int64") - with bb.function(func_name): - model = MistralEmbedTokensWrapper(config, tvm.tir.SizeVar("vocab_size", "int64")) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") - with bb.dataflow(): - inputs_embeds = model(input_ids) - params = [input_ids] + model.parameters() - gv = bb.emit_output(inputs_embeds) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 1)) - - -def create_encoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: MistralConfig, - quant_scheme: QuantizationScheme, - sep_embed: bool = False, -) -> None: - func_name = "prefill_with_embed" if sep_embed else "prefill" - - bsz = 1 - seq_len = tvm.tir.SizeVar("n", "int64") # number of tokens for the input - rolling_cache_len = tvm.tir.SizeVar( - "c", "int64" - ) # rolling_cache_len captures number of elements in the cache - kv_seq_len = tvm.tir.SizeVar( - "k", "int64" - ) # kv_seq_len captures number of elements in cache + seq_len - cache_offset = tvm.tir.SizeVar("o", "int64") # slidinf window kv cache offset - - hidden_size = config.hidden_size - with bb.function(func_name): - model = MistralForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - inputs = ( - nn.Placeholder((bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds") - if sep_embed - else nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") - ) - cache_len_shape = relax.Var( - "rolling_cache_len", relax.ShapeStructInfo((rolling_cache_len,)) - ) - kv_seq_len_shape = relax.Var("kv_seq_len", relax.ShapeStructInfo((kv_seq_len,))) - cache_offset_shape = relax.Var("cache_offset", relax.ShapeStructInfo((cache_offset,))) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] - ), - ) - with bb.dataflow(): - logits, key_value_cache = model( - inputs, - cache_len_shape, - kv_seq_len_shape, - cache_offset_shape, - past_key_values=past_key_values, - ) - params = [ - inputs, - cache_len_shape, - kv_seq_len_shape, - cache_offset_shape, - past_key_values, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 5)) - - -def create_decoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: MistralConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "decode" - - bsz = 1 - rolling_cache_len = tvm.tir.SizeVar( - "c", "int64" - ) # rolling_cache_len captures number of elements in the cache - kv_seq_len = tvm.tir.SizeVar( - "k", "int64" - ) # kv_seq_len captures number of elements in cache + seq_len - cache_offset = tvm.tir.SizeVar("o", "int64") # sliding window kv cache offset - - with bb.function(func_name): - model = MistralForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64")) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") - cache_len_shape = relax.Var( - "rolling_cache_len", relax.ShapeStructInfo((rolling_cache_len,)) - ) - kv_seq_len_shape = relax.Var("kv_seq_len", relax.ShapeStructInfo((kv_seq_len,))) - cache_offset_shape = relax.Var("cache_offset", relax.ShapeStructInfo((cache_offset,))) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] - ), - ) - with bb.dataflow(): - logits, key_value_cache = model( - input_ids, - cache_len_shape, - kv_seq_len_shape, - cache_offset_shape, - past_key_values=past_key_values, - ) - params = [ - input_ids, - cache_len_shape, - kv_seq_len_shape, - cache_offset_shape, - past_key_values, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 5)) - - -def create_kv_cache_func(bb: relax.BlockBuilder, config: MistralConfig) -> None: - num_key_value_heads = config.get_num_key_value_heads() // config.num_shards - init_shape = relax.ShapeExpr( - ( - config.sliding_window, - num_key_value_heads, - config.hidden_size // config.num_attention_heads, # head_dim - ) - ) - with bb.function("create_kv_cache", []): - with bb.dataflow(): - zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) - caches = [] - f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") - for _ in range(config.num_hidden_layers * 2): - caches.append( - bb.emit( - relax.call_pure_packed( - f_kv_cache_create, - zeros, - init_shape, - relax.PrimValue(0), - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - ) - gv = bb.emit_output(caches) - bb.emit_func_output(gv) - - -def create_softmax_func(bb: relax.BlockBuilder, config: MistralConfig) -> None: - with bb.function("softmax_with_temperature"): - logits = nn.Placeholder( - (1, 1, tvm.tir.SizeVar("vocab_size", "int64")), dtype="float32", name="logits" - ) - temperature = nn.Placeholder((), dtype="float32", name="temperature") - with bb.dataflow(): - div = bb.emit(relax.op.divide(logits, temperature)) - softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) - gv = bb.emit_output(softmax) - bb.emit_func_output(gv, [logits, temperature]) - - -def get_model(args, hf_config): - model_name = args.model - dtype = args.quantization.model_dtype - sep_embed = args.sep_embed - assert not sep_embed, "Mistral does not support separate embedding." - - if args.sliding_window != -1: - hf_config["sliding_window"] = args.sliding_window - if args.attention_sink_size > 0: - hf_config["attention_sink_size"] = args.attention_sink_size - if args.max_seq_len != -1: - hf_config["max_sequence_length"] = args.max_seq_len - - config = MistralConfig( - **hf_config, - dtype=dtype, - combine_matmul=True, - num_shards=args.num_shards, - build_model_only=args.build_model_only, - ) - - # prefill chunk size same as sliding window by default - if args.prefill_chunk_size < 1: - args.prefill_chunk_size = config.sliding_window - config.attention_sink_size - - assert config.sliding_window != -1 - assert args.prefill_chunk_size <= config.sliding_window - config.attention_sink_size - - param_manager = ParamManager() - bb = relax.BlockBuilder() - - create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) - create_decoding_func(bb, param_manager, config, args.quantization) - create_kv_cache_func(bb, config) - create_softmax_func(bb, config) - create_metadata_func( - bb, - model_name=model_name, - max_window_size=config.max_sequence_length, - stop_tokens=[2], - add_prefix_space=False, - sliding_window=config.sliding_window, - prefill_chunk_size=args.prefill_chunk_size, - ) - - mod = bb.get() - for gv in mod.functions: - func = mod[gv] - if isinstance(func, relax.Function): - mod[gv] = func.with_attr( - "tir_var_upper_bound", - { - "n": args.prefill_chunk_size, - "c": config.sliding_window, - "k": config.sliding_window + args.prefill_chunk_size, - }, - ) - - if args.build_model_only: - return mod, param_manager, None, config - - def f_convert_pname_fwd(pname: str) -> List[str]: - if not config.combine_matmul: - return [pname] - - qkv_str = "query_key_value_proj" - gate_up_str = "gate_up_proj" - if qkv_str in pname: - return [ - pname.replace(qkv_str, "q_proj"), - pname.replace(qkv_str, "k_proj"), - pname.replace(qkv_str, "v_proj"), - ] - elif gate_up_str in pname: - return [ - pname.replace(gate_up_str, "gate_proj"), - pname.replace(gate_up_str, "up_proj"), - ] - else: - return [pname] - - def f_convert_param_bkwd(torch_pname: str, torch_param): - if not config.combine_matmul: - return [(torch_pname, torch_param.astype(dtype))] - - combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] - if any([name in torch_pname for name in combined_layers]): - return None - return [(torch_pname, torch_param.astype(dtype))] - - def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): - # Expected to enter this function only for the combined linear matmul weights. - # Other weights are supposed to be loaded in `f_convert_param_bkwd` since - # each other relax param has a unique corresponding torch param. - if not config.combine_matmul: - # When matmul combination is not turned on, each relax param has a unique - # corresponding torch param, and this function is not expected to be entered. - raise NotImplementedError( - "Matmul combination is not turned on, and the function " - "is not expected to be entered" - ) - hidden_size = config.hidden_size - head_dim = config.hidden_size // config.num_attention_heads - - if "query_key_value_proj" in relax_pname: - q_heads = config.num_attention_heads - kv_heads = config.get_num_key_value_heads() - q, k, v = torch_params - assert q.shape == (q_heads * head_dim, hidden_size) - assert k.shape == (kv_heads * head_dim, hidden_size) - assert v.shape == (kv_heads * head_dim, hidden_size) - qkv = np.concatenate([q, k, v], axis=0).astype(dtype) - return qkv - if "gate_up_proj" in relax_pname: - gate, up = torch_params - gate_up = np.concatenate([gate, up], axis=0).astype(dtype) - return gate_up - raise ValueError("Unexpected param loading") - - param_manager.set_param_loading_func( - args.model_path, - args.use_safetensors, - f_convert_pname_fwd, - f_convert_param_bkwd, - f_compute_relax_param, - ) - - device = tvm.cpu() - param_list = [None] * param_manager.nparam_to_load - - head_dim = config.hidden_size / config.num_attention_heads - inv_freq = 1.0 / (config.rope_theta ** (np.arange(0, head_dim, 2).astype("float32") / head_dim)) - - # The following cos/sin values can be removed but **are kept for compatibility issues**. - t = np.arange(2048, dtype=inv_freq.dtype) - freqs = np.einsum("i,j->ij", t, inv_freq) - emb = np.concatenate((freqs, freqs), axis=-1) - param_list[-2] = tvm.nd.array(np.cos(emb).astype(config.dtype), device) - param_list[-1] = tvm.nd.array(np.sin(emb).astype(config.dtype), device) - - return mod, param_manager, param_list, config diff --git a/mlc_llm/relax_model/modules.py b/mlc_llm/relax_model/modules.py deleted file mode 100644 index e506938591..0000000000 --- a/mlc_llm/relax_model/modules.py +++ /dev/null @@ -1,280 +0,0 @@ -# pylint: disable=missing-docstring,invalid-name -from typing import Dict, List, Tuple, Optional - -import numpy as np -from tvm import relax, te, tir -from tvm.relax.op import matmul, permute_dims, reshape, take -from tvm.relax.op.nn import layer_norm -from tvm.relax.testing import nn -from tvm.runtime.ndarray import array as tvm_array - - -class ModuleList(nn.Module): - def __init__(self, modules: List[nn.Module]): - self.modules = modules - - def __iter__(self): - return iter(self.modules) - - def __getitem__(self, idx): - return self.modules[idx] - - def __len__(self): - return len(self.modules) - - def forward(self, x: relax.Expr) -> relax.Var: - for module in self.modules: - x = module(x) - return x - - -class Linear(nn.Module): - def __init__( - self, - in_features, - out_features, - dtype, - bias=True, - out_dtype=None, - ): - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter( - (out_features, in_features), - dtype=dtype, - name="linear_weight", - ) - if bias: - self.bias = nn.Parameter( - (out_features,), - dtype=dtype if out_dtype is None else out_dtype, - name="linear_bias", - ) - else: - self.bias = None - self.dtype = dtype - self.out_dtype = out_dtype - - def forward(self, x: relax.Expr) -> relax.Var: - x = nn.emit(x) - weight = permute_dims(self.weight, axes=None) - x = nn.emit(matmul(x, weight, out_dtype=self.out_dtype)) - if self.bias is not None: - x = nn.emit(x + self.bias) - return x - - -class Embedding(nn.Module): - def __init__(self, num_embeddings, embedding_dim, dtype): - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.weight = nn.Parameter( - (num_embeddings, embedding_dim), dtype=dtype, name="weight" - ) - - def forward(self, x: relax.Expr) -> relax.Var: - ndim = x.struct_info.ndim - if ndim == 1: - return nn.emit(take(self.weight, x, axis=0)) - x_shape = x.struct_info.shape.values - emb_size = self.weight.struct_info.shape.values[-1] - x = nn.emit(reshape(x, shape=[-1])) - embedding = nn.emit(take(self.weight, x, axis=0)) - return nn.emit(reshape(embedding, [*x_shape, emb_size])) - - -class LayerNorm(nn.Module): - def __init__( - self, - hidden_size, - dtype, - eps=1e-5, - ): - super().__init__() - self.eps = eps - self.weight = nn.Parameter((hidden_size,), dtype="float32", name="weight") - self.bias = nn.Parameter((hidden_size,), dtype="float32", name="bias") - - def forward(self, x: relax.Expr) -> relax.Var: - if x.struct_info.dtype != "float32": - x = nn.emit(relax.op.astype(x, "float32")) - x = nn.emit( - layer_norm( - x, - gamma=self.weight, - beta=self.bias, - axes=-1, - epsilon=self.eps, - ) - ) - return x - - -class RotaryEmbedding(nn.Module): - def __init__( - self, - hidden_size: int, - num_attention_heads: int, - position_embedding_base: int, - max_sequence_length: int, - rotary_pct: Optional[float] = None, - rotary_dim: Optional[int] = None, - swizzle_style: str = "neox", - dtype: str = "float32", - ): - super().__init__() - head_dim = hidden_size // num_attention_heads - if rotary_dim is not None: - rotary_ndim = rotary_dim - else: - rotary_ndim = int(head_dim * rotary_pct) - inv_freq = 1.0 / ( - position_embedding_base - ** (np.arange(0, rotary_ndim, 2).astype("float32") / rotary_ndim) - ) - t = np.arange(max_sequence_length, dtype=inv_freq.dtype) - freq = np.einsum("i,j->ij", t, inv_freq) - if swizzle_style == "neox": - emb = np.concatenate((freq, freq), axis=-1) - elif swizzle_style in ("gptj", "glm"): - emb = np.repeat(freq, repeats=2, axis=-1) - else: - raise KeyError("Unrecognized swizzle style {}".format(swizzle_style)) - self.swizzle_style = swizzle_style - self.rotary_ndim = rotary_ndim - self.cos_cached = relax.const(tvm_array(np.cos(emb).astype(dtype))) - self.sin_cached = relax.const(tvm_array(np.sin(emb).astype(dtype))) - - def get_x_swizzle(self, x, i_batch_size, i_seq_len, i_num_heads, i_head_dim): - if self.swizzle_style == "neox": - n_feat_half = self.rotary_ndim // 2 - return tir.Select( - i_head_dim < n_feat_half, - -x[ - i_batch_size, - i_seq_len, - i_num_heads, - i_head_dim + n_feat_half, - ], - x[ - i_batch_size, - i_seq_len, - i_num_heads, - i_head_dim - n_feat_half, - ], - ) - elif self.swizzle_style in ("gptj", "glm"): - return tir.Select( - i_head_dim % 2 == 0, - -x[i_batch_size, i_seq_len, i_num_heads, i_head_dim + 1], - x[i_batch_size, i_seq_len, i_num_heads, i_head_dim - 1], - ) - else: - raise KeyError("Unrecognized swizzle style: {}.".format(self.swizzle_style)) - - def forward( - self, - q: relax.Expr, - k: relax.Expr, - offset: relax.Expr, - ) -> Tuple[relax.Expr, relax.Expr]: - def rotary_embedding(x, cos, sin, offset): - def compute( - i_batch_size, - i_seq_len, - i_num_heads, - i_head_dim, - ): - return tir.Select( - i_head_dim < self.rotary_ndim, - cos[ - offset + i_seq_len, - i_head_dim, - ] - * x(i_batch_size, i_seq_len, i_num_heads, i_head_dim) - + sin[ - offset + i_seq_len, - i_head_dim, - ] - * self.get_x_swizzle( - x, i_batch_size, i_seq_len, i_num_heads, i_head_dim - ), - x(i_batch_size, i_seq_len, i_num_heads, i_head_dim), - ) - - return te.compute(x.shape, compute, name="rotary") - - cos, sin = self.cos_cached, self.sin_cached - q_embed = nn.emit_te( - rotary_embedding, - q, - cos, - sin, - offset, - primfunc_name_hint="rotary_embedding", - ) - k_embed = nn.emit_te( - rotary_embedding, - k, - cos, - sin, - offset, - primfunc_name_hint="rotary_embedding", - ) - return q_embed, k_embed - - -class TransformImage(nn.Module): - def __init__(self, dtype: str, in_chans: int = 4): - self.in_chans = in_chans - self.dtype = dtype - - # used in normalization, assume channels are RGB - self.r_mean = relax.const(0.48145466, "float32") - self.g_mean = relax.const(0.4578275, "float32") - self.b_mean = relax.const(0.40821073, "float32") - self.r_std = relax.const(0.26862954, "float32") - self.g_std = relax.const(0.26130258, "float32") - self.b_std = relax.const(0.27577711, "float32") - - def forward(self, input: relax.Expr) -> relax.Expr: - from tvm.relax.op import astype, concat, permute_dims, strided_slice - - assert input.struct_info.ndim == 4 - # perform torch.ToTensor on input of shape (bs, height, width, in_chans) - input = permute_dims(input, [0, 3, 1, 2]) - x = astype(input, "float32") / relax.const(255.0, "float32") - r = strided_slice(x, axes=[1], begin=[0], end=[1]) - g = strided_slice(x, axes=[1], begin=[1], end=[2]) - b = strided_slice(x, axes=[1], begin=[2], end=[3]) - - # normalize rgba to rgb - if self.in_chans == 4: - a = strided_slice(x, axes=[1], begin=[3], end=[4]) - r /= a - g /= a - b /= a - - # perform torch.Normalize - r = (r - self.r_mean) / self.r_std - g = (g - self.g_mean) / self.g_std - b = (b - self.b_mean) / self.b_std - res = concat([r, g, b], axis=1) - res = astype(res, self.dtype) - - return res - - -def named_parameters(model: nn.Module) -> Dict[str, nn.Parameter]: - params: Dict[str, nn.Parameter] = {} - for name, module in model.__dict__.items(): - if isinstance(module, nn.Parameter): - params[name] = module - elif isinstance(module, ModuleList): - for i, m in enumerate(module): - for param_name, param in named_parameters(m).items(): - params[f"{name}.{i}.{param_name}"] = param - elif isinstance(module, nn.Module): - for param_name, param in named_parameters(module).items(): - params[f"{name}.{param_name}"] = param - return params diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py deleted file mode 100644 index 1ad1ee6428..0000000000 --- a/mlc_llm/relax_model/param_manager.py +++ /dev/null @@ -1,1259 +0,0 @@ -import json -import os -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union - -import tvm -from torch import Tensor as torchTensor -from tvm import relax, tir -from tvm._ffi.runtime_ctypes import Device -from tvm.relax.analysis import remove_all_unused -from tvm.relax.expr import Expr, Function, Var -from tvm.relax.expr_functor import PyExprMutator, mutator -from tvm.relax.testing import nn - -from .. import quantization -from .modules import named_parameters -from ..transform import ReorderTransformFunc - - -def f_default_compute_relax_param(relax_pname: str, torch_params: List[Any]) -> Any: - """The defualt `f_compute_relax_param` for ParamManager. - See ParamManager for more details. - """ - raise NotImplementedError() - - -class Parameter: - """The abstraction of weight tensors (e.g., linear layer weight, embedding - table, etc.) in a model. - - Attributes - ---------- - name : str - The name of the parameter. - The name of a weight is got by `named_parameters()` method, similar to - PyTorch's `named_parameters()` function. - An example name is `model.layers.11.self_attn.k_proj.weight`. - In a model, the name is the **unique** identifier of a parameter. - - param_info_dict : Dict[str, relax.TensorStructInfo] - The shape and dtype of the parameter in each function. - The shape can be accessed by `param_info_dict[func_name].shape`, which is - a relax.ShapeExpr instance. - And the dtype can be accessed by `param_info_dict[func_name].dtype`, - which is a Python string. - - quant_spec : quantization.QuantizationSpec - The quantization specification of this parameter. - It specifies the algorithm to quantize and dequantize this parameter (or - this parameter does not need quantization). - - shard_dim : Optional[int] - The dimension to be sharded. - - shard_strategy : Optional[str] - The strategy to shard the parameter. - """ - - name: str - param_info_dict: Dict[str, relax.TensorStructInfo] - quant_spec: quantization.QuantizationSpec - shard_dim: Optional[int] - shard_strategy: Optional[str] - - def __init__( - self, - name: str, - quant_spec: quantization.QuantizationSpec, - shard_dim: Optional[int], - shard_strategy: Optional[str], - ) -> None: - self.name = name - self.param_info_dict = dict() - self.quant_spec = quant_spec - self.shard_dim = shard_dim - self.shard_strategy = shard_strategy - - def register_func(self, func_name: str, param_info: relax.TensorStructInfo): - self.param_info_dict[func_name] = param_info - - @property - def param_info(self): - """Return the shape and dtype of the parameter (in some arbitrary function).""" - return next(iter(self.param_info_dict.values())) - - -class ParamManager: - """The model-wise data structure which contains the information of every - weight in the model and is in charge of applying quantization and dequantization - to the parameters at the entire model level. - - Attributes - ---------- - params : Dict[str, Parameter] - The mapping from parameter names to parameters. - - param_names : List[str] - The name list of all the parameters. - To enforce a unique order or all the parameters for determinism, the - parameter names are kept in the list, and the parameter order is - uniquely determined by the parameter name list. - - func_raw_param_map : Dict[relax.Var, Tuple[str, Parameter]] - The mapping from each relax.Var that denotes a weight parameter to the - name of the function the var is in (e.g., "prefill" or "decode"), and - the Parameter it corresponds to. - This mapping is used for applying quantization transformation to the - Relax functions (e.g., the "prefill", "decode", etc.) in the model. - - param2qrange : Dict[Parameter, range] - The mapping from each parameter to the range of its quantized tensors - in the list of quantized tensors of all parameters. - Each parameter is quantized into multiple tensors. - For example, assume we have parameters `p0`, `p1`, `p2`. - - `p0` is quantized into `t0_0`, `t0_1`, - - `p1` is quantized into `t1_0`, and - - `p2` is quantized into `t2_0`, `t2_1` and `t2_2`. - Then the list of all quantized tensors is `[t0_0, t0_1, t1_0, t2_0, t2_1, t2_2]`, - and the dict `param2qrange` is - `{p0: range(0, 2), p1: range(2, 3), p2: range(3, 6)}`. - - f_convert_pname_fwd : Callable[[str], List[str]] - The function which converts Relax parameter name (ours) to torch's - parameter names, suggesting "to load this Relax parameter, which torch - parameter(s) are needed". - - Usually, the function maps a name to itself. For example, in LLaMA we - map `lm_head.weight` itself, as the parameter has the same name on both - Relax side and torch side. - - In some cases we map a name to multiple names. For example, if we - support combined QKV computing when the torch side separates them, on - Relax side we only have one QKV weight, while on torch side we have - one weight for each of Q, K, V. In this case, we map one name to three - names. - - In some cases we map a name to a single name which is other than - itself. This can happen either when the Relax nn.Module has different - param names than the torch's implementation so we need to map names - for connection, or when a Relax parameter is computed out from a torch - parameter. For example, if the torch implementation supports combined - QKV while the Relax one does not, we need compute the relax parameter - out from torch's parameter. In this case we map the relax parameter - name to the torch's parameter name. - - f_convert_param_bkwd : Callable[[str, Any], Optional[List[Tuple[str, Any]]]] - The function which converts torch parameter and param name back to - Relax parameters with names. `Any` here stands for numpy.ndarray. - - Usually, the function just returns the input torch parameter and - the corresponding Relax parameter's name. - - In some cases, we return multiple Relax parameters. For example, if - the torch implementation supports combined QKV while the Relax one does - not, the function takes torch's combined QKV weight, and return the - separated Q K V weights with their corresponding names. - - In some cases we return `None`. This happens when the input torch - parameter itself does not determine any Relax parameter. For example, - if we support combined QKV computing when the torch side separates them, - we return `None` here for the single Q, K, V weights, as by only having - a Q (or K, V) weight we cannot compute the combined QKV weight. - - f_compute_relax_param : Callable[[str, List[Any]], Any] - The function which computes a Relax parameter from a list of torch - parameters. `Any` here stands for numpy.ndarray. In the case when one - Relax parameter is computed from multiple torch parameters, this - functions is used. - For example, if we support combined QKV computing when the torch side - separates them, we use this function to combine the torch's Q, K, V - weights into one - In usual case, this function is not needed and by default it is - implemented by raising `NotImplementedError` (see f_default_compute_relax_param). - - model_path : str - The path of the Hugging Face model on disk. - - use_safetensors: bool - Whether to use `.safetensors` instead of `.bin` to load model. - - safetensors_load_func: Callable[[Union[str, os.PathLike], str], Dict[str, torch.Tensor]] - A reference to the function `load_file` improted from `safetensors.torch`. - The goal is to prevent repeatedly importing in a tvm registered function. - - pidx2pname : Dict[int, str] - The dictionary from each Relax parameter's index in `param_names` to - the Relax parameter's name. - - torch_pname2binname : Dict[str, str] - The dictionary from each torch parameter's name to the name of the - binary shard where the torch parameter is saved. - """ - - params: Dict[str, Parameter] - param_names: List[str] - func_raw_param_map: Dict[relax.Var, Tuple[str, Parameter]] - param2qrange: Dict[Parameter, range] - - qspec_updater_classes: List[quantization.QuantSpecUpdater] - - nparam_to_load: int - f_convert_pname_fwd: Callable[[str], List[str]] - f_convert_param_bkwd: Callable[[str, Any], Optional[List[Tuple[str, Any]]]] - f_compute_relax_param: Callable[[str, List[Any]], Any] - f_run_prequantize: Optional[Callable[[str], str]] - - model_path: str - use_safetensors: bool - safetensors_load_func: Callable[[Union[str, os.PathLike], str], Dict[str, torchTensor]] - pidx2pname: Dict[int, str] - torch_pname2binname: Dict[str, str] - - def __init__(self) -> None: - self.params = {} - self.param_names = [] - self.params_in_func = {} - - self.func_raw_param_map = {} - self.param2qrange = None - - self.nparam_to_load = None - self.f_convert_pname_fwd = None - self.f_convert_param_bkwd = None - self.f_compute_relax_param = None - self.f_run_prequantize = None - - self.qspec_updater_classes = [] - - def register_params( - self, - model: nn.Module, - func_name: str, - quantization_scheme: quantization.QuantizationScheme, - f_get_param_quant_kind: Callable[ - [str, relax.TensorStructInfo], quantization.ParamQuantKind - ], - ) -> None: - """Register the parameters of the input model (within the context of the - input function) in the parameter manager. - - Parameters - ---------- - model : nn.Module - The input model whose parameters are registered. - - func_name : str - The name of the function the input model is in. - For example, the "prefill" function or the "decode" function. - - quantization_scheme : quantization.QuantizationScheme - The quantization scheme of the input model, which describes how - to quantize the model. - - f_get_param_quant_kind: Callable[[str, relax.TensorStructInfo], quantization.ParamQuantKind] - A function which takes the name and StructInfo (effectively shape - and dtype) of a parameter, and returns which quantization kind this - parameter uses. - This is used for applying quantization to the parameters. - """ - if quantization_scheme.qspec_updater_class is not None: - self.qspec_updater_classes.append(quantization_scheme.qspec_updater_class) - if quantization_scheme.f_convert_param_bkwd is not None: - self.f_convert_param_bkwd = quantization_scheme.f_convert_param_bkwd - if quantization_scheme.f_compute_relax_param is not None: - self.f_compute_relax_param = quantization_scheme.f_compute_relax_param - if quantization_scheme.f_run_prequantize is not None: - self.f_run_prequantize = quantization_scheme.f_run_prequantize - - self.params_in_func[func_name] = [] - # For each parameter in the input model, get its quantization kind and - # register the parameter with its name and quantization kind. - for name, relax_param in named_parameters(model).items(): - quant_kind = f_get_param_quant_kind(name, relax_param.struct_info) - param = self._register_param( - name, - relax_param, - getattr(quantization_scheme, quant_kind.name), - func_name, - relax_param.__dict__.get("shard_dim", None), - relax_param.__dict__.get("shard_strategy", None), - ) - - self.params_in_func[func_name].append(param) - - def run_pre_quantize(self, model_path: str): - if self.f_run_prequantize is not None: - model_path = self.f_run_prequantize(model_path) - - self.model_path = model_path - return model_path - - def init_torch_pname_to_bin_name(self, use_safetensors: bool): - assert hasattr(self, "model_path"), ( - "Must call either set_param_loading_func or run_pre_quantize " - "before init_torch_pname_to_bin_name" - ) - - if self.pidx2pname: - mapping = load_torch_pname2binname_map( - self.model_path, - use_safetensors, - set(self.pidx2pname.values()), - self.f_convert_pname_fwd, - ) - else: - mapping = {} - - self.torch_pname2binname = mapping - - def set_param_loading_func( - self, - model_path: str, - use_safetensors: bool, - f_convert_pname_fwd: Callable[[str], List[str]] = lambda pname: [pname], - f_convert_param_bkwd: Callable[ - [str, Any], Optional[List[Tuple[str, Any]]] - ] = lambda pname, torch_param: [(pname, torch_param)], - f_compute_relax_param: Callable[[str, List[Any]], Any] = f_default_compute_relax_param, - *, - no_lazy_param_loading: bool = False, - ) -> None: - """Set the parameter loading functions. - - Parameters - ---------- - model_path : str - The path of the Hugging Face model on disk. - - use_safetensors : bool - Whether to use ``.safetensors`` instead of ``.bin`` to load model. - - f_convert_pname_fwd : Callable[[str], List[str]] - The function which converts Relax parameter name (ours) to torch's - parameter names. See the document of ParamManager for more details. - - f_convert_param_bkwd : Callable[[str, Any], Optional[List[Tuple[str, Any]]]] - The function which converts torch parameter and param name back to - Relax parameters with names. `Any` here stands for numpy.ndarray. - See the document of ParamManager for more details. - - f_compute_relax_param : Callable[[str, List[Any]], Any] - The function which computes a Relax parameter from a list of torch - parameters. `Any` here stands for numpy.ndarray. - See the document of ParamManager for more details. - - no_lazy_param_loading : bool - A boolean indicating that no lazy parameter loading from torch is needed. - This needs to be set as True when all the model weights are loaded - at the time of constructing the model. - """ - self.f_convert_pname_fwd = f_convert_pname_fwd - if self.f_convert_param_bkwd is None: - self.f_convert_param_bkwd = f_convert_param_bkwd - if self.f_compute_relax_param is None: - self.f_compute_relax_param = f_compute_relax_param - - self.model_path = model_path - self.use_safetensors = use_safetensors - if self.use_safetensors: - # Use a pointer here to prevent repeated import in tvm registered function - from safetensors.torch import ( - load_file, # pylint: disable=import-outside-toplevel - ) - - def load_safetensors_func(*args): - params = load_file(*args) - for name, param in params.items(): - dtype = str(param.dtype) - if dtype == "torch.bfloat16": - param = param.float() - params[name] = param - return params - - self.safetensors_load_func = load_safetensors_func - - pnames_to_load = [] - for param_name in self.param_names: - param = self.params[param_name] - loaded_names, _ = param.quant_spec.get_loaded_tensor_info(param_name, param.param_info) - pnames_to_load += loaded_names - - self.nparam_to_load = len(pnames_to_load) - if not no_lazy_param_loading: - self.pidx2pname = {pidx: pname for pidx, pname in enumerate(pnames_to_load)} - else: - self.pidx2pname = dict() - - def transform_dequantize(self) -> tvm.ir.transform.Pass: - """Apply dequantization to the input IRModule. - - Parameters - ---------- - mod : tvm.IRModule - The input IRModule to be applied dequantization. - The IRModule contains all the constructed Relax functions - (e.g., the "prefill"/"decode" functions) and is expected to - have all of its parameters registered in the ParamManager. - - Returns - ------- - updated_mod : tvm.IRModule - The IRModule updated with the dequantization computation. - """ - - @tvm.ir.transform.module_pass(opt_level=0, name="ParamManager.transform_dequantize") - def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule: - # For each Relax function in the input IRModule (e.g., "prefill"), - # we create its input relax.Var of all the quantized data, and - # store the mapping from function name to the var. - func_name_to_quantized_params: Dict[str, List[relax.Var]] = {} - - for gv, func in mod.functions.items(): - if isinstance(func, relax.Function) and func.attrs and "num_input" in func.attrs: - func_name_to_quantized_params[gv.name_hint] = self.get_quantized_params( - gv.name_hint - ) - - # Cache mapping to avoid duplicate dequantization. - dequantized_cache: Dict[relax.Var, relax.Var] = {} - - # Define a var replacement function for applying dequantization. - def f_replace(var: relax.Var, bb: relax.BlockBuilder) -> relax.Var: - if var in dequantized_cache: - return dequantized_cache[var] - assert var in self.func_raw_param_map - - func_name, param = self.func_raw_param_map[var] - quantized_params = func_name_to_quantized_params[func_name] - relevant_quantized_params = [quantized_params[i] for i in self.param2qrange[param]] - - dequantized = self._dequantize(param, relevant_quantized_params, bb, func_name) - - dequantized_cache[var] = dequantized - return dequantized - - # Create the function mutator for applying dequantization. - replacer = ParamReplacer(mod, func_name_to_quantized_params, f_replace) - # Update the input IRModule with dequantization. - mod = replacer.transform() - - return mod - - return transform_func - - def get_quantized_params(self, func_name: str) -> List[relax.Var]: - quantized_params: List[relax.Var] = [] - - bb = relax.BlockBuilder() - with bb.function("main", []): - self.param2qrange = dict() - - for name in self.param_names: - param = self.params[name] - param_info = None - if func_name in param.param_info_dict: - param_info = param.param_info_dict[func_name] - else: - param_info = relax.TensorStructInfo( - tvm.ir.load_json(tvm.ir.save_json(param.param_info.shape)), - param.param_info.dtype, - ) - - loaded_tensor_names, loaded_tensor_info = param.quant_spec.get_loaded_tensor_info( - name, param_info - ) - - provided_tensor_vars: List[relax.Var] = [ - relax.Var(name, sinfo) - for name, sinfo in zip(loaded_tensor_names, loaded_tensor_info) - ] - - # Get the quantization function of this parameter. - f_quantize = param.quant_spec.get_quantize_func(param_info) - if f_quantize is None: - # If the parameter does not have a quantization function, either it - # does not need quantization or it is pre-quantized. - self.param2qrange[param] = range( - len(quantized_params), - len(quantized_params) + len(provided_tensor_vars), - ) - quantized_params.extend(provided_tensor_vars) - else: - # If the parameter has a quantization function, it is not expected - # to be pre-quantized. - assert len(provided_tensor_vars) == 1, ( - "A parameter with quantization function is not expected " - "to be pre-quantized." - ) - - # Apply the quantization function. - quantized_data = bb.normalize(f_quantize(bb, provided_tensor_vars)) - if isinstance(quantized_data.struct_info, relax.TupleStructInfo): - fields = quantized_data.struct_info.fields - n_tensor = len(fields) - assert n_tensor > 1 - # Record the range of quantized tensors of this parameter. - self.param2qrange[param] = range( - len(quantized_params), - len(quantized_params) + n_tensor, - ) - # Collect the quantized tensors to return. - quantized_params.extend( - relax.Var(f"{name}.{field.dtype}.{i}", field) - for i, field in enumerate(fields) - ) - - else: - field = quantized_data.struct_info - assert isinstance(field, relax.TensorStructInfo) - self.param2qrange[param] = range( - len(quantized_params), len(quantized_params) + 1 - ) - quantized_params.append(relax.Var(f"{name}.{field.dtype}", field)) - bb.emit_func_output(relax.const(0, "int64")) - - return quantized_params - - def get_param_get_item( - self, device: Device, model_params: List[Optional[tvm.nd.NDArray]] = [] - ) -> Callable: - """A wrapper function which returns the `get_item` - functions for parameter lazy loading. - - The return value of this function is intended to be registered - as `"get_item"`, for use in a module built with - `LazyTransformParams`. - - .. code-block:: python - - get_item = manager.get_param_get_item(tvm.cuda()) - tvm.register_func(func_name="get_item", f=get_item, override=True) - compiled_function() - - Parameters - ---------- - device : Device - - The device onto which tensor parameters should be loaded. - - model_params : List[Optional[tvm.nd.NDArray]] - - Any pre-loaded model parameters. For parameter at index - `i`, if `model_params[i]` already contains an array, that - array will be returned from `get_item`. Otherwise, the - parameter will be loaded either from disk, or from an - internal cache. - - Returns - ------- - get_item: Callable[[int], tvm.nd.NDArray] - - A function that accepts an index, and returns the tensor - parameter located at that index, loaded onto `device`. - - """ - import torch # pylint: disable=import-outside-toplevel - - assert self.f_convert_pname_fwd is not None - assert self.f_convert_param_bkwd is not None - assert self.f_compute_relax_param is not None - pname2pidx: Dict[str, int] = {pname: pidx for pidx, pname in self.pidx2pname.items()} - - # The set of indices of loaded parameters, serving for - # robustness guarantee to avoid one parameter being loaded for - # multiple times. - loaded_idx_set: Set[int] = set() - - # The set of torch binary filenames, serving for robustness guarantee - # to avoid one torch binary file being loaded for multiple times. - loaded_torch_bins: Set[str] = set() - - # The set of cached Relax parameters. - cached_relax_params: Dict[int, tvm.nd.NDArray] = {} - - # The set of cached torch parameters. `Any` here stands for - # numpy.ndarray. - cached_torch_params: Dict[str, Any] = {} - - device_cpu = tvm.cpu() - - def fetch_torch_param(torch_param): - if str(torch_param.dtype) == "torch.bfloat16": - # Convert to float32 first. - return torch_param.detach().cpu().float().numpy() - else: - return torch_param.detach().cpu().numpy() - - def load_torch_params_from_bin(torch_binname: str): - torch_binpath = os.path.join(self.model_path, torch_binname) - torch_params = None - if self.use_safetensors: - torch_params = self.safetensors_load_func(torch_binpath) - else: - torch_params = torch.load( - torch_binpath, - map_location=torch.device("cpu"), - ) - torch_param_names = list(torch_params.keys()) - for torch_param_name in torch_param_names: - torch_param = fetch_torch_param(torch_params[torch_param_name]) - del torch_params[torch_param_name] - - relax_params = self.f_convert_param_bkwd(torch_param_name, torch_param) - if relax_params is not None: - for param_name, param in relax_params: - if param_name not in pname2pidx.keys(): - continue - pidx = pname2pidx[param_name] - assert pidx not in cached_relax_params - cached_relax_params[pidx] = tvm.nd.array(param, device_cpu) - else: - assert torch_param_name not in cached_torch_params - cached_torch_params[torch_param_name] = torch_param - del torch_param - - def get_item(i): - # If the weight is already provided by `model_params`, directly use it - # and no need to load from binary file. - if model_params and len(model_params) > i and model_params[i] is not None: - assert i not in cached_relax_params - return tvm.nd.array(model_params[i], device=device) - - # Otherwise, we load the weight from its corresponding binary file. - assert i in self.pidx2pname - relax_pname = self.pidx2pname[i] - torch_pnames = self.f_convert_pname_fwd(relax_pname) - - if i not in cached_relax_params: - for torch_binname in [ - self.torch_pname2binname[torch_pname] for torch_pname in torch_pnames - ]: - if torch_binname in loaded_torch_bins: - continue - load_torch_params_from_bin(torch_binname) - loaded_torch_bins.add(torch_binname) - - if i not in cached_relax_params: - assert len(torch_pnames) > 1 - assert all([torch_pname in cached_torch_params] for torch_pname in torch_pnames) - cached_relax_params[i] = self.f_compute_relax_param( - relax_pname, - [cached_torch_params[torch_pname] for torch_pname in torch_pnames], - ) - for torch_pname in torch_pnames: - del cached_torch_params[torch_pname] - - assert i in cached_relax_params - assert i not in loaded_idx_set - param_on_device = tvm.nd.array(cached_relax_params[i], device=device) - loaded_idx_set.add(i) - del cached_relax_params[i] - return param_on_device - - return get_item - - def get_param_set_item(self) -> Tuple[Callable, List[tvm.nd.NDArray]]: - """A wrapper function which returns the `set_item` - functions for parameter lazy loading. - - The return value of this function is intended to be registered - as `"set_item"`, for use in a module built with - `LazyTransformParams`. - - .. code-block:: python - - set_item,loaded_params = manager.get_param_set_item() - tvm.register_func(func_name="set_item", f=set_item, override=True) - compiled_function() - # `loaded_params` is now fully populated - - Returns - ------- - set_item: Callable[[int,tvm.nd.NDArray]] - - A function that accepts an index and the return value at - that index. - - loaded_params: List[tvm.nd.NDArray] - - A list of loaded parameters, populated by `set_item`. - When initially returned, this list is empty. After - executing the compiled function with - `LazyTransformParams`, `loaded_params` will be - populated. - """ - device_cpu = tvm.cpu() - loaded_params: List[tvm.nd.NDArray] = [] - - def set_item(i: int, computed_param: tvm.nd.NDArray): - if len(loaded_params) <= i: - loaded_params.extend([None for _ in range(i - len(loaded_params) + 1)]) - loaded_params[i] = tvm.nd.array(computed_param, device=device_cpu) - - return set_item, loaded_params - - #################### Below are internally called methods #################### - - def _register_param( - self, - name: str, - var: relax.Var, - quant_spec: quantization.QuantizationSpec, - func_name: str, - shard_dim: Optional[int], - shard_strategy: Optional[str], - ) -> Parameter: - """Register a single parameter in the parameter manager. - In most cases, this method is not directly used outside this class: - it is called by `register_params` above. - - Parameters - ---------- - name : str - The name of the parameter to register. - Name serves as the unique identifier of the parameter. - - var : relax.Var - The parameter relax.Var on the nn.Module side. - - quant_spec : quantization.QuantizationSpec - The quantization specification of the parameter - - func_name : str - The name of the function the input var is in. - For example, the "prefill" function or the "decode" function. - - shard_dim : Optional[int] - The dimension along which the parameter is sharded. - - shard_strategy : Optional[str] - The strategy of sharding the parameter. - - Returns - ------- - param : Parameter - The registered Parameter. - """ - assert ( - var not in self.func_raw_param_map - ), "The input var is not supposed to be already registered." - assert isinstance( - var.struct_info.shape, relax.ShapeExpr - ), "The parameter to register is expected to have shape as a tuple" - - if name in self.params: - # When the input name appears in `self.params`, it means the input - # parameter has been previously registered in some other function. - # Thus, we check if the dtype, shape and the quantization specification - # of both sides are consistent. - param = self.params[name] - assert ( - param.quant_spec == quant_spec - ), "One parameter is expected to be quantized by single specification in all functions." - assert ( - param.param_info.dtype == var.struct_info.dtype - ), "Dtype mismatch of one parameter in two functions." - assert ( - param.param_info.ndim == var.struct_info.ndim - ), "Shape mismatch of one parameter in two functions." - for len0, len1 in zip(param.param_info.shape.values, var.struct_info.shape.values): - if isinstance(len0, tir.IntImm) and isinstance(len1, tir.IntImm): - assert ( - len0.value == len1.value - ), "Shape mismatch of one parameter in two functions." - else: - # Otherwise, the parameter is registered for the first time. - param = Parameter(name, quant_spec, shard_dim, shard_strategy) - self.params[name] = param - self.param_names.append(name) - - param.register_func(func_name, var.struct_info) - # Record the mapping from the input relax.Var to the function name and - # the parameter in the manager. - self.func_raw_param_map[var] = (func_name, param) - return param - - def _dequantize( - self, - param: Parameter, - qparams: List[relax.Var], - bb: relax.BlockBuilder, - func_name: str, - ) -> relax.Var: - """Applying dequantization to the input parameter. - This method is called by `transform_module` below, and is not - directly invoked outside the class. - - Parameters - ---------- - param : Parameter - The parameter whose quantized tensors are to be dequantized. - - qparams : List[relax.Var] - The relax.Var of the quantized tensors of all parameters in the model. - - Returns - ------- - The dequantized parameter, in the form of a relax.Var. - """ - # Get the dequantization function of this parameter. - f_dequantize = param.quant_spec.get_dequantize_func( - param_info=param.param_info_dict[func_name], - qparam_info=[qparam.struct_info for qparam in qparams], - ) - if f_dequantize is None: - # If the parameter does not have a dequantization function, its "quantized - # data" is expected to have only one element. - assert len(qparams) == 1, ( - "A parameter without dequantization function is expected not to have " - 'more than one "quantized data".' - ) - return qparams[0] - else: - # Apply the dequantization function. - return bb.emit(f_dequantize(bb, qparams)) - - def create_parameter_transformation(self, optimize_parameter_order: bool = True): - """Produce an IRModule that can transform the parameters - - Parameters - ---------- - optimize_parameter_order: bool - - If true, reorder the parameter transformations to - prioritize operations that use a currently-open file. If - false, transform the parameters in their default order. - - Returns - ------- - tvm.IRModule - The transformation module - - """ - mod = _create_quantize_func(self) - if optimize_parameter_order: - mod = self.optimize_transform_param_order()(mod) - return mod - - def optimize_transform_param_order(self) -> tvm.transform.Pass: - """Produce an transformation that optimizes for minimal memory footprint - - Returns - ------- - tvm.transform.Pass - The transformation - """ - - pidx2binname: Dict[int, str] = { - pidx: self.torch_pname2binname[self.f_convert_pname_fwd(pname)[0]] - for pidx, pname in self.pidx2pname.items() - if self.f_convert_pname_fwd(pname)[0] in self.torch_pname2binname - } - return ReorderTransformFunc(pidx2binname) - - -@mutator -class ParamReplacer(PyExprMutator): - """The function mutator that updates the model with dequantization. - - Attributes - ---------- - mod : tvm.IRModule - The IRModule of the model to be updated. - - func_name_to_quantized_params : Dict[str, List[relax.Var]] - The mapping from each function name to its input var of quantized data tuple. - - f_replace : Callable[[relax.Var, relax.BlockBuilder], relax.Var] - The function for updating a previous parameter in functions with dequantization. - - param_set : Set[relax.Var] - The set of previous parameters (before applying quantization and dequantization) - in the relax functions. - """ - - mod: tvm.IRModule - func_name_to_quantized_params: Dict[str, List[relax.Var]] - f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var] - param_set: Set[relax.Var] - - cur_func_name: str - - def __init__( - self, - mod: tvm.IRModule, - func_name_to_quantized_params: Dict[str, relax.Var], - f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var], - ): - super().__init__(mod) - self.mod = mod - self.func_name_to_quantized_params = func_name_to_quantized_params - self.f_replace = f_replace - self.cur_func_name = "" - - def transform(self) -> tvm.IRModule: - for gv, func in self.mod.functions.items(): - if not isinstance(func, relax.Function): - continue - if func.attrs is None or not "num_input" in func.attrs: - continue - - assert ( - gv.name_hint in self.func_name_to_quantized_params - ), f"{gv.name_hint} not in {self.func_name_to_quantized_params}" - updated_func = self.rewrite_func(func, self.func_name_to_quantized_params[gv.name_hint]) - updated_func = remove_all_unused(updated_func) - self.builder_.update_func(gv, updated_func) - return self.builder_.get() - - def rewrite_func(self, func: Function, quantized_params: List[relax.Var]) -> relax.Function: - num_input = int(func.attrs["num_input"]) - self.param_set = set(func.params[num_input:]) - - body = self.visit_expr(func.body) - return relax.Function( - params=func.params[:num_input] + quantized_params, - body=body, - ret_struct_info=func.ret_struct_info, - is_pure=func.is_pure, - attrs=func.attrs, - ) - - def visit_var_(self, var: Var) -> Expr: - if var in self.param_set: - return self.f_replace(var, self.builder_) - else: - return super().visit_var_(var) - - -################################################################## - - -def load_torch_pname2binname_map( - model_path: str, - use_safetensors: bool, - relax_pnames: Set[str], - f_convert_pname_fwd: Callable[[str], List[str]] = lambda pname: [pname], -) -> Dict[str, str]: - """Constructing the dictionary from each torch parameter's name to - the name of the binary shard where the torch parameter is saved. - - Parameters - ---------- - model_path : str - The path of the Hugging Face model on disk. - - use_safetensors: bool - Whether to use ``.safetensors`` instead of ``.bin`` to load model. - - relax_pnames: Set[str] - The name of the Relax parameters. - - f_convert_pname_fwd: Callable[[str], List[str]] - The function which converts Relax parameter name to torch's - parameter names. See ParamManager for more details. - """ - bin_idx_path = None - single_shard_file_name = None - if use_safetensors: - bin_idx_path = os.path.join(model_path, "model.safetensors.index.json") - single_shard_file_name = "model.safetensors" - else: - bin_idx_path = os.path.join(model_path, "pytorch_model.bin.index.json") - single_shard_file_name = "pytorch_model.bin" - single_shard_path = os.path.join(model_path, single_shard_file_name) - - if os.path.isfile(bin_idx_path): - # Multiple weight shards. - with open(bin_idx_path, "r") as f_torch_json: - torch_bin_json = json.load(f_torch_json) - torch_pname2binname = torch_bin_json["weight_map"] - elif os.path.isfile(single_shard_path): - # Single weight shard. - torch_pname2binname = { - torch_pname: single_shard_file_name - for relax_pname in relax_pnames - for torch_pname in f_convert_pname_fwd(relax_pname) - } - else: - suffix = ".safetensors" if use_safetensors else ".bin" - shard_names = [] - # Collect Scan every single file with the suffix - for filename in os.listdir(model_path): - if filename.endswith(suffix): - shard_names.append(filename) - if len(shard_names) == 1: - torch_pname2binname = { - torch_pname: shard_names[0] - for relax_pname in relax_pnames - for torch_pname in f_convert_pname_fwd(relax_pname) - } - else: - raise ValueError("Multiple weight shard files without json map is not supported") - return torch_pname2binname - - -def _create_quantize_func(param_manager: ParamManager) -> tvm.IRModule: - """Construct the Relax function which computes quantization. - This method is called by `transform_module` below, and is not - directly invoked outside the class. - - Parameters - ---------- - param_manager : ParamManager - The parameter manager which has all the parameter information. - - Returns - ------- - The created function which computes quantization. - Precisely, an IRModule which contains the main quantization Relax function - and a series of TIR functions is returned. - """ - bb = relax.BlockBuilder() - param2qrange = dict() - - # Construct the input of the function. - # We need a list of ranges for each - # parameter to get its corresponding tensors loaded from disk. - input_tensor_info: List[relax.TensorStructInfo] = [] - loaded_tensor_ranges: List[range] = [] - for name in param_manager.param_names: - param = param_manager.params[name] - _, loaded_tensor_info = param.quant_spec.get_loaded_tensor_info(name, param.param_info) - loaded_tensor_ranges.append( - range( - len(input_tensor_info), - len(input_tensor_info) + len(loaded_tensor_info), - ) - ) - input_tensor_info += loaded_tensor_info - raw_param_tuple = relax.Var("params", relax.TupleStructInfo(input_tensor_info)) - - with bb.function("transform_params", params=[raw_param_tuple]): - with bb.dataflow(): - quantized_params: List[relax.Var] = [] - for pidx, name in enumerate(param_manager.param_names): - param = param_manager.params[name] - param_vars: List[relax.Var] = [] - # Emit relax.TupleGetItem to get the raw parameters or pre-quantized params. - for loaded_tensor_idx in loaded_tensor_ranges[pidx]: - param_vars.append( - bb.emit(relax.TupleGetItem(raw_param_tuple, loaded_tensor_idx)) - ) - - # Get the quantization function of this parameter. - f_quantize = param.quant_spec.get_quantize_func(param.param_info) - if f_quantize is None: - # If the parameter does not have a quantization function, either it - # does not need quantization or it is pre-quantized. - param2qrange[param] = range( - len(quantized_params), - len(quantized_params) + len(param_vars), - ) - quantized_params += param_vars - else: - # If the parameter has a quantization function, it is not expected - # to be pre-quantized. - assert len(param_vars) == 1, ( - "A parameter with quantization function is not expected " - "to be pre-quantized." - ) - - # Apply the quantization function. - quantized_data = bb.emit(f_quantize(bb, param_vars)) - - if isinstance(quantized_data.struct_info, relax.TupleStructInfo): - n_tensor = len(quantized_data.struct_info.fields) - assert n_tensor > 1 - # Record the range of quantized tensors of this parameter. - param2qrange[param] = range( - len(quantized_params), len(quantized_params) + n_tensor - ) - # Collect the quantized tensors to return. - for i in range(n_tensor): - quantized_params.append(bb.emit(relax.TupleGetItem(quantized_data, i))) - else: - assert isinstance(quantized_data.struct_info, relax.TensorStructInfo) - param2qrange[param] = range( - len(quantized_params), len(quantized_params) + 1 - ) - quantized_params.append(quantized_data) - - output = bb.emit_output(relax.Tuple(quantized_params)) - bb.emit_func_output(output) - - mod = bb.get() - param_manager.param2qrange = param2qrange - # Return the created IRModule. - return bb.get() - - -def transform_params_for_each_rank( - num_shards: int, rank_argument_name: str = "rank_arg" -) -> tvm.ir.transform.Pass: - """Update a parameter transform to apply across all ranks - - For use in generating a pre-sharded set of weights. Given a - parameter transformation that generates sharded model weights for - a single shard, produce a parameter transformation that generates - sharded model weights for each shard. - - Parameters - ---------- - mod: tvm.IRModule - - A module containing the parameter transformation function, - named "transform_params", along with any subroutines called by - the parameter transformation. - - num_shards: int - - The number of shards to generate. - - rank_argument_name: str - - The name of the argument that specifies the rank. Should be a - R.ShapeTuple with a single R.PrimStructInfo('int64'). - - Returns - ------- - tvm.IRModule - - The modified parameter transformation - """ - - @tvm.ir.transform.module_pass(opt_level=0, name="ParamManager.transform_params_for_each_rank") - def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule: - generic_transform = mod["transform_params"] - - if generic_transform.attrs is not None and "num_input" in generic_transform.attrs: - num_input = generic_transform.attrs["num_input"].value - else: - num_input = 0 - - if num_input == 0: - return mod - - tensor_params = generic_transform.params[num_input:] - attrs = {"num_input": num_input - 1} - - bb = relax.BlockBuilder() - - with bb.function("transform_params", params=tensor_params, attrs=attrs): - output = [] - for rank in range(num_shards): - # TODO(Lunderberg): Implement this in terms of a - # generic utility that inlines local functions. - func = generic_transform - func = func.bind_params({rank_argument_name: relax.ShapeExpr([rank])}) - func = relax.utils.copy_with_new_vars(func) - func = func.bind_params( - {var: tensor_param for (var, tensor_param) in zip(func.params, tensor_params)} - ) - shard_tuple = func.body - output.extend([shard_tuple[i] for i in range(len(tensor_params))]) - - with bb.dataflow(): - gv = bb.emit_output(relax.Tuple(output)) - bb.emit_func_output(gv) - - mod = mod.clone() - mod["transform_params"] = bb.get()["transform_params"] - return mod - - return transform_func - - -def chain_parameter_transforms(mod_a: tvm.IRModule, mod_b: tvm.IRModule) -> tvm.IRModule: - """Chain two sequential parameter transformations - - For use in manipulating sets of model weights. Given two - parameter transformations that could be applied sequentially, - produce a single parameter transformation whose output is the same - as applying the parameter transformations sequentially. - - - .. code-block:: python - - # Before - params_after_a = mod_a['transform_params'](orig_params) - params_after_b = mod_b['transform_params'](params_after_a) - - # After - mod_ab = chain_parameter_transforms(mod_a, mod_b) - params_after_b = mod_ab['transform_params'](orig_params) - - Parameters - ---------- - mod_a: tvm.IRModule - - The module containing the first parameter transformation. - - mod_b: tvm.IRModule - - The module containing the second parameter transformation. - - Returns - ------- - tvm.IRModule - - The module containing the output - - """ - func_a = mod_a["transform_params"] - func_b = mod_b["transform_params"] - - bb = relax.BlockBuilder() - - def get_num_input_attr(func): - if func.attrs is None: - return 0 - - attrs = func.attrs - if "num_input" not in attrs: - return 0 - num_input = attrs["num_input"] - - assert isinstance(num_input, tvm.tir.IntImm) - return num_input.value - - # Either func_a or func_b may have parameters that are provided at - # a later point. The chaining of parameter transforms assumes - # that all model weights accepted by func_b are produced by - # func_a. If func_b accepts non-weight parameters (e.g. the GPU - # rank), these must still be provided. - func_a_num_input = get_num_input_attr(func_a) - func_b_num_input = get_num_input_attr(func_b) - - output_num_input = func_a_num_input + func_b_num_input - output_params = [ - *func_a.params[:func_a_num_input], - *func_b.params[:func_b_num_input], - *func_a.params[func_a_num_input:], - ] - - with bb.function( - "transform_params", params=output_params, attrs={"num_input": output_num_input} - ): - with bb.dataflow(): - # TODO(Lunderberg): Implement this in terms of a - # generic utility that inlines local functions. - func_a_output = bb.emit(func_a.body) - func_b_param_map = { - param: expr - for (param, expr) in zip(func_b.params[func_b_num_input:], func_a_output) - } - func_b_output = func_b.bind_params(func_b_param_map).body - gv = bb.emit_output(func_b_output) - bb.emit_func_output(gv) - - merged_transform_func = bb.get()["transform_params"] - - new_mod = { - **{ - gvar: func - for gvar, func in mod_a.functions.items() - if gvar.name_hint != "transform_params" - }, - **{ - gvar: func - for gvar, func in mod_b.functions.items() - if gvar.name_hint != "transform_params" - }, - "transform_params": merged_transform_func, - } - return tvm.IRModule(new_mod) diff --git a/mlc_llm/relax_model/rwkv.py b/mlc_llm/relax_model/rwkv.py deleted file mode 100644 index 3c1a9ffa0d..0000000000 --- a/mlc_llm/relax_model/rwkv.py +++ /dev/null @@ -1,613 +0,0 @@ -# pylint: disable=missing-docstring,invalid-name -from dataclasses import dataclass -from typing import List, Literal, Tuple - -from tvm import relax, te, tir -from tvm.relax import Expr, op -from tvm.relax.testing import nn -from tvm.script import relax as R -from tvm.script import tir as T - -from ..quantization import ParamQuantKind, QuantizationScheme -from .commons import create_metadata_func -from .modules import Linear, ModuleList -from .param_manager import ParamManager - -# Reference: https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model_run.py - - -@dataclass -class RWKVConfig: - """The configuration class to store the configuration of a `RWKVModel`.""" - - num_hidden_layers: int - vocab_size: int - hidden_size: int - intermediate_size: int - rescale_every: int = 0 - layer_norm_epsilon: float = 1e-5 - max_sequence_length: int = 1024 - dtype: str = "float32" - - def __init__( - self, - num_hidden_layers: int, - vocab_size: int, - hidden_size: int, - intermediate_size: int, - rescale_every: int = 0, - layer_norm_epsilon: float = 1e-5, - context_length: int = 1024, - dtype: str = "float32", - **kwargs, - ) -> None: - self.num_hidden_layers = num_hidden_layers - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.rescale_every = rescale_every - self.layer_norm_epsilon = layer_norm_epsilon - self.max_sequence_length = context_length - self.dtype = dtype - self.kwargs = kwargs - - -class State: - ATT_X = 0 - ATT_A = 1 - ATT_B = 2 - ATT_P = 3 - FFN_X = 4 - - -def _load_state(state: Expr, hidden_size: int, dtype: str) -> Expr: - # Reuse `attention_kv_cache_view` - f_load_cache = relax.extern("vm.builtin.attention_kv_cache_view") - cache = nn.emit( - relax.call_pure_packed( - f_load_cache, - state, - R.shape([1, hidden_size]), - sinfo_args=[R.Tensor((1, hidden_size), dtype)], - ) - ) - return cache - - -def _store_state(state: Expr, value: Expr): - # Reuse `attention_kv_cache_update` - f_store_cache = relax.extern("vm.builtin.attention_kv_cache_update") - - return nn.emit( - relax.op.call_inplace_packed( - f_store_cache, - state, - value, - inplace_indices=[0], - sinfo_args=[R.Object()], - ) - ) - - -def is_one(x: tir.PrimExpr) -> bool: - return isinstance(x, tir.IntImm) and x.value == 1 - - -def create_wkv_func(hidden_size: int, dtype: str, out_dtype: str): - @T.prim_func - def wkv_func( - k: T.handle, - v: T.handle, - time_decay: T.handle, - time_first: T.handle, - saved_a: T.handle, - saved_b: T.handle, - saved_p: T.handle, - wkv: T.handle, - out_a: T.handle, - out_b: T.handle, - out_p: T.handle, - ): - T.func_attr({"op_pattern": 8, "tir.noalias": True, "tir.is_scheduled": 1}) - context_length = T.int64() - K = T.match_buffer(k, (context_length, hidden_size), dtype=dtype) - V = T.match_buffer(v, (context_length, hidden_size), dtype=dtype) - TimeDecay = T.match_buffer(time_decay, (hidden_size,), dtype=dtype) - TimeFirst = T.match_buffer(time_first, (hidden_size,), dtype=dtype) - SavedA = T.match_buffer(saved_a, (1, hidden_size), dtype=dtype) - SavedB = T.match_buffer(saved_b, (1, hidden_size), dtype=dtype) - SavedP = T.match_buffer(saved_p, (1, hidden_size), dtype=dtype) - WKV = T.match_buffer(wkv, (context_length, hidden_size), dtype=out_dtype) - OutA = T.match_buffer(out_a, (1, hidden_size), dtype=dtype) - OutB = T.match_buffer(out_b, (1, hidden_size), dtype=dtype) - OutP = T.match_buffer(out_p, (1, hidden_size), dtype=dtype) - - P = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local") - E1 = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local") - E2 = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local") - A_local = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local") - B_local = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local") - P_local = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local") - - for bx in T.thread_binding(hidden_size // 32, thread="blockIdx.x"): - for tx in T.thread_binding(32, thread="threadIdx.x"): - with T.block("init"): - vi = T.axis.S(hidden_size, bx * 32 + tx) - A_local[vi] = SavedA[0, vi] - B_local[vi] = SavedB[0, vi] - P_local[vi] = SavedP[0, vi] - for j in range(context_length): - with T.block("main"): - vi = T.axis.S(hidden_size, bx * 32 + tx) - vj = T.axis.opaque(context_length, j) - P[vi] = T.max(P_local[vi], K[vj, vi] + TimeFirst[vi]) - E1[vi] = T.exp(P_local[vi] - P[vi]) - E2[vi] = T.exp(K[vj, vi] + TimeFirst[vi] - P[vi]) - WKV[vj, vi] = T.cast( - (E1[vi] * A_local[vi] + E2[vi] * V[vj, vi]) - / (E1[vi] * B_local[vi] + E2[vi]), - out_dtype, - ) - - P[vi] = T.max(P_local[vi] + TimeDecay[vi], K[vj, vi]) - E1[vi] = T.exp(P_local[vi] + TimeDecay[vi] - P[vi]) - E2[vi] = T.exp(K[vj, vi] - P[vi]) - A_local[vi] = E1[vi] * A_local[vi] + E2[vi] * V[vj, vi] - B_local[vi] = E1[vi] * B_local[vi] + E2[vi] - P_local[vi] = P[vi] - - with T.block("write_back"): - vi = T.axis.S(hidden_size, bx * 32 + tx) - OutA[0, vi] = A_local[vi] - OutB[0, vi] = B_local[vi] - OutP[0, vi] = P_local[vi] - - return wkv_func - - -def _te_concat_saved_x(saved_x: te.Tensor, x: te.Tensor): - return te.compute( - x.shape, - lambda i, j: tir.if_then_else(i == 0, saved_x[0, j], x[i - 1, j]), - ) - - -def _te_get_last_x(x: te.Tensor): - seq_len, hidden_size = x.shape - return te.compute((1, hidden_size), lambda _, j: x[seq_len - 1, j]) - - -class RWKV_Embedding(nn.Module): - def __init__(self, num_embeddings, embedding_dim, dtype): - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.weight = nn.Parameter((num_embeddings, embedding_dim), dtype=dtype, name="weight") - - def forward(self, x: relax.Expr) -> relax.Var: - x = nn.emit(op.reshape(x, shape=[-1])) - return nn.emit(op.take(self.weight, x, axis=0)) - - -class RWKV_LayerNorm(nn.Module): - def __init__(self, intermediate_size, dtype, eps=1e-5, name_prefix=""): - super().__init__() - self.eps = eps - self.weight = nn.Parameter( - (intermediate_size,), dtype=dtype, name=f"{name_prefix}_ln_weight" - ) - self.bias = nn.Parameter((intermediate_size,), dtype=dtype, name=f"{name_prefix}_ln_bias") - - def forward(self, x: relax.Expr) -> relax.Var: - x = nn.emit( - op.nn.layer_norm( - x, - gamma=self.weight, - beta=self.bias, - axes=-1, - epsilon=self.eps, - ) - ) - return x - - -class RWKV_FFN(nn.Module): - def __init__(self, config: RWKVConfig, index: int) -> None: - super().__init__() - self.hidden_size = config.hidden_size - self.dtype = config.dtype - self.index = index - self.time_mix_key = nn.Parameter( - (self.hidden_size,), dtype=config.dtype, name=f"ffn_{index}_time_mix_k" - ) - self.time_mix_receptance = nn.Parameter( - (self.hidden_size,), dtype=config.dtype, name=f"ffn_{index}_time_mix_r" - ) - self.key = Linear( - self.hidden_size, config.intermediate_size, dtype=config.dtype, bias=False - ) - self.receptance = Linear(self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False) - self.value = Linear( - config.intermediate_size, self.hidden_size, dtype=config.dtype, bias=False - ) - - def forward(self, x: Expr, state: Expr) -> Expr: - offset = self.index * 5 + State.FFN_X - context_length = x.struct_info.shape[0] - hidden_size = self.hidden_size - - saved_x = _load_state(state[offset], hidden_size, self.dtype) - if not is_one(context_length): - saved_x = nn.emit_te(_te_concat_saved_x, saved_x, x) - ones = nn.emit(relax.op.ones((hidden_size,), self.dtype)) - xk = nn.emit(x * self.time_mix_key + saved_x * (ones - self.time_mix_key)) - xr = nn.emit(x * self.time_mix_receptance + saved_x * (ones - self.time_mix_receptance)) - if not is_one(context_length): - x = nn.emit_te(_te_get_last_x, x) - assert is_one(x.struct_info.shape[0]) - saved_x = _store_state(state[offset], x) - - r = nn.emit(op.sigmoid(self.receptance(xr))) - xv = nn.emit(op.square(op.nn.relu(self.key(xk)))) - - return nn.emit(r * self.value(xv)), [saved_x] - - -class RWKV_Attention(nn.Module): - def __init__(self, config: RWKVConfig, index: int) -> None: - super().__init__() - self.index = index - self.dtype = config.dtype - self.hidden_size = config.hidden_size - self.time_decay = nn.Parameter( - (self.hidden_size,), dtype="float32", name=f"att_{index}_time_decay" - ) - self.time_first = nn.Parameter( - (self.hidden_size,), dtype="float32", name=f"att_{index}_time_first" - ) - self.time_mix_key = nn.Parameter( - (self.hidden_size,), dtype=config.dtype, name=f"att_{index}_time_mix_k" - ) - self.time_mix_value = nn.Parameter( - (self.hidden_size,), dtype=config.dtype, name=f"att_{index}_time_mix_v" - ) - self.time_mix_receptance = nn.Parameter( - (self.hidden_size,), dtype=config.dtype, name=f"att_{index}_time_mix_r" - ) - self.key = Linear(self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False) - self.value = Linear(self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False) - self.receptance = Linear(self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False) - self.output = Linear(self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False) - - def forward(self, x: Expr, state: Expr) -> Expr: - # Load current state - ones = nn.emit(relax.op.ones((self.hidden_size,), self.dtype)) - index = self.index - hidden_size = self.hidden_size - context_length = x.struct_info.shape[0] - bb = relax.BlockBuilder.current() - - saved_a = _load_state(state[index * 5 + State.ATT_A], hidden_size, "float32") - saved_b = _load_state(state[index * 5 + State.ATT_B], hidden_size, "float32") - saved_p = _load_state(state[index * 5 + State.ATT_P], hidden_size, "float32") - saved_x = _load_state(state[index * 5 + State.ATT_X], hidden_size, self.dtype) - if not is_one(context_length): - saved_x = nn.emit_te(_te_concat_saved_x, saved_x, x) - - xk = nn.emit(x * self.time_mix_key + saved_x * (ones - self.time_mix_key)) - xv = nn.emit(x * self.time_mix_value + saved_x * (ones - self.time_mix_value)) - xr = nn.emit(x * self.time_mix_receptance + saved_x * (ones - self.time_mix_receptance)) - - r = nn.emit(op.sigmoid(self.receptance(xr))) - k = nn.emit(op.astype(self.key(xk), "float32")) - v = nn.emit(op.astype(self.value(xv), "float32")) - - gv = bb.add_func(create_wkv_func(hidden_size, "float32", self.dtype), "wkv") - ret = nn.emit( - relax.call_tir( - gv, - [k, v, self.time_decay, self.time_first, saved_a, saved_b, saved_p], - [ - R.Tensor((context_length, hidden_size), self.dtype), - R.Tensor((1, hidden_size), "float32"), - R.Tensor((1, hidden_size), "float32"), - R.Tensor((1, hidden_size), "float32"), - ], - ) - ) - if not is_one(context_length): - x = nn.emit_te(_te_get_last_x, x) - - assert is_one(x.struct_info.shape[0]) - saved_x = _store_state(state[self.index * 5 + State.ATT_X], x) - saved_a = _store_state(state[self.index * 5 + State.ATT_A], ret[1]) - saved_b = _store_state(state[self.index * 5 + State.ATT_B], ret[2]) - saved_p = _store_state(state[self.index * 5 + State.ATT_P], ret[3]) - - return nn.emit(self.output(r * ret[0])), [ - saved_x, - saved_a, - saved_b, - saved_p, - ] - - -class RWKVLayer(nn.Module): - def __init__(self, config: RWKVConfig, index: int) -> None: - super().__init__() - if index == 0: - self.pre_ln = RWKV_LayerNorm( - config.hidden_size, - config.dtype, - eps=config.layer_norm_epsilon, - name_prefix="pre_ln", - ) - self.ln1 = RWKV_LayerNorm( - config.hidden_size, - config.dtype, - eps=config.layer_norm_epsilon, - name_prefix=f"att_{index}", - ) - self.ln2 = RWKV_LayerNorm( - config.hidden_size, - config.dtype, - eps=config.layer_norm_epsilon, - name_prefix=f"ffn_{index}", - ) - self.attention = RWKV_Attention(config, index) - self.feed_forward = RWKV_FFN(config, index) - self.rescale_every = config.rescale_every - self.dtype = config.dtype - self.index = index - - def forward(self, x: Expr, state: Expr) -> Tuple[Expr, List[Expr]]: - if self.index == 0: - x = self.pre_ln(x) - att, att_state = self.attention(self.ln1(x), state) - x = nn.emit(x + att) - ffn, ffn_state = self.feed_forward(self.ln2(x), state) - x = nn.emit(x + ffn) - if self.rescale_every > 0 and (self.index + 1) % self.rescale_every == 0: - x = nn.emit(x / relax.const(2, dtype=self.dtype)) - return x, att_state + ffn_state - - -class RWKVModel(nn.Module): - def __init__(self, config: RWKVConfig) -> None: - super().__init__() - self.embeddings = RWKV_Embedding( - num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - dtype=config.dtype, - ) - self.blocks = ModuleList([RWKVLayer(config, i) for i in range(config.num_hidden_layers)]) - self.ln_out = RWKV_LayerNorm( - config.hidden_size, - config.dtype, - eps=config.layer_norm_epsilon, - name_prefix="out_ln", - ) - self.hidden_size = config.hidden_size - self.dtype = config.dtype - - def forward(self, input_ids: Expr, state: Expr) -> Tuple[Expr, List[Expr]]: - hidden_states = self.embeddings(input_ids) - states = [] - for _, layer in enumerate(self.blocks): - hidden_states, layer_states = layer(hidden_states, state) - states += layer_states - context_length = hidden_states.struct_info.shape[0] - if not is_one(context_length): - hidden_states = nn.emit_te(_te_get_last_x, hidden_states) - hidden_states = self.ln_out(hidden_states) - return hidden_states, states - - -class RWKVForCausalLM(nn.Module): - def __init__(self, config: RWKVConfig): - self.rwkv = RWKVModel(config) - self.head = Linear(config.hidden_size, config.vocab_size, dtype=config.dtype, bias=False) - self.vocab_size = config.vocab_size - ############ End ############ - - def forward( - self, - input_ids: relax.Expr, - state: relax.Expr, - ): - hidden_states, key_value_cache = self.rwkv(input_ids, state) - logits = nn.emit(self.head(hidden_states)) - logits = nn.emit(op.reshape(logits, (1, 1, self.vocab_size))) - if logits.struct_info.dtype != "float32": - logits = nn.emit(relax.op.astype(logits, "float32")) - - return logits, key_value_cache - - -def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: - if name.endswith("embeddings.weight"): - return ParamQuantKind.embedding_table - elif name == "head.weight": - return ParamQuantKind.final_fc_weight - elif param_info.ndim == 2 and name.endswith(".weight"): - return ParamQuantKind.linear_weight - else: - return ParamQuantKind.others - - -def create_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: RWKVConfig, - quant_scheme: QuantizationScheme, - func_name=Literal["prefill", "decode"], -): - if func_name not in ["prefill", "decode"]: - raise ValueError(f"func_name must be 'prefill' or 'decode', got {func_name}") - seq_len = 1 if func_name == "decode" else tir.SizeVar("n", "int64") - - with bb.function(func_name): - model = RWKVForCausalLM(config) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((1, seq_len), dtype="int32", name="input_ids") - # Placeholder for compatibility to LLAMA - all_seq_len_shape = relax.Var("place_holder", R.Object()) - state = relax.Var("state", R.Tuple([R.Object()] * config.num_hidden_layers * 5)) - with bb.dataflow(): - logits, states = model(input_ids, state) - params = [ - input_ids, - all_seq_len_shape, - state, - ] + model.parameters() - - gv = bb.emit_output((logits, relax.Tuple(states))) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - f = mod[gv].with_attr("num_input", 3) - if func_name == "prefill": - f = f.with_attr("tir_var_upper_bound", {"n": config.max_sequence_length}) - bb.update_func(gv, f) - - -def create_kv_cache_func(bb: relax.BlockBuilder, config: RWKVConfig) -> None: - """NOTE: It's not typical kv-cache, but try to reuse the logic for the quick hack.""" - init_shape = relax.ShapeExpr((1, config.hidden_size)) - with bb.function("create_kv_cache", []): - with bb.dataflow(): - input_dtype_zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) - fp32_zeros = bb.emit(relax.op.zeros(init_shape, "float32")) - fp32_neg_inf = bb.emit(fp32_zeros - relax.const(1e30, "float32")) - caches = [] - f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") - conf = [ - ("att_x", input_dtype_zeros), - ("att_a", fp32_zeros), - ("att_b", fp32_zeros), - ("att_p", fp32_neg_inf), - ("ffn_x", input_dtype_zeros), - ] - for i in range(config.num_hidden_layers): - for name, init_value in conf: - caches.append( - bb.emit( - relax.call_pure_packed( - f_kv_cache_create, - init_value, - init_shape, - relax.PrimValue(1), - sinfo_args=[R.Object()], - ), - name_hint=f"{name}_state_{i}", - ) - ) - gv = bb.emit_output(caches) - bb.emit_func_output(gv) - - -def create_kv_cache_reset_func(bb: relax.BlockBuilder, config: RWKVConfig) -> None: - state = relax.Var("state", R.Tuple([R.Object()] * config.num_hidden_layers * 5)) - init_shape = relax.ShapeExpr((1, config.hidden_size)) - with bb.function("reset_kv_cache", [state]): - with bb.dataflow(): - input_dtype_zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) - fp32_zeros = bb.emit(relax.op.zeros(init_shape, "float32")) - fp32_neg_inf = bb.emit(fp32_zeros - relax.const(1e30, "float32")) - caches = [] - for i in range(config.num_hidden_layers): - caches.append(_store_state(state[i * 5 + State.ATT_X], input_dtype_zeros)) - caches.append(_store_state(state[i * 5 + State.ATT_B], fp32_zeros)) - caches.append(_store_state(state[i * 5 + State.ATT_A], fp32_zeros)) - caches.append(_store_state(state[i * 5 + State.ATT_P], fp32_neg_inf)) - caches.append(_store_state(state[i * 5 + State.FFN_X], input_dtype_zeros)) - gv = bb.emit_output(caches) - bb.emit_func_output(gv) - - -def create_softmax_func(bb: relax.BlockBuilder, config: RWKVConfig) -> None: - with bb.function("softmax_with_temperature"): - logits = nn.Placeholder((1, 1, config.vocab_size), dtype="float32", name="logits") - temperature = nn.Placeholder((), dtype="float32", name="temperature") - with bb.dataflow(): - div = bb.emit(relax.op.divide(logits, temperature)) - softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) - gv = bb.emit_output(softmax) - bb.emit_func_output(gv, [logits, temperature]) - - -def get_model(args, hf_config): - model_name = args.model - max_seq_len = args.max_seq_len - dtype = args.quantization.model_dtype - - if not model_name.lower().startswith("rwkv-"): - raise ValueError(f"Unsupported model name: {model_name}") - - config = RWKVConfig(**hf_config, dtype=dtype) - if max_seq_len != -1: - config.max_sequence_length = max_seq_len - - param_manager = ParamManager() - bb = relax.BlockBuilder() - create_func(bb, param_manager, config, args.quantization, "prefill") - create_func(bb, param_manager, config, args.quantization, "decode") - create_kv_cache_func(bb, config) - create_softmax_func(bb, config) - create_metadata_func( - bb, - model_name=model_name, - # RNN model do not have window size limit - max_window_size=-1, - stop_tokens=[0], - add_prefix_space=False, - ) - create_kv_cache_reset_func(bb, config) - mod = bb.get() - - if args.build_model_only: - return mod, param_manager, None, config - - def f_convert_pname_fwd(pname: str) -> List[str]: - if ( - "key_weight" in pname - or "value_weight" in pname - or "receptance_weight" in pname - or "output_weight" in pname - or "head_weight" in pname - ): - return [pname.replace("_weight", ".weight")] - else: - return [pname] - - def f_convert_param_bkwd(torch_pname: str, torch_param): - # torch_param: numpy.ndarray - import numpy as np # pylint: disable=import-outside-toplevel - - # rescale_every - if config.rescale_every > 0 and "blocks." in torch_pname: - # based-on the assumption that the layer id is the second element in torch_pname - layer_id = int(torch_pname.split(".")[2]) - if ( - "attention.output.weight" in torch_pname - or "feed_forward.value.weight" in torch_pname - ): - torch_param = torch_param / (2 ** (layer_id // config.rescale_every)) - - # reshape - if "time_" in torch_pname: - torch_param = torch_param.squeeze() - - # convert dtype - if "time_decay" in torch_pname: # need fp32 for this - return [(torch_pname, -np.exp(torch_param.astype("float32")))] - elif "time_first" in torch_pname: - return [(torch_pname, torch_param.astype("float32"))] - else: - return [(torch_pname, torch_param.astype(config.dtype))] - - param_manager.set_param_loading_func( - args.model_path, args.use_safetensors, f_convert_pname_fwd, f_convert_param_bkwd - ) - return mod, param_manager, [None] * len(param_manager.param_names), config diff --git a/mlc_llm/relax_model/stablelm_3b.py b/mlc_llm/relax_model/stablelm_3b.py deleted file mode 100644 index c39b8018ce..0000000000 --- a/mlc_llm/relax_model/stablelm_3b.py +++ /dev/null @@ -1,919 +0,0 @@ -import math -from dataclasses import dataclass -from typing import Any, List, Optional, Tuple - -import numpy as np -import tvm -from tvm import relax, te -from tvm.relax.op import ccl -from tvm.relax.op.nn import layer_norm -from tvm.relax.testing import nn -from tvm.script import relax as R - -from ..quantization import ParamQuantKind, QuantizationScheme -from .commons import create_metadata_func -from .llama import Embedding, Linear -from .modules import ModuleList, RotaryEmbedding -from .param_manager import ParamManager - - -@dataclass -class StableLM3bConfig: - def __init__( - self, - dtype="float32", - max_sequence_length=4096, - vocab_size=50304, - hidden_size=2560, - intermediate_size=6912, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - initializer_range=0.02, - norm_eps=1e-5, - pad_token_id=-1, - bos_token_id=0, - eos_token_id=1, - tie_word_embeddings=False, - position_embedding_base=10000, - combine_matmul=True, - num_shards=1, - build_model_only=False, - convert_weights_only=False, - **kwargs, - ): - self.dtype = dtype - self.max_sequence_length = max_sequence_length - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.norm_eps = norm_eps - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.tie_word_embeddings = tie_word_embeddings - self.position_embedding_base = position_embedding_base - self.combine_matmul = combine_matmul - if build_model_only and num_shards > 1: - self.num_shards = num_shards - else: - self.num_shards = 1 - self.kwargs = kwargs - - def get_num_key_value_heads(self): - if self.num_key_value_heads is None: - return self.num_attention_heads - return self.num_key_value_heads - - -class LayerNorm(nn.Module): - def __init__( - self, - hidden_size, - dtype, - eps=1e-5, - ): - super().__init__() - self.eps = eps - self.weight = nn.Parameter((hidden_size,), dtype="float16", name="weight") - self.bias = nn.Parameter((hidden_size,), dtype="float16", name="bias") - - def forward(self, x: relax.Expr) -> relax.Var: - x = nn.emit( - layer_norm( - x, - gamma=self.weight, - beta=self.bias, - axes=-1, - epsilon=self.eps, - ) - ) - return x - - -class StableLM3bMLP(nn.Module): - def __init__(self, config: StableLM3bConfig): - self.combine_matmul = config.combine_matmul - self.num_shards = config.num_shards - hidden_size = config.hidden_size - intermediate_size = config.intermediate_size // self.num_shards - dtype = config.dtype - if self.combine_matmul: - self.gate_up_proj = Linear(hidden_size, 2 * intermediate_size, dtype=dtype, bias=False) - self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) - self.gate_up_proj.weight.shard_dim = 0 - self.down_proj.weight.shard_dim = 1 - else: - self.gate_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) - self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) - self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) - self.gate_proj.weight.shard_dim = 0 - self.up_proj.weight.shard_dim = 0 - self.down_proj.weight.shard_dim = 1 - - def forward(self, x): - if self.combine_matmul: - gate_up_results = nn.emit( - relax.op.split( - self.gate_up_proj(x), - indices_or_sections=2, - axis=-1, - ) - ) - gate_result = relax.TupleGetItem(gate_up_results, 0) - up_result = relax.TupleGetItem(gate_up_results, 1) - else: - gate_result = self.gate_proj(x) - up_result = self.up_proj(x) - - result = self.down_proj(relax.op.nn.silu(gate_result) * up_result) - return result - - -class StableLM3bAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: StableLM3bConfig, rotary_embedding: RotaryEmbedding): - dtype = config.dtype - self.num_shards = config.num_shards - self.hidden_size = config.hidden_size - self.num_key_value_heads = ( - config.num_key_value_heads is None - and config.num_attention_heads - or config.num_key_value_heads - ) // config.num_shards - self.num_query_heads = config.num_attention_heads // self.num_shards - self.head_dim = self.hidden_size // config.num_attention_heads - self.position_embedding_base = config.position_embedding_base - self.rotary_embedding = rotary_embedding - - self.combine_matmul = config.combine_matmul - if self.combine_matmul: - self.query_key_value_proj = Linear( - self.hidden_size, - (self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim, - dtype=dtype, - bias=False, - ) - self.query_key_value_proj.weight.shard_dim = 0 - else: - self.q_proj = Linear( - self.hidden_size, - self.num_query_heads * self.head_dim, - dtype=dtype, - bias=False, - ) - self.k_proj = Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - dtype=dtype, - bias=False, - ) - self.v_proj = Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - dtype=dtype, - bias=False, - ) - self.q_proj.weight.shard_dim = 0 - self.k_proj.weight.shard_dim = 0 - self.v_proj.weight.shard_dim = 0 - - self.o_proj = Linear( - self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False - ) - self.o_proj.weight.shard_dim = 1 - - def forward( - self, - hidden_states: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_value: Tuple[relax.Expr], - attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Optional[relax.Expr], Optional[Tuple[relax.Expr]]]: - from tvm.relax.op import ( - astype, - matmul, - maximum, - permute_dims, - reshape, - split, - squeeze, - ) - from tvm.relax.op.nn import softmax - - bsz, q_len, _ = hidden_states.struct_info.shape - assert bsz == 1, "Only support batch size 1 at this moment." - - if self.combine_matmul: - qkv_states = nn.emit( - split( - self.query_key_value_proj(hidden_states), - indices_or_sections=[ - self.num_query_heads * self.head_dim, - (self.num_query_heads + self.num_key_value_heads) * self.head_dim, - ], - axis=-1, - ) - ) - query_states = relax.TupleGetItem(qkv_states, 0) - key_states = relax.TupleGetItem(qkv_states, 1) - value_states = relax.TupleGetItem(qkv_states, 2) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = nn.emit( - reshape( - query_states, - (bsz, q_len, self.num_query_heads, self.head_dim), - ), - ) - key_states = nn.emit( - reshape( - key_states, - (bsz, q_len, self.num_key_value_heads, self.head_dim), - ), - ) - value_states = nn.emit( - reshape( - value_states, - (bsz, q_len, self.num_key_value_heads, self.head_dim), - ), - ) - - kv_seq_len = all_seq_len_shape.struct_info.values[0] - offset = kv_seq_len - q_len - query_states, key_states = self.rotary_embedding(query_states, key_states, offset) - # [bsz, t, nh, hd] - - kv_states_shape = key_states.struct_info.shape - kv_states_dtype = key_states.struct_info.dtype - assert kv_states_shape[0] == 1 # bsz - kv_states_shape = R.shape( - [kv_states_shape[0], kv_seq_len, kv_states_shape[2], kv_states_shape[3]] - ) - kv_cache_shape = R.shape([kv_seq_len, kv_states_shape[2], kv_states_shape[3]]) - - squeezed_key = nn.emit(squeeze(key_states, axis=0)) - squeezed_value = nn.emit(squeeze(value_states, axis=0)) - k_cache, v_cache = past_key_value - f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") - k_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_append, - k_cache, - squeezed_key, - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - v_cache = nn.emit( - relax.op.call_inplace_packed( - f_kv_cache_append, - v_cache, - squeezed_value, - inplace_indices=[0], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - past_key_value = (k_cache, v_cache) - f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") - k_cache = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - k_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], - ) - ) - v_cache = nn.emit( - relax.call_pure_packed( - f_kv_cache_view, - v_cache, - kv_cache_shape, - sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], - ) - ) - key_states = nn.emit(reshape(k_cache, kv_states_shape)) - value_states = nn.emit(reshape(v_cache, kv_states_shape)) - if self.num_key_value_heads != self.num_query_heads: - n_rep = self.num_query_heads // self.num_key_value_heads - key_states = nn.emit(relax.op.repeat(key_states, n_rep, axis=2)) - value_states = nn.emit(relax.op.repeat(value_states, n_rep, axis=2)) - - query_states = nn.emit(permute_dims(query_states, [0, 2, 1, 3])) - key_states = nn.emit(permute_dims(key_states, [0, 2, 1, 3])) - value_states = nn.emit(permute_dims(value_states, [0, 2, 1, 3])) - - attn_weights = nn.emit( - matmul(query_states, permute_dims(key_states, [0, 1, 3, 2])) - / relax.const(math.sqrt(self.head_dim), query_states.struct_info.dtype) - ) - - tvm.ir.assert_structural_equal( - attention_mask.struct_info.shape.values, - (bsz, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), - ) - - attn_weights = nn.emit( - maximum( - attn_weights, - relax.const( - tvm.tir.min_value(attn_weights.struct_info.dtype).value, - attn_weights.struct_info.dtype, - ), - ) - ) - attn_weights = nn.emit(relax.op.minimum(attn_weights, attention_mask)) - - # upcast attention to fp32 - if attn_weights.struct_info.dtype != "float32": - attn_weights = astype(attn_weights, "float32") - attn_weights = nn.emit(softmax(attn_weights, axis=-1)) - if attn_weights.struct_info.dtype != query_states.struct_info.dtype: - attn_weights = astype(attn_weights, query_states.struct_info.dtype) - attn_output = nn.emit(matmul(attn_weights, value_states)) - - attn_output = nn.emit(permute_dims(attn_output, [0, 2, 1, 3])) - attn_output = nn.emit( - reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) - ) - - attn_output = self.o_proj(attn_output) - return attn_output, ((None, None) if past_key_value is None else past_key_value) - - -class StableLM3bDecoderLayer(nn.Module): - def __init__(self, config: StableLM3bConfig, rotary_embedding: RotaryEmbedding): - self.hidden_size = config.hidden_size - self.self_attn = StableLM3bAttention(config, rotary_embedding) - self.mlp = StableLM3bMLP(config) - self.input_layernorm = LayerNorm( - config.hidden_size, dtype=config.dtype, eps=config.norm_eps - ) - self.post_attention_layernorm = LayerNorm( - config.hidden_size, dtype=config.dtype, eps=config.norm_eps - ) - - def forward( - self, - hidden_states: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_value: Tuple[relax.Expr], - attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - all_seq_len_shape=all_seq_len_shape, - ) - if self.self_attn.num_shards > 1: - residual = nn.emit( - residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) - ) - hidden_states = nn.emit(residual + hidden_states) - if self.self_attn.num_shards > 1: - hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if self.mlp.num_shards > 1: - residual = nn.emit( - residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) - ) - hidden_states = nn.emit(residual + hidden_states) - if self.mlp.num_shards > 1: - hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - return hidden_states, present_key_value - - -def _make_causal_mask(input_ids_shape, dtype, src_len): - from tvm.relax.op import broadcast_to - - bsz, tgt_len = input_ids_shape - - def min_max_triu_te(): - return te.compute( - (tgt_len, tgt_len), - lambda i, j: tvm.tir.Select(j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype)), - name="make_diag_mask_te", - ) - - mask = nn.emit_te(min_max_triu_te) - diag_mask = nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, tgt_len))) - if src_len == tgt_len: - return diag_mask - - def extend_te(x, tgt_len, src_len): - return te.compute( - (bsz, 1, tgt_len, src_len), - lambda b, _, i, j: te.if_then_else( - j < src_len - tgt_len, - tvm.tir.max_value(dtype), - x[b, _, i, j - (src_len - tgt_len)], - ), - name="concat_te", - ) - - return nn.emit_te(extend_te, diag_mask, tgt_len, src_len) - - -class StableLM3bEmbedTokens(nn.Module): - def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.SizeVar): - self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) - - def forward(self, input_ids: relax.Expr): - inputs_embeds = self.embed_tokens(input_ids) - return inputs_embeds - - -class StableLM3bEmbedTokensWrapper(nn.Module): - def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.SizeVar): - # build a wrapper to ensure that the naming of the embed_tokens parameter is consistent - self.model = StableLM3bEmbedTokens(config, vocab_size_var) - - def forward(self, input_ids: relax.Expr): - inputs_embeds = self.model(input_ids) - return inputs_embeds - - -class StableLM3bModell(nn.Module): - def __init__( - self, config: StableLM3bConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False - ): - rotary_embedding = RotaryEmbedding( - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - position_embedding_base=config.position_embedding_base, - max_sequence_length=config.max_sequence_length, - rotary_pct=0.25, - dtype=config.dtype, - ) - self.num_shards = config.num_shards - self.padding_idx = config.pad_token_id - self.embed_tokens = None - - if not sep_embed: - self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) - - self.layers = ModuleList( - [ - StableLM3bDecoderLayer(config, rotary_embedding) - for _ in range(config.num_hidden_layers) - ] - ) - self.norm = LayerNorm(config.hidden_size, dtype=config.dtype, eps=config.norm_eps) - - def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if isinstance(input_shape[-1], tvm.tir.SizeVar) or input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask(input_shape, dtype, src_len) - else: - # Get src_len from input parameters - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - bsz, tgt_len = input_shape - combined_attention_mask = nn.emit( - relax.op.full( - (bsz, 1, tgt_len, src_len), - relax.const(tvm.tir.max_value(dtype).value, dtype), - dtype, - ) - ) - return combined_attention_mask - - def forward( - self, - inputs: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_values: relax.Expr, - ): - if self.num_shards > 1: - inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) - if self.embed_tokens: - inputs_embeds = self.embed_tokens(inputs) - else: - inputs_embeds = inputs - # retrieve input_ids - batch_size, seq_length, _ = inputs_embeds.struct_info.shape - seq_length_with_past = all_seq_len_shape.struct_info.values[0] - # embed positions - attention_mask = self._prepare_decoder_attention_mask( - (batch_size, seq_length), - seq_length_with_past, - inputs_embeds.struct_info.dtype, - ) - - hidden_states = inputs_embeds - - # decoder layers - next_decoder_cache = () - - for idx, decoder_layer in enumerate(self.layers): - assert past_key_values is not None - past_key_value = (past_key_values[idx * 2], past_key_values[idx * 2 + 1]) - - hidden_states, key_value_cache = decoder_layer( - hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - all_seq_len_shape=all_seq_len_shape, - ) - next_decoder_cache += key_value_cache - - hidden_states = self.norm(hidden_states) - - assert len(next_decoder_cache) == len(self.layers) * 2 - return hidden_states, next_decoder_cache - - -class StableLM3bForCausalLM(nn.Module): - def __init__( - self, config: StableLM3bConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False - ): - self.model = StableLM3bModell(config, vocab_size_var, sep_embed) - self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) - - assert config.hidden_size % config.num_attention_heads == 0 - - def forward( - self, - inputs: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_values: relax.Expr, - ): - hidden_states, key_value_cache = self.model( - inputs=inputs, - all_seq_len_shape=all_seq_len_shape, - past_key_values=past_key_values, - ) - - def te_slicing(x: te.Tensor): - return te.compute( - shape=(1, 1, x.shape[-1]), - fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], - name="slice", - ) - - logits = self.lm_head(nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice")) - if logits.struct_info.dtype != "float32": - logits = nn.emit(relax.op.astype(logits, "float32")) - - return logits, key_value_cache - - -def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: - if "embed_tokens" in name: - return ParamQuantKind.embedding_table - elif "lm_head.weight" in name: - return ParamQuantKind.final_fc_weight - elif param_info.ndim == 2 and name.endswith(".weight"): - return ParamQuantKind.linear_weight - else: - return ParamQuantKind.others - - -def create_embed_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: StableLM3bConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "embed" - - bsz = 1 - seq_len = tvm.tir.SizeVar("m", "int64") - with bb.function(func_name): - model = StableLM3bEmbedTokensWrapper(config, tvm.tir.SizeVar("vocab_size", "int64")) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") - with bb.dataflow(): - inputs_embeds = model(input_ids) - params = [input_ids] + model.parameters() - gv = bb.emit_output(inputs_embeds) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 1)) - - -def create_encoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: StableLM3bConfig, - quant_scheme: QuantizationScheme, - sep_embed: bool = False, -) -> None: - func_name = "prefill_with_embed" if sep_embed else "prefill" - - bsz = 1 - seq_len = tvm.tir.SizeVar("n", "int64") - all_seq_len = tvm.tir.SizeVar("m", "int64") - hidden_size = config.hidden_size - with bb.function(func_name): - model = StableLM3bForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - inputs = ( - nn.Placeholder((bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds") - if sep_embed - else nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") - ) - all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] - ), - ) - with bb.dataflow(): - logits, key_value_cache = model( - inputs, all_seq_len_shape, past_key_values=past_key_values - ) - params = [ - inputs, - all_seq_len_shape, - past_key_values, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_decoding_func( - bb: relax.BlockBuilder, - param_manager: ParamManager, - config: StableLM3bConfig, - quant_scheme: QuantizationScheme, -) -> None: - func_name = "decode" - - bsz = 1 - all_seq_len = tvm.tir.SizeVar("m", "int64") - - with bb.function(func_name): - model = StableLM3bForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64")) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) - - input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") - all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) - past_key_values = relax.Var( - "kv_cache", - relax.TupleStructInfo( - [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] - ), - ) - with bb.dataflow(): - logits, key_value_cache = model( - input_ids, all_seq_len_shape, past_key_values=past_key_values - ) - params = [ - input_ids, - all_seq_len_shape, - past_key_values, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) - bb.emit_func_output(gv, params) - - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) - - -def create_kv_cache_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None: - num_key_value_heads = ( - config.num_attention_heads - if config.num_key_value_heads is None - else config.num_key_value_heads - ) // config.num_shards - init_shape = relax.ShapeExpr( - ( - config.max_sequence_length, - num_key_value_heads, - config.hidden_size // config.num_attention_heads, # head_dim - ) - ) - with bb.function("create_kv_cache", []): - with bb.dataflow(): - zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) - caches = [] - f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") - for _ in range(config.num_hidden_layers * 2): - caches.append( - bb.emit( - relax.call_pure_packed( - f_kv_cache_create, - zeros, - init_shape, - relax.PrimValue(0), - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - ) - gv = bb.emit_output(caches) - bb.emit_func_output(gv) - - -def create_softmax_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None: - with bb.function("softmax_with_temperature"): - logits = nn.Placeholder( - (1, 1, tvm.tir.SizeVar("vocab_size", "int64")), dtype="float32", name="logits" - ) - temperature = nn.Placeholder((), dtype="float32", name="temperature") - with bb.dataflow(): - div = bb.emit(relax.op.divide(logits, temperature)) - softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) - gv = bb.emit_output(softmax) - bb.emit_func_output(gv, [logits, temperature]) - - -def emit_shard3d(bb: relax.BlockBuilder) -> None: - from tvm.script import tir as T - - def _emit(dtype: str, global_symbol: str): - @T.prim_func - def shard_3d(a: T.handle, num_shards: T.int64, b: T.handle): - T.func_attr( - { - "tir.noalias": T.bool(True), - "global_symbol": global_symbol, - } - ) - s_0, s_1, s_2 = T.int64(), T.int64(), T.int64() - # pylint: disable=invalid-name - A = T.match_buffer(a, (s_0, s_1, s_2), dtype) - B = T.match_buffer(b, (num_shards, s_0, s_1 // num_shards, s_2), dtype) - # pylint: enable=invalid-name - for j_o, i, j_i, k in T.grid(num_shards, s_0, s_1 // num_shards, s_2): - with T.block("B"): - v_j_o = T.axis.spatial(num_shards, j_o) - v_i = T.axis.spatial(s_0, i) - v_j_i = T.axis.spatial(s_1 // num_shards, j_i) - v_k = T.axis.spatial(s_2, k) - B[v_j_o, v_i, v_j_i, v_k] = A[v_i, v_j_o * (s_1 // num_shards) + v_j_i, v_k] - - bb.add_func(shard_3d, global_symbol) - - _emit("float32", "shard3d_fp32") - _emit("float16", "shard3d_fp16") - _emit("uint32", "shard3d_uint32") - - -def get_model(args, hf_config): - model_name = args.model - dtype = args.quantization.model_dtype - max_seq_len = args.max_seq_len - sep_embed = args.sep_embed - - position_embedding_base = 10000 - if "rope_theta" in hf_config: - position_embedding_base = hf_config["rope_theta"] - - config = StableLM3bConfig( - **hf_config, - dtype=dtype, - position_embedding_base=position_embedding_base, - combine_matmul=True, - num_shards=args.num_shards, - build_model_only=args.build_model_only, - convert_weights_only=args.convert_weights_only, - ) - if max_seq_len != -1: - config.max_sequence_length = max_seq_len - - param_manager = ParamManager() - bb = relax.BlockBuilder() - emit_shard3d(bb) - - if sep_embed: - create_embed_func(bb, param_manager, config, args.quantization) - create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) - create_decoding_func(bb, param_manager, config, args.quantization) - create_kv_cache_func(bb, config) - create_softmax_func(bb, config) - create_metadata_func( - bb, - model_name=model_name, - max_window_size=config.max_sequence_length, - stop_tokens=[2], - add_prefix_space=False, - prefill_chunk_size=args.prefill_chunk_size, - ) - - mod = bb.get() - - tir_bound_map = dict() - tir_bound_map["n"] = ( - args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length - ) - tir_bound_map["m"] = config.max_sequence_length - for gv in mod.functions: - func = mod[gv] - if isinstance(func, relax.Function): - mod[gv] = func.with_attr("tir_var_upper_bound", tir_bound_map) - - if args.build_model_only: - return mod, param_manager, None, config - - def f_convert_pname_fwd(pname: str) -> List[str]: - if not config.combine_matmul: - return [pname] - - qkv_str = "query_key_value_proj" - gate_up_str = "gate_up_proj" - if qkv_str in pname: - return [ - pname.replace(qkv_str, "q_proj"), - pname.replace(qkv_str, "k_proj"), - pname.replace(qkv_str, "v_proj"), - ] - elif gate_up_str in pname: - return [ - pname.replace(gate_up_str, "gate_proj"), - pname.replace(gate_up_str, "up_proj"), - ] - else: - return [pname] - - def f_convert_param_bkwd(torch_pname: str, torch_param): - if not config.combine_matmul: - return [(torch_pname, torch_param.astype(dtype))] - - combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] - if any([name in torch_pname for name in combined_layers]): - return None - return [(torch_pname, torch_param.astype(dtype))] - - def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): - # Expected to enter this function only for the combined linear matmul weights. - # Other weights are supposed to be loaded in `f_convert_param_bkwd` since - # each other relax param has a unique corresponding torch param. - if not config.combine_matmul: - # When matmul combination is not turned on, each relax param has a unique - # corresponding torch param, and this function is not expected to be entered. - raise NotImplementedError( - "Matmul combination is not turned on, and the function " - "is not expected to be entered" - ) - num_shards = args.num_shards - hidden_size = config.hidden_size - head_dim = config.hidden_size // config.num_attention_heads - - if "query_key_value_proj" in relax_pname: - q_heads = config.num_attention_heads - kv_heads = config.num_key_value_heads - if kv_heads is None: - kv_heads = q_heads - q, k, v = torch_params - assert q.shape == (q_heads * head_dim, hidden_size) - assert k.shape == (kv_heads * head_dim, hidden_size) - assert v.shape == (kv_heads * head_dim, hidden_size) - q = q.reshape((num_shards, q_heads // num_shards, head_dim, hidden_size)) - k = k.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)) - v = v.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)) - qkv = np.concatenate([q, k, v], axis=1) - qkv = qkv.reshape((-1, hidden_size)).astype(dtype) - return qkv - if "gate_up_proj" in relax_pname: - intermediate_size = config.intermediate_size - gate, up = torch_params - gate = gate.reshape((num_shards, intermediate_size // num_shards, hidden_size)) - up = up.reshape((num_shards, intermediate_size // num_shards, hidden_size)) - gate_up = np.concatenate([gate, up], axis=1) - gate_up = gate_up.reshape((-1, hidden_size)).astype(dtype) - return gate_up - raise ValueError("Unexpected param loading") - - param_manager.set_param_loading_func( - args.model_path, - args.use_safetensors, - f_convert_pname_fwd, - f_convert_param_bkwd, - f_compute_relax_param, - ) - - param_list = [None] * param_manager.nparam_to_load - - return mod, param_manager, param_list, config diff --git a/mlc_llm/transform/__init__.py b/mlc_llm/transform/__init__.py deleted file mode 100644 index 758d8a1081..0000000000 --- a/mlc_llm/transform/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .clean_up_tir_attrs import CleanUpTIRAttrs -from .decode_matmul_ewise import FuseDecodeMatmulEwise -from .decode_take import FuseDecodeTake -from .decode_transpose import FuseDecodeTranspose -from .fuse_split_rotary_embedding import fuse_split_rotary_embedding -from .lift_tir_global_buffer_alloc import LiftTIRGlobalBufferAlloc -from .reorder_transform_func import ReorderTransformFunc -from .rewrite_attention import rewrite_attention -from .transpose_matmul import FuseTransposeMatmul, FuseTranspose1Matmul, FuseTranspose2Matmul -from .set_entry_funcs import SetEntryFuncs diff --git a/mlc_llm/transform/clean_up_tir_attrs.py b/mlc_llm/transform/clean_up_tir_attrs.py deleted file mode 100644 index 93a90f8227..0000000000 --- a/mlc_llm/transform/clean_up_tir_attrs.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Clean up TIR attributes that may affect dispatching""" - -import tvm -from tvm.ir.module import IRModule - - -@tvm.transform.module_pass(opt_level=0, name="CleanUpTIRAttrs") -class CleanUpTIRAttrs: - def transform_module( - self, mod: IRModule, ctx: tvm.transform.PassContext - ) -> IRModule: - undesired_attrs = ["op_pattern"] - - for gv in list(mod.functions): - func = mod[gv] - changed = False - for attr in undesired_attrs: - if func.attrs is not None and attr in func.attrs: - func = func.without_attr(attr) - changed = True - break - - if changed: - mod[gv] = func - return mod diff --git a/mlc_llm/transform/decode_matmul_ewise.py b/mlc_llm/transform/decode_matmul_ewise.py deleted file mode 100644 index 7471848bfb..0000000000 --- a/mlc_llm/transform/decode_matmul_ewise.py +++ /dev/null @@ -1,84 +0,0 @@ -import tvm -from tvm import IRModule, relax, tir -from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_op, wildcard - - -def check_decoding(ctx: relax.transform.PatternCheckContext) -> bool: - call = ctx.annotated_expr["w"] - if not isinstance(call, relax.Call): - return False - gv = call.args[0] - if not isinstance(gv, relax.GlobalVar): - return False - return gv.name_hint.startswith("decode") or gv.name_hint.startswith("fused_decode") - - -def check_matmul(ctx: relax.transform.PatternCheckContext) -> bool: - call = ctx.annotated_expr["matmul"] - if not isinstance(call, relax.Call): - return False - gv = call.args[0] - if not isinstance(gv, relax.GlobalVar): - return False - return ( - gv.name_hint.startswith("matmul") - or gv.name_hint.startswith("fused_matmul") - or gv.name_hint.startswith("NT_matmul") - or gv.name_hint.startswith("fused_NT_matmul") - ) - - -def pattern_check(): - def f_pattern_check(ctx: relax.transform.PatternCheckContext) -> bool: - return check_decoding(ctx) and check_matmul(ctx) - - return f_pattern_check - - -def decode_matmul_pattern(match_ewise: int, n_aux_tensor: int): - assert n_aux_tensor == 1 or n_aux_tensor == 2 or n_aux_tensor == 3 or n_aux_tensor == 4 - - w_scaled = wildcard() - aux_tensors = [wildcard(), wildcard(), wildcard(), wildcard()] - x = wildcard() - w = is_op("relax.call_tir")( - GlobalVarPattern(), - TuplePattern([w_scaled, *aux_tensors[0:n_aux_tensor]]), - add_constraint=False, - ) - matmul_args = [x, w] - for _ in range(match_ewise): - matmul_args.append(wildcard()) - matmul = is_op("relax.call_tir")( - GlobalVarPattern(), TuplePattern(matmul_args), add_constraint=False - ) - - annotations = { - "matmul": matmul, - "w": w, - "x": x, - "w_scaled": w_scaled, - } - return matmul, annotations, pattern_check() - - -@tvm.transform.module_pass(opt_level=0, name="FuseDecodeMatmulEwise") -class FuseDecodeMatmulEwise: - def transform_module( - self, mod: IRModule, ctx: tvm.transform.PassContext # pylint: disable=unused-argument - ) -> IRModule: - for n_aux_tensor in [1, 2, 3, 4]: - for match_ewise in [0, 1, 2, 6]: - if match_ewise == 6 and n_aux_tensor != 4: - continue - mod = relax.transform.FuseOpsByPattern( - [ - ( - "decode_matmul", - *decode_matmul_pattern(match_ewise, n_aux_tensor), - ) - ] - )(mod) - mod = relax.transform.FuseTIR()(mod) - - return mod diff --git a/mlc_llm/transform/decode_take.py b/mlc_llm/transform/decode_take.py deleted file mode 100644 index cd09771126..0000000000 --- a/mlc_llm/transform/decode_take.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Fusing and inlining decode function into embedding table lookup.""" -import tvm -from tvm import relax, tir -from tvm.ir.module import IRModule -from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_const, is_op, wildcard - - -def pattern_check(ctx: relax.transform.PatternCheckContext) -> bool: - take = ctx.annotated_expr["take"] - decode = ctx.annotated_expr["decode"] - if not isinstance(decode, relax.expr.Call): - return False - if not isinstance(take.args[0], relax.GlobalVar) or not isinstance( - decode.args[0], relax.GlobalVar - ): - return False - return "take" in take.args[0].name_hint and "decode" in decode.args[0].name_hint - - -def decode_take_pattern(n_aux_tensor: int, match_tir_vars: bool): - aux_tensors = [wildcard(), wildcard(), wildcard()] - decode = is_op("relax.call_tir")( - GlobalVarPattern(), - TuplePattern([*aux_tensors[0:n_aux_tensor]]), - add_constraint=False, - ) - indices = ~is_const() - take_args = [decode, indices] - call_tir_args_take = [GlobalVarPattern(), TuplePattern(take_args)] - if match_tir_vars: - call_tir_args_take.append(wildcard()) - take = is_op("relax.call_tir")(*call_tir_args_take, add_constraint=False) - - annotations = { - "take": take, - "decode": decode, - "indices": indices, - } - - return take, annotations, pattern_check - - -@tvm.transform.module_pass(opt_level=0, name="FuseDecodeTake") -class FuseDecodeTake: - def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: - for n_aux_tensor in [2, 3]: - for match_tir_vars in [False, True]: - mod = relax.transform.FuseOpsByPattern( - [ - ( - "decode_take", - *decode_take_pattern(n_aux_tensor, match_tir_vars), - ) - ] - )(mod) - mod = relax.transform.FuseTIR()(mod) - - for gv, func in mod.functions.items(): - if not isinstance(func, tir.PrimFunc): - continue - if "fused_decode" not in gv.name_hint or "take" not in gv.name_hint: - continue - - downcasted_mod = tir.transform.ForceNarrowIndexToInt32()(tvm.IRModule({"main": func}))[ - "main" - ] - sch = tir.Schedule(downcasted_mod) - sch.compute_inline("decode") - mod[gv] = sch.mod["main"] - - return mod diff --git a/mlc_llm/transform/decode_transpose.py b/mlc_llm/transform/decode_transpose.py deleted file mode 100644 index be5dccdc91..0000000000 --- a/mlc_llm/transform/decode_transpose.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Fusing and inlining transpose function into decode function.""" -import tvm -from tvm import relax, tir -from tvm.ir.module import IRModule -from tvm.relax.analysis import remove_all_unused -from tvm.relax.expr_functor import PyExprMutator, mutator - - -@tvm.transform.module_pass(opt_level=0, name="FuseDecodeTranspose") -class FuseDecodeTranspose: - def __init__(self, skip_gemm=True) -> None: - self.skip_gemm = skip_gemm - - def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: - @mutator - class DecodeTransposeFusor(PyExprMutator): - def __init__(self, mod: IRModule, skip_gemm=True): - super().__init__(mod) - self.mod = mod - self.skip_gemm = skip_gemm - - def transform(self) -> IRModule: - for gv, func in self.mod.functions.items(): - if not isinstance(func, relax.Function): - continue - - updated_func = self.visit_expr(func) - updated_func = remove_all_unused(updated_func) - self.builder_.update_func(gv, updated_func) - - return self.builder_.get() - - def visit_call_(self, call: relax.Call) -> relax.Expr: - call = self.visit_expr_post_order(call) - - if call.op != tvm.ir.Op.get("relax.matmul"): - return call - - # Do not fuse decode-transpose for GeMM - if self.skip_gemm and ( - call.args[0].struct_info.ndim < 2 - or not isinstance(call.args[0].struct_info.shape[-2], tir.IntImm) - or call.args[0].struct_info.shape[-2].value != 1 - ): - return call - - matmul_rhs = self.lookup_binding(call.args[1]) - if ( - not isinstance(matmul_rhs, relax.Call) - or matmul_rhs.op != tvm.ir.Op.get("relax.permute_dims") - or matmul_rhs.args[0].struct_info.ndim != 2 - or matmul_rhs.attrs.axes is not None - ): - return call - - transpose_input = self.lookup_binding(matmul_rhs.args[0]) - if ( - not isinstance(transpose_input, relax.Call) - or transpose_input.op != tvm.ir.Op.get("relax.call_tir") - or not transpose_input.args[0].name_hint.startswith("decode") - or not isinstance( - transpose_input.struct_info, relax.TensorStructInfo - ) - ): - return call - - decode_tir_func = self.mod[transpose_input.args[0]] - assert isinstance(decode_tir_func, tir.PrimFunc) - if ( - len(decode_tir_func.body.block.alloc_buffers) != 1 - or not isinstance(decode_tir_func.body.block.body, tir.SeqStmt) - or len(decode_tir_func.body.block.body) != 2 - or not isinstance(decode_tir_func.body.block.body[1], tir.For) - or not isinstance( - decode_tir_func.body.block.body[1].body.body, tir.BlockRealize - ) - or decode_tir_func.body.block.body[1].body.body.block.name_hint - != "T_transpose" - ): - return call - - new_func_buffers = [ - decode_tir_func.buffer_map[var] for var in decode_tir_func.params - ] - new_func_buffers[-1] = decode_tir_func.body.block.alloc_buffers[0] - new_func = tir.PrimFunc( - params=new_func_buffers, - body=tir.BlockRealize( - iter_values=[], - predicate=True, - block=tir.Block( - iter_vars=[], - reads=[], - writes=[], - name_hint="root", - body=decode_tir_func.body.block.body[0], - ), - ), - ) - # Call `renew_defs` for deep-copy to avoid IR node duplication in - # different PrimFuncs of an IRModule. - new_func = tir.stmt_functor.renew_defs(new_func) - gv = self.builder_.add_func(new_func, func_name="decode") - decoded_matmul_rhs = self.builder_.emit( - relax.call_tir( - gv, transpose_input.args[1], out_sinfo=matmul_rhs.struct_info - ) - ) - return relax.op.matmul( - call.args[0], decoded_matmul_rhs, out_dtype=call.attrs.out_dtype - ) - - return DecodeTransposeFusor(mod, self.skip_gemm).transform() diff --git a/mlc_llm/transform/fuse_split_rotary_embedding.py b/mlc_llm/transform/fuse_split_rotary_embedding.py deleted file mode 100644 index ed19a7095c..0000000000 --- a/mlc_llm/transform/fuse_split_rotary_embedding.py +++ /dev/null @@ -1,284 +0,0 @@ -import tvm -from tvm import relax -from tvm.relax.dpl import ( - PatternContext, - is_op, - rewrite_bindings, - wildcard, - is_tuple_get_item, - GlobalVarPattern, - TuplePattern, - is_shape, -) -from tvm.script import relax as R, tir as T - - -def get_dynamic_split_rotary(): - """Implementation of R.split(rotary_embedding(fused_qkv)) - - Implementation is generic over the number of query heads, - key/value heads, sequence length, head dimension, and position - embedding base. These parameters can be replaced with static - values using `PrimFunc.specialize`. - """ - - @T.prim_func(private=True) - def split_rotary( - fused_qkv_handle: T.handle, - embedded_query_handle: T.handle, - embedded_key_handle: T.handle, - value_handle: T.handle, - rotary_offset: T.int64, - batch_size: T.int64, - seq_len: T.int64, - num_query_heads: T.int64, - num_kv_heads: T.int64, - head_dim: T.int64, - position_embedding_base: T.float32, - ): - Fused_QKV = T.match_buffer( - fused_qkv_handle, - [batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim], - dtype="float16", - ) - EmbeddedQuery = T.match_buffer( - embedded_query_handle, - [batch_size, seq_len, num_query_heads, head_dim], - dtype="float16", - ) - EmbeddedKey = T.match_buffer( - embedded_key_handle, - [batch_size, seq_len, num_kv_heads, head_dim], - dtype="float16", - ) - Value = T.match_buffer( - value_handle, - [batch_size, seq_len, num_kv_heads, head_dim], - dtype="float16", - ) - - T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) - - for iters in T.grid(batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim): - with T.block("FusedRotaryEmbeddingAndSplitQKV"): - batch_i, seq_i, head_num, head_i = T.axis.remap("SSSS", iters) - pos: T.float32 = T.Cast("float32", rotary_offset + seq_i - seq_len) - - inv_freq: T.float32 = T.float32(1) / T.pow( - position_embedding_base, - T.Cast("float32", (head_i * 2) % head_dim) / T.float32(head_dim), - ) - freq: T.float32 = pos * inv_freq - cos_value: T.float16 = T.Cast("float16", T.cos(freq)) - sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - - input_value = Fused_QKV[batch_i, seq_i, head_num, head_i] - embedded_value = cos_value * input_value + sin_value * T.Select( - head_i < T.int64(head_dim // 2), - Fused_QKV[batch_i, seq_i, head_num, head_i + T.int64(head_dim // 2)] - * T.float16(-1), - Fused_QKV[batch_i, seq_i, head_num, head_i - T.int64(head_dim // 2)], - ) - if head_num < num_query_heads: - EmbeddedQuery[batch_i, seq_i, head_num, head_i] = embedded_value - elif head_num < num_query_heads + num_kv_heads: - EmbeddedKey[batch_i, seq_i, head_num - num_query_heads, head_i] = embedded_value - else: - Value[ - batch_i, seq_i, head_num - num_query_heads - num_kv_heads, head_i - ] = input_value - - param_sinfo = [] - for param in split_rotary.params: - if param in split_rotary.buffer_map: - buf = split_rotary.buffer_map[param] - sinfo = relax.TensorStructInfo(shape=buf.shape, dtype=buf.dtype) - else: - sinfo = relax.PrimStructInfo(param.dtype) - param_sinfo.append(sinfo) - - relax.expr._update_struct_info( - split_rotary, - tvm.relax.FuncStructInfo( - params=param_sinfo, - ret=relax.TupleStructInfo([]), - purity=False, - ), - ) - - return split_rotary - - -def fuse_split_rotary_embedding( - num_query_heads, num_kv_heads, hidden_size, position_embedding_base -): - @tvm.ir.transform.module_pass(opt_level=0, name="fuse_split_rotary_embedding") - def ir_module_pass(mod: tvm.IRModule, _pass_context) -> tvm.IRModule: - head_dim = hidden_size // num_query_heads - split_rotary = get_dynamic_split_rotary() - - ( - dyn_batch_size, - dyn_seq_len, - dyn_num_query_heads, - dyn_num_kv_heads, - dyn_head_dim, - dyn_position_embedding_base, - ) = split_rotary.params[-6:] - - split_rotary = split_rotary.specialize( - { - # Static model parameters - dyn_batch_size: T.int64(1), - dyn_num_query_heads: T.int64(num_query_heads), - dyn_num_kv_heads: T.int64(num_kv_heads), - dyn_head_dim: T.int64(head_dim), - dyn_position_embedding_base: T.float32(position_embedding_base), - # Dynamic parameters, to be inferred from TIR Buffer shapes - dyn_seq_len: tvm.tir.Var("query_sequence_length", "int64"), - } - ) - - mod["split_rotary"] = split_rotary - - split_rotary_gvar = mod.get_global_var("split_rotary") - relax.expr._update_struct_info(split_rotary_gvar, mod["split_rotary"].struct_info) - - with PatternContext() as ctx: - # flat_qkv_tuple: R.Tuple( - # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), - # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), - # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), - # ) = R.split(flat_fused_qkv, indices_or_sections=[4096, 8192], axis=2) - # - # flat_query: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[0] - # query: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( - # flat_query, R.shape([batch_size, seq_len, 32, 128]) - # ) - # flat_key: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[1] - # key: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( - # flat_key, R.shape([batch_size, seq_len, 32, 128]) - # ) - # flat_value: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[2] - # value: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( - # flat_value, R.shape([batch_size, seq_len, 32, 128]) - # ) - # embedded_query = R.call_tir( - # cls.rotary_embedding1, - # [query], - # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), - # tir_vars=R.shape([n]), - # ) - # embedded_key = R.call_tir( - # cls.rotary_embedding1, - # [key], - # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), - # tir_vars=R.shape([n]), - # ) - - pat_rotary_embedding_gvar = GlobalVarPattern() - - pat_flat_fused_qkv = wildcard() - pat_offset = wildcard() - - # query_shape = is_shape([1, seq_len, num_query_heads, head_dim]) - pat_query_shape = wildcard() - # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) - pat_key_shape = wildcard() - # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) - pat_value_shape = wildcard() - - pat_flat_qkv_tuple = is_op("relax.split")(pat_flat_fused_qkv) - pat_flat_query = is_tuple_get_item(pat_flat_qkv_tuple, 0) - pat_query = is_op("relax.reshape")( - pat_flat_query, pat_query_shape, add_constraint=False - ) - pat_flat_query.used_by(pat_query) - pat_flat_key = is_tuple_get_item(pat_flat_qkv_tuple, 1) - pat_key = is_op("relax.reshape")(pat_flat_key, pat_key_shape, add_constraint=False) - pat_flat_key.used_by(pat_key) - pat_flat_value = is_tuple_get_item(pat_flat_qkv_tuple, 2) - pat_value = is_op("relax.reshape")( - pat_flat_value, pat_value_shape, add_constraint=False - ) - pat_flat_value.used_by(pat_value) - - pat_embedded_query = is_op("relax.call_tir")( - pat_rotary_embedding_gvar, - TuplePattern([pat_query]), - pat_offset, - add_constraint=False, - ) - pat_embedded_key = is_op("relax.call_tir")( - pat_rotary_embedding_gvar, - TuplePattern([pat_key]), - pat_offset, - add_constraint=False, - ) - - pat_flat_qkv_tuple.used_by(pat_flat_query) - pat_flat_qkv_tuple.used_by(pat_flat_key) - pat_flat_qkv_tuple.used_by(pat_flat_value) - pat_query.used_by(pat_embedded_query) - pat_key.used_by(pat_embedded_key) - - def rewriter(matchings, bindings): - # Extracting all the relax and TIR variables that we'll need - flat_fused_qkv = matchings[pat_flat_fused_qkv] - flat_qkv_tuple = matchings[pat_flat_qkv_tuple] - - flat_query = matchings[pat_flat_query] - flat_key = matchings[pat_flat_key] - flat_value = matchings[pat_flat_value] - - query = matchings[pat_query] - key = matchings[pat_key] - value = matchings[pat_value] - - embedded_query = matchings[pat_embedded_query] - embedded_key = matchings[pat_embedded_key] - - # rotary_embedding_offset = bindings[query].args[-1][1] - rotary_embedding_offset = bindings[embedded_query].args[-1][0] - - batch_size, seq_len, num_query_heads, head_dim = query.struct_info.shape - _batch_size, _seq_len, num_kv_heads, _head_dim = key.struct_info.shape - - # Rewriting along the new path - - fused_qkv = relax.op.reshape( - flat_fused_qkv, [batch_size, seq_len, num_query_heads + 2 * num_kv_heads, head_dim] - ) - - split_rotary_sinfo = [ - R.Tensor((batch_size, seq_len, num_query_heads, head_dim), dtype="float16"), - R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), - R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), - ] - qkv_tuple_new = R.call_tir( - split_rotary_gvar, - (fused_qkv,), - out_sinfo=split_rotary_sinfo, - tir_vars=[rotary_embedding_offset], - ) - - embedded_query_new = qkv_tuple_new[0] - embedded_key_new = qkv_tuple_new[1] - value_new = qkv_tuple_new[2] - - return { - value: value_new, - embedded_query: embedded_query_new, - embedded_key: embedded_key_new, - } - - new_mod = {} - for gvar, func in mod.functions.items(): - if isinstance(func, relax.Function): - func = rewrite_bindings(ctx, rewriter, func) - new_mod[gvar] = func - - new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos) - return new_mod - - return ir_module_pass diff --git a/mlc_llm/transform/lift_tir_global_buffer_alloc.py b/mlc_llm/transform/lift_tir_global_buffer_alloc.py deleted file mode 100644 index 5805e9f1fc..0000000000 --- a/mlc_llm/transform/lift_tir_global_buffer_alloc.py +++ /dev/null @@ -1,197 +0,0 @@ -"""Lift global buffer allocation in TIR to graph level""" - -from typing import Dict, List, Tuple, Optional - -import tvm -from tvm import relax, tir -from tvm.ir.module import IRModule -from tvm.relax.analysis import remove_all_unused -from tvm.relax.expr_functor import PyExprMutator, mutator - - -def remove_global_buf_alloc( - func: tir.PrimFunc, -) -> Optional[Tuple[tir.PrimFunc, List[relax.TensorStructInfo]]]: - """Remove the global buffer allocation for a given TIR PrimFunc.""" - if not isinstance(func.body, tir.BlockRealize): - return None - - params = list(func.params) - buffer_map = dict(func.buffer_map) - tensor_sinfo = [] - alloc_buffers = [] - - insertion_point = len(params) - while params[insertion_point - 1].dtype != "handle": - insertion_point -= 1 - assert insertion_point >= 1 - - prev_root_block = func.body.block - for buf_alloc in func.body.block.alloc_buffers: - if buf_alloc.scope() == "global": - param = tir.Var("var_" + buf_alloc.name, "handle") - params.insert(insertion_point, param) - insertion_point += 1 - buffer_map[param] = buf_alloc - tensor_sinfo.append(relax.TensorStructInfo(buf_alloc.shape, buf_alloc.dtype)) - else: - alloc_buffers.append(buf_alloc) - - if len(tensor_sinfo) == 0: - return None - - assert len(prev_root_block.iter_vars) == 0 - assert len(prev_root_block.reads) == 0 - assert len(prev_root_block.writes) == 0 - assert len(prev_root_block.match_buffers) == 0 - assert prev_root_block.name_hint == "root" - assert prev_root_block.init is None - root_block = tir.Block( - iter_vars=[], - reads=[], - writes=[], - name_hint="root", - body=prev_root_block.body, - alloc_buffers=alloc_buffers, - annotations=prev_root_block.annotations, - ) - - updated_func = tir.PrimFunc( - params=params, - body=tir.BlockRealize(iter_values=[], predicate=True, block=root_block), - ret_type=func.ret_type, - buffer_map=buffer_map, - attrs=func.attrs, - ) - return updated_func, tensor_sinfo - - -def contain_symbolic_var(tensor_sinfo: relax.TensorStructInfo) -> bool: - assert isinstance(tensor_sinfo.shape, relax.ShapeExpr) - for v in tensor_sinfo.shape.values: - if not isinstance(v, tir.IntImm): - return True - return False - - -def resolve_tir_var_mapping( - func: tir.PrimFunc, call: relax.Call, tensor_sinfo: List[relax.TensorStructInfo] -) -> Tuple[List[relax.TensorStructInfo], bool]: - """Resolve the TIR symbolic var relationship across sides of PrimFunc and Relax Function""" - var_map: Dict[tir.Var, tir.PrimExpr] = dict() - - n_arg = len(call.args[1].fields) - for i in range(n_arg): - buffer_shape = func.buffer_map[func.params[i]].shape - arg_shape = call.args[1][i].struct_info.shape.values - assert len(buffer_shape) == len(arg_shape) - for vl, vr in zip(buffer_shape, arg_shape): - if isinstance(vl, tir.Var): - var_map[vl] = vr - elif not isinstance(vl, tir.IntImm): - return [], False - - ret_tensors = call.sinfo_args[0] - ret_tensors = ( - [ret_tensors] - if isinstance(ret_tensors, relax.TensorStructInfo) - else list(ret_tensors.fields) - ) - for i in range(len(ret_tensors)): - buffer_shape = func.buffer_map[func.params[n_arg + i]].shape - ret_tensor_shape = ret_tensors[i].shape.values - assert len(buffer_shape) == len(ret_tensor_shape) - for vl, vr in zip(buffer_shape, ret_tensor_shape): - if isinstance(vl, tir.Var): - var_map[vl] = vr - elif not isinstance(vl, tir.IntImm): - return [], False - - updated_tensor_sinfo = [] - for sinfo in tensor_sinfo: - if not contain_symbolic_var(sinfo): - updated_tensor_sinfo.append(sinfo) - continue - - new_shape = [] - for v in sinfo.shape.values: - new_shape.append(tir.stmt_functor.substitute(v, var_map)) - updated_tensor_sinfo.append(relax.TensorStructInfo(new_shape, sinfo.dtype)) - return updated_tensor_sinfo, True - - -def LiftTIRGlobalBufferAlloc(): - @mutator - class TIRGlobalAllocRewriter(PyExprMutator): - def __init__(self, mod: IRModule): - super().__init__(mod) - self.mod = mod - - def transform(self) -> IRModule: - self.mod = self.builder_.get() - for gv, func in self.mod.functions.items(): - if isinstance(func, relax.Function): - updated_func = self.visit_expr(func) - self.builder_.update_func(gv, updated_func) - return self.builder_.get() - - def visit_call_(self, call: relax.Call): - call = self.visit_expr_post_order(call) - if call.op != tvm.ir.Op.get("relax.call_tir"): - return call - - old_gvar = call.args[0] - - func_before_update = self.mod.functions[old_gvar] - updates = remove_global_buf_alloc(func_before_update) - if updates is None: - return call - updated_func, tensor_sinfo = updates - - assert len(call.sinfo_args) == 1 - if any(contain_symbolic_var(sinfo) for sinfo in tensor_sinfo): - tensor_sinfo, success = resolve_tir_var_mapping( - func_before_update, call, tensor_sinfo - ) - if not success: - # Cannot resolve TIR var mapping. Fall back to no lifting. - return call - - new_gvar = self.builder_.add_func(updated_func, old_gvar.name_hint) - new_args = [new_gvar, *call.args[1:]] - - if isinstance(call.sinfo_args[0], relax.TensorStructInfo): - new_call = relax.Call( - call.op, - args=new_args, - sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args) + tensor_sinfo)], - attrs=call.attrs, - ) - emitted_tuple = self.builder_.emit(new_call) - return relax.TupleGetItem(emitted_tuple, 0) - elif isinstance(call.sinfo_args[0], relax.TupleStructInfo): - return relax.Call( - call.op, - args=new_args, - sinfo_args=[ - relax.TupleStructInfo(list(call.sinfo_args[0].fields) + tensor_sinfo) - ], - attrs=call.attrs, - ) - else: - raise TypeError( - f"Expected {call.op} to return either R.Tensor or R.Tuple, " - f"but instead returned {call.sinfo_args[0]}" - ) - - @tvm.transform.module_pass(opt_level=0, name="LiftTIRGlobalBufferAlloc.Inner") - def transform_module(mod: IRModule, _: tvm.transform.PassContext) -> IRModule: - return TIRGlobalAllocRewriter(mod).transform() - - return tvm.ir.transform.Sequential( - [ - transform_module, - tvm.relax.transform.DeadCodeElimination(), - ], - name="LiftTIRGlobalBufferAlloc", - ) diff --git a/mlc_llm/transform/reorder_transform_func.py b/mlc_llm/transform/reorder_transform_func.py deleted file mode 100644 index 50b6337e3a..0000000000 --- a/mlc_llm/transform/reorder_transform_func.py +++ /dev/null @@ -1,281 +0,0 @@ -from typing import Callable, Dict, List, Set, Tuple, Optional - -import tvm -from tvm import relax -from tvm.ir.module import IRModule - -""" -This pass in this file reorders the bindings of the weight transform function -according to the weight location in binary files. The goal of the reorder is to -reduce the memory pressure when loading the raw model weights and processing -them. In the ideal case, with this pass, the highest CPU memory usage will -around the size of the largest raw weight binary file. - -Regarding the implementation, the bindings of fetching a raw weight in the -weight transform function are all in the form of `lv = params[idx]`. Here, each -index specifies a raw weight tensor, and the raw weight tensor resides in a -binary file on the disk. - -We group such `lv = params[idx]` into multiple groups, such that all raw weight -tensors in a group come from a same binary file. We reorder the bindings -according to the grouping result based on topological sort. - -In ideal case, after reordering the weight transform function has the following -process during execution: -* load a weight binary file, -* process all weights in this file, -* load another weight binary file, -* process all weights in this file, -* ... - -So the maximum CPU memory usage will be the size of the largest raw weight -binary file, since we process and release all the raw weight tensors immediately -after loading them from the file. -""" - - -def analyze_func( - func: relax.Function, - pidx2binname: Dict[int, str], -) -> Tuple[List[relax.Binding], Dict[relax.Var, List[relax.Binding]], Dict[relax.Binding, int],]: - """Binding grouping analysis function. - It takes the function to be analyzed, and mapping from each raw tensor index - to the name of the binary file where it resides. - - This analysis function - * computes a new order of weight fetching bindings (the bindings in form - `lv = params[idx]`) based on weight location on disk. - * collects the dataflow def-use information of the given function for - topological sort (particularly, it collects the consumers of each binding - variables and the number of variables each binding depends on). - - Parameters - ---------- - func : relax.Function - The weight transform function to be analyzed. - - pidx2binname : Dict[int, str] - The mapping from each raw tensor index to the name of the binary - file where it resides. - - Returns - ------- - get_param_bindings : List[relax.Binding] - The weight fetching bindings (`lv = params[idx]`) in the new order. - - var_users : Dict[relax.Var, List[relax.Binding]] - The consumer bindings of each binding variable. - Used for topological sort. - - num_depending_vars : Dict[relax.Binding, int] - The number of variables each binding depends on. - Used for topological sort. - """ - - # The mapping of the weight fetching bindings in each binary file. - # Here empty string means the weight is not in any binary file (e.g., cached - # sin and cos values for rotary embeddings). - binname2get_param_bindings: Dict[str, List[relax.Binding]] = {"": []} - # The set of binding variables. - binding_var_set: Set[relax.Var] = set() - var_users: Dict[relax.Var, List[relax.Binding]] = {} - num_depending_vars: Dict[relax.Binding, int] = {} - - if func.attrs is not None and "num_input" in func.attrs: - num_input = func.attrs["num_input"].value - else: - num_input = 0 - - # Sanity check on the function pattern. - assert isinstance(func.body, relax.SeqExpr) - assert len(func.body.blocks) == 1 - assert isinstance(func.body.blocks[0], relax.DataflowBlock) - assert func.body.blocks[0].bindings[-1].var.same_as(func.body.body) - - if isinstance(func.params[num_input].struct_info, relax.TupleStructInfo): - model_param_tuple = func.params[num_input] - else: - model_param_tuple = None - for i, var in enumerate(func.params[num_input:]): - binname = pidx2binname.get(i, var.name_hint) - if binname not in binname2get_param_bindings: - binname2get_param_bindings[binname] = [] - binname2get_param_bindings[binname].append(var) - - bindings = list(func.body.blocks[0].bindings) - - # Go through each binding except the last one. (The last one is the output - # binding `gv = (lv, lv1, ...)`) which we ignore for analysis. - for binding in bindings[:-1]: - value = binding.value - binding_var_set.add(binding.var) - var_users[binding.var] = [] - - if ( - model_param_tuple is not None - and isinstance(value, relax.TupleGetItem) - and value.tuple_value.same_as(model_param_tuple) - ): - # For weight fetching bindings (`lv = params[idx]`), we group them - # according to the binary file name. - pidx = value.index - if pidx not in pidx2binname: - binname2get_param_bindings[""].append(binding) - continue - - binname = pidx2binname[pidx] - if binname in binname2get_param_bindings: - binname2get_param_bindings[binname].append(binding) - else: - binname2get_param_bindings[binname] = [binding] - else: - # For other bindings, we collect the use-def information for - # topological sort. - num_depending_vars[binding] = 0 - - def fvisit(obj): - if isinstance(obj, relax.Var) and obj in binding_var_set: - assert obj in var_users - var_users[obj].append(binding) - num_depending_vars[binding] += 1 - - relax.analysis.post_order_visit(value, fvisit) - - # Get the weight fetching bindings in new order according to the group results. - get_param_bindings: List[relax.Binding] = [] - for bindings in binname2get_param_bindings.values(): - get_param_bindings += bindings - - return get_param_bindings, var_users, num_depending_vars - - -def reorder_func( - func: relax.Function, - pidx2binname: Optional[Dict[int, str]] = None, -) -> relax.Function: - """Reorder the bindings of the input weight transform Relax function - according the weight location in binary files. - - This function first analyzes the input function and gets the reordered - weight fetching bindings and the use-def information for topological sort. - It then reorders all bindings in the function with topological sort. - - Parameters - ---------- - func : relax.Function - The weight transform function to be analyzed. - - pidx2binname : Optional[Dict[int, str]] - - The mapping from each raw tensor index to the name of the - binary file where it resides. If a relax dataflow graph has - multiple valid topological sorts, the order that minimizes the - number of simultaneously open files will be produced - - If `None` (default), the existing order of relax bindings is - preserved in these cases. - - Returns - ------- - func_updated : relax.Function - The returned function where the bindings are updated with the new order. - - """ - - if pidx2binname is None: - pidx2binname = {} - - bindings_to_visit = list(func.body.blocks[0].bindings) - param_lookup = {param: i for i, param in enumerate(func.params)} - binding_lookup = {} - previously_defined = set(func.params) - new_binding_order = [] - - param_tuple = None - if len(func.params) == 1 and isinstance(func.params[0].struct_info, relax.TupleStructInfo): - param_tuple = func.params[0] - - def sort_key(i): - binding = bindings_to_visit[i] - upstream_vars = relax.analysis.free_vars(binding.value) - - valid_ordering = all(var in previously_defined for var in upstream_vars) - last_param_used = max( - (param_lookup[var] for var in upstream_vars if var in param_lookup), default=-1 - ) - earliest_binding_used = min( - (binding_lookup[var] for var in upstream_vars if var in binding_lookup), default=-1 - ) - if ( - param_tuple - and isinstance(binding.value, relax.TupleGetItem) - and binding.value.tuple_value.same_as(param_tuple) - and binding.value.index in pidx2binname - ): - tuple_param_group = pidx2binname[binding.value.index] - else: - tuple_param_group = "" - - return [ - # First, sort by valid orderings, so the min element will - # always be a binding that would be legal to use. - -valid_ordering, - # Next, sort by the function parameter used by this - # binding, in increasing order. That way, we start by - # computing everything that required just the first - # parameter, then move on to variables that can be - # computed with the first two parameters, and so on. - last_param_used, - # Next, sort by the other bindings used. This way, for - # variables that are only used as input in a single - # downstream binding, the variable's required live range - # is minimized. - -earliest_binding_used, - # Finally, if this is a `TupleGetItem(param_tuple, i)`, - # select the option that uses an already-open file. This - # is mainly used relevant when loading from pytorch, which - # require loading the entire file at once. - tuple_param_group, - ] - - while bindings_to_visit: - i_binding = min(range(len(bindings_to_visit)), key=sort_key) - binding = bindings_to_visit.pop(i_binding) - - assert all(var in previously_defined for var in relax.analysis.free_vars(binding.value)) - new_binding_order.append(binding) - previously_defined.add(binding.var) - - assert len(new_binding_order) == len(func.body.blocks[0].bindings) - - return relax.Function( - func.params, - relax.SeqExpr( - blocks=[relax.DataflowBlock(new_binding_order)], - body=func.body.body, - ), - func.ret_struct_info, - func.is_pure, - func.attrs, - ) - - -@tvm.transform.module_pass(opt_level=0, name="ReorderTransformFunc") -class ReorderTransformFunc: - def __init__(self, pidx2binname: Optional[Dict[int, str]] = None): - if pidx2binname is None: - pidx2binname = {} - self.pidx2binname = pidx2binname - - def transform_module( - self, - mod: IRModule, - ctx: tvm.transform.PassContext, - ) -> IRModule: - mod = mod.clone() - for gv, func in list(mod.functions.items()): - if isinstance(func, relax.Function) and func.attrs and "global_symbol" in func.attrs: - assert gv.name_hint.endswith("transform_params") - func_updated = reorder_func(func, self.pidx2binname) - mod[gv] = func_updated - return mod diff --git a/mlc_llm/transform/rewrite_attention.py b/mlc_llm/transform/rewrite_attention.py deleted file mode 100644 index d6d5693762..0000000000 --- a/mlc_llm/transform/rewrite_attention.py +++ /dev/null @@ -1,46 +0,0 @@ -import tvm -from tvm.relax.dpl import PatternContext, is_const, is_op, rewrite_call, wildcard -from tvm.script import relax as R - - -def rewrite_attention(use_flash_mqa=False): - @tvm.ir.transform.module_pass(opt_level=0, name="mlc_llm.transform.rewrite_attention") - def ir_module_transform(mod: tvm.IRModule, context) -> tvm.IRModule: - Q = wildcard() - K = wildcard() - V = wildcard() - - Q_BNSH = is_op("relax.permute_dims")(Q) - - if use_flash_mqa: - K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K)) - V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V)) - else: - K_BNSH = is_op("relax.permute_dims")(K) - V_BNSH = is_op("relax.permute_dims")(V) - - K_BNSH_T = is_op("relax.permute_dims")(K_BNSH) - - matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T) - divide = is_op("relax.divide")(matmul1, is_const()) - max = is_op("relax.maximum")(divide, is_const()) - min = is_op("relax.minimum")(max, wildcard()) - softmax = is_op("relax.nn.softmax")(is_op("relax.astype")(min)) - matmul2 = is_op("relax.matmul")(is_op("relax.astype")(softmax), V_BNSH) - - pattern = is_op("relax.permute_dims")(matmul2) - - def callback(_, matchings): - return R.nn.attention( - matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight" - ) - - new_module = {} - for gvar, func in mod.functions.items(): - if isinstance(func, tvm.relax.Function): - func = rewrite_call(pattern, callback, func) - new_module[gvar] = func - - return tvm.IRModule(new_module, mod.type_definitions, mod.attrs, mod.global_infos) - - return ir_module_transform diff --git a/mlc_llm/transform/set_entry_funcs.py b/mlc_llm/transform/set_entry_funcs.py deleted file mode 100644 index 714da06dd7..0000000000 --- a/mlc_llm/transform/set_entry_funcs.py +++ /dev/null @@ -1,70 +0,0 @@ -import re - -from typing import List, Union - -import tvm -from tvm.ir import GlobalVar - - -def SetEntryFuncs(*entry_funcs: List[Union[GlobalVar, str]]) -> tvm.ir.transform.Pass: - """Update which functions are externally-exposed - - All functions whose GlobalVar is contained `entry_funcs` list, or - whose name matches a regular expression in `entry_funcs`, are set - as externally exposed. All other functions are set as internal. - - This pass does not add or remove any functions from the - `IRModule`. This pass may result in functions no longer being - used by any externally-exposed function. In these cases, users - may use the `relax.transform.DeadCodeElimination` pass to remove - any unnecessary functions. - - Parameters - ---------- - entry_funcs: List[Union[GlobalVar, str]] - - Specifies which functions that should be externally exposed, - either by GlobalVar or by regular expression. - - Returns - ------- - transform: tvm.ir.transform.Pass - - The IRModule-to-IRModule transformation - """ - - def is_entry_func(gvar: GlobalVar) -> bool: - for entry_func in entry_funcs: - if isinstance(entry_func, GlobalVar): - if entry_func.same_as(gvar): - return True - elif isinstance(entry_func, str): - if re.fullmatch(entry_func, gvar.name_hint): - return True - else: - raise TypeError( - f"SetEntryFuncs requires all arguments to be a GlobalVar or a str. " - f"However, argument {entry_func} has type {type(entry_func)}." - ) - - def is_exposed(func: tvm.ir.BaseFunc) -> bool: - return func.attrs is not None and "global_symbol" in func.attrs - - @tvm.ir.transform.module_pass(opt_level=0, name="SetEntryFuncs") - def transform(mod: tvm.IRModule, _pass_context) -> tvm.IRModule: - updates = {} - for gvar, func in mod.functions.items(): - if is_entry_func(gvar): - if not is_exposed(func): - updates[gvar] = func.with_attr("global_symbol", gvar.name_hint) - else: - if is_exposed(func): - updates[gvar] = func.without_attr("global_symbol") - - if updates: - mod = mod.clone() - mod.update(updates) - - return mod - - return transform diff --git a/mlc_llm/transform/transpose_matmul.py b/mlc_llm/transform/transpose_matmul.py deleted file mode 100644 index fd8a9aef41..0000000000 --- a/mlc_llm/transform/transpose_matmul.py +++ /dev/null @@ -1,349 +0,0 @@ -import tvm -from tvm import IRModule, relax, te, tir -from tvm.relax.dpl.pattern import is_op, wildcard - - -@relax.expr_functor.mutator -class TransposeMatmulCodeGenerator(relax.PyExprMutator): - def __init__(self, mod): - super().__init__(mod) - - @staticmethod - def pattern(): - w = wildcard() - x = wildcard() - wT = is_op("relax.permute_dims")(w) - o = is_op("relax.matmul")(x, wT) - annotations = {"o": o, "w": w, "x": x, "wT": wT} - - def _check(context: relax.transform.PatternCheckContext) -> bool: - transpose_call = context.annotated_expr["wT"] - ndim = transpose_call.args[0].struct_info.ndim - if ndim == -1: - return False - if ndim == 2 and transpose_call.attrs.axes is None: - return True - axes = list(range(ndim)) - axes[-1], axes[-2] = axes[-2], axes[-1] - return list(transpose_call.attrs.axes) == axes - - return o, annotations, _check - - def visit_call_(self, call: relax.Call) -> relax.Expr: - out_dtype = None - - def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: - nonlocal out_dtype - a_shape = list(a.shape) - b_shape = list(b.shape) - a_prepended = False - b_appended = False - if len(a_shape) == 1: - a_prepended = True - a_shape.insert(0, 1) - if len(b_shape) == 1: - b_appended = True - b_shape.append(1) - - is_a_larger = len(a_shape) > len(b_shape) - offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape) - - a_relax = relax.Var("a", relax.TensorStructInfo(a.shape)) - bT_shape = list(b.shape) - bT_shape[-1], bT_shape[-2] = bT_shape[-2], bT_shape[-1] - bT_relax = relax.Var("b", relax.TensorStructInfo(bT_shape)) - output_shape = self.builder_.normalize( - relax.op.matmul(a_relax, bT_relax) - ).struct_info.shape - - def matmul_compute(*idx_spatial): - k = te.reduce_axis((0, a_shape[-1]), name="k") - - def multiply_compute(idx_reduce): - a_indices = [] - b_indices = [] - - for i in range(offset): - if is_a_larger: - a_indices.append(idx_spatial[i]) - else: - b_indices.append(idx_spatial[i]) - for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)): - a_dim = a_shape[i if is_a_larger else i - offset] - b_dim = b_shape[i if not is_a_larger else i - offset] - dim_equal = a_dim == b_dim - if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0: - a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1 - b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1 - a_indices.append(0 if a_dim_is_one else idx_spatial[i]) - b_indices.append(0 if b_dim_is_one else idx_spatial[i]) - else: - a_indices.append(idx_spatial[i]) - b_indices.append(idx_spatial[i]) - - if not a_prepended: - a_indices.append(idx_spatial[-2 + b_appended]) - a_indices.append(idx_reduce) - if not b_appended: - b_indices.append(idx_spatial[-1]) - b_indices.append(idx_reduce) - - dtype = out_dtype - if dtype != "": - return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) - return a(*a_indices) * b(*b_indices) - - return te.sum(multiply_compute(k), axis=k) - - return te.compute( - output_shape, - lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda - name="NT_matmul", - ) - - if isinstance(call.op, relax.GlobalVar): - function = self.builder_.get()[call.op] - if ( - function.attrs - and "Composite" in function.attrs - and function.attrs["Composite"] == "transpose_matmul_fuse" - ): - out_dtype = function.ret_struct_info.dtype - return self.builder_.call_te( - te_transposed_matmul, - call.args[1], - call.args[0], - primfunc_name_hint="NT_matmul", - ) - - return super().visit_call_(call) - - -@tvm.transform.module_pass(opt_level=0, name="FuseTransposeMatmul") -class FuseTransposeMatmul: - def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: - mod = relax.transform.FuseOpsByPattern( - [("transpose_matmul_fuse", *TransposeMatmulCodeGenerator.pattern())] - )(mod) - - transpose_matmul_codegen = TransposeMatmulCodeGenerator(mod) - for gv in mod.functions: - func = mod[gv] - if not isinstance(func, relax.Function): - continue - func = transpose_matmul_codegen.visit_expr(func) - transpose_matmul_codegen.builder_.update_func(gv, func) - - return transpose_matmul_codegen.builder_.get() - -@relax.expr_functor.mutator -class Transpose1MatmulCodeGenerator(relax.PyExprMutator): - def __init__(self, mod): - super().__init__(mod) - - @staticmethod - def pattern(): - w = wildcard() - x = wildcard() - xT = is_op("relax.permute_dims")(x) - wT = is_op("relax.permute_dims")(w) - o = is_op("relax.matmul")(xT, wT) - annotations = {"o": o, "w": w, "x": x, "xT": xT, "wT": wT} - - def _check(context: relax.transform.PatternCheckContext) -> bool: - x_transpose_call = context.annotated_expr["o"] - w_transpose_call = context.annotated_expr["o"] - x_shape = context.annotated_expr["x"].struct_info.shape - w_shape = context.annotated_expr["w"].struct_info.shape - xT_shape = x_transpose_call.args[0].struct_info.shape - wT_shape = w_transpose_call.args[1].struct_info.shape - - if not ( - xT_shape[0] == x_shape[0] and xT_shape[1] == x_shape[2] - and xT_shape[2] == x_shape[1] and xT_shape[3] == x_shape[3] - ): - return False - - if not ( - wT_shape[0] == w_shape[0] and wT_shape[1] == w_shape[2] - and wT_shape[2] == w_shape[3] and wT_shape[3] == w_shape[1] - ): - return False - - return True - - return o, annotations, _check - - def visit_call_(self, call: relax.Call) -> relax.Expr: - out_dtype = None - - def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: - nonlocal out_dtype - a_shape = list(a.shape) - b_shape = list(b.shape) - - aT_shape = list(a.shape) - aT_shape[-2], aT_shape[-3] = aT_shape[-3], aT_shape[-2] - aT_relax = relax.Var("a", relax.TensorStructInfo(aT_shape)) - bT_shape = list(b.shape) - bT_shape[-1], bT_shape[-2], bT_shape[-3] = bT_shape[-3], bT_shape[-1], bT_shape[-2] - bT_relax = relax.Var("b", relax.TensorStructInfo(bT_shape)) - output_shape = self.builder_.normalize( - relax.op.matmul(aT_relax, bT_relax) - ).struct_info.shape - def matmul_compute(*idx_spatial): - k = te.reduce_axis((0, a_shape[-1]), name="k") - def multiply_compute(idx_reduce): - a_indices = [idx_spatial[0], idx_spatial[2], idx_spatial[1], idx_reduce] - b_indices = [idx_spatial[0], idx_spatial[3], idx_spatial[1], idx_reduce] - dtype = out_dtype - if dtype != "": - return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) - return a(*a_indices) * b(*b_indices) - - return te.sum(multiply_compute(k), axis=k) - - return te.compute( - output_shape, - lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda - name="NT_matmul", - ) - - if isinstance(call.op, relax.GlobalVar): - function = self.builder_.get()[call.op] - if ( - "Composite" in function.attrs - and function.attrs["Composite"] == "transpose1_matmul_fuse" - ): - out_dtype = function.ret_struct_info.dtype - return self.builder_.call_te( - te_transposed_matmul, - call.args[0], - call.args[1], - primfunc_name_hint="NT_matmul", - ) - - return super().visit_call_(call) - - -@tvm.transform.module_pass(opt_level=0, name="FuseTranspose1Matmul") -class FuseTranspose1Matmul: - def transform_module( - self, mod: IRModule, ctx: tvm.transform.PassContext - ) -> IRModule: - mod = relax.transform.FuseOpsByPattern( - [("transpose1_matmul_fuse", *Transpose1MatmulCodeGenerator.pattern())] - )(mod) - - transpose_matmul_codegen = Transpose1MatmulCodeGenerator(mod) - for gv in mod.functions: - func = mod[gv] - if not isinstance(func, relax.Function): - continue - func = transpose_matmul_codegen.visit_expr(func) - transpose_matmul_codegen.builder_.update_func(gv, func) - - return transpose_matmul_codegen.builder_.get() - - -@relax.expr_functor.mutator -class Transpose2MatmulCodeGenerator(relax.PyExprMutator): - def __init__(self, mod): - super().__init__(mod) - - @staticmethod - def pattern(): - w = wildcard() - x = wildcard() - wT = is_op("relax.permute_dims")(w) - o = is_op("relax.permute_dims")(is_op("relax.matmul")(x, wT)) - #oT = is_op("relax.permute_dims")(o) - annotations = {"o": o, "w": w, "x": x, "wT": wT} - - def _check(context: relax.transform.PatternCheckContext) -> bool: - w_transpose_call = context.annotated_expr["wT"] - w_shape = w_transpose_call.args[0].struct_info.shape - wT_shape = w_transpose_call.struct_info.shape - oT_call = context.annotated_expr["o"] - o_shape = oT_call.args[0].struct_info.shape - oT_shape = oT_call.struct_info.shape - - if not ( - wT_shape[0] == w_shape[0] and wT_shape[1] == w_shape[2] - and wT_shape[2] == w_shape[1] and wT_shape[3] == w_shape[3] - ): - return False - - if not ( - oT_shape[0] == o_shape[0] and oT_shape[1] == o_shape[2] - and oT_shape[2] == o_shape[1] and oT_shape[3] == o_shape[3] - ): - return False - - return True - - return o, annotations, _check - - def visit_call_(self, call: relax.Call) -> relax.Expr: - out_dtype = None - - def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: - nonlocal out_dtype - a_shape = list(a.shape) - b_shape = list(b.shape) - output_shape = [a_shape[0], b_shape[-2], a_shape[2], a_shape[3]] - def matmul_compute(*idx_spatial): - k = te.reduce_axis((0, b_shape[-1]), name="k") - def multiply_compute(idx_reduce): - a_indices = [idx_spatial[0], idx_reduce, idx_spatial[2], idx_spatial[3]] - b_indices = [idx_spatial[0], idx_spatial[2], idx_spatial[1], idx_reduce] - - dtype = out_dtype - if dtype != "": - return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) - return a(*a_indices) * b(*b_indices) - - return te.sum(multiply_compute(k), axis=k) - - return te.compute( - output_shape, - lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda - name="NT_matmul", - ) - - if isinstance(call.op, relax.GlobalVar): - function = self.builder_.get()[call.op] - if ( - "Composite" in function.attrs - and function.attrs["Composite"] == "transpose2_matmul_fuse" - ): - out_dtype = function.ret_struct_info.dtype - #NT_output_shape = function.ret_struct_info.shape - return self.builder_.call_te( - te_transposed_matmul, - call.args[0], - call.args[1], - primfunc_name_hint="NT_matmul", - ) - - return super().visit_call_(call) - - -@tvm.transform.module_pass(opt_level=0, name="FuseTranspose2Matmul") -class FuseTranspose2Matmul: - def transform_module( - self, mod: IRModule, ctx: tvm.transform.PassContext - ) -> IRModule: - mod = relax.transform.FuseOpsByPattern( - [("transpose2_matmul_fuse", *Transpose2MatmulCodeGenerator.pattern())] - )(mod) - - transpose_matmul_codegen = Transpose2MatmulCodeGenerator(mod) - for gv in mod.functions: - func = mod[gv] - if not isinstance(func, relax.Function): - continue - func = transpose_matmul_codegen.visit_expr(func) - transpose_matmul_codegen.builder_.update_func(gv, func) - - return transpose_matmul_codegen.builder_.get() diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py deleted file mode 100644 index 094c81d25a..0000000000 --- a/mlc_llm/utils.py +++ /dev/null @@ -1,738 +0,0 @@ -# pylint: disable=missing-docstring,invalid-name -import argparse -import functools -import json -import math -import os -import shutil -from typing import Any, Dict, List, Optional, Set - -import numpy as np -import tvm -from tvm import relax - -from .quantization import quantization_schemes -from .relax_model import param_manager - -supported_model_types = set( - [ - "llama", - "gpt_neox", - "gpt_bigcode", - "minigpt", - "moss", - "rwkv", - "gptj", - "chatglm", - "mistral", - "stablelm_epoch", - "gpt2", - "qwen" - ] -) - - -def wrap_tqdm_counter(func, **tqdm_kwargs): - # tqdm isn't a hard requirement, so return the original function - # if it isn't available. - try: - from tqdm import tqdm - except ImportError: - return func - - pbar = tqdm(**tqdm_kwargs) - - @functools.wraps(func) - def inner(*args, **kwargs): - pbar.update(1) - return func(*args, **kwargs) - - return inner - - -def argparse_postproc_common(args: argparse.Namespace) -> None: - if hasattr(args, "device_name"): - if args.device_name == "auto": - if tvm.cuda().exist: - args.device_name = "cuda" - elif tvm.metal().exist: - args.device_name = "metal" - elif tvm.vulkan().exist: - args.device_name = "vulkan" - elif tvm.opencl().exist: - args.device_name = "opencl" - else: - raise ValueError("Cannot auto deduce device-name, please set it") - - model_category_override = { - "moss-moon-003-sft": "gptj", - "moss-moon-003-base": "gptj", - "rwkv-": "rwkv", - "rwkv_world": "rwkv_world", - "minigpt": "minigpt", - } - try: - with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: - config = json.load(i_f) - args.model_category = config["model_type"] - model_path_lower = args.model_path.lower() - if "rwkv" in model_path_lower and "world" in model_path_lower: - args.model_category = "rwkv_world" - except Exception: - args.model_category = "" - model = args.model.lower() - if "rwkv" in model and "world" in model: - model = "rwkv_world" - for prefix, override_category in model_category_override.items(): - if model.startswith(prefix): - args.model_category = override_category - break - assert args.model_category is not None - - model_conv_templates = { - "llama-2": "llama-2", - "codellama-7b-instruct": "codellama_instruct", - "codellama-13b-instruct": "codellama_instruct", - "codellama-34b-instruct": "codellama_instruct", - "codellama": "codellama_completion", - "gpt2": "gpt2", - "vicuna-": "vicuna_v1.1", - "dolly-": "dolly", - "stablelm-3b-": "stablelm-3b", - "stablelm-": "stablelm", - "redpajama-": "redpajama_chat", - "minigpt": "minigpt", - "moss-moon-003-sft": "moss", - "moss-moon-003-base": "LM", - "gpt-j-": "LM", - "open_llama": "LM", - "rwkv-": "rwkv", - "rwkv_world": "rwkv_world", - "gorilla-": "gorilla", - "guanaco": "guanaco", - "wizardlm-7b": "wizardlm_7b", # first get rid of 7b - "wizardlm-": "vicuna_v1.1", # all others use vicuna template - "wizardmath-": "wizard_coder_or_math", - "wizardcoder-": "wizard_coder_or_math", - "starcoder": "gpt_bigcode", - "gpt_bigcode-santacoder": "gpt_bigcode", - "stablecode-completion": "stablecode_completion", - "stablecode-instruct": "stablecode_instruct", - "chatglm2": "glm", - "chatglm3": "glm", - "codegeex2": "glm", - "tinyllama": "chatml", - "openhermes-2.5-mistral": "open_hermes_mistral", - "neuralhermes-2.5-mistral": "neural_hermes_mistral", - "qwen": "qwen" - } - - for prefix, conv_template in model_conv_templates.items(): - if model.startswith(prefix): - args.conv_template = conv_template - break - else: - args.conv_template = f"{args.model_category}_default" - - if args.quantization not in quantization_schemes: - raise ValueError(f'Quantization "{args.quantization}" is not supported.') - - args.quantization = quantization_schemes[args.quantization] - - use_ft_quant = args.quantization in ["q4f16_ft", "q8f16_ft", "q4f16_ft_group", "q8f16_ft_group"] - - if use_ft_quant and args.num_shards > 1: - # Preprocess is done after sharding for this case. - args.quantization.linear_weight.do_preprocess = False - args.quantization.final_fc_weight.do_preprocess = False - - -def debug_dump_script(mod, name, args: argparse.Namespace, show_meta=True): - """Debug dump mode""" - if not args.debug_dump: - return - dump_path = os.path.join(args.artifact_path, "debug", name) - with open(dump_path, "w", encoding="utf-8") as outfile: - outfile.write(mod.script(show_meta=show_meta)) - print(f"Dump mod to {dump_path}") - - -def debug_dump_benchmark_script( - mod: tvm.ir.IRModule, - name: str, - args: argparse.Namespace, -) -> None: - """Extract model level benchmark workloads from relax model.""" - if not args.debug_dump: - return - - from tvm.dlight.benchmark import ( # pylint: disable=import-error,import-outside-toplevel - extract_all_func_info_from_relax, - ) - - dump_path = os.path.join(args.artifact_path, "debug", name + ".py") - with open(dump_path, "w", encoding="utf-8") as outfile: - outfile.write( - "# Please save this file to dlight_bench/models and add\n" - + f"# `from .{name} import *` to dlight_bench/models/__init__.py\n" - + "from dlight_bench import DlightBench\n" - + "from tvm.script import tir as T\n\n" - ) - - stmt = [] - try: - relax_funcs, _ = extract_all_func_info_from_relax(mod) - except NotImplementedError: - return - tvm_script_prefix = "# from tvm.script import tir as T" - for relax_func_gv in relax_funcs: # pylint: disable=consider-using-dict-items - for prim_func_gv in relax_funcs[relax_func_gv]: - # add global_symbol - func_body = ( - mod[prim_func_gv] - .with_attr("global_symbol", prim_func_gv.name_hint) - .script(name=prim_func_gv.name_hint) - ) - # remove prefix - if func_body.startswith(tvm_script_prefix + "\n"): - func_body = func_body[len(tvm_script_prefix) :] - # print out - outfile.write(func_body + "\n") - # register - stmt.append( - f"DlightBench.register_bench_workload({prim_func_gv.name_hint}, " - f"'{name}', '{prim_func_gv.name_hint}')" - ) - outfile.write("\n" + "\n".join(stmt) + "\n") - print(f"Dump benchmarking script to {dump_path}.") - - -def debug_load_script(name: str, args: argparse.Namespace): - input_path = os.path.join(args.artifact_path, "debug", name) - lib = {"__file__": input_path} - with open(input_path, "rb") as i_f: - exec(compile(i_f.read(), input_path, "exec"), lib, lib) # pylint: disable=exec-used - return lib["Module"] - - -def debug_dump_shader(ex: tvm.relax.Executable, name: str, args: argparse.Namespace): - """Debug dump mode""" - if not args.debug_dump: - return - target_kind = args.target.kind.default_keys[0] - suffix_map = { - "webgpu": ".wgsl", - "cuda": ".cu", - "metal": ".mtl", - "opencl": ".cl", - } - suffix = suffix_map.get(target_kind, ".txt") - dump_path = os.path.join(args.artifact_path, "debug", name + suffix) - source = ex.mod.imported_modules[0].imported_modules[0].get_source() - with open(dump_path, "w", encoding="utf-8") as outfile: - outfile.write(source) - print(f"Dump shader to {dump_path}") - - -def convert_weights( - mod_transform: tvm.IRModule, - param_mgr: param_manager.ParamManager, - model_params: List[Optional[tvm.nd.NDArray]], - args: argparse.Namespace, -): - # Save the number of parameters before we lower mod_transform, so - # we can use them in the progress bar. - transform_func = mod_transform["transform_params"] - num_original_params = len(transform_func.params[0].struct_info.fields) - num_transformed_params = len(transform_func.struct_info.ret.fields) - - # Remove the dataflow block inside the param transform function, - # so that the LazyTransformParams pass can be applied. - mod_transform = relax.transform.ToNonDataflow()(mod_transform) - mod_transform = relax.transform.LazyTransformParams()(mod_transform) - mod_transform = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_transform) - mod_transform = relax.transform.LegalizeOps()(mod_transform) - - debug_dump_script(mod_transform, "mod_convert_weights.py", args) - - target = detect_local_target() - print(f"Automatically using target for weight quantization: {target}") - device = tvm.device(target.kind.default_keys[0]) - - get_item = param_mgr.get_param_get_item( - device, - model_params, - ) - set_item, loaded_params = param_mgr.get_param_set_item() - - get_item = wrap_tqdm_counter( - get_item, desc="Get old param", position=0, unit="tensors", total=num_original_params - ) - set_item = wrap_tqdm_counter( - set_item, desc="Set new param", position=1, unit="tensors", total=num_transformed_params - ) - - tvm.register_func(func_name="get_item", f=get_item, override=True) - tvm.register_func(func_name="set_item", f=set_item, override=True) - - if target.kind.name != "llvm": - with tvm.target.Target(target): - mod_transform = tvm.tir.transform.DefaultGPUSchedule()(mod_transform) - - ex = relax.build(mod_transform, target=target) - vm = relax.vm.VirtualMachine(ex, device) - print("Start computing and quantizing weights... This may take a while.") - vm["transform_params"]() - print("Finish computing and quantizing weights.") - return loaded_params - - -def save_params(params: List[tvm.nd.NDArray], artifact_path: str, num_presharded: int = 1) -> None: - from tvm.contrib import tvmjs # pylint: disable=import-outside-toplevel - - assert len(params) % num_presharded == 0 - num_weights = len(params) // num_presharded - - meta_data = {} - param_dict = {} - meta_data["ParamSize"] = len(params) - for i, nd in enumerate(params): - if num_presharded == 1: - param_name = f"param_{i}" - else: - expected_worker_id = i // num_weights - orig_param_id = i % num_weights - param_name = f"param_{orig_param_id}_shard-{expected_worker_id+1}-of-{num_presharded}" - - param_dict[param_name] = nd - - total_size_bytes = sum( - math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params - ) - total_size_gb = total_size_bytes / (1024**3) - print(f"Total param size: {total_size_gb} GB") - tvmjs.dump_ndarray_cache( - param_dict, f"{artifact_path}/params", meta_data=meta_data, encode_format="raw" - ) - - -def load_params(artifact_path: str, device) -> List[tvm.nd.NDArray]: - from tvm.contrib import tvmjs # pylint: disable=import-outside-toplevel - - params, meta = tvmjs.load_ndarray_cache(f"{artifact_path}/params", device) - plist = [] - size = meta["ParamSize"] - for i in range(size): - plist.append(params[f"param_{i}"]) - return plist - - -def load_params_SLM( - model_weight_path: str, device, model_metadata: Dict[str, Any] -) -> List[tvm.nd.NDArray]: - from tvm.contrib import tvmjs # pylint: disable=import-outside-toplevel - - params, meta = tvmjs.load_ndarray_cache(model_weight_path, device) - param_names = [param["name"] for param in model_metadata["params"]] - assert len(param_names) == meta["ParamSize"] - - plist = [] - for param_name in param_names: - plist.append(params[param_name]) - return plist - - -def copy_tokenizer(args: argparse.Namespace) -> None: - for filename in os.listdir(args.model_path): - if filename in [ - "tokenizer.model", - "tokenizer.json", - "vocab.json", - "merges.txt", - "added_tokens.json", - "tokenizer_config.json", - ]: - shutil.copy( - os.path.join(args.model_path, filename), - os.path.join(args.artifact_path, "params"), - ) - - # If we have `tokenizer.model` but not `tokenizer.json`, try convert it to - # `tokenizer.json` with `transformers`. - tokenizer_json_path = os.path.join(args.model_path, "tokenizer.json") - tokenizer_model_path = os.path.join(args.model_path, "tokenizer.model") - if os.path.exists(tokenizer_model_path) and (not os.path.exists(tokenizer_json_path)): - print("Attempting to convert `tokenizer.model` to `tokenizer.json`.") - try: - # pylint: disable=import-outside-toplevel - from transformers import AutoTokenizer - - tokenizer_json_save_dest = os.path.join(args.artifact_path, "params/tokenizer.json") - fast_tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=True) - fast_tokenizer.backend_tokenizer.save(tokenizer_json_save_dest) - print(f"Succesfully converted `tokenizer.model` to: {tokenizer_json_save_dest}") - except ImportError: - print( - "WARNING: The model has `tokenizer.model` but not `tokenizer.json`. It is" - + "recommended to use `tokenizer.json`, so we try convert it with `transformers`.\n" - + "However, we were unable to import `transformers`, hence skipping this step." - ) - except Exception as error: # pylint: disable=broad-exception-caught - print( - "WARNING: The model has `tokenizer.model` but not `tokenizer.json`. It is" - + "recommended to use `tokenizer.json`, so we try convert it with `transformers`.\n" - + "However, we are skipping this due to an error:\n", - error, - ) - - -def get_tokenizer_files(path) -> List[str]: - tokenizer_set = { - "tokenizer.model", - "tokenizer.json", - "vocab.json", - "merges.txt", - "added_tokens.json", - } - return [x for x in os.listdir(path) if x in tokenizer_set] - - -def _detect_local_metal_host(): - target_triple = tvm._ffi.get_global_func("tvm.codegen.llvm.GetDefaultTargetTriple")() - process_triple = tvm._ffi.get_global_func("tvm.codegen.llvm.GetProcessTriple")() - host_cpu = tvm._ffi.get_global_func("tvm.codegen.llvm.GetHostCPUName")() - print( - f"Host CPU dection:\n Target triple: {target_triple}\n Process triple: {process_triple}\n Host CPU: {host_cpu}" - ) - if target_triple.startswith("x86_64-"): - return tvm.target.Target( - { - "kind": "llvm", - "mtriple": "x86_64-apple-macos", - "mcpu": host_cpu, - } - ) - # should start with "arm64-" - return tvm.target.Target( - { - "kind": "llvm", - "mtriple": "arm64-apple-macos", - "mcpu": host_cpu, - } - ) - - -def _detect_local_metal(): - dev = tvm.metal() - if not dev.exist: - return None - - return tvm.target.Target( - { - "kind": "metal", - "max_shared_memory_per_block": 32768, - "max_threads_per_block": dev.max_threads_per_block, - "thread_warp_size": 32, - }, - host=_detect_local_metal_host(), - ) - - -def _detect_local_cuda(): - dev = tvm.cuda() - if not dev.exist: - return None - return tvm.target.Target( - { - "kind": "cuda", - "max_shared_memory_per_block": dev.max_shared_memory_per_block, - "max_threads_per_block": dev.max_threads_per_block, - "thread_warp_size": dev.warp_size, - "registers_per_block": 65536, - "arch": "sm_" + dev.compute_version.replace(".", ""), - } - ) - - -def _detect_local_rocm(): - dev = tvm.rocm() - if not dev.exist: - return None - return tvm.target.Target( - { - "kind": "rocm", - "max_shared_memory_per_block": dev.max_shared_memory_per_block, - "max_threads_per_block": dev.max_threads_per_block, - "thread_warp_size": dev.warp_size, - } - ) - - -def _detect_local_vulkan(): - dev = tvm.vulkan() - if not dev.exist: - return None - return tvm.target.Target( - { - "kind": "vulkan", - "max_threads_per_block": dev.max_threads_per_block, - "max_shared_memory_per_block": dev.max_shared_memory_per_block, - "thread_warp_size": dev.warp_size, - "supports_float16": 1, - "supports_int16": 1, - "supports_int8": 1, - "supports_16bit_buffer": 1, - } - ) - - -def _detect_local_opencl(): - dev = tvm.opencl() - if not dev.exist: - return None - return tvm.target.Target("opencl") - - -def detect_local_target(): - for method in [ - _detect_local_metal, - _detect_local_rocm, - _detect_local_cuda, - _detect_local_vulkan, - _detect_local_opencl, - ]: - target = method() - if target is not None: - return target - - print("Failed to detect local GPU, falling back to CPU as a target") - return tvm.target.Target("llvm") - - -def parse_target(args: argparse.Namespace) -> None: - if not hasattr(args, "target"): - return - if args.target == "auto": - target = detect_local_target() - if target.host is None: - target = tvm.target.Target( - target, - host="llvm", # TODO: detect host CPU - ) - args.target = target - args.target_kind = args.target.kind.default_keys[0] - elif args.target == "cuda" or args.target == "cuda-multiarch": - target = _detect_local_cuda() - if target is None: - raise ValueError("Cannot detect local CUDA GPU target!") - multiarch = args.target == "cuda-multiarch" - args.target = target - args.target_kind = args.target.kind.default_keys[0] - if multiarch: - args.target_kind += "-multiarch" - elif args.target.startswith("nvidia/jetson"): - try: - args.target = tvm.target.Target(args.target) - except ValueError: - raise ValueError("Cannot find configuration of given nvidia/jetson board target!") - if not hasattr(args, "cc_path") or args.cc_path == "": - args.cc_path = "/usr/bin/aarch64-linux-gnu-g++" - from tvm.contrib.cc import ( # pylint: disable=import-outside-toplevel - cross_compiler, - ) - - args.export_kwargs = { - "fcompile": cross_compiler( - args.cc_path, - ), - } - args.target_kind = args.target.kind.default_keys[0] - elif args.target == "metal": - target = _detect_local_metal() - if target is None: - print("Cannot detect local Apple Metal GPU target! Falling back...") - target = tvm.target.Target( - tvm.target.Target( - { - "kind": "metal", - "max_threads_per_block": 256, - "max_shared_memory_per_block": 32768, - "thread_warp_size": 1, - } - ), - host=_detect_local_metal_host(), - ) - args.target = target - args.target_kind = args.target.kind.default_keys[0] - elif args.target == "metal_x86_64": - from tvm.contrib import xcode # pylint: disable=import-outside-toplevel - - args.target = tvm.target.Target( - tvm.target.Target( - { - "kind": "metal", - "max_threads_per_block": 256, - "max_shared_memory_per_block": 32768, - "thread_warp_size": 1, - } - ), - host="llvm -mtriple=x86_64-apple-darwin", - ) - args.target_kind = "metal_x86_64" - args.export_kwargs = { - "fcompile": xcode.create_dylib, - "sdk": "macosx", - "arch": "x86_64", - } - args.lib_format = "dylib" - elif args.target in ["iphone", "iphone-dylib", "iphone-tar"]: - from tvm.contrib import tar, xcode # pylint: disable=import-outside-toplevel - - if args.target == "iphone-dylib": - args.export_kwargs = { - "fcompile": xcode.create_dylib, - "sdk": "iphoneos", - "arch": "arm64", - } - args.lib_format = "dylib" - else: - args.export_kwargs = {"fcompile": tar.tar} - args.lib_format = "tar" - args.system_lib = True - args.system_lib_prefix = f"{args.model}_{args.quantization}_".replace("-", "_") - - @tvm.register_func("tvm_callback_metal_compile") - def compile_metal(src, target): - if target.libs: - return xcode.compile_metal(src, sdk=target.libs[0]) - return xcode.compile_metal(src) - - target = tvm.target.Target( - tvm.target.Target( - { - "kind": "metal", - "max_threads_per_block": 256, - "max_shared_memory_per_block": 32768, - "thread_warp_size": 1, - "libs": ["iphoneos"], - } - ), - host="llvm -mtriple=arm64-apple-darwin", - ) - args.target = target - args.target_kind = "iphone" - elif args.target == "vulkan": - target = tvm.target.Target( - tvm.target.Target( - { - "kind": "vulkan", - "max_threads_per_block": 256, - "max_shared_memory_per_block": 32768, - "thread_warp_size": 1, - "supports_float16": 1, - "supports_int16": 1, - "supports_int8": 1, - "supports_8bit_buffer": 1, - "supports_16bit_buffer": 1, - "supports_storage_buffer_storage_class": 1, - } - ), - host="llvm", - ) - args.target = target - args.target_kind = args.target.kind.default_keys[0] - elif args.target == "opencl": - target = tvm.target.Target( - "opencl", - host="llvm", - ) - args.target = target - args.target_kind = args.target.kind.default_keys[0] - elif args.target == "webgpu": - args.target = tvm.target.Target( - "webgpu", - host="llvm -mtriple=wasm32-unknown-unknown-wasm", - ) - args.target_kind = "webgpu" - args.lib_format = "wasm" - args.system_lib = True - if os.environ.get("TVM_HOME", "") == "": - raise RuntimeError( - "Please set TVM_HOME for webgpu build following scripts/prep_emcc_deps.sh" - ) - elif args.target in ["android", "android-dylib"]: # android-opencl - from tvm.contrib import ndk, tar - - if args.target == "android-dylib": - args.export_kwargs = { - "fcompile": ndk.create_shared, - } - args.lib_format = "so" - else: - args.export_kwargs = { - "fcompile": tar.tar, - } - args.lib_format = "tar" - args.system_lib = True - args.system_lib_prefix = f"{args.model}_{args.quantization}_".replace("-", "_") - args.target = tvm.target.Target( - "opencl", - host="llvm -mtriple=aarch64-linux-android", # TODO: Only support arm64 for now - ) - args.target_kind = "android" - elif args.target in ["mali"]: - if "TVM_NDK_CC" in os.environ: - from tvm.contrib import ndk - - args.export_kwargs = { - "fcompile": ndk.create_shared, - } - target = tvm.target.Target( - "opencl -device=mali", - host="llvm -mtriple=aarch64-linux-gnu", - ) - args.target = target - args.target_kind = "mali" - else: - args.target = tvm.target.Target(args.target, host="llvm") - args.target_kind = args.target.kind.default_keys[0] - - if args.target_kind == "cuda-multiarch": - from tvm.contrib import nvcc - - assert args.target.arch[3:] != "" - arch_list = os.getenv("CUDA_ARCH_LIST") or os.getenv("TORCH_CUDA_ARCH_LIST") - if arch_list: - compute_versions = [int(v) for v in arch_list.replace(" ", ";").split(";")] - elif int(args.target.arch[3:]) >= 70: - compute_versions = [70, 72, 75, 80, 86, 87, 89, 90] - else: - compute_versions = [60, 61, 62] - - args.target_kind = "cuda" - - @tvm.register_func("tvm_callback_cuda_compile", override=True) - def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument - """use nvcc to generate fatbin code for better optimization""" - arch = [] - for compute_version in compute_versions: - arch += ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"] - ptx = nvcc.compile_cuda(code, target_format="fatbin", arch=arch) - return ptx - - # use mingw to cross compile windows - if hasattr(args, "llvm_mingw") and args.llvm_mingw != "": - from tvm.contrib.cc import ( # pylint: disable=import-outside-toplevel - cross_compiler, - ) - - args.export_kwargs = { - "fcompile": cross_compiler( - os.path.join(args.llvm_mingw, "bin", "x86_64-w64-mingw32-clang++"), - output_format="dll", - ), - } - args.target = args.target.with_host("llvm -mtriple=x86_64-w64-windows-gnu") - args.lib_format = "dll" - - print(f"Target configured: {args.target}") diff --git a/setup.py b/setup.py deleted file mode 100644 index b9721497c2..0000000000 --- a/setup.py +++ /dev/null @@ -1,47 +0,0 @@ -from distutils.core import setup -from setuptools.dist import Distribution -from setuptools import find_packages -import os - -# Note there is no need to setup when -# running locally. - -CURRENT_DIR = os.path.dirname(__file__) - - -def git_describe_version(original_version): - """Get git describe version.""" - ver_py = os.path.join(CURRENT_DIR, "version.py") - libver = {"__file__": ver_py} - exec(compile(open(ver_py, "rb").read(), ver_py, "exec"), libver, libver) - _, gd_version = libver["git_describe_version"]() - if gd_version is not None and gd_version != original_version: - print("Use git describe based version %s" % gd_version) - return gd_version - - -__version__ = git_describe_version(None) - -setup( - name="mlc_llm", - version=__version__, - description="MLC LLM: Universal Compilation of Large Language Models", - url="https://llm.mlc.ai/", - author="MLC LLM Contributors", - license="Apache 2.0", - # See https://pypi.org/classifiers/ - classifiers=[ - "License :: OSI Approved :: Apache Software License", - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - ], - keywords="machine learning", - zip_safe=False, - packages=find_packages(), - package_dir={"mlc_llm": "mlc_llm"}, - install_requires=["numpy", "torch", "transformers", "scipy", "timm"], - entry_points={"console_scripts": ["mlc_llm_build = mlc_llm.build:main"]}, - distclass=Distribution, -) From 716b8e1878e09433b69efc7010fce0295c5ac71e Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 12 Mar 2024 03:09:24 +0800 Subject: [PATCH 053/531] [Serving] Register the StableLM3B conversation template (#1920) Update conversation_template.py --- python/mlc_chat/conversation_template.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/mlc_chat/conversation_template.py b/python/mlc_chat/conversation_template.py index 7192cc818b..fb367b7aa3 100644 --- a/python/mlc_chat/conversation_template.py +++ b/python/mlc_chat/conversation_template.py @@ -133,3 +133,22 @@ def get_conv_template(name: str) -> Optional[Conversation]: stop_token_ids=[50256], ) ) + +# StableLM3B +ConvTemplateRegistry.register_conv_template( + Conversation( + name="stablelm-3b", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={ + "user": "<|user|>", + "assistant": "<|assistant|>", + "tool": "<|user|>", + }, + seps=["<|endoftext|>", "<|endoftext|>"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|endoftext|>"], + stop_token_ids=[0], + ) +) From 2e6f9cbab9bcbffb71b94cd727cbe36aabdeb55c Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 11 Mar 2024 15:10:34 -0400 Subject: [PATCH 054/531] Remove deprecated build.py --- build.py | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 build.py diff --git a/build.py b/build.py deleted file mode 100644 index 94df83d6e5..0000000000 --- a/build.py +++ /dev/null @@ -1,4 +0,0 @@ -from mlc_llm.build import main - -if __name__ == "__main__": - main() From 9c801052bf58a78b379e3507962781b5a94584c7 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 11 Mar 2024 18:13:41 -0400 Subject: [PATCH 055/531] [Fix] KVCache creation with call_pure_packed (#1930) With https://github.com/apache/tvm/pull/16684 merged in, the KV cache creation will fail when compiling models. This PR fixes the problem by using `call_pure_packed`. --- .../dispatch_kv_cache_creation.py | 46 +++++++++---------- python/mlc_chat/nn/kv_cache.py | 30 ++++++------ 2 files changed, 35 insertions(+), 41 deletions(-) diff --git a/python/mlc_chat/compiler_pass/dispatch_kv_cache_creation.py b/python/mlc_chat/compiler_pass/dispatch_kv_cache_creation.py index 08cf730f5f..1995b3c517 100644 --- a/python/mlc_chat/compiler_pass/dispatch_kv_cache_creation.py +++ b/python/mlc_chat/compiler_pass/dispatch_kv_cache_creation.py @@ -16,35 +16,33 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]: assert len(func.body.blocks[0].bindings) == 2 assert isinstance(func.body.blocks[0].bindings[0], relax.VarBinding) assert isinstance(func.body.blocks[0].bindings[0].value, relax.Call) - assert isinstance(func.body.blocks[0].bindings[0].value.op, relax.ExternFunc) - assert ( - func.body.blocks[0].bindings[0].value.op.global_symbol - == "mlc.create_paged_kv_cache_generic" - ) - + assert func.body.blocks[0].bindings[0].value.op == tvm.ir.Op.get("relax.call_pure_packed") args = func.body.blocks[0].bindings[0].value.args - assert len(args) == 10 - assert isinstance(args[0], relax.ShapeExpr) - assert len(args[0].values) == 4 - for i in range(1, 9): + assert isinstance(args[0], relax.ExternFunc) + assert args[0].global_symbol == "mlc.create_paged_kv_cache_generic" + + assert len(args) == 11 + assert isinstance(args[1], relax.ShapeExpr) + assert len(args[1].values) == 4 + for i in range(2, 10): assert isinstance(args[i], relax.PrimValue) assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm)) - assert isinstance(args[9], relax.DataTypeImm) + assert isinstance(args[10], relax.DataTypeImm) return { - "max_batch_size": args[0].values[0], - "max_total_seq_len": args[0].values[1], - "prefill_chunk_size": args[0].values[2], - "page_size": args[0].values[3], - "num_hidden_layers": args[1].value.value, - "num_attention_heads": args[2].value.value, - "num_key_value_heads": args[3].value.value, - "head_dim": args[4].value.value, - "rope_mode": args[5].value.value, - "rope_scale": args[6].value.value, - "rope_theta": args[7].value.value, - "rotary_dim": args[8].value.value, - "dtype": args[9].value, + "max_batch_size": args[1].values[0], + "max_total_seq_len": args[1].values[1], + "prefill_chunk_size": args[1].values[2], + "page_size": args[1].values[3], + "num_hidden_layers": args[2].value.value, + "num_attention_heads": args[3].value.value, + "num_key_value_heads": args[4].value.value, + "head_dim": args[5].value.value, + "rope_mode": args[6].value.value, + "rope_scale": args[7].value.value, + "rope_theta": args[8].value.value, + "rotary_dim": args[9].value.value, + "dtype": args[10].value, } diff --git a/python/mlc_chat/nn/kv_cache.py b/python/mlc_chat/nn/kv_cache.py index f63e74d855..027c08bd71 100644 --- a/python/mlc_chat/nn/kv_cache.py +++ b/python/mlc_chat/nn/kv_cache.py @@ -62,23 +62,19 @@ def create_generic( # pylint: disable=too-many-arguments if rotary_dim is None: rotary_dim = head_dim return PagedKVCache( - _expr=rx.Call( - rx.extern("mlc.create_paged_kv_cache_generic"), - args=[ - rx.ShapeExpr( - [max_batch_size, max_total_seq_len, prefill_chunk_size, page_size] - ), - rx.PrimValue(num_hidden_layers), - rx.PrimValue(num_attention_heads), - rx.PrimValue(num_key_value_heads), - rx.PrimValue(head_dim), - rx.PrimValue(rope_mode), - rx.PrimValue(rope_scale), - rx.PrimValue(rope_theta), - rx.PrimValue(rotary_dim), - rx.DataTypeImm(dtype), - ], - sinfo_args=[rx.ObjectStructInfo()], + _expr=rx.call_pure_packed( + "mlc.create_paged_kv_cache_generic", + rx.ShapeExpr([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size]), + rx.PrimValue(num_hidden_layers), + rx.PrimValue(num_attention_heads), + rx.PrimValue(num_key_value_heads), + rx.PrimValue(head_dim), + rx.PrimValue(rope_mode), + rx.PrimValue(rope_scale), + rx.PrimValue(rope_theta), + rx.PrimValue(rotary_dim), + rx.DataTypeImm(dtype), + sinfo_args=rx.ObjectStructInfo(), ), _name=name, ) From d8fedd1b25afc6298c9f77f46fc975b0693c6786 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 11 Mar 2024 19:11:18 -0400 Subject: [PATCH 056/531] [KVCache] Update FlashInfer PackedFunc names (#1931) This PR updates the FlashInfer names given https://github.com/apache/tvm/pull/16692 has been merged. --- python/mlc_chat/nn/kv_cache.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/mlc_chat/nn/kv_cache.py b/python/mlc_chat/nn/kv_cache.py index 027c08bd71..636861f3bd 100644 --- a/python/mlc_chat/nn/kv_cache.py +++ b/python/mlc_chat/nn/kv_cache.py @@ -255,15 +255,15 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals # pylint: disable=line-too-long # fmt: off bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), - rx.extern("paged_kv_cache.attention_kernel_prefill"), - rx.extern("paged_kv_cache.attention_kernel_decode"), + rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache"), + rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), - rx.extern("paged_kv_cache.attention_kernel_prefill_begin_forward"), - rx.extern("paged_kv_cache.attention_kernel_prefill_end_forward"), - rx.extern("paged_kv_cache.attention_kernel_decode_begin_forward"), - rx.extern("paged_kv_cache.attention_kernel_decode_end_forward"), + rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache_begin_forward"), + rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache_end_forward"), + rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_begin_forward"), + rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward"), rx.extern("flashinfer.merge_state_in_place"), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), bb.add_func(llama_inplace_rope(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, target, rotary_dim), "tir_qk_rotary_inplace"), From 4290a053a02dd69ede4162565e29b120052fbe72 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 11 Mar 2024 20:50:25 -0400 Subject: [PATCH 057/531] [REFACTOR] remove tests/legacy-python (#1933) This PR removes the folder tests/legacy-python as a followup cleanup step of the old flow Some of the files like compare lib are useful and we should recover them later at mlc_llm.testing.DebugChat flow --- tests/legacy-python/compare_lib.py | 213 ----------- tests/legacy-python/dump_intermediate.py | 196 ---------- tests/legacy-python/evaluate.py | 202 ----------- tests/legacy-python/module_intercept.py | 147 -------- tests/legacy-python/test_batching_llama.py | 160 --------- tests/legacy-python/test_build_args.py | 175 --------- .../test_build_model_from_args.py | 142 -------- .../legacy-python/test_sliding_window_mask.py | 338 ------------------ 8 files changed, 1573 deletions(-) delete mode 100644 tests/legacy-python/compare_lib.py delete mode 100644 tests/legacy-python/dump_intermediate.py delete mode 100644 tests/legacy-python/evaluate.py delete mode 100644 tests/legacy-python/module_intercept.py delete mode 100644 tests/legacy-python/test_batching_llama.py delete mode 100644 tests/legacy-python/test_build_args.py delete mode 100644 tests/legacy-python/test_build_model_from_args.py delete mode 100644 tests/legacy-python/test_sliding_window_mask.py diff --git a/tests/legacy-python/compare_lib.py b/tests/legacy-python/compare_lib.py deleted file mode 100644 index 5bcea1e699..0000000000 --- a/tests/legacy-python/compare_lib.py +++ /dev/null @@ -1,213 +0,0 @@ -import argparse -import json -import os -from typing import List - -import numpy as np -import torch -import tvm -from transformers import AutoTokenizer, LlamaTokenizer -from tvm import relax, rpc -from tvm.relax.testing.lib_comparator import LibCompareVMInstrument - -from mlc_llm import utils - - -class LibCompare(LibCompareVMInstrument): - def __init__(self, mod, device, time_eval, skip_rounds=0): - super().__init__(mod, device, True) - self.time_eval = time_eval - self.time_eval_results = {} - self.visited = set([]) - self.skip_rounds = skip_rounds - self.atol = 1e-2 - self.rtol = 1e-3 - - def skip_instrument(self, func, name, before_run, ret_val, *args): - print(f"run {name}") - if name.startswith("shape_func"): - return True - if self.counter < self.skip_rounds: - self.counter += 1 - print(f"[{self.counter}] Skip validating {name}..") - return True - if name in self.visited: - if self.time_eval and name in self.time_eval_results: - record = self.time_eval_results[name] - self.time_eval_results[name] = (record[0], record[1] + 1) - return True - self.visited.add(name) - return False - - def compare( - self, - name: str, - ref_args: List[tvm.nd.NDArray], - new_args: List[tvm.nd.NDArray], - ret_indices: List[int], - ): - super().compare(name, ref_args, new_args, ret_indices) - - if self.time_eval and name not in self.time_eval_results: - res = self.mod.time_evaluator( - name, self.device, number=20, repeat=3 # , cache_flush_bytes=256 * 10**6 - )(*new_args) - self.time_eval_results[name] = (res.mean, 1) - print(f"Time-eval result {name} on {self.device}: {res}") - - -def print_as_table(sorted_list): - print( - "Name".ljust(50) - + "Time (ms)".ljust(12) - + "Count".ljust(8) - + "Total time (ms)".ljust(18) - + "Percentage (%)" - ) - total_time = sum([record[1][0] * record[1][1] for record in sorted_list]) * 1000 - for record in sorted_list: - time = record[1][0] * 1000 - weighted_time = time * record[1][1] - percentage = weighted_time / total_time * 100 - print( - record[0].ljust(50) - + "{:.4f}".format(time).ljust(12) - + str(record[1][1]).ljust(8) - + "{:.4f}".format(weighted_time).ljust(18) - + "{:.2f}".format(percentage) - ) - print("Total time: {:.4f} ms".format(total_time)) - print() - - -class TestState: - def __init__(self, args): - self.primary_device = tvm.device(args.primary_device) - ex = tvm.runtime.load_module( - os.path.join( - args.artifact_path, - f"{args.model}-{args.quantization.name}-{args.primary_device}.so", - ) - ) - self.vm = relax.VirtualMachine(ex, self.primary_device) - if args.cmp_device == "iphone": - lib_name = f"{args.model}-{args.quantization.name}-{args.cmp_device}.dylib" - local_lib_path = os.path.join(args.artifact_path, lib_name) - proxy_host = os.environ.get("TVM_RPC_PROXY_HOST", "127.0.0.1") - proxy_port = int(os.environ.get("TVM_RPC_PROXY_PORT", "9090")) - self.sess = rpc.connect(proxy_host, proxy_port, "iphone") - self.sess.upload(local_lib_path) - self.lib = self.sess.load_module(lib_name) - self.cmp_device = self.sess.metal() - elif args.cmp_device == "android": - lib_name = f"{args.model}-{args.quantization.name}-{args.cmp_device}.so" - local_lib_path = os.path.join(args.artifact_path, lib_name) - tracker_host = os.environ.get("TVM_TRACKER_HOST", "0.0.0.0") - tracker_port = int(os.environ.get("TVM_TRACKER_PORT", "9190")) - tracker = rpc.connect_tracker(tracker_host, tracker_port) - self.sess = tracker.request("android") - self.sess.upload(local_lib_path) - self.lib = self.sess.load_module(lib_name) - self.cmp_device = self.sess.cl(0) - else: - self.sess = None - self.lib = tvm.runtime.load_module( - os.path.join( - args.artifact_path, - f"{args.model}-{args.quantization.name}-{args.cmp_device}.so", - ) - ) - self.cmp_device = tvm.device(args.cmp_device) - self.const_params_dict = utils.load_params(args.artifact_path, self.primary_device) - self.cmp_instrument = LibCompare( - self.lib, - self.cmp_device, - time_eval=args.time_eval, - skip_rounds=args.skip_rounds, - ) - self.vm.set_instrument(self.cmp_instrument) - - -def deploy_to_pipeline(args) -> None: - with open(os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), "r") as f: - config = json.load(f) - - primary_device = tvm.device(args.primary_device) - const_params = utils.load_params(args.artifact_path, primary_device) - state = TestState(args) - - if config["model_category"] == "llama": - tokenizer = LlamaTokenizer.from_pretrained( - os.path.join(args.artifact_path, "params"), trust_remote_code=True - ) - else: - tokenizer = AutoTokenizer.from_pretrained( - os.path.join(args.artifact_path, "params"), trust_remote_code=True - ) - - print("Tokenizing...") - inputs = tvm.nd.array( - tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy(), - primary_device, - ) - first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device) - seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) - second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) - kv_caches = state.vm["create_kv_cache"]() - - print("Running inference...") - print("======================= Starts Encoding =======================") - logits, kv_caches = state.vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) - print_as_table( - sorted( - state.cmp_instrument.time_eval_results.items(), - key=lambda x: -(x[1][0] * x[1][1]), - ) - ) - state.cmp_instrument.time_eval_results.clear() - state.cmp_instrument.visited.clear() - print("======================= Starts Decoding =======================") - logits, kv_caches = state.vm["decode"]( - first_sampled_token, second_seq_len_shape, kv_caches, const_params - ) - print_as_table( - sorted( - state.cmp_instrument.time_eval_results.items(), - key=lambda x: -(x[1][0] * x[1][1]), - ) - ) - state.cmp_instrument.time_eval_results.clear() - - -def _parse_args(): - args = argparse.ArgumentParser() - args.add_argument("--local-id", type=str, required=True) - args.add_argument("--artifact-path", type=str, default="dist") - args.add_argument("--primary-device", type=str, default="auto") - args.add_argument("--cmp-device", type=str, required=True) - args.add_argument("--prompt", type=str, default="The capital of Canada is") - args.add_argument("--time-eval", default=False, action="store_true") - args.add_argument("--skip-rounds", type=int, default=0) - parsed = args.parse_args() - parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) - utils.argparse_postproc_common(parsed) - - parsed.artifact_path = os.path.join( - parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" - ) - - if parsed.primary_device == "auto": - if tvm.cuda().exist: - parsed.primary_device = "cuda" - elif tvm.metal().exist: - parsed.primary_device = "metal" - elif tvm.rocm().exist: - parsed.primary_device = "rocm" - else: - raise ValueError("Cannot auto deduce device-name, please set it") - return parsed - - -if __name__ == "__main__": - args = _parse_args() - deploy_to_pipeline(args) diff --git a/tests/legacy-python/dump_intermediate.py b/tests/legacy-python/dump_intermediate.py deleted file mode 100644 index e1da427c00..0000000000 --- a/tests/legacy-python/dump_intermediate.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Debug a model by printing out argument information before and after each function.""" - -import argparse -import json -import os - -import numpy as np -import torch -import tvm -from mlc_llm import utils -from transformers import AutoTokenizer -from tvm import relax -from tvm.runtime import ShapeTuple - -# pylint: disable=redefined-outer-name - - -def _extract_metadata(model_lib): - # pylint: disable=import-outside-toplevel - from tvm.runtime import device, load_module - from tvm.runtime.relax_vm import VirtualMachine - - # pylint: enable=import-outside-toplevel - - return json.loads(VirtualMachine(load_module(model_lib), device("cpu"))["_metadata"]()) - - -class DumpInstrument: # pylint: disable=too-few-public-methods - """Defines what to do before and after each function.""" - - def __init__(self, verbose=True): - self.verbose = verbose - self.counter = 0 - self.first_nan_occurred = False - self.first_inf_occurred = False - - def __call__(self, func, name, before_run, ret_val, *args): - # Determine what functions to look at - if before_run: # Whether before the function is called or after - return - # if self.first_nan_occurred: - # return - # if self.first_inf_occurred: - # return - if name.startswith("vm.builtin."): - return - if any(not isinstance(x, tvm.nd.NDArray) for x in args): - return - - # Decide what to print or save about the function's arguments (where args[-1] is the - # buffer we write the result to) - func_name = ( - f"f{self.counter}_before_{name}" if before_run else f"f{self.counter}_after_{name}" - ) - print(func_name) - - # Write your own behavior below. For example, we can count the number of INF/NaN in args[-1] - num_nans = np.sum(np.isnan(args[-1].numpy())) - num_infs = np.sum(np.isinf(args[-1].numpy())) - if num_nans > 0: - print(f"has NaN: {num_nans}") - self.first_nan_occurred = True - if num_infs > 0: - print(f"has INF: {num_infs}") - self.first_inf_occurred = True - - # You can also save the the arguments to experiment offline - # if self.counter == 769: - # for i, ndarray in enumerate(args): - # save_name = func_name + f"_arg{i}" - # np.save(f"./debug/{save_name}.npy", ndarray.numpy()) - - self.counter += 1 - - -def print_as_table(sorted_list): # pylint: disable=missing-function-docstring - # pylint: disable=consider-using-f-string - print( - "Name".ljust(50) - + "Time (ms)".ljust(12) - + "Count".ljust(8) - + "Total time (ms)".ljust(18) - + "Percentage (%)" - ) - total_time = sum([record[1][0] * record[1][1] for record in sorted_list]) * 1000 - for record in sorted_list: - time = record[1][0] * 1000 - weighted_time = time * record[1][1] - percentage = weighted_time / total_time * 100 - print( - record[0].ljust(50) - + "{:.4f}".format(time).ljust(12) - + str(record[1][1]).ljust(8) - + "{:.4f}".format(weighted_time).ljust(18) - + "{:.2f}".format(percentage) - ) - print("Total time: {:.4f} ms".format(total_time)) - print() - - -class TestState: - """Embodies the virtual machine and instrument.""" - - def __init__(self, args): - self.primary_device = tvm.device(args.primary_device) - ex = tvm.runtime.load_module(args.model_lib_path) - self.vm = relax.VirtualMachine(ex, self.primary_device) - self.sess = None - self.instrument = DumpInstrument(verbose=True) - self.vm.set_instrument(self.instrument) - - -def deploy_to_pipeline(args) -> None: - """Main pipeline forst testing; can be modified for specific testing purposes.""" - primary_device = tvm.device(args.primary_device) - model_metadata = _extract_metadata(args.model_lib_path) - const_params = utils.load_params_SLM(args.model, primary_device, model_metadata) - state = TestState(args) - tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.model), trust_remote_code=True) - - print("Tokenizing...") - inputs = tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy() - inputs = tvm.nd.array(inputs, device=primary_device) - first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device) - - kv_cache_method: str - if state.vm.module.implements_function( - "create_tir_paged_kv_cache" - ) or state.vm.module.implements_function("create_flashinfer_paged_kv_cache"): - kv_cache_method = "paged_kv_cache" - raise NotImplementedError() - elif state.vm.module.implements_function("create_rnn_state"): - kv_cache_method = "rnn_state" - max_num_seq, history = ShapeTuple([1]), ShapeTuple([1]) - kv_caches = state.vm.module["create_rnn_state"](max_num_seq, history) - f_add_seq = tvm.get_global_func("vm.builtin.kv_state_add_sequence") - f_begin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward") - f_end_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward") - elif state.vm.module.implements_function("_initialize_effect"): - kv_cache_method = "effect" - kv_caches = state.vm.module["_initialize_effect"]() - else: - raise ValueError("Unknown how to create KVCache") - - def forward(inputs, kv_caches, total_seq_len): - hidden = state.vm["embed"](inputs, const_params) - if inputs.shape[1] > 1: - f_forward = state.vm["prefill"] - else: - f_forward = state.vm["decode"] - if kv_cache_method == "effect": - logits, kv_caches = f_forward( - hidden, ShapeTuple([total_seq_len]), kv_caches, const_params - ) - else: - seq_ids, input_shape = ShapeTuple([0]), ShapeTuple([inputs.shape[1]]) - f_begin_forward(kv_caches, seq_ids, input_shape) - logits, kv_caches = f_forward(hidden, kv_caches, const_params) - f_end_forward(kv_caches) - - return logits, kv_caches - - print("Running inference...") - - print("======================= Starts Prefilling ======================") - - if kv_cache_method != "effect": - f_add_seq(kv_caches, 0) - logits, kv_caches = forward(inputs, kv_caches, inputs.shape[1]) - - print("======================= Starts Decoding =======================") - - logits, kv_caches = forward(first_sampled_token, kv_caches, inputs.shape[1] + 1) - - -def _parse_args(): - args = argparse.ArgumentParser() - args.add_argument("--model", type=str, required=True) # The model weight folder - args.add_argument("--model-lib-path", type=str, required=True) # Path to the model library - args.add_argument("--primary-device", type=str, default="auto") # Device to run on - args.add_argument("--prompt", type=str, default="The capital of Canada is") - parsed = args.parse_args() - - if parsed.primary_device == "auto": - if tvm.cuda().exist: - parsed.primary_device = "cuda" - elif tvm.metal().exist: - parsed.primary_device = "metal" - else: - raise ValueError("Cannot auto deduce device-name, please set it") - return parsed - - -if __name__ == "__main__": - args = _parse_args() - deploy_to_pipeline(args) diff --git a/tests/legacy-python/evaluate.py b/tests/legacy-python/evaluate.py deleted file mode 100644 index 4a370c517c..0000000000 --- a/tests/legacy-python/evaluate.py +++ /dev/null @@ -1,202 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring -# Used as reference - -import argparse -import json -import os -import time -from typing import List, Tuple - -import numpy as np -import torch -import tvm -from transformers import AutoTokenizer, LlamaTokenizer # type: ignore[import] -from tvm import relax -from tvm.relax.testing.lib_comparator import LibCompareVMInstrument -from tvm.runtime import ShapeTuple - -from mlc_llm import utils - - -def _parse_args(): - args = argparse.ArgumentParser() - args.add_argument("--local-id", type=str, required=True) - args.add_argument("--device-name", type=str, default="auto") - args.add_argument("--debug-dump", action="store_true", default=False) - args.add_argument("--artifact-path", type=str, default="dist") - args.add_argument("--prompt", type=str, default="The capital of Canada is") - args.add_argument("--profile", action="store_true", default=False) - parsed = args.parse_args() - parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) - utils.argparse_postproc_common(parsed) - parsed.artifact_path = os.path.join( - parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" - ) - return parsed - - -class LibCompare(LibCompareVMInstrument): - def __init__(self, mod, device): - super().__init__(mod, device, verbose=False) - self.time_eval_results = {} - - def compare( - self, - name: str, - ref_args: List[tvm.nd.NDArray], - new_args: List[tvm.nd.NDArray], - ret_indices: List[int], - ): - if name.startswith("shape_func"): - return - if name not in self.time_eval_results: - super().compare(name, ref_args, new_args, ret_indices) - res = self.mod.time_evaluator( - name, - dev=self.device, - number=100, - repeat=3, - )(*new_args).mean - shapes = [arg.shape for arg in new_args] - total_bytes = sum(arg.numpy().size * arg.numpy().itemsize for arg in new_args) - self.time_eval_results[name] = (res, 1, shapes, total_bytes) - else: - record = self.time_eval_results[name] - self.time_eval_results[name] = ( - record[0], - record[1] + 1, - record[2], - record[3], - ) - - -def print_as_table(sorted_list: List[Tuple[str, Tuple[float, int]]]): - print( - "Name".ljust(50) - + "Time (ms)".ljust(12) - + "Count".ljust(8) - + "Total time (ms)".ljust(18) - + "Pct (%)".ljust(10) - + "Memory (MB)".ljust(16) - + "Bandwidth (GB/s)".ljust(18) - + "Shape" - ) - total_time = sum(record[1][0] * record[1][1] for record in sorted_list) * 1000 - for record in sorted_list: - time_used = record[1][0] * 1000 - weighted_time = time_used * record[1][1] - percentage = weighted_time / total_time * 100 - total_bytes = record[1][3] - bandwidth = total_bytes / record[1][0] / (1024**3) - - print( - record[0].ljust(50) - + f"{time_used:.4f}".ljust(12) - + str(record[1][1]).ljust(8) - + f"{weighted_time:.4f}".ljust(18) - + f"{percentage:.2f}".ljust(10) - + f"{total_bytes / (1024 * 1024):.2f}".ljust(16) - + f"{bandwidth:.4f}".format(bandwidth).ljust(18) - + ", ".join(str(s) for s in record[1][2]) - ) - print(f"Total time: {total_time:.4f} ms") - print() - - -def deploy_to_pipeline(args) -> None: # pylint: disable=too-many-locals - device = tvm.device(args.device_name) - const_params = utils.load_params(args.artifact_path, device) - ex = tvm.runtime.load_module( - os.path.join( - args.artifact_path, - f"{args.model}-{args.quantization.name}-{args.device_name}.so", - ) - ) - vm = relax.VirtualMachine(ex, device) - - with open( - os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), - "r", - encoding="utf-8", - ) as f: - config = json.load(f) - - if config["model_category"] == "llama": - tokenizer = LlamaTokenizer.from_pretrained( - os.path.join(args.artifact_path, "params"), trust_remote_code=True - ) - else: - tokenizer = AutoTokenizer.from_pretrained( - os.path.join(args.artifact_path, "params"), trust_remote_code=True - ) - - print("Tokenizing...") - inputs = tvm.nd.array( - tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy(), - device, - ) - first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), device) - seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) - second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) - kv_caches = vm["create_kv_cache"]() - # skip warm up - - logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) - logits, kv_caches = vm["decode"]( - first_sampled_token, second_seq_len_shape, kv_caches, const_params - ) - device.sync() - - kv_caches = vm["create_kv_cache"]() - print("Running inference...") - start = time.time() - logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) - device.sync() - encoding_end = time.time() - logits, kv_caches = vm["decode"]( - first_sampled_token, second_seq_len_shape, kv_caches, const_params - ) - device.sync() - end = time.time() - if args.debug_dump: - fcache_view = tvm.get_global_func("vm.builtin.attention_kv_cache_view") - first_k_cache = fcache_view(kv_caches[0], ShapeTuple([7, 32, 128])) - print(f"output kv_cache[0]:\n{first_k_cache.numpy().transpose(1, 0, 2)}") - print(f"output logits:\n{logits.numpy()}") - print( - f"Time elapsed: encoding {(encoding_end - start)} seconds, " - f"decoding {end - encoding_end} secs" - ) - - if args.profile: - cmp_instrument = LibCompare(ex, device) - vm.set_instrument(cmp_instrument) - - print("Profiling...") - kv_caches = vm["create_kv_cache"]() - - logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) - print("======================= Encoding Profiling =======================") - print_as_table( - sorted( - cmp_instrument.time_eval_results.items(), - key=lambda x: -(x[1][0] * x[1][1]), - ) - ) - cmp_instrument.time_eval_results.clear() - - logits, kv_caches = vm["decode"]( - first_sampled_token, second_seq_len_shape, kv_caches, const_params - ) - print("======================= Decoding Profiling =======================") - print_as_table( - sorted( - cmp_instrument.time_eval_results.items(), - key=lambda x: -(x[1][0] * x[1][1]), - ) - ) - - -if __name__ == "__main__": - ARGS = _parse_args() - deploy_to_pipeline(ARGS) diff --git a/tests/legacy-python/module_intercept.py b/tests/legacy-python/module_intercept.py deleted file mode 100644 index e63bb21de6..0000000000 --- a/tests/legacy-python/module_intercept.py +++ /dev/null @@ -1,147 +0,0 @@ -"""This script is an example of running and comparing the outputs of two different TVM Relax VMs. -""" -# pylint: disable=missing-docstring,invalid-name -import json - -import numpy as np -import torch -import tvm -from transformers import LlamaTokenizer -from tvm import relax -from tvm.contrib import tvmjs - -KVCACHE_FUNCS = [ - "vm.builtin.attention_kv_cache_append", - "vm.builtin.attention_kv_cache_view", -] -DEVICE = "cuda:0" -PROMPT = "What is the meaning of life?" -TOKENIZER = "./dist/debug-llama/" - -COMBO = { - "CURRENT": { - "model_lib": "./dist/debug-llama/llama.so", - "params": "./dist/debug-llama", - "target_func": "fused_fused_dequantize1_NT_matmul6", - }, - "LEGACY": { - "model_lib": "./dist/Llama-2-7b-chat-hf-q4f16_1/Llama-2-7b-chat-hf-q4f16_1-cuda.so", - "params": "./dist/Llama-2-7b-chat-hf-q4f16_1/params", - "target_func": "fused_fused_decode2_NT_matmul", - }, -} - - -class Instrument: # pylint: disable=too-few-public-methods - def __init__( - self, - target_func: str, - ): - self.first_time = True - self.target_func = target_func - self.saved_args = [] # type: ignore - - def __call__( - self, - func, - func_symbol: str, - before_run: bool, - ret_value, - *args, - ): - if before_run: - return - if func_symbol.startswith("vm.builtin."): - if func_symbol not in KVCACHE_FUNCS: - return - if func_symbol == self.target_func and self.first_time: - self.first_time = False - for arg in args: - print(arg.shape, arg.dtype) - self.saved_args.append(arg.numpy()) - - -class TestState: - def __init__(self, device, model_lib, target_func): - self.mod = relax.VirtualMachine( - tvm.runtime.load_module(model_lib), - device, - ) - self.inst = Instrument(target_func=target_func) - self.mod.set_instrument(self.inst) - - -def _tokenize(sentence: str): - tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER, trust_remote_code=True) - tokens = tokenizer(PROMPT, return_tensors="pt").input_ids.to(torch.int32).numpy() - print(f"Tokenizing: {sentence}") - print(f"Tokens: {tokens}") - return tokens - - -def _load_params(params, device, metadata): - param_dict, _ = tvmjs.load_ndarray_cache(params, device) - param_list = [] - for name in [x["name"] for x in metadata["params"]]: - param_list.append(param_dict[name]) - return param_list - - -def _load_params_legacy(params, device): - param_dict, metadata = tvmjs.load_ndarray_cache(params, device) - param_list = [] - for i in range(metadata["ParamSize"]): - param_list.append(param_dict[f"param_{i}"]) - return param_list - - -def _as_input_tuple(scalar): - return tvm.runtime.ShapeTuple([scalar]) - - -@tvm.register_func("debug_save") -def _debug_save(x, _): - return tvm.nd.array(x.numpy(), x.device) - - -def main() -> None: - device = tvm.device(DEVICE) - prompt = _tokenize(PROMPT) - - def _run_legacy(model_lib, params, target_func): - state = TestState(device, model_lib, target_func) - kv_cache = state.mod["create_kv_cache"]() - param_list = _load_params_legacy(params, device) - state.mod["prefill"]( - tvm.nd.array(prompt, device), - _as_input_tuple(len(prompt[0])), - kv_cache, - param_list, - ) - return state.inst.saved_args - - def _run_current(model_lib, params, target_func): - state = TestState(device, model_lib, target_func) - metadata = json.loads(state.mod["_metadata"]()) - kv_cache = state.mod["_initialize_effect"]() - param_list = _load_params(params, device, metadata) - state.mod["prefill"]( - tvm.nd.array(prompt, device), - _as_input_tuple(len(prompt[0])), - kv_cache, - param_list, - ) - return state.inst.saved_args - - print("============== Running old flow =================") - new_args = _run_current(**COMBO["CURRENT"]) - print("============== Running new flow =================") - old_args = _run_legacy(**COMBO["LEGACY"]) - - for i, (new_arg, old_arg) in enumerate(zip(new_args, old_args)): - print(f"Checking arg {i}") - np.testing.assert_allclose(new_arg, old_arg, rtol=1e-12, atol=1e-12) - - -if __name__ == "__main__": - main() diff --git a/tests/legacy-python/test_batching_llama.py b/tests/legacy-python/test_batching_llama.py deleted file mode 100644 index ff11188e4b..0000000000 --- a/tests/legacy-python/test_batching_llama.py +++ /dev/null @@ -1,160 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring -# Used as reference - -import argparse -import json -import os - -import numpy as np -import torch -import tvm -from transformers import LlamaTokenizer # type: ignore[import] -from tvm import relax -from tvm.runtime import ShapeTuple - -from mlc_llm import utils - -############################################################## -# Test file for e2e Llama with batching enabled by directly -# calling functions in VM. -# -# NOTE: the test will not be runnable until the attention -# compute function is integrated to Llama. This is left as -# an item that we will work on shortly in the future. -############################################################## - - -def _parse_args(): - args = argparse.ArgumentParser() - args.add_argument("--local-id", type=str, default="Llama-2-7b-chat-hf-q4f16_1") - args.add_argument("--device-name", type=str, default="auto") - args.add_argument("--artifact-path", type=str, default="dist") - args.add_argument("--prompt", type=str, default="What's the meaning of life?") - args.add_argument("--profile", action="store_true", default=False) - parsed = args.parse_args() - parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) - utils.argparse_postproc_common(parsed) - parsed.artifact_path = os.path.join( - parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" - ) - return parsed - - -def sample_from_logits(vm, logits, device): - temperature = 0.7 - top_p = 0.95 - - num_sequence = logits.shape[0] - temperature_arr = tvm.nd.array(np.full((num_sequence,), temperature, dtype="float32"), device) - probs = vm["softmax_with_temperature"](logits, temperature_arr).numpy() - - sampled_tokens = [] - fsample_top_p_from_prob = tvm.get_global_func("vm.builtin.sample_top_p_from_prob") - for seq_id in range(num_sequence): - token = fsample_top_p_from_prob(tvm.nd.array(probs[seq_id]), top_p, np.random.sample()) - sampled_tokens.append(token) - return sampled_tokens - - -def deploy_to_pipeline(args) -> None: # pylint: disable=too-many-locals - device = tvm.device(args.device_name) - const_params = utils.load_params(args.artifact_path, device) - ex = tvm.runtime.load_module( - os.path.join( - args.artifact_path, - f"{args.model}-{args.quantization.name}-{args.device_name}.so", - ) - ) - vm = relax.VirtualMachine(ex, device) - - with open( - os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), - "r", - encoding="utf-8", - ) as f: - config = json.load(f) - - assert config["model_category"] == "llama" - tokenizer = LlamaTokenizer.from_pretrained( - os.path.join(args.artifact_path, "params"), trust_remote_code=True - ) - - num_sequences = 4 - generated_tokens = [[], [], [], []] - prompts = [ - "What's the meaning of life?", - "Introduce the history of Pittsburgh to me.", - "Write a three-day Seattle travel plan.", - "What is Alaska famous of?", - ] - num_decode_steps = 256 - - print("Create KV cache...") - max_total_seq_len = 16384 - page_size = 16 - kv_cache = vm["create_kv_cache"](ShapeTuple([num_sequences, max_total_seq_len, page_size])) - - fadd_sequence = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_add_sequence") - freset_append_length = tvm.get_global_func( - "vm.builtin.paged_attention_kv_cache_reset_append_lengths" - ) - freserve = tvm.get_global_func( - "vm.builtin.paged_attention_kv_cache_reserve_extra_length_for_append" - ) - fsync = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_sync_aux_array_to_device") - - for seq_id in range(num_sequences): - print(f"Process seq {seq_id} for prefill...") - inputs = tvm.nd.array( - tokenizer(prompts[seq_id], return_tensors="pt").input_ids.to(torch.int32).numpy(), - device, - ) - seq_length = inputs.shape[1] - embedding = vm["embed"](inputs, const_params) - - seq_id_in_cache = fadd_sequence(kv_cache) - assert seq_id_in_cache == seq_id - - freset_append_length(kv_cache) - freserve(kv_cache, seq_id, seq_length) - fsync(kv_cache) - - print(f"Prefilling seq {seq_id}...") - logits, _ = vm["prefill_with_embed"](embedding, kv_cache, const_params) - - tokens = sample_from_logits(vm, logits, device) - assert len(tokens) == 1 - generated_tokens[seq_id].append(tokens[0]) - - print("Decoding...") - for step in range(num_decode_steps): - inputs = tvm.nd.array( - np.array( - [[generated_tokens[seq_id][-1]] for seq_id in range(num_sequences)], dtype="int32" - ), - device, - ) - embedding = vm["embed"](inputs, const_params) - freset_append_length(kv_cache) - for seq_id in range(num_sequences): - freserve(kv_cache, seq_id, 1) - fsync(kv_cache) - - logits, _ = vm["decode_with_embed"](embedding, kv_cache, const_params) - tokens = sample_from_logits(vm, logits, device) - assert len(tokens) == num_sequences - - for seq_id in range(num_sequences): - generated_tokens[seq_id].append(tokens[seq_id]) - - for seq_id in range(num_sequences): - output = tokenizer.decode(generated_tokens[seq_id]) - print("====================================================================") - print(f"Prompt {seq_id}: {prompts[seq_id]}") - print(f"Output: {output}") - print("\n\n") - - -if __name__ == "__main__": - ARGS = _parse_args() - deploy_to_pipeline(ARGS) diff --git a/tests/legacy-python/test_build_args.py b/tests/legacy-python/test_build_args.py deleted file mode 100644 index 8f32d123b6..0000000000 --- a/tests/legacy-python/test_build_args.py +++ /dev/null @@ -1,175 +0,0 @@ -"""For testing the functionality of `BuildArgs` and `convert_build_args_to_argparser`.""" -import argparse -import dataclasses -import unittest - -from mlc_llm import BuildArgs, core, utils - - -def old_make_args(): - """The exact old way of creating `ArgumentParser`, used to test whether - `BuildArgs` is equivalent to this.""" - args = argparse.ArgumentParser() - args.add_argument( - "--model", - type=str, - default="auto", - help=( - 'The name of the model to build. If it is "auto", we will ' - 'automatically set the model name according to "--model-path", ' - '"hf-path" or the model folders under "--artifact-path/models"' - ), - ) - args.add_argument( - "--hf-path", - type=str, - default=None, - help="Hugging Face path from which to download params, tokenizer, and config", - ) - args.add_argument( - "--quantization", - type=str, - choices=[*utils.quantization_schemes.keys()], - default=list(utils.quantization_schemes.keys())[0], - help="The quantization mode we use to compile.", - ) - args.add_argument( - "--max-seq-len", - type=int, - default=-1, - help="The maximum allowed sequence length for the model.", - ) - args.add_argument( - "--target", type=str, default="auto", help="The target platform to compile the model for." - ) - args.add_argument( - "--reuse-lib", - type=str, - default=None, - help="Whether to reuse a previously generated lib.", - ) - args.add_argument( - "--artifact-path", type=str, default="dist", help="Where to store the output." - ) - args.add_argument( - "--use-cache", - type=int, - default=1, - help="Whether to use previously pickled IRModule and skip trace.", - ) - args.add_argument( - "--debug-dump", - action="store_true", - default=False, - help="Whether to dump debugging files during compilation.", - ) - args.add_argument( - "--debug-load-script", - action="store_true", - default=False, - help="Whether to load the script for debugging.", - ) - args.add_argument( - "--llvm-mingw", - type=str, - default="", - help="/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows.", - ) - args.add_argument( - "--system-lib", action="store_true", default=False, help="A parameter to `relax.build`." - ) - args.add_argument( - "--sep-embed", - action="store_true", - default=False, - help=( - "Build with separated embedding layer, only applicable to LlaMa. " - "This feature is in testing stage, and will be formally replaced after " - "massive overhaul of embedding feature for all models and use cases" - ), - ) - - return args - - -# Referred to HfArgumentParserTest from https://github.com/huggingface/ -# transformers/blob/e84bf1f734f87aa2bedc41b9b9933d00fc6add98/tests/utils -# /test_hf_argparser.py#L143 -class BuildArgsTest(unittest.TestCase): - """Tests whether BuildArgs reaches parity with regular ArgumentParser.""" - - def argparsers_equal(self, parse_a: argparse.ArgumentParser, parse_b: argparse.ArgumentParser): - """ - Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances. - """ - self.assertEqual( - len(parse_a._actions), len(parse_b._actions) - ) # pylint: disable=protected-access - for x, y in zip(parse_a._actions, parse_b._actions): # pylint: disable=protected-access - xx = {k: v for k, v in vars(x).items() if k != "container"} - yy = {k: v for k, v in vars(y).items() if k != "container"} - # Choices with mixed type have custom function as "type" - # So we need to compare results directly for equality - if xx.get("choices", None) and yy.get("choices", None): - for expected_choice in yy["choices"] + xx["choices"]: - self.assertEqual(xx["type"](expected_choice), yy["type"](expected_choice)) - del xx["type"], yy["type"] - - self.assertEqual(xx, yy) - - def test_new_and_old_arg_parse_are_equivalent(self): - """Tests whether creating `ArgumentParser` from `BuildArgs` is equivalent - to the conventional way of creating it.""" - self.argparsers_equal(core.convert_build_args_to_argparser(), old_make_args()) - - def test_namespaces_are_equivalent_str(self): - """Tests whether the resulting namespaces from command line entry - and Python API entry are equivalent, as they are passed down to the - same workflow.""" - # Namespace that would be created through Python API build_model - build_args = BuildArgs(model="RedPJ", target="cuda") - build_args_as_dict = dataclasses.asdict(build_args) - build_args_namespace = argparse.Namespace(**build_args_as_dict) - - # Namespace that would be created through commandline - empty_args = core.convert_build_args_to_argparser() - parsed_args = empty_args.parse_args(["--model", "RedPJ", "--target", "cuda"]) - - self.assertEqual(build_args_namespace, parsed_args) - - # Modify build_args so that it would not be equivalent - build_args = BuildArgs(model="RedPJ", target="vulkan") - build_args_as_dict = dataclasses.asdict(build_args) - build_args_namespace = argparse.Namespace(**build_args_as_dict) - - self.assertNotEqual(build_args_namespace, parsed_args) - - def test_namespaces_are_equivalent_str_boolean_int(self): - """Same test, but for a mixture of argument types.""" - # 1. Equal - build_args = BuildArgs(model="RedPJ", max_seq_len=20, debug_dump=True) - build_args_as_dict = dataclasses.asdict(build_args) - build_args_namespace = argparse.Namespace(**build_args_as_dict) - - # Namespace that would be created through commandline - empty_args = core.convert_build_args_to_argparser() - parsed_args = empty_args.parse_args( - ["--model", "RedPJ", "--max-seq-len", "20", "--debug-dump"] - ) - self.assertEqual(build_args_namespace, parsed_args) - - # 2. Not equal - missing boolean - build_args = BuildArgs(model="RedPJ", max_seq_len=20) - build_args_as_dict = dataclasses.asdict(build_args) - build_args_namespace = argparse.Namespace(**build_args_as_dict) - self.assertNotEqual(build_args_namespace, parsed_args) - - # 3. Not equal - different integer - build_args = BuildArgs(model="RedPJ", max_seq_len=18, debug_dump=True) - build_args_as_dict = dataclasses.asdict(build_args) - build_args_namespace = argparse.Namespace(**build_args_as_dict) - self.assertNotEqual(build_args_namespace, parsed_args) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/legacy-python/test_build_model_from_args.py b/tests/legacy-python/test_build_model_from_args.py deleted file mode 100644 index b342e035bb..0000000000 --- a/tests/legacy-python/test_build_model_from_args.py +++ /dev/null @@ -1,142 +0,0 @@ -import argparse -import os -import unittest -from unittest.mock import MagicMock, mock_open, patch - -from mlc_llm import utils -from mlc_llm.core import build_model_from_args - - -class MockMkdir(object): - def __init__(self): - self.received_args = None - - def __call__(self, *args): - self.received_args = args - - -class BuildModelTest(unittest.TestCase): - def setUp(self): - self._orig_mkdir = os.mkdir - os.mkdir = MockMkdir() - - self.mock_args = argparse.Namespace() - self.mock_args.quantization = utils.quantization_schemes["q8f16_1"] - self.mock_args.debug_dump = False - self.mock_args.use_cache = False - self.mock_args.sep_embed = False - self.mock_args.build_model_only = True - self.mock_args.use_safetensors = False - self.mock_args.convert_weights_only = False - self.mock_args.no_cutlass_attn = True - self.mock_args.no_cutlass_norm = True - self.mock_args.reuse_lib = True - self.mock_args.artifact_path = "/tmp/" - self.mock_args.model_path = "/tmp/" - self.mock_args.model = "/tmp/" - self.mock_args.target_kind = "cuda" - self.mock_args.max_seq_len = 2048 - - def tearDown(self): - os.mkdir = self._orig_mkdir - - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect=[{}])) - def test_llama_model(self, mock_file): - self.mock_args.model_category = "llama" - - build_model_from_args(self.mock_args) - - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch( - "json.load", - MagicMock( - side_effect=[ - { - "use_parallel_residual": False, - "hidden_size": 32, - "intermediate_size": 32, - "num_attention_heads": 32, - "num_hidden_layers": 28, - "vocab_size": 1024, - "rotary_pct": 1, - "rotary_emb_base": 1, - "layer_norm_eps": 1, - } - ] - ), - ) - def test_gpt_neox_model(self, mock_file): - self.mock_args.model_category = "gpt_neox" - self.mock_args.model = "dolly-test" - - build_model_from_args(self.mock_args) - - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect=[{}])) - def test_gpt_bigcode_model(self, mock_file): - self.mock_args.model_category = "gpt_bigcode" - self.mock_args.model = "gpt_bigcode" - - build_model_from_args(self.mock_args) - - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect=[{}])) - def test_minigpt_model(self, mock_file): - self.mock_args.model_category = "minigpt" - self.mock_args.model = "minigpt4-7b" - - build_model_from_args(self.mock_args) - - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch( - "json.load", - MagicMock( - side_effect=[ - { - "vocab_size": 1024, - "n_embd": 32, - "n_inner": 32, - "n_head": 32, - "n_layer": 28, - "bos_token_id": 28, - "eos_token_id": 1, - "rotary_dim": 1, - "tie_word_embeddings": 1, - } - ] - ), - ) - def test_gptj_model(self, mock_file): - self.mock_args.model_category = "gptj" - self.mock_args.model = "gpt-j-" - - build_model_from_args(self.mock_args) - - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch( - "json.load", - MagicMock( - side_effect=[ - { - "num_hidden_layers": 16, - "vocab_size": 1024, - "hidden_size": 16, - "intermediate_size": 32, - } - ] - ), - ) - def test_rwkv_model(self, mock_file): - self.mock_args.model_category = "rwkv" - self.mock_args.model = "rwkv-" - - build_model_from_args(self.mock_args) - - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect=[{}])) - def test_chatglm_model(self, mock_file): - self.mock_args.model_category = "chatglm" - self.mock_args.model = "chatglm2" - - build_model_from_args(self.mock_args) diff --git a/tests/legacy-python/test_sliding_window_mask.py b/tests/legacy-python/test_sliding_window_mask.py deleted file mode 100644 index 51be2d0749..0000000000 --- a/tests/legacy-python/test_sliding_window_mask.py +++ /dev/null @@ -1,338 +0,0 @@ -# fmt: off -"""For testing `_make_sliding_window_mask` in mistral.py""" - -import unittest - -import numpy as np -import tvm -from mlc_llm.relax_model.mistral import _make_sliding_window_mask -from tvm import relax -from tvm.runtime import ShapeTuple - - -def _create_vm(): - # pylint: disable=too-many-locals - bb = relax.BlockBuilder() - - # Step 1: Build `_make_sliding_window_mask()` into an IRModule - bsz = tvm.tir.Var("bsz", "int64") - seq_length = tvm.tir.Var("seq_length", "int64") # tgt_len - kv_seq_len = tvm.tir.Var("kv_seq_len", "int64") - sliding_window = tvm.tir.Var("sliding_window", "int64") - - with bb.function("main"): - # Convert to relax.Var because params to an IRModule function needs to be relax.Var - bsz_shape = relax.Var("bsz", relax.ShapeStructInfo((bsz,))) - seq_length_shape = relax.Var("seq_length", relax.ShapeStructInfo((seq_length,))) - kv_seq_len_shape = relax.Var("kv_seq_len", relax.ShapeStructInfo((kv_seq_len,))) - sliding_window_shape = relax.Var("sliding_window", relax.ShapeStructInfo((sliding_window,))) - - # Convert back to tir.Var since `_prepare_sliding_window_mask` needs it to be tir.Var - with bb.dataflow(): - bsz_input = bsz_shape.struct_info.values[0] - seq_length_input = seq_length_shape.struct_info.values[0] - kv_seq_len_input = kv_seq_len_shape.struct_info.values[0] - sliding_window_input = sliding_window_shape.struct_info.values[0] - mask = _make_sliding_window_mask( - (bsz_input, seq_length_input), - kv_seq_len_input, - sliding_window_input, - "float32", - ) - params = [ - bsz_shape, - seq_length_shape, - kv_seq_len_shape, - sliding_window_shape, - ] - gv = bb.emit_output(mask) - bb.emit_func_output(gv, params) - - # Step 2. Optimize IRModule - mod = bb.get() - mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter - with tvm.target.Target("cuda"): - mod = tvm.tir.transform.DefaultGPUSchedule()(mod) - - # Step 3. Deploy to GPU - ex = relax.build(mod, "cuda") - vm = relax.VirtualMachine(ex, tvm.cuda()) #pylint: disable=redefined-outer-name - return vm - - -vm = _create_vm() - -class SlidingWindowMaskTest(unittest.TestCase): - """ - The sliding window mask is based on figure 3 of the Mistral paper. - There are three cases when making a mask: first prefill, subsequent prefill, - and decoding. - - 1. First Prefill - This is when the cache is empty (i.e. kv_seq_len == 0). If tgt_len <= sliding_window, - this is just a normal causal mask. Otherwise, e.g. tgt_len = 3, WS = 2, we create a - mask below: - 1, 0, 0 - 1, 1, 0 - 0, 1, 1 - - 2. Subsequent Prefill - This is when the cache is not empty and yet tgt_len > 1. - e.g. t0-t4 in cache; current input is t5-t7; WS=5 - 0, 1, 2, 3, 4, | 5, 6, 7 - - 0, 1, 1, 1, 1, | 1, 0, 0 - 0, 0, 1, 1, 1, | 1, 1, 0 - 0, 0, 0, 1, 1, | 1, 1, 1 - [in cache] [current] - - 3. Decode - It will always be ones with shape (1 + kv_seq_len) since cache_size equals sliding_window. - Note that a prefilling (first or subsequent) with chunk_size of 1 is equivalent to a decode - in mask making. - """ - - ################### 1. TESTS FOR FIRST PREFILL ################### - def test_first_prefill_chunk_size_smaller_than_ws(self): - """ - When chunk size < WS, we return a normal causal mask. - Here, chunk size 3, WS 5. - """ - bsz = ShapeTuple([1]) - seq_length = ShapeTuple([3]) # chunk size is 3 - kv_seq_len = ShapeTuple([3]) - sliding_window = ShapeTuple([5]) - - result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) - - correct = np.array([[[ - [3.402823e38, -3.402823e38, -3.402823e38], - [3.402823e38, 3.402823e38, -3.402823e38], - [3.402823e38, 3.402823e38, 3.402823e38], - ]]]).astype("float32") - - np.testing.assert_array_equal(result.numpy(), correct) - - def test_first_prefill_chunk_size_equals_ws(self): - """ - When chunk_size == WS, we also return a normal causal mask. - Here both chunk size and WS are 5. - """ - bsz = ShapeTuple([1]) - seq_length = ShapeTuple([5]) - kv_seq_len = ShapeTuple([5]) - sliding_window = ShapeTuple([5]) - - result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) - - correct = np.array([[[ - [3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38], - [3.402823e38, 3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38], - [3.402823e38, 3.402823e38, 3.402823e38, -3.402823e38, -3.402823e38], - [3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38, -3.402823e38], - [3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38], - ]]]).astype("float32") - - np.testing.assert_array_equal(result.numpy(), correct) - - def test_first_prefill_chunk_size_greater_than_ws(self): - """ - When chunk_size > WS, return a normal causal mask but each row only has at most WS 1's. - Here chunk_size = 5, WS=3. - """ - bsz = ShapeTuple([1]) - seq_length = ShapeTuple([5]) - kv_seq_len = ShapeTuple([5]) - sliding_window = ShapeTuple([3]) - - result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) - - correct = np.array([[[ - [3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38], - [3.402823e38, 3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38], - [3.402823e38, 3.402823e38, 3.402823e38, -3.402823e38, -3.402823e38], - [-3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38, -3.402823e38], - [-3.402823e38, -3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38], - ]]]).astype("float32") - - np.testing.assert_array_equal(result.numpy(), correct) - - def test_first_prefill_chunk_size_one(self): - """ - Corner case: the prompt only has 1 token. - """ - bsz = ShapeTuple([1]) - seq_length = ShapeTuple([1]) - kv_seq_len = ShapeTuple([1]) - sliding_window = ShapeTuple([3]) - - result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) - - correct = np.array([[[ - [3.402823e38] - ]]]).astype("float32") - - np.testing.assert_array_equal(result.numpy(), correct) - - ################### 2. TESTS FOR SUBSEQUENT PREFILL ################### - def test_subsequent_prefill_1(self): - """ - Test 1: chunk size is 3, WS is 5, cache carrying t0, t1, t2; input t3, t4, t5. - """ - - bsz = ShapeTuple([1]) - seq_length = ShapeTuple([3]) - kv_seq_len = ShapeTuple([6]) - sliding_window = ShapeTuple([5]) - - result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) - - correct = np.array([[[ - # pylint: disable=line-too-long - # | IN CACHE | CURRENT CHUNK | - # t0 t1 t2 t3 t4 t5 - [ 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38], - [ 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38], - [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] - ]]]).astype("float32") - - np.testing.assert_array_equal(result.numpy(), correct) - - def test_subsequent_prefill_2(self): - """ - Test 2: chunk size is 3, WS is 5, cache carrying t1 - t5 (t0 is overwritten); - input t6, t7, t8. - """ - bsz = ShapeTuple([1]) - seq_length = ShapeTuple([3]) - kv_seq_len = ShapeTuple([8]) - sliding_window = ShapeTuple([5]) - - result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) - - correct = np.array([[[ - # pylint: disable=line-too-long - # | IN CACHE | CURRENT CHUNK | - # t1 t2 t3 t4 t5 t6 t7 t8 - [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38], - [-3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38], - [-3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] - ]]]).astype("float32") - - np.testing.assert_array_equal(result.numpy(), correct) - - def test_subsequent_prefill_3(self): - """ - Test 3: chunk size is 5, WS is 5, cache carrying t0-t4; input t5-t9. - """ - bsz = ShapeTuple([1]) - seq_length = ShapeTuple([5]) - kv_seq_len = ShapeTuple([10]) - sliding_window = ShapeTuple([5]) - - result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) - - correct = np.array([[[ - # pylint: disable=line-too-long - # | IN CACHE | CURRENT CHUNK | - # t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 - [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38], - [-3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38], - [-3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38], - [-3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38], - [-3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] - ]]]).astype("float32") - - np.testing.assert_array_equal(result.numpy(), correct) - - def test_subsequent_prefill_4(self): - """ - Test 4: chunk size is 5, WS is 3, cache carrying t2-t4 (t0, t1 did not - stay in cache); input t5-t9. - """ - bsz = ShapeTuple([1]) - seq_length = ShapeTuple([5]) - kv_seq_len = ShapeTuple([8]) - sliding_window = ShapeTuple([3]) - - result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) - - correct = np.array([[[ - # pylint: disable=line-too-long - # | IN CACHE | CURRENT CHUNK | - # t2 t3 t4 t5 t6 t7 t8 t9 - [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38], - [-3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38], - [-3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38], - [-3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38], - [-3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] - ]]]).astype("float32") - - np.testing.assert_array_equal(result.numpy(), correct) - - def test_subsequent_prefill_5(self): - """ - Test 5: chunk size is 5, WS is 5, cache carrying t5-t9 (t0-t4 overwritten); - input t10 (remainder of a prompt). Note that this test can also be - viewed as a decode. That is, prefilling a chunk of size 1, is the same is decoding. - """ - bsz = ShapeTuple([1]) - seq_length = ShapeTuple([1]) - kv_seq_len = ShapeTuple([6]) - sliding_window = ShapeTuple([5]) - - result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) - - correct = np.array([[[ - # pylint: disable=line-too-long - # | IN CACHE |CURRENT CHUNK| - # t5 t6 t7 t8 t9 t10 - [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] - ]]]).astype("float32") - - np.testing.assert_array_equal(result.numpy(), correct) - - ################### 3. TESTS FOR DECODE ################### - def test_decode_1(self): - """ - Test 1: chunk size is 5, WS is 5, cache carrying t5-t9 (t0-t4 overwritten); - input t10 (decoding). - """ - bsz = ShapeTuple([1]) - seq_length = ShapeTuple([1]) - kv_seq_len = ShapeTuple([6]) - sliding_window = ShapeTuple([5]) - - result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) - - correct = np.array([[[ - # pylint: disable=line-too-long - # | IN CACHE |CURRENT CHUNK| - # t5 t6 t7 t8 t9 t10 - [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] - ]]]).astype("float32") - - np.testing.assert_array_equal(result.numpy(), correct) - - def test_decode_2(self): - """ - Test 2 (Cache not full): prompt is size 4, WS is 5, cache carrying t0-t3; input t4. - """ - bsz = ShapeTuple([1]) - seq_length = ShapeTuple([1]) - kv_seq_len = ShapeTuple([5]) - sliding_window = ShapeTuple([5]) - - result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) - - correct = np.array([[[ - # | IN CACHE |CURRENT CHUNK| - # t0 t1 t2 t3 t4 - [3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] - ]]]).astype("float32") - - np.testing.assert_array_equal(result.numpy(), correct) - - -if __name__ == "__main__": - unittest.main() From 8beed7a706fae9d857407e507b9def8e6c95e0e8 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 12 Mar 2024 00:02:37 -0400 Subject: [PATCH 058/531] [REFACTOR] rename mlc_chat => mlc_llm (#1932) This PR renames the mlc_chat pckage to the mlc_llm package now that this is the new official flow. We also update the necessary locations that might touch the package. --- ci/task/build_clean.sh | 2 +- cpp/llm_chat.cc | 3 +- cpp/serve/function_table.cc | 2 +- docs/compilation/compile_models.rst | 224 +++++++++--------- docs/compilation/convert_weights.rst | 30 +-- docs/deploy/android.rst | 18 +- docs/deploy/cli.rst | 24 +- docs/deploy/ios.rst | 40 ++-- docs/deploy/javascript.rst | 46 ++-- docs/deploy/python.rst | 48 ++-- docs/deploy/rest.rst | 24 +- docs/get_started/mlc_chat_config.rst | 6 +- docs/index.rst | 40 +--- docs/install/mlc_llm.rst | 38 +-- docs/prebuilt_models.rst | 182 +++++++------- examples/python/benchmark.py | 2 +- examples/python/sample_chat_stream.py | 4 +- examples/python/sample_mlc_chat.py | 6 +- examples/rest/nodejs/README.MD | 4 +- examples/rest/python/sample_langchain.py | 81 ++++--- pyproject.toml | 2 +- python/README.md | 5 - python/{mlc_chat => mlc_llm}/__init__.py | 0 python/{mlc_chat => mlc_llm}/__main__.py | 14 +- python/{mlc_chat => mlc_llm}/_ffi_api.py | 2 +- python/{mlc_chat => mlc_llm}/base.py | 0 python/{mlc_chat => mlc_llm}/callback.py | 0 python/{mlc_chat => mlc_llm}/chat_module.py | 34 ++- python/{mlc_chat => mlc_llm}/cli/__init__.py | 0 python/{mlc_chat => mlc_llm}/cli/bench.py | 8 +- python/{mlc_chat => mlc_llm}/cli/benchmark.py | 2 +- python/{mlc_chat => mlc_llm}/cli/chat.py | 6 +- .../{mlc_chat => mlc_llm}/cli/check_device.py | 0 python/{mlc_chat => mlc_llm}/cli/compile.py | 19 +- .../cli/convert_weight.py | 16 +- python/{mlc_chat => mlc_llm}/cli/delivery.py | 14 +- .../{mlc_chat => mlc_llm}/cli/gen_config.py | 12 +- .../cli/model_metadata.py | 8 +- python/{mlc_chat => mlc_llm}/cli/worker.py | 0 .../compiler_pass/__init__.py | 0 .../compiler_pass/attach_to_ir_module.py | 0 .../compiler_pass/clean_up_tir_attrs.py | 0 .../compiler_pass/cublas_dispatch.py | 0 .../dispatch_kv_cache_creation.py | 2 +- .../compiler_pass/estimate_memory_usage.py | 2 +- .../compiler_pass/fuse_add_norm.py | 0 .../fuse_dequantize_matmul_ewise.py | 0 .../compiler_pass/fuse_dequantize_take.py | 0 .../fuse_dequantize_transpose.py | 0 .../fuse_ft_dequantize_matmul_epilogue.py | 0 .../compiler_pass/fuse_transpose_matmul.py | 0 .../compiler_pass/lift_global_buffer_alloc.py | 0 .../compiler_pass/low_batch_specialization.py | 0 .../compiler_pass/pipeline.py | 2 +- .../compiler_pass/scatter_tuple_get_item.py | 0 .../conversation_template.py | 0 .../embeddings/__init__.py | 0 .../embeddings/openai.py | 2 +- python/{mlc_chat => mlc_llm}/gradio.py | 0 python/{mlc_chat => mlc_llm}/help.py | 0 .../interface/__init__.py | 0 .../{mlc_chat => mlc_llm}/interface/bench.py | 2 +- .../{mlc_chat => mlc_llm}/interface/chat.py | 8 +- .../interface/compile.py | 16 +- .../interface/compiler_flags.py | 6 +- .../interface/convert_weight.py | 12 +- .../interface/gen_config.py | 8 +- python/{mlc_chat => mlc_llm}/interface/jit.py | 12 +- .../interface/openai_api.py | 0 python/{mlc_chat => mlc_llm}/libinfo.py | 0 .../{mlc_chat => mlc_llm}/loader/__init__.py | 0 .../loader/huggingface_loader.py | 6 +- python/{mlc_chat => mlc_llm}/loader/loader.py | 0 .../{mlc_chat => mlc_llm}/loader/mapping.py | 0 python/{mlc_chat => mlc_llm}/loader/stats.py | 4 +- python/{mlc_chat => mlc_llm}/loader/utils.py | 2 +- .../{mlc_chat => mlc_llm}/model/__init__.py | 0 .../model/baichuan/__init__.py | 0 .../model/baichuan/baichuan_loader.py | 4 +- .../model/baichuan/baichuan_model.py | 10 +- .../model/baichuan/baichuan_quantization.py | 4 +- .../model/gemma/__init__.py | 0 .../model/gemma/gemma_loader.py | 4 +- .../model/gemma/gemma_model.py | 12 +- .../model/gemma/gemma_quantization.py | 4 +- .../model/gpt2/__init__.py | 0 .../model/gpt2/gpt2_loader.py | 4 +- .../model/gpt2/gpt2_model.py | 12 +- .../model/gpt2/gpt2_quantization.py | 4 +- .../model/gpt_bigcode/__init__.py | 0 .../model/gpt_bigcode/gpt_bigcode_loader.py | 4 +- .../model/gpt_bigcode/gpt_bigcode_model.py | 12 +- .../gpt_bigcode/gpt_bigcode_quantization.py | 4 +- .../model/gpt_neox/__init__.py | 0 .../model/gpt_neox/gpt_neox_loader.py | 4 +- .../model/gpt_neox/gpt_neox_model.py | 10 +- .../model/gpt_neox/gpt_neox_quantization.py | 4 +- .../model/internlm/__init__.py | 0 .../model/internlm/internlm_loader.py | 4 +- .../model/internlm/internlm_model.py | 10 +- .../model/internlm/internlm_quantization.py | 4 +- .../model/llama/__init__.py | 0 .../model/llama/llama_loader.py | 4 +- .../model/llama/llama_model.py | 12 +- .../model/llama/llama_quantization.py | 4 +- .../model/mistral/__init__.py | 0 .../model/mistral/mistral_loader.py | 4 +- .../model/mistral/mistral_model.py | 10 +- .../model/mistral/mistral_quantization.py | 4 +- .../model/mixtral/__init__.py | 0 .../model/mixtral/mixtral_loader.py | 4 +- .../model/mixtral/mixtral_model.py | 12 +- .../model/mixtral/mixtral_quantization.py | 4 +- python/{mlc_chat => mlc_llm}/model/model.py | 4 +- .../model/model_preset.py | 0 .../model/orion/__init__.py | 0 .../model/orion/orion_loader.py | 4 +- .../model/orion/orion_model.py | 12 +- .../model/orion/orion_quantization.py | 4 +- .../model/phi/__init__.py | 0 .../model/phi/phi_loader.py | 4 +- .../model/phi/phi_model.py | 12 +- .../model/phi/phi_quantization.py | 4 +- .../model/qwen/__init__.py | 0 .../model/qwen/qwen_loader.py | 4 +- .../model/qwen/qwen_model.py | 10 +- .../model/qwen/qwen_quantization.py | 4 +- .../model/qwen2/__init__.py | 0 .../model/qwen2/qwen2_loader.py | 4 +- .../model/qwen2/qwen2_model.py | 10 +- .../model/qwen2/qwen2_quantization.py | 4 +- .../model/rwkv5/__init__.py | 0 .../model/rwkv5/rwkv5_loader.py | 0 .../model/rwkv5/rwkv5_model.py | 6 +- .../model/rwkv5/rwkv5_quantization.py | 0 .../model/stable_lm/__init__.py | 0 .../model/stable_lm/stablelm_loader.py | 4 +- .../model/stable_lm/stablelm_model.py | 10 +- .../model/stable_lm/stablelm_quantization.py | 4 +- python/{mlc_chat => mlc_llm}/nn/__init__.py | 0 python/{mlc_chat => mlc_llm}/nn/expert.py | 2 +- python/{mlc_chat => mlc_llm}/nn/kv_cache.py | 2 +- python/{mlc_chat => mlc_llm}/nn/rnn_state.py | 0 python/{mlc_chat => mlc_llm}/op/__init__.py | 0 python/{mlc_chat => mlc_llm}/op/attention.py | 2 +- python/{mlc_chat => mlc_llm}/op/extern.py | 0 python/{mlc_chat => mlc_llm}/op/ft_gemm.py | 0 python/{mlc_chat => mlc_llm}/op/moe_matmul.py | 0 python/{mlc_chat => mlc_llm}/op/moe_misc.py | 0 .../op/position_embedding.py | 0 .../protocol/__init__.py | 0 .../protocol/conversation_protocol.py | 0 .../protocol/openai_api_protocol.py | 2 +- .../protocol/protocol_utils.py | 0 .../quantization/__init__.py | 0 .../quantization/awq_quantization.py | 2 +- .../quantization/ft_quantization.py | 0 .../quantization/group_quantization.py | 10 +- .../quantization/no_quantization.py | 0 .../quantization/quantization.py | 0 .../quantization/utils.py | 0 python/{mlc_chat => mlc_llm}/rest.py | 6 +- .../{mlc_chat => mlc_llm}/serve/__init__.py | 0 .../{mlc_chat => mlc_llm}/serve/_ffi_api.py | 2 +- .../serve/async_engine.py | 0 python/{mlc_chat => mlc_llm}/serve/config.py | 0 python/{mlc_chat => mlc_llm}/serve/data.py | 0 python/{mlc_chat => mlc_llm}/serve/engine.py | 16 +- .../serve/entrypoints/__init__.py | 0 .../serve/entrypoints/debug_entrypoints.py | 0 .../serve/entrypoints/entrypoint_utils.py | 0 .../serve/entrypoints/openai_entrypoints.py | 0 .../serve/event_trace_recorder.py | 0 python/{mlc_chat => mlc_llm}/serve/grammar.py | 0 python/{mlc_chat => mlc_llm}/serve/request.py | 0 .../serve/server/__init__.py | 0 .../serve/server/__main__.py | 0 .../serve/server/popen_server.py | 4 +- .../serve/server/server_context.py | 0 python/{mlc_chat => mlc_llm}/streamer.py | 0 .../{mlc_chat => mlc_llm}/support/__init__.py | 0 .../{mlc_chat => mlc_llm}/support/argparse.py | 0 .../support/auto_config.py | 16 +- .../support/auto_device.py | 2 +- .../support/auto_target.py | 2 +- .../support/auto_weight.py | 0 .../{mlc_chat => mlc_llm}/support/config.py | 0 .../support/constants.py | 6 +- .../support/convert_tiktoken.py | 0 .../{mlc_chat => mlc_llm}/support/download.py | 0 .../{mlc_chat => mlc_llm}/support/logging.py | 0 .../support/max_thread_check.py | 0 .../{mlc_chat => mlc_llm}/support/preshard.py | 0 .../{mlc_chat => mlc_llm}/support/random.py | 0 python/{mlc_chat => mlc_llm}/support/style.py | 0 .../support/tensor_parallel.py | 0 python/{mlc_chat => mlc_llm}/support/tqdm.py | 0 python/{mlc_chat => mlc_llm}/tokenizer.py | 0 python/setup.py | 18 +- rust/README.md | 2 +- tests/python/api/test_python.py | 4 +- tests/python/api/test_rest.py | 2 +- ...test_fuse_ft_dequantize_matmul_epilogue.py | 2 +- .../python/integration/test_model_compile.py | 12 +- tests/python/loader/test_awq.py | 8 +- tests/python/loader/test_huggingface.py | 6 +- tests/python/model/test_gpt2.py | 2 +- tests/python/model/test_gptNeox.py | 2 +- tests/python/model/test_kv_cache.py | 2 +- tests/python/model/test_llama.py | 2 +- tests/python/model/test_llama_quantization.py | 6 +- tests/python/model/test_mistral.py | 2 +- tests/python/model/test_phi.py | 2 +- .../quantization/test_awq_quantization.py | 4 +- .../quantization/test_group_quantization.py | 6 +- tests/python/serve/benchmark.py | 4 +- tests/python/serve/evaluate_engine.py | 4 +- tests/python/serve/server/conftest.py | 2 +- tests/python/serve/server/test_server.py | 2 +- .../python/serve/test_event_trace_recorder.py | 2 +- tests/python/serve/test_grammar_parser.py | 2 +- .../test_grammar_state_matcher_custom.py | 4 +- .../serve/test_grammar_state_matcher_json.py | 4 +- tests/python/serve/test_serve_async_engine.py | 4 +- .../serve/test_serve_async_engine_spec.py | 4 +- tests/python/serve/test_serve_engine.py | 4 +- .../python/serve/test_serve_engine_grammar.py | 8 +- tests/python/serve/test_serve_engine_spec.py | 4 +- tests/python/support/test_auto_config.py | 4 +- tests/python/support/test_auto_weight.py | 4 +- tests/python/support/test_streamer.py | 4 +- 231 files changed, 754 insertions(+), 788 deletions(-) delete mode 100644 python/README.md rename python/{mlc_chat => mlc_llm}/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/__main__.py (75%) rename python/{mlc_chat => mlc_llm}/_ffi_api.py (88%) rename python/{mlc_chat => mlc_llm}/base.py (100%) rename python/{mlc_chat => mlc_llm}/callback.py (100%) rename python/{mlc_chat => mlc_llm}/chat_module.py (97%) rename python/{mlc_chat => mlc_llm}/cli/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/cli/bench.py (89%) rename python/{mlc_chat => mlc_llm}/cli/benchmark.py (98%) rename python/{mlc_chat => mlc_llm}/cli/chat.py (88%) rename python/{mlc_chat => mlc_llm}/cli/check_device.py (100%) rename python/{mlc_chat => mlc_llm}/cli/compile.py (90%) rename python/{mlc_chat => mlc_llm}/cli/convert_weight.py (86%) rename python/{mlc_chat => mlc_llm}/cli/delivery.py (97%) rename python/{mlc_chat => mlc_llm}/cli/gen_config.py (90%) rename python/{mlc_chat => mlc_llm}/cli/model_metadata.py (97%) rename python/{mlc_chat => mlc_llm}/cli/worker.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/attach_to_ir_module.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/clean_up_tir_attrs.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/cublas_dispatch.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/dispatch_kv_cache_creation.py (99%) rename python/{mlc_chat => mlc_llm}/compiler_pass/estimate_memory_usage.py (98%) rename python/{mlc_chat => mlc_llm}/compiler_pass/fuse_add_norm.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/fuse_dequantize_matmul_ewise.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/fuse_dequantize_take.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/fuse_dequantize_transpose.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/fuse_transpose_matmul.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/lift_global_buffer_alloc.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/low_batch_specialization.py (100%) rename python/{mlc_chat => mlc_llm}/compiler_pass/pipeline.py (99%) rename python/{mlc_chat => mlc_llm}/compiler_pass/scatter_tuple_get_item.py (100%) rename python/{mlc_chat => mlc_llm}/conversation_template.py (100%) rename python/{mlc_chat => mlc_llm}/embeddings/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/embeddings/openai.py (99%) rename python/{mlc_chat => mlc_llm}/gradio.py (100%) rename python/{mlc_chat => mlc_llm}/help.py (100%) rename python/{mlc_chat => mlc_llm}/interface/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/interface/bench.py (93%) rename python/{mlc_chat => mlc_llm}/interface/chat.py (96%) rename python/{mlc_chat => mlc_llm}/interface/compile.py (96%) rename python/{mlc_chat => mlc_llm}/interface/compiler_flags.py (96%) rename python/{mlc_chat => mlc_llm}/interface/convert_weight.py (95%) rename python/{mlc_chat => mlc_llm}/interface/gen_config.py (97%) rename python/{mlc_chat => mlc_llm}/interface/jit.py (94%) rename python/{mlc_chat => mlc_llm}/interface/openai_api.py (100%) rename python/{mlc_chat => mlc_llm}/libinfo.py (100%) rename python/{mlc_chat => mlc_llm}/loader/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/loader/huggingface_loader.py (98%) rename python/{mlc_chat => mlc_llm}/loader/loader.py (100%) rename python/{mlc_chat => mlc_llm}/loader/mapping.py (100%) rename python/{mlc_chat => mlc_llm}/loader/stats.py (97%) rename python/{mlc_chat => mlc_llm}/loader/utils.py (98%) rename python/{mlc_chat => mlc_llm}/model/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/baichuan/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/baichuan/baichuan_loader.py (92%) rename python/{mlc_chat => mlc_llm}/model/baichuan/baichuan_model.py (95%) rename python/{mlc_chat => mlc_llm}/model/baichuan/baichuan_quantization.py (89%) rename python/{mlc_chat => mlc_llm}/model/gemma/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/gemma/gemma_loader.py (97%) rename python/{mlc_chat => mlc_llm}/model/gemma/gemma_model.py (98%) rename python/{mlc_chat => mlc_llm}/model/gemma/gemma_quantization.py (90%) rename python/{mlc_chat => mlc_llm}/model/gpt2/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/gpt2/gpt2_loader.py (96%) rename python/{mlc_chat => mlc_llm}/model/gpt2/gpt2_model.py (98%) rename python/{mlc_chat => mlc_llm}/model/gpt2/gpt2_quantization.py (93%) rename python/{mlc_chat => mlc_llm}/model/gpt_bigcode/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/gpt_bigcode/gpt_bigcode_loader.py (94%) rename python/{mlc_chat => mlc_llm}/model/gpt_bigcode/gpt_bigcode_model.py (98%) rename python/{mlc_chat => mlc_llm}/model/gpt_bigcode/gpt_bigcode_quantization.py (93%) rename python/{mlc_chat => mlc_llm}/model/gpt_neox/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/gpt_neox/gpt_neox_loader.py (97%) rename python/{mlc_chat => mlc_llm}/model/gpt_neox/gpt_neox_model.py (98%) rename python/{mlc_chat => mlc_llm}/model/gpt_neox/gpt_neox_quantization.py (92%) rename python/{mlc_chat => mlc_llm}/model/internlm/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/internlm/internlm_loader.py (97%) rename python/{mlc_chat => mlc_llm}/model/internlm/internlm_model.py (98%) rename python/{mlc_chat => mlc_llm}/model/internlm/internlm_quantization.py (92%) rename python/{mlc_chat => mlc_llm}/model/llama/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/llama/llama_loader.py (98%) rename python/{mlc_chat => mlc_llm}/model/llama/llama_model.py (98%) rename python/{mlc_chat => mlc_llm}/model/llama/llama_quantization.py (93%) rename python/{mlc_chat => mlc_llm}/model/mistral/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/mistral/mistral_loader.py (98%) rename python/{mlc_chat => mlc_llm}/model/mistral/mistral_model.py (98%) rename python/{mlc_chat => mlc_llm}/model/mistral/mistral_quantization.py (93%) rename python/{mlc_chat => mlc_llm}/model/mixtral/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/mixtral/mixtral_loader.py (97%) rename python/{mlc_chat => mlc_llm}/model/mixtral/mixtral_model.py (96%) rename python/{mlc_chat => mlc_llm}/model/mixtral/mixtral_quantization.py (93%) rename python/{mlc_chat => mlc_llm}/model/model.py (98%) rename python/{mlc_chat => mlc_llm}/model/model_preset.py (100%) rename python/{mlc_chat => mlc_llm}/model/orion/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/orion/orion_loader.py (96%) rename python/{mlc_chat => mlc_llm}/model/orion/orion_model.py (98%) rename python/{mlc_chat => mlc_llm}/model/orion/orion_quantization.py (90%) rename python/{mlc_chat => mlc_llm}/model/phi/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/phi/phi_loader.py (98%) rename python/{mlc_chat => mlc_llm}/model/phi/phi_model.py (98%) rename python/{mlc_chat => mlc_llm}/model/phi/phi_quantization.py (92%) rename python/{mlc_chat => mlc_llm}/model/qwen/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/qwen/qwen_loader.py (95%) rename python/{mlc_chat => mlc_llm}/model/qwen/qwen_model.py (98%) rename python/{mlc_chat => mlc_llm}/model/qwen/qwen_quantization.py (92%) rename python/{mlc_chat => mlc_llm}/model/qwen2/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/qwen2/qwen2_loader.py (96%) rename python/{mlc_chat => mlc_llm}/model/qwen2/qwen2_model.py (98%) rename python/{mlc_chat => mlc_llm}/model/qwen2/qwen2_quantization.py (92%) rename python/{mlc_chat => mlc_llm}/model/rwkv5/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/rwkv5/rwkv5_loader.py (100%) rename python/{mlc_chat => mlc_llm}/model/rwkv5/rwkv5_model.py (99%) rename python/{mlc_chat => mlc_llm}/model/rwkv5/rwkv5_quantization.py (100%) rename python/{mlc_chat => mlc_llm}/model/stable_lm/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/model/stable_lm/stablelm_loader.py (97%) rename python/{mlc_chat => mlc_llm}/model/stable_lm/stablelm_model.py (98%) rename python/{mlc_chat => mlc_llm}/model/stable_lm/stablelm_quantization.py (92%) rename python/{mlc_chat => mlc_llm}/nn/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/nn/expert.py (95%) rename python/{mlc_chat => mlc_llm}/nn/kv_cache.py (99%) rename python/{mlc_chat => mlc_llm}/nn/rnn_state.py (100%) rename python/{mlc_chat => mlc_llm}/op/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/op/attention.py (99%) rename python/{mlc_chat => mlc_llm}/op/extern.py (100%) rename python/{mlc_chat => mlc_llm}/op/ft_gemm.py (100%) rename python/{mlc_chat => mlc_llm}/op/moe_matmul.py (100%) rename python/{mlc_chat => mlc_llm}/op/moe_misc.py (100%) rename python/{mlc_chat => mlc_llm}/op/position_embedding.py (100%) rename python/{mlc_chat => mlc_llm}/protocol/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/protocol/conversation_protocol.py (100%) rename python/{mlc_chat => mlc_llm}/protocol/openai_api_protocol.py (99%) rename python/{mlc_chat => mlc_llm}/protocol/protocol_utils.py (100%) rename python/{mlc_chat => mlc_llm}/quantization/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/quantization/awq_quantization.py (99%) rename python/{mlc_chat => mlc_llm}/quantization/ft_quantization.py (100%) rename python/{mlc_chat => mlc_llm}/quantization/group_quantization.py (98%) rename python/{mlc_chat => mlc_llm}/quantization/no_quantization.py (100%) rename python/{mlc_chat => mlc_llm}/quantization/quantization.py (100%) rename python/{mlc_chat => mlc_llm}/quantization/utils.py (100%) rename python/{mlc_chat => mlc_llm}/rest.py (98%) rename python/{mlc_chat => mlc_llm}/serve/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/serve/_ffi_api.py (87%) rename python/{mlc_chat => mlc_llm}/serve/async_engine.py (100%) rename python/{mlc_chat => mlc_llm}/serve/config.py (100%) rename python/{mlc_chat => mlc_llm}/serve/data.py (100%) rename python/{mlc_chat => mlc_llm}/serve/engine.py (98%) rename python/{mlc_chat => mlc_llm}/serve/entrypoints/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/serve/entrypoints/debug_entrypoints.py (100%) rename python/{mlc_chat => mlc_llm}/serve/entrypoints/entrypoint_utils.py (100%) rename python/{mlc_chat => mlc_llm}/serve/entrypoints/openai_entrypoints.py (100%) rename python/{mlc_chat => mlc_llm}/serve/event_trace_recorder.py (100%) rename python/{mlc_chat => mlc_llm}/serve/grammar.py (100%) rename python/{mlc_chat => mlc_llm}/serve/request.py (100%) rename python/{mlc_chat => mlc_llm}/serve/server/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/serve/server/__main__.py (100%) rename python/{mlc_chat => mlc_llm}/serve/server/popen_server.py (97%) rename python/{mlc_chat => mlc_llm}/serve/server/server_context.py (100%) rename python/{mlc_chat => mlc_llm}/streamer.py (100%) rename python/{mlc_chat => mlc_llm}/support/__init__.py (100%) rename python/{mlc_chat => mlc_llm}/support/argparse.py (100%) rename python/{mlc_chat => mlc_llm}/support/auto_config.py (92%) rename python/{mlc_chat => mlc_llm}/support/auto_device.py (98%) rename python/{mlc_chat => mlc_llm}/support/auto_target.py (99%) rename python/{mlc_chat => mlc_llm}/support/auto_weight.py (100%) rename python/{mlc_chat => mlc_llm}/support/config.py (100%) rename python/{mlc_chat => mlc_llm}/support/constants.py (93%) rename python/{mlc_chat => mlc_llm}/support/convert_tiktoken.py (100%) rename python/{mlc_chat => mlc_llm}/support/download.py (100%) rename python/{mlc_chat => mlc_llm}/support/logging.py (100%) rename python/{mlc_chat => mlc_llm}/support/max_thread_check.py (100%) rename python/{mlc_chat => mlc_llm}/support/preshard.py (100%) rename python/{mlc_chat => mlc_llm}/support/random.py (100%) rename python/{mlc_chat => mlc_llm}/support/style.py (100%) rename python/{mlc_chat => mlc_llm}/support/tensor_parallel.py (100%) rename python/{mlc_chat => mlc_llm}/support/tqdm.py (100%) rename python/{mlc_chat => mlc_llm}/tokenizer.py (100%) diff --git a/ci/task/build_clean.sh b/ci/task/build_clean.sh index 997979f701..c08ae9d129 100755 --- a/ci/task/build_clean.sh +++ b/ci/task/build_clean.sh @@ -8,4 +8,4 @@ set -x rm -rf ${WORKSPACE_CWD}/build/ \ ${WORKSPACE_CWD}/python/dist/ \ ${WORKSPACE_CWD}/python/build/ \ - ${WORKSPACE_CWD}/python/mlc_chat.egg-info + ${WORKSPACE_CWD}/python/mlc_llm.egg-info diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index cfb08082f5..e0f653841e 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -127,8 +127,7 @@ struct FunctionTable { device_ids[i] = i; } this->use_disco = true; - this->sess = - Session::ProcessSession(num_shards, f_create_process_pool, "mlc_chat.cli.worker"); + this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), lib_path, null_device); diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index bbeb23ec89..70c855d5f7 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -85,7 +85,7 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object device_ids[i] = i; } this->use_disco = true; - this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_chat.cli.worker"); + this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), lib_path, null_device); diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index 855c805094..b30076f018 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -20,7 +20,7 @@ We compile ``RedPajama-INCITE-Chat-3B-v1`` with ``q4f16_1`` as an example for al .. note:: Before you proceed, make sure you followed :ref:`install-tvm-unity`, a required backend to compile models with MLC LLM. - + Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python` to obtain the CLI app / Python API that can be used to chat with the compiled model. Finally, we strongly recommend you to read :ref:`project-overview` first to get @@ -33,20 +33,20 @@ We compile ``RedPajama-INCITE-Chat-3B-v1`` with ``q4f16_1`` as an example for al 0. Verify Installation ---------------------- -**Step 1. Verify mlc_chat** +**Step 1. Verify mlc_llm** -We use the python package ``mlc_chat`` to compile models. This can be installed by +We use the python package ``mlc_llm`` to compile models. This can be installed by following :ref:`install-mlc-packages`, either by building from source, or by -installing the prebuilt package. Verify ``mlc_chat`` installation in command line via: +installing the prebuilt package. Verify ``mlc_llm`` installation in command line via: .. code:: bash - $ mlc_chat --help + $ mlc_llm --help # You should see help information with this line usage: MLC LLM Command Line Interface. [-h] {compile,convert_weight,gen_config} .. note:: - If it runs into error ``command not found: mlc_chat``, try ``python -m mlc_chat --help``. + If it runs into error ``command not found: mlc_llm``, try ``python -m mlc_llm --help``. **Step 2. Verify TVM** @@ -75,7 +75,7 @@ can share the same compiled/quantized weights. git clone https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1 cd ../.. # Convert weight - mlc_chat convert_weight ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + mlc_llm convert_weight ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ --quantization q4f16_1 \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC @@ -103,11 +103,11 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ --quantization q4f16_1 --conv-template redpajama_chat \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ --device cuda -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so @@ -118,11 +118,11 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ --quantization q4f16_1 --conv-template redpajama_chat \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ --device metal -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so Cross-Compiling for Intel Mac on M-chip Mac: @@ -130,11 +130,11 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ --quantization q4f16_1 --conv-template redpajama_chat \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ --device metal:x86-64 -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal_x86_64.dylib For Intel Mac: @@ -142,38 +142,38 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ --quantization q4f16_1 --conv-template redpajama_chat \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ --device metal -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal_x86_64.dylib .. group-tab:: Vulkan - For Linux: + For Linux: .. code:: shell - + # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ --quantization q4f16_1 --conv-template redpajama_chat \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ --device vulkan -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so - For Windows: + For Windows: .. code:: shell - + # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ --quantization q4f16_1 --conv-template redpajama_chat \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ --device vulkan -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.dll .. group-tab:: iOS/iPadOS @@ -183,11 +183,11 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ --quantization q4f16_1 \ --conv-template redpajama_chat --context-window-size 768 \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ --device iphone -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar .. note:: @@ -207,11 +207,11 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ --quantization q4f16_1 \ --conv-template redpajama_chat --context-window-size 768 \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ --device android -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar .. group-tab:: WebGPU @@ -219,15 +219,15 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ --quantization q4f16_1 --conv-template redpajama_chat \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ --device webgpu -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm .. note:: - To compile for webgpu, you need to build from source when installing ``mlc_chat``. Besides, you also need to follow :ref:`install-web-build`. + To compile for webgpu, you need to build from source when installing ``mlc_llm``. Besides, you also need to follow :ref:`install-web-build`. Otherwise, it would run into error .. code:: text @@ -243,13 +243,13 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from 'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range. -.. note:: +.. note:: For the ``conv-template``, `conv_template.cc `__ contains a full list of conversation templates that MLC provides. If the model you are adding requires a new conversation template, you would need to add your own. Follow `this PR `__ as an example. - However, adding your own template would require you :ref:`build mlc_chat from source ` + However, adding your own template would require you :ref:`build mlc_llm from source ` in order for it to be recognized by the runtime. For more details, please see :ref:`configure-mlc-chat-json`. @@ -283,7 +283,7 @@ We can check the output with the commands below: .. code:: shell python - >>> from mlc_chat import ChatModule + >>> from mlc_llm import ChatModule >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") >>> cm.generate("hi") @@ -310,7 +310,7 @@ We can check the output with the commands below: .. code:: shell python - >>> from mlc_chat import ChatModule + >>> from mlc_llm import ChatModule >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so") >>> cm.generate("hi") @@ -338,7 +338,7 @@ We can check the output with the commands below: .. code:: shell python - >>> from mlc_chat import ChatModule + >>> from mlc_llm import ChatModule >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so", device="vulkan") >>> cm.generate("hi") @@ -426,8 +426,8 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell - mlc_chat convert_weight ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC - + mlc_llm convert_weight ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC + Afterwards, run the following command to generate mlc config and compile the model. .. code:: shell @@ -442,10 +442,10 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ --device cuda -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-cuda.so .. tab:: Metal @@ -455,10 +455,10 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ --device metal -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-metal.so Cross-Compiling for Intel Mac on M-chip Mac: @@ -466,11 +466,11 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ --quantization q4f16_1 --conv-template redpajama_chat \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ --device metal:x86-64 -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal_x86_64.dylib For Intel Mac: @@ -478,34 +478,34 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ --device metal -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-metal_x86_64.dylib .. tab:: Vulkan - For Linux: + For Linux: .. code:: shell - + # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ --device vulkan -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-vulkan.so - For Windows: + For Windows: .. code:: shell - + # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ --device vulkan -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-vulkan.dll .. tab:: WebGPU @@ -513,14 +513,14 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ --context-window-size 2048 --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ --device webgpu -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-webgpu.wasm .. note:: - To compile for webgpu, you need to build from source when installing ``mlc_chat``. Besides, you also need to follow :ref:`install-web-build`. + To compile for webgpu, you need to build from source when installing ``mlc_llm``. Besides, you also need to follow :ref:`install-web-build`. Otherwise, it would run into error .. code:: text @@ -534,10 +534,10 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ --conv-template llama-2 --context-window-size 768 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ --device iphone -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-iphone.tar .. tab:: Android @@ -545,10 +545,10 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ --conv-template llama-2 --context-window-size 768 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ --device android -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-android.tar .. tab:: Mistral-7B-Instruct-v0.2 @@ -571,7 +571,7 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell - mlc_chat convert_weight ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + mlc_llm convert_weight ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC Afterwards, run the following command to generate mlc config and compile the model. @@ -588,10 +588,10 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ --device cuda -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so .. tab:: Metal @@ -601,10 +601,10 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ --device metal -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-metal.so @@ -613,34 +613,34 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ --device metal -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-metal_x86_64.dylib .. tab:: Vulkan - For Linux: + For Linux: .. code:: shell - + # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ --device vulkan -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-vulkan.so - For Windows: + For Windows: .. code:: shell - + # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ --device vulkan -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-vulkan.dll .. tab:: WebGPU @@ -648,15 +648,15 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ --prefill-chunk-size 1024 --conv-template mistral_default \ -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ --device webgpu -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-webgpu.wasm .. note:: - To compile for webgpu, you need to build from source when installing ``mlc_chat``. Besides, you also need to follow :ref:`install-web-build`. + To compile for webgpu, you need to build from source when installing ``mlc_llm``. Besides, you also need to follow :ref:`install-web-build`. Otherwise, it would run into error .. code:: text @@ -679,11 +679,11 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ --conv-template mistral_default --sliding-window-size 1024 --prefill-chunk-size 128 \ -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ --device iphone -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-iphone.tar .. tab:: Android @@ -691,10 +691,10 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ + mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ --conv-template mistral_default --sliding-window-size 1024 --prefill-chunk-size 128 -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ --device android -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-android.tar .. tab:: Other models @@ -714,7 +714,7 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell - mlc_chat convert_weight ./dist/models/HF_MODEL/ --quantization q4f16_1 -o dist/OUTPUT-MLC + mlc_llm convert_weight ./dist/models/HF_MODEL/ --quantization q4f16_1 -o dist/OUTPUT-MLC Afterwards, run the following command to generate mlc config and compile the model. @@ -730,9 +730,9 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ + mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device cuda -o dist/libs/OUTPUT-cuda.so + mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device cuda -o dist/libs/OUTPUT-cuda.so .. tab:: Metal @@ -741,9 +741,9 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ + mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device metal -o dist/libs/OUTPUT-metal.so + mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device metal -o dist/libs/OUTPUT-metal.so For Intel Mac: @@ -751,41 +751,41 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ + mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device metal -o dist/libs/OUTPUT-metal_x86_64.dylib + mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device metal -o dist/libs/OUTPUT-metal_x86_64.dylib .. tab:: Vulkan - For Linux: + For Linux: .. code:: shell - + # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ + mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device vulkan -o dist/libs/OUTPUT-vulkan.so + mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device vulkan -o dist/libs/OUTPUT-vulkan.so - For Windows: + For Windows: .. code:: shell - + # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ + mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device vulkan -o dist/libs/OUTPUT-vulkan.dll + mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device vulkan -o dist/libs/OUTPUT-vulkan.dll .. tab:: WebGPU .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ + mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device webgpu -o dist/libs/OUTPUT-webgpu.wasm + mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device webgpu -o dist/libs/OUTPUT-webgpu.wasm .. note:: - To compile for webgpu, you need to build from source when installing ``mlc_chat``. Besides, you also need to follow :ref:`install-web-build`. + To compile for webgpu, you need to build from source when installing ``mlc_llm``. Besides, you also need to follow :ref:`install-web-build`. Otherwise, it would run into error .. code:: text @@ -808,20 +808,20 @@ generalized to any model variant, as long as mlc-llm supports the architecture. .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE \ + mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE \ --context-window-size 768 -o dist/OUTPUT-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device iphone -o dist/libs/OUTPUT-iphone.tar + mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device iphone -o dist/libs/OUTPUT-iphone.tar .. tab:: Android .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE \ + mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE \ --context-window-size 768 -o dist/OUTPUT-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device android -o dist/libs/OUTPUT-android.tar + mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device android -o dist/libs/OUTPUT-android.tar For each model and each backend, the above only provides the most recommended build command (which is the most optimized). You can also try with different argument values (e.g., different quantization modes, context window size, etc.), @@ -852,7 +852,7 @@ Weight conversion command follows the pattern below: .. code:: text - mlc_chat convert_weight \ + mlc_llm convert_weight \ CONFIG \ --quantization QUANTIZATION_MODE \ [--model-type MODEL_TYPE] \ @@ -880,7 +880,7 @@ Note that ``CONFIG`` is a positional argument. Arguments wrapped with ``[ ]`` ar Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main. For existing pre-defined model architecture, see ``MODEL_PRESETS`` - `here `_. + `here `_. --quantization QUANTIZATION_MODE The quantization mode we use to compile. @@ -914,7 +914,7 @@ Config generation command follows the pattern below: .. code:: text - mlc_chat gen_config \ + mlc_llm gen_config \ CONFIG \ --quantization QUANTIZATION_MODE \ [--model-type MODEL_TYPE] \ @@ -944,7 +944,7 @@ Note that ``CONFIG`` is a positional argument. Arguments wrapped with ``[ ]`` ar Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main. For existing pre-defined model architecture, see ``MODEL_PRESETS`` - `here `_. + `here `_. --quantization QUANTIZATION_MODE The quantization mode we use to compile. @@ -959,11 +959,11 @@ Note that ``CONFIG`` is a positional argument. Arguments wrapped with ``[ ]`` ar --conv-template CONV_TEMPLATE Conversation template. It depends on how the model is tuned. Use "LM" for vanilla base model For existing pre-defined templates, see ``CONV_TEMPLATES`` - `here `_. + `here `_. --context-window-size CONTEXT_WINDOW_SIZE Option to provide the maximum sequence length supported by the model. This is usually explicitly shown as context length or context window in the model card. - If this option is not set explicitly, by default, + If this option is not set explicitly, by default, it will be determined by ``context_window_size`` or ``max_position_embeddings`` in ``config.json``, and the latter is usually inaccurate for some models. @@ -990,7 +990,7 @@ Model compilation command follows the pattern below: .. code:: text - mlc_chat compile \ + mlc_llm compile \ MODEL \ [--quantization QUANTIZATION_MODE] \ [--model-type MODEL_TYPE] \ @@ -1031,7 +1031,7 @@ Note that ``MODEL`` is a positional argument. Arguments wrapped with ``[ ]`` are denoted as ``O0``, ``O1``, ``O2``, ``O3``, where ``O0`` means no optimization, ``O2`` means majority of them, and ``O3`` represents extreme optimization that could potentially break the system. - + Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. ``--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0"``. diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst index 7657bca7d8..2507687c21 100644 --- a/docs/compilation/convert_weights.rst +++ b/docs/compilation/convert_weights.rst @@ -8,8 +8,8 @@ To run a model with MLC LLM in any platform, you need: 1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC `_.) 2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). -In many cases, we only need to convert weights and reuse existing model library. -This page demonstrates adding a model variant with ``mlc_chat convert_weight``, which +In many cases, we only need to convert weights and reuse existing model library. +This page demonstrates adding a model variant with ``mlc_llm convert_weight``, which takes a hugginface model as input and converts/quantizes into MLC-compatible weights. Specifically, we add RedPjama-INCITE-**Instruct**-3B-v1, while MLC already @@ -23,7 +23,7 @@ This can be extended to, e.g.: .. note:: Before you proceed, make sure you followed :ref:`install-tvm-unity`, a required backend to compile models with MLC LLM. - + Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python` to obtain the CLI app / Python API that can be used to chat with the compiled model. Finally, we strongly recommend you to read :ref:`project-overview` first to get @@ -38,20 +38,20 @@ This can be extended to, e.g.: 0. Verify installation ---------------------- -**Step 1. Verify mlc_chat** +**Step 1. Verify mlc_llm** -We use the python package ``mlc_chat`` to compile models. This can be installed by +We use the python package ``mlc_llm`` to compile models. This can be installed by following :ref:`install-mlc-packages`, either by building from source, or by -installing the prebuilt package. Verify ``mlc_chat`` installation in command line via: +installing the prebuilt package. Verify ``mlc_llm`` installation in command line via: .. code:: bash - $ mlc_chat --help + $ mlc_llm --help # You should see help information with this line usage: MLC LLM Command Line Interface. [-h] {compile,convert_weight,gen_config} .. note:: - If it runs into error ``command not found: mlc_chat``, try ``python -m mlc_chat --help``. + If it runs into error ``command not found: mlc_llm``, try ``python -m mlc_llm --help``. **Step 2. Verify TVM** @@ -80,7 +80,7 @@ for specification of ``convert_weight``. git clone https://huggingface.co/togethercomputer/RedPajama-INCITE-Instruct-3B-v1 cd ../.. # Convert weight - mlc_chat convert_weight ./dist/models/RedPajama-INCITE-Instruct-3B-v1/ \ + mlc_llm convert_weight ./dist/models/RedPajama-INCITE-Instruct-3B-v1/ \ --quantization q4f16_1 \ -o dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC @@ -89,12 +89,12 @@ for specification of ``convert_weight``. 2. Generate MLC Chat Config --------------------------- -Use ``mlc_chat gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. +Use ``mlc_llm gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. See :ref:`compile-command-specification` for specification of ``gen_config``. .. code:: shell - mlc_chat gen_config ./dist/models/RedPajama-INCITE-Instruct-3B-v1/ \ + mlc_llm gen_config ./dist/models/RedPajama-INCITE-Instruct-3B-v1/ \ --quantization q4f16_1 --conv-template redpajama_chat \ -o dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/ @@ -102,7 +102,7 @@ See :ref:`compile-command-specification` for specification of ``gen_config``. .. note:: The file ``mlc-chat-config.json`` is crucial in both model compilation and runtime chatting. Here we only care about the latter case. - + You can **optionally** customize ``dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/mlc-chat-config.json`` (checkout :ref:`configure-mlc-chat-json` for more detailed instructions). You can also simply use the default configuration. @@ -111,7 +111,7 @@ See :ref:`compile-command-specification` for specification of ``gen_config``. contains a full list of conversation templates that MLC provides. If the model you are adding requires a new conversation template, you would need to add your own. Follow `this PR `__ as an example. However, - adding your own template would require you :ref:`build mlc_chat from source ` in order for it + adding your own template would require you :ref:`build mlc_llm from source ` in order for it to be recognized by the runtime. By now, you should have the following files. @@ -132,7 +132,7 @@ By now, you should have the following files. (Optional) 3. Upload weights to HF ---------------------------------- -Optionally, you can upload what we have to huggingface. +Optionally, you can upload what we have to huggingface. .. code:: shell @@ -175,7 +175,7 @@ Running the distributed models are similar to running prebuilt model weights and # Run the model in Python; note that we reuse `-Chat` model library python - >>> from mlc_chat import ChatModule + >>> from mlc_llm import ChatModule >>> cm = ChatModule(model="dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC", \ model_lib_path="dist/prebuilt_libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") # Adjust based on backend >>> cm.generate("hi") diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst index 7bcda64ff4..a9b2fcb18f 100644 --- a/docs/deploy/android.rst +++ b/docs/deploy/android.rst @@ -37,8 +37,8 @@ Prerequisite **JDK**, such as OpenJDK >= 17, to compile Java bindings of TVM Unity runtime. It could be installed via Homebrew on macOS, apt on Ubuntu or other package managers. Set up the following environment variable: -- ``JAVA_HOME`` so that Java is available in ``$JAVA_HOME/bin/java``. - +- ``JAVA_HOME`` so that Java is available in ``$JAVA_HOME/bin/java``. + Please ensure that the JDK versions for Android Studio and JAVA_HOME are the same. We recommended setting the `JAVA_HOME` to the JDK bundled with Android Studio. e.g. `export JAVA_HOME=/Applications/Android\ Studio.app/Contents/jbr/Contents/Home` for macOS. **TVM Unity runtime** is placed under `3rdparty/tvm `__ in MLC LLM, so there is no need to install anything extra. Set up the following environment variable: @@ -92,14 +92,14 @@ To deploy models on Android with reasonable performance, one has to cross-compil .. code-block:: bash # convert weights - mlc_chat convert_weight ./dist/models/$MODEL_NAME/ --quantization $QUANTIZATION -o dist/$MODEL_NAME-$QUANTIZATION-MLC/ + mlc_llm convert_weight ./dist/models/$MODEL_NAME/ --quantization $QUANTIZATION -o dist/$MODEL_NAME-$QUANTIZATION-MLC/ # create mlc-chat-config.json - mlc_chat gen_config ./dist/models/$MODEL_NAME/ --quantization $QUANTIZATION \ + mlc_llm gen_config ./dist/models/$MODEL_NAME/ --quantization $QUANTIZATION \ --conv-template llama-2 --context-window-size 768 -o dist/${MODEL_NAME}-${QUANTIZATION}-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/${MODEL_NAME}-${QUANTIZATION}-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/${MODEL_NAME}-${QUANTIZATION}-MLC/mlc-chat-config.json \ --device android -o ./dist/${MODEL_NAME}-${QUANTIZATION}-MLC/${MODEL_NAME}-${QUANTIZATION}-android.tar This generates the directory ``./dist/$MODEL_NAME-$QUANTIZATION-MLC`` which contains the necessary components to run the model, as explained below. @@ -131,19 +131,19 @@ The source code for MLC LLM is available under ``android/``, including scripts t (Required) Unique local identifier to identify the model. ``model_lib`` - (Required) Matches the system-lib-prefix, generally set during ``mlc_chat compile`` which can be specified using - ``--system-lib-prefix`` argument. By default, it is set to ``"${model_type}_${quantization}"`` e.g. ``gpt_neox_q4f16_1`` for the RedPajama-INCITE-Chat-3B-v1 model. If the ``--system-lib-prefix`` argument is manually specified during ``mlc_chat compile``, the ``model_lib`` field should be updated accordingly. + (Required) Matches the system-lib-prefix, generally set during ``mlc_llm compile`` which can be specified using + ``--system-lib-prefix`` argument. By default, it is set to ``"${model_type}_${quantization}"`` e.g. ``gpt_neox_q4f16_1`` for the RedPajama-INCITE-Chat-3B-v1 model. If the ``--system-lib-prefix`` argument is manually specified during ``mlc_llm compile``, the ``model_lib`` field should be updated accordingly. ``estimated_vram_bytes`` (Optional) Estimated requirements of VRAM to run the model. - + To change the configuration, edit ``app-config.json``: .. code-block:: bash vim ./src/main/assets/app-config.json -Then bundle the android library ``${MODEL_NAME}-${QUANTIZATION}-android.tar`` compiled from ``mlc_chat compile`` in the previous steps, with TVM Unity's Java runtime by running the commands below: +Then bundle the android library ``${MODEL_NAME}-${QUANTIZATION}-android.tar`` compiled from ``mlc_llm compile`` in the previous steps, with TVM Unity's Java runtime by running the commands below: .. code-block:: bash diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst index 83a2a9dcf1..f341e31e71 100644 --- a/docs/deploy/cli.rst +++ b/docs/deploy/cli.rst @@ -19,8 +19,8 @@ To use other GPU runtimes, e.g. CUDA, please instead :ref:`build it from source .. code:: shell conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly - mlc_chat chat -h + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly + mlc_llm chat -h .. note:: The prebuilt package supports **Metal** on macOS and **Vulkan** on Linux and Windows. It is possible to use other GPU runtimes such as **CUDA** by compiling MLCChat CLI from the source. @@ -29,7 +29,7 @@ To use other GPU runtimes, e.g. CUDA, please instead :ref:`build it from source Option 2. Build MLC Runtime from Source ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -We also provide options to build mlc runtime libraries and ``mlc_chat`` from source. +We also provide options to build mlc runtime libraries and ``mlc_llm`` from source. This step is useful if the prebuilt is unavailable on your platform, or if you would like to build a runtime that supports other GPU runtime than the prebuilt version. We can build a customized version of mlc chat runtime. You only need to do this if you choose not to use the prebuilt. @@ -44,7 +44,7 @@ Then please follow the instructions in :ref:`mlcchat_build_from_source` to build Run Models through MLCChat CLI ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Once ``mlc_chat`` is installed, you are able to run any MLC-compiled model on the command line. +Once ``mlc_llm`` is installed, you are able to run any MLC-compiled model on the command line. To run a model with MLC LLM in any platform, you can either: @@ -53,14 +53,14 @@ To run a model with MLC LLM in any platform, you can either: **Option 1: Use model prebuilts** -To run ``mlc_chat``, you can specify the Huggingface MLC prebuilt model repo path with the prefix ``HF://``. +To run ``mlc_llm``, you can specify the Huggingface MLC prebuilt model repo path with the prefix ``HF://``. For example, to run the MLC Llama 2 7B Q4F16_1 model (`Repo link `_), simply use ``HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC``. The model weights and library will be downloaded automatically from Huggingface. .. code:: shell - mlc_chat chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC --device "cuda:0" --overrides context_window_size=1024 + mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC --device "cuda:0" --overrides context_window_size=1024 .. code:: shell @@ -75,10 +75,10 @@ automatically from Huggingface. Multi-line input: Use escape+enter to start a new line. [INST]: What's the meaning of life - [/INST]: - Ah, a question that has puzzled philosophers and theologians for centuries! The meaning - of life is a deeply personal and subjective topic, and there are many different - perspectives on what it might be. However, here are some possible answers that have been + [/INST]: + Ah, a question that has puzzled philosophers and theologians for centuries! The meaning + of life is a deeply personal and subjective topic, and there are many different + perspectives on what it might be. However, here are some possible answers that have been proposed by various thinkers and cultures: ... @@ -91,14 +91,14 @@ For models other than the prebuilt ones we provided: follow :ref:`convert-weights-via-MLC` to convert the weights and reuse existing model libraries. 2. Otherwise, follow :ref:`compile-model-libraries` to compile both the model library and weights. -Once you have the model locally compiled with a model library and model weights, to run ``mlc_chat``, simply +Once you have the model locally compiled with a model library and model weights, to run ``mlc_llm``, simply - Specify the path to ``mlc-chat-config.json`` and the converted model weights to ``--model`` - Specify the path to the compiled model library (e.g. a .so file) to ``--model-lib-path`` .. code:: shell - mlc_chat chat dist/Llama-2-7b-chat-hf-q4f16_1-MLC \ + mlc_llm chat dist/Llama-2-7b-chat-hf-q4f16_1-MLC \ --device "cuda:0" --overrides context_window_size=1024 \ --model-lib-path dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-vulkan.so # CUDA on Linux: dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index 0d3b4f6ff1..c0217db9e9 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -160,10 +160,10 @@ controls the list of local and remote models to be packaged into the app, given (Required) Unique local identifier to identify the model. ``model_lib`` - (Required) Matches the system-lib-prefix, generally set during ``mlc_chat compile`` which can be specified using - ``--system-lib-prefix`` argument. By default, it is set to ``"${model_type}_${quantization}"`` e.g. ``gpt_neox_q4f16_1`` - for the RedPajama-INCITE-Chat-3B-v1 model. If the ``--system-lib-prefix`` argument is manually specified during - ``mlc_chat compile``, the ``model_lib`` field should be updated accordingly. + (Required) Matches the system-lib-prefix, generally set during ``mlc_llm compile`` which can be specified using + ``--system-lib-prefix`` argument. By default, it is set to ``"${model_type}_${quantization}"`` e.g. ``gpt_neox_q4f16_1`` + for the RedPajama-INCITE-Chat-3B-v1 model. If the ``--system-lib-prefix`` argument is manually specified during + ``mlc_llm compile``, the ``model_lib`` field should be updated accordingly. ``required_vram_bytes`` (Required) Estimated requirements of VRAM to run the model. @@ -192,7 +192,7 @@ In this section, we walk you through adding ``NeuralHermes-2.5-Mistral-7B-q3f16_ According to the model's ``config.json`` on `its Huggingface repo `_, it reuses the Mistral model architecture. -.. note:: +.. note:: This section largely replicates :ref:`convert-weights-via-MLC`. See that page for more details. Note that the weights are shared across @@ -213,26 +213,26 @@ for specification of ``convert_weight``. git clone https://huggingface.co/mlabonne/NeuralHermes-2.5-Mistral-7B cd ../.. # Convert weight - mlc_chat convert_weight ./dist/models/NeuralHermes-2.5-Mistral-7B/ \ + mlc_llm convert_weight ./dist/models/NeuralHermes-2.5-Mistral-7B/ \ --quantization q4f16_1 \ -o dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC **Step 2 Generate MLC Chat Config** -Use ``mlc_chat gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. +Use ``mlc_llm gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. See :ref:`compile-command-specification` for specification of ``gen_config``. .. code:: shell - mlc_chat gen_config ./dist/models/NeuralHermes-2.5-Mistral-7B/ \ + mlc_llm gen_config ./dist/models/NeuralHermes-2.5-Mistral-7B/ \ --quantization q3f16_1 --conv-template neural_hermes_mistral \ -o dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC For the ``conv-template``, `conv_template.cc `__ contains a full list of conversation templates that MLC provides. -If the model you are adding requires a new conversation template, you would need to add your own. -Follow `this PR `__ as an example. +If the model you are adding requires a new conversation template, you would need to add your own. +Follow `this PR `__ as an example. We look up the template to use with the ``conv_template`` field in ``mlc-chat-config.json``. For more details, please see :ref:`configure-mlc-chat-json`. @@ -250,7 +250,7 @@ For more details, please see :ref:`configure-mlc-chat-json`. git add . && git commit -m "Add mistral model weights" git push origin main -After successfully following all steps, you should end up with a Huggingface repo similar to +After successfully following all steps, you should end up with a Huggingface repo similar to `NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC `__, which includes the converted/quantized weights, the ``mlc-chat-config.json``, and tokenizer files. @@ -261,11 +261,11 @@ Finally, we modify the code snippet for `app-config.json `__ pasted above. -We simply specify the Huggingface link as ``model_url``, while reusing the ``model_lib`` for +We simply specify the Huggingface link as ``model_url``, while reusing the ``model_lib`` for ``Mistral-7B``. .. code:: javascript - + "model_list": [ // Other records here omitted... { @@ -304,7 +304,7 @@ more details, specifically the ``iOS`` option. **Step 0. Install dependencies** -To compile model libraries for iOS, you need to :ref:`build mlc_chat from source `. +To compile model libraries for iOS, you need to :ref:`build mlc_llm from source `. **Step 1. Clone from HF and convert_weight** @@ -320,7 +320,7 @@ can share the same compiled/quantized weights. git clone https://huggingface.co/microsoft/phi-2 cd ../.. # Convert weight - mlc_chat convert_weight ./dist/models/phi-2/ \ + mlc_llm convert_weight ./dist/models/phi-2/ \ --quantization q4f16_1 \ -o dist/phi-2-q4f16_1-MLC @@ -338,11 +338,11 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/phi-2/ \ + mlc_llm gen_config ./dist/models/phi-2/ \ --quantization q4f16_1 --conv-template phi-2 \ -o dist/phi-2-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json \ --device iphone -o dist/libs/phi-2-q4f16_1-iphone.tar .. note:: @@ -396,7 +396,7 @@ hardware. We can calculate this estimate using the following command: .. code:: shell - ~/mlc-llm > python -m mlc_chat.cli.model_metadata ./dist/libs/phi-2-q4f16_1-iphone.tar \ + ~/mlc-llm > python -m mlc_llm.cli.model_metadata ./dist/libs/phi-2-q4f16_1-iphone.tar \ > --memory-only --mlc-chat-config ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json INFO model_metadata.py:90: Total memory usage: 3042.96 MB (Parameters: 1492.45 MB. KVCache: 640.00 MB. Temporary buffer: 910.51 MB) INFO model_metadata.py:99: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size` @@ -408,12 +408,12 @@ Finally, we update the code snippet for `app-config.json `__ pasted above. -We simply specify the Huggingface link as ``model_url``, while using the new ``model_lib`` for +We simply specify the Huggingface link as ``model_url``, while using the new ``model_lib`` for ``phi-2``. Regarding the field ``estimated_vram_bytes``, we can use the output of the last step rounded up to MB. .. code:: javascript - + "model_list": [ // Other records here omitted... { diff --git a/docs/deploy/javascript.rst b/docs/deploy/javascript.rst index 06a1d3fdcb..57f192f61a 100644 --- a/docs/deploy/javascript.rst +++ b/docs/deploy/javascript.rst @@ -33,7 +33,7 @@ is powered by the WebLLM npm package, specifically with the code in the `simple-chat `__ example. Each of the model in the `WebLLM prebuilt webpage `__ -is registered as an instance of ``ModelRecord``. Looking at the most straightforward example +is registered as an instance of ``ModelRecord``. Looking at the most straightforward example `get-started `__, we see the code snippet: @@ -61,7 +61,7 @@ we see the code snippet: Just like any other platforms, to run a model with on WebLLM, you need: -1. **Model weights** converted to MLC format (e.g. `Llama-2-7b-hf-q4f32_1-MLC +1. **Model weights** converted to MLC format (e.g. `Llama-2-7b-hf-q4f32_1-MLC `_.): downloaded through ``model_url`` 2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__): downloaded through ``model_lib_url``. @@ -69,22 +69,22 @@ Verify Installation for Adding Models ------------------------------------- In sections below, we walk you through two examples of adding models to WebLLM. Before proceeding, -please verify installation of ``mlc_chat`` and ``tvm``: +please verify installation of ``mlc_llm`` and ``tvm``: -**Step 1. Verify mlc_chat** +**Step 1. Verify mlc_llm** -We use the python package ``mlc_chat`` to compile models. This can be installed by +We use the python package ``mlc_llm`` to compile models. This can be installed by following :ref:`install-mlc-packages`, either by building from source, or by -installing the prebuilt package. Verify ``mlc_chat`` installation in command line via: +installing the prebuilt package. Verify ``mlc_llm`` installation in command line via: .. code:: bash - $ mlc_chat --help + $ mlc_llm --help # You should see help information with this line usage: MLC LLM Command Line Interface. [-h] {compile,convert_weight,gen_config} .. note:: - If it runs into error ``command not found: mlc_chat``, try ``python -m mlc_chat --help``. + If it runs into error ``command not found: mlc_llm``, try ``python -m mlc_llm --help``. **Step 2. Verify TVM** @@ -109,12 +109,12 @@ model, we only need to convert weights and reuse existing model library. For ins - Adding ``Llama2-uncensored`` when MLC supports ``Llama2`` -In this section, we walk you through adding ``WizardMath-7B-V1.1-q4f16_1`` to the +In this section, we walk you through adding ``WizardMath-7B-V1.1-q4f16_1`` to the `get-started `__ example. According to the model's ``config.json`` on `its Huggingface repo `_, it reuses the Mistral model architecture. -.. note:: +.. note:: This section largely replicates :ref:`convert-weights-via-MLC`. See that page for more details. Note that the weights are shared across @@ -135,18 +135,18 @@ for specification of ``convert_weight``. git clone https://huggingface.co/WizardLM/WizardMath-7B-V1.1 cd ../.. # Convert weight - mlc_chat convert_weight ./dist/models/WizardMath-7B-V1.1/ \ + mlc_llm convert_weight ./dist/models/WizardMath-7B-V1.1/ \ --quantization q4f16_1 \ -o dist/WizardMath-7B-V1.1-q4f16_1-MLC **Step 2 Generate MLC Chat Config** -Use ``mlc_chat gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. +Use ``mlc_llm gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. See :ref:`compile-command-specification` for specification of ``gen_config``. .. code:: shell - mlc_chat gen_config ./dist/models/WizardMath-7B-V1.1/ \ + mlc_llm gen_config ./dist/models/WizardMath-7B-V1.1/ \ --quantization q4f16_1 --conv-template wizard_coder_or_math \ -o dist/WizardMath-7B-V1.1-q4f16_1-MLC/ @@ -159,11 +159,11 @@ We look up the template to use with the ``conv_template`` field in ``mlc-chat-co For more details, please see :ref:`configure-mlc-chat-json`. -.. note:: +.. note:: If you added your conversation template in ``src/conversation.ts``, you need to build WebLLM from source following the instruction in - `the WebLLM repo's README `_. + `the WebLLM repo's README `_. Alternatively, you could use the ``"custom"`` conversation template so that you can pass in your own ``ConvTemplateConfig`` in runtime without having to build the package from source. @@ -181,7 +181,7 @@ For more details, please see :ref:`configure-mlc-chat-json`. git add . && git commit -m "Add wizardMath model weights" git push origin main -After successfully following all steps, you should end up with a Huggingface repo similar to +After successfully following all steps, you should end up with a Huggingface repo similar to `WizardMath-7B-V1.1-q4f16_1-MLC `__, which includes the converted/quantized weights, the ``mlc-chat-config.json``, and tokenizer files. @@ -192,7 +192,7 @@ Finally, we modify the code snippet for `get-started `__ pasted above. -We simply specify the Huggingface link as ``model_url``, while reusing the ``model_lib_url`` for +We simply specify the Huggingface link as ``model_url``, while reusing the ``model_lib_url`` for ``Mistral-7B``. Note that we need the suffix to be ``/resolve/main/``. .. code:: typescript @@ -215,7 +215,7 @@ We simply specify the Huggingface link as ``model_url``, while reusing the ``mod Now, running the ``get-started`` example will use the ``WizardMath`` model you just added. See `get-started's README `__ -on how to run it. +on how to run it. Bring Your Own Model Library @@ -241,7 +241,7 @@ more details, specifically the ``WebGPU`` option. **Step 0. Install dependencies** -To compile model libraries for webgpu, you need to :ref:`build mlc_chat from source `. +To compile model libraries for webgpu, you need to :ref:`build mlc_llm from source `. Besides, you also need to follow :ref:`install-web-build`. Otherwise, it would run into error: .. code:: text @@ -262,7 +262,7 @@ can share the same compiled/quantized weights. git clone https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1 cd ../.. # Convert weight - mlc_chat convert_weight ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + mlc_llm convert_weight ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ --quantization q4f16_1 \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC @@ -280,11 +280,11 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con .. code:: shell # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_chat gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ + mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ --quantization q4f16_1 --conv-template redpajama_chat \ -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_chat compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ + mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ --device webgpu -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm .. note:: @@ -357,4 +357,4 @@ Finally, we are able to run the model we added in WebLLM's `get-started `__ -on how to run it. \ No newline at end of file +on how to run it. \ No newline at end of file diff --git a/docs/deploy/python.rst b/docs/deploy/python.rst index 3dd1b67743..d5edcf82aa 100644 --- a/docs/deploy/python.rst +++ b/docs/deploy/python.rst @@ -32,9 +32,9 @@ Verify Installation .. code:: bash - python -c "from mlc_chat import ChatModule; print(ChatModule)" + python -c "from mlc_llm import ChatModule; print(ChatModule)" -You are expected to see the information about the :class:`mlc_chat.ChatModule` class. +You are expected to see the information about the :class:`mlc_llm.ChatModule` class. If the command above results in error, follow :ref:`install-mlc-packages` (either install the prebuilt pip wheels or :ref:`mlcchat_build_from_source`). @@ -44,7 +44,7 @@ Run MLC Models w/ Python To run a model with MLC LLM in any platform/runtime, you need: -1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-MLC +1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-MLC `_.) 2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). @@ -77,14 +77,14 @@ Skip this step if you have already obtained the model weights and libraries. **Step 2: Run the model in Python** -Use the conda environment you used to install ``mlc_chat``. +Use the conda environment you used to install ``mlc_llm``. From the ``mlc-llm`` directory, you can create a Python -file ``sample_mlc_chat.py`` and paste the following lines: +file ``sample_mlc_llm.py`` and paste the following lines: .. code:: python - from mlc_chat import ChatModule - from mlc_chat.callback import StreamToStdout + from mlc_llm import ChatModule + from mlc_llm.callback import StreamToStdout # Create a ChatModule instance cm = ChatModule( @@ -125,7 +125,7 @@ Now run the Python file to start the chat .. code:: bash - python sample_mlc_chat.py + python sample_mlc_llm.py .. collapse:: See output @@ -173,14 +173,14 @@ option of overriding any field you'd like in Python, so that you do not need to ``mlc-chat-config.json``. Since there are two concepts -- `MLCChat Configuration` and `Conversation Configuration` -- we correspondingly -provide two dataclasses :class:`mlc_chat.ChatConfig` and :class:`mlc_chat.ConvConfig`. +provide two dataclasses :class:`mlc_llm.ChatConfig` and :class:`mlc_llm.ConvConfig`. We provide an example below. .. code:: python - from mlc_chat import ChatModule, ChatConfig, ConvConfig - from mlc_chat.callback import StreamToStdout + from mlc_llm import ChatModule, ChatConfig, ConvConfig + from mlc_llm.callback import StreamToStdout # Using a `ConvConfig`, we modify `system`, a field in the conversation template # `system` refers to the prompt encoded before starting the chat @@ -232,12 +232,12 @@ We provide an example below. | -.. note:: +.. note:: You do not need to specify the entire ``ChatConfig`` or ``ConvConfig``. Instead, we will first load all the fields defined in ``mlc-chat-config.json``, a file required when instantiating - a :class:`mlc_chat.ChatModule`. Then, we will load in the optional ``ChatConfig`` you provide, overriding the + a :class:`mlc_llm.ChatModule`. Then, we will load in the optional ``ChatConfig`` you provide, overriding the fields specified. - + It is also worth noting that ``ConvConfig`` itself is overriding the original conversation template specified by the field ``conv_template`` in the chat configuration. Learn more about it in :ref:`Configure MLCChat in JSON`. @@ -245,7 +245,7 @@ We provide an example below. Raw Text Generation in Python ----------------------------- -Raw text generation allows the user to have more flexibility over his prompts, +Raw text generation allows the user to have more flexibility over his prompts, without being forced to create a new conversational template, making prompt customization easier. This serves other demands for APIs to handle LLM generation without the usual system prompts and other items. @@ -253,8 +253,8 @@ We provide an example below. .. code:: python - from mlc_chat import ChatModule, ChatConfig, ConvConfig - from mlc_chat.callback import StreamToStdout + from mlc_llm import ChatModule, ChatConfig, ConvConfig + from mlc_llm.callback import StreamToStdout # Use a `ConvConfig` to define the generation settings # Since the "LM" template only supports raw text generation, @@ -293,9 +293,9 @@ We provide an example below. progress_callback=StreamToStdout(callback_interval=2), ) -.. note:: +.. note:: The ``LM`` is a template without memory, which means that every execution will be cleared. - Additionally, system prompts will not be run when instantiating a `mlc_chat.ChatModule`, + Additionally, system prompts will not be run when instantiating a `mlc_llm.ChatModule`, unless explicitly given inside the prompt. Stream Iterator in Python @@ -308,8 +308,8 @@ We provide an example below. .. code:: python - from mlc_chat import ChatModule - from mlc_chat.callback import StreamIterator + from mlc_llm import ChatModule + from mlc_llm.callback import StreamIterator # Create a ChatModule instance cm = ChatModule( @@ -340,10 +340,10 @@ We provide an example below. API Reference ------------- -User can initiate a chat module by creating :class:`mlc_chat.ChatModule` class, which is a wrapper of the MLC-Chat model. -The :class:`mlc_chat.ChatModule` class provides the following methods: +User can initiate a chat module by creating :class:`mlc_llm.ChatModule` class, which is a wrapper of the MLC-Chat model. +The :class:`mlc_llm.ChatModule` class provides the following methods: -.. currentmodule:: mlc_chat +.. currentmodule:: mlc_llm .. autoclass:: ChatModule :members: diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index d12029a80d..d955d6066f 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -18,7 +18,7 @@ Verify Installation .. code:: bash - python -m mlc_chat.rest --help + python -m mlc_llm.rest --help You are expected to see the help information of the REST API. @@ -32,14 +32,14 @@ that supports other GPU runtime than the prebuilt version. We can build a custom of mlc chat runtime. You only need to do this if you choose not to use the prebuilt. First, make sure you install TVM unity (following the instruction in :ref:`install-tvm-unity`). -You can choose to only pip install `mlc-ai-nightly` that comes with the tvm unity but skip `mlc-chat-nightly`. +You can choose to only pip install `mlc-ai-nightly` that comes with the tvm unity but skip `mlc-llm-nightly`. Then please follow the instructions in :ref:`mlcchat_build_from_source` to build the necessary libraries. -You can now use ``mlc_chat`` package by including the `python` directory to ``PYTHONPATH`` environment variable. +You can now use ``mlc_llm`` package by including the `python` directory to ``PYTHONPATH`` environment variable. .. code:: bash - PYTHONPATH=python python -m mlc_chat.rest --help + PYTHONPATH=python python -m mlc_llm.rest --help Launch the Server ----------------- @@ -48,7 +48,7 @@ To launch the REST server for MLC-Chat, run the following command in your termin .. code:: bash - python -m mlc_chat.rest --model MODEL [--lib-path LIB_PATH] [--device DEVICE] [--host HOST] [--port PORT] + python -m mlc_llm.rest --model MODEL [--lib-path LIB_PATH] [--device DEVICE] [--host HOST] [--port PORT] --model The model folder after compiling with MLC-LLM build process. The parameter can either be the model name with its quantization scheme @@ -115,10 +115,10 @@ The REST API provides the following endpoints: For more details on how repetition penalty controls text generation, please check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). **presence_penalty**: *float* (optional) - Positive values penalize new tokens if they are already present in the text so far, + Positive values penalize new tokens if they are already present in the text so far, decreasing the model's likelihood to repeat tokens. **frequency_penalty**: *float* (optional) - Positive values penalize new tokens based on their existing frequency in the text so far, + Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat tokens. **mean_gen_len**: *int* (optional) The approximated average number of generated tokens in each round. Used @@ -129,7 +129,7 @@ The REST API provides the following endpoints: ------------------------------------------------ -**Returns** +**Returns** If ``stream`` is set to ``False``, the response will be a ``CompletionResponse`` object. If ``stream`` is set to ``True``, the response will be a stream of ``CompletionStreamResponse`` objects. @@ -177,10 +177,10 @@ The REST API provides the following endpoints: For more details on how repetition penalty controls text generation, please check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). **presence_penalty**: *float* (optional) - Positive values penalize new tokens if they are already present in the text so far, + Positive values penalize new tokens if they are already present in the text so far, decreasing the model's likelihood to repeat tokens. **frequency_penalty**: *float* (optional) - Positive values penalize new tokens based on their existing frequency in the text so far, + Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat tokens. **mean_gen_len**: *int* (optional) The approximated average number of generated tokens in each round. Used @@ -200,7 +200,7 @@ The REST API provides the following endpoints: ------------------------------------------------ -**Returns** +**Returns** If ``stream`` is set to ``False``, the response will be a ``ChatCompletionResponse`` object. If ``stream`` is set to ``True``, the response will be a stream of ``ChatCompletionStreamResponse`` objects. @@ -344,7 +344,7 @@ Response Objects The role(author) of the message. It can be either ``user`` or ``assistant``. **content**: *str* The content of the message. - + ------------------------------------------------ diff --git a/docs/get_started/mlc_chat_config.rst b/docs/get_started/mlc_chat_config.rst index c583c1659a..ccaa97b4fc 100644 --- a/docs/get_started/mlc_chat_config.rst +++ b/docs/get_started/mlc_chat_config.rst @@ -62,7 +62,7 @@ Below is the ``mlc-chat-config.json`` file corresponding to Llama2 model: "conv_template": "llama-2", } -.. note:: +.. note:: Fields in the first part of ``mlc-chat-config.json`` (e.g. ``context-window-size``) is only for compile-time. Changing them during runtime may lead to unexpected behavior. @@ -224,7 +224,7 @@ If you're tired of the default system prompt, here's an example of how you can r } -The next time you run ``mlc_chat`` CLI, you will start a chat with Vicuna using a new system prompt. +The next time you run ``mlc_llm`` CLI, you will start a chat with Vicuna using a new system prompt. .. _example_resume_chat_history: @@ -251,4 +251,4 @@ The following example demonstrates how to chat with Vicuna and resume from a cha } -The next time you start ``mlc_chat`` CLI, or use Python API, you will initiate a chat with Vicuna and resume from the provided chat history. +The next time you start ``mlc_llm`` CLI, or use Python API, you will initiate a chat with Vicuna and resume from the provided chat history. diff --git a/docs/index.rst b/docs/index.rst index 596e5d3877..504b667285 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,7 +17,7 @@ It is recommended to have at least 6GB free VRAM to run it. .. tab:: Python - **Install MLC Chat Python**. :doc:`MLC LLM ` is available via pip. + **Install MLC LLM Python**. :doc:`MLC LLM ` is available via pip. It is always recommended to install it in an isolated conda virtual environment. **Download pre-quantized weights**. The commands below download the int4-quantized Llama2-7B from HuggingFace: @@ -38,8 +38,8 @@ It is recommended to have at least 6GB free VRAM to run it. .. code:: python - from mlc_chat import ChatModule - from mlc_chat.callback import StreamToStdout + from mlc_llm import ChatModule + from mlc_llm.callback import StreamToStdout cm = ChatModule( model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", @@ -63,44 +63,16 @@ It is recommended to have at least 6GB free VRAM to run it. .. tab:: Command Line - **Install MLC Chat CLI.** MLC Chat CLI is available via conda using the command below. + **Install MLC LLM**. :doc:`MLC LLM ` is available via pip. It is always recommended to install it in an isolated conda virtual environment. - For Windows/Linux users, make sure to have latest :ref:`Vulkan driver ` installed. - - .. code:: bash - - conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly - conda activate mlc-chat-venv - - **Download pre-quantized weights**. The comamnds below download the int4-quantized Llama2-7B from HuggingFace: - - .. code:: bash - - git lfs install && mkdir dist/ - git clone https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC \ - dist/Llama-2-7b-chat-hf-q4f16_1-MLC - **Download pre-compiled model library**. The pre-compiled model library is available as below: - - .. code:: bash - - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs + For Windows/Linux users, make sure to have latest :ref:`Vulkan driver ` installed. **Run in command line**. .. code:: bash - mlc_chat chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC - - .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/Llama2-macOS.gif - :width: 500 - :align: center - - MLC LLM on CLI - - .. note:: - The MLC Chat CLI package is only built with Vulkan (Windows/Linux) and Metal (macOS). - To use other GPU backends such as CUDA and ROCm, please use the prebuilt Python package or build from source. + mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC .. tab:: Web Browser diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index 004ee1529e..b4eff63041 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -29,49 +29,49 @@ Select your operating system/compute platform and run the command in your termin .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly .. tab:: CUDA 11.7 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu117 mlc-ai-nightly-cu117 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu117 mlc-ai-nightly-cu117 .. tab:: CUDA 11.8 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu118 mlc-ai-nightly-cu118 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu118 mlc-ai-nightly-cu118 .. tab:: CUDA 12.1 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu121 mlc-ai-nightly-cu121 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu121 mlc-ai-nightly-cu121 .. tab:: CUDA 12.2 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu122 mlc-ai-nightly-cu122 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu122 mlc-ai-nightly-cu122 .. tab:: ROCm 5.6 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-rocm56 mlc-ai-nightly-rocm56 - + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-rocm56 mlc-ai-nightly-rocm56 + .. tab:: ROCm 5.7 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-rocm57 mlc-ai-nightly-rocm57 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-rocm57 mlc-ai-nightly-rocm57 .. tab:: Vulkan @@ -101,7 +101,7 @@ Select your operating system/compute platform and run the command in your termin .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly .. note:: @@ -122,7 +122,7 @@ Select your operating system/compute platform and run the command in your termin .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly .. note:: If encountering the error below: @@ -142,8 +142,8 @@ Then you can verify installation in command line: .. code-block:: bash - python -c "import mlc_chat; print(mlc_chat)" - # Prints out: + python -c "import mlc_llm; print(mlc_llm)" + # Prints out: | @@ -152,7 +152,7 @@ Then you can verify installation in command line: Option 2. Build from Source --------------------------- -We also provide options to build mlc runtime libraries ``mlc_chat`` from source. +We also provide options to build mlc runtime libraries ``mlc_llm`` from source. This step is useful when you want to make modification or obtain a specific version of mlc runtime. @@ -203,11 +203,11 @@ This step is useful when you want to make modification or obtain a specific vers If you are using CUDA and your compute capability is above 80, then it is require to build with ``set(USE_FLASHINFER ON)``. Otherwise, you may run into ``Cannot find PackedFunc`` issue during runtime. - + To check your CUDA compute capability, you can use ``nvidia-smi --query-gpu=compute_cap --format=csv``. -**Step 3. Install via Python.** We recommend that you install ``mlc_chat`` as a Python package, giving you -access to ``mlc_chat.compile``, ``mlc_chat.ChatModule``, and the CLI. +**Step 3. Install via Python.** We recommend that you install ``mlc_llm`` as a Python package, giving you +access to ``mlc_llm.compile``, ``mlc_llm.ChatModule``, and the CLI. There are two ways to do so: .. tabs :: @@ -223,7 +223,7 @@ There are two ways to do so: cd /path-to-mlc-llm/python pip install -e . -**Step 4. Validate installation.** You may validate if MLC libarires and mlc_chat CLI is compiled successfully using the following command: +**Step 4. Validate installation.** You may validate if MLC libarires and mlc_llm CLI is compiled successfully using the following command: .. code-block:: bash :caption: Validate installation @@ -231,10 +231,10 @@ There are two ways to do so: # expected to see `libmlc_llm.so` and `libtvm_runtime.so` ls -l ./build/ # expected to see help message - mlc_chat chat -h + mlc_llm chat -h Finally, you can verify installation in command line. You should see the path you used to build from source with: .. code:: bash - python -c "import mlc_chat; print(mlc_chat)" + python -c "import mlc_llm; print(mlc_llm)" diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst index 6d848d57d0..e299f68138 100644 --- a/docs/prebuilt_models.rst +++ b/docs/prebuilt_models.rst @@ -12,8 +12,8 @@ Model Prebuilts Overview -------- -MLC-LLM is a universal solution for deploying different language models. Any models that can be described in `TVM Relax `__ -(a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the +MLC-LLM is a universal solution for deploying different language models. Any models that can be described in `TVM Relax `__ +(a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the help of :doc:`TVM Unity `. There are two ways to run a model on MLC-LLM (this page focuses on the second one): @@ -68,7 +68,7 @@ For more, please see :doc:`the CLI page `, and the :doc:`the Python .. code:: shell - mlc_chat chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC To run the model with Python API, see :doc:`the Python page ` (all other downloading steps are the same as CLI). @@ -86,7 +86,7 @@ For more, please see :doc:`the iOS page `. .. collapse:: Click to show details - The `iOS app `_ has builtin RedPajama-3B and Mistral-7B-Instruct-v0.2 support. + The `iOS app `_ has builtin RedPajama-3B and Mistral-7B-Instruct-v0.2 support. All prebuilt models with an entry in ``iOS`` in the :ref:`model library table ` are supported by iOS. Namely, we have: @@ -175,7 +175,7 @@ MLC-LLM supports the following model architectures: - Unavailable in MLC Prebuilts * - `LLaMA `__ - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ + * `MLC Implementation `__ - * :ref:`Llama-2-chat ` - * `Code Llama `__ * `Vicuna `__ @@ -191,40 +191,40 @@ MLC-LLM supports the following model architectures: * `YuLan-Chat `__ * - `Mistral `__ - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ + * `MLC Implementation `__ - * :ref:`Mistral-7B-Instruct-v0.2 ` * :ref:`NeuralHermes-2.5-Mistral-7B ` * :ref:`OpenHermes-2.5-Mistral-7B ` * :ref:`WizardMath-7B-V1.1 ` - - + - * - `GPT-NeoX `__ - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`RedPajama ` + * `MLC Implementation `__ + - * :ref:`RedPajama ` - * `Dolly `__ * `Pythia `__ * `StableCode `__ * - `GPTBigCode `__ - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - + * `MLC Implementation `__ + - - * `StarCoder `__ * `SantaCoder `__ * `WizardCoder (old) `__ * - `Phi `__ - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ + * `MLC Implementation `__ - * :ref:`Phi-1_5 ` * :ref:`Phi-2 ` - - + - * - `GPT2 `__ - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ + * `MLC Implementation `__ - * :ref:`GPT2 ` - - + - If the model variant you are interested in uses one of these model architectures we support, -(but we have not provided the prebuilt weights yet), you can check out +(but we have not provided the prebuilt weights yet), you can check out :doc:`/compilation/convert_weights` on how to convert the weights. Afterwards, you may follow :ref:`distribute-compiled-models` to upload your prebuilt weights to hugging face, and submit a PR that adds an entry to this page, @@ -291,59 +291,59 @@ Llama - `q4f16_1 `__ `q4f32_1 `__ - - + - - `q4f16_1 `__ `q4f32_1 `__ - - + - - `q4f16_1 `__ `q4f32_1 `__ - - - - + - + - - `q4f16_1 `__ `q4f32_1 `__ - `q4f16_1 `__ `q4f32_1 `__ - - + - * - 13B - `q4f16_1 `__ - - + - - `q4f16_1 `__ - - + - - `q4f16_1 `__ - - - - - - + - + - + - - `q4f16_1 `__ - - + - * - 34B - - - - - - - - - - - - - - - - - - + - + - + - + - + - + - + - + - + - - * - 70B - `q4f16_1 `__ - - + - - `q4f16_1 `__ - - + - - `q4f16_1 `__ - - - - - - + - + - + - - `q4f16_1 `__ - - + - .. _mistral_library_table: - + Mistral ^^^^^^^ .. list-table:: Mistral @@ -372,11 +372,11 @@ Mistral - mali * - 7B - `q4f16_1 `__ - - + - - `q4f16_1 `__ - - + - - `q4f16_1 `__ - - + - - `q3f16_1 `__ - `q4f16_1 `__ - `q4f16_1 `__ @@ -384,7 +384,7 @@ Mistral .. _gpt_neox_library_table: - + GPT-NeoX (RedPajama-INCITE) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. list-table:: GPT-NeoX (RedPajama-INCITE) @@ -413,23 +413,23 @@ GPT-NeoX (RedPajama-INCITE) - mali * - 3B - `q4f16_1 `__ - + `q4f32_1 `__ - - + - - `q4f16_1 `__ - + `q4f32_1 `__ - - + - - `q4f16_1 `__ - + `q4f32_1 `__ - - + - - `q4f16_1 `__ - `q4f16_1 `__ `q4f32_1 `__ - `q4f16_1 `__ - + `q4f32_1 `__ - @@ -463,19 +463,19 @@ GPTBigCode - webgpu - mali * - 15B - - - - - - - - - - - - - - - - - - - - + - + - + - + - + - + - + - + - + - + - .. _phi_library_table: - + Phi ^^^ .. list-table:: Phi @@ -503,50 +503,50 @@ Phi - webgpu - mali * - Phi-2 - + (2.7B) - `q0f16 `__ `q4f16_1 `__ - - + - - `q0f16 `__ `q4f16_1 `__ - - + - - `q0f16 `__ `q4f16_1 `__ - - - - - - + - + - + - - `q0f16 `__ `q4f16_1 `__ - * - Phi-1.5 - + (1.3B) - `q0f16 `__ `q4f16_1 `__ - - + - - `q0f16 `__ `q4f16_1 `__ - - + - - `q0f16 `__ `q4f16_1 `__ - - - - - - + - + - + - - `q0f16 `__ `q4f16_1 `__ - .. _gpt2_library_table: - + GPT2 ^^^^ .. list-table:: GPT2 @@ -573,30 +573,30 @@ GPT2 - Android - webgpu - mali - * - GPT2 - + * - GPT2 + (124M) - `q0f16 `__ - - + - - `q0f16 `__ - - + - - `q0f16 `__ - - - - - - + - + - + - - `q0f16 `__ - * - GPT2-med - + (355M) - `q0f16 `__ - - + - - `q0f16 `__ - - + - - `q0f16 `__ - - - - - - + - + - + - - `q0f16 `__ - diff --git a/examples/python/benchmark.py b/examples/python/benchmark.py index 7cdbe7899c..7c897215d1 100644 --- a/examples/python/benchmark.py +++ b/examples/python/benchmark.py @@ -1,4 +1,4 @@ -from mlc_chat import ChatModule +from mlc_llm import ChatModule # From the mlc-llm directory, run # $ python examples/python/benchmark.py diff --git a/examples/python/sample_chat_stream.py b/examples/python/sample_chat_stream.py index 980e833d20..7b6beea0a3 100644 --- a/examples/python/sample_chat_stream.py +++ b/examples/python/sample_chat_stream.py @@ -1,5 +1,5 @@ -from mlc_chat import ChatModule -from mlc_chat.callback import StreamToStdout, StreamIterator +from mlc_llm import ChatModule +from mlc_llm.callback import StreamToStdout, StreamIterator # From the mlc-llm directory, run # $ python examples/python/sample_chat_stream.py diff --git a/examples/python/sample_mlc_chat.py b/examples/python/sample_mlc_chat.py index 6d20d0c1ce..de00e84ff6 100644 --- a/examples/python/sample_mlc_chat.py +++ b/examples/python/sample_mlc_chat.py @@ -1,8 +1,8 @@ -from mlc_chat import ChatModule -from mlc_chat.callback import StreamToStdout +from mlc_llm import ChatModule +from mlc_llm.callback import StreamToStdout # From the mlc-llm directory, run -# $ python examples/python/sample_mlc_chat.py +# $ python examples/python/sample_mlc_llm.py # Create a ChatModule instance cm = ChatModule( diff --git a/examples/rest/nodejs/README.MD b/examples/rest/nodejs/README.MD index 1d63d546cf..419b959ef3 100755 --- a/examples/rest/nodejs/README.MD +++ b/examples/rest/nodejs/README.MD @@ -1,4 +1,4 @@ -# Node/Javascript/Typescript Access Examples for MLC_CHAT REST APIs +# Node/Javascript/Typescript Access Examples for mlc_llm REST APIs Please make sure you are running v18.17.x of node (and npm v9.6.7) -- v20.x currently has some compatibility problems with typescript used in the langchain example. @@ -8,7 +8,7 @@ First install dependencies. Copy `dotenv.exmaple` to `.env`. -To run JS chat completion (both streaming and non-streaming) example: +To run JS chat completion (both streaming and non-streaming) example: `node sample_client.js` diff --git a/examples/rest/python/sample_langchain.py b/examples/rest/python/sample_langchain.py index cda326f470..1bfe80bd26 100644 --- a/examples/rest/python/sample_langchain.py +++ b/examples/rest/python/sample_langchain.py @@ -12,8 +12,7 @@ # Note that Langchain support for embedding documents using MLC is currently blocked on # https://github.com/langchain-ai/langchain/pull/7815 # We have subclassed `OpenAIEmbeddings` in the meantime to get around this dependency. -from mlc_chat.embeddings.openai import MLCEmbeddings - +from mlc_llm.embeddings.openai import MLCEmbeddings # First set the following in your environment: @@ -24,17 +23,19 @@ # https://github.com/langchain-ai/langchain/issues/6841 # Please ensure that your `pydantic` version is < 2.0 + class color: - PURPLE = '\033[95m' - CYAN = '\033[96m' - DARKCYAN = '\033[36m' - BLUE = '\033[94m' - GREEN = '\033[92m' - YELLOW = '\033[93m' - RED = '\033[91m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' - END = '\033[0m' + PURPLE = "\033[95m" + CYAN = "\033[96m" + DARKCYAN = "\033[36m" + BLUE = "\033[94m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + END = "\033[0m" + def llm_chain_example(): template = """ @@ -42,28 +43,29 @@ def llm_chain_example(): USER: {human_input} ASSISTANT:""" - prompt = PromptTemplate( - input_variables=["history", "human_input"], - template=template - ) + prompt = PromptTemplate(input_variables=["history", "human_input"], template=template) llm_chain = LLMChain( llm=ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()]), prompt=prompt, verbose=True, - memory=ConversationBufferWindowMemory(human_prefix="USER", ai_prefix="ASSISTANT") + memory=ConversationBufferWindowMemory(human_prefix="USER", ai_prefix="ASSISTANT"), ) output = llm_chain.predict(human_input="Write a short poem about Pittsburgh.") output = llm_chain.predict(human_input="What does the poem mean?") + def load_qa_chain_example(): - loader = TextLoader('../resources/linux.txt') + loader = TextLoader("../resources/linux.txt") documents = loader.load() chain = load_qa_chain(llm=OpenAI(), chain_type="stuff", verbose=False) query = "When was Linux released?" print(f"{color.BOLD}Query:{color.END} {color.BLUE} {query}{color.END}") - print(f"{color.BOLD}Response:{color.END} {color.GREEN}{chain.run(input_documents=documents, question=query)}{color.END}") + print( + f"{color.BOLD}Response:{color.END} {color.GREEN}{chain.run(input_documents=documents, question=query)}{color.END}" + ) + def retrieval_qa_sotu_example(): prompt_template = """Use only the following pieces of context to answer the question at the end. Don't use any other knowledge. @@ -73,11 +75,9 @@ def retrieval_qa_sotu_example(): USER: {question} ASSISTANT:""" - PROMPT = PromptTemplate( - template=prompt_template, input_variables=["context", "question"] - ) + PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) - loader = TextLoader('../resources/state_of_the_union.txt') + loader = TextLoader("../resources/state_of_the_union.txt") documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100) @@ -85,29 +85,32 @@ def retrieval_qa_sotu_example(): # print(texts) embeddings = MLCEmbeddings(deployment="text-embedding-ada-002", embedding_ctx_length=None) db = Chroma.from_documents(documents=texts, embedding=embeddings) - retriever = db.as_retriever(search_type="similarity", search_kwargs={"k":2}) + retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2}) qa = RetrievalQA.from_chain_type( llm=OpenAI(), chain_type="stuff", retriever=retriever, return_source_documents=True, - chain_type_kwargs={"prompt": PROMPT} + chain_type_kwargs={"prompt": PROMPT}, ) questions = [ "What is the American Rescue Plan?", "What did the president say about Ketanji Brown Jackson?", "Who is mentioned in the speech?", "To whom is the speech addressed?", - "Tell me more about the Made in America campaign." + "Tell me more about the Made in America campaign.", ] for qn in questions: print(f"{color.BOLD}QUESTION:{color.END} {qn}") - res = qa({'query': qn}) + res = qa({"query": qn}) print(f"{color.BOLD}RESPONSE:{color.END} {color.GREEN}{res['result']}{color.END}") - print(f"{color.BOLD}SOURCE:{color.END} {color.BLUE}{repr(res['source_documents'][0].page_content)}{color.END}") + print( + f"{color.BOLD}SOURCE:{color.END} {color.BLUE}{repr(res['source_documents'][0].page_content)}{color.END}" + ) print() + def retrieval_qa_mlc_docs_example(): prompt_template = """Use only the following pieces of context to answer the question at the end. Don't use any other knowledge. @@ -116,29 +119,35 @@ def retrieval_qa_mlc_docs_example(): USER: {question} ASSISTANT:""" - PROMPT = PromptTemplate( - template=prompt_template, input_variables=["context", "question"] - ) + PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) - loader = DirectoryLoader("../../../docs", glob='*/*.rst', show_progress=True, loader_cls=UnstructuredRSTLoader, loader_kwargs={"mode": "single"}) + loader = DirectoryLoader( + "../../../docs", + glob="*/*.rst", + show_progress=True, + loader_cls=UnstructuredRSTLoader, + loader_kwargs={"mode": "single"}, + ) documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) texts = text_splitter.split_documents(documents) embeddings = MLCEmbeddings(deployment="text-embedding-ada-002", embedding_ctx_length=None) db = Chroma.from_documents(collection_name="abc", documents=texts, embedding=embeddings) - retriever = db.as_retriever(search_type="similarity", search_kwargs={"k":3}) + retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) qa = RetrievalQA.from_chain_type( llm=OpenAI(), chain_type="stuff", retriever=retriever, return_source_documents=True, - chain_type_kwargs={"prompt": PROMPT} + chain_type_kwargs={"prompt": PROMPT}, ) while True: qn = input(f"{color.BOLD}QUESTION:{color.END} ") - res = qa({'query': qn}) + res = qa({"query": qn}) print(f"{color.BOLD}RESPONSE:{color.END} {color.GREEN}{res['result']}{color.END}") - print(f"{color.BOLD}SOURCE:{color.END} {color.BLUE}{repr(res['source_documents'][0].page_content)}{color.END}") + print( + f"{color.BOLD}SOURCE:{color.END} {color.BLUE}{repr(res['source_documents'][0].page_content)}{color.END}" + ) print() # Some example questions: diff --git a/pyproject.toml b/pyproject.toml index 1ffd135abf..d52c094ba6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ # under the License. [tool.isort] profile = "black" -src_paths = ["python/mlc_chat"] +src_paths = ["python/mlc_llm"] known_third_party = ["numpy", "tvm", "tqdm", "torch", "transformers"] [tool.black] diff --git a/python/README.md b/python/README.md deleted file mode 100644 index a1866eedab..0000000000 --- a/python/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# MLC-Chat Python Package - -This folder contains the source code of MLC-Chat python package, -please refer to the [REST API](https://llm.mlc.ai/docs/deploy/rest.html) -and [Python API](https://llm.mlc.ai/docs/deploy/python.html) documentation for usage. diff --git a/python/mlc_chat/__init__.py b/python/mlc_llm/__init__.py similarity index 100% rename from python/mlc_chat/__init__.py rename to python/mlc_llm/__init__.py diff --git a/python/mlc_chat/__main__.py b/python/mlc_llm/__main__.py similarity index 75% rename from python/mlc_chat/__main__.py rename to python/mlc_llm/__main__.py index 8cb80a65e0..3888b6839f 100644 --- a/python/mlc_chat/__main__.py +++ b/python/mlc_llm/__main__.py @@ -1,8 +1,8 @@ """Entrypoint of all CLI commands from MLC LLM""" import sys -from mlc_chat.support import logging -from mlc_chat.support.argparse import ArgumentParser +from mlc_llm.support import logging +from mlc_llm.support.argparse import ArgumentParser logging.enable_logging() @@ -19,23 +19,23 @@ def main(): parsed = parser.parse_args(sys.argv[1:2]) # pylint: disable=import-outside-toplevel if parsed.subcommand == "compile": - from mlc_chat.cli import compile as cli + from mlc_llm.cli import compile as cli cli.main(sys.argv[2:]) elif parsed.subcommand == "convert_weight": - from mlc_chat.cli import convert_weight as cli + from mlc_llm.cli import convert_weight as cli cli.main(sys.argv[2:]) elif parsed.subcommand == "gen_config": - from mlc_chat.cli import gen_config as cli + from mlc_llm.cli import gen_config as cli cli.main(sys.argv[2:]) elif parsed.subcommand == "chat": - from mlc_chat.cli import chat as cli + from mlc_llm.cli import chat as cli cli.main(sys.argv[2:]) elif parsed.subcommand == "bench": - from mlc_chat.cli import bench as cli + from mlc_llm.cli import bench as cli cli.main(sys.argv[2:]) else: diff --git a/python/mlc_chat/_ffi_api.py b/python/mlc_llm/_ffi_api.py similarity index 88% rename from python/mlc_chat/_ffi_api.py rename to python/mlc_llm/_ffi_api.py index b0074ad821..ee303681fc 100644 --- a/python/mlc_chat/_ffi_api.py +++ b/python/mlc_llm/_ffi_api.py @@ -1,4 +1,4 @@ -"""FFI APIs for mlc_chat""" +"""FFI APIs for mlc_llm""" import tvm._ffi # Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc" prefix. diff --git a/python/mlc_chat/base.py b/python/mlc_llm/base.py similarity index 100% rename from python/mlc_chat/base.py rename to python/mlc_llm/base.py diff --git a/python/mlc_chat/callback.py b/python/mlc_llm/callback.py similarity index 100% rename from python/mlc_chat/callback.py rename to python/mlc_llm/callback.py diff --git a/python/mlc_chat/chat_module.py b/python/mlc_llm/chat_module.py similarity index 97% rename from python/mlc_chat/chat_module.py rename to python/mlc_llm/chat_module.py index 62ca013569..675e1e7c94 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_llm/chat_module.py @@ -16,14 +16,14 @@ import tvm from tvm.runtime import disco # pylint: disable=unused-import -from mlc_chat.support import logging -from mlc_chat.support.auto_device import detect_device -from mlc_chat.support.config import ConfigBase +from mlc_llm.support import logging +from mlc_llm.support.auto_device import detect_device +from mlc_llm.support.config import ConfigBase from . import base as _ if TYPE_CHECKING: - from mlc_chat.interface.openai_api import ChatMessage + from mlc_llm.interface.openai_api import ChatMessage # pylint: disable=line-too-long _PYTHON_GET_STARTED_TUTORIAL_URL = "https://github.com/mlc-ai/notebooks/blob/main/mlc-llm/tutorial_chat_module_getting_started.ipynb" @@ -37,8 +37,8 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes r"""A dataclass that represents user-defined partial configuration for conversation template. - This is an attribute of :class:`mlc_chat.ChatConfig`, which can then be passed in to the - instantiation of a :class:`mlc_chat.ChatModule` instance to override the default + This is an attribute of :class:`mlc_llm.ChatConfig`, which can then be passed in to the + instantiation of a :class:`mlc_llm.ChatModule` instance to override the default setting in ``mlc-chat-config.json`` under the model folder. Note that we will first load the predefined template with the name specified in ``conv_template``. @@ -104,7 +104,7 @@ class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes chat config file. An instance of ``ChatConfig`` can be passed in to the instantiation of a - :class:`mlc_chat.ChatModule` instance to override the default setting in + :class:`mlc_llm.ChatModule` instance to override the default setting in ``mlc-chat-config.json`` under the model folder. Since the configuration is partial, everything will be ``Optional``. @@ -225,7 +225,7 @@ class GenerationConfig(ConfigBase): # pylint: disable=too-many-instance-attribu r"""A dataclass that represents user-defined generation configuration. An instance of ``GenerationConfig`` can be passed in to the generate function - of a :class:`mlc_chat.ChatModule` instance to override the default generation + of a :class:`mlc_llm.ChatModule` instance to override the default generation setting in ``mlc-chat-config.json`` and ``ChatConfig`` under the model folder. Once the generation ends, ``GenerationConfig`` is discarded, since the values @@ -349,7 +349,7 @@ def _get_model_path(model: str) -> Tuple[str, str]: FileNotFoundError: if we cannot find a valid `model_path`. """ if model.startswith("HF://"): - from mlc_chat.support.download import ( # pylint: disable=import-outside-toplevel + from mlc_llm.support.download import ( # pylint: disable=import-outside-toplevel download_mlc_weights, ) @@ -642,7 +642,7 @@ def _inspect_model_lib_metadata_memory_usage(model_lib_path, config_file_path): cmd = [ sys.executable, "-m", - "mlc_chat.cli.model_metadata", + "mlc_llm.cli.model_metadata", model_lib_path, "--memory-only", "--mlc-chat-config", @@ -659,8 +659,8 @@ class ChatModule: # pylint: disable=too-many-instance-attributes .. code:: python - from mlc_chat import ChatModule - from mlc_chat.callback import StreamToStdout + from mlc_llm import ChatModule + from mlc_llm.callback import StreamToStdout # Create a ChatModule instance cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1") @@ -763,9 +763,7 @@ def __init__( # pylint: disable=too-many-arguments ) except FileNotFoundError: logger.info("Model lib not found. Now compiling model lib on device...") - from mlc_chat.interface import ( # pylint: disable=import-outside-toplevel - jit, - ) + from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel self.model_lib_path = str( jit.jit( @@ -811,7 +809,7 @@ def generate( The generation config object to override the ChatConfig generation settings. progress_callback: object The optional callback method used upon receiving a newly generated message from the - chat module. See `mlc_chat/callback.py` for a full list of available callback classes. + chat module. See `mlc_llm/callback.py` for a full list of available callback classes. Currently, only streaming to stdout callback method is supported, see `Examples` for more detailed usage. @@ -829,7 +827,7 @@ def generate( # the chat module streaming to stdout piece by piece, and in the end we receive the # full response as a single string `output`. - from mlc_chat import ChatModule, GenerationConfig, callback + from mlc_llm import ChatModule, GenerationConfig, callback cm = ChatModule(xxx) prompt = "what's the color of banana?" output = cm.generate( @@ -936,7 +934,7 @@ def benchmark_generate(self, prompt: str, generate_length: int) -> str: .. code:: python - from mlc_chat import ChatModule + from mlc_llm import ChatModule cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1") output = cm.benchmark_generate("What's the meaning of life?", generate_length=256) diff --git a/python/mlc_chat/cli/__init__.py b/python/mlc_llm/cli/__init__.py similarity index 100% rename from python/mlc_chat/cli/__init__.py rename to python/mlc_llm/cli/__init__.py diff --git a/python/mlc_chat/cli/bench.py b/python/mlc_llm/cli/bench.py similarity index 89% rename from python/mlc_chat/cli/bench.py rename to python/mlc_llm/cli/bench.py index 4b9af7c661..26b74b1f10 100644 --- a/python/mlc_chat/cli/bench.py +++ b/python/mlc_llm/cli/bench.py @@ -1,8 +1,8 @@ """Command line entrypoint of benchmark.""" -from mlc_chat.help import HELP -from mlc_chat.interface.bench import bench -from mlc_chat.interface.chat import ChatConfigOverride -from mlc_chat.support.argparse import ArgumentParser +from mlc_llm.help import HELP +from mlc_llm.interface.bench import bench +from mlc_llm.interface.chat import ChatConfigOverride +from mlc_llm.support.argparse import ArgumentParser def main(argv): diff --git a/python/mlc_chat/cli/benchmark.py b/python/mlc_llm/cli/benchmark.py similarity index 98% rename from python/mlc_chat/cli/benchmark.py rename to python/mlc_llm/cli/benchmark.py index e6014aa267..72c86fab03 100644 --- a/python/mlc_chat/cli/benchmark.py +++ b/python/mlc_llm/cli/benchmark.py @@ -2,7 +2,7 @@ import argparse from pathlib import Path -from mlc_chat import ChatConfig, ChatModule +from mlc_llm import ChatConfig, ChatModule parser = argparse.ArgumentParser(description="Benchmark an MLC LLM ChatModule.") parser.add_argument( diff --git a/python/mlc_chat/cli/chat.py b/python/mlc_llm/cli/chat.py similarity index 88% rename from python/mlc_chat/cli/chat.py rename to python/mlc_llm/cli/chat.py index 7ec6efb213..13c83a64ec 100644 --- a/python/mlc_chat/cli/chat.py +++ b/python/mlc_llm/cli/chat.py @@ -1,7 +1,7 @@ """Command line entrypoint of chat.""" -from mlc_chat.help import HELP -from mlc_chat.interface.chat import ChatConfigOverride, chat -from mlc_chat.support.argparse import ArgumentParser +from mlc_llm.help import HELP +from mlc_llm.interface.chat import ChatConfigOverride, chat +from mlc_llm.support.argparse import ArgumentParser def main(argv): diff --git a/python/mlc_chat/cli/check_device.py b/python/mlc_llm/cli/check_device.py similarity index 100% rename from python/mlc_chat/cli/check_device.py rename to python/mlc_llm/cli/check_device.py diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_llm/cli/compile.py similarity index 90% rename from python/mlc_chat/cli/compile.py rename to python/mlc_llm/cli/compile.py index c56b4044b6..7d7025a91f 100644 --- a/python/mlc_chat/cli/compile.py +++ b/python/mlc_llm/cli/compile.py @@ -6,24 +6,21 @@ from pathlib import Path from typing import Union -from mlc_chat.help import HELP -from mlc_chat.interface.compile import ( # pylint: disable=redefined-builtin +from mlc_llm.help import HELP +from mlc_llm.interface.compile import ( # pylint: disable=redefined-builtin ModelConfigOverride, OptimizationFlags, compile, ) -from mlc_chat.model import MODELS -from mlc_chat.quantization import QUANTIZATION -from mlc_chat.support.argparse import ArgumentParser -from mlc_chat.support.auto_config import ( +from mlc_llm.model import MODELS +from mlc_llm.quantization import QUANTIZATION +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.support.auto_config import ( detect_mlc_chat_config, detect_model_type, detect_quantization, ) -from mlc_chat.support.auto_target import ( - detect_system_lib_prefix, - detect_target_and_host, -) +from mlc_llm.support.auto_target import detect_system_lib_prefix, detect_target_and_host def main(argv): @@ -55,7 +52,7 @@ def _check_system_lib_prefix(prefix: str) -> str: "numbers (0-9), alphabets (A-Z, a-z) and underscore (_)." ) - parser = ArgumentParser("mlc_chat compile") + parser = ArgumentParser("mlc_llm compile") parser.add_argument( "model", type=detect_mlc_chat_config, diff --git a/python/mlc_chat/cli/convert_weight.py b/python/mlc_llm/cli/convert_weight.py similarity index 86% rename from python/mlc_chat/cli/convert_weight.py rename to python/mlc_llm/cli/convert_weight.py index 5e97cc7486..08d98c421d 100644 --- a/python/mlc_chat/cli/convert_weight.py +++ b/python/mlc_llm/cli/convert_weight.py @@ -3,14 +3,14 @@ from pathlib import Path from typing import Union -from mlc_chat.help import HELP -from mlc_chat.interface.convert_weight import convert_weight -from mlc_chat.model import MODELS -from mlc_chat.quantization import QUANTIZATION -from mlc_chat.support.argparse import ArgumentParser -from mlc_chat.support.auto_config import detect_config, detect_model_type -from mlc_chat.support.auto_device import detect_device -from mlc_chat.support.auto_weight import detect_weight +from mlc_llm.help import HELP +from mlc_llm.interface.convert_weight import convert_weight +from mlc_llm.model import MODELS +from mlc_llm.quantization import QUANTIZATION +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.support.auto_config import detect_config, detect_model_type +from mlc_llm.support.auto_device import detect_device +from mlc_llm.support.auto_weight import detect_weight def main(argv): diff --git a/python/mlc_chat/cli/delivery.py b/python/mlc_llm/cli/delivery.py similarity index 97% rename from python/mlc_chat/cli/delivery.py rename to python/mlc_llm/cli/delivery.py index cc5fd079df..50b9c7e170 100644 --- a/python/mlc_chat/cli/delivery.py +++ b/python/mlc_llm/cli/delivery.py @@ -12,11 +12,11 @@ from huggingface_hub import HfApi # pylint: disable=import-error from huggingface_hub.utils import HfHubHTTPError # pylint: disable=import-error -from mlc_chat.support import logging -from mlc_chat.support.argparse import ArgumentParser -from mlc_chat.support.constants import MLC_TEMP_DIR -from mlc_chat.support.download import git_clone -from mlc_chat.support.style import bold, green, red +from mlc_llm.support import logging +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.support.constants import MLC_TEMP_DIR +from mlc_llm.support.download import git_clone +from mlc_llm.support.style import bold, green, red logging.enable_logging() logger = logging.getLogger(__name__) @@ -113,7 +113,7 @@ def _run_quantization( cmd = [ sys.executable, "-m", - "mlc_chat", + "mlc_llm", "gen_config", str(model_info.model), "--quantization", @@ -135,7 +135,7 @@ def _run_quantization( cmd = [ sys.executable, "-m", - "mlc_chat", + "mlc_llm", "convert_weight", str(model_info.model), "--quantization", diff --git a/python/mlc_chat/cli/gen_config.py b/python/mlc_llm/cli/gen_config.py similarity index 90% rename from python/mlc_chat/cli/gen_config.py rename to python/mlc_llm/cli/gen_config.py index dd6848499d..b58b546678 100644 --- a/python/mlc_chat/cli/gen_config.py +++ b/python/mlc_llm/cli/gen_config.py @@ -2,12 +2,12 @@ from pathlib import Path from typing import Union -from mlc_chat.help import HELP -from mlc_chat.interface.gen_config import CONV_TEMPLATES, gen_config -from mlc_chat.model import MODELS -from mlc_chat.quantization import QUANTIZATION -from mlc_chat.support.argparse import ArgumentParser -from mlc_chat.support.auto_config import detect_config, detect_model_type +from mlc_llm.help import HELP +from mlc_llm.interface.gen_config import CONV_TEMPLATES, gen_config +from mlc_llm.model import MODELS +from mlc_llm.quantization import QUANTIZATION +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.support.auto_config import detect_config, detect_model_type def main(argv): diff --git a/python/mlc_chat/cli/model_metadata.py b/python/mlc_llm/cli/model_metadata.py similarity index 97% rename from python/mlc_chat/cli/model_metadata.py rename to python/mlc_llm/cli/model_metadata.py index 2ba9e2aa88..9b45561665 100644 --- a/python/mlc_chat/cli/model_metadata.py +++ b/python/mlc_llm/cli/model_metadata.py @@ -8,10 +8,10 @@ import numpy as np -from mlc_chat.support import logging -from mlc_chat.support.argparse import ArgumentParser -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import green, red +from mlc_llm.support import logging +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import green, red logging.enable_logging() logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/cli/worker.py b/python/mlc_llm/cli/worker.py similarity index 100% rename from python/mlc_chat/cli/worker.py rename to python/mlc_llm/cli/worker.py diff --git a/python/mlc_chat/compiler_pass/__init__.py b/python/mlc_llm/compiler_pass/__init__.py similarity index 100% rename from python/mlc_chat/compiler_pass/__init__.py rename to python/mlc_llm/compiler_pass/__init__.py diff --git a/python/mlc_chat/compiler_pass/attach_to_ir_module.py b/python/mlc_llm/compiler_pass/attach_to_ir_module.py similarity index 100% rename from python/mlc_chat/compiler_pass/attach_to_ir_module.py rename to python/mlc_llm/compiler_pass/attach_to_ir_module.py diff --git a/python/mlc_chat/compiler_pass/clean_up_tir_attrs.py b/python/mlc_llm/compiler_pass/clean_up_tir_attrs.py similarity index 100% rename from python/mlc_chat/compiler_pass/clean_up_tir_attrs.py rename to python/mlc_llm/compiler_pass/clean_up_tir_attrs.py diff --git a/python/mlc_chat/compiler_pass/cublas_dispatch.py b/python/mlc_llm/compiler_pass/cublas_dispatch.py similarity index 100% rename from python/mlc_chat/compiler_pass/cublas_dispatch.py rename to python/mlc_llm/compiler_pass/cublas_dispatch.py diff --git a/python/mlc_chat/compiler_pass/dispatch_kv_cache_creation.py b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py similarity index 99% rename from python/mlc_chat/compiler_pass/dispatch_kv_cache_creation.py rename to python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py index 1995b3c517..0c8846d670 100644 --- a/python/mlc_chat/compiler_pass/dispatch_kv_cache_creation.py +++ b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py @@ -5,7 +5,7 @@ import tvm from tvm import IRModule, relax -from mlc_chat.nn import RopeMode, kv_cache +from mlc_llm.nn import RopeMode, kv_cache def extract_creation_args(func: relax.Function) -> Dict[str, Any]: diff --git a/python/mlc_chat/compiler_pass/estimate_memory_usage.py b/python/mlc_llm/compiler_pass/estimate_memory_usage.py similarity index 98% rename from python/mlc_chat/compiler_pass/estimate_memory_usage.py rename to python/mlc_llm/compiler_pass/estimate_memory_usage.py index f3ac747e0f..9b4de3a5cc 100644 --- a/python/mlc_chat/compiler_pass/estimate_memory_usage.py +++ b/python/mlc_llm/compiler_pass/estimate_memory_usage.py @@ -7,7 +7,7 @@ from tvm.ir import IRModule, Op from tvm.relax.expr_functor import PyExprVisitor, visitor -from mlc_chat.support import logging +from mlc_llm.support import logging logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/compiler_pass/fuse_add_norm.py b/python/mlc_llm/compiler_pass/fuse_add_norm.py similarity index 100% rename from python/mlc_chat/compiler_pass/fuse_add_norm.py rename to python/mlc_llm/compiler_pass/fuse_add_norm.py diff --git a/python/mlc_chat/compiler_pass/fuse_dequantize_matmul_ewise.py b/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py similarity index 100% rename from python/mlc_chat/compiler_pass/fuse_dequantize_matmul_ewise.py rename to python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py diff --git a/python/mlc_chat/compiler_pass/fuse_dequantize_take.py b/python/mlc_llm/compiler_pass/fuse_dequantize_take.py similarity index 100% rename from python/mlc_chat/compiler_pass/fuse_dequantize_take.py rename to python/mlc_llm/compiler_pass/fuse_dequantize_take.py diff --git a/python/mlc_chat/compiler_pass/fuse_dequantize_transpose.py b/python/mlc_llm/compiler_pass/fuse_dequantize_transpose.py similarity index 100% rename from python/mlc_chat/compiler_pass/fuse_dequantize_transpose.py rename to python/mlc_llm/compiler_pass/fuse_dequantize_transpose.py diff --git a/python/mlc_chat/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py b/python/mlc_llm/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py similarity index 100% rename from python/mlc_chat/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py rename to python/mlc_llm/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py diff --git a/python/mlc_chat/compiler_pass/fuse_transpose_matmul.py b/python/mlc_llm/compiler_pass/fuse_transpose_matmul.py similarity index 100% rename from python/mlc_chat/compiler_pass/fuse_transpose_matmul.py rename to python/mlc_llm/compiler_pass/fuse_transpose_matmul.py diff --git a/python/mlc_chat/compiler_pass/lift_global_buffer_alloc.py b/python/mlc_llm/compiler_pass/lift_global_buffer_alloc.py similarity index 100% rename from python/mlc_chat/compiler_pass/lift_global_buffer_alloc.py rename to python/mlc_llm/compiler_pass/lift_global_buffer_alloc.py diff --git a/python/mlc_chat/compiler_pass/low_batch_specialization.py b/python/mlc_llm/compiler_pass/low_batch_specialization.py similarity index 100% rename from python/mlc_chat/compiler_pass/low_batch_specialization.py rename to python/mlc_llm/compiler_pass/low_batch_specialization.py diff --git a/python/mlc_chat/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py similarity index 99% rename from python/mlc_chat/compiler_pass/pipeline.py rename to python/mlc_llm/compiler_pass/pipeline.py index e13ff2a404..d8f98b84eb 100644 --- a/python/mlc_chat/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -9,7 +9,7 @@ from tvm.relax import register_pipeline # pylint: disable=no-name-in-module from tvm.relax.frontend import nn -from mlc_chat.support import logging +from mlc_llm.support import logging from .attach_to_ir_module import ( AttachAdditionalPrimFuncs, diff --git a/python/mlc_chat/compiler_pass/scatter_tuple_get_item.py b/python/mlc_llm/compiler_pass/scatter_tuple_get_item.py similarity index 100% rename from python/mlc_chat/compiler_pass/scatter_tuple_get_item.py rename to python/mlc_llm/compiler_pass/scatter_tuple_get_item.py diff --git a/python/mlc_chat/conversation_template.py b/python/mlc_llm/conversation_template.py similarity index 100% rename from python/mlc_chat/conversation_template.py rename to python/mlc_llm/conversation_template.py diff --git a/python/mlc_chat/embeddings/__init__.py b/python/mlc_llm/embeddings/__init__.py similarity index 100% rename from python/mlc_chat/embeddings/__init__.py rename to python/mlc_llm/embeddings/__init__.py diff --git a/python/mlc_chat/embeddings/openai.py b/python/mlc_llm/embeddings/openai.py similarity index 99% rename from python/mlc_chat/embeddings/openai.py rename to python/mlc_llm/embeddings/openai.py index 022d55be70..39f66ef51a 100644 --- a/python/mlc_chat/embeddings/openai.py +++ b/python/mlc_llm/embeddings/openai.py @@ -10,7 +10,7 @@ embed_with_retry, ) -from mlc_chat.support import logging +from mlc_llm.support import logging logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/gradio.py b/python/mlc_llm/gradio.py similarity index 100% rename from python/mlc_chat/gradio.py rename to python/mlc_llm/gradio.py diff --git a/python/mlc_chat/help.py b/python/mlc_llm/help.py similarity index 100% rename from python/mlc_chat/help.py rename to python/mlc_llm/help.py diff --git a/python/mlc_chat/interface/__init__.py b/python/mlc_llm/interface/__init__.py similarity index 100% rename from python/mlc_chat/interface/__init__.py rename to python/mlc_llm/interface/__init__.py diff --git a/python/mlc_chat/interface/bench.py b/python/mlc_llm/interface/bench.py similarity index 93% rename from python/mlc_chat/interface/bench.py rename to python/mlc_llm/interface/bench.py index a1d4e27034..6a7d833447 100644 --- a/python/mlc_chat/interface/bench.py +++ b/python/mlc_llm/interface/bench.py @@ -1,7 +1,7 @@ """Python entrypoint of benchmark.""" from typing import Optional -from mlc_chat.chat_module import ChatConfig, ChatModule +from mlc_llm.chat_module import ChatConfig, ChatModule from .chat import ChatConfigOverride diff --git a/python/mlc_chat/interface/chat.py b/python/mlc_llm/interface/chat.py similarity index 96% rename from python/mlc_chat/interface/chat.py rename to python/mlc_llm/interface/chat.py index cd473f7968..9c0763a6ef 100644 --- a/python/mlc_chat/interface/chat.py +++ b/python/mlc_llm/interface/chat.py @@ -5,10 +5,10 @@ from prompt_toolkit import prompt as get_prompt # pylint: disable=import-error from prompt_toolkit.key_binding import KeyBindings # pylint: disable=import-error -from mlc_chat.callback import StreamToStdout -from mlc_chat.chat_module import ChatConfig, ChatModule, GenerationConfig -from mlc_chat.support import argparse -from mlc_chat.support.config import ConfigOverrideBase +from mlc_llm.callback import StreamToStdout +from mlc_llm.chat_module import ChatConfig, ChatModule, GenerationConfig +from mlc_llm.support import argparse +from mlc_llm.support.config import ConfigOverrideBase @dataclasses.dataclass diff --git a/python/mlc_chat/interface/compile.py b/python/mlc_llm/interface/compile.py similarity index 96% rename from python/mlc_chat/interface/compile.py rename to python/mlc_llm/interface/compile.py index 768871532d..b6052a935a 100644 --- a/python/mlc_chat/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -11,14 +11,14 @@ from tvm.relax.frontend import nn from tvm.target import Target -from mlc_chat import compiler_pass as _ -from mlc_chat import op as op_ext -from mlc_chat.cli.model_metadata import _report_memory_usage -from mlc_chat.model import Model -from mlc_chat.quantization import Quantization -from mlc_chat.support import logging -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import compiler_pass as _ +from mlc_llm import op as op_ext +from mlc_llm.cli.model_metadata import _report_memory_usage +from mlc_llm.model import Model +from mlc_llm.quantization import Quantization +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold from .compiler_flags import ModelConfigOverride, OptimizationFlags diff --git a/python/mlc_chat/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py similarity index 96% rename from python/mlc_chat/interface/compiler_flags.py rename to python/mlc_llm/interface/compiler_flags.py index 7eeedaf6fc..fd820e7124 100644 --- a/python/mlc_chat/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -3,8 +3,8 @@ from io import StringIO from typing import Optional -from mlc_chat.support import argparse, logging -from mlc_chat.support.config import ConfigOverrideBase +from mlc_llm.support import argparse, logging +from mlc_llm.support.config import ConfigOverrideBase logger = logging.getLogger(__name__) @@ -57,7 +57,7 @@ def update(self, target, quantization) -> None: """Update optimization flags based on additional information.""" def _flashinfer(target) -> bool: - from mlc_chat.support.auto_target import ( # pylint: disable=import-outside-toplevel + from mlc_llm.support.auto_target import ( # pylint: disable=import-outside-toplevel detect_cuda_arch_list, ) diff --git a/python/mlc_chat/interface/convert_weight.py b/python/mlc_llm/interface/convert_weight.py similarity index 95% rename from python/mlc_chat/interface/convert_weight.py rename to python/mlc_llm/interface/convert_weight.py index 1e28417eaa..fad6114c6e 100644 --- a/python/mlc_chat/interface/convert_weight.py +++ b/python/mlc_llm/interface/convert_weight.py @@ -13,12 +13,12 @@ from tvm.runtime import cpu as cpu_device from tvm.target import Target -from mlc_chat.loader import LOADER -from mlc_chat.model import Model -from mlc_chat.quantization import Quantization -from mlc_chat.support import logging, tqdm -from mlc_chat.support.preshard import apply_preshard -from mlc_chat.support.style import bold, green +from mlc_llm.loader import LOADER +from mlc_llm.model import Model +from mlc_llm.quantization import Quantization +from mlc_llm.support import logging, tqdm +from mlc_llm.support.preshard import apply_preshard +from mlc_llm.support.style import bold, green logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py similarity index 97% rename from python/mlc_chat/interface/gen_config.py rename to python/mlc_llm/interface/gen_config.py index d45e1daff0..f4d39aa8ba 100644 --- a/python/mlc_chat/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -6,10 +6,10 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from mlc_chat.model import Model -from mlc_chat.quantization import Quantization -from mlc_chat.support import convert_tiktoken, logging -from mlc_chat.support.style import bold, green, red +from mlc_llm.model import Model +from mlc_llm.quantization import Quantization +from mlc_llm.support import convert_tiktoken, logging +from mlc_llm.support.style import bold, green, red from .compiler_flags import ModelConfigOverride diff --git a/python/mlc_chat/interface/jit.py b/python/mlc_llm/interface/jit.py similarity index 94% rename from python/mlc_chat/interface/jit.py rename to python/mlc_llm/interface/jit.py index 6d9b131c67..06a22eb8fd 100644 --- a/python/mlc_chat/interface/jit.py +++ b/python/mlc_llm/interface/jit.py @@ -13,16 +13,16 @@ from tvm.runtime import Device -from mlc_chat.model import MODELS -from mlc_chat.support import logging -from mlc_chat.support.auto_device import device2str -from mlc_chat.support.constants import ( +from mlc_llm.model import MODELS +from mlc_llm.support import logging +from mlc_llm.support.auto_device import device2str +from mlc_llm.support.constants import ( MLC_CACHE_DIR, MLC_DSO_SUFFIX, MLC_JIT_POLICY, MLC_TEMP_DIR, ) -from mlc_chat.support.style import blue, bold +from mlc_llm.support.style import blue, bold from .compiler_flags import ModelConfigOverride, OptimizationFlags @@ -78,7 +78,7 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): cmd = [ sys.executable, "-m", - "mlc_chat", + "mlc_llm", "compile", str(model_path), "--opt", diff --git a/python/mlc_chat/interface/openai_api.py b/python/mlc_llm/interface/openai_api.py similarity index 100% rename from python/mlc_chat/interface/openai_api.py rename to python/mlc_llm/interface/openai_api.py diff --git a/python/mlc_chat/libinfo.py b/python/mlc_llm/libinfo.py similarity index 100% rename from python/mlc_chat/libinfo.py rename to python/mlc_llm/libinfo.py diff --git a/python/mlc_chat/loader/__init__.py b/python/mlc_llm/loader/__init__.py similarity index 100% rename from python/mlc_chat/loader/__init__.py rename to python/mlc_llm/loader/__init__.py diff --git a/python/mlc_chat/loader/huggingface_loader.py b/python/mlc_llm/loader/huggingface_loader.py similarity index 98% rename from python/mlc_chat/loader/huggingface_loader.py rename to python/mlc_llm/loader/huggingface_loader.py index 5334242c6e..1f72197150 100644 --- a/python/mlc_chat/loader/huggingface_loader.py +++ b/python/mlc_llm/loader/huggingface_loader.py @@ -10,9 +10,9 @@ from tvm.runtime import Device, NDArray from tvm.runtime.ndarray import array as as_ndarray -from mlc_chat.support import logging -from mlc_chat.support.preshard import _sharded_param_name -from mlc_chat.support.style import bold +from mlc_llm.support import logging +from mlc_llm.support.preshard import _sharded_param_name +from mlc_llm.support.style import bold from .mapping import ExternMapping, QuantizeMapping from .stats import Stats diff --git a/python/mlc_chat/loader/loader.py b/python/mlc_llm/loader/loader.py similarity index 100% rename from python/mlc_chat/loader/loader.py rename to python/mlc_llm/loader/loader.py diff --git a/python/mlc_chat/loader/mapping.py b/python/mlc_llm/loader/mapping.py similarity index 100% rename from python/mlc_chat/loader/mapping.py rename to python/mlc_llm/loader/mapping.py diff --git a/python/mlc_chat/loader/stats.py b/python/mlc_llm/loader/stats.py similarity index 97% rename from python/mlc_chat/loader/stats.py rename to python/mlc_llm/loader/stats.py index 6a97cf993c..4710e47307 100644 --- a/python/mlc_chat/loader/stats.py +++ b/python/mlc_llm/loader/stats.py @@ -3,8 +3,8 @@ import time from contextlib import contextmanager -from mlc_chat.support import logging -from mlc_chat.support.style import green +from mlc_llm.support import logging +from mlc_llm.support.style import green logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/loader/utils.py b/python/mlc_llm/loader/utils.py similarity index 98% rename from python/mlc_chat/loader/utils.py rename to python/mlc_llm/loader/utils.py index b35f9a934d..a838841b7e 100644 --- a/python/mlc_chat/loader/utils.py +++ b/python/mlc_llm/loader/utils.py @@ -5,7 +5,7 @@ import numpy as np -from mlc_chat.support import logging +from mlc_llm.support import logging if TYPE_CHECKING: from tvm.runtime import NDArray diff --git a/python/mlc_chat/model/__init__.py b/python/mlc_llm/model/__init__.py similarity index 100% rename from python/mlc_chat/model/__init__.py rename to python/mlc_llm/model/__init__.py diff --git a/python/mlc_chat/model/baichuan/__init__.py b/python/mlc_llm/model/baichuan/__init__.py similarity index 100% rename from python/mlc_chat/model/baichuan/__init__.py rename to python/mlc_llm/model/baichuan/__init__.py diff --git a/python/mlc_chat/model/baichuan/baichuan_loader.py b/python/mlc_llm/model/baichuan/baichuan_loader.py similarity index 92% rename from python/mlc_chat/model/baichuan/baichuan_loader.py rename to python/mlc_llm/model/baichuan/baichuan_loader.py index 2807060438..6114cc1b71 100644 --- a/python/mlc_chat/model/baichuan/baichuan_loader.py +++ b/python/mlc_llm/model/baichuan/baichuan_loader.py @@ -7,8 +7,8 @@ import numpy as np -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .baichuan_model import BaichuanConfig, BaichuanForCausalLM diff --git a/python/mlc_chat/model/baichuan/baichuan_model.py b/python/mlc_llm/model/baichuan/baichuan_model.py similarity index 95% rename from python/mlc_chat/model/baichuan/baichuan_model.py rename to python/mlc_llm/model/baichuan/baichuan_model.py index 266d9678c3..334c32d7d5 100644 --- a/python/mlc_chat/model/baichuan/baichuan_model.py +++ b/python/mlc_llm/model/baichuan/baichuan_model.py @@ -10,11 +10,11 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.nn import PagedKVCache, RopeMode -from mlc_chat.support import logging -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/baichuan/baichuan_quantization.py b/python/mlc_llm/model/baichuan/baichuan_quantization.py similarity index 89% rename from python/mlc_chat/model/baichuan/baichuan_quantization.py rename to python/mlc_llm/model/baichuan/baichuan_quantization.py index 2558942ba7..70522b599d 100644 --- a/python/mlc_chat/model/baichuan/baichuan_quantization.py +++ b/python/mlc_llm/model/baichuan/baichuan_quantization.py @@ -4,8 +4,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize from .baichuan_model import BaichuanConfig, BaichuanForCausalLM diff --git a/python/mlc_chat/model/gemma/__init__.py b/python/mlc_llm/model/gemma/__init__.py similarity index 100% rename from python/mlc_chat/model/gemma/__init__.py rename to python/mlc_llm/model/gemma/__init__.py diff --git a/python/mlc_chat/model/gemma/gemma_loader.py b/python/mlc_llm/model/gemma/gemma_loader.py similarity index 97% rename from python/mlc_chat/model/gemma/gemma_loader.py rename to python/mlc_llm/model/gemma/gemma_loader.py index c839978147..6910b40af0 100644 --- a/python/mlc_chat/model/gemma/gemma_loader.py +++ b/python/mlc_llm/model/gemma/gemma_loader.py @@ -7,8 +7,8 @@ import numpy as np -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .gemma_model import GemmaConfig, GemmaForCausalLM diff --git a/python/mlc_chat/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py similarity index 98% rename from python/mlc_chat/model/gemma/gemma_model.py rename to python/mlc_llm/model/gemma/gemma_model.py index 94768a0d89..9303e2552e 100644 --- a/python/mlc_chat/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -7,12 +7,12 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.nn import PagedKVCache, RopeMode -from mlc_chat.support import logging -from mlc_chat.support import tensor_parallel as tp -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/gemma/gemma_quantization.py b/python/mlc_llm/model/gemma/gemma_quantization.py similarity index 90% rename from python/mlc_chat/model/gemma/gemma_quantization.py rename to python/mlc_llm/model/gemma/gemma_quantization.py index 28b42343a4..9108dbc1ff 100644 --- a/python/mlc_chat/model/gemma/gemma_quantization.py +++ b/python/mlc_llm/model/gemma/gemma_quantization.py @@ -5,8 +5,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import GroupQuantize, NoQuantize from .gemma_model import GemmaConfig, GemmaForCausalLM diff --git a/python/mlc_chat/model/gpt2/__init__.py b/python/mlc_llm/model/gpt2/__init__.py similarity index 100% rename from python/mlc_chat/model/gpt2/__init__.py rename to python/mlc_llm/model/gpt2/__init__.py diff --git a/python/mlc_chat/model/gpt2/gpt2_loader.py b/python/mlc_llm/model/gpt2/gpt2_loader.py similarity index 96% rename from python/mlc_chat/model/gpt2/gpt2_loader.py rename to python/mlc_llm/model/gpt2/gpt2_loader.py index 43c4ff14e1..0c28461242 100644 --- a/python/mlc_chat/model/gpt2/gpt2_loader.py +++ b/python/mlc_llm/model/gpt2/gpt2_loader.py @@ -4,8 +4,8 @@ """ import functools -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .gpt2_model import GPT2Config, GPT2LMHeadModel diff --git a/python/mlc_chat/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py similarity index 98% rename from python/mlc_chat/model/gpt2/gpt2_model.py rename to python/mlc_llm/model/gpt2/gpt2_model.py index 83f65502f8..cf2a967cac 100644 --- a/python/mlc_chat/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -10,12 +10,12 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.nn import PagedKVCache, RopeMode -from mlc_chat.support import logging -from mlc_chat.support import tensor_parallel as tp -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/gpt2/gpt2_quantization.py b/python/mlc_llm/model/gpt2/gpt2_quantization.py similarity index 93% rename from python/mlc_chat/model/gpt2/gpt2_quantization.py rename to python/mlc_llm/model/gpt2/gpt2_quantization.py index b953d8cd84..9d8ce427d4 100644 --- a/python/mlc_chat/model/gpt2/gpt2_quantization.py +++ b/python/mlc_llm/model/gpt2/gpt2_quantization.py @@ -4,8 +4,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize from .gpt2_model import GPT2Config, GPT2LMHeadModel diff --git a/python/mlc_chat/model/gpt_bigcode/__init__.py b/python/mlc_llm/model/gpt_bigcode/__init__.py similarity index 100% rename from python/mlc_chat/model/gpt_bigcode/__init__.py rename to python/mlc_llm/model/gpt_bigcode/__init__.py diff --git a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_loader.py similarity index 94% rename from python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py rename to python/mlc_llm/model/gpt_bigcode/gpt_bigcode_loader.py index 1504719045..0c07a7768e 100644 --- a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_loader.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_loader.py @@ -5,8 +5,8 @@ import functools -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .gpt_bigcode_model import GPTBigCodeConfig, GPTBigCodeForCausalLM diff --git a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py similarity index 98% rename from python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py rename to python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py index 302b093125..d98871964f 100644 --- a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py @@ -10,12 +10,12 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.nn import PagedKVCache, RopeMode -from mlc_chat.support import logging -from mlc_chat.support import tensor_parallel as tp -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_quantization.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_quantization.py similarity index 93% rename from python/mlc_chat/model/gpt_bigcode/gpt_bigcode_quantization.py rename to python/mlc_llm/model/gpt_bigcode/gpt_bigcode_quantization.py index 021cc0872a..78d68f501a 100644 --- a/python/mlc_chat/model/gpt_bigcode/gpt_bigcode_quantization.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_quantization.py @@ -5,8 +5,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize from .gpt_bigcode_model import GPTBigCodeConfig, GPTBigCodeForCausalLM diff --git a/python/mlc_chat/model/gpt_neox/__init__.py b/python/mlc_llm/model/gpt_neox/__init__.py similarity index 100% rename from python/mlc_chat/model/gpt_neox/__init__.py rename to python/mlc_llm/model/gpt_neox/__init__.py diff --git a/python/mlc_chat/model/gpt_neox/gpt_neox_loader.py b/python/mlc_llm/model/gpt_neox/gpt_neox_loader.py similarity index 97% rename from python/mlc_chat/model/gpt_neox/gpt_neox_loader.py rename to python/mlc_llm/model/gpt_neox/gpt_neox_loader.py index b7e4027ce2..7f4d5f56c4 100644 --- a/python/mlc_chat/model/gpt_neox/gpt_neox_loader.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_loader.py @@ -6,8 +6,8 @@ import numpy as np -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .gpt_neox_model import GPTNeoXConfig, GPTNeoXForCausalLM diff --git a/python/mlc_chat/model/gpt_neox/gpt_neox_model.py b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py similarity index 98% rename from python/mlc_chat/model/gpt_neox/gpt_neox_model.py rename to python/mlc_llm/model/gpt_neox/gpt_neox_model.py index 895655d60b..0a0c494685 100644 --- a/python/mlc_chat/model/gpt_neox/gpt_neox_model.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py @@ -11,11 +11,11 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.nn import PagedKVCache, RopeMode -from mlc_chat.support import tensor_parallel as tp -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/gpt_neox/gpt_neox_quantization.py b/python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py similarity index 92% rename from python/mlc_chat/model/gpt_neox/gpt_neox_quantization.py rename to python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py index 9f1daaf42b..f751426708 100644 --- a/python/mlc_chat/model/gpt_neox/gpt_neox_quantization.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py @@ -4,8 +4,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize from .gpt_neox_model import GPTNeoXConfig, GPTNeoXForCausalLM diff --git a/python/mlc_chat/model/internlm/__init__.py b/python/mlc_llm/model/internlm/__init__.py similarity index 100% rename from python/mlc_chat/model/internlm/__init__.py rename to python/mlc_llm/model/internlm/__init__.py diff --git a/python/mlc_chat/model/internlm/internlm_loader.py b/python/mlc_llm/model/internlm/internlm_loader.py similarity index 97% rename from python/mlc_chat/model/internlm/internlm_loader.py rename to python/mlc_llm/model/internlm/internlm_loader.py index 7e80aeeb64..60039d7fc6 100644 --- a/python/mlc_chat/model/internlm/internlm_loader.py +++ b/python/mlc_llm/model/internlm/internlm_loader.py @@ -7,8 +7,8 @@ import numpy as np -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .internlm_model import InternLMConfig, InternLMForCausalLM diff --git a/python/mlc_chat/model/internlm/internlm_model.py b/python/mlc_llm/model/internlm/internlm_model.py similarity index 98% rename from python/mlc_chat/model/internlm/internlm_model.py rename to python/mlc_llm/model/internlm/internlm_model.py index 153905f55e..cf39437dd6 100644 --- a/python/mlc_chat/model/internlm/internlm_model.py +++ b/python/mlc_llm/model/internlm/internlm_model.py @@ -10,11 +10,11 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.nn import PagedKVCache, RopeMode -from mlc_chat.support import logging -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/internlm/internlm_quantization.py b/python/mlc_llm/model/internlm/internlm_quantization.py similarity index 92% rename from python/mlc_chat/model/internlm/internlm_quantization.py rename to python/mlc_llm/model/internlm/internlm_quantization.py index 22f2eae2f5..114e9e193e 100644 --- a/python/mlc_chat/model/internlm/internlm_quantization.py +++ b/python/mlc_llm/model/internlm/internlm_quantization.py @@ -4,8 +4,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize from .internlm_model import InternLMConfig, InternLMForCausalLM diff --git a/python/mlc_chat/model/llama/__init__.py b/python/mlc_llm/model/llama/__init__.py similarity index 100% rename from python/mlc_chat/model/llama/__init__.py rename to python/mlc_llm/model/llama/__init__.py diff --git a/python/mlc_chat/model/llama/llama_loader.py b/python/mlc_llm/model/llama/llama_loader.py similarity index 98% rename from python/mlc_chat/model/llama/llama_loader.py rename to python/mlc_llm/model/llama/llama_loader.py index 5dd902d04d..070753bc2b 100644 --- a/python/mlc_chat/model/llama/llama_loader.py +++ b/python/mlc_llm/model/llama/llama_loader.py @@ -6,8 +6,8 @@ import numpy as np -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .llama_model import LlamaConfig, LlamaForCasualLM from .llama_quantization import awq_quant diff --git a/python/mlc_chat/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py similarity index 98% rename from python/mlc_chat/model/llama/llama_model.py rename to python/mlc_llm/model/llama/llama_model.py index 69884e8492..fb5f5637b8 100644 --- a/python/mlc_chat/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -10,12 +10,12 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.nn import PagedKVCache, RopeMode -from mlc_chat.support import logging -from mlc_chat.support import tensor_parallel as tp -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/llama/llama_quantization.py b/python/mlc_llm/model/llama/llama_quantization.py similarity index 93% rename from python/mlc_chat/model/llama/llama_quantization.py rename to python/mlc_llm/model/llama/llama_quantization.py index 0460c98b51..cf67288585 100644 --- a/python/mlc_chat/model/llama/llama_quantization.py +++ b/python/mlc_llm/model/llama/llama_quantization.py @@ -4,8 +4,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize from .llama_model import LlamaConfig, LlamaForCasualLM diff --git a/python/mlc_chat/model/mistral/__init__.py b/python/mlc_llm/model/mistral/__init__.py similarity index 100% rename from python/mlc_chat/model/mistral/__init__.py rename to python/mlc_llm/model/mistral/__init__.py diff --git a/python/mlc_chat/model/mistral/mistral_loader.py b/python/mlc_llm/model/mistral/mistral_loader.py similarity index 98% rename from python/mlc_chat/model/mistral/mistral_loader.py rename to python/mlc_llm/model/mistral/mistral_loader.py index 71a8f1abe9..d9748f1fc5 100644 --- a/python/mlc_chat/model/mistral/mistral_loader.py +++ b/python/mlc_llm/model/mistral/mistral_loader.py @@ -6,8 +6,8 @@ import numpy as np -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .mistral_model import MistralConfig, MistralForCasualLM from .mistral_quantization import awq_quant diff --git a/python/mlc_chat/model/mistral/mistral_model.py b/python/mlc_llm/model/mistral/mistral_model.py similarity index 98% rename from python/mlc_chat/model/mistral/mistral_model.py rename to python/mlc_llm/model/mistral/mistral_model.py index d2b5c57bf2..9374df595c 100644 --- a/python/mlc_chat/model/mistral/mistral_model.py +++ b/python/mlc_llm/model/mistral/mistral_model.py @@ -9,11 +9,11 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.support import logging -from mlc_chat.support import tensor_parallel as tp -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/mistral/mistral_quantization.py b/python/mlc_llm/model/mistral/mistral_quantization.py similarity index 93% rename from python/mlc_chat/model/mistral/mistral_quantization.py rename to python/mlc_llm/model/mistral/mistral_quantization.py index e3622fda29..7efaa00b06 100644 --- a/python/mlc_chat/model/mistral/mistral_quantization.py +++ b/python/mlc_llm/model/mistral/mistral_quantization.py @@ -4,8 +4,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize from .mistral_model import MistralConfig, MistralForCasualLM diff --git a/python/mlc_chat/model/mixtral/__init__.py b/python/mlc_llm/model/mixtral/__init__.py similarity index 100% rename from python/mlc_chat/model/mixtral/__init__.py rename to python/mlc_llm/model/mixtral/__init__.py diff --git a/python/mlc_chat/model/mixtral/mixtral_loader.py b/python/mlc_llm/model/mixtral/mixtral_loader.py similarity index 97% rename from python/mlc_chat/model/mixtral/mixtral_loader.py rename to python/mlc_llm/model/mixtral/mixtral_loader.py index 12e96ebad2..dad152b784 100644 --- a/python/mlc_chat/model/mixtral/mixtral_loader.py +++ b/python/mlc_llm/model/mixtral/mixtral_loader.py @@ -6,8 +6,8 @@ import numpy as np -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .mixtral_model import MixtralConfig, MixtralForCasualLM diff --git a/python/mlc_chat/model/mixtral/mixtral_model.py b/python/mlc_llm/model/mixtral/mixtral_model.py similarity index 96% rename from python/mlc_chat/model/mixtral/mixtral_model.py rename to python/mlc_llm/model/mixtral/mixtral_model.py index 2a707b0a77..3f41988788 100644 --- a/python/mlc_chat/model/mixtral/mixtral_model.py +++ b/python/mlc_llm/model/mixtral/mixtral_model.py @@ -6,17 +6,17 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.model.llama.llama_model import ( +from mlc_llm import op as op_ext +from mlc_llm.model.llama.llama_model import ( LlamaAttention, LlamaConfig, LlamaForCasualLM, LlamaModel, ) -from mlc_chat.nn import PagedKVCache -from mlc_chat.nn.expert import MixtralExperts -from mlc_chat.support import logging -from mlc_chat.support import tensor_parallel as tp +from mlc_llm.nn import PagedKVCache +from mlc_llm.nn.expert import MixtralExperts +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/mixtral/mixtral_quantization.py b/python/mlc_llm/model/mixtral/mixtral_quantization.py similarity index 93% rename from python/mlc_chat/model/mixtral/mixtral_quantization.py rename to python/mlc_llm/model/mixtral/mixtral_quantization.py index 37f7ad5f55..0e8130e051 100644 --- a/python/mlc_chat/model/mixtral/mixtral_quantization.py +++ b/python/mlc_llm/model/mixtral/mixtral_quantization.py @@ -4,8 +4,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize from .mixtral_model import MixtralConfig, MixtralForCasualLM diff --git a/python/mlc_chat/model/model.py b/python/mlc_llm/model/model.py similarity index 98% rename from python/mlc_chat/model/model.py rename to python/mlc_llm/model/model.py index ef67c8e5ab..607cec2918 100644 --- a/python/mlc_chat/model/model.py +++ b/python/mlc_llm/model/model.py @@ -5,8 +5,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import ExternMapping, QuantizeMapping -from mlc_chat.quantization.quantization import Quantization +from mlc_llm.loader import ExternMapping, QuantizeMapping +from mlc_llm.quantization.quantization import Quantization from .baichuan import baichuan_loader, baichuan_model, baichuan_quantization from .gemma import gemma_loader, gemma_model, gemma_quantization diff --git a/python/mlc_chat/model/model_preset.py b/python/mlc_llm/model/model_preset.py similarity index 100% rename from python/mlc_chat/model/model_preset.py rename to python/mlc_llm/model/model_preset.py diff --git a/python/mlc_chat/model/orion/__init__.py b/python/mlc_llm/model/orion/__init__.py similarity index 100% rename from python/mlc_chat/model/orion/__init__.py rename to python/mlc_llm/model/orion/__init__.py diff --git a/python/mlc_chat/model/orion/orion_loader.py b/python/mlc_llm/model/orion/orion_loader.py similarity index 96% rename from python/mlc_chat/model/orion/orion_loader.py rename to python/mlc_llm/model/orion/orion_loader.py index 61c8138634..d735052ba9 100644 --- a/python/mlc_chat/model/orion/orion_loader.py +++ b/python/mlc_llm/model/orion/orion_loader.py @@ -6,8 +6,8 @@ import numpy as np -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .orion_model import OrionConfig, OrionForCasualLM diff --git a/python/mlc_chat/model/orion/orion_model.py b/python/mlc_llm/model/orion/orion_model.py similarity index 98% rename from python/mlc_chat/model/orion/orion_model.py rename to python/mlc_llm/model/orion/orion_model.py index 5894a5ab61..9964ab911f 100644 --- a/python/mlc_chat/model/orion/orion_model.py +++ b/python/mlc_llm/model/orion/orion_model.py @@ -10,12 +10,12 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.nn import PagedKVCache, RopeMode -from mlc_chat.support import logging -from mlc_chat.support import tensor_parallel as tp -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/orion/orion_quantization.py b/python/mlc_llm/model/orion/orion_quantization.py similarity index 90% rename from python/mlc_chat/model/orion/orion_quantization.py rename to python/mlc_llm/model/orion/orion_quantization.py index d34f59b2dd..740253351b 100644 --- a/python/mlc_chat/model/orion/orion_quantization.py +++ b/python/mlc_llm/model/orion/orion_quantization.py @@ -4,8 +4,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import GroupQuantize, NoQuantize from .orion_model import OrionConfig, OrionForCasualLM diff --git a/python/mlc_chat/model/phi/__init__.py b/python/mlc_llm/model/phi/__init__.py similarity index 100% rename from python/mlc_chat/model/phi/__init__.py rename to python/mlc_llm/model/phi/__init__.py diff --git a/python/mlc_chat/model/phi/phi_loader.py b/python/mlc_llm/model/phi/phi_loader.py similarity index 98% rename from python/mlc_chat/model/phi/phi_loader.py rename to python/mlc_llm/model/phi/phi_loader.py index d393c61f2e..70b277c6b2 100644 --- a/python/mlc_chat/model/phi/phi_loader.py +++ b/python/mlc_llm/model/phi/phi_loader.py @@ -6,8 +6,8 @@ import numpy as np -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .phi_model import Phi1Config, PhiConfig, PhiForCausalLM diff --git a/python/mlc_chat/model/phi/phi_model.py b/python/mlc_llm/model/phi/phi_model.py similarity index 98% rename from python/mlc_chat/model/phi/phi_model.py rename to python/mlc_llm/model/phi/phi_model.py index 372598d5ae..0b3f3f092f 100644 --- a/python/mlc_chat/model/phi/phi_model.py +++ b/python/mlc_llm/model/phi/phi_model.py @@ -10,12 +10,12 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.nn import PagedKVCache, RopeMode -from mlc_chat.support import logging -from mlc_chat.support import tensor_parallel as tp -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/phi/phi_quantization.py b/python/mlc_llm/model/phi/phi_quantization.py similarity index 92% rename from python/mlc_chat/model/phi/phi_quantization.py rename to python/mlc_llm/model/phi/phi_quantization.py index 52089c26ba..3a620d0200 100644 --- a/python/mlc_chat/model/phi/phi_quantization.py +++ b/python/mlc_llm/model/phi/phi_quantization.py @@ -4,8 +4,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize from .phi_model import PhiConfig, PhiForCausalLM diff --git a/python/mlc_chat/model/qwen/__init__.py b/python/mlc_llm/model/qwen/__init__.py similarity index 100% rename from python/mlc_chat/model/qwen/__init__.py rename to python/mlc_llm/model/qwen/__init__.py diff --git a/python/mlc_chat/model/qwen/qwen_loader.py b/python/mlc_llm/model/qwen/qwen_loader.py similarity index 95% rename from python/mlc_chat/model/qwen/qwen_loader.py rename to python/mlc_llm/model/qwen/qwen_loader.py index 810efedb35..5b5f8fe5be 100644 --- a/python/mlc_chat/model/qwen/qwen_loader.py +++ b/python/mlc_llm/model/qwen/qwen_loader.py @@ -6,8 +6,8 @@ import numpy as np -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .qwen_model import QWenConfig, QWenLMHeadModel diff --git a/python/mlc_chat/model/qwen/qwen_model.py b/python/mlc_llm/model/qwen/qwen_model.py similarity index 98% rename from python/mlc_chat/model/qwen/qwen_model.py rename to python/mlc_llm/model/qwen/qwen_model.py index b5879a92a2..54157c7eb3 100644 --- a/python/mlc_chat/model/qwen/qwen_model.py +++ b/python/mlc_llm/model/qwen/qwen_model.py @@ -10,11 +10,11 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.nn import PagedKVCache, RopeMode -from mlc_chat.support import logging -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/qwen/qwen_quantization.py b/python/mlc_llm/model/qwen/qwen_quantization.py similarity index 92% rename from python/mlc_chat/model/qwen/qwen_quantization.py rename to python/mlc_llm/model/qwen/qwen_quantization.py index c69f5835ef..862cd6fd8c 100644 --- a/python/mlc_chat/model/qwen/qwen_quantization.py +++ b/python/mlc_llm/model/qwen/qwen_quantization.py @@ -4,8 +4,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize from .qwen_model import QWenConfig, QWenLMHeadModel diff --git a/python/mlc_chat/model/qwen2/__init__.py b/python/mlc_llm/model/qwen2/__init__.py similarity index 100% rename from python/mlc_chat/model/qwen2/__init__.py rename to python/mlc_llm/model/qwen2/__init__.py diff --git a/python/mlc_chat/model/qwen2/qwen2_loader.py b/python/mlc_llm/model/qwen2/qwen2_loader.py similarity index 96% rename from python/mlc_chat/model/qwen2/qwen2_loader.py rename to python/mlc_llm/model/qwen2/qwen2_loader.py index 559a911316..0a421b5f64 100644 --- a/python/mlc_chat/model/qwen2/qwen2_loader.py +++ b/python/mlc_llm/model/qwen2/qwen2_loader.py @@ -7,8 +7,8 @@ import numpy as np -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .qwen2_model import QWen2Config, QWen2LMHeadModel diff --git a/python/mlc_chat/model/qwen2/qwen2_model.py b/python/mlc_llm/model/qwen2/qwen2_model.py similarity index 98% rename from python/mlc_chat/model/qwen2/qwen2_model.py rename to python/mlc_llm/model/qwen2/qwen2_model.py index a5dc351a9e..ad55c83bb4 100644 --- a/python/mlc_chat/model/qwen2/qwen2_model.py +++ b/python/mlc_llm/model/qwen2/qwen2_model.py @@ -10,11 +10,11 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.nn import PagedKVCache, RopeMode -from mlc_chat.support import logging -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/qwen2/qwen2_quantization.py b/python/mlc_llm/model/qwen2/qwen2_quantization.py similarity index 92% rename from python/mlc_chat/model/qwen2/qwen2_quantization.py rename to python/mlc_llm/model/qwen2/qwen2_quantization.py index a59802dd57..b5e3791331 100644 --- a/python/mlc_chat/model/qwen2/qwen2_quantization.py +++ b/python/mlc_llm/model/qwen2/qwen2_quantization.py @@ -5,8 +5,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize from .qwen2_model import QWen2Config, QWen2LMHeadModel diff --git a/python/mlc_chat/model/rwkv5/__init__.py b/python/mlc_llm/model/rwkv5/__init__.py similarity index 100% rename from python/mlc_chat/model/rwkv5/__init__.py rename to python/mlc_llm/model/rwkv5/__init__.py diff --git a/python/mlc_chat/model/rwkv5/rwkv5_loader.py b/python/mlc_llm/model/rwkv5/rwkv5_loader.py similarity index 100% rename from python/mlc_chat/model/rwkv5/rwkv5_loader.py rename to python/mlc_llm/model/rwkv5/rwkv5_loader.py diff --git a/python/mlc_chat/model/rwkv5/rwkv5_model.py b/python/mlc_llm/model/rwkv5/rwkv5_model.py similarity index 99% rename from python/mlc_chat/model/rwkv5/rwkv5_model.py rename to python/mlc_llm/model/rwkv5/rwkv5_model.py index e88efa4aec..49386720da 100644 --- a/python/mlc_chat/model/rwkv5/rwkv5_model.py +++ b/python/mlc_llm/model/rwkv5/rwkv5_model.py @@ -8,9 +8,9 @@ from tvm.relax.frontend.nn import Object, Tensor, op from tvm.script import tir as T -from mlc_chat.nn.rnn_state import RNNState -from mlc_chat.support import logging -from mlc_chat.support.config import ConfigBase +from mlc_llm.nn.rnn_state import RNNState +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/rwkv5/rwkv5_quantization.py b/python/mlc_llm/model/rwkv5/rwkv5_quantization.py similarity index 100% rename from python/mlc_chat/model/rwkv5/rwkv5_quantization.py rename to python/mlc_llm/model/rwkv5/rwkv5_quantization.py diff --git a/python/mlc_chat/model/stable_lm/__init__.py b/python/mlc_llm/model/stable_lm/__init__.py similarity index 100% rename from python/mlc_chat/model/stable_lm/__init__.py rename to python/mlc_llm/model/stable_lm/__init__.py diff --git a/python/mlc_chat/model/stable_lm/stablelm_loader.py b/python/mlc_llm/model/stable_lm/stablelm_loader.py similarity index 97% rename from python/mlc_chat/model/stable_lm/stablelm_loader.py rename to python/mlc_llm/model/stable_lm/stablelm_loader.py index d2cc4d93c8..b5764947d3 100644 --- a/python/mlc_chat/model/stable_lm/stablelm_loader.py +++ b/python/mlc_llm/model/stable_lm/stablelm_loader.py @@ -7,8 +7,8 @@ import numpy as np -from mlc_chat.loader import ExternMapping -from mlc_chat.quantization import Quantization +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization from .stablelm_model import StableLmConfig, StableLmForCausalLM diff --git a/python/mlc_chat/model/stable_lm/stablelm_model.py b/python/mlc_llm/model/stable_lm/stablelm_model.py similarity index 98% rename from python/mlc_chat/model/stable_lm/stablelm_model.py rename to python/mlc_llm/model/stable_lm/stablelm_model.py index 8193c15ccc..b32372ce6d 100644 --- a/python/mlc_chat/model/stable_lm/stablelm_model.py +++ b/python/mlc_llm/model/stable_lm/stablelm_model.py @@ -10,11 +10,11 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from mlc_chat import op as op_ext -from mlc_chat.nn import PagedKVCache, RopeMode -from mlc_chat.support import logging -from mlc_chat.support.config import ConfigBase -from mlc_chat.support.style import bold +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/model/stable_lm/stablelm_quantization.py b/python/mlc_llm/model/stable_lm/stablelm_quantization.py similarity index 92% rename from python/mlc_chat/model/stable_lm/stablelm_quantization.py rename to python/mlc_llm/model/stable_lm/stablelm_quantization.py index 327082aeaa..5f502b0970 100644 --- a/python/mlc_chat/model/stable_lm/stablelm_quantization.py +++ b/python/mlc_llm/model/stable_lm/stablelm_quantization.py @@ -4,8 +4,8 @@ from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize from .stablelm_model import StableLmConfig, StableLmForCausalLM diff --git a/python/mlc_chat/nn/__init__.py b/python/mlc_llm/nn/__init__.py similarity index 100% rename from python/mlc_chat/nn/__init__.py rename to python/mlc_llm/nn/__init__.py diff --git a/python/mlc_chat/nn/expert.py b/python/mlc_llm/nn/expert.py similarity index 95% rename from python/mlc_chat/nn/expert.py rename to python/mlc_llm/nn/expert.py index a4ff0cf2c2..b6659d3d60 100644 --- a/python/mlc_chat/nn/expert.py +++ b/python/mlc_llm/nn/expert.py @@ -2,7 +2,7 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor -from mlc_chat.op import extern, ft_gemm, moe_matmul +from mlc_llm.op import extern, ft_gemm, moe_matmul class MixtralExperts(nn.Module): diff --git a/python/mlc_chat/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py similarity index 99% rename from python/mlc_chat/nn/kv_cache.py rename to python/mlc_llm/nn/kv_cache.py index 636861f3bd..2863ed47b7 100644 --- a/python/mlc_chat/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -12,7 +12,7 @@ from tvm.script import tir as T from tvm.target import Target -from mlc_chat.op.position_embedding import ( +from mlc_llm.op.position_embedding import ( llama_inplace_rope, llama_rope_with_position_map, rope_freq, diff --git a/python/mlc_chat/nn/rnn_state.py b/python/mlc_llm/nn/rnn_state.py similarity index 100% rename from python/mlc_chat/nn/rnn_state.py rename to python/mlc_llm/nn/rnn_state.py diff --git a/python/mlc_chat/op/__init__.py b/python/mlc_llm/op/__init__.py similarity index 100% rename from python/mlc_chat/op/__init__.py rename to python/mlc_llm/op/__init__.py diff --git a/python/mlc_chat/op/attention.py b/python/mlc_llm/op/attention.py similarity index 99% rename from python/mlc_chat/op/attention.py rename to python/mlc_llm/op/attention.py index 02f21a6dfd..801dbd66ba 100644 --- a/python/mlc_chat/op/attention.py +++ b/python/mlc_llm/op/attention.py @@ -5,7 +5,7 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import op -from mlc_chat.support import logging +from mlc_llm.support import logging from . import extern as _extern diff --git a/python/mlc_chat/op/extern.py b/python/mlc_llm/op/extern.py similarity index 100% rename from python/mlc_chat/op/extern.py rename to python/mlc_llm/op/extern.py diff --git a/python/mlc_chat/op/ft_gemm.py b/python/mlc_llm/op/ft_gemm.py similarity index 100% rename from python/mlc_chat/op/ft_gemm.py rename to python/mlc_llm/op/ft_gemm.py diff --git a/python/mlc_chat/op/moe_matmul.py b/python/mlc_llm/op/moe_matmul.py similarity index 100% rename from python/mlc_chat/op/moe_matmul.py rename to python/mlc_llm/op/moe_matmul.py diff --git a/python/mlc_chat/op/moe_misc.py b/python/mlc_llm/op/moe_misc.py similarity index 100% rename from python/mlc_chat/op/moe_misc.py rename to python/mlc_llm/op/moe_misc.py diff --git a/python/mlc_chat/op/position_embedding.py b/python/mlc_llm/op/position_embedding.py similarity index 100% rename from python/mlc_chat/op/position_embedding.py rename to python/mlc_llm/op/position_embedding.py diff --git a/python/mlc_chat/protocol/__init__.py b/python/mlc_llm/protocol/__init__.py similarity index 100% rename from python/mlc_chat/protocol/__init__.py rename to python/mlc_llm/protocol/__init__.py diff --git a/python/mlc_chat/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py similarity index 100% rename from python/mlc_chat/protocol/conversation_protocol.py rename to python/mlc_llm/protocol/conversation_protocol.py diff --git a/python/mlc_chat/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py similarity index 99% rename from python/mlc_chat/protocol/openai_api_protocol.py rename to python/mlc_llm/protocol/openai_api_protocol.py index 8e56d3855f..c2cff9c4fd 100644 --- a/python/mlc_chat/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -10,7 +10,7 @@ import shortuuid from pydantic import BaseModel, Field, field_validator, model_validator -from mlc_chat.serve.config import ResponseFormat +from mlc_llm.serve.config import ResponseFormat ################ Commons ################ diff --git a/python/mlc_chat/protocol/protocol_utils.py b/python/mlc_llm/protocol/protocol_utils.py similarity index 100% rename from python/mlc_chat/protocol/protocol_utils.py rename to python/mlc_llm/protocol/protocol_utils.py diff --git a/python/mlc_chat/quantization/__init__.py b/python/mlc_llm/quantization/__init__.py similarity index 100% rename from python/mlc_chat/quantization/__init__.py rename to python/mlc_llm/quantization/__init__.py diff --git a/python/mlc_chat/quantization/awq_quantization.py b/python/mlc_llm/quantization/awq_quantization.py similarity index 99% rename from python/mlc_chat/quantization/awq_quantization.py rename to python/mlc_llm/quantization/awq_quantization.py index 116582f0b0..0b89e5db6a 100644 --- a/python/mlc_chat/quantization/awq_quantization.py +++ b/python/mlc_llm/quantization/awq_quantization.py @@ -7,7 +7,7 @@ from tvm.relax.frontend import nn from tvm.runtime import NDArray -from mlc_chat.loader import QuantizeMapping +from mlc_llm.loader import QuantizeMapping from .utils import convert_uint_to_float, is_final_fc diff --git a/python/mlc_chat/quantization/ft_quantization.py b/python/mlc_llm/quantization/ft_quantization.py similarity index 100% rename from python/mlc_chat/quantization/ft_quantization.py rename to python/mlc_llm/quantization/ft_quantization.py diff --git a/python/mlc_chat/quantization/group_quantization.py b/python/mlc_llm/quantization/group_quantization.py similarity index 98% rename from python/mlc_chat/quantization/group_quantization.py rename to python/mlc_llm/quantization/group_quantization.py index baf8662963..3431b5415e 100644 --- a/python/mlc_chat/quantization/group_quantization.py +++ b/python/mlc_llm/quantization/group_quantization.py @@ -11,10 +11,10 @@ from tvm.runtime import NDArray from tvm.target import Target -from mlc_chat.loader import QuantizeMapping -from mlc_chat.nn import MixtralExperts -from mlc_chat.support import logging -from mlc_chat.support import tensor_parallel as tp +from mlc_llm.loader import QuantizeMapping +from mlc_llm.nn import MixtralExperts +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp from .utils import convert_uint_to_float, is_final_fc @@ -628,7 +628,7 @@ def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disa ret : nn.Tensor The output tensor for the group quantized mistral experts layer. """ - from mlc_chat.op import moe_matmul # pylint: disable=import-outside-toplevel + from mlc_llm.op import moe_matmul # pylint: disable=import-outside-toplevel assert x.ndim == 2 if indptr.ndim == 2: # single-batch diff --git a/python/mlc_chat/quantization/no_quantization.py b/python/mlc_llm/quantization/no_quantization.py similarity index 100% rename from python/mlc_chat/quantization/no_quantization.py rename to python/mlc_llm/quantization/no_quantization.py diff --git a/python/mlc_chat/quantization/quantization.py b/python/mlc_llm/quantization/quantization.py similarity index 100% rename from python/mlc_chat/quantization/quantization.py rename to python/mlc_llm/quantization/quantization.py diff --git a/python/mlc_chat/quantization/utils.py b/python/mlc_llm/quantization/utils.py similarity index 100% rename from python/mlc_chat/quantization/utils.py rename to python/mlc_llm/quantization/utils.py diff --git a/python/mlc_chat/rest.py b/python/mlc_llm/rest.py similarity index 98% rename from python/mlc_chat/rest.py rename to python/mlc_llm/rest.py index d2911a15f4..011ef4df29 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_llm/rest.py @@ -13,8 +13,8 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse -from mlc_chat.chat_module import GenerationConfig -from mlc_chat.support.random import set_global_random_seed +from mlc_llm.chat_module import GenerationConfig +from mlc_llm.support.random import set_global_random_seed from .chat_module import ChatModule from .interface.openai_api import ( @@ -489,4 +489,4 @@ async def request_llm_vscode(request: VisualStudioCodeCompletionRequest): ARGS = convert_args_to_argparser().parse_args() if __name__ == "__main__": - uvicorn.run("mlc_chat.rest:app", host=ARGS.host, port=ARGS.port, reload=False, access_log=False) + uvicorn.run("mlc_llm.rest:app", host=ARGS.host, port=ARGS.port, reload=False, access_log=False) diff --git a/python/mlc_chat/serve/__init__.py b/python/mlc_llm/serve/__init__.py similarity index 100% rename from python/mlc_chat/serve/__init__.py rename to python/mlc_llm/serve/__init__.py diff --git a/python/mlc_chat/serve/_ffi_api.py b/python/mlc_llm/serve/_ffi_api.py similarity index 87% rename from python/mlc_chat/serve/_ffi_api.py rename to python/mlc_llm/serve/_ffi_api.py index 282c80c4d1..d755fea6d3 100644 --- a/python/mlc_chat/serve/_ffi_api.py +++ b/python/mlc_llm/serve/_ffi_api.py @@ -1,4 +1,4 @@ -"""FFI APIs for mlc_chat.serve""" +"""FFI APIs for mlc_llm.serve""" import tvm._ffi # Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc.serve" prefix. diff --git a/python/mlc_chat/serve/async_engine.py b/python/mlc_llm/serve/async_engine.py similarity index 100% rename from python/mlc_chat/serve/async_engine.py rename to python/mlc_llm/serve/async_engine.py diff --git a/python/mlc_chat/serve/config.py b/python/mlc_llm/serve/config.py similarity index 100% rename from python/mlc_chat/serve/config.py rename to python/mlc_llm/serve/config.py diff --git a/python/mlc_chat/serve/data.py b/python/mlc_llm/serve/data.py similarity index 100% rename from python/mlc_chat/serve/data.py rename to python/mlc_llm/serve/data.py diff --git a/python/mlc_chat/serve/engine.py b/python/mlc_llm/serve/engine.py similarity index 98% rename from python/mlc_chat/serve/engine.py rename to python/mlc_llm/serve/engine.py index c4b3e5d9b4..994a5f4e9e 100644 --- a/python/mlc_chat/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -11,10 +11,10 @@ import tvm from tvm.runtime import Device -from mlc_chat.serve import data -from mlc_chat.support import logging -from mlc_chat.support.auto_device import detect_device -from mlc_chat.support.style import green +from mlc_llm.serve import data +from mlc_llm.support import logging +from mlc_llm.support.auto_device import detect_device +from mlc_llm.support.style import green from ..chat_module import _get_chat_config, _get_lib_module_path, _get_model_path from ..streamer import TextStreamer @@ -109,9 +109,7 @@ def _convert_model_info(model: ModelInfo) -> List[Any]: config_file_path=config_file_path, ) except FileNotFoundError: - from mlc_chat.interface import ( # pylint: disable=import-outside-toplevel - jit, - ) + from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel model_lib_path = str( jit.jit( @@ -155,7 +153,7 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals cmd = [ sys.executable, "-m", - "mlc_chat.cli.model_metadata", + "mlc_llm.cli.model_metadata", model.model_lib_path, "--print-memory-usage-in-json", "--mlc-chat-config", @@ -169,7 +167,7 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals cmd = [ sys.executable, "-m", - "mlc_chat.cli.model_metadata", + "mlc_llm.cli.model_metadata", model.model_lib_path, "--print-kv-cache-metadata-in-json", ] diff --git a/python/mlc_chat/serve/entrypoints/__init__.py b/python/mlc_llm/serve/entrypoints/__init__.py similarity index 100% rename from python/mlc_chat/serve/entrypoints/__init__.py rename to python/mlc_llm/serve/entrypoints/__init__.py diff --git a/python/mlc_chat/serve/entrypoints/debug_entrypoints.py b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py similarity index 100% rename from python/mlc_chat/serve/entrypoints/debug_entrypoints.py rename to python/mlc_llm/serve/entrypoints/debug_entrypoints.py diff --git a/python/mlc_chat/serve/entrypoints/entrypoint_utils.py b/python/mlc_llm/serve/entrypoints/entrypoint_utils.py similarity index 100% rename from python/mlc_chat/serve/entrypoints/entrypoint_utils.py rename to python/mlc_llm/serve/entrypoints/entrypoint_utils.py diff --git a/python/mlc_chat/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py similarity index 100% rename from python/mlc_chat/serve/entrypoints/openai_entrypoints.py rename to python/mlc_llm/serve/entrypoints/openai_entrypoints.py diff --git a/python/mlc_chat/serve/event_trace_recorder.py b/python/mlc_llm/serve/event_trace_recorder.py similarity index 100% rename from python/mlc_chat/serve/event_trace_recorder.py rename to python/mlc_llm/serve/event_trace_recorder.py diff --git a/python/mlc_chat/serve/grammar.py b/python/mlc_llm/serve/grammar.py similarity index 100% rename from python/mlc_chat/serve/grammar.py rename to python/mlc_llm/serve/grammar.py diff --git a/python/mlc_chat/serve/request.py b/python/mlc_llm/serve/request.py similarity index 100% rename from python/mlc_chat/serve/request.py rename to python/mlc_llm/serve/request.py diff --git a/python/mlc_chat/serve/server/__init__.py b/python/mlc_llm/serve/server/__init__.py similarity index 100% rename from python/mlc_chat/serve/server/__init__.py rename to python/mlc_llm/serve/server/__init__.py diff --git a/python/mlc_chat/serve/server/__main__.py b/python/mlc_llm/serve/server/__main__.py similarity index 100% rename from python/mlc_chat/serve/server/__main__.py rename to python/mlc_llm/serve/server/__main__.py diff --git a/python/mlc_chat/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py similarity index 97% rename from python/mlc_chat/serve/server/popen_server.py rename to python/mlc_llm/serve/server/popen_server.py index 09e468850e..6a668419cc 100644 --- a/python/mlc_chat/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -25,7 +25,7 @@ def __init__( # pylint: disable=too-many-arguments host: str = "127.0.0.1", port: int = 8000, ) -> None: - """Please check out `python/mlc_chat/serve/server/__main__.py` + """Please check out `python/mlc_llm/serve/server/__main__.py` for the server arguments.""" self.model = model self.model_lib_path = model_lib_path @@ -42,7 +42,7 @@ def start(self) -> None: Wait until the server becomes ready before return. """ cmd = [sys.executable] - cmd += ["-m", "mlc_chat.serve.server"] + cmd += ["-m", "mlc_llm.serve.server"] cmd += ["--model", self.model] cmd += ["--model-lib-path", self.model_lib_path] cmd += ["--device", self.device] diff --git a/python/mlc_chat/serve/server/server_context.py b/python/mlc_llm/serve/server/server_context.py similarity index 100% rename from python/mlc_chat/serve/server/server_context.py rename to python/mlc_llm/serve/server/server_context.py diff --git a/python/mlc_chat/streamer.py b/python/mlc_llm/streamer.py similarity index 100% rename from python/mlc_chat/streamer.py rename to python/mlc_llm/streamer.py diff --git a/python/mlc_chat/support/__init__.py b/python/mlc_llm/support/__init__.py similarity index 100% rename from python/mlc_chat/support/__init__.py rename to python/mlc_llm/support/__init__.py diff --git a/python/mlc_chat/support/argparse.py b/python/mlc_llm/support/argparse.py similarity index 100% rename from python/mlc_chat/support/argparse.py rename to python/mlc_llm/support/argparse.py diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_llm/support/auto_config.py similarity index 92% rename from python/mlc_chat/support/auto_config.py rename to python/mlc_llm/support/auto_config.py index a5b73b73d4..f0247a6ef9 100644 --- a/python/mlc_chat/support/auto_config.py +++ b/python/mlc_llm/support/auto_config.py @@ -8,8 +8,8 @@ from .style import bold, green if TYPE_CHECKING: - from mlc_chat.model import Model # pylint: disable=unused-import - from mlc_chat.quantization import Quantization # pylint: disable=unused-import + from mlc_llm.model import Model # pylint: disable=unused-import + from mlc_llm.quantization import Quantization # pylint: disable=unused-import logger = logging.getLogger(__name__) @@ -33,7 +33,7 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path: The path points to mlc_chat_config.json. """ # pylint: disable=import-outside-toplevel - from mlc_chat.model import MODEL_PRESETS + from mlc_llm.model import MODEL_PRESETS from .download import download_mlc_weights @@ -85,7 +85,7 @@ def detect_config(config: str) -> Path: config_json_path : pathlib.Path The path points to config.json. """ - from mlc_chat.model import MODEL_PRESETS # pylint: disable=import-outside-toplevel + from mlc_llm.model import MODEL_PRESETS # pylint: disable=import-outside-toplevel if isinstance(config, str) and config in MODEL_PRESETS: logger.info("%s preset model: %s", FOUND, config) @@ -131,11 +131,11 @@ def detect_model_type(model_type: str, config: Path) -> "Model": Returns ------- - model : mlc_chat.compiler.Model + model : mlc_llm.compiler.Model The model type. """ - from mlc_chat.model import MODELS # pylint: disable=import-outside-toplevel + from mlc_llm.model import MODELS # pylint: disable=import-outside-toplevel if model_type == "auto": with open(config, "r", encoding="utf-8") as config_file: @@ -171,10 +171,10 @@ def detect_quantization(quantization_arg: str, config: Path) -> "Quantization": Returns ------- - quantization : mlc_chat.quantization.Quantization + quantization : mlc_llm.quantization.Quantization The model quantization scheme. """ - from mlc_chat.quantization import ( # pylint: disable=import-outside-toplevel + from mlc_llm.quantization import ( # pylint: disable=import-outside-toplevel QUANTIZATION, ) diff --git a/python/mlc_chat/support/auto_device.py b/python/mlc_llm/support/auto_device.py similarity index 98% rename from python/mlc_chat/support/auto_device.py rename to python/mlc_llm/support/auto_device.py index 6d18de479b..cf6d09495a 100644 --- a/python/mlc_chat/support/auto_device.py +++ b/python/mlc_llm/support/auto_device.py @@ -54,7 +54,7 @@ def _device_exists(device: Device) -> bool: cmd = [ sys.executable, "-m", - "mlc_chat.cli.check_device", + "mlc_llm.cli.check_device", device_type, ] prefix = "check_device:" diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_llm/support/auto_target.py similarity index 99% rename from python/mlc_chat/support/auto_target.py rename to python/mlc_llm/support/auto_target.py index a4bb853bc7..434cfff8d0 100644 --- a/python/mlc_chat/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -15,7 +15,7 @@ from .style import bold, green, red if TYPE_CHECKING: - from mlc_chat.compiler.compile import CompileArgs + from mlc_llm.compiler.compile import CompileArgs logger = logging.getLogger(__name__) diff --git a/python/mlc_chat/support/auto_weight.py b/python/mlc_llm/support/auto_weight.py similarity index 100% rename from python/mlc_chat/support/auto_weight.py rename to python/mlc_llm/support/auto_weight.py diff --git a/python/mlc_chat/support/config.py b/python/mlc_llm/support/config.py similarity index 100% rename from python/mlc_chat/support/config.py rename to python/mlc_llm/support/config.py diff --git a/python/mlc_chat/support/constants.py b/python/mlc_llm/support/constants.py similarity index 93% rename from python/mlc_chat/support/constants.py rename to python/mlc_llm/support/constants.py index 09e489348c..82697ff71a 100644 --- a/python/mlc_chat/support/constants.py +++ b/python/mlc_llm/support/constants.py @@ -17,13 +17,13 @@ def _get_cache_dir() -> Path: result = Path(os.environ["MLC_CACHE_DIR"]) elif sys.platform == "win32": result = Path(os.environ["LOCALAPPDATA"]) - result = result / "mlc_chat" + result = result / "mlc_llm" elif os.getenv("XDG_CACHE_HOME", None) is not None: result = Path(os.getenv("XDG_CACHE_HOME")) - result = result / "mlc_chat" + result = result / "mlc_llm" else: result = Path(os.path.expanduser("~/.cache")) - result = result / "mlc_chat" + result = result / "mlc_llm" result.mkdir(parents=True, exist_ok=True) if not result.is_dir(): raise ValueError( diff --git a/python/mlc_chat/support/convert_tiktoken.py b/python/mlc_llm/support/convert_tiktoken.py similarity index 100% rename from python/mlc_chat/support/convert_tiktoken.py rename to python/mlc_llm/support/convert_tiktoken.py diff --git a/python/mlc_chat/support/download.py b/python/mlc_llm/support/download.py similarity index 100% rename from python/mlc_chat/support/download.py rename to python/mlc_llm/support/download.py diff --git a/python/mlc_chat/support/logging.py b/python/mlc_llm/support/logging.py similarity index 100% rename from python/mlc_chat/support/logging.py rename to python/mlc_llm/support/logging.py diff --git a/python/mlc_chat/support/max_thread_check.py b/python/mlc_llm/support/max_thread_check.py similarity index 100% rename from python/mlc_chat/support/max_thread_check.py rename to python/mlc_llm/support/max_thread_check.py diff --git a/python/mlc_chat/support/preshard.py b/python/mlc_llm/support/preshard.py similarity index 100% rename from python/mlc_chat/support/preshard.py rename to python/mlc_llm/support/preshard.py diff --git a/python/mlc_chat/support/random.py b/python/mlc_llm/support/random.py similarity index 100% rename from python/mlc_chat/support/random.py rename to python/mlc_llm/support/random.py diff --git a/python/mlc_chat/support/style.py b/python/mlc_llm/support/style.py similarity index 100% rename from python/mlc_chat/support/style.py rename to python/mlc_llm/support/style.py diff --git a/python/mlc_chat/support/tensor_parallel.py b/python/mlc_llm/support/tensor_parallel.py similarity index 100% rename from python/mlc_chat/support/tensor_parallel.py rename to python/mlc_llm/support/tensor_parallel.py diff --git a/python/mlc_chat/support/tqdm.py b/python/mlc_llm/support/tqdm.py similarity index 100% rename from python/mlc_chat/support/tqdm.py rename to python/mlc_llm/support/tqdm.py diff --git a/python/mlc_chat/tokenizer.py b/python/mlc_llm/tokenizer.py similarity index 100% rename from python/mlc_chat/tokenizer.py rename to python/mlc_llm/tokenizer.py diff --git a/python/setup.py b/python/setup.py index 4602f55cb8..2f1b632bf5 100644 --- a/python/setup.py +++ b/python/setup.py @@ -13,7 +13,7 @@ def get_lib_path(): """Get library path, name and version""" # Directly exec libinfo to get the right setup - libinfo_py = os.path.join(CURRENT_DIR, "./mlc_chat/libinfo.py") + libinfo_py = os.path.join(CURRENT_DIR, "./mlc_llm/libinfo.py") libinfo = {"__file__": libinfo_py} with open(libinfo_py, "rb") as f: exec(compile(f.read(), libinfo_py, "exec"), libinfo, libinfo) @@ -69,15 +69,15 @@ def main(): with open("MANIFEST.in", "w", encoding="utf-8") as fo: for path in LIB_LIST: if os.path.isfile(path): - shutil.copy(path, os.path.join(CURRENT_DIR, "mlc_chat")) + shutil.copy(path, os.path.join(CURRENT_DIR, "mlc_llm")) _, libname = os.path.split(path) - fo.write(f"include mlc_chat/{libname}\n") + fo.write(f"include mlc_llm/{libname}\n") setup_kwargs = {"include_package_data": True} setup( - name="mlc_chat", + name="mlc_llm", version=__version__, - description="MLC Chat: an universal runtime running LLMs", + description="MLC LLM: an universal LLM deployment engine via ML compilation.", url="https://llm.mlc.ai/", author="MLC LLM Contributors", license="Apache 2.0", @@ -93,11 +93,9 @@ def main(): zip_safe=False, packages=find_packages(), entry_points={ - "console_scripts": [ - "mlc_chat = mlc_chat.__main__:main", - ], + "console_scripts": ["mlc_llm = mlc_llm.__main__:main"], }, - package_dir={"mlc_chat": "mlc_chat"}, + package_dir={"mlc_llm": "mlc_llm"}, install_requires=[ "fastapi", "uvicorn", @@ -126,7 +124,7 @@ def _remove_path(path): os.remove("MANIFEST.in") for path in LIB_LIST: _, libname = os.path.split(path) - _remove_path(f"mlc_chat/{libname}") + _remove_path(f"mlc_llm/{libname}") main() diff --git a/rust/README.md b/rust/README.md index 8c92525772..971fb11200 100644 --- a/rust/README.md +++ b/rust/README.md @@ -20,6 +20,6 @@ To start using the package, you can refer to the example code provided in the ex Execute the example with Cargo using the following command: ```bash -cargo run --example mlc_chat +cargo run --example mlc_llm ``` diff --git a/tests/python/api/test_python.py b/tests/python/api/test_python.py index ceba066a13..d4945f9503 100644 --- a/tests/python/api/test_python.py +++ b/tests/python/api/test_python.py @@ -1,8 +1,8 @@ # pylint: disable=missing-docstring import pytest -from mlc_chat import ChatModule, GenerationConfig -from mlc_chat.callback import StreamToStdout +from mlc_llm import ChatModule, GenerationConfig +from mlc_llm.callback import StreamToStdout MODELS = ["Llama-2-7b-chat-hf-q4f16_1"] diff --git a/tests/python/api/test_rest.py b/tests/python/api/test_rest.py index f4ef4428a2..f617c5727d 100644 --- a/tests/python/api/test_rest.py +++ b/tests/python/api/test_rest.py @@ -13,7 +13,7 @@ @pytest.fixture def run_rest_server(model): - cmd = f"python -m mlc_chat.rest --model {model}" + cmd = f"python -m mlc_llm.rest --model {model}" print(cmd) os.environ["PYTHONPATH"] = "./python" with subprocess.Popen(cmd.split()) as server_proc: diff --git a/tests/python/compiler_pass/test_fuse_ft_dequantize_matmul_epilogue.py b/tests/python/compiler_pass/test_fuse_ft_dequantize_matmul_epilogue.py index eed1010cf2..1035ce96fd 100644 --- a/tests/python/compiler_pass/test_fuse_ft_dequantize_matmul_epilogue.py +++ b/tests/python/compiler_pass/test_fuse_ft_dequantize_matmul_epilogue.py @@ -4,7 +4,7 @@ from tvm.script import ir as I from tvm.script import relax as R -from mlc_chat.compiler_pass.fuse_ft_dequantize_matmul_epilogue import ( +from mlc_llm.compiler_pass.fuse_ft_dequantize_matmul_epilogue import ( FuseFTDequantizeEpilogue, ) diff --git a/tests/python/integration/test_model_compile.py b/tests/python/integration/test_model_compile.py index c70b1b5b20..2f136f3f16 100644 --- a/tests/python/integration/test_model_compile.py +++ b/tests/python/integration/test_model_compile.py @@ -9,10 +9,10 @@ import tvm -from mlc_chat.model import MODEL_PRESETS -from mlc_chat.model import MODELS as SUPPORTED_MODELS -from mlc_chat.quantization import QUANTIZATION as SUPPORTED_QUANTS -from mlc_chat.support.constants import MLC_TEMP_DIR +from mlc_llm.model import MODEL_PRESETS +from mlc_llm.model import MODELS as SUPPORTED_MODELS +from mlc_llm.quantization import QUANTIZATION as SUPPORTED_QUANTS +from mlc_llm.support.constants import MLC_TEMP_DIR OPT_LEVEL = "O2" DEVICE2TARGET = { @@ -61,7 +61,7 @@ "ios": "tar", } MODELS = list(MODEL_PRESETS.keys()) -QUANTS = [ # TODO(@junrushao): use `list(mlc_chat.quantization.QUANTIZATION.keys())` +QUANTS = [ # TODO(@junrushao): use `list(mlc_llm.quantization.QUANTIZATION.keys())` "q0f16", "q0f32", "q3f16_1", @@ -117,7 +117,7 @@ def test_model_compile(): # pylint: disable=too-many-locals cmd = [ sys.executable, "-m", - "mlc_chat", + "mlc_llm", "compile", model, "--quantization", diff --git a/tests/python/loader/test_awq.py b/tests/python/loader/test_awq.py index d945a95db0..3ab5bd911e 100644 --- a/tests/python/loader/test_awq.py +++ b/tests/python/loader/test_awq.py @@ -5,10 +5,10 @@ import pytest import tvm -from mlc_chat.loader import HuggingFaceLoader -from mlc_chat.model import MODEL_PRESETS, MODELS -from mlc_chat.quantization import QUANTIZATION -from mlc_chat.support import logging, tqdm +from mlc_llm.loader import HuggingFaceLoader +from mlc_llm.model import MODEL_PRESETS, MODELS +from mlc_llm.quantization import QUANTIZATION +from mlc_llm.support import logging, tqdm logging.enable_logging() diff --git a/tests/python/loader/test_huggingface.py b/tests/python/loader/test_huggingface.py index dfbef55c28..1b7bd3c02d 100644 --- a/tests/python/loader/test_huggingface.py +++ b/tests/python/loader/test_huggingface.py @@ -5,9 +5,9 @@ import pytest import tvm -from mlc_chat.loader import HuggingFaceLoader -from mlc_chat.model import MODELS -from mlc_chat.support import logging, tqdm +from mlc_llm.loader import HuggingFaceLoader +from mlc_llm.model import MODELS +from mlc_llm.support import logging, tqdm logging.enable_logging() diff --git a/tests/python/model/test_gpt2.py b/tests/python/model/test_gpt2.py index 9517ad1c45..cdbe7ff222 100644 --- a/tests/python/model/test_gpt2.py +++ b/tests/python/model/test_gpt2.py @@ -1,7 +1,7 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.model import MODEL_PRESETS, MODELS +from mlc_llm.model import MODEL_PRESETS, MODELS @pytest.mark.parametrize("model_name", ["gpt2"]) diff --git a/tests/python/model/test_gptNeox.py b/tests/python/model/test_gptNeox.py index d4fcfdd142..5983a5b491 100644 --- a/tests/python/model/test_gptNeox.py +++ b/tests/python/model/test_gptNeox.py @@ -1,7 +1,7 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.model import MODEL_PRESETS, MODELS +from mlc_llm.model import MODEL_PRESETS, MODELS @pytest.mark.parametrize("model_name", ["redpajama_3b_v1"]) diff --git a/tests/python/model/test_kv_cache.py b/tests/python/model/test_kv_cache.py index 970b7bac16..be4cc4a507 100644 --- a/tests/python/model/test_kv_cache.py +++ b/tests/python/model/test_kv_cache.py @@ -6,7 +6,7 @@ from tvm.script import relax as R from tvm.script import tir as T -from mlc_chat.nn.kv_cache import FlashInferPagedKVCache, PagedKVCache, RopeMode +from mlc_llm.nn.kv_cache import FlashInferPagedKVCache, PagedKVCache, RopeMode # mypy: disable-error-code="attr-defined" # pylint: disable=invalid-name,unused-argument,too-many-locals,too-many-statements diff --git a/tests/python/model/test_llama.py b/tests/python/model/test_llama.py index 6e1b38dbca..5591dcdca2 100644 --- a/tests/python/model/test_llama.py +++ b/tests/python/model/test_llama.py @@ -1,7 +1,7 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.model import MODEL_PRESETS, MODELS +from mlc_llm.model import MODEL_PRESETS, MODELS @pytest.mark.parametrize( diff --git a/tests/python/model/test_llama_quantization.py b/tests/python/model/test_llama_quantization.py index 4d4c761fb1..87d9d2b282 100644 --- a/tests/python/model/test_llama_quantization.py +++ b/tests/python/model/test_llama_quantization.py @@ -1,9 +1,9 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.model import MODEL_PRESETS, MODELS -from mlc_chat.quantization import QUANTIZATION -from mlc_chat.quantization.group_quantization import ( +from mlc_llm.model import MODEL_PRESETS, MODELS +from mlc_llm.quantization import QUANTIZATION +from mlc_llm.quantization.group_quantization import ( GroupQuantizeEmbedding, GroupQuantizeLinear, ) diff --git a/tests/python/model/test_mistral.py b/tests/python/model/test_mistral.py index 631b592979..c1d47eba77 100644 --- a/tests/python/model/test_mistral.py +++ b/tests/python/model/test_mistral.py @@ -1,7 +1,7 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.model import MODEL_PRESETS, MODELS +from mlc_llm.model import MODEL_PRESETS, MODELS @pytest.mark.parametrize("model_name", ["mistral_7b"]) diff --git a/tests/python/model/test_phi.py b/tests/python/model/test_phi.py index e3f55f263e..e72effab35 100644 --- a/tests/python/model/test_phi.py +++ b/tests/python/model/test_phi.py @@ -1,7 +1,7 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.model import MODEL_PRESETS, MODELS +from mlc_llm.model import MODEL_PRESETS, MODELS @pytest.mark.parametrize("model_name", ["phi-1_5", "phi-2"]) diff --git a/tests/python/quantization/test_awq_quantization.py b/tests/python/quantization/test_awq_quantization.py index 244271aff7..0222a29b6f 100644 --- a/tests/python/quantization/test_awq_quantization.py +++ b/tests/python/quantization/test_awq_quantization.py @@ -9,8 +9,8 @@ from tvm import DataType from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import QUANTIZATION, AWQQuantize +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import QUANTIZATION, AWQQuantize def dequantize_np( diff --git a/tests/python/quantization/test_group_quantization.py b/tests/python/quantization/test_group_quantization.py index 72133ff013..b3f9d8034c 100644 --- a/tests/python/quantization/test_group_quantization.py +++ b/tests/python/quantization/test_group_quantization.py @@ -9,9 +9,9 @@ from tvm import DataType from tvm.relax.frontend import nn -from mlc_chat.loader import QuantizeMapping -from mlc_chat.quantization import QUANTIZATION -from mlc_chat.quantization.group_quantization import ( +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import QUANTIZATION +from mlc_llm.quantization.group_quantization import ( GroupQuantize, GroupQuantizeEmbedding, GroupQuantizeLinear, diff --git a/tests/python/serve/benchmark.py b/tests/python/serve/benchmark.py index 26e9d9af40..94d48c12af 100644 --- a/tests/python/serve/benchmark.py +++ b/tests/python/serve/benchmark.py @@ -10,8 +10,8 @@ import numpy as np from transformers import AutoTokenizer -from mlc_chat.serve import Engine, GenerationConfig, KVCacheConfig -from mlc_chat.serve.engine import ModelInfo +from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig +from mlc_llm.serve.engine import ModelInfo def _parse_args(): diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index 9fd21f6f53..bbd2089f4c 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -4,8 +4,8 @@ import random from typing import List, Tuple -from mlc_chat.serve import Engine, GenerationConfig, KVCacheConfig -from mlc_chat.serve.engine import ModelInfo +from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig +from mlc_llm.serve.engine import ModelInfo def _parse_args(): diff --git a/tests/python/serve/server/conftest.py b/tests/python/serve/server/conftest.py index 004b148788..807739ace6 100644 --- a/tests/python/serve/server/conftest.py +++ b/tests/python/serve/server/conftest.py @@ -4,7 +4,7 @@ import pytest -from mlc_chat.serve import PopenServer +from mlc_llm.serve import PopenServer @pytest.fixture(scope="session") diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index 1436de34d7..88734455cf 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -13,7 +13,7 @@ launch the server in ahead before running this file. This can be done in two steps: - start a new shell session, run - python -m mlc_chat.serve.server --model "YOUR_MODEL_LIB" + python -m mlc_llm.serve.server --model "YOUR_MODEL_LIB" - start another shell session, run this file MLC_SERVE_MODEL_LIB="YOUR_MODEL_LIB" python tests/python/serve/server/test_server.py """ diff --git a/tests/python/serve/test_event_trace_recorder.py b/tests/python/serve/test_event_trace_recorder.py index fb2a5f2974..b22dfeddad 100644 --- a/tests/python/serve/test_event_trace_recorder.py +++ b/tests/python/serve/test_event_trace_recorder.py @@ -1,7 +1,7 @@ # pylint: disable=missing-module-docstring,missing-function-docstring import json -from mlc_chat.serve.event_trace_recorder import EventTraceRecorder +from mlc_llm.serve.event_trace_recorder import EventTraceRecorder def test_event_trace_recorder(): diff --git a/tests/python/serve/test_grammar_parser.py b/tests/python/serve/test_grammar_parser.py index 87228b1c18..325b0a5117 100644 --- a/tests/python/serve/test_grammar_parser.py +++ b/tests/python/serve/test_grammar_parser.py @@ -5,7 +5,7 @@ import tvm.testing from tvm import TVMError -from mlc_chat.serve import BNFGrammar +from mlc_llm.serve import BNFGrammar def test_bnf_simple(): diff --git a/tests/python/serve/test_grammar_state_matcher_custom.py b/tests/python/serve/test_grammar_state_matcher_custom.py index d9a9a09bab..37c9af0d9b 100644 --- a/tests/python/serve/test_grammar_state_matcher_custom.py +++ b/tests/python/serve/test_grammar_state_matcher_custom.py @@ -10,8 +10,8 @@ import tvm import tvm.testing -from mlc_chat.serve import BNFGrammar, GrammarStateMatcher -from mlc_chat.tokenizer import Tokenizer +from mlc_llm.serve import BNFGrammar, GrammarStateMatcher +from mlc_llm.tokenizer import Tokenizer def get_json_grammar(): diff --git a/tests/python/serve/test_grammar_state_matcher_json.py b/tests/python/serve/test_grammar_state_matcher_json.py index a38a0edefe..dfc0257b04 100644 --- a/tests/python/serve/test_grammar_state_matcher_json.py +++ b/tests/python/serve/test_grammar_state_matcher_json.py @@ -9,8 +9,8 @@ import tvm.testing from tvm import TVMError -from mlc_chat.serve import BNFGrammar, GrammarStateMatcher -from mlc_chat.tokenizer import Tokenizer +from mlc_llm.serve import BNFGrammar, GrammarStateMatcher +from mlc_llm.tokenizer import Tokenizer @pytest.fixture(scope="function") diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index c7616df5f7..a1a2791bf7 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -3,8 +3,8 @@ import asyncio from typing import List -from mlc_chat.serve import AsyncThreadedEngine, GenerationConfig, KVCacheConfig -from mlc_chat.serve.engine import ModelInfo +from mlc_llm.serve import AsyncThreadedEngine, GenerationConfig, KVCacheConfig +from mlc_llm.serve.engine import ModelInfo prompts = [ "What is the meaning of life?", diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index becc594622..10ed7a4729 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -3,13 +3,13 @@ import asyncio from typing import List -from mlc_chat.serve import ( +from mlc_llm.serve import ( AsyncThreadedEngine, EngineMode, GenerationConfig, KVCacheConfig, ) -from mlc_chat.serve.engine import ModelInfo +from mlc_llm.serve.engine import ModelInfo prompts = [ "What is the meaning of life?", diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 5cd13be91e..9f56f507ca 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -4,7 +4,7 @@ import numpy as np -from mlc_chat.serve import ( +from mlc_llm.serve import ( Engine, GenerationConfig, KVCacheConfig, @@ -12,7 +12,7 @@ RequestStreamOutput, data, ) -from mlc_chat.serve.engine import ModelInfo +from mlc_llm.serve.engine import ModelInfo prompts = [ "What is the meaning of life?", diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index e96eac9dda..b5430acd39 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -5,10 +5,10 @@ import pytest -from mlc_chat.serve import Engine, GenerationConfig, KVCacheConfig -from mlc_chat.serve.async_engine import AsyncThreadedEngine -from mlc_chat.serve.config import ResponseFormat -from mlc_chat.serve.engine import ModelInfo +from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig +from mlc_llm.serve.async_engine import AsyncThreadedEngine +from mlc_llm.serve.config import ResponseFormat +from mlc_llm.serve.engine import ModelInfo prompts_list = [ "Generate a JSON string containing 20 objects:", diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 663744305d..828146afc9 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -4,7 +4,7 @@ import numpy as np -from mlc_chat.serve import ( +from mlc_llm.serve import ( Engine, EngineMode, GenerationConfig, @@ -13,7 +13,7 @@ RequestStreamOutput, data, ) -from mlc_chat.serve.engine import ModelInfo +from mlc_llm.serve.engine import ModelInfo prompts = [ "What is the meaning of life?", diff --git a/tests/python/support/test_auto_config.py b/tests/python/support/test_auto_config.py index 77c6a0d80a..90e797b14e 100644 --- a/tests/python/support/test_auto_config.py +++ b/tests/python/support/test_auto_config.py @@ -5,8 +5,8 @@ import pytest -from mlc_chat.support import logging -from mlc_chat.support.auto_config import detect_config +from mlc_llm.support import logging +from mlc_llm.support.auto_config import detect_config logging.enable_logging() diff --git a/tests/python/support/test_auto_weight.py b/tests/python/support/test_auto_weight.py index dfbefff3e6..2b3ad48393 100644 --- a/tests/python/support/test_auto_weight.py +++ b/tests/python/support/test_auto_weight.py @@ -6,8 +6,8 @@ import pytest -from mlc_chat.support import logging -from mlc_chat.support.auto_weight import detect_weight +from mlc_llm.support import logging +from mlc_llm.support.auto_weight import detect_weight logging.enable_logging() diff --git a/tests/python/support/test_streamer.py b/tests/python/support/test_streamer.py index 4f51ea1dd7..4ea4573c08 100644 --- a/tests/python/support/test_streamer.py +++ b/tests/python/support/test_streamer.py @@ -22,8 +22,8 @@ import pytest -from mlc_chat.streamer import StopStrHandler, TextStreamer -from mlc_chat.tokenizer import Tokenizer +from mlc_llm.streamer import StopStrHandler, TextStreamer +from mlc_llm.tokenizer import Tokenizer # fmt: off para_input_tokens = [18585, 29892, 1244, 29915, 29879, 263, 3273, 14880, 1048, 953, 29877, 2397, From c268f950178e961a50f8a4778fb93547b5d08b25 Mon Sep 17 00:00:00 2001 From: Git bot Date: Tue, 12 Mar 2024 14:10:15 +0000 Subject: [PATCH 059/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index f06d486b4a..1d4da926c7 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit f06d486b4a1a27f0bbb072688a5fc41e7b15323c +Subproject commit 1d4da926c726e2700593c7f62006545bda6a46f9 From d6d972c4256dcbfe8de0ecc0db913852cbb6cde5 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 12 Mar 2024 11:58:52 -0400 Subject: [PATCH 060/531] [Docs] Deprecating CUDA 11.7/11.8 support (#1939) We have deprecated the wheel support for CUDA 11.7/11.8 due to TVM thrust compatibility with old CUDA versions. --- docs/install/mlc_llm.rst | 14 -------------- docs/install/tvm.rst | 14 -------------- 2 files changed, 28 deletions(-) diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index b4eff63041..3003abdc72 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -31,20 +31,6 @@ Select your operating system/compute platform and run the command in your termin conda activate your-environment python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly - .. tab:: CUDA 11.7 - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu117 mlc-ai-nightly-cu117 - - .. tab:: CUDA 11.8 - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu118 mlc-ai-nightly-cu118 - .. tab:: CUDA 12.1 .. code-block:: bash diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst index f5cb460dfd..7fbd3d08ad 100644 --- a/docs/install/tvm.rst +++ b/docs/install/tvm.rst @@ -39,20 +39,6 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. conda activate your-environment python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly - .. tab:: CUDA 11.7 - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu117 - - .. tab:: CUDA 11.8 - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu118 - .. tab:: CUDA 12.1 .. code-block:: bash From 9df8f035b1694f6c60fb25dd70f9ffa3eb44fe3e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 12 Mar 2024 12:01:17 -0400 Subject: [PATCH 061/531] [Fix] Fix KV cache call in mistral (#1938) The latest TVM introduces the wellformedness check of the IR. The mistral model definition breaks the wellformedness due to the purity. This PR fixes this issue. --- python/mlc_llm/model/mistral/mistral_model.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/mlc_llm/model/mistral/mistral_model.py b/python/mlc_llm/model/mistral/mistral_model.py index 9374df595c..88be860628 100644 --- a/python/mlc_llm/model/mistral/mistral_model.py +++ b/python/mlc_llm/model/mistral/mistral_model.py @@ -1,6 +1,7 @@ """ Implementation for Mistral architecture. """ + import dataclasses from typing import Any, Dict, Optional @@ -279,14 +280,12 @@ def override(self, new_element: Tensor, max_cache_size: int, attention_sink_size f'but got "{new_element.dtype}"' ) self.cache = rx.BlockBuilder.current().emit( - rx.Call( - rx.extern("vm.builtin.attention_kv_cache_window_override_with_sinks"), - args=[ - self.cache, - new_element._expr, # pylint: disable=protected-access - rx.PrimValue(max_cache_size), - rx.PrimValue(attention_sink_size), - ], + rx.call_pure_packed( + "vm.builtin.attention_kv_cache_window_override_with_sinks", + self.cache, + new_element._expr, # pylint: disable=protected-access + rx.PrimValue(max_cache_size), + rx.PrimValue(attention_sink_size), sinfo_args=[rx.ObjectStructInfo()], ) ) From 48934150281dfde05552d3e86c95ff83fa0bced1 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 12 Mar 2024 13:16:39 -0400 Subject: [PATCH 062/531] [ChatModule] Remove eos_token_ids (#1940) This PR removes the eos_token_ids from the ChatModule given it is nowhere used actually. --- cpp/llm_chat.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index e0f653841e..aca13db863 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -575,10 +575,6 @@ class LLMChat { CHECK(config["bos_token_id"].is()); this->bos_token_id_ = config["bos_token_id"].get(); } - if (config.count("eos_token_id")) { - CHECK(config["eos_token_id"].is()); - this->eos_token_id_ = config["eos_token_id"].get(); - } } /*! @@ -1628,8 +1624,6 @@ class LLMChat { Tokenizer tokenizer_; // bos token int32_t bos_token_id_{1}; - // eos token id - int32_t eos_token_id_{2}; //---------------------------- // TVM related states //---------------------------- From 738e353a55af1f8e12c64d4fa90b7826588cefde Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 12 Mar 2024 15:04:30 -0400 Subject: [PATCH 063/531] [SLM] Weight conversion with generator (#1916) This PR enhances weight conversion so that it passes a generator to `tvmjs.dump_ndarray_cache`. This effectively reduces the CPU memory pressure when converting weights, especially when the total converted weight size is close to or larger to the CPU memory size. --- python/mlc_llm/interface/convert_weight.py | 63 +++++++++++++--------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/python/mlc_llm/interface/convert_weight.py b/python/mlc_llm/interface/convert_weight.py index fad6114c6e..0d5cd53fea 100644 --- a/python/mlc_llm/interface/convert_weight.py +++ b/python/mlc_llm/interface/convert_weight.py @@ -5,6 +5,7 @@ import os from io import StringIO from pathlib import Path +from typing import Any, Dict, Iterator, Tuple import numpy as np from tvm import tir @@ -83,7 +84,7 @@ def _check_param(name: str, param: NDArray): nonlocal named_params if name not in named_params: raise ValueError(f"Parameter not found in model: {name}") - if name in param_dict: + if name in param_names: raise ValueError(f"Duplication: Parameter {name} already computed") # Check shape (possibly dynamic) @@ -112,20 +113,43 @@ def _check_shape(actual: tuple, expect: tuple): # expect can have tir.Var del named_params[name] # load and quantize - param_dict = {} + param_names = set() total_bytes = 0.0 - with Target.from_device(args.device), tqdm.redirect(): - loader = LOADER[args.source_format]( - path=args.source, - extern_param_map=args.model.source[args.source_format](model_config, args.quantization), - quantize_param_map=quantize_map, - ) - for name, param in loader.load(device=args.device, preshard_funcs=preshard_funcs): - _check_param(name, param) - param = param.copyto(cpu_device()) - param_dict[name] = param - total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize - total_params = loader.stats.total_param_num + total_params: int + + def _param_generator() -> Iterator[Tuple[str, NDArray]]: + nonlocal total_params, total_bytes + with Target.from_device(args.device), tqdm.redirect(): + loader = LOADER[args.source_format]( + path=args.source, + extern_param_map=args.model.source[args.source_format]( + model_config, args.quantization + ), + quantize_param_map=quantize_map, + ) + for name, param in loader.load(device=args.device, preshard_funcs=preshard_funcs): + _check_param(name, param) + param_names.add(name) + param = param.copyto(cpu_device()) + total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize + yield name, param + total_params = loader.stats.total_param_num + + def _metadata_callback() -> Dict[str, Any]: + return { + "ParamSize": len(param_names), + "ParamBytes": total_bytes, + "BitsPerParam": total_bytes * 8.0 / total_params, + } + + # dump to output directory + tvmjs.dump_ndarray_cache( + _param_generator(), + str(args.output), + meta_data=_metadata_callback, + encode_format="f32-to-bf16", + show_progress=False, + ) if named_params: raise ValueError(f"Parameter not found in source: {', '.join(named_params.keys())}") # Log necessary statistics @@ -140,17 +164,6 @@ def _check_shape(actual: tuple, expect: tuple): # expect can have tir.Var green("Bits per parameter"), total_bytes * 8.0 / total_params, ) - # dump to output directory - tvmjs.dump_ndarray_cache( - param_dict, - str(args.output), - meta_data={ - "ParamSize": len(param_dict), - "ParamBytes": total_bytes, - "BitsPerParam": total_bytes * 8.0 / total_params, - }, - encode_format="f32-to-bf16", - ) logger.info("Saved to directory: %s", bold(str(args.output))) From 5b8c529e9704abd09b0432da6dcb4b013fdf43b1 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 12 Mar 2024 17:12:07 -0400 Subject: [PATCH 064/531] [Serve] Introducing GPU sampler for CUDA (#1934) This PR introduces the GPU sampler for CUDA only. The GPU sampler makes use of the GPU sampling ops introduced in apache/tvm#16575. We will follow up to benchmark the performance of the GPU sampler over CPU sampler. --- cpp/serve/engine.cc | 10 +- cpp/serve/engine_actions/action.h | 2 +- cpp/serve/engine_actions/batch_decode.cc | 2 +- cpp/serve/engine_actions/batch_draft.cc | 2 +- cpp/serve/engine_actions/batch_verify.cc | 2 +- .../engine_actions/new_request_prefill.cc | 2 +- cpp/serve/function_table.cc | 7 + cpp/serve/function_table.h | 4 + cpp/serve/model.cc | 9 + cpp/serve/model.h | 8 + .../{sampler.cc => sampler/cpu_sampler.cc} | 24 +- cpp/serve/sampler/gpu_sampler.cc | 328 ++++++++++++++++++ cpp/serve/{ => sampler}/sampler.h | 42 ++- .../attach_embedding_allocator.py | 39 +++ ...ir_module.py => attach_logit_processor.py} | 81 +---- .../mlc_llm/compiler_pass/attach_sampler.py | 274 +++++++++++++++ .../compiler_pass/attach_support_info.py | 48 +++ python/mlc_llm/compiler_pass/pipeline.py | 9 +- 18 files changed, 769 insertions(+), 124 deletions(-) rename cpp/serve/{sampler.cc => sampler/cpu_sampler.cc} (97%) create mode 100644 cpp/serve/sampler/gpu_sampler.cc rename cpp/serve/{ => sampler}/sampler.h (74%) create mode 100644 python/mlc_llm/compiler_pass/attach_embedding_allocator.py rename python/mlc_llm/compiler_pass/{attach_to_ir_module.py => attach_logit_processor.py} (60%) create mode 100644 python/mlc_llm/compiler_pass/attach_sampler.py create mode 100644 python/mlc_llm/compiler_pass/attach_support_info.py diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index f043b4bcac..39c84a1c8d 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -25,7 +25,7 @@ #include "model.h" #include "request.h" #include "request_state.h" -#include "sampler.h" +#include "sampler/sampler.h" namespace mlc { namespace llm { @@ -78,13 +78,13 @@ class EngineImpl : public Engine { this->models_.push_back(model); this->model_workspaces_.push_back(ModelWorkspace{model->AllocEmbeddingTensor()}); } - int max_logit_processor_num_token = kv_cache_config_->max_num_sequence; + int max_num_tokens = kv_cache_config_->max_num_sequence; if (engine_mode_->enable_speculative) { - max_logit_processor_num_token *= engine_mode_->spec_draft_length; + max_num_tokens *= engine_mode_->spec_draft_length; } LogitProcessor logit_processor = - this->models_[0]->CreateLogitProcessor(max_logit_processor_num_token, trace_recorder); - Sampler sampler = Sampler::Create(/*sampler_kind=*/"cpu", trace_recorder_); + this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); + Sampler sampler = this->models_[0]->CreateSampler(max_num_tokens, trace_recorder); // Step 3. Initialize engine actions that represent state transitions. if (this->engine_mode_->enable_speculative) { // Speculative decoding is only possible for more than one model. diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index 7a5e217569..e355168365 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -11,7 +11,7 @@ #include "../engine_state.h" #include "../event_trace_recorder.h" #include "../model.h" -#include "../sampler.h" +#include "../sampler/sampler.h" namespace mlc { namespace llm { diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index 2af5d86404..eea7e79fb4 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -8,7 +8,7 @@ #include "../../random.h" #include "../config.h" #include "../model.h" -#include "../sampler.h" +#include "../sampler/sampler.h" #include "action.h" #include "action_commons.h" diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index cef66443db..b56f7fa9b6 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -7,7 +7,7 @@ #include "../config.h" #include "../model.h" -#include "../sampler.h" +#include "../sampler/sampler.h" #include "action.h" #include "action_commons.h" diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 79c2a17b95..df1737c547 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -11,7 +11,7 @@ #include "../../random.h" #include "../config.h" #include "../model.h" -#include "../sampler.h" +#include "../sampler/sampler.h" #include "action.h" #include "action_commons.h" diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index 9a2722ff1c..715105a043 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -5,7 +5,7 @@ #include "../config.h" #include "../model.h" -#include "../sampler.h" +#include "../sampler/sampler.h" #include "action.h" #include "action_commons.h" diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 70c855d5f7..1c42caae1e 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -18,6 +18,7 @@ #include #include "../support/load_bytes_from_file.h" +#include "sampler/sampler.h" namespace mlc { namespace llm { @@ -221,6 +222,12 @@ void FunctionTable::_InitFunctions() { this->kv_cache_popn_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_popn"); this->kv_cache_get_num_available_pages_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_get_num_available_pages"); + if (Sampler::SupportGPUSampler(local_gpu_device)) { + gpu_multinomial_from_uniform_func_ = mod->GetFunction("multinomial_from_uniform", true); + gpu_argsort_probs_func_ = mod->GetFunction("argsort_probs", true); + gpu_sample_with_top_p_func_ = mod->GetFunction("sample_with_top_p", true); + gpu_sampler_take_probs_func_ = mod->GetFunction("sampler_take_probs", true); + } this->nd_view_func_ = get_global_func("vm.builtin.reshape"); this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of"); this->nd_copy_embedding_to_offset_func_ = get_global_func("mlc.copy_embedding_to_offset"); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index 9cc0ecb8e2..f3466506ff 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -87,6 +87,10 @@ struct FunctionTable { PackedFunc kv_cache_attention_func_; PackedFunc kv_cache_popn_func_; PackedFunc kv_cache_get_num_available_pages_func_; + PackedFunc gpu_multinomial_from_uniform_func_; + PackedFunc gpu_argsort_probs_func_; + PackedFunc gpu_sample_with_top_p_func_; + PackedFunc gpu_sampler_take_probs_func_; PackedFunc nd_view_func_; PackedFunc nd_get_shape_func_; PackedFunc nd_copy_embedding_to_offset_func_; diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index b5cb5c6b5a..da332b3775 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -284,6 +284,15 @@ class ModelImpl : public ModelObj { std::move(trace_recorder)); } + Sampler CreateSampler(int max_num_sample, Optional trace_recorder) { + if (Sampler::SupportGPUSampler(device_)) { + return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_, + std::move(trace_recorder)); + } else { + return Sampler::CreateCPUSampler(std::move(trace_recorder)); + } + } + void CreateKVCache(KVCacheConfig kv_cache_config) final { IntTuple max_num_sequence{kv_cache_config->max_num_sequence}; IntTuple max_total_sequence_length{kv_cache_config->max_total_sequence_length}; diff --git a/cpp/serve/model.h b/cpp/serve/model.h index acc50187d2..7bce2cafd4 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -15,6 +15,7 @@ #include "event_trace_recorder.h" #include "function_table.h" #include "logit_processor.h" +#include "sampler/sampler.h" namespace mlc { namespace llm { @@ -23,6 +24,9 @@ namespace serve { using tvm::Device; using namespace tvm::runtime; +// Declare the sampler class for `Model::CreateSampler`. +class Sampler; + /*! * \brief The workspace tensors that may be shared across different * calls to Model. For example, the prefill action use the `embeddings` @@ -144,6 +148,10 @@ class ModelObj : public Object { virtual LogitProcessor CreateLogitProcessor(int max_num_token, Optional trace_recorder) = 0; + /*! \brief Create a sampler from this model. */ + virtual Sampler CreateSampler(int max_num_sample, + Optional trace_recorder) = 0; + /*! * \brief Estimate number of CPU units required to drive the model * executing during TP. diff --git a/cpp/serve/sampler.cc b/cpp/serve/sampler/cpu_sampler.cc similarity index 97% rename from cpp/serve/sampler.cc rename to cpp/serve/sampler/cpu_sampler.cc index 4a59cefaff..e1316e57f0 100644 --- a/cpp/serve/sampler.cc +++ b/cpp/serve/sampler/cpu_sampler.cc @@ -1,10 +1,8 @@ /*! * Copyright (c) 2023 by Contributors - * \file serve/sampler.cc - * \brief The implementation for runtime module of sampler functions. + * \file serve/sampler/cpu_sampler.cc + * \brief The implementation for CPU sampler functions. */ -#include "sampler.h" - #include #include #include @@ -12,7 +10,8 @@ #include -#include "../random.h" +#include "../../random.h" +#include "sampler.h" namespace mlc { namespace llm { @@ -250,6 +249,8 @@ inline std::vector ComputeTopProbs(NDArray prob, int unit_offset, /********************* CPU Sampler *********************/ +TVM_REGISTER_OBJECT_TYPE(SamplerObj); + class CPUSampler : public SamplerObj { public: explicit CPUSampler(Optional trace_recorder) @@ -430,17 +431,8 @@ class CPUSampler : public SamplerObj { const float eps_ = 1e-5; }; -/*********************** Sampler ***********************/ - -TVM_REGISTER_OBJECT_TYPE(SamplerObj); - -Sampler Sampler::Create(std::string sampler_kind, Optional trace_recorder) { - if (sampler_kind == "cpu") { - return Sampler(make_object(std::move(trace_recorder))); - } else { - LOG(FATAL) << "Unsupported sampler_kind \"" << sampler_kind << "\""; - throw; - } +Sampler Sampler::CreateCPUSampler(Optional trace_recorder) { + return Sampler(make_object(std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc new file mode 100644 index 0000000000..d8a54001d3 --- /dev/null +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -0,0 +1,328 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/sampler/gpu_sampler.cc + * \brief The implementation for GPU sampler functions. + */ +#include +#include + +#include "../../random.h" +#include "sampler.h" + +namespace mlc { +namespace llm { +namespace serve { + +inline void CopyArray(NDArray src, NDArray dst) { + DLTensor dl_dst = *(dst.operator->()); + NDArray::CopyFromTo(src.operator->(), &dl_dst); +} + +/*********************** GPU Sampler ***********************/ + +class GPUSampler : public SamplerObj { + public: + explicit GPUSampler(int max_num_sample, int vocab_size, FunctionTable* ft, DLDevice device, + Optional trace_recorder) + : max_num_sample_(max_num_sample), + vocab_size_(vocab_size), + device_(device), + gpu_multinomial_from_uniform_func_(ft->gpu_multinomial_from_uniform_func_), + gpu_argsort_probs_func_(ft->gpu_argsort_probs_func_), + gpu_sample_with_top_p_func_(ft->gpu_sample_with_top_p_func_), + gpu_sampler_take_probs_func_(ft->gpu_sampler_take_probs_func_), + trace_recorder_(std::move(trace_recorder)) { + ICHECK(gpu_multinomial_from_uniform_func_.defined()); + ICHECK(gpu_argsort_probs_func_.defined()); + ICHECK(gpu_sample_with_top_p_func_.defined()); + ICHECK(gpu_sampler_take_probs_func_.defined()); + + DLDevice device_cpu{DLDeviceType::kDLCPU, /*device_id=*/0}; + // We support at most 5 top prob results for each sequence. + // Initialize auxiliary arrays on CPU. + uniform_samples_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); + sample_indices_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); + top_p_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); + top_prob_offsets_host_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device_cpu); + sampled_token_ids_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); + sampled_probs_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); + top_prob_probs_host_ = NDArray::Empty({max_num_sample * 5}, dtype_f32_, device_cpu); + top_prob_indices_host_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device_cpu); + // Initialize auxiliary arrays on GPU. + uniform_samples_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); + sample_indices_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + top_p_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); + top_prob_offsets_device_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device); + } + + std::vector BatchSampleTokens(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + std::vector* output_prob_dist) final { + // probs_on_device: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); + CHECK_EQ(probs_on_device->ndim, 2); + int num_samples = sample_indices.size(); + int num_probs = probs_on_device->shape[0]; + int vocab_size = probs_on_device->shape[1]; + ICHECK_EQ(request_ids.size(), num_samples); + ICHECK_EQ(generation_cfg.size(), num_samples); + ICHECK_EQ(rngs.size(), num_samples); + + // - Generate random numbers. + // Copy the random numbers and sample indices. + auto [uniform_samples_device, sample_indices_device] = + CopySamplesAndIndicesToGPU(sample_indices, rngs, num_samples); + + // - Check if there is need for applying top p or prob values, + // so that argsort is needed. + bool need_top_p = false; + bool need_prob_values = false; + // The indptr array of the number of top probs for each sample. + std::vector top_prob_offset_indptr; + CheckTopPAndProbValues(generation_cfg, sample_indices, num_probs, num_samples, vocab_size, + &need_top_p, &need_prob_values, &top_prob_offset_indptr); + + // - Sample tokens on GPU, and take out the probability values if needed. + std::vector device_arrays = + SampleOnGPU(probs_on_device, uniform_samples_device, sample_indices_device, need_top_p, + need_prob_values, num_probs, top_prob_offset_indptr); + + // - Copy the GPU sampling function results to CPU. + std::vector host_arrays = CopyArraysToCPU(device_arrays, num_samples, need_prob_values, + top_prob_offset_indptr.back()); + + // - Collect the sampling results. + const int* p_sampled_token_ids = static_cast(host_arrays[0]->data); + const float* p_sampled_probs = nullptr; + const float* p_top_prob_probs = nullptr; + const int* p_top_prob_indices = nullptr; + if (need_prob_values) { + p_sampled_probs = static_cast(host_arrays[1]->data); + p_top_prob_probs = static_cast(host_arrays[2]->data); + p_top_prob_indices = static_cast(host_arrays[3]->data); + } + std::vector sample_results; + sample_results.reserve(num_samples); + ICHECK_EQ(top_prob_offset_indptr.size(), num_samples + 1); + for (int i = 0; i < num_samples; ++i) { + // Note: we set the probability in SampleResult to 1.0 since prob value is not needed. + float sampled_prob = need_prob_values ? p_sampled_probs[i] : 1.0; + std::vector top_prob_tokens; + top_prob_tokens.reserve(top_prob_offset_indptr[i + 1] - top_prob_offset_indptr[i]); + for (int j = top_prob_offset_indptr[i]; j < top_prob_offset_indptr[i + 1]; ++j) { + top_prob_tokens.emplace_back(p_top_prob_indices[j], p_top_prob_probs[j]); + } + sample_results.push_back( + SampleResult{{p_sampled_token_ids[i], sampled_prob}, top_prob_tokens}); + } + + RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + return sample_results; + } + + std::vector> BatchVerifyDraftTokens( + NDArray probs_on_device, const Array& request_ids, + const std::vector& cum_verify_lengths, const Array& generation_cfg, + const std::vector& rngs, + const std::vector>& draft_output_tokens, + const std::vector>& draft_output_prob_dist) final { + LOG(FATAL) << "GPU sampler does not support batch verification for now."; + } + + private: + /*! \brief Generate uniform random numbers, and copy the numbers and sample indices to GPU. */ + std::pair CopySamplesAndIndicesToGPU(const std::vector& sample_indices, + const std::vector& rngs, + int num_samples) { + // Generate random numbers. + float* p_uniform_samples = static_cast(uniform_samples_host_->data); + int* p_sample_indices = static_cast(sample_indices_host_->data); + for (int i = 0; i < num_samples; ++i) { + p_uniform_samples[i] = rngs[i]->GetRandomNumber(); + p_sample_indices[i] = sample_indices[i]; + } + // Copy the random numbers and sample indices to GPU. + NDArray uniform_samples_host = uniform_samples_host_.CreateView({num_samples}, dtype_f32_); + NDArray uniform_samples_device = uniform_samples_device_.CreateView({num_samples}, dtype_f32_); + NDArray sample_indices_host = sample_indices_host_.CreateView({num_samples}, dtype_i32_); + NDArray sample_indices_device = sample_indices_device_.CreateView({num_samples}, dtype_i32_); + CopyArray(/*src=*/uniform_samples_host, /*dst=*/uniform_samples_device); + CopyArray(/*src=*/sample_indices_host, /*dst=*/sample_indices_device); + return {uniform_samples_device, sample_indices_device}; + } + + /*! \brief Check if top p and prob values are needed, and collect info when necessary. */ + void CheckTopPAndProbValues(const Array& generation_cfg, + const std::vector& sample_indices, int num_probs, + int num_samples, int vocab_size, bool* need_top_p, + bool* need_prob_values, std::vector* top_prob_offset_indptr) { + top_prob_offset_indptr->reserve(num_samples + 1); + top_prob_offset_indptr->push_back(0); + // Initialize top p values with -1. + float* p_top_p = static_cast(top_p_host_->data); + for (int i = 0; i < num_probs; ++i) { + p_top_p[i] = -1.0; + } + int* p_top_prob_offsets = static_cast(top_prob_offsets_host_->data); + int num_top_probs = 0; + for (int i = 0; i < num_samples; ++i) { + if (p_top_p[sample_indices[i]] == -1.0) { + p_top_p[sample_indices[i]] = generation_cfg[i]->top_p; + *need_top_p |= generation_cfg[i]->top_p != 1.0; + } else { + CHECK(fabs(p_top_p[sample_indices[i]] - generation_cfg[i]->top_p) < eps_) + << "GPU sampler requires the top_p values for each prob distribution are the same."; + } + + *need_prob_values |= generation_cfg[i]->logprobs; + for (int j = 0; j < generation_cfg[i]->top_logprobs; ++j) { + p_top_prob_offsets[num_top_probs++] = sample_indices[i] * vocab_size + j; + } + top_prob_offset_indptr->push_back(top_prob_offset_indptr->back() + + generation_cfg[i]->top_logprobs); + } + ICHECK_EQ(num_top_probs, top_prob_offset_indptr->back()); + } + + /*! \brief Sample tokens on GPU. Take out the probability values when needed. */ + std::vector SampleOnGPU(NDArray probs_on_device, NDArray uniform_samples_device, + NDArray sample_indices_device, // + bool need_top_p, bool need_prob_values, int num_probs, + const std::vector& top_prob_offset_indptr) { + NDArray sampled_token_ids_device{nullptr}; + NDArray sampled_probs_device{nullptr}; + NDArray top_prob_probs_device{nullptr}; + NDArray top_prob_indices_device{nullptr}; + + if (!need_top_p && !need_prob_values) { + // - Short path: If top_p and prob values are not needed, we directly sample from multinomial. + sampled_token_ids_device = gpu_multinomial_from_uniform_func_( + probs_on_device, uniform_samples_device, sample_indices_device); + return {sampled_token_ids_device, sampled_probs_device, top_prob_probs_device, + top_prob_indices_device}; + } + + // - Argsort the probability. + Array argsort_results = gpu_argsort_probs_func_(probs_on_device); + ICHECK_EQ(argsort_results.size(), 2); + NDArray sorted_probs_on_device = argsort_results[0]; + NDArray sorted_indices_on_device = argsort_results[1]; + + if (need_top_p) { + // - Sample with top_p applied. + NDArray top_p_host = top_p_host_.CreateView({num_probs}, dtype_f32_); + NDArray top_p_device = top_p_device_.CreateView({num_probs}, dtype_f32_); + CopyArray(/*src=*/top_p_host, /*dst=*/top_p_device); + sampled_token_ids_device = + gpu_sample_with_top_p_func_(sorted_probs_on_device, sorted_indices_on_device, + uniform_samples_device, sample_indices_device, top_p_device); + } else { + // - Sample without top_p. + sampled_token_ids_device = gpu_multinomial_from_uniform_func_( + probs_on_device, uniform_samples_device, sample_indices_device); + } + + if (need_prob_values) { + // - Take the probability values. + int num_top_probs = top_prob_offset_indptr.back(); + NDArray top_prob_offsets_host = + top_prob_offsets_host_.CreateView({num_top_probs}, dtype_i32_); + NDArray top_prob_offsets_device = + top_prob_offsets_device_.CreateView({num_top_probs}, dtype_i32_); + CopyArray(/*src=*/top_prob_offsets_host, /*dst=*/top_prob_offsets_device); + Array prob_value_results = gpu_sampler_take_probs_func_( + probs_on_device, sorted_indices_on_device, sample_indices_device, + sampled_token_ids_device, top_prob_offsets_device); + sampled_probs_device = prob_value_results[0]; + top_prob_probs_device = prob_value_results[1]; + top_prob_indices_device = prob_value_results[2]; + } + + return {sampled_token_ids_device, sampled_probs_device, top_prob_probs_device, + top_prob_indices_device}; + } + + /*! \brief Copy the results of GPU sampling functions back to CPU. */ + std::vector CopyArraysToCPU(const std::vector& device_arrays, // + int num_samples, bool need_prob_values, int num_top_probs) { + NDArray sampled_token_ids_device = device_arrays[0]; + NDArray sampled_probs_device = device_arrays[1]; + NDArray top_prob_probs_device = device_arrays[2]; + NDArray top_prob_indices_device = device_arrays[3]; + ICHECK(sampled_token_ids_device.defined()); + ICHECK_EQ(sampled_token_ids_device->ndim, 1); + ICHECK_EQ(sampled_token_ids_device->shape[0], num_samples); + NDArray sampled_token_ids_host = sampled_token_ids_host_.CreateView({num_samples}, dtype_i32_); + CopyArray(/*src=*/sampled_token_ids_device, /*dst=*/sampled_token_ids_host); + + NDArray sampled_probs_host{nullptr}; + NDArray top_prob_probs_host{nullptr}; + NDArray top_prob_indices_host{nullptr}; + if (need_prob_values) { + ICHECK(sampled_probs_device.defined()); + ICHECK(top_prob_probs_device.defined()); + ICHECK(top_prob_indices_device.defined()); + ICHECK_EQ(sampled_probs_device->ndim, 1); + ICHECK_EQ(top_prob_probs_device->ndim, 1); + ICHECK_EQ(top_prob_indices_device->ndim, 1); + ICHECK_EQ(sampled_probs_device->shape[0], num_samples); + ICHECK_EQ(top_prob_probs_device->shape[0], num_top_probs); + ICHECK_EQ(top_prob_indices_device->shape[0], num_top_probs); + sampled_probs_host = sampled_probs_host_.CreateView({num_samples}, dtype_i32_); + top_prob_probs_host = top_prob_probs_host_.CreateView({num_top_probs}, dtype_f32_); + top_prob_indices_host = top_prob_indices_host_.CreateView({num_top_probs}, dtype_i32_); + CopyArray(/*src=*/sampled_probs_device, /*dst=*/sampled_probs_host); + if (num_top_probs > 0) { + CopyArray(/*src=*/top_prob_probs_device, /*dst=*/top_prob_probs_host); + CopyArray(/*src=*/top_prob_indices_device, /*dst=*/top_prob_indices_host); + } + } + + // Synchronize for CPU to get the correct array results. + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + + return {sampled_token_ids_host, sampled_probs_host, top_prob_probs_host, top_prob_indices_host}; + } + + // Model configurations + const int max_num_sample_; + const int vocab_size_; + const DLDataType dtype_i32_ = DataType::Int(32); + const DLDataType dtype_f32_ = DataType::Float(32); + // Functions for sampling on GPU. + Device device_; + PackedFunc gpu_multinomial_from_uniform_func_; + PackedFunc gpu_argsort_probs_func_; + PackedFunc gpu_sample_with_top_p_func_; + PackedFunc gpu_sampler_take_probs_func_; + // Auxiliary NDArrays on CPU + NDArray uniform_samples_host_; + NDArray sample_indices_host_; + NDArray top_p_host_; + NDArray top_prob_offsets_host_; + NDArray sampled_token_ids_host_; + NDArray sampled_probs_host_; + NDArray top_prob_probs_host_; + NDArray top_prob_indices_host_; + // Auxiliary NDArrays on GPU + NDArray uniform_samples_device_; + NDArray sample_indices_device_; + NDArray top_p_device_; + NDArray top_prob_offsets_device_; + // The event trace recorder for requests. */ + Optional trace_recorder_; + const float eps_ = 1e-5; +}; + +Sampler Sampler::CreateGPUSampler(int max_num_sample, int vocab_size, FunctionTable* ft, + DLDevice device, Optional trace_recorder) { + return Sampler( + make_object(max_num_sample, vocab_size, ft, device, std::move(trace_recorder))); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/sampler.h b/cpp/serve/sampler/sampler.h similarity index 74% rename from cpp/serve/sampler.h rename to cpp/serve/sampler/sampler.h index c48702c0c7..03d031bdb7 100644 --- a/cpp/serve/sampler.h +++ b/cpp/serve/sampler/sampler.h @@ -1,21 +1,21 @@ /*! * Copyright (c) 2023 by Contributors - * \file serve/sampler.h + * \file serve/sampler/sampler.h * \brief The header for runtime module of sampler functions. */ -#ifndef MLC_LLM_SERVE_SAMPLER_H_ -#define MLC_LLM_SERVE_SAMPLER_H_ +#ifndef MLC_LLM_SERVE_SAMPLER_SAMPLER_H_ +#define MLC_LLM_SERVE_SAMPLER_SAMPLER_H_ #include #include -#include "../base.h" -#include "../random.h" -#include "data.h" -#include "event_trace_recorder.h" -#include "model.h" -#include "request_state.h" +#include "../../base.h" +#include "../../random.h" +#include "../data.h" +#include "../event_trace_recorder.h" +#include "../model.h" +#include "../request_state.h" namespace mlc { namespace llm { @@ -84,14 +84,24 @@ class SamplerObj : public Object { class Sampler : public ObjectRef { public: + /*! * \brief Create a CPU sampler. */ + TVM_DLL static Sampler CreateCPUSampler(Optional trace_recorder); /*! - * \brief Create the runtime sampler module. - * \param sampler_kind The sampler name denoting which sampler to create. - * \param trace_recorder The event trace recorder for requests. - * \return The created runtime module. + * \brief Create a GPU sampler. + * \param max_num_sample The max number of samples to sample at a time. + * \param vocab_size The model's vocabulary size. + * \param ft The packed function table. + * \param device The device that the model runs on. + * \param trace_recorder The event trace recorder. */ - TVM_DLL static Sampler Create(std::string sampler_kind, - Optional trace_recorder); + TVM_DLL static Sampler CreateGPUSampler(int max_num_sample, int vocab_size, FunctionTable* ft, + DLDevice device, + Optional trace_recorder); + + /*! \brief Check if the given device supports GPU sampling. */ + static bool SupportGPUSampler(Device device) { + return device.device_type == DLDeviceType::kDLCUDA; + } TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Sampler, ObjectRef, SamplerObj); }; @@ -100,4 +110,4 @@ class Sampler : public ObjectRef { } // namespace llm } // namespace mlc -#endif // MLC_LLM_SERVE_SAMPLER_H_ +#endif // MLC_LLM_SERVE_SAMPLER_SAMPLER_H_ diff --git a/python/mlc_llm/compiler_pass/attach_embedding_allocator.py b/python/mlc_llm/compiler_pass/attach_embedding_allocator.py new file mode 100644 index 0000000000..270c67523c --- /dev/null +++ b/python/mlc_llm/compiler_pass/attach_embedding_allocator.py @@ -0,0 +1,39 @@ +"""The pass that attaches embedding allocation function to the IRModule.""" + +from typing import Any, Dict + +import tvm +from tvm import IRModule, relax + + +@tvm.transform.module_pass(opt_level=0, name="AttachAllocEmbeddingTensorFunc") +class AttachAllocEmbeddingTensorFunc: # pylint: disable=too-few-public-methods + """Attach embedding tensor allocation Relax function to IRModule.""" + + def __init__(self, metadata: Dict[str, Any]): + self.metadata = metadata + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + embed_func = None + for gv, func in mod.functions_items(): + if gv.name_hint == "embed": + embed_func = func + + if embed_func is None: + return mod + + hidden_size = embed_func.ret_struct_info.shape[-1] + dtype = embed_func.ret_struct_info.dtype + bb = relax.BlockBuilder(mod) + with bb.function("alloc_embedding_tensor", []): + bb.emit_func_output( + bb.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr([self.metadata["prefill_chunk_size"], hidden_size]), + dtype, + runtime_device_index=0, + ) + ) + ) + return bb.finalize() diff --git a/python/mlc_llm/compiler_pass/attach_to_ir_module.py b/python/mlc_llm/compiler_pass/attach_logit_processor.py similarity index 60% rename from python/mlc_llm/compiler_pass/attach_to_ir_module.py rename to python/mlc_llm/compiler_pass/attach_logit_processor.py index 9f1271dcf6..1b3b5c4994 100644 --- a/python/mlc_llm/compiler_pass/attach_to_ir_module.py +++ b/python/mlc_llm/compiler_pass/attach_logit_processor.py @@ -1,54 +1,10 @@ -"""A couple of passes that simply attach additional information onto the IRModule.""" - -from typing import Any, Dict +"""The pass that attaches logit processor functions to the IRModule.""" import tvm -from tvm import IRModule, relax, tir +from tvm import IRModule from tvm.script import tir as T -@tvm.transform.module_pass(opt_level=0, name="AttachVariableBounds") -class AttachVariableBounds: # pylint: disable=too-few-public-methods - """Attach variable bounds to each Relax function, which primarily helps with memory planning.""" - - def __init__(self, variable_bounds: Dict[str, int]): - # Specifically for RWKV workloads, which contains -1 max_seq_len - self.variable_bounds = {k: v for k, v in variable_bounds.items() if v > 0} - - def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: - """Entrypoint""" - for g_var, func in mod.functions_items(): - if isinstance(func, relax.Function): - mod[g_var] = func.with_attr("tir_var_upper_bound", self.variable_bounds) - return mod - - -@tvm.transform.module_pass(opt_level=0, name="AttachAdditionalPrimFuncs") -class AttachAdditionalPrimFuncs: # pylint: disable=too-few-public-methods - """Attach extra TIR PrimFuncs to the IRModule""" - - def __init__(self, functions: Dict[str, tir.PrimFunc]): - self.functions = functions - - def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: - """Entrypoint""" - for func_name, func in self.functions.items(): - mod[func_name] = func.with_attr("global_symbol", func_name) - return mod - - -@tvm.transform.module_pass(opt_level=0, name="AttachMemoryPlanAttr") -class AttachMemoryPlanAttr: # pylint: disable=too-few-public-methods - """Attach memory planning attribute for dynamic function output planning to Relax functions.""" - - def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: - """Entrypoint""" - for g_var, func in mod.functions_items(): - if isinstance(func, relax.Function): - mod[g_var] = func.with_attr("relax.memory_plan_dynamic_func_output", True) - return mod - - @tvm.transform.module_pass(opt_level=0, name="AttachLogitProcessFunc") class AttachLogitProcessFunc: # pylint: disable=too-few-public-methods """Attach logit processing TIR functions to IRModule.""" @@ -62,39 +18,6 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR return mod -@tvm.transform.module_pass(opt_level=0, name="AttachAllocEmbeddingTensorFunc") -class AttachAllocEmbeddingTensorFunc: # pylint: disable=too-few-public-methods - """Attach embedding tensor allocation Relax function to IRModule.""" - - def __init__(self, metadata: Dict[str, Any]): - self.metadata = metadata - - def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: - """Entrypoint""" - embed_func = None - for gv, func in mod.functions_items(): - if gv.name_hint == "embed": - embed_func = func - - if embed_func is None: - return mod - - hidden_size = embed_func.ret_struct_info.shape[-1] - dtype = embed_func.ret_struct_info.dtype - bb = relax.BlockBuilder(mod) - with bb.function("alloc_embedding_tensor", []): - bb.emit_func_output( - bb.emit( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr([self.metadata["prefill_chunk_size"], hidden_size]), - dtype, - runtime_device_index=0, - ) - ) - ) - return bb.finalize() - - @T.prim_func def _apply_logit_bias_inplace( var_logits: T.handle, diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py new file mode 100644 index 0000000000..64faf93bf3 --- /dev/null +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -0,0 +1,274 @@ +"""The pass that attaches GPU sampler functions to the IRModule.""" + +from typing import Dict + +import tvm +from tvm import IRModule, relax, te, tir +from tvm.relax.frontend import nn +from tvm.script import tir as T + + +@tvm.transform.module_pass(opt_level=0, name="AttachGPUSamplingFunc") +class AttachGPUSamplingFunc: # pylint: disable=too-few-public-methods + """Attach GPU sampling functions to IRModule.""" + + def __init__(self, target: tvm.target.Target, variable_bounds: Dict[str, int]): + # Specifically for RWKV workloads, which contains -1 max_seq_len + max_batch_size = variable_bounds["batch_size"] + self.variable_bounds = { + "batch_size": max_batch_size, + "num_samples": max_batch_size, + "num_positions": 6 * max_batch_size, + } + self.target = target + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + if str(self.target.kind) != "cuda": + # Only enable GPU sampling for CUDA. + return mod + + bb = relax.BlockBuilder(mod) + vocab_size = mod["prefill"].ret_struct_info.fields[0].shape[-1] + gv_names = [ + gv.name_hint + for gv in [ + _attach_multinomial_sampling_func(bb, vocab_size), + _attach_argsort_func(bb, vocab_size), + _attach_sample_with_top_p(bb, vocab_size), + _attach_take_probs_func(bb, vocab_size), + ] + ] + + mod = bb.finalize() + for gv_name in gv_names: + mod[gv_name] = mod[gv_name].with_attr("tir_var_upper_bound", self.variable_bounds) + return mod + + +def _attach_multinomial_sampling_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): + batch_size = tir.Var("batch_size", "int64") + num_samples = tir.Var("num_samples", "int64") + probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) + uniform_samples = relax.Var( + "uniform_samples", relax.TensorStructInfo((num_samples,), "float32") + ) + sample_indices = relax.Var("sample_indices", relax.TensorStructInfo((num_samples,), "int32")) + with bb.function("multinomial_from_uniform", [probs, uniform_samples, sample_indices]): + with bb.dataflow(): + sample_shape = relax.ShapeExpr([num_samples, 1]) + probs_tensor = nn.wrap_nested(probs, name="probs") + uniform_samples_tensor = nn.wrap_nested( + relax.call_pure_packed( + "vm.builtin.reshape", + uniform_samples, + sample_shape, + sinfo_args=relax.TensorStructInfo(sample_shape, "float32"), + ), + name="uniform_samples", + ) + sample_indices_tensor = nn.wrap_nested( + relax.call_pure_packed( + "vm.builtin.reshape", + sample_indices, + sample_shape, + sinfo_args=relax.TensorStructInfo(sample_shape, "int32"), + ), + name="sample_indices", + ) + result_tensor = nn.multinomial_from_uniform( # pylint:disable=too-many-function-args + probs_tensor, uniform_samples_tensor, sample_indices_tensor, "int32" + ) + result = bb.emit( + relax.call_pure_packed( + "vm.builtin.reshape", + result_tensor._expr, # pylint: disable=protected-access + sample_indices.struct_info.shape, # pylint: disable=no-member + sinfo_args=sample_indices.struct_info, # pylint: disable=no-member + ) + ) + gv = bb.emit_func_output(result) + return gv + + +def _attach_argsort_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): + batch_size = tir.Var("batch_size", "int64") + probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) + with bb.function("argsort_probs", [probs]): + with bb.dataflow(): + sorted_indices = bb.emit(relax.op.argsort(probs, descending=True, dtype="int32")) + sorted_values = bb.emit_te( + lambda unsorted_probs, sorted_indices: te.compute( + (batch_size, vocab_size), + lambda i, j: unsorted_probs[i, sorted_indices[i, j]], + name="take_sorted_probs", + ), + probs, + sorted_indices, + primfunc_name_hint="take_sorted_probs", + ) + gv = bb.emit_func_output([sorted_values, sorted_indices]) + return gv + + +def _attach_sample_with_top_p( # pylint: disable=too-many-locals + bb: relax.BlockBuilder, vocab_size: tir.PrimExpr +): + batch_size = tir.Var("batch_size", "int64") + num_samples = tir.Var("num_samples", "int64") + sorted_probs = relax.Var( + "sorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32") + ) + sorted_indices = relax.Var( + "sorted_indices", relax.TensorStructInfo((batch_size, vocab_size), "int32") + ) + uniform_samples = relax.Var( + "uniform_samples", relax.TensorStructInfo((num_samples,), "float32") + ) + sample_indices = relax.Var("sample_indices", relax.TensorStructInfo((num_samples,), "int32")) + top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32")) + + @T.prim_func + def full(var_result: T.handle, value: T.int32): + batch_size = T.int32(is_size_var=True) + result = T.match_buffer(var_result, (batch_size, 1), "int32") + for i in T.serial(batch_size): + with T.block("block"): + vi = T.axis.spatial(batch_size, i) + result[vi, 0] = value + + with bb.function( + "sample_with_top_p", + [sorted_probs, sorted_indices, uniform_samples, sample_indices, top_p], + ): + with bb.dataflow(): + sample_shape = relax.ShapeExpr([num_samples, 1]) + top_p_shape = relax.ShapeExpr([batch_size, 1]) + sorted_probs_tensor = nn.wrap_nested(sorted_probs, name="sorted_probs") + sorted_indices_tensor = nn.wrap_nested(sorted_indices, name="sorted_indices") + uniform_samples_tensor = nn.wrap_nested( + relax.call_pure_packed( + "vm.builtin.reshape", + uniform_samples, + sample_shape, + sinfo_args=relax.TensorStructInfo(sample_shape, "float32"), + ), + name="uniform_samples", + ) + sample_indices_tensor = nn.wrap_nested( + relax.call_pure_packed( + "vm.builtin.reshape", + sample_indices, + sample_shape, + sinfo_args=relax.TensorStructInfo(sample_shape, "int32"), + ), + name="sample_indices", + ) + top_p_tensor = nn.wrap_nested( + relax.call_pure_packed( + "vm.builtin.reshape", + top_p, + top_p_shape, + sinfo_args=relax.TensorStructInfo(top_p_shape, "float32"), + ), + name="sample_indices", + ) + top_k_tensor = nn.tensor_ir_op( + full, + name_hint="full", + args=[vocab_size], + out=nn.Tensor.placeholder( + [batch_size, 1], + "int32", + ), + ) + + result_tensor = ( + nn.sample_top_p_top_k_from_sorted_prob( # pylint:disable=too-many-function-args + sorted_probs_tensor, + sorted_indices_tensor, + top_p_tensor, + top_k_tensor, + uniform_samples_tensor, + sample_indices_tensor, + ) + ) + result = bb.emit( + relax.call_pure_packed( + "vm.builtin.reshape", + result_tensor._expr, # pylint: disable=protected-access + sample_indices.struct_info.shape, # pylint: disable=no-member + sinfo_args=sample_indices.struct_info, # pylint: disable=no-member + ) + ) + gv = bb.emit_func_output(result) + return gv + + +def _attach_take_probs_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): + batch_size = tir.Var("batch_size", "int64") + num_samples = tir.Var("num_samples", "int64") + num_positions = tir.Var("num_positions", "int64") + unsorted_probs = relax.Var( + "unsorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32") + ) + sorted_indices = relax.Var( + "sorted_indices", relax.TensorStructInfo((batch_size, vocab_size), "int32") + ) + sample_indices = relax.Var("sample_indices", relax.TensorStructInfo((num_samples,), "int32")) + sampling_results = relax.Var("sampling_result", relax.TensorStructInfo((num_samples,), "int32")) + top_prob_offsets = relax.Var( + "lobprob_offsets", relax.TensorStructInfo((num_positions,), "int32") + ) + + @T.prim_func + def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-arguments + var_unsorted_probs: T.handle, + var_sorted_indices: T.handle, + var_sample_indices: T.handle, + var_sampling_results: T.handle, + var_top_prob_offsets: T.handle, + var_sampled_values: T.handle, + var_top_prob_probs: T.handle, + var_top_prob_indices: T.handle, + ): + batch_size = T.int32(is_size_var=True) + num_samples = T.int32(is_size_var=True) + num_positions = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + unsorted_probs = T.match_buffer(var_unsorted_probs, (batch_size, vocab_size), "float32") + sorted_indices = T.match_buffer(var_sorted_indices, (batch_size, vocab_size), "int32") + sample_indices = T.match_buffer(var_sample_indices, (num_samples,), "int32") + sampling_results = T.match_buffer(var_sampling_results, (num_samples,), "int32") + top_prob_offsets = T.match_buffer(var_top_prob_offsets, (num_positions,), "int32") + sampled_values = T.match_buffer(var_sampled_values, (num_samples,), "float32") + top_prob_probs = T.match_buffer(var_top_prob_probs, (num_positions,), "float32") + top_prob_indices = T.match_buffer(var_top_prob_indices, (num_positions,), "int32") + for i in T.serial(num_positions + num_samples): + with T.block("block"): + vi = T.axis.spatial(num_positions + num_samples, i) + if vi < num_positions: + row = T.floordiv(top_prob_offsets[vi], vocab_size) + col = T.floormod(top_prob_offsets[vi], vocab_size) + top_prob_indices[vi] = sorted_indices[row, col] + top_prob_probs[vi] = unsorted_probs[row, sorted_indices[row, col]] + else: + vj: T.int32 = vi - num_positions + sampled_values[vj] = unsorted_probs[sample_indices[vj], sampling_results[vj]] + + args = [unsorted_probs, sorted_indices, sample_indices, sampling_results, top_prob_offsets] + with bb.function("sampler_take_probs", args): + with bb.dataflow(): + taken_probs_indices = bb.emit( + relax.call_tir( + bb.add_func(sampler_take_probs_tir, "sampler_take_probs_tir"), + args, + out_sinfo=[ + relax.TensorStructInfo((num_samples,), "float32"), + relax.TensorStructInfo((num_positions,), "float32"), + relax.TensorStructInfo((num_positions,), "int32"), + ], + ) + ) + gv = bb.emit_func_output(taken_probs_indices) + return gv diff --git a/python/mlc_llm/compiler_pass/attach_support_info.py b/python/mlc_llm/compiler_pass/attach_support_info.py new file mode 100644 index 0000000000..c6ec834b13 --- /dev/null +++ b/python/mlc_llm/compiler_pass/attach_support_info.py @@ -0,0 +1,48 @@ +"""A couple of passes that simply supportive information onto the IRModule.""" + +from typing import Dict + +import tvm +from tvm import IRModule, relax, tir + + +@tvm.transform.module_pass(opt_level=0, name="AttachVariableBounds") +class AttachVariableBounds: # pylint: disable=too-few-public-methods + """Attach variable bounds to each Relax function, which primarily helps with memory planning.""" + + def __init__(self, variable_bounds: Dict[str, int]): + # Specifically for RWKV workloads, which contains -1 max_seq_len + self.variable_bounds = {k: v for k, v in variable_bounds.items() if v > 0} + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + for g_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + mod[g_var] = func.with_attr("tir_var_upper_bound", self.variable_bounds) + return mod + + +@tvm.transform.module_pass(opt_level=0, name="AttachAdditionalPrimFuncs") +class AttachAdditionalPrimFuncs: # pylint: disable=too-few-public-methods + """Attach extra TIR PrimFuncs to the IRModule""" + + def __init__(self, functions: Dict[str, tir.PrimFunc]): + self.functions = functions + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + for func_name, func in self.functions.items(): + mod[func_name] = func.with_attr("global_symbol", func_name) + return mod + + +@tvm.transform.module_pass(opt_level=0, name="AttachMemoryPlanAttr") +class AttachMemoryPlanAttr: # pylint: disable=too-few-public-methods + """Attach memory planning attribute for dynamic function output planning to Relax functions.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + for g_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + mod[g_var] = func.with_attr("relax.memory_plan_dynamic_func_output", True) + return mod diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index d8f98b84eb..933b8ad6bb 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -11,10 +11,11 @@ from mlc_llm.support import logging -from .attach_to_ir_module import ( +from .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc +from .attach_logit_processor import AttachLogitProcessFunc +from .attach_sampler import AttachGPUSamplingFunc +from .attach_support_info import ( AttachAdditionalPrimFuncs, - AttachAllocEmbeddingTensorFunc, - AttachLogitProcessFunc, AttachMemoryPlanAttr, AttachVariableBounds, ) @@ -95,6 +96,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I AttachLogitProcessFunc(), AttachAdditionalPrimFuncs(additional_tirs), AttachAllocEmbeddingTensorFunc(metadata), + AttachGPUSamplingFunc(target, variable_bounds), AttachMemoryPlanAttr(), tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)), _DebugDump("debug-phase0.py", debug_dump, show_meta=False), @@ -108,6 +110,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I _DebugDump("debug-phase1.py", debug_dump, show_meta=False), # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline _LogProgress("Lowering to TVM TIR kernels"), + tvm.relax.backend.DispatchSortScan(), tvm.relax.transform.LegalizeOps(), tvm.relax.transform.AnnotateTIROpPattern(), tvm.relax.transform.FoldConstant(), From 73b99655d4dec2f66c16b907e8bacc35414e7e6a Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 13 Mar 2024 08:50:12 -0400 Subject: [PATCH 065/531] [Serve] Constrain KV cache capacity on Metal (#1943) This PR constrains the KV cache capacity for Metal devices to 32768, in order to avoid large tensors in KV cache. This is because right now Metal runtime has performance issue when running a kernel where when some input buffer is very large, even if little of the large buffer is accesed in the kernel. --- python/mlc_llm/serve/engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 994a5f4e9e..7d19532d2b 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -227,6 +227,11 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals f"The model weight size {params_bytes} may be larger than GPU memory size {gpu_size_bytes}" ) + if models[0].device.device_type == Device.kDLMetal: + # NOTE: Metal runtime has severe performance issues with large buffers. + # To work around the issue, we limit the KV cache capacity to 32768. + max_total_sequence_length = min(max_total_sequence_length, 32768) + total_size = ( params_bytes + temp_func_bytes From 8a29ee16232e73315050b725d6f418874584c43c Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 13 Mar 2024 08:50:30 -0400 Subject: [PATCH 066/531] [CI] Add windows ci (#1942) This PR adds windows CI. --- .../{documentation.yml => documentation.yaml} | 0 .../{update-relax.yml => update-relax.yaml} | 0 .github/workflows/windows-build.yaml | 38 +++++++++++++++++++ 3rdparty/tvm | 2 +- ci/build-environment.yaml | 15 ++++++++ ci/task/build_win.bat | 15 ++++++++ 6 files changed, 69 insertions(+), 1 deletion(-) rename .github/workflows/{documentation.yml => documentation.yaml} (100%) rename .github/workflows/{update-relax.yml => update-relax.yaml} (100%) create mode 100644 .github/workflows/windows-build.yaml create mode 100644 ci/build-environment.yaml create mode 100644 ci/task/build_win.bat diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yaml similarity index 100% rename from .github/workflows/documentation.yml rename to .github/workflows/documentation.yaml diff --git a/.github/workflows/update-relax.yml b/.github/workflows/update-relax.yaml similarity index 100% rename from .github/workflows/update-relax.yml rename to .github/workflows/update-relax.yaml diff --git a/.github/workflows/windows-build.yaml b/.github/workflows/windows-build.yaml new file mode 100644 index 0000000000..b64b5efd0a --- /dev/null +++ b/.github/workflows/windows-build.yaml @@ -0,0 +1,38 @@ +# GH actions. +# We use it to cover windows builds +# Jenkins is still the primary CI +name: Windows CI + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + Windows: + runs-on: windows-latest + defaults: + run: + shell: 'cmd /C call {0}' + + steps: + - uses: actions/checkout@v3 + with: + submodules: 'recursive' + - uses: conda-incubator/setup-miniconda@v2 + with: + activate-environment: mlc-llm-build + channel-priority: strict + environment-file: ci/build-environment.yaml + auto-activate-base: false + - name: Conda info + run: | + conda info + conda list + python --version + - name: Build MLC-LLM + run: >- + ci/task/build_win.bat diff --git a/3rdparty/tvm b/3rdparty/tvm index 1d4da926c7..f06d486b4a 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 1d4da926c726e2700593c7f62006545bda6a46f9 +Subproject commit f06d486b4a1a27f0bbb072688a5fc41e7b15323c diff --git a/ci/build-environment.yaml b/ci/build-environment.yaml new file mode 100644 index 0000000000..b14ac14860 --- /dev/null +++ b/ci/build-environment.yaml @@ -0,0 +1,15 @@ +name: mlc-llm-build + +channels: + - conda-forge + +dependencies: + - conda-build + - anaconda-client + - libvulkan-headers + - libvulkan-loader + - spirv-tools + - spirv-headers + - git + - cmake + - bzip2 diff --git a/ci/task/build_win.bat b/ci/task/build_win.bat new file mode 100644 index 0000000000..a68cf22e8f --- /dev/null +++ b/ci/task/build_win.bat @@ -0,0 +1,15 @@ +cd mlc-llm +rd /s /q build +mkdir build +cd build + +cmake -A x64 -Thost=x64 ^ + -G "Visual Studio 17 2022" ^ + -DUSE_VULKAN=ON ^ + .. + +if %errorlevel% neq 0 exit %errorlevel% + +cmake --build . --parallel 3 --config Release -- /m + +if %errorlevel% neq 0 exit %errorlevel% From 5c29f02cc198a61545b499595e5d0e50f4d9b138 Mon Sep 17 00:00:00 2001 From: Git bot Date: Wed, 13 Mar 2024 13:43:00 +0000 Subject: [PATCH 067/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index f06d486b4a..1d4da926c7 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit f06d486b4a1a27f0bbb072688a5fc41e7b15323c +Subproject commit 1d4da926c726e2700593c7f62006545bda6a46f9 From 8d192ef74df1a972b34b8871ea8bc471eb598a71 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 13 Mar 2024 13:51:54 -0400 Subject: [PATCH 068/531] [Fix] Fix embedding shape check in ChatModule (#1953) This PR is a fix to address #1952. --- cpp/llm_chat.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index aca13db863..5577f9b87d 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -1447,7 +1447,7 @@ class LLMChat { embedding_shape = embedding_nd.Shape(); } ICHECK_EQ(embedding_shape.size(), 2); - ICHECK_GT(embedding_shape[0], 1); + ICHECK_GE(embedding_shape[0], 1); this->hidden_size_ = embedding_shape[1]; return this->hidden_size_; } From c0b2ccd42a79b1d1bf7d3065892d15f8ffc26af0 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 13 Mar 2024 22:15:22 -0400 Subject: [PATCH 069/531] [Fix] Fetching the Git-LFS tokenizer files (#1954) Prior to this PR, when running commands like ```shell python3 -m mlc_chat chat HF://mlc-ai/gemma-7b-it-q4f16_2-MLC ``` only the binary weight files are downloaded, among all the Git LFS files. For models like Gemma whose tokenizer is large and also in Git LFS file, the tokenizer files are not effectively downloaded automatically. For example, the cloned Gemma `tokenizer.json` file has content ``` version https://git-lfs.github.com/spec/v1 oid sha256:05e97791a5e007260de1db7e1692e53150e08cea481e2bf25435553380c147ee size 17477929 ``` and this content is never realized to the actual tokenizer. This will lead to the issue of #1913. This PR fixes the issue by pulling all the Git LFS files that are not binary files. --- python/mlc_llm/support/download.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/mlc_llm/support/download.py b/python/mlc_llm/support/download.py index 10b1620dc5..a109c967bc 100644 --- a/python/mlc_llm/support/download.py +++ b/python/mlc_llm/support/download.py @@ -1,4 +1,5 @@ """Common utilities for downloading files from HuggingFace or other URLs online.""" + import concurrent.futures as cf import hashlib import json @@ -7,7 +8,7 @@ import subprocess import tempfile from pathlib import Path -from typing import Optional, Tuple +from typing import List, Optional, Tuple import requests # pylint: disable=import-error @@ -56,7 +57,7 @@ def git_clone(url: str, destination: Path, ignore_lfs: bool) -> None: ) from error -def git_lfs_pull(repo_dir: Path) -> None: +def git_lfs_pull(repo_dir: Path, ignore_extensions: Optional[List[str]] = None) -> None: """Pull files with Git LFS.""" filenames = ( subprocess.check_output( @@ -66,6 +67,12 @@ def git_lfs_pull(repo_dir: Path) -> None: .decode("utf-8") .splitlines() ) + if ignore_extensions is not None: + filenames = [ + filename + for filename in filenames + if not any(filename.endswith(extension) for extension in ignore_extensions) + ] logger.info("[Git LFS] Downloading %d files with Git LFS: %s", len(filenames), filenames) with tqdm.redirect(): for file in tqdm.tqdm(filenames): @@ -127,6 +134,7 @@ def download_mlc_weights( # pylint: disable=too-many-locals tmp_dir = Path(tmp_dir_prefix) / "tmp" git_url = git_url_template.format(user=user, repo=repo) git_clone(git_url, tmp_dir, ignore_lfs=True) + git_lfs_pull(tmp_dir, ignore_extensions=[".bin"]) shutil.rmtree(tmp_dir / ".git", ignore_errors=True) with (tmp_dir / "ndarray-cache.json").open(encoding="utf-8") as in_file: param_metadata = json.load(in_file)["records"] From 2872f70be279a289f5823c5ccfda474c4531e373 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Thu, 14 Mar 2024 10:25:54 -0400 Subject: [PATCH 070/531] [LogitProcessor] Add max thread awareness to logit processing kernels (#1955) Make the kernels in `AttachLogitProcessFunc` to be aware of maximum threads, fixing https://github.com/mlc-ai/mlc-llm/issues/1951. Most code change is due to indentation, the main change is changing `1024` to `tx`, where `tx` is ``` tx = 1024 # default max_num_threads_per_block = get_max_num_threads_per_block(target) if max_num_threads_per_block < tx: tx = max_num_threads_per_block check_thread_limits(target, bdx=tx, bdy=1, bdz=1, gdz=1) ``` --- .../compiler_pass/attach_logit_processor.py | 254 +++++++++++------- python/mlc_llm/compiler_pass/pipeline.py | 2 +- 2 files changed, 156 insertions(+), 100 deletions(-) diff --git a/python/mlc_llm/compiler_pass/attach_logit_processor.py b/python/mlc_llm/compiler_pass/attach_logit_processor.py index 1b3b5c4994..8dabf3dcfd 100644 --- a/python/mlc_llm/compiler_pass/attach_logit_processor.py +++ b/python/mlc_llm/compiler_pass/attach_logit_processor.py @@ -4,113 +4,169 @@ from tvm import IRModule from tvm.script import tir as T +from ..support.max_thread_check import ( + check_thread_limits, + get_max_num_threads_per_block, +) + @tvm.transform.module_pass(opt_level=0, name="AttachLogitProcessFunc") class AttachLogitProcessFunc: # pylint: disable=too-few-public-methods """Attach logit processing TIR functions to IRModule.""" + def __init__(self, target: tvm.target.Target): + """Initializer. + + Parameters + ---------- + target : tvm.target.Target + The target of the model compilation. + """ + self.target = target + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: """Entrypoint""" mod = mod.clone() - mod["apply_logit_bias_inplace"] = _apply_logit_bias_inplace - mod["apply_penalty_inplace"] = _apply_penalty_inplace - mod["apply_bitmask_inplace"] = _apply_bitmask_inplace + mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace(self.target) + mod["apply_penalty_inplace"] = _get_apply_penalty_inplace(self.target) + mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace(self.target) return mod -@T.prim_func -def _apply_logit_bias_inplace( - var_logits: T.handle, - var_pos2seq_id: T.handle, - var_token_ids: T.handle, - var_logit_bias: T.handle, -) -> None: - """Function that applies logit bias in place.""" - T.func_attr( - {"global_symbol": "apply_logit_bias_inplace", "tir.noalias": True, "tir.is_scheduled": True} - ) - batch_size = T.int32(is_size_var=True) - vocab_size = T.int32(is_size_var=True) - num_token = T.int32(is_size_var=True) - logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") - # seq_ids - pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") - token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") - logit_bias = T.match_buffer(var_logit_bias, (num_token,), "float32") - - for p0 in T.thread_binding(0, (num_token + 1023) // 1024, "blockIdx.x"): - for p1 in T.thread_binding(0, 1024, "threadIdx.x"): - with T.block("block"): - vp = T.axis.spatial(num_token, p0 * 1024 + p1) - T.where(p0 * 1024 + p1 < num_token) - logits[pos2seq_id[vp], token_ids[vp]] += logit_bias[vp] - - -@T.prim_func -def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-locals - var_logits: T.handle, - var_seq_ids: T.handle, - var_pos2seq_id: T.handle, - var_token_ids: T.handle, - var_token_cnt: T.handle, - var_penalties: T.handle, -) -> None: - """Function that applies penalties in place.""" - T.func_attr( - {"global_symbol": "apply_penalty_inplace", "tir.noalias": True, "tir.is_scheduled": True} - ) - batch_size = T.int32(is_size_var=True) - vocab_size = T.int32(is_size_var=True) - num_token = T.int32(is_size_var=True) - num_seq = T.int32(is_size_var=True) - logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") - seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") - pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") - token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") - token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32") - penalties = T.match_buffer(var_penalties, (num_seq, 3), "float32") - - for p0 in T.thread_binding(0, (num_token + 1023) // 1024, "blockIdx.x"): - for p1 in T.thread_binding(0, 1024, "threadIdx.x"): - with T.block("block"): - vp = T.axis.spatial(num_token, p0 * 1024 + p1) - T.where(p0 * 1024 + p1 < num_token) - # Penalties: (presence_penalty, frequency_penalty, repetition_penalty) - logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= ( - penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1] - ) - logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else( - logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] > 0, - logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], - logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2], - ) - - -@T.prim_func -def _apply_bitmask_inplace( - var_logits: T.handle, - var_seq_ids: T.handle, - var_bitmask: T.handle, -) -> None: - """Function that applies vocabulary masking in place.""" - T.func_attr( - {"global_symbol": "apply_bitmask_inplace", "tir.noalias": True, "tir.is_scheduled": True} - ) - batch_size = T.int32(is_size_var=True) - vocab_size = T.int32(is_size_var=True) - num_seq = T.int32(is_size_var=True) - logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") - seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") - bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32") - - for fused_s_v_0 in T.thread_binding(0, (num_seq * vocab_size + 1023) // 1024, "blockIdx.x"): - for fused_s_v_1 in T.thread_binding(0, 1024, "threadIdx.x"): - with T.block("block"): - vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size) - vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size) - T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size) - logits[seq_ids[vs], vv] = T.if_then_else( - (bitmask[seq_ids[vs], vv // 32] >> (vv % 32)) & 1 == 1, - logits[seq_ids[vs], vv], - T.float32(-1e10), - ) +def _get_apply_logit_bias_inplace(target: tvm.target.Target): + tx = 1024 # default + max_num_threads_per_block = get_max_num_threads_per_block(target) + if max_num_threads_per_block < tx: + tx = max_num_threads_per_block + check_thread_limits(target, bdx=tx, bdy=1, bdz=1, gdz=1) + + @T.prim_func + def _apply_logit_bias_inplace( + var_logits: T.handle, + var_pos2seq_id: T.handle, + var_token_ids: T.handle, + var_logit_bias: T.handle, + ) -> None: + """Function that applies logit bias in place.""" + T.func_attr( + { + "global_symbol": "apply_logit_bias_inplace", + "tir.noalias": True, + "tir.is_scheduled": True, + } + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_token = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + # seq_ids + pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") + token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") + logit_bias = T.match_buffer(var_logit_bias, (num_token,), "float32") + + for p0 in T.thread_binding(0, (num_token + tx - 1) // tx, "blockIdx.x"): + for p1 in T.thread_binding(0, tx, "threadIdx.x"): + with T.block("block"): + vp = T.axis.spatial(num_token, p0 * tx + p1) + T.where(p0 * tx + p1 < num_token) + logits[pos2seq_id[vp], token_ids[vp]] += logit_bias[vp] + + return _apply_logit_bias_inplace + + +def _get_apply_penalty_inplace(target: tvm.target.Target): + tx = 1024 # default + max_num_threads_per_block = get_max_num_threads_per_block(target) + if max_num_threads_per_block < tx: + tx = max_num_threads_per_block + check_thread_limits(target, bdx=tx, bdy=1, bdz=1, gdz=1) + + @T.prim_func + def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-locals + var_logits: T.handle, + var_seq_ids: T.handle, + var_pos2seq_id: T.handle, + var_token_ids: T.handle, + var_token_cnt: T.handle, + var_penalties: T.handle, + ) -> None: + """Function that applies penalties in place.""" + T.func_attr( + { + "global_symbol": "apply_penalty_inplace", + "tir.noalias": True, + "tir.is_scheduled": True, + } + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_token = T.int32(is_size_var=True) + num_seq = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") + pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") + token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") + token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32") + penalties = T.match_buffer(var_penalties, (num_seq, 3), "float32") + + for p0 in T.thread_binding(0, (num_token + tx - 1) // tx, "blockIdx.x"): + for p1 in T.thread_binding(0, tx, "threadIdx.x"): + with T.block("block"): + vp = T.axis.spatial(num_token, p0 * tx + p1) + T.where(p0 * tx + p1 < num_token) + # Penalties: (presence_penalty, frequency_penalty, repetition_penalty) + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= ( + penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1] + ) + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else( + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] > 0, + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] + * penalties[pos2seq_id[vp], 2], + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] + / penalties[pos2seq_id[vp], 2], + ) + + return _apply_penalty_inplace + + +def _get_apply_bitmask_inplace(target: tvm.target.Target): + tx = 1024 # default + max_num_threads_per_block = get_max_num_threads_per_block(target) + if max_num_threads_per_block < tx: + tx = max_num_threads_per_block + check_thread_limits(target, bdx=tx, bdy=1, bdz=1, gdz=1) + + @T.prim_func + def _apply_bitmask_inplace( + var_logits: T.handle, + var_seq_ids: T.handle, + var_bitmask: T.handle, + ) -> None: + """Function that applies vocabulary masking in place.""" + T.func_attr( + { + "global_symbol": "apply_bitmask_inplace", + "tir.noalias": True, + "tir.is_scheduled": True, + } + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_seq = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") + bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32") + + for fused_s_v_0 in T.thread_binding(0, (num_seq * vocab_size + tx - 1) // tx, "blockIdx.x"): + for fused_s_v_1 in T.thread_binding(0, tx, "threadIdx.x"): + with T.block("block"): + vs = T.axis.spatial(num_seq, (fused_s_v_0 * tx + fused_s_v_1) // vocab_size) + vv = T.axis.spatial(vocab_size, (fused_s_v_0 * tx + fused_s_v_1) % vocab_size) + T.where(fused_s_v_0 * tx + fused_s_v_1 < num_seq * vocab_size) + logits[seq_ids[vs], vv] = T.if_then_else( + (bitmask[seq_ids[vs], vv // 32] >> (vv % 32)) & 1 == 1, + logits[seq_ids[vs], vv], + T.float32(-1e10), + ) + + return _apply_bitmask_inplace diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index 933b8ad6bb..d576c68451 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -93,7 +93,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Phase 0. Add additional information for compilation and remove unused Relax func DispatchKVCacheCreation(target, flashinfer, metadata), AttachVariableBounds(variable_bounds), - AttachLogitProcessFunc(), + AttachLogitProcessFunc(target), AttachAdditionalPrimFuncs(additional_tirs), AttachAllocEmbeddingTensorFunc(metadata), AttachGPUSamplingFunc(target, variable_bounds), From d5461342fe25ca3858cd3a537fc19a5fda77b55f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 14 Mar 2024 13:28:05 -0700 Subject: [PATCH 071/531] [Model] Use static hidden size in mixtral scatter_output (#1959) --- python/mlc_llm/op/moe_misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/op/moe_misc.py b/python/mlc_llm/op/moe_misc.py index e97ef94fff..19bf10381f 100644 --- a/python/mlc_llm/op/moe_misc.py +++ b/python/mlc_llm/op/moe_misc.py @@ -385,11 +385,11 @@ def scatter_output(x: Tensor, indices: Tensor) -> Tensor: The output of MoE experts with shape [batch_size * num_experts_per_tok, hidden_size]. """ dtype = x.dtype + _, hidden_size = x.shape @T.prim_func(private=True) def _func(var_x: T.handle, var_indices: T.handle, var_out: T.handle): T.func_attr({"tir.noalias": True}) - hidden_size = T.int64() indices_len = T.int64() x = T.match_buffer(var_x, [indices_len, hidden_size], dtype) indices = T.match_buffer(var_indices, [indices_len], "int32") From 01527e99fc3a02a48d74f06661738799956b671b Mon Sep 17 00:00:00 2001 From: Git bot Date: Fri, 15 Mar 2024 01:10:57 +0000 Subject: [PATCH 072/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 1d4da926c7..641209c69a 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 1d4da926c726e2700593c7f62006545bda6a46f9 +Subproject commit 641209c69ad153c02471ba71bdf40a10c90789e5 From 09fe1bc0211ab22df149057c42177f3dfabc5641 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 14 Mar 2024 22:25:01 -0400 Subject: [PATCH 073/531] [CompilerFlag] Detect if FlashInfer is enabled from libinfo (#1941) This PR supports the detection of if FlashInfer is enabled when building TVM, so that FlashInfer won't be enabled when TVM is not built with FlashInfer enabled. --- python/mlc_llm/interface/compiler_flags.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index fd820e7124..bc40103918 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -1,8 +1,11 @@ """Flags for overriding model config.""" + import dataclasses from io import StringIO from typing import Optional +import tvm + from mlc_llm.support import argparse, logging from mlc_llm.support.config import ConfigOverrideBase @@ -65,6 +68,8 @@ def _flashinfer(target) -> bool: return False if target.kind.name != "cuda": return False + if tvm.get_global_func("support.GetLibInfo")()["USE_FLASHINFER"] != "ON": + return False arch_list = detect_cuda_arch_list(target) for arch in arch_list: if arch < 80: From c7d52c40f484c5a3c8067c4e5ae5d9a7da82abe8 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Fri, 15 Mar 2024 21:20:35 +0800 Subject: [PATCH 074/531] [Serving][Grammar] Add grammar termination as a stop condition (#1964) --- cpp/serve/grammar/grammar_state_matcher_base.h | 2 +- .../grammar/grammar_state_matcher_state.h | 10 ++++++++-- cpp/serve/request_state.cc | 18 +++++++++++++----- .../python/serve/test_serve_engine_grammar.py | 15 +++++++++++++-- 4 files changed, 35 insertions(+), 10 deletions(-) diff --git a/cpp/serve/grammar/grammar_state_matcher_base.h b/cpp/serve/grammar/grammar_state_matcher_base.h index 4c543a2e69..d26069be00 100644 --- a/cpp/serve/grammar/grammar_state_matcher_base.h +++ b/cpp/serve/grammar/grammar_state_matcher_base.h @@ -126,7 +126,7 @@ inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool tmp_new_stack_tops_.clear(); for (auto prev_top : prev_stack_tops) { - const auto& cur_rule_position = tree_[prev_top]; + auto cur_rule_position = tree_[prev_top]; auto current_sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); if (cur_rule_position.parent_id == RulePosition::kNoParent && cur_rule_position.element_id == current_sequence.size()) { diff --git a/cpp/serve/grammar/grammar_state_matcher_state.h b/cpp/serve/grammar/grammar_state_matcher_state.h index fad3365ed9..08f54be310 100644 --- a/cpp/serve/grammar/grammar_state_matcher_state.h +++ b/cpp/serve/grammar/grammar_state_matcher_state.h @@ -101,8 +101,14 @@ class RulePositionBuffer { } /*! \brief Get the RulePosition with the given id. */ - RulePosition& operator[](int32_t id) { return buffer_[id]; } - const RulePosition& operator[](int32_t id) const { return buffer_[id]; } + RulePosition& operator[](int32_t id) { + DCHECK(id < static_cast(buffer_.size()) && buffer_[id] != kInvalidRulePosition); + return buffer_[id]; + } + const RulePosition& operator[](int32_t id) const { + DCHECK(id < static_cast(buffer_.size()) && buffer_[id] != kInvalidRulePosition); + return buffer_[id]; + } void Reset() { buffer_.clear(); diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 6eca65f05f..1a0e1970f7 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -118,7 +118,7 @@ DeltaRequestReturn RequestStateEntryNode::GetReturnTokenIds(const Tokenizer& tok int max_single_sequence_length) { // - Case 0. There is remaining draft output ==> Unfinished // All draft outputs are supposed to be processed before finish. - for (RequestModelState mstate : mstates) { + for (RequestModelState mstate : this->mstates) { if (!mstate->draft_output_tokens.empty()) { return {{}, {}, Optional()}; } @@ -127,7 +127,7 @@ DeltaRequestReturn RequestStateEntryNode::GetReturnTokenIds(const Tokenizer& tok std::vector return_token_ids; std::vector logprob_json_strs; Optional finish_reason; - const std::vector& committed_tokens = mstates[0]->committed_tokens; + const std::vector& committed_tokens = this->mstates[0]->committed_tokens; int num_committed_tokens = committed_tokens.size(); ICHECK_LE(this->next_callback_token_pos, num_committed_tokens); @@ -160,7 +160,7 @@ DeltaRequestReturn RequestStateEntryNode::GetReturnTokenIds(const Tokenizer& tok request->generation_cfg->stop_token_ids.begin(), request->generation_cfg->stop_token_ids.end(), [&return_token_ids, i](int32_t token) { return token == return_token_ids[i]; })) { - // Stop token matched. Erase all tokens after the current position. + // Stop token matched. Erase the stop token and all tokens after it. finish_reason = "stop"; while (static_cast(return_token_ids.size()) > i) { return_token_ids.pop_back(); @@ -170,11 +170,19 @@ DeltaRequestReturn RequestStateEntryNode::GetReturnTokenIds(const Tokenizer& tok } } + // Case 4. When stop token is not detected (e.g. ignore_eos is set), but the grammar state is + // terminated, stop the generation and pop the last token (used to trigger the termination). + if (finish_reason != "stop" && this->mstates[0]->grammar_state_matcher.defined() && + this->mstates[0]->grammar_state_matcher.value()->IsTerminated()) { + return_token_ids.pop_back(); + finish_reason = "stop"; + } + if (finish_reason.defined()) { return {return_token_ids, logprob_json_strs, finish_reason}; } - // Case 4. Generation reaches the specified max generation length ==> Finished + // Case 5. Generation reaches the specified max generation length ==> Finished // `max_tokens` means the generation length is limited by model capacity. if (request->generation_cfg->max_tokens >= 0 && num_committed_tokens >= request->generation_cfg->max_tokens) { @@ -182,7 +190,7 @@ DeltaRequestReturn RequestStateEntryNode::GetReturnTokenIds(const Tokenizer& tok return_token_ids.insert(return_token_ids.end(), remaining.begin(), remaining.end()); return {return_token_ids, logprob_json_strs, String("length")}; } - // Case 5. Total length of the request reaches the maximum single sequence length ==> Finished + // Case 6. Total length of the request reaches the maximum single sequence length ==> Finished if (request->input_total_length + num_committed_tokens >= max_single_sequence_length) { std::vector remaining = stop_str_handler->Finish(); return_token_ids.insert(return_token_ids.end(), remaining.begin(), remaining.end()); diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index b5430acd39..abe0e391ed 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -26,7 +26,8 @@ def test_batch_generation_with_grammar(): # Create engine engine = Engine(model, kv_cache_config) - prompts = prompts_list * 2 + prompt_len = len(prompts_list) + prompts = prompts_list * 3 temperature = 1 repetition_penalty = 1 @@ -45,7 +46,17 @@ def test_batch_generation_with_grammar(): stop_token_ids=[2], response_format=ResponseFormat(type="json_object"), ) - all_generation_configs = [generation_config_no_json] * 3 + [generation_config_json] * 3 + generation_config_json_no_stop_token = GenerationConfig( + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + response_format=ResponseFormat(type="json_object"), + ) + all_generation_configs = ( + [generation_config_no_json] * prompt_len + + [generation_config_json] * prompt_len + + [generation_config_json_no_stop_token] * prompt_len + ) # Generate output. output_texts, _ = engine.generate(prompts, all_generation_configs) From 994f9289892b0218e3fe7e9df4685d35a8fcdfb5 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Fri, 15 Mar 2024 12:22:19 -0400 Subject: [PATCH 075/531] Unify schema for conversation template and embed into mlc-chat-config.json (#1965) --- tests/python/conftest.py => conftest.py | 0 cpp/conversation.cc | 132 ++++++++++- cpp/conversation.h | 12 + cpp/llm_chat.cc | 25 +- docs/deploy/python.rst | 2 +- docs/get_started/mlc_chat_config.rst | 215 +++++++----------- python/mlc_llm/chat_module.py | 92 ++++---- python/mlc_llm/interface/gen_config.py | 18 +- .../mlc_llm/protocol/conversation_protocol.py | 14 +- tests/cpp/conv_unittest.cc | 61 ++++- .../protocol/test_converation_protocol.py | 20 ++ 11 files changed, 410 insertions(+), 181 deletions(-) rename tests/python/conftest.py => conftest.py (100%) create mode 100644 tests/python/protocol/test_converation_protocol.py diff --git a/tests/python/conftest.py b/conftest.py similarity index 100% rename from tests/python/conftest.py rename to conftest.py diff --git a/cpp/conversation.cc b/cpp/conversation.cc index a3a432397a..d05021dc6c 100644 --- a/cpp/conversation.cc +++ b/cpp/conversation.cc @@ -11,6 +11,130 @@ namespace llm { void Conversation::LoadJSONOverride(const picojson::value& config_json, bool partial_update) { std::string err_templ = " in conversion template json file."; picojson::object config = config_json.get(); + + if (config.count("name")) { + CHECK(config["name"].is()) << "Invalid name" << err_templ; + this->name = config["name"].get(); + } else { + CHECK(partial_update) << "Key \"name\" not found."; + } + + if (config.count("system_template") && config.count("system_message")) { + std::string system_placeholder = "{system_message}"; + CHECK(config["system_template"].is()) << "Invalid system template" << err_templ; + CHECK(config["system_message"].is()) << "Invalid system message" << err_templ; + std::string system_template = config["system_template"].get(); + std::string system_msg = config["system_message"].get(); + std::string system = system_template.replace(system_template.find(system_placeholder), + system_placeholder.length(), system_msg); + this->system = system; + } else { + CHECK(partial_update) << "Key \"system_template\" or \"system_message\" not found."; + } + + if (config.count("system_prefix_token_ids")) { + CHECK(config["system_prefix_token_ids"].is()) + << "Invalid system_prefix_token_ids" << err_templ; + picojson::array prefix_tokens_arr = config["system_prefix_token_ids"].get(); + std::vector prefix_tokens; + for (const picojson::value& prefix_token : prefix_tokens_arr) { + CHECK(prefix_token.is()) << "Invalid prefix_tokens" << err_templ; + prefix_tokens.push_back(prefix_token.get()); + } + this->prefix_tokens = prefix_tokens; + } + + if (config.count("roles")) { + CHECK(config["roles"].is()) << "Invalid roles" << err_templ; + picojson::object roles_json = config["roles"].get(); + std::vector roles(2); + for (auto [role, role_name] : roles_json) { + CHECK(role_name.is()); + if (role == "user") { + roles.at(0) = role_name.get(); + } + if (role == "assistant") { + roles.at(1) = role_name.get(); + } + } + this->roles = roles; + } + + if (config.count("messages")) { + CHECK(config["messages"].is()) << "Invalid messages" << err_templ; + std::vector> messages; + picojson::array msgs_arr = config["messages"].get(); + for (const picojson::value& msgs_i : msgs_arr) { + CHECK(msgs_i.is()) << "Invalid messages" << err_templ; + picojson::array msgs_i_arr = msgs_i.get(); + std::vector messages_i; + for (const picojson::value& msg_v : msgs_i_arr) { + CHECK(msg_v.is()) << "Invalid messages" << err_templ; + messages_i.push_back(msg_v.get()); + } + messages.push_back(messages_i); + } + this->messages = messages; + this->offset = messages.size(); + } else { + this->offset = 0; + } + + if (config.count("seps")) { + std::vector seps; + CHECK(config["seps"].is()) << "Invalid seps" << err_templ; + picojson::array seps_arr = config["seps"].get(); + for (const picojson::value& sep : seps_arr) { + CHECK(sep.is()) << "Invalid seps" << err_templ; + seps.push_back(sep.get()); + } + this->seps = seps; + } else { + CHECK(partial_update) << "Key \"seps\" not found."; + } + + if (config.count("role_content_sep")) { + CHECK(config["role_content_sep"].is()) << "Invalid role_content_sep" << err_templ; + this->role_msg_sep = config["role_content_sep"].get(); + } else { + CHECK(partial_update) << "Key \"role_msg_sep\" not found."; + } + if (config.count("role_empty_sep")) { + CHECK(config["role_empty_sep"].is()) << "Invalid role_empty_sep" << err_templ; + this->role_empty_sep = config["role_empty_sep"].get(); + } else { + CHECK(partial_update) << "Key \"role_empty_sep\" not found."; + } + + if (config.count("stop_str")) { + CHECK(config["stop_str"].is()) << "Invalid stop_str" << err_templ; + picojson::array stop_str_arr = config["stop_str"].get(); + if (stop_str_arr.size() >= 1) { + picojson::value stop_str = stop_str_arr.at(0); + CHECK(stop_str.is()); + this->stop_str = stop_str.get(); + } + } else { + CHECK(partial_update) << "Key \"stop_str\" not found."; + } + + if (config.count("stop_token_ids")) { + CHECK(config["stop_token_ids"].is()) << "Invalid stop_token_ids" << err_templ; + picojson::array stop_tokens_arr = config["stop_token_ids"].get(); + std::vector stop_tokens; + for (const picojson::value& stop_token : stop_tokens_arr) { + CHECK(stop_token.is()) << "Invalid stop_tokens" << err_templ; + stop_tokens.push_back(stop_token.get()); + } + this->stop_tokens = stop_tokens; + } else { + CHECK(partial_update) << "Key \"stop_token_ids\" not found."; + } +} + +void Conversation::LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update) { + std::string err_templ = " in conversion template json file."; + picojson::object config = config_json.get(); if (config.count("name")) { CHECK(config["name"].is()) << "Invalid name" << err_templ; this->name = config["name"].get(); @@ -134,7 +258,13 @@ void Conversation::LoadJSONOverride(const std::string& config_str, bool partial_ LOG(FATAL) << err; return; } - LoadJSONOverride(config_json, partial_update); + + picojson::object config = config_json.get(); + try { + LoadJSONOverride(config_json, partial_update); + } catch (...) { + LoadJSONOverrideLegacy(config_json, partial_update); + } } picojson::value Conversation::SerializeToJSON() const { diff --git a/cpp/conversation.h b/cpp/conversation.h index 14cbd44149..7a75e8748a 100644 --- a/cpp/conversation.h +++ b/cpp/conversation.h @@ -154,6 +154,18 @@ class Conversation { */ void LoadJSONOverride(const picojson::value& config_json, bool partial_update = false); + /*! + * \brief Load legacy JSON config and overrides options. + * + * \param config_json A json config in picojson type that is partially specifies + * some of the options. + * \param partial_update Whether it's a partial update or full update, if set to true, + * we perform a partial update on some of the provided options; if set to false, all + * options must be provided. + * \note DEPRECATED. This function loads the legacy JSON config value. + */ + void LoadJSONOverrideLegacy(const picojson::value& config_json, bool partial_update = false); + /*! * \brief Serialize the Conversation to JSON. * \return Serialized conversion in JSON format. diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 5577f9b87d..09c2ce9a37 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -558,16 +558,31 @@ class LLMChat { CHECK(partial_update) << "Key \"shift_fill_factor\" not found."; } if (config.count("conv_template")) { - ICHECK(config["conv_template"].is()); - std::string conv_template = config["conv_template"].get(); - this->conversation_ = Conversation::FromTemplate(conv_template); + if (config["conv_template"].is()) { + this->conversation_.LoadJSONOverride(config["conv_template"], false); + } else { + ICHECK(config["conv_template"].is()); + LOG(WARNING) + << "Legacy conversation template detected. It will be deprecated in the future. " + "Please regenerate mlc-chat-config.json with the latest version"; + std::string conv_template = config["conv_template"].get(); + this->conversation_ = Conversation::FromTemplate(conv_template); + } if (config.count("conv_config")) { // conv_config can override conv_template - this->conversation_.LoadJSONOverride(config["conv_config"], true); + try { + this->conversation_.LoadJSONOverride(config["conv_config"], true); + } catch (...) { + this->conversation_.LoadJSONOverrideLegacy(config["conv_config"], true); + } } } else if (config.count("conv_config")) { // without conv template, conv_config needs to be a complete config - this->conversation_.LoadJSONOverride(config["conv_config"], false); + try { + this->conversation_.LoadJSONOverride(config["conv_config"], false); + } catch (...) { + this->conversation_.LoadJSONOverrideLegacy(config["conv_config"], false); + } } else { CHECK(partial_update) << "Key \"conv_template\" and \"conv_config\" not found."; } diff --git a/docs/deploy/python.rst b/docs/deploy/python.rst index d5edcf82aa..38cdec2f85 100644 --- a/docs/deploy/python.rst +++ b/docs/deploy/python.rst @@ -184,7 +184,7 @@ We provide an example below. # Using a `ConvConfig`, we modify `system`, a field in the conversation template # `system` refers to the prompt encoded before starting the chat - conv_config = ConvConfig(system='Please show as much happiness as you can when talking to me.') + conv_config = ConvConfig(system_message='Please show as much happiness as you can when talking to me.') # We then include the `ConvConfig` instance in `ChatConfig` while overriding `max_gen_len` # Note that `conv_config` is an optional subfield of `chat_config` diff --git a/docs/get_started/mlc_chat_config.rst b/docs/get_started/mlc_chat_config.rst index ccaa97b4fc..482e68d368 100644 --- a/docs/get_started/mlc_chat_config.rst +++ b/docs/get_started/mlc_chat_config.rst @@ -52,14 +52,21 @@ Below is the ``mlc-chat-config.json`` file corresponding to Llama2 model: "tokenizer_config.json" ] - // 3. Chat related fields that affect runtime behavior + // 3. Conversation template related fields + "conv_template": { + "name": "llama-2", + "system_template": "[INST] <>\n{system_message}\n<>\n\n ", + "system_message": "You are a helpful, respectful and honest assistant.", + // more fields here... + }, + + // 4. Chat related fields that affect runtime behavior "mean_gen_len": 128, "max_gen_len": 512, "shift_fill_factor": 0.3, "temperature": 0.6, "repetition_penalty": 1.0, - "top_p": 0.9, - "conv_template": "llama-2", + "top_p": 0.9 } .. note:: @@ -70,7 +77,11 @@ Below is the ``mlc-chat-config.json`` file corresponding to Llama2 model: can be customized to change the behavior of the model.** ``conv_template`` - The name of the conversation template that this chat uses. For more information, please refer to :ref:`conversation structure `. + .. note:: + Legacy ``mlc-chat-config.json`` may specify a string for this field to look up a registered conversation + template. It will be deprecated in the future. Re-generate config using the latest version of mlc_llm + to make sure this field is a complete JSON object. + The conversation template that this chat uses. For more information, please refer to :ref:`conversation structure `. ``temperature`` The temperature applied to logits before sampling. The default value is ``0.7``. A higher temperature encourages more diverse outputs, while a lower temperature produces more deterministic outputs. @@ -99,32 +110,17 @@ can be customized to change the behavior of the model.** Conversation Structure ^^^^^^^^^^^^^^^^^^^^^^ -There are three options of loading conversation configurations: - -1. Load from pre-defined conversation templates. -2. Load from JSON format conversation configuration. -3. First load from pre-defined conversation templates, then override some fields with JSON format conversation configuration. - -.. _load-predefined-conv-template: - -Load from Pre-defined Conversation Templates --------------------------------------------- - -MLC-LLM provided a set of pre-defined conversation templates, which you can directly use by specifying the template name in ``conv_template`` field in the ``mlc-chat-config.json``, below is a list (not complete) of supported conversation templates: +MLC-LLM provided a set of pre-defined conversation templates, which you can directly use by +specifying ``--conv-template [name]`` when generating config. Below is a list (not complete) of +supported conversation templates: - ``llama-2`` -- ``vicuna_v1.1`` -- ``redpajama_chat`` -- ``rwkv`` -- ``dolly`` +- ``mistral_default`` +- ``chatml`` +- ``phi-2`` - ... -Please refer to `conv_template.cc `_ for the full list of supported templates and their implementations. - -.. _load-json-conv-config: - -Load from JSON Conversation Configuration ------------------------------------------ +Please refer to `conversation_template.py `_ for the full list of supported templates and their implementations. Below is a generic structure of a JSON conversation configuration (we use vicuna as an example): @@ -133,122 +129,81 @@ Below is a generic structure of a JSON conversation configuration (we use vicuna // mlc-chat-config.json { // ... - "conv_config": { + "conv_template": { + "name": "llama-2", + "system_template": "[INST] <>\n{system_message}\n<>\n\n ", + "system_message": "You are a helpful, respectful and honest assistant.", + "roles": { + "user": "[INST]", + "assistant": "[/INST]", + "tool": "[INST]" + }, + "role_templates": { + "user": "{user_message}", + "assistant": "{assistant_message}", + "tool": "{tool_message}" + }, + "messages": [], "seps": [ - " ", - "<\/s>" + " " ], - "stop_tokens": [ - 2 + "role_content_sep": " ", + "role_empty_sep": " ", + "stop_str": [ + "[INST]" ], - "offset": 0, - "separator_style": 0, - "messages": [], - "stop_str": "<\/s>", - "roles": [ - "USER", - "ASSISTANT" + "stop_token_ids": [ + 2 ], - "role_msg_sep": ": ", - "role_empty_sep": ": ", - "system": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", - "add_bos": true, - "name": "vicuna_v1.1" + "function_string": "", + "use_function_calling": false } } +``name`` + Name of the conversation. +``system_template`` + The system prompt template, it optionally contains the system + message placeholder, and the placeholder will be replaced with + the system message below. +``system_message`` + The content of the system prompt (without the template format). +``system_prefix_token_ids`` + The system token ids to be prepended at the beginning of tokenized + generated prompt. ``roles`` - An array that describes the role names of the user and the model. These names are specific to the model being used. -``system`` - The prompt encoded before starting the chat. It can be customized to a user-defined prompt. -``add_bos`` - Determines whether a beginning-of-string (bos) token should be added before the input tokens. -``stop_str`` - When the ``stop_str`` is encountered, the model will stop generating output. -``stop_tokens`` - A list of token IDs that act as stop tokens. -``seps`` - An array of strings indicating the separators to be used after a user message and a model message respectively. + The conversation roles +``role_templates`` + The roles prompt template, it optionally contains the defaults + message placeholders and will be replaced by actual content ``messages`` - The chat history represented as an array of string pairs in the following format: ``[[role_0, msg_0], [role_1, msg_1], ...]`` -``offset`` - The offset used to begin the chat from the chat history. When ``offset`` is not ``0``, ``messages[0:offset-1]`` will be encoded. -``separator_style`` - Specifies whether we are in chat-bot mode (``0``) or pure LM prompt mode (``1``). -``role_msg_sep`` - A string indicating the separator between a role and a message. + The conversation history messages. + Each message is a pair of strings, denoting "(role, content)". + The content can be None. +``seps`` + An array of strings indicating the separators to be used after a user + message and a model message respectively. +``role_content_sep`` + The separator between the role and the content in a message. ``role_empty_sep`` - A string indicating the separator to append to a role when there is no message yet. - - -When the value of ``separator_style`` is set to 0 (or ``kSepRoleMsg``), each round of conversation follows the format: - -.. code:: text - - {role[0]}{separator_style}{user_input}{sep[0]} - {role[1]}{separator_style}{model_output}{sep[1]} - -Here, ``{user_input}`` represents the input provided by the user, and ``{model_output}`` represents the output generated by the model. + The separator between the role and empty contents. +``stop_str`` + When the ``stop_str`` is encountered, the model will stop generating output. +``stop_token_ids`` + A list of token IDs that act as stop tokens. +``function_string`` + The function calling string. +``use_function_calling`` + Whether using function calling or not, helps check for output message format in API call. -On the other hand, if the value of ``separator_style`` is set to 1 (or ``kLM``), the model is not aware of the chat history and generates the response immediately after the user input prompt: +Given a conversation template, the corresponding prompt generated out +from it is in the following format: .. code:: text - {user_prompt}{model_output} - - -.. _customize-conv-template: - -Customize Conversation Template -------------------------------- - -In the ``mlc-chat-config.json`` file, you have the option to specify both ``conv_template`` and ``conv_config``. MLC-LLM will first load the predefined template with the name specified in ``conv_template`` and then override some of the configurations specified in ``conv_config``. It's important to note that the configurations in ``conv_config`` don't need to be complete, allowing for partial updates. - -.. _example_replace_system_prompt: - -Example 1: Replace System Prompt -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -If you're tired of the default system prompt, here's an example of how you can replace it: - -.. code:: json - - // mlc-chat-config.json - { - // ... - "conv_template": "vicuna_v1.1", - "conv_config": { - "system": "You are not Vicuna, your name is Guanaco, now let's chat!" - } - } - - -The next time you run ``mlc_llm`` CLI, you will start a chat with Vicuna using a new system prompt. - -.. _example_resume_chat_history: - -Example 2: Resume from Chat History -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The following example demonstrates how to chat with Vicuna and resume from a chat history: - -.. code:: json - - // mlc-chat-config.json - { - // ... - "conv_template": "vicuna_v1.1", - "conv_config": { - "messages": [ - ["USER", "Suppose we already have projects llama, alpaca and vicuna, what do you think would be a great name for the next project?"], - ["ASSISTANT", "Based on the previous projects, a possible name for the next project could be \"cervidae\" which is the scientific name for deer family. This name reflects the collaboration and teamwork involved in the development of the project, and also nods to the previous projects that have been developed by the team."], - ["USER", "I like cervidae, but the name is too long!"], - ["ASSISTANT", "In that case, a shorter and catchier name for the next project could be \"DeerRun\" which plays on the idea of the project being fast and efficient, just like a deer running through the woods. This name is memorable and easy to pronounce, making it a good choice for a project name."] - ], - "offset": 4 - } - } - - -The next time you start ``mlc_llm`` CLI, or use Python API, you will initiate a chat with Vicuna and resume from the provided chat history. + <><><><><> + <><><><> + ... + <><><><> + <><> diff --git a/python/mlc_llm/chat_module.py b/python/mlc_llm/chat_module.py index 675e1e7c94..18c3258514 100644 --- a/python/mlc_llm/chat_module.py +++ b/python/mlc_llm/chat_module.py @@ -16,6 +16,7 @@ import tvm from tvm.runtime import disco # pylint: disable=unused-import +from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.support import logging from mlc_llm.support.auto_device import detect_device from mlc_llm.support.config import ConfigBase @@ -44,58 +45,61 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes Since the configuration is partial, everything will be ``Optional``. + The parameters are the same as :class:`mlc_llm.protocol.conversation_protocol.Conversation` + Parameters ---------- name : Optional[str] Name of the conversation. - system : Optional[str] - The prompt encoded before starting the chat. - roles : Optional[List[str]] - An array that describes the role names of the user and the model. These - names are specific to the model being used. - messages : Optional[List[List[str]]] - The chat history represented as an array of string pairs in the following - format: ``[[role_0, msg_0], [role_1, msg_1], ...]``. - offset : Optional[int] - The offset used to begin the chat from the chat history. When offset - is not ``0``, ``messages[0:offset-1]`` will be encoded. - separator_style : Optional[int] - Specifies whether we are in chat-bot mode (``0``) or pure LM prompt mode (``1``). + system_template : Optional[str] + The system prompt template, it optionally contains the system + message placeholder, and the placeholder will be replaced with + the system message below. + system_message : Optional[str] + The content of the system prompt (without the template format). + system_prefix_token_ids : Optional[List[int]] + The system token ids to be prepended at the beginning of tokenized + generated prompt. + roles : Optional[Dict[str, str]] + The conversation roles + role_templates : Optional[Dict[str, str]] + The roles prompt template, it optionally contains the defaults + message placeholders and will be replaced by actual content + messages : Optional[List[Tuple[str, Optional[str]]]] + The conversation history messages. + Each message is a pair of strings, denoting "(role, content)". + The content can be None. seps : Optional[List[str]] An array of strings indicating the separators to be used after a user message and a model message respectively. - role_msg_sep : Optional[str] - A string indicating the separator between a role and a message. + role_content_sep : Optional[str] + The separator between the role and the content in a message. role_empty_sep : Optional[str] - A string indicating the separator to append to a role when there is no message yet. - stop_str : Optional[str] + The separator between the role and empty contents. + stop_str : Optional[List[str]] When the ``stop_str`` is encountered, the model will stop generating output. - stop_tokens : Optional[List[int]] + stop_token_ids : Optional[List[int]] A list of token IDs that act as stop tokens. - prefix_tokens : Optional[List[int]] - Token list prefixing the conversation. - add_bos : Optional[bool] - Determines whether a beginning-of-string (bos) token should be added - before the input tokens. + function_string : Optional[str] + The function calling string. + use_function_calling : Optional[bool] + Whether using function calling or not, helps check for output message format in API call. """ name: Optional[str] = None - system: Optional[str] = None - roles: Optional[List[str]] = None - messages: Optional[List[List[str]]] = None - offset: Optional[int] = None - separator_style: Optional[int] = None + system_template: Optional[str] = None + system_message: Optional[str] = None + system_prefix_token_ids: Optional[List[int]] = None + roles: Optional[Dict[str, str]] = None + role_templates: Optional[Dict[str, str]] = None + messages: Optional[List[Tuple[str, Optional[str]]]] = None seps: Optional[List[str]] = None - role_msg_sep: Optional[str] = None + role_content_sep: Optional[str] = None role_empty_sep: Optional[str] = None - stop_str: Optional[str] = None - stop_tokens: Optional[List[int]] = None - prefix_tokens: Optional[List[int]] = None - add_bos: Optional[bool] = None - - def __post_init__(self): - if self.messages is not None and self.offset is None: - self.offset = len(self.messages) + stop_str: Optional[List[str]] = None + stop_token_ids: Optional[List[int]] = None + function_string: Optional[str] = None + use_function_calling: Optional[bool] = None @dataclass @@ -192,7 +196,7 @@ class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes model_lib: Optional[str] = None local_id: Optional[str] = None - conv_template: Optional[str] = None + conv_template: Optional[Union[str, Conversation]] = None temperature: Optional[float] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 @@ -217,6 +221,8 @@ class ChatConfig(ConfigBase): # pylint: disable=too-many-instance-attributes @classmethod def _from_json(cls, json_obj: dict): + if "conv_template" in json_obj and isinstance(json_obj["conv_template"], dict): + json_obj["conv_template"] = Conversation.from_json_dict(json_obj["conv_template"]) return cls(**{k: v for k, v in json_obj.items() if k in inspect.signature(cls).parameters}) @@ -440,6 +446,13 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi "override the full model library path instead." ) warnings.warn(warn_msg) + elif field_name == "conv_template" and isinstance(field_value, Conversation): + warn_msg = ( + 'WARNING: Do not override "conv_template" in ChatConfig. ' + 'Please override "conv_config" instead.' + "This override will be ignored." + ) + warnings.warn(warn_msg) else: setattr(final_chat_config, field_name, field_value) return final_chat_config @@ -613,6 +626,9 @@ def _convert_chat_config_to_json_str( conv_dict[conv_k] = conv_v chat_dict[key] = conv_dict continue + if key == "conv_template" and isinstance(value, Conversation): + chat_dict[key] = Conversation.to_json_dict(value) + continue if value is not None: chat_dict[key] = value diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index f4d39aa8ba..4bce52aa20 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -4,8 +4,9 @@ import json import shutil from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union +from mlc_llm.conversation_template import ConvTemplateRegistry from mlc_llm.model import Model from mlc_llm.quantization import Quantization from mlc_llm.support import convert_tiktoken, logging @@ -45,7 +46,7 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes repetition_penalty: float = None top_p: float = None # Conversation template - conv_template: str = None + conv_template: Union[str, Dict[str, Any]] = None pad_token_id: int = None bos_token_id: int = None eos_token_id: int = None @@ -89,6 +90,17 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b ): """Entrypoint of MLC Chat configuration generation.""" # Step 1. Initialize `mlc-chat-config.json` using `config.json` + conversation_reg = ConvTemplateRegistry.get_conv_template(conv_template) + if conversation_reg is None: + logger.warning( + "%s: Conversation template is not registered in ConvTemplateRegistry: %s", + red("Warning"), + conv_template, + ) + conversation = conv_template # type: ignore + else: + conversation = conversation_reg.to_json_dict() # type: ignore + model_config = ModelConfigOverride( context_window_size=context_window_size, sliding_window_size=sliding_window_size, @@ -107,7 +119,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b prefill_chunk_size=model_config.prefill_chunk_size, attention_sink_size=getattr(model_config, "attention_sink_size", -1), tensor_parallel_shards=model_config.tensor_parallel_shards, - conv_template=conv_template, + conv_template=conversation, ) # Step 2. Load `generation_config.json` and `config.json` for text-generation related configs for generation_config_filename in ["generation_config.json", "config.json"]: diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index 01c145db7d..fa99b95c16 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -1,7 +1,7 @@ """The standard conversation protocol in MLC LLM""" from enum import Enum -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar from pydantic import BaseModel, Field, field_validator @@ -17,6 +17,9 @@ class MessagePlaceholders(Enum): FUNCTION = "{function_string}" +T = TypeVar("T", bound="BaseModel") + + class Conversation(BaseModel): """Class that specifies the convention template of conversation and contains the conversation history. @@ -95,6 +98,15 @@ def check_message_seps(cls, seps: List[str]) -> List[str]: raise ValueError("seps should have size 1 or 2.") return seps + def to_json_dict(self) -> Dict[str, Any]: + """Convert to a json dictionary""" + return self.model_dump(exclude_none=True) + + @classmethod + def from_json_dict(cls: Type[T], json_dict: Dict[str, Any]) -> T: + """Convert from a json dictionary""" + return Conversation.model_validate(json_dict) + def as_prompt(self) -> str: """Convert the conversation template and history messages to a single prompt. diff --git a/tests/cpp/conv_unittest.cc b/tests/cpp/conv_unittest.cc index 98d01a58ba..d49c7107cd 100644 --- a/tests/cpp/conv_unittest.cc +++ b/tests/cpp/conv_unittest.cc @@ -1,6 +1,61 @@ #include #include +void _TestConversationLoadJSON() { + std::string conv_template = + "{\n" + " \"name\": \"test\",\n" + " \"system_template\": \"abc{system_message}\",\n" + " \"system_message\": \"de\",\n" + " \"roles\": {\n" + " \"user\": \"Instruct\",\n" + " \"assistant\": \"Output\",\n" + " \"tool\": \"Instruct\"\n" + " },\n" + " \"role_templates\": {\n" + " \"user\": \"{user_message}\",\n" + " \"assistant\": \"{assistant_message}\",\n" + " \"tool\": \"{tool_message}\"\n" + " },\n" + " \"messages\": [[\"Instruct\", \"Hello\"], [\"Output\", \"Hey\"]],\n" + " \"seps\": [\n" + " \"\\n\"\n" + " ],\n" + " \"role_content_sep\": \": \",\n" + " \"role_empty_sep\": \":\",\n" + " \"stop_str\": [\n" + " \"<|endoftext|>\"\n" + " ],\n" + " \"stop_token_ids\": [\n" + " 50256\n" + " ],\n" + " \"function_string\": \"\",\n" + " \"use_function_calling\": false\n" + "}"; + mlc::llm::Conversation conv; + conv.LoadJSONOverride(conv_template, true); + ASSERT_EQ(conv.name, "test"); + ASSERT_EQ(conv.system, "abcde"); + + std::vector expected_roles{"Instruct", "Output"}; + ASSERT_EQ(conv.roles, expected_roles); + + std::vector> expected_messages = {{"Instruct", "Hello"}, + {"Output", "Hey"}}; + ASSERT_EQ(conv.messages, expected_messages); + ASSERT_EQ(conv.offset, 2); + + std::vector expected_seps = {"\n"}; + ASSERT_EQ(conv.seps, expected_seps); + + ASSERT_EQ(conv.role_msg_sep, ": "); + ASSERT_EQ(conv.role_empty_sep, ":"); + ASSERT_EQ(conv.stop_str, "<|endoftext|>"); + + std::vector expected_stop_tokens = {50256}; + ASSERT_EQ(conv.stop_tokens, expected_stop_tokens); +} + void _TestConversationJSONRoundTrip(std::string templ_name) { mlc::llm::Conversation conv = mlc::llm::Conversation::FromTemplate(templ_name); std::string conv_json = conv.GetConfigJSON(); @@ -11,12 +66,14 @@ void _TestConversationJSONRoundTrip(std::string templ_name) { void _TestConversationPartialUpdate() { mlc::llm::Conversation conv; - std::string json_str = "{\"offset\": -1}"; + std::string json_str = "{\"name\": \"test\"}"; ASSERT_ANY_THROW(conv.LoadJSONOverride(json_str, false)); conv.LoadJSONOverride(json_str, true); - ASSERT_EQ(conv.offset, -1); + ASSERT_EQ(conv.name, "test"); } +TEST(ConversationTest, ConversationLoadJSONTest) { _TestConversationLoadJSON(); } + TEST(ConversationTest, ConversationJSONRoundTripTest) { _TestConversationJSONRoundTrip("vicuna_v1.1"); _TestConversationJSONRoundTrip("conv_one_shot"); diff --git a/tests/python/protocol/test_converation_protocol.py b/tests/python/protocol/test_converation_protocol.py new file mode 100644 index 0000000000..9656eb8b18 --- /dev/null +++ b/tests/python/protocol/test_converation_protocol.py @@ -0,0 +1,20 @@ +import pytest + +from mlc_llm.conversation_template import ConvTemplateRegistry +from mlc_llm.protocol.conversation_protocol import Conversation + + +def get_conv_templates(): + return ["llama-2", "mistral_default", "gorilla", "chatml", "phi-2"] + + +@pytest.mark.parametrize("conv_template_name", get_conv_templates()) +def test_json(conv_template_name): + template = ConvTemplateRegistry.get_conv_template(conv_template_name) + j = template.to_json_dict() + template_parsed = Conversation.from_json_dict(j) + assert template == template_parsed + + +if __name__ == "__main__": + test_json() From 73f2b27b73cb035ca1e5715110950cc8d70e0d4b Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sat, 16 Mar 2024 17:33:02 +0800 Subject: [PATCH 076/531] [SLM] Small correction on Stablelm and Qwen2. (#1958) * small fix * small fix * Update stablelm_model.py --- python/mlc_llm/model/qwen2/qwen2_model.py | 2 +- python/mlc_llm/model/stable_lm/stablelm_model.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/mlc_llm/model/qwen2/qwen2_model.py b/python/mlc_llm/model/qwen2/qwen2_model.py index ad55c83bb4..db533285d8 100644 --- a/python/mlc_llm/model/qwen2/qwen2_model.py +++ b/python/mlc_llm/model/qwen2/qwen2_model.py @@ -267,7 +267,7 @@ def create_paged_kv_cache( page_size=page_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, - num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, head_dim=self.head_dim, rope_mode=RopeMode.NORMAL, rope_scale=1, diff --git a/python/mlc_llm/model/stable_lm/stablelm_model.py b/python/mlc_llm/model/stable_lm/stablelm_model.py index b32372ce6d..710bf7698e 100644 --- a/python/mlc_llm/model/stable_lm/stablelm_model.py +++ b/python/mlc_llm/model/stable_lm/stablelm_model.py @@ -74,7 +74,6 @@ def __post_init__(self): bold("context_window_size"), ) self.prefill_chunk_size = self.context_window_size - assert self.tensor_parallel_shards == 1, "StableLM currently does not support sharding." # pylint: disable=invalid-name,missing-docstring @@ -168,11 +167,11 @@ def __init__(self, config: StableLmConfig): self.num_hidden_layers = config.num_hidden_layers self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads self.head_dim = self.hidden_size // self.num_attention_heads self.vocab_size = config.vocab_size self.rope_theta = config.rope_theta self.tensor_parallel_shards = config.tensor_parallel_shards - self.dtype = "float32" self.partial_rotary_factor = config.partial_rotary_factor def to(self, dtype: Optional[str] = None): @@ -253,7 +252,7 @@ def create_paged_kv_cache( page_size=page_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, - num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, head_dim=self.head_dim, rope_mode=RopeMode.NORMAL, rope_scale=1, From d6b86d1ba0e439cbfb79146853eda95afdb6a0e1 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Sat, 16 Mar 2024 23:08:43 +0800 Subject: [PATCH 077/531] [Serving][Fix] Fix JSON output check in test_server.py (#1966) `test_server::is_json_or_json_prefix` is used to check the output is JSON or a prefix of JSON. It uses json.loads internally. However, json.loads (i.e. json.decode) is token-based instead of char based. If half a token is left at the end of the string, it cannot be matched. This PR adds another check for the rest "half a token" if it exists. --- tests/python/serve/server/test_server.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index 88734455cf..b726a6b41d 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -25,6 +25,7 @@ from typing import Dict, List, Optional, Tuple import pytest +import regex import requests from openai import OpenAI @@ -35,12 +36,26 @@ DEBUG_DUMP_EVENT_TRACE_URL = "http://127.0.0.1:8000/debug/dump_event_trace" +JSON_TOKEN_PATTERN = ( + r"((-?(?:0|[1-9]\d*))(\.\d+)?([eE][-+]?\d+)?)|null|true|false|" + r'("((\\["\\\/bfnrt])|(\\u[0-9a-fA-F]{4})|[^"\\\x00-\x1f])*")' +) +JSON_TOKEN_RE = regex.compile(JSON_TOKEN_PATTERN) + + def is_json_or_json_prefix(s: str) -> bool: try: json.loads(s) return True except json.JSONDecodeError as e: - return e.pos == len(s) + # If the JSON decoder reaches the end of s, it is a prefix of a JSON string. + if e.pos == len(s): + return True + # Since json.loads is token-based instead of char-based, there may remain half a token after + # the matching position. + # If the left part is a prefix of a valid JSON token, the output is also valid + regex_match = JSON_TOKEN_RE.fullmatch(s[e.pos :], partial=True) + return regex_match is not None def check_openai_nonstream_response( From edffce44c55539ca43c3eff4b4022dd628205cb7 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 16 Mar 2024 18:42:51 -0400 Subject: [PATCH 078/531] [Model] Migrate Mistral to use PagedKVCache (#1967) This PR migrates the mistral model to the PagedKVCache interface which supports sliding window attention with paged attention kernel written in TensorIR. We thereby introduce a `support_sliding_window` mode for KV cache, which leaves space for supporting sliding window for any model at runtime. This PR tests the mistral on with both chat and serve. The chat performance of Mistral 7B gets improvement than before, benefitted from the paged attention implementation. --- cpp/llm_chat.cc | 122 ++--- cpp/serve/config.cc | 7 +- cpp/serve/engine.cc | 28 +- cpp/serve/engine_actions/action_commons.cc | 19 - cpp/serve/engine_actions/batch_decode.cc | 1 - cpp/serve/engine_actions/batch_verify.cc | 1 - .../engine_actions/new_request_prefill.cc | 24 +- cpp/serve/engine_state.cc | 1 - cpp/serve/engine_state.h | 2 - cpp/serve/function_table.cc | 26 +- cpp/serve/function_table.h | 3 +- cpp/serve/model.cc | 56 +- cpp/serve/model.h | 20 +- .../dispatch_kv_cache_creation.py | 25 +- .../mlc_llm/model/baichuan/baichuan_model.py | 5 +- python/mlc_llm/model/gemma/gemma_model.py | 5 +- python/mlc_llm/model/gpt2/gpt2_model.py | 5 +- .../model/gpt_bigcode/gpt_bigcode_model.py | 5 +- .../mlc_llm/model/gpt_neox/gpt_neox_model.py | 5 +- .../mlc_llm/model/internlm/internlm_model.py | 5 +- python/mlc_llm/model/llama/llama_model.py | 5 +- python/mlc_llm/model/mistral/mistral_model.py | 483 ++++++------------ python/mlc_llm/model/orion/orion_model.py | 5 +- python/mlc_llm/model/phi/phi_model.py | 5 +- python/mlc_llm/model/qwen/qwen_model.py | 5 +- python/mlc_llm/model/qwen2/qwen2_model.py | 5 +- .../mlc_llm/model/stable_lm/stablelm_model.py | 5 +- python/mlc_llm/nn/kv_cache.py | 337 ++++++------ python/mlc_llm/op/position_embedding.py | 123 ----- python/mlc_llm/serve/async_engine.py | 6 +- python/mlc_llm/serve/engine.py | 13 +- .../serve/entrypoints/entrypoint_utils.py | 6 +- .../serve/entrypoints/openai_entrypoints.py | 4 +- tests/python/model/test_kv_cache.py | 171 ++----- tests/python/serve/server/test_server.py | 7 +- 35 files changed, 627 insertions(+), 918 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 09c2ce9a37..8ec3c5ec1d 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -259,6 +259,8 @@ struct FunctionTable { this->reset_kv_cache_func_ = get_global_func("vm.builtin.kv_state_clear"); this->kv_cache_add_sequence_func_ = get_global_func("vm.builtin.kv_state_add_sequence"); this->kv_cache_remove_sequence_func_ = get_global_func("vm.builtin.kv_state_remove_sequence"); + this->kv_cache_enable_sliding_window_for_seq_ = + get_global_func("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq"); this->kv_cache_begin_forward_func_ = get_global_func("vm.builtin.kv_state_begin_forward"); this->kv_cache_end_forward_func_ = get_global_func("vm.builtin.kv_state_end_forward"); this->fkvcache_array_popn_ = get_global_func("vm.builtin.kv_state_popn"); @@ -345,6 +347,7 @@ struct FunctionTable { PackedFunc reset_kv_cache_func_; PackedFunc kv_cache_add_sequence_func_; PackedFunc kv_cache_remove_sequence_func_; + PackedFunc kv_cache_enable_sliding_window_for_seq_; PackedFunc kv_cache_begin_forward_func_; PackedFunc kv_cache_end_forward_func_; bool support_backtracking_kv_; @@ -663,12 +666,17 @@ class LLMChat { this->params_ = ft_.LoadParams(model_path, device_, use_presharded_weights_); // Step 6. KV cache creation. if (ft_.use_kv_state == FunctionTable::KVStateKind::kAttention) { + int max_total_seq_length = + this->max_window_size_ == -1 ? this->sliding_window_size_ : this->max_window_size_; + ICHECK_GT(max_total_seq_length, 0); IntTuple max_num_sequence{1}; - IntTuple max_total_sequence_length{this->max_window_size_}; + IntTuple max_total_sequence_length{max_total_seq_length}; IntTuple prefill_chunk_size{this->prefill_chunk_size_}; IntTuple page_size{16}; - this->kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence, max_total_sequence_length, - prefill_chunk_size, page_size); + IntTuple support_sliding_window{sliding_window_size_ != -1}; + this->kv_cache_ = + ft_.create_kv_cache_func_(max_num_sequence, max_total_sequence_length, prefill_chunk_size, + page_size, support_sliding_window); } else if (ft_.use_kv_state == FunctionTable::KVStateKind::kRNNState) { IntTuple max_num_sequence{1}; IntTuple max_history_length{1}; @@ -697,8 +705,6 @@ class LLMChat { this->ResetRuntimeStats(); this->ResetKVCache(); this->total_seq_len_ = 0; - this->sliding_window_cache_offset_ = 0; - this->sink_triggered_ = false; } /*! \brief reset the runtime stats. */ @@ -984,19 +990,6 @@ class LLMChat { std::vector(prompt_tokens.begin() + begin, prompt_tokens.begin() + end); new_seq_len += static_cast(chunk.size()); logits_on_device = this->ForwardTokens(chunk, new_seq_len); - - // update window cache offset (prefill) - if (this->sliding_window_size_ != -1) { - if (sink_triggered_) { - sliding_window_cache_offset_ = - std::max((sliding_window_cache_offset_ + static_cast(chunk.size())) % - sliding_window_size_, - attention_sink_size_); - } else { - sliding_window_cache_offset_ += static_cast(chunk.size()); - sink_triggered_ = sliding_window_cache_offset_ >= attention_sink_size_; - } - } } ICHECK_EQ(new_seq_len, total_seq_len_ + token_len) << "Expect chunking process all tokens"; } else { @@ -1035,18 +1028,6 @@ class LLMChat { NDArray logits_on_device = this->ForwardTokens({last_token}, total_seq_len_ + 1); total_seq_len_ += 1; - - // update window cache offset (decoding) - if (this->sliding_window_size_ != -1) { - if (sink_triggered_) { - sliding_window_cache_offset_ = std::max( - (sliding_window_cache_offset_ + 1) % sliding_window_size_, attention_sink_size_); - } else { - sliding_window_cache_offset_ += 1; - sink_triggered_ = sliding_window_cache_offset_ >= attention_sink_size_; - } - } - int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config); auto tend = std::chrono::high_resolution_clock::now(); @@ -1372,32 +1353,20 @@ class LLMChat { ObjectRef ret{nullptr}; if (input_tokens.size() > 1 && ft_.prefill_func_.defined()) { ObjectRef input_data = ft_.CopyToWorker0(this->GetInputTokenNDArray(input_tokens)); - if (sliding_window_size_ == -1) { - if (ft_.use_kv_state) { - int input_len = input_tokens.size(); - IntTuple seq_ids_tuple({0}); - ShapeTuple input_len_shape{input_len}; - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, input_len_shape); - input_data = ft_.nd_view_func_(input_data, input_len_shape); - auto embed = ft_.embed_func_(input_data, params_); - ShapeTuple embedding_shape = {1, input_len, GetHiddenSizeFromEmbedding(embed)}; - embed = ft_.nd_view_func_(embed, embedding_shape); - ret = ft_.prefill_func_(embed, kv_cache_, params_); - ft_.kv_cache_end_forward_func_(kv_cache_); - } else { - ShapeTuple cur_pos_shape = ShapeTuple({cur_pos}); - ret = ft_.prefill_func_(input_data, cur_pos_shape, kv_cache_, params_); - } + if (ft_.use_kv_state) { + int input_len = input_tokens.size(); + IntTuple seq_ids_tuple({0}); + ShapeTuple input_len_shape{input_len}; + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, input_len_shape); + input_data = ft_.nd_view_func_(input_data, input_len_shape); + auto embed = ft_.embed_func_(input_data, params_); + ShapeTuple embedding_shape = {1, input_len, GetHiddenSizeFromEmbedding(embed)}; + embed = ft_.nd_view_func_(embed, embedding_shape); + ret = ft_.prefill_func_(embed, kv_cache_, params_); + ft_.kv_cache_end_forward_func_(kv_cache_); } else { - // Sliding window attention needs extra shape parameters - int64_t seq_len = static_cast(input_tokens.size()); - // Number of elements in the cache - int64_t cache_len = std::min(this->sliding_window_size_, cur_pos - seq_len); - ShapeTuple cache_len_shape = ShapeTuple({cache_len}); - ShapeTuple kv_seq_len_shape = ShapeTuple({cache_len + seq_len}); - ShapeTuple cache_offset_shape = ShapeTuple({sliding_window_cache_offset_}); - ret = ft_.prefill_func_(input_data, cache_len_shape, kv_seq_len_shape, cache_offset_shape, - kv_cache_, params_); + ShapeTuple cur_pos_shape = ShapeTuple({cur_pos}); + ret = ft_.prefill_func_(input_data, cur_pos_shape, kv_cache_, params_); } } else { // running decode function when prefill is not available @@ -1412,30 +1381,18 @@ class LLMChat { } int64_t pos = cur_pos + i + 1 - input_tokens.size(); ShapeTuple pos_shape = ShapeTuple({pos}); - if (sliding_window_size_ == -1) { - if (ft_.use_kv_state) { - IntTuple seq_ids_tuple({0}); - IntTuple append_length({1}); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, append_length); - input_data = ft_.nd_view_func_(input_data, append_length); - auto embed = ft_.embed_func_(input_data, params_); - ShapeTuple embedding_shape = {1, 1, GetHiddenSizeFromEmbedding(embed)}; - embed = ft_.nd_view_func_(embed, embedding_shape); - ret = ft_.decode_func_(embed, kv_cache_, params_); - ft_.kv_cache_end_forward_func_(kv_cache_); - } else { - ret = ft_.decode_func_(input_data, pos_shape, kv_cache_, params_); - } + if (ft_.use_kv_state) { + IntTuple seq_ids_tuple({0}); + IntTuple append_length({1}); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, append_length); + input_data = ft_.nd_view_func_(input_data, append_length); + auto embed = ft_.embed_func_(input_data, params_); + ShapeTuple embedding_shape = {1, 1, GetHiddenSizeFromEmbedding(embed)}; + embed = ft_.nd_view_func_(embed, embedding_shape); + ret = ft_.decode_func_(embed, kv_cache_, params_); + ft_.kv_cache_end_forward_func_(kv_cache_); } else { - // Sliding window attention needs extra shape parameters - int64_t seq_len = static_cast(input_tokens.size()); - // Number of elements in the cache - int64_t cache_len = std::min(this->sliding_window_size_, pos - seq_len); - ShapeTuple cache_len_shape = ShapeTuple({cache_len}); - ShapeTuple kv_seq_len_shape = ShapeTuple({cache_len + seq_len}); - ShapeTuple cache_offset_shape = ShapeTuple({sliding_window_cache_offset_}); - ret = ft_.decode_func_(input_data, cache_len_shape, kv_seq_len_shape, cache_offset_shape, - kv_cache_, params_); + ret = ft_.decode_func_(input_data, pos_shape, kv_cache_, params_); } } } @@ -1553,6 +1510,11 @@ class LLMChat { ft_.reset_kv_cache_func_(kv_cache_); if (ft_.use_kv_state) { ft_.kv_cache_add_sequence_func_(kv_cache_, 0); + if (sliding_window_size_ != -1) { + int attention_sink_size = std::max(static_cast(attention_sink_size_), 0); + ft_.kv_cache_enable_sliding_window_for_seq_(kv_cache_, 0, sliding_window_size_, + attention_sink_size); + } } } @@ -1624,10 +1586,6 @@ class LLMChat { std::string output_message_; // Whether encounter stop str bool stop_triggered_{false}; - // Whether sink is in action - bool sink_triggered_{false}; - // sliding window cache offset - int64_t sliding_window_cache_offset_{0}; //---------------------------- // Model configurations //---------------------------- diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 451b3a0279..5a0b35a3c6 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -252,10 +252,9 @@ KVCacheConfig::KVCacheConfig(const std::string& config_str, int max_single_seque if (config.count("max_num_sequence")) { CHECK(config["max_num_sequence"].is()); max_num_sequence = config["max_num_sequence"].get(); - } - - if (max_num_sequence == -1) { - max_num_sequence = max_total_sequence_length / max_single_sequence_length; + CHECK_GT(max_num_sequence, 0) << "Max number of sequence should be positive."; + } else { + LOG(FATAL) << "Key \"max_num_sequence\" not found."; } ObjectPtr n = make_object(); diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 39c84a1c8d..3288a70afd 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -51,7 +51,10 @@ class EngineImpl : public Engine { CHECK_GE(model_infos.size(), 1) << "ValueError: No model is provided in the engine."; // Step 1. Initialize metadata and singleton states inside the engine this->estate_->Reset(); - this->max_single_sequence_length_ = max_single_sequence_length; + // Being "-1" means there is no limit on single sequence length. + this->max_single_sequence_length_ = max_single_sequence_length != -1 + ? max_single_sequence_length + : std::numeric_limits::max(); this->kv_cache_config_ = KVCacheConfig(kv_cache_config_json_str, max_single_sequence_length); this->engine_mode_ = EngineMode(engine_mode_json_str); this->request_stream_callback_ = std::move(request_stream_callback); @@ -140,6 +143,17 @@ class EngineImpl : public Engine { // Get a request copy where all text inputs are tokenized. request = Request::FromUntokenized(request, tokenizer_); ICHECK_NE(request->input_total_length, -1); + + if (request->input_total_length >= kv_cache_config_->prefill_chunk_size) { + // If the request input length exceeds the prefill chunk size, + // invoke callback and do not process the request. + // Todo(mlc-team): Use "maximum single sequence length" after impl input chunking. + Array output{RequestStreamOutput( + request->id, {}, Optional>>(), {String("length")})}; + request_stream_callback_.value()(std::move(output)); + return; + } + // Append to the waiting queue and create the request state. estate_->waiting_queue.push_back(request); @@ -189,21 +203,11 @@ class EngineImpl : public Engine { // The request to abort is in running queue estate_->running_queue.erase(it_running); - // Reduce the input length. - estate_->stats.current_total_seq_len -= request->input_total_length; - // Reduce the generated length. - for (int i = 0; i < static_cast(rstate->entries.size()); ++i) { + for (int i = static_cast(rstate->entries.size()) - 1; i >= 0; --i) { if (rstate->entries[i]->status != RequestStateStatus::kAlive) { continue; } - estate_->stats.current_total_seq_len -= - rstate->entries[i]->mstates[0]->committed_tokens.size(); RemoveRequestFromModel(estate_, rstate->entries[i]->mstates[0]->internal_id, models_); - if (rstate->entries[i]->child_indices.empty()) { - // For each running leaf state, length 1 is over reduced since the last - // token is not added into KV cache. So we add the length back. - ++estate_->stats.current_total_seq_len; - } } } if (it_waiting != estate_->waiting_queue.end()) { diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index 133bc4e6e5..35ba851386 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -28,8 +28,6 @@ void ProcessFinishedRequestStateEntries(std::vector finished_ // Remove the request state entry from all the models. RemoveRequestFromModel(estate, rsentry->mstates[0]->internal_id, models); estate->id_manager.RecycleId(rsentry->mstates[0]->internal_id); - estate->stats.current_total_seq_len -= - static_cast(rsentry->mstates[0]->committed_tokens.size()) - 1; RequestState rstate = estate->GetRequestState(rsentry->request); int parent_idx = rsentry->parent_idx; @@ -51,16 +49,11 @@ void ProcessFinishedRequestStateEntries(std::vector finished_ // Remove the request state entry from all the models. RemoveRequestFromModel(estate, rstate->entries[parent_idx]->mstates[0]->internal_id, models); estate->id_manager.RecycleId(rstate->entries[parent_idx]->mstates[0]->internal_id); - estate->stats.current_total_seq_len -= - static_cast(rstate->entries[parent_idx]->mstates[0]->committed_tokens.size()); // Climb up to the parent. parent_idx = rstate->entries[parent_idx]->parent_idx; } if (parent_idx == -1) { - // All request state entries of the request have been removed. - // Reduce the total input length from the engine stats. - estate->stats.current_total_seq_len -= rsentry->request->input_total_length; // Remove from running queue and engine state. auto it = std::find(estate->running_queue.begin(), estate->running_queue.end(), rsentry->request); @@ -163,18 +156,6 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, // - Update `inputs` for future prefill. RECORD_EVENT(trace_recorder, rsentry->request->id, "preempt"); rsentry->status = RequestStateStatus::kPending; - estate->stats.current_total_seq_len -= rsentry->mstates[0]->committed_tokens.size(); - if (rsentry->child_indices.empty()) { - // The length was overly decreased by 1 when the entry has no child. - ++estate->stats.current_total_seq_len; - } - if (rsentry->parent_idx == -1) { - // Subtract the input length from the total length when the - // current entry is the root entry of the request. - estate->stats.current_total_seq_len -= request->input_total_length; - } - estate->stats.current_total_seq_len -= - request->input_total_length + rsentry->mstates[0]->committed_tokens.size() - 1; for (RequestModelState mstate : rsentry->mstates) { mstate->RemoveAllDraftTokens(); ICHECK(mstate->inputs.empty()); diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index eea7e79fb4..47007f6c8d 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -53,7 +53,6 @@ class BatchDecodeActionObj : public EngineActionObj { // NOTE: Right now we only support decode all the running request states at a time. int num_rsentries = running_rsentries.size(); - estate->stats.current_total_seq_len += num_rsentries; // Collect // - the last committed token, // - the request id, diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index df1737c547..9270b6d284 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -127,7 +127,6 @@ class BatchVerifyActionObj : public EngineActionObj { rsentries[i]->mstates[verify_model_id_]->CommitToken(sample_result); rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result); } - estate->stats.current_total_seq_len += accept_length; estate->stats.total_accepted_length += accept_length; // - Minus one because the last draft token has no kv cache entry // - Take max with 0 in case of all accepted. diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index 715105a043..905eea3ed1 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -91,6 +91,10 @@ class NewRequestPrefillActionObj : public EngineActionObj { ->internal_id, mstate->internal_id); } + // Enable sliding window for the sequence if it is not a parent. + if (rsentries[i]->child_indices.empty()) { + models_[model_id]->EnableSlidingWindowForSeq(mstate->internal_id); + } request_internal_ids.push_back(mstate->internal_id); RECORD_EVENT(trace_recorder_, rsentries[i]->request->id, "start embedding"); for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { @@ -151,7 +155,6 @@ class NewRequestPrefillActionObj : public EngineActionObj { request_ids.clear(); generation_cfg.clear(); for (int i = 0; i < num_rsentries; ++i) { - estate->stats.current_total_seq_len += prefill_lengths[i]; const RequestStateEntry& rsentry = rsentries[i]; for (int child_idx : rsentry->child_indices) { if (rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty()) { @@ -168,9 +171,14 @@ class NewRequestPrefillActionObj : public EngineActionObj { ICHECK(rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending); rstates_of_entries[i]->entries[child_idx]->status = RequestStateStatus::kAlive; for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { - models_[model_id]->ForkSequence( - rsentry->mstates[model_id]->internal_id, - rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id); + int64_t child_internal_id = + rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id; + models_[model_id]->ForkSequence(rsentry->mstates[model_id]->internal_id, + child_internal_id); + // Enable sliding window for the child sequence if the child is not a parent. + if (rstates_of_entries[i]->entries[child_idx]->child_indices.empty()) { + models_[model_id]->EnableSlidingWindowForSeq(child_internal_id); + } } } } @@ -252,6 +260,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { int total_required_pages = 0; int num_available_pages = models_[0]->GetNumAvailablePages(); int num_running_rsentries = GetRunningRequestStateEntries(estate).size(); + int current_total_seq_len = models_[0]->GetCurrentTotalSequenceLength(); int num_prefill_rsentries = 0; for (const Request& request : estate->waiting_queue) { @@ -276,7 +285,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { total_required_pages += num_require_pages; if (CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), total_input_length, total_required_pages, num_available_pages, - num_running_rsentries)) { + current_total_seq_len, num_running_rsentries)) { rsentries_to_prefill.push_back(rsentry); prefill_lengths.push_back(input_length); num_prefill_rsentries += 1 + rsentry->child_indices.size(); @@ -297,7 +306,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { /*! \brief Check if the input requests can be prefilled under conditions. */ bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, - int num_required_pages, int num_available_pages, int num_running_rsentries) { + int num_required_pages, int num_available_pages, int current_total_seq_len, + int num_running_rsentries) { ICHECK_LE(num_running_rsentries, kv_cache_config_->max_num_sequence); // No exceeding of the maximum allowed requests that can @@ -317,7 +327,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { int new_batch_size = num_running_rsentries + num_prefill_rsentries; return total_input_length <= kv_cache_config_->prefill_chunk_size && num_required_pages + new_batch_size <= num_available_pages && - estate->stats.current_total_seq_len + total_input_length + 8 * new_batch_size <= + current_total_seq_len + total_input_length + 8 * new_batch_size <= kv_cache_config_->max_total_sequence_length; } diff --git a/cpp/serve/engine_state.cc b/cpp/serve/engine_state.cc index 3aeac5ffaf..563f0e7b13 100644 --- a/cpp/serve/engine_state.cc +++ b/cpp/serve/engine_state.cc @@ -26,7 +26,6 @@ String EngineStats::AsJSON() const { } void EngineStats::Reset() { - current_total_seq_len = 0; request_total_prefill_time = 0.0f; request_total_decode_time = 0.0f; engine_total_prefill_time = 0.0f; diff --git a/cpp/serve/engine_state.h b/cpp/serve/engine_state.h index edd61d751a..ff955a264f 100644 --- a/cpp/serve/engine_state.h +++ b/cpp/serve/engine_state.h @@ -18,8 +18,6 @@ using namespace tvm::runtime; /*! \brief Runtime statistics of engine. */ struct EngineStats { - /*! \brief The current total sequence length in the first model. */ - int64_t current_total_seq_len = 0; /*! \brief The sum of "prefill time of each request". */ double request_total_prefill_time = 0.0f; /*! \brief The sum of "decode time of each request". */ diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 1c42caae1e..d7c70a508a 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -207,21 +207,19 @@ void FunctionTable::_InitFunctions() { this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); ICHECK(this->create_kv_cache_func_.defined()); } - this->reset_kv_cache_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_clear"); - this->kv_cache_add_sequence_func_ = - get_global_func("vm.builtin.paged_attention_kv_cache_add_sequence"); - this->kv_cache_fork_sequence_func_ = - get_global_func("vm.builtin.paged_attention_kv_cache_fork_sequence"); - this->kv_cache_remove_sequence_func_ = - get_global_func("vm.builtin.paged_attention_kv_cache_remove_sequence"); - this->kv_cache_begin_forward_func_ = - get_global_func("vm.builtin.paged_attention_kv_cache_begin_forward"); - this->kv_cache_end_forward_func_ = - get_global_func("vm.builtin.paged_attention_kv_cache_end_forward"); - this->kv_cache_attention_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_attention"); - this->kv_cache_popn_func_ = get_global_func("vm.builtin.paged_attention_kv_cache_popn"); + this->reset_kv_cache_func_ = get_global_func("vm.builtin.kv_state_clear"); + this->kv_cache_add_sequence_func_ = get_global_func("vm.builtin.kv_state_add_sequence"); + this->kv_cache_fork_sequence_func_ = get_global_func("vm.builtin.kv_state_fork_sequence"); + this->kv_cache_enable_sliding_window_for_seq_ = + get_global_func("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq"); + this->kv_cache_remove_sequence_func_ = get_global_func("vm.builtin.kv_state_remove_sequence"); + this->kv_cache_begin_forward_func_ = get_global_func("vm.builtin.kv_state_begin_forward"); + this->kv_cache_end_forward_func_ = get_global_func("vm.builtin.kv_state_end_forward"); + this->kv_cache_popn_func_ = get_global_func("vm.builtin.kv_state_popn"); this->kv_cache_get_num_available_pages_func_ = - get_global_func("vm.builtin.paged_attention_kv_cache_get_num_available_pages"); + *tvm::runtime::Registry::Get("vm.builtin.attention_kv_cache_get_num_available_pages"); + this->kv_cache_get_total_sequence_length_func_ = + *tvm::runtime::Registry::Get("vm.builtin.attention_kv_cache_get_total_sequence_length"); if (Sampler::SupportGPUSampler(local_gpu_device)) { gpu_multinomial_from_uniform_func_ = mod->GetFunction("multinomial_from_uniform", true); gpu_argsort_probs_func_ = mod->GetFunction("argsort_probs", true); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index f3466506ff..5a515ba9b7 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -81,12 +81,13 @@ struct FunctionTable { bool support_backtracking_kv_; PackedFunc kv_cache_add_sequence_func_; PackedFunc kv_cache_fork_sequence_func_; + PackedFunc kv_cache_enable_sliding_window_for_seq_; PackedFunc kv_cache_remove_sequence_func_; PackedFunc kv_cache_begin_forward_func_; PackedFunc kv_cache_end_forward_func_; - PackedFunc kv_cache_attention_func_; PackedFunc kv_cache_popn_func_; PackedFunc kv_cache_get_num_available_pages_func_; + PackedFunc kv_cache_get_total_sequence_length_func_; PackedFunc gpu_multinomial_from_uniform_func_; PackedFunc gpu_argsort_probs_func_; PackedFunc gpu_sample_with_top_p_func_; diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index da332b3775..0463728df0 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -298,8 +298,10 @@ class ModelImpl : public ModelObj { IntTuple max_total_sequence_length{kv_cache_config->max_total_sequence_length}; IntTuple prefill_chunk_size{kv_cache_config->prefill_chunk_size}; IntTuple page_size{kv_cache_config->page_size}; + IntTuple support_sliding_window{sliding_window_size_ != -1}; kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence, max_total_sequence_length, - prefill_chunk_size, page_size); + prefill_chunk_size, page_size, support_sliding_window); + local_kv_cache_ = ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; } void AddNewSequence(int64_t seq_id) final { ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); } @@ -308,24 +310,29 @@ class ModelImpl : public ModelObj { ft_.kv_cache_fork_sequence_func_(kv_cache_, parent_seq_id, child_seq_id); } - /*! \brief Remove the given sequence from the KV cache in the model. */ void RemoveSequence(int64_t seq_id) final { ft_.kv_cache_remove_sequence_func_(kv_cache_, seq_id); } - /*! \brief Get the number of available pages in KV cache. */ - int GetNumAvailablePages() const final { - if (!ft_.use_disco) { - return ft_.kv_cache_get_num_available_pages_func_(kv_cache_); - } else { - DRef ret = ft_.kv_cache_get_num_available_pages_func_(kv_cache_); - return ret->DebugGetFromRemote(0); + void PopNFromKVCache(int64_t seq_id, int num_tokens) final { + ft_.kv_cache_popn_func_(kv_cache_, seq_id, num_tokens); + } + + void EnableSlidingWindowForSeq(int64_t seq_id) final { + if (sliding_window_size_ != -1) { + ft_.kv_cache_enable_sliding_window_for_seq_(kv_cache_, seq_id, sliding_window_size_, + attention_sink_size_); } } - /*! \brief Pop out N pages from KV cache. */ - void PopNFromKVCache(int seq_id, int num_tokens) final { - ft_.kv_cache_popn_func_(kv_cache_, seq_id, num_tokens); + /************** Raw Info Query **************/ + + int GetNumAvailablePages() const final { + return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_); + } + + int GetCurrentTotalSequenceLength() const final { + return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_); } /*********************** Utilities ***********************/ @@ -336,8 +343,8 @@ class ModelImpl : public ModelObj { } int GetMaxWindowSize() const final { - CHECK_NE(max_window_size_, -1) << "The model has not been initialized"; - return max_window_size_; + // Being "-1" means there is no limit on the window size. + return max_window_size_ != -1 ? max_window_size_ : std::numeric_limits::max(); } ObjectRef AllocEmbeddingTensor() final { @@ -383,6 +390,17 @@ class ModelImpl : public ModelObj { } else { LOG(FATAL) << "Key \"context_window_size\" not found."; } + if (config.count("sliding_window_size")) { + CHECK(config["sliding_window_size"].is()); + this->sliding_window_size_ = config["sliding_window_size"].get(); + CHECK(sliding_window_size_ == -1 || sliding_window_size_ > 0) + << "Sliding window should be either -1 (which means disabled) of positive"; + } + if (config.count("attention_sink_size")) { + CHECK(config["attention_sink_size"].is()); + this->attention_sink_size_ = config["attention_sink_size"].get(); + this->attention_sink_size_ = std::max(this->attention_sink_size_, 0); + } if (config.count("tensor_parallel_shards")) { CHECK(config["tensor_parallel_shards"].is()); this->num_shards_ = config["tensor_parallel_shards"].get(); @@ -408,6 +426,8 @@ class ModelImpl : public ModelObj { // Model configurations //---------------------------- int max_window_size_ = -1; + int sliding_window_size_ = -1; + int attention_sink_size_ = 0; int num_shards_ = -1; int max_num_sequence_ = -1; int prefill_chunk_size_ = -1; @@ -418,8 +438,14 @@ class ModelImpl : public ModelObj { //---------------------------- // Packed function table FunctionTable ft_; - // Paged KV cache + // Paged KV cache. + // - We use `kv_cache_` for general KV cache operations. + // When tensor parallelism is enabled, `kv_cache_` is a DRef object. + // - For efficient KV cache raw info query, we use `local_kv_cache` + // as a local **reference** of `kv_cache_`. It is a pure mirror of `kv_cache_` + // except that it is always a local object. ObjectRef kv_cache_{nullptr}; + ObjectRef local_kv_cache_{nullptr}; // Runtime device Device device_; // Model parameters diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 7bce2cafd4..1019834921 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -136,11 +136,25 @@ class ModelObj : public Object { /*! \brief Remove the given sequence from the KV cache in the model. */ virtual void RemoveSequence(int64_t seq_id) = 0; + /*! \brief Pop out N pages from KV cache. */ + virtual void PopNFromKVCache(int64_t seq_id, int num_tokens) = 0; + + /*! + * \brief Enabling sliding window for the given sequence. + * It is a no-op if the model does not support sliding window. + * \note Given this operation is tied with the underlying KV cache, + * we add the function in Model interface to expose this for Engine. + * This may be optimized with decoupling KV cache and Model in the future. + */ + virtual void EnableSlidingWindowForSeq(int64_t seq_id) = 0; + + /************** Raw Info Query **************/ + /*! \brief Get the number of available pages in KV cache. */ virtual int GetNumAvailablePages() const = 0; - /*! \brief Pop out N pages from KV cache. */ - virtual void PopNFromKVCache(int seq_id, int num_tokens) = 0; + /*! \brief Get the current total sequence length in the KV cache. */ + virtual int GetCurrentTotalSequenceLength() const = 0; /*********************** Utilities ***********************/ @@ -161,7 +175,7 @@ class ModelObj : public Object { */ virtual int EstimateHostCPURequirement() const = 0; - /*! \brief Get the max window size of the model. */ + /*! \brief Get the max window size of the model. "-1" means infinite length. */ virtual int GetMaxWindowSize() const = 0; /*! \brief Allocate an embedding tensor with the prefill chunk size. */ diff --git a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py index 0c8846d670..e90bdfef78 100644 --- a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py +++ b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py @@ -23,7 +23,7 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]: assert len(args) == 11 assert isinstance(args[1], relax.ShapeExpr) - assert len(args[1].values) == 4 + assert len(args[1].values) == 5 for i in range(2, 10): assert isinstance(args[i], relax.PrimValue) assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm)) @@ -34,6 +34,7 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]: "max_total_seq_len": args[1].values[1], "prefill_chunk_size": args[1].values[2], "page_size": args[1].values[3], + "support_sliding_window": args[1].values[4], "num_hidden_layers": args[2].value.value, "num_attention_heads": args[3].value.value, "num_key_value_heads": args[4].value.value, @@ -119,10 +120,19 @@ def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, An "prefill_chunk_size_", relax.ShapeStructInfo([kwargs["prefill_chunk_size"]]) ) page_size = relax.Var("page_size_", relax.ShapeStructInfo([kwargs["page_size"]])) + support_sliding_window = relax.Var( + "support_sliding_window_", relax.ShapeStructInfo([kwargs["support_sliding_window"]]) + ) with bb.function( name="create_tir_paged_kv_cache", - params=[max_batch_size, max_total_seq_len, prefill_chunk_size, page_size], + params=[ + max_batch_size, + max_total_seq_len, + prefill_chunk_size, + page_size, + support_sliding_window, + ], ): cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs) bb.emit_func_output(cache._expr) # pylint: disable=protected-access @@ -160,10 +170,19 @@ def create_flashinfer_paged_kv_cache( "prefill_chunk_size_", relax.ShapeStructInfo([kwargs["prefill_chunk_size"]]) ) page_size = relax.Var("page_size_", relax.ShapeStructInfo([kwargs["page_size"]])) + support_sliding_window = relax.Var( + "support_sliding_window_", relax.ShapeStructInfo([kwargs["support_sliding_window"]]) + ) with bb.function( name="create_flashinfer_paged_kv_cache", - params=[max_batch_size, max_total_seq_len, prefill_chunk_size, page_size], + params=[ + max_batch_size, + max_total_seq_len, + prefill_chunk_size, + page_size, + support_sliding_window, + ], ): cache = kv_cache.FlashInferPagedKVCache(target=self.target, **kwargs) bb.emit_func_output(cache._expr) # pylint: disable=protected-access diff --git a/python/mlc_llm/model/baichuan/baichuan_model.py b/python/mlc_llm/model/baichuan/baichuan_model.py index 334c32d7d5..ce51659b25 100644 --- a/python/mlc_llm/model/baichuan/baichuan_model.py +++ b/python/mlc_llm/model/baichuan/baichuan_model.py @@ -229,18 +229,20 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards, @@ -314,6 +316,7 @@ def get_default_spec(self): "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py index 9303e2552e..079708ddb8 100644 --- a/python/mlc_llm/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -291,18 +291,20 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, @@ -376,6 +378,7 @@ def get_default_spec(self): "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index cf2a967cac..3c229fd911 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -283,18 +283,20 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=self.n_layer, num_attention_heads=self.n_head // self.tensor_parallel_shards, num_key_value_heads=self.n_head // self.tensor_parallel_shards, @@ -368,6 +370,7 @@ def get_default_spec(self): "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py index d98871964f..c96caa9fee 100644 --- a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py @@ -260,18 +260,20 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=self.n_layer, num_attention_heads=self.num_q_heads // self.tensor_parallel_shards, num_key_value_heads=self.num_kv_heads // self.tensor_parallel_shards, @@ -345,6 +347,7 @@ def get_default_spec(self): "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py index 0a0c494685..62e6587bf2 100644 --- a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py @@ -314,18 +314,20 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards, @@ -400,6 +402,7 @@ def get_default_spec(self): "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/internlm/internlm_model.py b/python/mlc_llm/model/internlm/internlm_model.py index cf39437dd6..d97d253c8f 100644 --- a/python/mlc_llm/model/internlm/internlm_model.py +++ b/python/mlc_llm/model/internlm/internlm_model.py @@ -230,18 +230,20 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards, @@ -315,6 +317,7 @@ def get_default_spec(self): "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index fb5f5637b8..f38997cdeb 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -274,18 +274,20 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, @@ -359,6 +361,7 @@ def get_default_spec(self): "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/mistral/mistral_model.py b/python/mlc_llm/model/mistral/mistral_model.py index 88be860628..0b66ea706d 100644 --- a/python/mlc_llm/model/mistral/mistral_model.py +++ b/python/mlc_llm/model/mistral/mistral_model.py @@ -5,12 +5,12 @@ import dataclasses from typing import Any, Dict, Optional -from tvm import relax as rx from tvm import te, tir from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -36,6 +36,7 @@ class MistralConfig(ConfigBase): # pylint: disable=too-many-instance-attributes prefill_chunk_size: int = 0 attention_sink_size: int = 4 tensor_parallel_shards: int = 1 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -59,62 +60,11 @@ def __post_init__(self): self.sliding_window_size, ) self.prefill_chunk_size = self.sliding_window_size - elif self.prefill_chunk_size > self.sliding_window_size: - logger.info( - "Overriding %s from %d to %d (%s)", - bold("prefill_chunk_size"), - self.prefill_chunk_size, - self.sliding_window_size, - bold("sliding_window_size"), - ) - self.prefill_chunk_size = self.sliding_window_size # pylint: disable=invalid-name,missing-docstring -class RotaryEmbedding(nn.Module): - """Cache relative Rotary Embedding.""" - - def __init__(self, config: MistralConfig): - super().__init__() - self.head_dim = config.head_dim - self.position_embedding_base = config.position_embedding_base - - def forward(self, q: Tensor, k: Tensor, q_offset: tir.Var): - def te_op(x: te.Tensor, offset: tir.Var): - dtype = x.dtype - - def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var): - head_dim = tir.const(self.head_dim, "int32") - position_embedding_base = tir.const(self.position_embedding_base, "float32") - freq = tir.power( - position_embedding_base, - (d * 2 % head_dim).astype("float32") / head_dim, - ) - freq = (offset + s) / freq - cos = tir.cos(freq).astype(dtype) * x[b, s, h, d] - sin = tir.sin(freq).astype(dtype) * tir.if_then_else( - d < head_dim // 2, - -x[b, s, h, d + head_dim // 2], - x[b, s, h, d - head_dim // 2], - ) - return cos + sin - - return te.compute(x.shape, compute, name="rotary") - - q_embed = op.tensor_expr_op( - te_op, - "rotary_embedding", - args=[q, q_offset], - attrs={"mlc.rotary_embedding_to_all_dims": True}, - ) - k_embed = op.tensor_expr_op( - te_op, "rotary_embedding", args=[k, 0], attrs={"mlc.rotary_embedding_to_all_dims": True} - ) - return q_embed, k_embed - - class MistralMLP(nn.Module): """Same as in Llama architecture (LlamaFFN).""" @@ -137,166 +87,37 @@ def forward(self, x: Tensor): class MistralAttention(nn.Module): # pylint: disable=too-many-instance-attributes """Same as LlamaAttention, but with sliding window attention using a rolling buffer cache.""" - def __init__(self, config: MistralConfig, rotary_embedding: RotaryEmbedding): - self.rotary_embedding = rotary_embedding - self.hidden_size = config.hidden_size + def __init__(self, config: MistralConfig): self.head_dim = config.head_dim self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards - self.sliding_window_size = config.sliding_window_size - self.attention_sink_size = config.attention_sink_size self.qkv_proj = nn.Linear( in_features=config.hidden_size, out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim, bias=False, ) self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False) - self.k_cache = RollingKVCacheWithSinks( - self.sliding_window_size, [self.num_kv_heads, self.head_dim] - ) - self.v_cache = RollingKVCacheWithSinks( - self.sliding_window_size, [self.num_kv_heads, self.head_dim] - ) - - def interleave_kv( # pylint: disable=too-many-arguments,too-many-locals - self, - k_cur: Tensor, - v_cur: Tensor, - kv_seq_len: tir.Var, - rolling_cache_len: tir.Var, - cache_offset: tir.Var, - ): - """Unrotate and concatenate currunt and cached k and v""" - h_kv, d = self.num_kv_heads, self.head_dim - kv_s, c, o = kv_seq_len, rolling_cache_len, cache_offset - b = k_cur.shape[0] - - k_cached = op.reshape(self.k_cache.view(c), (b, c, h_kv, d)) - v_cached = op.reshape(self.v_cache.view(c), (b, c, h_kv, d)) - - def _cache_unrotate(x_cached, rolling_cache_len, cache_offset): - return te.compute( - (b, kv_s, h_kv, d), - lambda xb, xs, xh, xd: te.if_then_else( - xs < self.attention_sink_size, - x_cached[xb, xs, xh, xd], - te.if_then_else( - xs < rolling_cache_len - cache_offset + self.attention_sink_size, - x_cached[xb, xs + cache_offset - self.attention_sink_size, xh, xd], - x_cached[xb, xs + cache_offset - rolling_cache_len, xh, xd], - ), - ), - name="cache_unrotate_te", - ) - - def _cache_cur_concat(x_cached, x_cur, rolling_cache_len): - return te.compute( - (b, kv_s, h_kv, d), - lambda xb, xs, xh, xd: te.if_then_else( - xs < rolling_cache_len, - x_cached[xb, xs, xh, xd], - x_cur[xb, xs - rolling_cache_len, xh, xd], - ), - name="cache_cur_concat_te", - ) - k_cached = op.tensor_expr_op( - _cache_unrotate, - name_hint="te_cache_unrotate_key", - args=[k_cached, c, o], - ) - k = op.tensor_expr_op( - _cache_cur_concat, - name_hint="te_cache_cur_concat_key", - args=[k_cached, k_cur, c], - ) - - v_cached = op.tensor_expr_op( - _cache_unrotate, - name_hint="te_cache_unrotate_value", - args=[v_cached, c, o], - ) - v = op.tensor_expr_op( - _cache_cur_concat, - name_hint="te_cache_cur_concat_value", - args=[v_cached, v_cur, c], - ) - - self.k_cache.override( - op.squeeze(k_cur, axis=0), self.sliding_window_size, self.attention_sink_size - ) - self.v_cache.override( - op.squeeze(v_cur, axis=0), self.sliding_window_size, self.attention_sink_size - ) - - return k, v - - def forward( # pylint: disable=too-many-arguments, too-many-locals - self, - hidden_states: Tensor, - attention_mask: Tensor, - rolling_cache_len: tir.Var, # Number of elements currently in the cache. - kv_seq_len: tir.Var, # Equals to ``seq_len + rolling_cache_len``. - cache_offset: tir.Var, - ): - """Forward pass of MistralAttention, performing QKV.""" + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads b, s, _ = hidden_states.shape - assert b == 1, "Only support batch size 1 at this moment." - qkv_cur = self.qkv_proj(hidden_states) - qkv_cur = op.reshape(qkv_cur, (b, s, h_q + 2 * h_kv, d)) - q, k_cur, v_cur = op.split(qkv_cur, [h_q, h_q + h_kv], axis=2) - k, v = self.interleave_kv(k_cur, v_cur, kv_seq_len, rolling_cache_len, cache_offset) - q, k = self.rotary_embedding(q, k, rolling_cache_len) - output = op_ext.attention(q, k, v, attention_mask) - return self.o_proj(output) - - -class RollingKVCacheWithSinks(nn.KVCache): - """ - Rolling buffer cache implementation. - """ - - cache: Optional[rx.Var] - - def override(self, new_element: Tensor, max_cache_size: int, attention_sink_size: int) -> None: - """ - Override cache elements in RollingKVCacheWithSinks. - - Parameters - ---------- - new_element : Tensor - The new tensor to append. - - max_cache_size : int - Max size of the cache. - - attention_sink_size : int - Number of stored attention sinks. - """ - if new_element.dtype != self.dtype: - raise TypeError( - f'RollingKVCacheWithSinks has been set to use dtype "{self.dtype}", ' - f'but got "{new_element.dtype}"' - ) - self.cache = rx.BlockBuilder.current().emit( - rx.call_pure_packed( - "vm.builtin.attention_kv_cache_window_override_with_sinks", - self.cache, - new_element._expr, # pylint: disable=protected-access - rx.PrimValue(max_cache_size), - rx.PrimValue(attention_sink_size), - sinfo_args=[rx.ObjectStructInfo()], - ) + # QKV Projection + qkv = self.qkv_proj(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + (b, s, h_q * d), ) + return self.o_proj(output) class MistralDecoderLayer(nn.Module): """Exact same as LlamaDecoderLayer.""" - def __init__(self, config: MistralConfig, rotary_embedding: RotaryEmbedding): + def __init__(self, config: MistralConfig): rms_norm_eps = config.rms_norm_eps - self.self_attn = MistralAttention(config, rotary_embedding) + self.self_attn = MistralAttention(config) self.mlp = MistralMLP(config) self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) @@ -318,73 +139,53 @@ def _set(layer, hint): self.tensor_parallel_shards = config.tensor_parallel_shards _set_tp() - def forward( # pylint: disable=too-many-arguments - self, - hidden_states: Tensor, - attention_mask: Tensor, - rolling_cache_len: tir.Var, - kv_seq_len: tir.Var, - cache_offset: tir.Var, - ): - """Forward pass of a decoder layer; calculate attention, and add an residual connection.""" - - def _apply_residual(out, residual): - if self.tensor_parallel_shards > 1: - return op.ccl_allreduce(out, "sum") + residual - return out + residual - - out = self.self_attn( - self.input_layernorm(hidden_states), - attention_mask, - rolling_cache_len, - kv_seq_len, - cache_offset, - ) - hidden_states = _apply_residual(out, residual=hidden_states) + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) + hidden_states = self._apply_residual(out, residual=hidden_states) out = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = _apply_residual(out, residual=hidden_states) + hidden_states = self._apply_residual(out, residual=hidden_states) return hidden_states + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + class MistralModel(nn.Module): """Exact same as LlamaModel.""" def __init__(self, config: MistralConfig): assert config.hidden_size % config.num_attention_heads == 0 - rotary_embedding = RotaryEmbedding(config) self.embed_tokens = nn.Embedding("vocab_size", config.hidden_size) self.layers = nn.ModuleList( - [MistralDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)] + [MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) self.tensor_parallel_shards = config.tensor_parallel_shards - def forward( # pylint: disable=too-many-arguments - self, - inputs: Tensor, - rolling_cache_len: tir.Var, - kv_seq_len: tir.Var, - cache_offset: tir.Var, - attention_mask: Tensor, - ): - """Forward pass of the model, passing through all decoder layers.""" - if self.tensor_parallel_shards > 1: - inputs = op.ccl_broadcast_from_worker0(inputs) - hidden_states = self.embed_tokens(inputs) - for layer in self.layers: - hidden_states = layer( - hidden_states, attention_mask, rolling_cache_len, kv_seq_len, cache_offset - ) + def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = input_embed + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) hidden_states = self.norm(hidden_states) return hidden_states -class MistralForCasualLM(nn.Module): +class MistralForCasualLM(nn.Module): # pylint: disable=too-many-instance-attributes """Same as LlamaForCausalLM, except for the use of sliding window attention.""" def __init__(self, config: MistralConfig): self.model = MistralModel(config) self.lm_head = nn.Linear(config.hidden_size, "vocab_size", bias=False) + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.vocab_size = config.vocab_size + self.rope_theta = config.position_embedding_base + self.tensor_parallel_shards = config.tensor_parallel_shards self.sliding_window_size = config.sliding_window_size self.dtype = "float32" @@ -393,131 +194,155 @@ def to(self, dtype: Optional[str] = None): if dtype is not None: self.dtype = dtype - def forward( # pylint: disable=too-many-arguments + def batch_forward( self, - inputs: Tensor, - rolling_cache_len: tir.Var, - kv_seq_len: tir.Var, - cache_offset: tir.Var, - attention_mask: Tensor, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, ): - """Forward pass.""" + op_ext.configure() + + hidden_states = self.model(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + return self.model.embed_tokens(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() def _index(x: te.Tensor): # x[:-1,:] b, s, d = x.shape return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") - hidden_states = self.model( - inputs, rolling_cache_len, kv_seq_len, cache_offset, attention_mask - ) + hidden_states = self.model(input_embed, paged_kv_cache) hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") - return logits + return logits, paged_kv_cache - def prefill( - self, - inputs: Tensor, - rolling_cache_len: tir.Var, - kv_seq_len: tir.Var, - cache_offset: tir.Var, + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): - """ - Prefilling the prompt. - - Parameters - ---------- - inputs: Tensor - Input tokens, having ``seq_len`` number of tokens. - - rolling_cache_len: tir.Var - Number of elements currently in the cache. - - kv_seq_len: tir.Var - Equals to ``seq_len + rolling_cache_len``. - - cache_offset: tir.Var - Next position to be overrided on the rolling kv cache. - """ - - def _sliding_window_attention_mask( - batch_size, seq_len, rolling_cache_len, kv_seq_len, sliding_window_size - ): - # See `tests/legacy-python/test_sliding_window_mask.py` for its behavior - return te.compute( - (batch_size, 1, seq_len, kv_seq_len), - lambda b, _, i, j: tir.Select( - tir.all( - i + rolling_cache_len >= j, i + rolling_cache_len - j < sliding_window_size - ), - tir.max_value(self.dtype), - tir.min_value(self.dtype), - ), - name="sliding_window_attention_mask_prefill", - ) + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache - batch_size, seq_len = inputs.shape - attention_mask = op.tensor_expr_op( - _sliding_window_attention_mask, - name_hint="sliding_window_attention_mask_prefill", - args=[ - batch_size, - seq_len, - rolling_cache_len, - kv_seq_len, - self.sliding_window_size, - ], - ) - return self.forward(inputs, rolling_cache_len, kv_seq_len, cache_offset, attention_mask) + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache - def decode( + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, - inputs: Tensor, - rolling_cache_len: tir.Var, - kv_seq_len: tir.Var, - cache_offset: tir.Var, - ): - """Decoding step.""" - batch_size, seq_len = inputs.shape - attention_mask = op.full( - shape=[batch_size, 1, seq_len, kv_seq_len], - fill_value=tir.max_value(self.dtype), + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, dtype=self.dtype, ) - return self.forward(inputs, rolling_cache_len, kv_seq_len, cache_offset, attention_mask) - - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - """Softmax.""" - return op.softmax(logits / temperature, axis=-1) def get_default_spec(self): - """Needed for ``export_tvm()``.""" - batch_size = 1 mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "prefill": { - "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), - "rolling_cache_len": int, - "kv_seq_len": int, - "cache_offset": int, + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", }, }, "decode": { - "inputs": nn.spec.Tensor([batch_size, 1], "int32"), - "rolling_cache_len": int, - "kv_seq_len": int, - "cache_offset": int, + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), "$": { "param_mode": "packed", - "effect_mode": "packed", + "effect_mode": "none", }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/orion/orion_model.py b/python/mlc_llm/model/orion/orion_model.py index 9964ab911f..48de826a3b 100644 --- a/python/mlc_llm/model/orion/orion_model.py +++ b/python/mlc_llm/model/orion/orion_model.py @@ -275,18 +275,20 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, @@ -360,6 +362,7 @@ def get_default_spec(self): "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/phi/phi_model.py b/python/mlc_llm/model/phi/phi_model.py index 0b3f3f092f..6d95833d41 100644 --- a/python/mlc_llm/model/phi/phi_model.py +++ b/python/mlc_llm/model/phi/phi_model.py @@ -384,18 +384,20 @@ def embed(self, input_ids: Tensor): embeds = self.transformer.embd(input_ids) return embeds - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, @@ -470,6 +472,7 @@ def get_default_spec(self): "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/qwen/qwen_model.py b/python/mlc_llm/model/qwen/qwen_model.py index 54157c7eb3..5cd979e589 100644 --- a/python/mlc_llm/model/qwen/qwen_model.py +++ b/python/mlc_llm/model/qwen/qwen_model.py @@ -235,18 +235,20 @@ def batch_verify(self, inputs: Tensor, paged_kv_cache: PagedKVCache): def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards, @@ -320,6 +322,7 @@ def get_default_spec(self): "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/qwen2/qwen2_model.py b/python/mlc_llm/model/qwen2/qwen2_model.py index db533285d8..c85e8337df 100644 --- a/python/mlc_llm/model/qwen2/qwen2_model.py +++ b/python/mlc_llm/model/qwen2/qwen2_model.py @@ -253,18 +253,20 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, @@ -338,6 +340,7 @@ def get_default_spec(self): "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/stable_lm/stablelm_model.py b/python/mlc_llm/model/stable_lm/stablelm_model.py index 710bf7698e..8589fbc501 100644 --- a/python/mlc_llm/model/stable_lm/stablelm_model.py +++ b/python/mlc_llm/model/stable_lm/stablelm_model.py @@ -238,18 +238,20 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, @@ -324,6 +326,7 @@ def get_default_spec(self): "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index 2863ed47b7..c4792bb57c 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -1,6 +1,6 @@ """Attention KV cache modeling.""" -# pylint: disable=too-many-statements,too-many-lines +# pylint: disable=too-many-statements,too-many-lines,too-many-arguments import enum import math from typing import Optional, Tuple @@ -12,11 +12,7 @@ from tvm.script import tir as T from tvm.target import Target -from mlc_llm.op.position_embedding import ( - llama_inplace_rope, - llama_rope_with_position_map, - rope_freq, -) +from mlc_llm.op.position_embedding import llama_rope_with_position_map, rope_freq from ..support.max_thread_check import ( check_thread_limits, @@ -40,11 +36,12 @@ class PagedKVCache(Object): # pylint: disable=too-few-public-methods """The Paged KV Cache used in LLM batching for efficient attention computation.""" @staticmethod - def create_generic( # pylint: disable=too-many-arguments + def create_generic( max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, num_hidden_layers: int, num_attention_heads: int, num_key_value_heads: int, @@ -64,7 +61,15 @@ def create_generic( # pylint: disable=too-many-arguments return PagedKVCache( _expr=rx.call_pure_packed( "mlc.create_paged_kv_cache_generic", - rx.ShapeExpr([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size]), + rx.ShapeExpr( + [ + max_batch_size, + max_total_seq_len, + prefill_chunk_size, + page_size, + support_sliding_window, + ] + ), rx.PrimValue(num_hidden_layers), rx.PrimValue(num_attention_heads), rx.PrimValue(num_key_value_heads), @@ -79,48 +84,6 @@ def create_generic( # pylint: disable=too-many-arguments _name=name, ) - def attention( # pylint: disable=invalid-name, too-many-arguments - self, - layer_id: int, - q: Tensor, - k: Tensor, - v: Tensor, - attn_score_scaling_factor: float = 1.0, - ) -> Tensor: - """Compute attention with the given q/k/v data and in-cache k/v data - on the specified layer. Rotary position embeddings are applied to k/v - within this function. - - - For prefill, the input q and output tensor have shape - (1, total_seq_len, num_attention_heads, head_dim), and the - k/v tensors have shape (1, total_seq_len, num_key_value_heads, head_dim). - - For decode, the input q and output tensor have shape - (batch_size, 1, num_attention_heads, head_dim), and the - k/v tensors have shape (batch_size, 1, num_key_value_heads, head_dim). - """ - # pylint: disable=protected-access - q_shape = q.shape - q = q.reshape(q.shape[0] * q.shape[1], q.shape[2], q.shape[3]) - k = k.reshape(k.shape[0] * k.shape[1], k.shape[2], k.shape[3]) - v = v.reshape(v.shape[0] * v.shape[1], v.shape[2], v.shape[3]) - return Tensor( - _expr=rx.BlockBuilder.current().emit( - rx.call_dps_packed( - "vm.builtin.paged_attention_kv_cache_attention", - [ - self._expr, - rx.PrimValue(layer_id), # type: ignore[arg-type] - rx.PrimValue(attn_score_scaling_factor), - q._expr, - k._expr, - v._expr, - ], - out_sinfo=q._expr.struct_info, - ) - ) - ).reshape(*q_shape) - # pylint: enable=protected-access - def attention_with_fused_qkv( # pylint: disable=invalid-name self, layer_id: int, @@ -146,7 +109,7 @@ def attention_with_fused_qkv( # pylint: disable=invalid-name return Tensor( _expr=rx.BlockBuilder.current().emit( rx.call_dps_packed( - "vm.builtin.paged_attention_kv_cache_attention_with_fused_qkv", + "vm.builtin.attention_kv_cache_attention_with_fused_qkv", [ self._expr, rx.PrimValue(layer_id), # type: ignore[arg-type] @@ -176,7 +139,7 @@ def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor: return Tensor( _expr=rx.BlockBuilder.current().emit( rx.call_pure_packed( - "vm.builtin.paged_attention_kv_cache_get_query_positions", + "vm.builtin.attention_kv_cache_get_query_positions", self._expr, sinfo_args=rx.TensorStructInfo((total_length,), "int32"), ) @@ -189,12 +152,13 @@ def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor: class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods """Paged KV cache using FlashInfer (CUDA) kernels.""" - def __init__( # pylint: disable=too-many-arguments,too-many-locals + def __init__( # pylint: disable=too-many-locals self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, num_hidden_layers: int, num_attention_heads: int, num_key_value_heads: int, @@ -227,6 +191,10 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals The size (a.k.a. number of tokens) of each page. It is a symbolic variable whose concrete value is specified at runtime. + support_sliding_window : tir.Var + 0 or 1, denoting whether the KV cache supports sliding window. + It is a symbolic variable whose concrete value is specified + at runtime. rope_mode : RopeMode The RoPE mode of the Paged KV cache. If it is normal, RoPE will be applied to k before adding k to cache. @@ -243,7 +211,15 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals bb = rx.BlockBuilder.current() # pylint: disable=invalid-name args = [ - rx.ShapeExpr([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size]), + rx.ShapeExpr( + [ + max_batch_size, + max_total_seq_len, + prefill_chunk_size, + page_size, + support_sliding_window, + ] + ), rx.PrimValue(num_hidden_layers), rx.PrimValue(num_attention_heads), rx.PrimValue(num_key_value_heads), @@ -257,6 +233,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache"), rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_prefill_sliding_window"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_decode_sliding_window"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), @@ -266,16 +244,15 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward"), rx.extern("flashinfer.merge_state_in_place"), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), - bb.add_func(llama_inplace_rope(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, target, rotary_dim), "tir_qk_rotary_inplace"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), # fmt: on # pylint: enable=line-too-long ] super().__init__( - _expr=rx.Call( - rx.extern("vm.builtin.paged_attention_kv_cache_create"), - args=args, - sinfo_args=[rx.ObjectStructInfo()], + _expr=rx.call_pure_packed( + "vm.builtin.paged_attention_kv_cache_create", + *args, + sinfo_args=rx.ObjectStructInfo(), ), _name=name, ) @@ -284,12 +261,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods """Paged KV cache using TIR kernels.""" - def __init__( # pylint: disable=too-many-arguments,too-many-locals + def __init__( # pylint: disable=too-many-locals self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, num_hidden_layers: int, num_attention_heads: int, num_key_value_heads: int, @@ -322,6 +300,10 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals The size (a.k.a. number of tokens) of each page. It is a symbolic variable whose concrete value is specified at runtime. + support_sliding_window : tir.Var + 0 or 1, denoting whether the KV cache supports sliding window. + It is a symbolic variable whose concrete value is specified + at runtime. rope_mode : RopeMode The RoPE mode of the Paged KV cache. If it is normal, RoPE will be applied to k before adding k to cache. @@ -338,7 +320,15 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals bb = rx.BlockBuilder.current() args = [ - rx.ShapeExpr([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size]), + rx.ShapeExpr( + [ + max_batch_size, + max_total_seq_len, + prefill_chunk_size, + page_size, + support_sliding_window, + ] + ), rx.PrimValue(num_hidden_layers), rx.PrimValue(num_attention_heads), rx.PrimValue(num_key_value_heads), @@ -350,21 +340,22 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals # pylint: disable=line-too-long # fmt: off bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_decode"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, False, target), "tir_attention_prefill"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, False, target), "tir_attention_decode"), + bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_prefill_sliding_window"), + bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_decode_sliding_window"), bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_ragged"), bb.add_func(_merge_state_inplace(num_key_value_heads, head_dim, dtype, target), "tir_attention_merge_state"), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), - bb.add_func(llama_inplace_rope(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, target, rotary_dim), "tir_qk_rotary_inplace"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), # fmt: on # pylint: enable=line-too-long ] super().__init__( - _expr=rx.Call( - rx.extern("vm.builtin.paged_attention_kv_cache_create_reduced"), - args=args, - sinfo_args=[rx.ObjectStructInfo()], + _expr=rx.call_pure_packed( + "vm.builtin.paged_attention_kv_cache_create_reduced", + *args, + sinfo_args=rx.ObjectStructInfo(), ), _name=name, ) @@ -394,18 +385,19 @@ def tir_kv_cache_transpose_append( v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), dtype) position_map = T.match_buffer(var_position_map, (ntoken,), "int32") for global_pos, h, f in T.grid(ntoken, num_key_value_heads, head_dim): - with T.block("k_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) - T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) - position: T.int32 = position_map[vgpos] # type: ignore - pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf] = k_data[vgpos, vh, vf] - with T.block("v_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) - T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) - position: T.int32 = position_map[vgpos] # type: ignore[name-defined,no-redef] - pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf] = v_data[vgpos, vh, vf] + if position_map[global_pos] != T.int32(-1): + with T.block("k_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] # type: ignore + pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] # type: ignore[name-defined,no-redef] + pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf] = v_data[vgpos, vh, vf] # fmt: on # pylint: enable=line-too-long,invalid-name @@ -447,7 +439,7 @@ def tir_kv_cache_debug_get_kv( return tir_kv_cache_debug_get_kv -def _rope( # pylint: disable=too-many-arguments +def _rope( buffer: T.Buffer, offset: tir.Var, rotary_dim: int, @@ -471,7 +463,46 @@ def _var(dtype): return T.alloc_buffer((1,), dtype, scope="local") -def _attention_prefill(h_kv, h_q, d, dtype, target: Target): # pylint: disable=unused-argument +def _causal_mask(causal, row, col, kv_len, qo_len): + return T.if_then_else( + causal > 0, + col < kv_len - qo_len + row + 1, + col < kv_len, + ) + + +def _declare_length_info(var_length_info, batch_size, sliding_window): + return ( + T.match_buffer(var_length_info, (3, batch_size), "int32") + if sliding_window + else T.match_buffer(var_length_info, (batch_size,), "int32") + ) + + +def _get_kv_chunk_len(num_pages, page_size, seq_id, length_info, sliding_window): + if not sliding_window: + return (num_pages - 1) * page_size + length_info[seq_id] + # ((num_pages - 1) * page_size + last_page_len) - sliding_window_offset + sink_size + return ( + (num_pages - 1) * page_size + + length_info[0, seq_id] + - length_info[1, seq_id] + + length_info[2, seq_id] + ) + + +def _get_seq_offset(pos, seq_id, length_info, sliding_window): + if not sliding_window: + return pos + # pos if pos < sink_size else pos - sink_size + sliding_window_offset + return T.if_then_else( + pos < length_info[2, seq_id], + pos, + pos - length_info[2, seq_id] + length_info[1, seq_id], + ) + + +def _attention_prefill(h_kv, h_q, d, dtype, sliding_window: bool, target: Target): # pylint: disable=invalid-name NUM_BLKS = 16 LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes @@ -492,14 +523,11 @@ def _attention_prefill(h_kv, h_q, d, dtype, target: Target): # pylint: disable= num_warps = 2 check_thread_limits(target, bdx=bdx, bdy=num_warps, bdz=1, gdz=1) - def mask(causal, row, col, kv_len, qo_len): - return T.if_then_else( - causal > 0, - col < kv_len - qo_len + row + 1, - col < kv_len, - ) + global_symbol = "batch_prefill_paged_kv" + if sliding_window: + global_symbol += "_sliding_window" - # pylint: disable=line-too-long,too-many-arguments,too-many-branches + # pylint: disable=line-too-long,too-many-branches # fmt: off @T.prim_func def batch_prefill_paged_kv( @@ -509,7 +537,7 @@ def batch_prefill_paged_kv( var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] var_page_indptr: T.handle, # [batch_size + 1] var_page_values: T.handle, # [nnz_pages] - var_last_page_len: T.handle, # [b] + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] var_k_rope_pos_offset: T.handle, # [b] var_q_rope_position: T.handle, # [total_len] var_output: T.handle, # [total_len, h_q, d] @@ -520,6 +548,7 @@ def batch_prefill_paged_kv( rope_theta: T.float32, attn_score_scaling_factor: T.float32, ): + T.func_attr({"global_symbol": global_symbol}) batch_size = T.int32(is_size_var=True) total_len = T.int32(is_size_var=True) nnz_pages = T.int32(is_size_var=True) @@ -530,11 +559,19 @@ def batch_prefill_paged_kv( pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32") page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32") - last_page_len = T.match_buffer(var_last_page_len, (batch_size,), "int32") k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32") q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32") output = T.match_buffer(var_output, (total_len, h_q, d), dtype) lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint: disable=unused-variable + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info(var_length_info, batch_size, sliding_window) # kernel code for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): @@ -590,10 +627,9 @@ def batch_prefill_paged_kv( cur_page_indptr_begin: T.int32 = page_indptr[b_idx] cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] - cur_last_page_len: T.int32 = last_page_len[b_idx] kv_chunk_len[0] = T.if_then_else( cur_page_indptr_begin != cur_page_indptr_end, - (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + cur_last_page_len, + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, b_idx, length_info, sliding_window), 0 ) T.tvm_storage_sync("shared") @@ -638,8 +674,9 @@ def batch_prefill_paged_kv( T.writes() cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: - page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(cur_L, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(cur_L, 16) # type: ignore + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore K_smem[i, j] = T.if_then_else( rotary_mode == 1, _rope(pages, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by, page_offset, j), dtype), @@ -655,8 +692,9 @@ def batch_prefill_paged_kv( T.writes() cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: - page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(cur_L, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(cur_L, 16) # type: ignore + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore V_smem[i, j] = pages[page_no, 1, by, page_offset, j] else: V_smem[i, j] = 0.0 @@ -686,7 +724,7 @@ def batch_prefill_paged_kv( m_new[i] = m_smem[row] # mask out of kv_chunk_len S for j in T.serial(tile_z): - if mask(causal, + if _causal_mask(causal, row=tile_id[0] * L_per_cta + row // group_size, col=L_kv_start + j, kv_len=kv_chunk_len[0], @@ -700,7 +738,7 @@ def batch_prefill_paged_kv( for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: - if mask(causal, + if _causal_mask(causal, row=tile_id[0] * L_per_cta + row // group_size, col=L_kv_start + j, kv_len=kv_chunk_len[0], @@ -746,7 +784,7 @@ def batch_prefill_paged_kv( # move to next tile tile_id[0] += NUM_BLKS # fmt: on - # pylint: enable=line-too-long,invalid-name,too-many-arguments,too-many-branches + # pylint: enable=line-too-long,invalid-name,too-many-branches sch = tir.Schedule(batch_prefill_paged_kv) def get_tile_size(x, y, t): @@ -779,7 +817,7 @@ def apply_to_so_ewise(sch: tir.Schedule, block, tile): sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument + def apply_to_gemm( # pylint: disable=unused-argument sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False ): loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] @@ -823,7 +861,8 @@ def _attention_decode( num_qo_heads, head_dim, qkv_dtype, - target: Target, # pylint: disable=unused-argument + sliding_window: bool, + target: Target, ): # pylint: disable=invalid-name qkv_dtype_bytes = 2 @@ -852,7 +891,11 @@ def _attention_decode( log2e = math.log2(math.exp(1)) check_thread_limits(target, bdx=bdx, bdy=bdy, bdz=bdz, gdz=1) - # pylint: disable=line-too-long,too-many-arguments,too-many-branches + global_symbol = "batch_decode_paged_kv" + if sliding_window: + global_symbol += "_sliding_window" + + # pylint: disable=line-too-long,too-many-branches # fmt: off @T.prim_func def batch_decode_paged_kv( @@ -861,7 +904,7 @@ def batch_decode_paged_kv( pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, - last_page_len_handle: T.handle, + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, @@ -871,7 +914,7 @@ def batch_decode_paged_kv( rope_theta: T.float32, attn_score_scaling_factor: T.float32, ): - T.func_attr({"tir.is_scheduled": 1}) + T.func_attr({"tir.is_scheduled": 1, "global_symbol": global_symbol}) B = T.int32(is_size_var=True) nnz_pages = T.int32(is_size_var=True) max_num_pages = T.int32(is_size_var=True) @@ -884,9 +927,17 @@ def batch_decode_paged_kv( page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32") k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32") q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32") - last_page_len = T.match_buffer(last_page_len_handle, (B,), "int32") output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype) lse = T.match_buffer(lse_handle, (B, H_qo), "float32") # pylint: disable=unused-variable + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info(var_length_info, B, sliding_window) sm_scale = 1.0 / math.sqrt(float(D)) * log2e @@ -922,10 +973,9 @@ def batch_decode_paged_kv( batch_idx: T.int32 = bx cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] - cur_last_page_len: T.int32 = last_page_len[batch_idx] kv_chunk_len[0] = T.if_then_else( cur_page_indptr_begin != cur_page_indptr_end, - (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + cur_last_page_len, + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, batch_idx, length_info, sliding_window), 0 ) @@ -948,31 +998,39 @@ def batch_decode_paged_kv( tile_start_g: T.int32(is_size_var=True) = ((iterator * bdz + tz) * bdy + ty) * tile_size_per_bdx # type: ignore # load K from global memory to shared memory for j in T.serial(tile_size_per_bdx): - row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore - if row_g < kv_chunk_len[0]: - page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(row_g, 16) # type: ignore - for vec in T.vectorized(VEC_SIZE): - K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = T.if_then_else( - rotary_mode == 1, - _rope(pages, k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype), - pages[page_no, 0, by, page_offset, tx * VEC_SIZE + vec] - ) - else: - for vec in T.vectorized(VEC_SIZE): - K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 + with T.block("K_load"): + T.reads() + T.writes() + row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore + if row_g < kv_chunk_len[0]: + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_g, batch_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + for vec in T.vectorized(VEC_SIZE): + K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = T.if_then_else( + rotary_mode == 1, + _rope(pages, k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype), + pages[page_no, 0, by, page_offset, tx * VEC_SIZE + vec] + ) + else: + for vec in T.vectorized(VEC_SIZE): + K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 T.tvm_storage_sync("shared") # load V from global memory to shared memory for j in T.serial(tile_size_per_bdx): - row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore - if row_g < kv_chunk_len[0]: - page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(row_g, 16) # type: ignore - for vec in T.vectorized(VEC_SIZE): - V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec] - else: - for vec in T.vectorized(VEC_SIZE): - V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 + with T.block("V_load"): + T.reads() + T.writes() + row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore + if row_g < kv_chunk_len[0]: + seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_g, batch_idx, length_info, sliding_window) # type: ignore + page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore + page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore + for vec in T.vectorized(VEC_SIZE): + V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec] + else: + for vec in T.vectorized(VEC_SIZE): + V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 T.tvm_storage_sync("shared") # compute QK m_prev[0] = st_m[0] @@ -1054,7 +1112,7 @@ def batch_decode_paged_kv( # store lse to global memory lse[batch_idx, by * GROUP_SIZE + bz * bdy + ty] = st_m[0] + T.log2(st_d[0]) # fmt: on - # pylint: enable=line-too-long,invalid-name,too-many-arguments,too-many-branches + # pylint: enable=line-too-long,invalid-name,too-many-branches return batch_decode_paged_kv @@ -1157,16 +1215,9 @@ def _attention_prefill_ragged( tile_z = 8 num_warps = 2 - def mask(causal, row, col, kv_len, qo_len): - return T.if_then_else( - causal > 0, - col < kv_len - qo_len + row + 1, - col < kv_len, - ) - # fmt: off @T.prim_func - def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branches + def batch_prefill_ragged_kv( # pylint: disable=too-many-branches var_q: T.handle, # [total_len, h_q, d] var_q_indptr: T.handle, # [batch_size + 1] var_k: T.handle, # [total_len, h_kv, d] @@ -1336,7 +1387,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran m_new[i] = m_smem[row] # mask out of kv_chunk_len S for j in T.serial(tile_z): - if mask(causal, + if _causal_mask(causal, row=tile_id[0] * L_per_cta + row // group_size, col=L_kv_start + j, kv_len=kv_chunk_len[0], @@ -1350,7 +1401,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: - if mask(causal, + if _causal_mask(causal, row=tile_id[0] * L_per_cta + row // group_size, col=L_kv_start + j, kv_len=kv_chunk_len[0], @@ -1396,7 +1447,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-bran # move to next tile tile_id[0] += NUM_BLKS # fmt: on - # pylint: enable=line-too-long,invalid-name,too-many-arguments,too-many-branches + # pylint: enable=line-too-long,invalid-name,too-many-branches sch = tir.Schedule(batch_prefill_ragged_kv) def get_tile_size(x, y, t): @@ -1429,7 +1480,7 @@ def apply_to_so_ewise(sch: tir.Schedule, block, tile): sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - def apply_to_gemm( # pylint: disable=too-many-arguments,unused-argument + def apply_to_gemm( # pylint: disable=unused-argument sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False ): loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] diff --git a/python/mlc_llm/op/position_embedding.py b/python/mlc_llm/op/position_embedding.py index 323afc02da..e6cb25d856 100644 --- a/python/mlc_llm/op/position_embedding.py +++ b/python/mlc_llm/op/position_embedding.py @@ -5,12 +5,6 @@ from tvm import tir from tvm.relax.frontend.nn import Tensor, op from tvm.script import tir as T -from tvm.target import Target - -from ..support.max_thread_check import ( - check_thread_limits, - get_max_num_threads_per_block, -) # pylint: disable=invalid-name @@ -271,120 +265,3 @@ def fused_rope( # pylint: disable=too-many-locals v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] return fused_rope - - -# pylint: disable=line-too-long,too-many-arguments,too-many-nested-blocks,invalid-name - - -def llama_inplace_rope( - theta: float, - scale: float, - head_dim: int, - num_q_heads: int, - num_kv_heads: int, - dtype: str, - target: Target, # pylint: disable=unused-argument - rotary_dim: Optional[int] = None, -): - """Return the TIR function that inplace computes Llama-style RoPE with q position offset. - - Parameters - ---------- - theta : float - The theta value, or "base" in RoPE, which controls the frequency. - - scale : float - The RoPE scaling factor. - - head_dim : int - The number of features on each head. - - num_q_heads : int - The number of query heads. - - num_kv_heads : int - The number of key/value heads. It differs from `num_q_heads` in group-query attention. - - dtype : str - The dtype of qkv data. - - target : Target - The target to build the model to. - - rotary_dim : Optional[int] - The number of dimensions in the embedding that RoPE is applied to. By default, the - rotary_dim is the same as head_dim. - """ - if rotary_dim is None: - rotary_dim = head_dim - - VEC_SIZE = 4 - bdx = (head_dim + VEC_SIZE - 1) // VEC_SIZE # T.ceildiv(head_dim, VEC_SIZE) - bdy = 32 - max_num_threads_per_block = get_max_num_threads_per_block(target) - # TODO(mlc-team): Check correctness after `bdy` backoff - while bdx * bdy > max_num_threads_per_block and bdy > 1: - bdy //= 2 - check_thread_limits(target, bdx=bdx, bdy=bdy, bdz=1, gdz=1) - - def _rope( - x: T.Buffer, - s: tir.Var, - h: tir.Var, - d: tir.Var, - rope_offset: tir.Var, - instance_offset: tir.Var, - ): - cos_freq, sin_freq = rope_freq((s + rope_offset) * scale, d, rotary_dim, theta, dtype) - cos = cos_freq * x[s + instance_offset, h, d] - sin = sin_freq * tir.if_then_else( - d < rotary_dim // 2, - -x[s + instance_offset, h, d + rotary_dim // 2], - x[s + instance_offset, h, d - rotary_dim // 2], - ) - return cos + sin - - # fmt: off - @T.prim_func - def tir_rotary( # pylint: disable=too-many-locals - var_q: T.handle, - var_k: T.handle, - var_append_len_indptr: T.handle, - var_rope_offsets: T.handle, - _0: T.int32, - _1: T.int32, - _2: T.int32, - _3: T.int32, - _4: T.int32, - _5: T.float32, - _6: T.float32, - ): - T.func_attr({"tir.is_scheduled": 1}) - total_len = T.int32() - batch_size = T.int32() - q = T.match_buffer(var_q, (total_len, num_q_heads, head_dim), dtype) - k = T.match_buffer(var_k, (total_len, num_kv_heads, head_dim), dtype) - rope_offsets = T.match_buffer(var_rope_offsets, (batch_size,), "int32") - append_len_indptr = T.match_buffer(var_append_len_indptr, (batch_size + 1,), "int32") - with T.block(): - for b_h in T.thread_binding(batch_size * (num_q_heads + num_kv_heads), thread="blockIdx.x"): - b: T.int32 = b_h // (num_q_heads + num_kv_heads) - h: T.int32 = b_h % (num_q_heads + num_kv_heads) - instance_offset: T.int32 = append_len_indptr[b] - rope_offset: T.int32 = rope_offsets[b] - append_len: T.int32 = append_len_indptr[b + 1] - append_len_indptr[b] - for s0 in range(T.ceildiv(append_len, bdy)): - for s1 in T.thread_binding(bdy, thread="threadIdx.y"): - for d0 in T.thread_binding(bdx, thread="threadIdx.x"): - for d1 in T.vectorized(VEC_SIZE): - s: T.int32 = s0 * bdy + s1 - d: T.int32 = d0 * VEC_SIZE + d1 - if s < append_len and d < rotary_dim: - if h < num_q_heads: - q[s + instance_offset, h, d] = _rope(q, s, h, d, rope_offset, instance_offset) - else: - k[s + instance_offset, h - num_q_heads, d] = _rope(k, s, h - num_q_heads, d, rope_offset, instance_offset) - return tir_rotary - - -# pylint: enable=line-too-long,too-many-arguments,too-many-nested-blocks,invalid-name diff --git a/python/mlc_llm/serve/async_engine.py b/python/mlc_llm/serve/async_engine.py index 84037b6fb1..048123286d 100644 --- a/python/mlc_llm/serve/async_engine.py +++ b/python/mlc_llm/serve/async_engine.py @@ -141,11 +141,13 @@ def __init__( model_args, config_file_paths, tokenizer_path, - self.max_single_sequence_length, + max_single_sequence_length, prefill_chunk_size, self.conv_template_name, ) = _process_model_args(models) self.trace_recorder = EventTraceRecorder() if enable_tracing else None + # Todo(mlc-team): use `max_single_sequence_length` only after impl input chunking. + self.max_input_sequence_length = min(max_single_sequence_length, prefill_chunk_size) if kv_cache_config.max_total_sequence_length is None: kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( @@ -182,7 +184,7 @@ def __init__( def _background_loop(): self._ffi["init_background_engine"]( - self.max_single_sequence_length, + max_single_sequence_length, tokenizer_path, kv_cache_config.asjson(), engine_mode.asjson(), diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 7d19532d2b..06185d0c2a 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -11,6 +11,7 @@ import tvm from tvm.runtime import Device +from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.serve import data from mlc_llm.support import logging from mlc_llm.support.auto_device import detect_device @@ -87,7 +88,7 @@ def _convert_model_info(model: ModelInfo) -> List[Any]: model_path, config_file_path = _get_model_path(model.model) config_file_paths.append(config_file_path) chat_config = _get_chat_config(config_file_path, user_chat_config=None) - if chat_config.context_window_size: + if chat_config.context_window_size and chat_config.context_window_size != -1: max_single_sequence_length = min( max_single_sequence_length, chat_config.context_window_size, @@ -97,7 +98,8 @@ def _convert_model_info(model: ModelInfo) -> List[Any]: if tokenizer_path is None: tokenizer_path = model_path if conv_template_name is None: - conv_template_name = chat_config.conv_template + assert isinstance(chat_config.conv_template, Conversation) + conv_template_name = chat_config.conv_template.name # Try look up model library, and do JIT compile if model library not found. try: model_lib_path = _get_lib_module_path( @@ -125,6 +127,7 @@ def _convert_model_info(model: ModelInfo) -> List[Any]: start=[], ) + assert prefill_chunk_size != int(1e9) return ( model_args, config_file_paths, @@ -317,7 +320,7 @@ def __init__( # pylint: disable=too-many-arguments model_args, config_file_paths, tokenizer_path, - self.max_single_sequence_length, + max_single_sequence_length, prefill_chunk_size, self.conv_template_name, ) = _process_model_args(models) @@ -335,6 +338,8 @@ def __init__( # pylint: disable=too-many-arguments ], ) self.trace_recorder = EventTraceRecorder() if enable_tracing else None + # Todo(mlc-team): use `max_single_sequence_length` only after impl input chunking. + self.max_input_sequence_length = min(max_single_sequence_length, prefill_chunk_size) if kv_cache_config.max_total_sequence_length is None: kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( @@ -354,7 +359,7 @@ def __init__( # pylint: disable=too-many-arguments engine_mode = EngineMode() self._ffi["init"]( - self.max_single_sequence_length, + max_single_sequence_length, tokenizer_path, kv_cache_config.asjson(), engine_mode.asjson(), diff --git a/python/mlc_llm/serve/entrypoints/entrypoint_utils.py b/python/mlc_llm/serve/entrypoints/entrypoint_utils.py index 5a9924b94b..10256c2a48 100644 --- a/python/mlc_llm/serve/entrypoints/entrypoint_utils.py +++ b/python/mlc_llm/serve/entrypoints/entrypoint_utils.py @@ -38,7 +38,7 @@ def check_unsupported_fields( def check_prompts_length( - prompts: List[List[int]], max_single_sequence_length: int + prompts: List[List[int]], max_input_sequence_length: int ) -> Optional[fastapi.responses.JSONResponse]: """Check if the total prompt length exceeds the max single sequence sequence length allowed by the served model. Return an error if so. @@ -46,11 +46,11 @@ def check_prompts_length( total_length = 0 for prompt in prompts: total_length += len(prompt) - if total_length > max_single_sequence_length: + if total_length > max_input_sequence_length: return create_error_response( HTTPStatus.BAD_REQUEST, message=f"Request prompt has {total_length} tokens in total," - f" larger than the model capacity {max_single_sequence_length}.", + f" larger than the model input length limit {max_input_sequence_length}.", ) return None diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index 15e944e16a..04f7c3eb58 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -79,7 +79,7 @@ async def request_completion(request: CompletionRequest, raw_request: fastapi.Re message="Entrypoint /v1/completions only accept single prompt. " f"However, {len(prompts)} prompts {prompts} are received.", ) - error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_single_sequence_length) + error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_input_sequence_length) if error is not None: return error prompt = prompts[0] @@ -410,7 +410,7 @@ async def request_chat_completion( assert isinstance(prompts, list) and len(prompts) == 1, "Internal error" if conv_template.system_prefix_token_ids is not None: prompts[0] = conv_template.system_prefix_token_ids + prompts[0] - error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_single_sequence_length) + error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_input_sequence_length) if error is not None: return error prompt = prompts[0] diff --git a/tests/python/model/test_kv_cache.py b/tests/python/model/test_kv_cache.py index be4cc4a507..3e3afb92cc 100644 --- a/tests/python/model/test_kv_cache.py +++ b/tests/python/model/test_kv_cache.py @@ -16,160 +16,70 @@ def test_nn_module_paged_kv_cache(): # fmt: off @I.ir_module class Module: - @T.prim_func - def fused_rope(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, apply_rope: T.int32): # pylint: disable=too-many-arguments - T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) - seq_len = T.int64() - qkv = T.match_buffer(var_qkv, (seq_len, 96, 128), "float16") - position_map = T.match_buffer(var_position_map, (seq_len,), "int32") - q = T.match_buffer(var_q, (seq_len, 32, 128), "float16") - k = T.match_buffer(var_k, (seq_len, 32, 128), "float16") - v = T.match_buffer(var_v, (seq_len, 32, 128), "float16") - for iters_0, iters_1, iters_2 in T.grid(seq_len, 96, 128): - with T.block("llama_fused_rope"): - s, h, d = T.axis.remap("SSS", [iters_0, iters_1, iters_2]) - T.reads(position_map[s], qkv[s, h, d - 64:d - 64 + 129]) - T.writes(q[s, h, d], k[s, h - 32, d], v[s, h - 64, d]) - if h < 32: - q[s, h, d] = T.if_then_else(apply_rope > 0 and d < 128, T.Cast("float16", T.cos(T.Cast("float32", T.Cast("float16", position_map[s])) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * qkv[s, h, d] + T.Cast("float16", T.sin(T.Cast("float32", T.Cast("float16", position_map[s])) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * T.if_then_else(d < 64, qkv[s, h, d + 64] * T.float16(-1), qkv[s, h, d - 64]), qkv[s, h, d]) - else: - if h < 64: - k[s, h - 32, d] = T.if_then_else(apply_rope > 0 and d < 128, T.Cast("float16", T.cos(T.Cast("float32", T.Cast("float16", position_map[s])) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * qkv[s, h, d] + T.Cast("float16", T.sin(T.Cast("float32", T.Cast("float16", position_map[s])) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * T.if_then_else(d < 64, qkv[s, h, d + 64] * T.float16(-1), qkv[s, h, d - 64]), qkv[s, h, d]) - else: - v[s, h - 64, d] = qkv[s, h, d] - - @T.prim_func - def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int64): - T.func_attr({"tir.noalias": T.bool(True)}) - num_pages, page_size = T.int64(), T.int64(is_size_var=True) - pages = T.match_buffer(var_pages, (num_pages, 2, 32, page_size, 128), "float16") - seqlen = T.int64(is_size_var=True) - position_map = T.match_buffer(var_position_map, (seqlen,), "int32") - k_data = T.match_buffer(var_k_data, (32, seqlen, 32, 128), "float16") - v_data = T.match_buffer(var_v_data, (32, seqlen, 32, 128), "float16") - for p, h, d in T.grid(seqlen, 32, 128): - with T.block("copy0"): - vp, vh, vd = T.axis.remap("SSS", [p, h, d]) - T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd]) - T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) - position: T.int32 = position_map[vp] # type: ignore[name-defined] - k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd] - v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd] - - @T.prim_func - def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - num_pages = T.int64() - pages = T.match_buffer(var_pages, (num_pages, 2, 32, 16, 128), "float16") - ntoken = T.int64(is_size_var=True) - k_data = T.match_buffer(var_k_data, (ntoken, 32, 128), "float16") - v_data = T.match_buffer(var_v_data, (ntoken, 32, 128), "float16") - position_map = T.match_buffer(var_position_map, (ntoken,), "int32") - # with T.block("root"): - for global_pos, h, f in T.grid(ntoken, 32, 128): - with T.block("k_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) - T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) - position: T.int32 = position_map[vgpos] # type: ignore[no-redef] - pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf] - with T.block("v_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) - T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) - position: T.int32 = position_map[vgpos] # type: ignore[no-redef] - pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf] - - @T.prim_func - def tir_rotary(var_q: T.handle, var_k: T.handle, var_append_len_indptr: T.handle, var_rope_offsets: T.handle, _0: T.int32, _1: T.int32, _2: T.int32, _3: T.int32, _4: T.int32, _5: T.float32, _6: T.float32): - T.func_attr({"tir.is_scheduled": 1}) - total_len = T.int32() - q = T.match_buffer(var_q, (total_len, 32, 128), "float16") - k = T.match_buffer(var_k, (total_len, 32, 128), "float16") - batch_size = T.int32() - append_len_indptr = T.match_buffer(var_append_len_indptr, (batch_size + 1,), "int32") - rope_offsets = T.match_buffer(var_rope_offsets, (batch_size,), "int32") - with T.block(""): - T.reads() - T.writes() - for b_h in T.thread_binding(batch_size * 64, thread="blockIdx.x"): # pylint: disable=too-many-nested-blocks - b: T.int32 = b_h // 64 - h: T.int32 = b_h % 64 - instance_offset: T.int32 = append_len_indptr[b] - rope_offset: T.int32 = rope_offsets[b] - append_len: T.int32 = append_len_indptr[b + 1] - append_len_indptr[b] - for s0 in range((append_len + 31) // 32): - for s1 in T.thread_binding(32, thread="threadIdx.y"): - for d0 in T.thread_binding(32, thread="threadIdx.x"): - for d1 in T.vectorized(4): - s: T.int32 = s0 * 32 + s1 - d: T.int32 = d0 * 4 + d1 - if s < append_len and d < 128: - if h < 32: - q[s + instance_offset, h, d] = T.Cast("float16", T.cos(T.Cast("float32", s + rope_offset) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * q[s + instance_offset, h, d] + T.Cast("float16", T.sin(T.Cast("float32", s + rope_offset) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * T.if_then_else(d < 64, q[s + instance_offset, h, d + 64] * T.float16(-1), q[s + instance_offset, h, d - 64]) - else: - k[s + instance_offset, h - 32, d] = T.Cast("float16", T.cos(T.Cast("float32", s + rope_offset) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * k[s + instance_offset, h - 32, d] + T.Cast("float16", T.sin(T.Cast("float32", s + rope_offset) / T.pow(T.float32(10000), T.Cast("float32", d * 2 % 128) / T.float32(128)))) * T.if_then_else(d < 64, k[s + instance_offset, h - 32, d + 64] * T.float16(-1), k[s + instance_offset, h - 32, d - 64]) - - @R.function - def _initialize_effect() -> R.Tuple(R.Object): - with R.dataflow(): - _io: R.Object = R.null_value() # type: ignore - lv: R.Tuple(R.Object) = (_io,) # type: ignore - gv: R.Tuple(R.Object) = lv # type: ignore - R.output(gv) - return gv - @R.function - def create_flashinfer_paged_kv_cache(max_batch_size: R.Shape(["max_batch_size_1"]), max_total_seq_len: R.Shape(["max_total_seq_len_1"]), prefill_chunk_size: R.Shape(["prefill_chunk_size_1"]), page_size: R.Shape(["page_size_1"]), _io: R.Object) -> R.Tuple(R.Object, R.Tuple(R.Object)): + def create_paged_kv_cache( + max_batch_size: R.Shape(["max_batch_size_1"]), # type: ignore + max_total_seq_len: R.Shape(["max_total_seq_len_1"]), # type: ignore + prefill_chunk_size: R.Shape(["prefill_chunk_size_1"]), # type: ignore + page_size: R.Shape(["page_size_1"]), # type: ignore + support_sliding_window: R.Shape(["support_sliding_window_1"]), # type: ignore + ) -> R.Object: max_batch_size_1 = T.int64() max_total_seq_len_1 = T.int64() prefill_chunk_size_1 = T.int64() page_size_1 = T.int64() + support_sliding_window_1 = T.int64() R.func_attr({"num_input": 5}) - cls = Module with R.dataflow(): - lv2: R.Tensor((), dtype="float16") = R.zeros(R.shape([]), dtype="float16") # type: ignore - paged_kv_cache: R.Object = R.call_packed("vm.builtin.paged_attention_kv_cache_create", R.shape([max_batch_size_1, max_total_seq_len_1, prefill_chunk_size_1, page_size_1]), R.prim_value(32), R.prim_value(32), R.prim_value(32), R.prim_value(128), R.prim_value(0), R.prim_value(1), R.prim_value(10000), lv2, cls.tir_kv_cache_transpose_append, R.ExternFunc("paged_kv_cache.attention_kernel_prefill"), R.ExternFunc("paged_kv_cache.attention_kernel_decode"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_prefill_begin_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_prefill_end_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_decode_begin_forward"), R.ExternFunc("paged_kv_cache.attention_kernel_decode_end_forward"), R.ExternFunc("flashinfer.merge_state_in_place"), cls.fused_rope, cls.tir_rotary, cls.tir_kv_cache_debug_get_kv, sinfo_args=(R.Object,)) - gv2: R.Tuple(R.Object, R.Tuple(R.Object)) = paged_kv_cache, (_io,) # type: ignore - R.output(gv2) - return gv2 + paged_kv_cache: R.Object = R.call_pure_packed("mlc.create_paged_kv_cache_generic", R.shape([max_batch_size_1, max_total_seq_len_1, prefill_chunk_size_1, page_size_1, support_sliding_window_1]), R.prim_value(32), R.prim_value(32), R.prim_value(32), R.prim_value(128), R.prim_value(1), R.prim_value(1), R.prim_value(10000), R.prim_value(128), R.dtype("float16"), sinfo_args=(R.Object,)) + gv1: R.Object = paged_kv_cache + R.output(gv1) + return gv1 @R.function - def forward(cache: R.Object, q: R.Tensor((1, 100, 32, 128), dtype="float16"), k: R.Tensor((1, 100, 32, 128), dtype="float16"), v: R.Tensor((1, 100, 32, 128), dtype="float16"), _io: R.Object) -> R.Tuple(R.Tensor((1, 100, 32, 128), dtype="float16"), R.Tuple(R.Object)): - R.func_attr({"num_input": 5}) + def forward( + cache: R.Object, qkv: R.Tensor((1, 100, 96, 128), dtype="float16") # type: ignore + ) -> R.Tensor((1, 100, 32, 128), dtype="float16"): # type: ignore + R.func_attr({"num_input": 2}) with R.dataflow(): - reshape: R.Tensor((100, 32, 128), dtype="float16") = R.reshape(q, R.shape([100, 32, 128])) # type: ignore - reshape1: R.Tensor((100, 32, 128), dtype="float16") = R.reshape(k, R.shape([100, 32, 128])) # type: ignore - reshape2: R.Tensor((100, 32, 128), dtype="float16") = R.reshape(v, R.shape([100, 32, 128])) # type: ignore - lv1 = R.call_dps_packed("vm.builtin.paged_attention_kv_cache_attention", (cache, R.prim_value(0), reshape, reshape1, reshape2), out_sinfo=R.Tensor((100, 32, 128), dtype="float16")) - reshape3: R.Tensor((1, 100, 32, 128), dtype="float16") = R.reshape(lv1, R.shape([1, 100, 32, 128])) # type: ignore - gv1: R.Tuple(R.Tensor((1, 100, 32, 128), dtype="float16"), R.Tuple(R.Object)) = reshape3, (_io,) # type: ignore - R.output(gv1) - return gv1 + reshape: R.Tensor((100, 96, 128), dtype="float16") = R.reshape( # type: ignore + qkv, R.shape([100, 96, 128]) + ) + lv = R.call_dps_packed( + "vm.builtin.attention_kv_cache_attention_with_fused_qkv", + (cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape), + out_sinfo=R.Tensor((100, 32, 128), dtype="float16"), + ) + reshape1: R.Tensor((1, 100, 32, 128), dtype="float16") = R.reshape( # type: ignore + lv, R.shape([1, 100, 32, 128]) + ) + gv: R.Tensor((1, 100, 32, 128), dtype="float16") = reshape1 # type: ignore + R.output(gv) + return gv # fmt: on class PagedKVCacheTest(modules.Module): def forward( self, cache: PagedKVCache, - q: core.Tensor, - k: core.Tensor, - v: core.Tensor, + qkv: core.Tensor, ) -> core.Tensor: - return cache.attention(0, q, k, v) + return cache.attention_with_fused_qkv(0, qkv, num_qo_heads=32) - def create_flashinfer_paged_kv_cache( + def create_paged_kv_cache( self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, + support_sliding_window: tir.Var, ) -> PagedKVCache: - return FlashInferPagedKVCache( + return PagedKVCache.create_generic( max_batch_size=max_batch_size, max_total_seq_len=max_total_seq_len, prefill_chunk_size=prefill_chunk_size, page_size=page_size, + support_sliding_window=support_sliding_window, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, @@ -179,25 +89,22 @@ def create_flashinfer_paged_kv_cache( rope_theta=10000, rotary_dim=128, dtype="float16", - target=tvm.target.Target("cuda"), ) export_results = PagedKVCacheTest().export_tvm( spec={ "forward": { "cache": spec.Object(object_type=PagedKVCache), - "q": spec.Tensor((1, 100, 32, 128), "float16"), - "k": spec.Tensor((1, 100, 32, 128), "float16"), - "v": spec.Tensor((1, 100, 32, 128), "float16"), + "qkv": spec.Tensor((1, 100, 96, 128), "float16"), }, - "create_flashinfer_paged_kv_cache": { + "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, "prefill_chunk_size": int, "page_size": int, + "support_sliding_window": int, }, }, - debug=True, ) tvm_mod = export_results[0] tvm.ir.assert_structural_equal(tvm_mod, Module, True) diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index b726a6b41d..7ef6e22fe0 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -710,7 +710,7 @@ def test_openai_v1_completions_prompt_overlong( response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) error_msg_prefix = ( - f"Request prompt has {num_tokens} tokens in total, larger than the model capacity" + f"Request prompt has {num_tokens} tokens in total, larger than the model input length limit" ) if not stream: expect_error(response.json(), msg_prefix=error_msg_prefix) @@ -895,6 +895,7 @@ def test_openai_v1_chat_completions_n( "messages": messages, "stream": stream, "n": n, + "max_tokens": 300, } response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) @@ -905,7 +906,7 @@ def test_openai_v1_chat_completions_n( model=served_model[0], object_str="chat.completion", num_choices=n, - finish_reasons=["stop"], + finish_reasons=["stop", "length"], ) else: responses = [] @@ -919,7 +920,7 @@ def test_openai_v1_chat_completions_n( model=served_model[0], object_str="chat.completion.chunk", num_choices=n, - finish_reasons=["stop"], + finish_reasons=["stop", "length"], ) From 8f5e25dcb24af144d833b27fc4acb08658213541 Mon Sep 17 00:00:00 2001 From: Git bot Date: Mon, 18 Mar 2024 13:17:01 +0000 Subject: [PATCH 079/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 641209c69a..c06ec1f245 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 641209c69ad153c02471ba71bdf40a10c90789e5 +Subproject commit c06ec1f24548c0e94e15d3ea3c405f5f475b22af From 386af8dd677820c6b35445e762d3031ad9a9488a Mon Sep 17 00:00:00 2001 From: Kartik Khandelwal Date: Mon, 18 Mar 2024 10:24:30 -0400 Subject: [PATCH 080/531] [REST] Update Rest API docs for the latest serve flow (#1972) * [Docs][Upd] Server launch, examples for endpoints for MLC Serve * remove v1/completions * add api docs to rest --------- Co-authored-by: Shrey Gupta --- docs/deploy/rest.rst | 502 ++++++++++++++++++++----------------------- 1 file changed, 235 insertions(+), 267 deletions(-) diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index d955d6066f..959c235201 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -11,51 +11,35 @@ for a user to interact with MLC-Chat in their own programs. Install MLC-Chat Package ------------------------ -The REST API is a part of the MLC-Chat package, which we have prepared pre-built :doc:`pip wheels <../install/mlc_llm>`. +SERVE is a part of the MLC-Chat package, installation instruction for which we be found here :doc:`<../install/mlc_llm>`. Verify Installation ^^^^^^^^^^^^^^^^^^^ .. code:: bash - python -m mlc_llm.rest --help + python -m mlc_llm.serve.server --help -You are expected to see the help information of the REST API. +You are expected to see the help information of the MLC SERVE. .. _mlcchat_package_build_from_source: -Optional: Build from Source -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -If the prebuilt is unavailable on your platform, or you would like to build a runtime -that supports other GPU runtime than the prebuilt version. We can build a customized version -of mlc chat runtime. You only need to do this if you choose not to use the prebuilt. - -First, make sure you install TVM unity (following the instruction in :ref:`install-tvm-unity`). -You can choose to only pip install `mlc-ai-nightly` that comes with the tvm unity but skip `mlc-llm-nightly`. -Then please follow the instructions in :ref:`mlcchat_build_from_source` to build the necessary libraries. - -You can now use ``mlc_llm`` package by including the `python` directory to ``PYTHONPATH`` environment variable. - -.. code:: bash - - PYTHONPATH=python python -m mlc_llm.rest --help Launch the Server ----------------- -To launch the REST server for MLC-Chat, run the following command in your terminal. +To launch the MLC Server for MLC-Chat, run the following command in your terminal. .. code:: bash - python -m mlc_llm.rest --model MODEL [--lib-path LIB_PATH] [--device DEVICE] [--host HOST] [--port PORT] + python -m mlc_llm.serve.server --model MODEL --model-lib-path MODEL_LIB_PATH [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] --model The model folder after compiling with MLC-LLM build process. The parameter can either be the model name with its quantization scheme (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model folder. In the former case, we will use the provided name to search for the model folder over possible paths. ---lib-path An optional field to specify the full path to the model library file to use (e.g. a ``.so`` file). +--model-lib-path A field to specify the full path to the model library file to use (e.g. a ``.so`` file). --device The description of the device to run on. User should provide a string in the form of 'device_name:device_id' or 'device_name', where 'device_name' is one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the @@ -63,6 +47,15 @@ To launch the REST server for MLC-Chat, run the following command in your termin with the device id set to 0 for default. --host The host at which the server should be started, defaults to ``127.0.0.1``. --port The port on which the server should be started, defaults to ``8000``. +--allow-credentials A flag to indicate whether the server should allow credentials. If set, the server will + include the ``CORS`` header in the response +--allowed-origins Specifies the allowed origins. It expects a JSON list of strings, with the default value being ``["*"]``, allowing all origins. +--allowed-methods Specifies the allowed methods. It expects a JSON list of strings, with the default value being ``["*"]``, allowing all methods. +--allowed-headers Specifies the allowed headers. It expects a JSON list of strings, with the default value being ``["*"]``, allowing all headers. +--max-batch-size The maximum batch size for processing. +--max-total-seq-length The maximum total number of tokens whose KV data are allowed to exist in the KV cache at any time. Set it to None to enable automatic computation of the max total sequence length. +--prefill-chunk-size The maximum total sequence length in a prefill. If not specified, it will be automatically inferred from model config. +--enable-tracing A boolean indicating if to enable event logging for requests. You can access ``http://127.0.0.1:PORT/docs`` (replace ``PORT`` with the port number you specified) to see the list of supported endpoints. @@ -72,66 +65,28 @@ API Endpoints The REST API provides the following endpoints: -.. http:get:: /v1/completions +.. http:get:: /v1/models ------------------------------------------------ - Get a completion from MLC-Chat using a prompt. - -**Request body** - -**model**: *str* (required) - The model folder after compiling with MLC-LLM build process. The parameter - can either be the model name with its quantization scheme - (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model - folder. In the former case, we will use the provided name to search - for the model folder over possible paths. -**prompt**: *str* (required) - A list of chat messages. The last message should be from the user. -**stream**: *bool* (optional) - Whether to stream the response. If ``True``, the response will be streamed - as the model generates the response. If ``False``, the response will be - returned after the model finishes generating the response. -**temperature**: *float* (optional) - The temperature applied to logits before sampling. The default value is - ``0.7``. A higher temperature encourages more diverse outputs, while a - lower temperature produces more deterministic outputs. -**top_p**: *float* (optional) - This parameter determines the set of tokens from which we sample during - decoding. The default value is set to ``0.95``. At each step, we select - tokens from the minimal set that has a cumulative probability exceeding - the ``top_p`` parameter. - - For additional information on top-p sampling, please refer to this blog - post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. -**repetition_penalty**: *float* (optional) - The repetition penalty controls the likelihood of the model generating - repeated texts. The default value is set to ``1.0``, indicating that no - repetition penalty is applied. Increasing the value reduces the - likelihood of repeat text generation. However, setting a high - ``repetition_penalty`` may result in the model generating meaningless - texts. The ideal choice of repetition penalty may vary among models. - - For more details on how repetition penalty controls text generation, please - check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). -**presence_penalty**: *float* (optional) - Positive values penalize new tokens if they are already present in the text so far, - decreasing the model's likelihood to repeat tokens. -**frequency_penalty**: *float* (optional) - Positive values penalize new tokens based on their existing frequency in the text so far, - decreasing the model's likelihood to repeat tokens. -**mean_gen_len**: *int* (optional) - The approximated average number of generated tokens in each round. Used - to determine whether the maximum window size would be exceeded. -**max_gen_len**: *int* (optional) - This parameter determines the maximum length of the generated text. If it is - not set, the model will generate text until it encounters a stop token. + Get a list of models available for MLC-Chat. ------------------------------------------------- +**Example** -**Returns** - If ``stream`` is set to ``False``, the response will be a ``CompletionResponse`` object. - If ``stream`` is set to ``True``, the response will be a stream of ``CompletionStreamResponse`` objects. +.. code:: bash + + import requests + + url = "http://127.0.0.1:8000/v1/models" + headers = {"accept": "application/json"} + + response = requests.get(url, headers=headers) + + if response.status_code == 200: + print("Response:") + print(response.json()) + else: + print("Error:", response.status_code) .. http:get:: /v1/chat/completions @@ -140,255 +95,268 @@ The REST API provides the following endpoints: Get a response from MLC-Chat using a prompt, either with or without streaming. -**Request body** - -**model**: *str* (required) - The model folder after compiling with MLC-LLM build process. The parameter - can either be the model name with its quantization scheme - (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model - folder. In the former case, we will use the provided name to search - for the model folder over possible paths. -**messages**: *list[ChatMessage]* (required) - A list of chat messages. The last message should be from the user. -**stream**: *bool* (optional) - Whether to stream the response. If ``True``, the response will be streamed - as the model generates the response. If ``False``, the response will be - returned after the model finishes generating the response. -**temperature**: *float* (optional) - The temperature applied to logits before sampling. The default value is - ``0.7``. A higher temperature encourages more diverse outputs, while a - lower temperature produces more deterministic outputs. -**top_p**: *float* (optional) - This parameter determines the set of tokens from which we sample during - decoding. The default value is set to ``0.95``. At each step, we select - tokens from the minimal set that has a cumulative probability exceeding - the ``top_p`` parameter. - - For additional information on top-p sampling, please refer to this blog - post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. -**repetition_penalty**: *float* (optional) - The repetition penalty controls the likelihood of the model generating - repeated texts. The default value is set to ``1.0``, indicating that no - repetition penalty is applied. Increasing the value reduces the - likelihood of repeat text generation. However, setting a high - ``repetition_penalty`` may result in the model generating meaningless - texts. The ideal choice of repetition penalty may vary among models. - - For more details on how repetition penalty controls text generation, please - check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). -**presence_penalty**: *float* (optional) - Positive values penalize new tokens if they are already present in the text so far, - decreasing the model's likelihood to repeat tokens. -**frequency_penalty**: *float* (optional) - Positive values penalize new tokens based on their existing frequency in the text so far, - decreasing the model's likelihood to repeat tokens. -**mean_gen_len**: *int* (optional) - The approximated average number of generated tokens in each round. Used - to determine whether the maximum window size would be exceeded. -**max_gen_len**: *int* (optional) - This parameter determines the maximum length of the generated text. If it is - not set, the model will generate text until it encounters a stop token. -**n**: *int* (optional) - This parameter determines the number of text samples to generate. The default - value is ``1``. Note that this parameter is only used when ``stream`` is set to - ``False``. -**stop**: *str* or *list[str]* (optional) - When ``stop`` is encountered, the model will stop generating output. - It can be a string or a list of strings. If it is a list of strings, the model - will stop generating output when any of the strings in the list is encountered. - Note that this parameter does not override the default stop string of the model. - ------------------------------------------------- +**Chat Completion Request Object** -**Returns** - If ``stream`` is set to ``False``, the response will be a ``ChatCompletionResponse`` object. - If ``stream`` is set to ``True``, the response will be a stream of ``ChatCompletionStreamResponse`` objects. +- **messages** (*List[ChatCompletionMessage]*, required): A sequence of messages that have been exchanged in the conversation so far. Each message in the conversation is represented by a `ChatCompletionMessage` object, which includes the following fields: + - **content** (*Optional[Union[str, List[Dict[str, str]]]]*): The text content of the message or structured data in case of tool-generated messages. + - **role** (*Literal["system", "user", "assistant", "tool"]*): The role of the message sender, indicating whether the message is from the system, user, assistant, or a tool. + - **name** (*Optional[str]*): An optional name for the sender of the message. + - **tool_calls** (*Optional[List[ChatToolCall]]*): A list of calls to external tools or functions made within this message, applicable when the role is `tool`. + - **tool_call_id** (*Optional[str]*): A unique identifier for the tool call, relevant when integrating external tools or services. + +- **model** (*str*, required): The model to be used for generating responses. -.. http:get:: /chat/reset +- **frequency_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model’s likelihood to repeat tokens. - Reset the chat. +- **presence_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens if they are already present in the text so far, decreasing the model’s likelihood to repeat tokens. -.. http:get:: /stats +- **logprobs** (*bool*, optional, default=False): Indicates whether to include log probabilities for each token in the response. - Get the latest runtime stats (encode/decode speed). +- **top_logprobs** (*int*, optional, default=0): An integer ranging from 0 to 5. It determines the number of tokens, most likely to appear at each position, to be returned. Each token is accompanied by a log probability. If this parameter is used, 'logprobs' must be set to true. -.. http:get:: /verbose_stats +- **logit_bias** (*Optional[Dict[int, float]]*): Allows specifying biases for or against specific tokens during generation. - Get the verbose runtime stats (encode/decode speed, total runtime). +- **max_tokens** (*Optional[int]*): The maximum number of tokens to generate in the response(s). +- **n** (*int*, optional, default=1): Number of responses to generate for the given prompt. -Request Objects ---------------- +- **seed** (*Optional[int]*): A seed for deterministic generation. Using the same seed and inputs will produce the same output. -**ChatMessage** +- **stop** (*Optional[Union[str, List[str]]]*): One or more strings that, if encountered, will cause generation to stop. -**role**: *str* (required) - The role(author) of the message. It can be either ``user`` or ``assistant``. -**content**: *str* (required) - The content of the message. -**name**: *str* (optional) - The name of the author of the message. +- **stream** (*bool*, optional, default=False): If `True`, responses are streamed back as they are generated. -Response Objects ----------------- +- **temperature** (*float*, optional, default=1.0): Controls the randomness of the generation. Lower values lead to less random completions. -**CompletionResponse** +- **top_p** (*float*, optional, default=1.0): Nucleus sampling parameter that controls the diversity of the generated responses. -**id**: *str* - The id of the completion. -**object**: *str* - The object name ``text.completion``. -**created**: *int* - The time when the completion is created. -**choices**: *list[CompletionResponseChoice]* - A list of choices generated by the model. -**usage**: *UsageInfo* or *None* - The usage information of the model. +- **tools** (*Optional[List[ChatTool]]*): Specifies external tools or functions that can be called as part of the chat. ------------------------------------------------- +- **tool_choice** (*Optional[Union[Literal["none", "auto"], Dict]]*): Controls how tools are selected for use in responses. -**CompletionResponseChoice** +- **user** (*Optional[str]*): An optional identifier for the user initiating the request. -**index**: *int* - The index of the choice. -**text**: *str* - The message generated by the model. -**finish_reason**: *str* - The reason why the model finishes generating the message. It can be either - ``stop`` or ``length``. +- **ignore_eos** (*bool*, optional, default=False): If `True`, the model will ignore the end-of-sequence token for generating responses. +- **response_format** (*RequestResponseFormat*, optional): Specifies the format of the response. Can be either "text" or "json_object", with optional schema definition for JSON responses. ------------------------------------------------- +**Returns** -**CompletionStreamResponse** +- If `stream` is `False`, a `ChatCompletionResponse` object containing the generated response(s). +- If `stream` is `True`, a stream of `ChatCompletionStreamResponse` objects, providing a real-time feed of generated responses. -**id**: *str* - The id of the completion. -**object**: *str* - The object name ``text.completion.chunk``. -**created**: *int* - The time when the completion is created. -**choices**: *list[ChatCompletionResponseStreamhoice]* - A list of choices generated by the model. ------------------------------------------------- +**ChatCompletionResponseChoice** -**ChatCompletionResponseStreamChoice** +- **finish_reason** (*Optional[Literal["stop", "length", "tool_calls", "error"]]*, optional): The reason the completion process was terminated. It can be due to reaching a stop condition, the maximum length, output of tool calls, or an error. + +- **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices. + +- **message** (*ChatCompletionMessage*, required): The message part of the chat completion, containing the content of the chat response. + +- **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token -**index**: *int* - The index of the choice. -**text**: *str* - The message generated by the model. -**finish_reason**: *str* - The reason why the model finishes generating the message. It can be either - ``stop`` or ``length``. +**ChatCompletionStreamResponseChoice** ------------------------------------------------- +- **finish_reason** (*Optional[Literal["stop", "length", "tool_calls"]]*, optional): Specifies why the streaming completion process ended. Valid reasons are "stop", "length", and "tool_calls". + +- **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices. + +- **delta** (*ChatCompletionMessage*, required): Represents the incremental update or addition to the chat completion message in the stream. + +- **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token **ChatCompletionResponse** -**id**: *str* - The id of the completion. -**object**: *str* - The object name ``chat.completion``. -**created**: *int* - The time when the completion is created. -**choices**: *list[ChatCompletionResponseChoice]* - A list of choices generated by the model. -**usage**: *UsageInfo* or *None* - The usage information of the model. +- **id** (*str*, required): A unique identifier for the chat completion session. + +- **choices** (*List[ChatCompletionResponseChoice]*, required): A collection of `ChatCompletionResponseChoice` objects, representing the potential responses generated by the model. + +- **created** (*int*, required, default=current time): The UNIX timestamp representing when the response was generated. + +- **model** (*str*, required): The name of the model used to generate the chat completions. + +- **system_fingerprint** (*str*, required): A system-generated fingerprint that uniquely identifies the computational environment. + +- **object** (*Literal["chat.completion"]*, required, default="chat.completion"): A string literal indicating the type of object, here always "chat.completion". + +- **usage** (*UsageInfo*, required, default=empty `UsageInfo` object): Contains information about the API usage for this specific request. ------------------------------------------------- - -**ChatCompletionResponseChoice** +**ChatCompletionStreamResponse** -**index**: *int* - The index of the choice. -**message**: *ChatMessage* - The message generated by the model. -**finish_reason**: *str* - The reason why the model finishes generating the message. It can be either - ``stop`` or ``length``. +- **id** (*str*, required): A unique identifier for the streaming chat completion session. + +- **choices** (*List[ChatCompletionStreamResponseChoice]*, required): A list of `ChatCompletionStreamResponseChoice` objects, each representing a part of the streaming chat response. + +- **created** (*int*, required, default=current time): The creation time of the streaming response, represented as a UNIX timestamp. + +- **model** (*str*, required): Specifies the model that was used for generating the streaming chat completions. + +- **system_fingerprint** (*str*, required): A unique identifier for the system generating the streaming completions. + +- **object** (*Literal["chat.completion.chunk"]*, required, default="chat.completion.chunk"): A literal indicating that this object represents a chunk of a streaming chat completion. ------------------------------------------------ -**ChatCompletionStreamResponse** -**id**: *str* - The id of the completion. -**object**: *str* - The object name ``chat.completion.chunk``. -**created**: *int* - The time when the completion is created. -**choices**: *list[ChatCompletionResponseStreamhoice]* - A list of choices generated by the model. +**Example** ------------------------------------------------- +Once you have launched the Server, you can use the API in your own program. Below is an example of using the API to interact with MLC-Chat in Python without Streaming (suppose the server is running on ``http://127.0.0.1:8080/``): -**ChatCompletionResponseStreamChoice** +.. code:: bash -**index**: *int* - The index of the choice. -**delta**: *DeltaMessage* - The delta message generated by the model. -**finish_reason**: *str* - The reason why the model finishes generating the message. It can be either - ``stop`` or ``length``. + import requests + + # Get a response using a prompt without streaming + payload = { + "model": "./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/", + "messages": [ + {"role": "user", "content": "Hello! Our project is MLC LLM."}, + { + "role": "assistant", + "content": "Hello! It's great to hear about your project, MLC LLM.", + }, + {"role": "user", "content": "What is the name of our project?"}, + ], + "stream": False, + # "n": 1, + "max_tokens": 300, + } + r = requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload) + choices = r.json()["choices"] + for choice in choices: + print(f"{choice['message']['content']}\n") ------------------------------------------------ +Below is an example of using the API to interact with MLC-Chat in Python with Streaming. -**DeltaMessage** +.. code:: bash + + import requests + import json -**role**: *str* - The role(author) of the message. It can be either ``user`` or ``assistant``. -**content**: *str* - The content of the message. + # Get a response using a prompt with streaming + payload = { + "model": "./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/", + "messages": [{"role": "user", "content": "Write a haiku"}], + "stream": True, + } + with requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload, stream=True) as r: + for chunk in r.iter_content(chunk_size=None): + chunk = chunk.decode("utf-8") + if "[DONE]" in chunk[6:]: + break + response = json.loads(chunk[6:]) + content = response["choices"][0]["delta"].get("content", "") + print(content, end="", flush=True) + print("\n") ------------------------------------------------ -Use REST API in your own program --------------------------------- - -Once you have launched the REST server, you can use the REST API in your own program. Below is an example of using REST API to interact with MLC-Chat in Python (suppose the server is running on ``http://127.0.0.1:8000/``): +There is also support for function calling similar to OpenAI (https://platform.openai.com/docs/guides/function-calling). Below is an example on how to use function calling in Python. .. code:: bash import requests import json - # Get a response using a prompt without streaming + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + payload = { - "model": "vicuna-v1-7b", - "messages": [{"role": "user", "content": "Write a haiku"}], - "stream": False + "model": "./dist/gorilla-openfunctions-v1-q4f16_1-MLC/", + "messages": [ + { + "role": "user", + "content": "What is the current weather in Pittsburgh, PA in fahrenheit?", + } + ], + "stream": False, + "tools": tools, } - r = requests.post("http://127.0.0.1:8000/v1/chat/completions", json=payload) - print(f"Without streaming:\n{r.json()['choices'][0]['message']['content']}\n") - # Reset the chat - r = requests.post("http://127.0.0.1:8000/chat/reset", json=payload) - print(f"Reset chat: {str(r)}\n") + r = requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload) + print(f"{r.json()['choices'][0]['message']['tool_calls'][0]['function']}\n") + + # Output: {'name': 'get_current_weather', 'arguments': {'location': 'Pittsburgh, PA', 'unit': 'fahrenheit'}} + +------------------------------------------------ + +Function Calling with streaming is also supported. Below is an example on how to use function calling with streaming in Python. + +.. code:: bash + + import requests + import json + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] - # Get a response using a prompt with streaming payload = { - "model": "vicuna-v1-7b", - "messages": [{"role": "user", "content": "Write a haiku"}], - "stream": True + "model": "./dist/gorilla-openfunctions-v1-q4f16_1-MLC/", + "messages": [ + { + "role": "user", + "content": "What is the current weather in Pittsburgh, PA and Tokyo, JP in fahrenheit?", + } + ], + "stream": True, + "tools": tools, } - with requests.post("http://127.0.0.1:8000/v1/chat/completions", json=payload, stream=True) as r: - print(f"With streaming:") - for chunk in r: - content = json.loads(chunk[6:-2])["choices"][0]["delta"].get("content", "") - print(f"{content}", end="", flush=True) - print("\n") - # Get the latest runtime stats - r = requests.get("http://127.0.0.1:8000/stats") - print(f"Runtime stats: {r.json()}\n") + with requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload, stream=True) as r: + for chunk in r.iter_content(chunk_size=None): + chunk = chunk.decode("utf-8") + if "[DONE]" in chunk[6:]: + break + response = json.loads(chunk[6:]) + content = response["choices"][0]["delta"].get("content", "") + print(f"{content}", end="", flush=True) + print("\n") + + # Output: ["get_current_weather(location='Pittsburgh,PA',unit='fahrenheit')", "get_current_weather(location='Tokyo,JP',unit='fahrenheit')"] -Please check `example folder `__ for more examples using REST API. .. note:: - The REST API is a uniform interface that supports multiple languages. You can also utilize the REST API in languages other than Python. + The API is a uniform interface that supports multiple languages. You can also utilize these functionalities in languages other than Python. + + + From 4db43735d5da9a2bffc5d411dce732c220d7ea75 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Mon, 18 Mar 2024 13:18:05 -0400 Subject: [PATCH 081/531] [Conv] Add bos_token to llama and mistral in ConvTemplateRegistry (#1970) Since we don't have the `add_bos` field in the new Conversation template, we should add the bos token into the system_prefix_token_ids, so that it will be added to the tokenized prompt. --- python/mlc_llm/conversation_template.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index fb367b7aa3..d69be848bc 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -48,6 +48,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: role_empty_sep=" ", stop_str=["[INST]"], stop_token_ids=[2], + system_prefix_token_ids=[1], ) ) @@ -65,6 +66,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: role_empty_sep="", stop_str=[""], stop_token_ids=[2], + system_prefix_token_ids=[1], ) ) From 949ff2dd4e1a01043bec64f094e072dbb9405234 Mon Sep 17 00:00:00 2001 From: Animesh Bohara Date: Mon, 18 Mar 2024 16:44:09 -0400 Subject: [PATCH 082/531] [Model][Serve] Add support for LLaVa model in serving engine (#1974) This PR adds support for LLaVa-v1.5 model on the serving engine. Use the HF weights and config from https://huggingface.co/llava-hf/llava-1.5-7b-hf. Passing image input is supported as url (reference: https://platform.openai.com/docs/guides/vision) Example: ```python data = { "model": "dist/llava-1.5-7b-hf-q4f16_1-MLC/params/", "messages": [ { "role": "user", "content": [ { "type": "image_url", "image_url": "https://llava-vl.github.io/static/images/view.jpg", }, {"type": "text", "text": "What does this image represent?"}, ], } ] } response = requests.post("http://127.0.0.1:8000/v1/chat/completions", json=data) print("Response body:", response.text) ``` --- cpp/serve/data.cc | 25 + cpp/serve/data.h | 23 + cpp/serve/function_table.cc | 1 + cpp/serve/function_table.h | 1 + cpp/serve/model.cc | 16 + cpp/serve/model.h | 7 + cpp/serve/request.cc | 2 + python/mlc_llm/conversation_template.py | 15 + python/mlc_llm/interface/gen_config.py | 1 + python/mlc_llm/model/llava/__init__.py | 0 python/mlc_llm/model/llava/llava_loader.py | 162 +++++ python/mlc_llm/model/llava/llava_model.py | 623 ++++++++++++++++++ .../mlc_llm/model/llava/llava_quantization.py | 53 ++ python/mlc_llm/model/model.py | 16 + python/mlc_llm/model/model_preset.py | 34 + .../mlc_llm/protocol/conversation_protocol.py | 85 ++- python/mlc_llm/serve/__init__.py | 2 +- python/mlc_llm/serve/async_engine.py | 20 +- python/mlc_llm/serve/data.py | 24 + python/mlc_llm/serve/engine.py | 17 +- .../serve/entrypoints/entrypoint_utils.py | 45 +- .../serve/entrypoints/openai_entrypoints.py | 28 +- python/mlc_llm/serve/server/server_context.py | 13 + .../python/serve/server/test_server_image.py | 258 ++++++++ tests/python/serve/test_serve_engine_image.py | 50 ++ 25 files changed, 1496 insertions(+), 25 deletions(-) create mode 100644 python/mlc_llm/model/llava/__init__.py create mode 100644 python/mlc_llm/model/llava/llava_loader.py create mode 100644 python/mlc_llm/model/llava/llava_model.py create mode 100644 python/mlc_llm/model/llava/llava_quantization.py create mode 100644 tests/python/serve/server/test_server_image.py create mode 100644 tests/python/serve/test_serve_engine_image.py diff --git a/cpp/serve/data.cc b/cpp/serve/data.cc index e6155061db..fe104a33ea 100644 --- a/cpp/serve/data.cc +++ b/cpp/serve/data.cc @@ -79,6 +79,31 @@ TVM_REGISTER_GLOBAL("mlc.serve.TokenDataGetTokenIds").set_body_typed([](TokenDat return data->token_ids; }); +/****************** ImageData ******************/ + +TVM_REGISTER_OBJECT_TYPE(ImageDataNode); + +ImageData::ImageData(NDArray image, int embed_size) { + ObjectPtr n = make_object(); + n->image = std::move(image); + n->embed_size = embed_size; + data_ = std::move(n); +} + +int ImageDataNode::GetLength() const { return embed_size; } + +ObjectRef ImageDataNode::GetEmbedding(Model model, ObjectRef* dst, int offset) const { + return model->ImageEmbed(image, dst, offset); +} + +TVM_REGISTER_GLOBAL("mlc.serve.ImageData").set_body_typed([](NDArray image, int embed_size) { + return ImageData(std::move(image), embed_size); +}); + +TVM_REGISTER_GLOBAL("mlc.serve.ImageDataGetImage").set_body_typed([](ImageData data) { + return data->image; +}); + /****************** SampleResult ******************/ /*! \brief Convert a single token with probability to JSON string. */ diff --git a/cpp/serve/data.h b/cpp/serve/data.h index b9558b8fad..d225bb6acc 100644 --- a/cpp/serve/data.h +++ b/cpp/serve/data.h @@ -100,6 +100,29 @@ class TokenData : public Data { TVM_DEFINE_OBJECT_REF_METHODS(TokenData, Data, TokenDataNode); }; +/****************** ImageDataNode ******************/ + +/*! \brief The class of image data, containing a 3D array of pixel values. */ +class ImageDataNode : public DataNode { + public: + /*! \brief The pixel values. */ + NDArray image; + int embed_size; + + int GetLength() const final; + ObjectRef GetEmbedding(Model model, ObjectRef* dst = nullptr, int offset = 0) const final; + + static constexpr const char* _type_key = "mlc.serve.ImageData"; + TVM_DECLARE_BASE_OBJECT_INFO(ImageDataNode, DataNode); +}; + +class ImageData : public Data { + public: + explicit ImageData(NDArray image, int embed_size); + + TVM_DEFINE_OBJECT_REF_METHODS(ImageData, Data, ImageDataNode); +}; + /****************** SampleResult ******************/ // The pair of a token id and its probability in sampling. diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index d7c70a508a..f4466c875b 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -191,6 +191,7 @@ ObjectRef FunctionTable::LoadParams(const std::string& model_path, Device device void FunctionTable::_InitFunctions() { this->embed_func_ = mod_get_func("embed"); + this->image_embed_func_ = mod_get_func("image_embed"); this->single_batch_prefill_func_ = mod_get_func("prefill"); this->single_batch_decode_func_ = mod_get_func("decode"); this->prefill_func_ = mod_get_func("batch_prefill"); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index 5a515ba9b7..29d9d82fbc 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -66,6 +66,7 @@ struct FunctionTable { ModelMetadata model_metadata_; PackedFunc embed_func_; + PackedFunc image_embed_func_; PackedFunc single_batch_prefill_func_; PackedFunc single_batch_decode_func_; PackedFunc prefill_func_; diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 0463728df0..94645b8634 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -94,6 +94,20 @@ class ModelImpl : public ModelObj { } } + ObjectRef ImageEmbed(const NDArray& image, ObjectRef* dst, int offset) final { + CHECK(ft_.image_embed_func_.defined()) << "`image_embed` function is not found in the model. "; + auto image_dref_or_nd = ft_.CopyToWorker0(image, "image", image.Shape()); + ObjectRef embeddings = ft_.image_embed_func_(image_dref_or_nd, params_); + if (dst != nullptr) { + CHECK(dst->defined()); + ft_.nd_copy_embedding_to_offset_func_(embeddings, *dst, offset); + return *dst; + } else { + CHECK_EQ(offset, 0); + return embeddings; + } + } + NDArray BatchPrefill(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) final { CHECK(!seq_ids.empty()); @@ -419,6 +433,7 @@ class ModelImpl : public ModelObj { } else { LOG(FATAL) << "Key \"vocab_size\" not found."; } + return config; } @@ -433,6 +448,7 @@ class ModelImpl : public ModelObj { int prefill_chunk_size_ = -1; int hidden_size_ = -1; int vocab_size_ = -1; + int image_embed_size_ = -1; //---------------------------- // TVM related states //---------------------------- diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 1019834921..4edd272638 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -84,6 +84,13 @@ class ModelObj : public Object { virtual ObjectRef TokenEmbed(IntTuple batch_token_ids, ObjectRef* dst = nullptr, int offset = 0) = 0; + /*! + * \brief Compute embeddings for the input image. + * \param image The image to compute embedding for. + * \return The computed embeddings. + */ + virtual ObjectRef ImageEmbed(const NDArray& image, ObjectRef* dst = nullptr, int offset = 0) = 0; + /*! * \brief Batch prefill function. Embedding in, logits out. * The embedding order of sequences in `embedding_arr` follows diff --git a/cpp/serve/request.cc b/cpp/serve/request.cc index 25162d79fb..8ecd20b18e 100644 --- a/cpp/serve/request.cc +++ b/cpp/serve/request.cc @@ -26,6 +26,8 @@ Request::Request(String id, Array inputs, GenerationConfig generation_cfg) for (Data input : inputs) { if (const auto* token_data = input.as()) { input_total_length += token_data->token_ids.size(); + } else if (const auto* image_data = input.as()) { + input_total_length += image_data->GetLength(); } else { input_total_length = -1; break; diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index d69be848bc..c1c8f49426 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -154,3 +154,18 @@ def get_conv_template(name: str) -> Optional[Conversation]: stop_token_ids=[0], ) ) + +# Llava +ConvTemplateRegistry.register_conv_template( + Conversation( + name="llava", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "USER", "assistant": "ASSISTANT", "tool": "USER"}, + seps=[" "], + role_content_sep=": ", + role_empty_sep=":", + stop_str=[""], + stop_token_ids=[2], + ) +) diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index 4bce52aa20..890b467688 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -243,4 +243,5 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b "stablelm-2", "gemma_instruction", "orion", + "llava", } diff --git a/python/mlc_llm/model/llava/__init__.py b/python/mlc_llm/model/llava/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/llava/llava_loader.py b/python/mlc_llm/model/llava/llava_loader.py new file mode 100644 index 0000000000..cf80e262d1 --- /dev/null +++ b/python/mlc_llm/model/llava/llava_loader.py @@ -0,0 +1,162 @@ +""" +This file specifies how MLC's Llava parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .llava_model import LlavaConfig, LlavaForCasualLM +from .llava_quantization import awq_quant + + +def huggingface(model_config: LlavaConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : LlavaConfig + The configuration of the Llava model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = LlavaForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), allow_extern=True + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.text_config.num_hidden_layers): + # Add QKV in self attention + attn = f"language_model.model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"language_model.model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping + + +def awq(model_config: LlavaConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of AWQ parameters. + Parameters + ---------- + model_config : LlavaConfig + The configuration of the Llava model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to AWQ. + """ + model, _ = awq_quant(model_config, quantization) + _, _named_params = model.export_tvm(spec=model.get_default_spec()) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.text_config.num_hidden_layers): + # Add QKV in self attention + attn = f"language_model.model.layers.{i}.self_attn" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{attn}.qkv_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{quantize_suffix}", + f"{attn}.k_proj.{quantize_suffix}", + f"{attn}.v_proj.{quantize_suffix}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Concat gate and up in MLP + mlp = f"language_model.model.layers.{i}.mlp" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.{quantize_suffix}", + f"{mlp}.up_proj.{quantize_suffix}", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype), + ) + return mapping diff --git a/python/mlc_llm/model/llava/llava_model.py b/python/mlc_llm/model/llava/llava_model.py new file mode 100644 index 0000000000..30963f990c --- /dev/null +++ b/python/mlc_llm/model/llava/llava_model.py @@ -0,0 +1,623 @@ +""" +Implementation of LLaVa Model +Implements the CLIP Vision Encoder. Uses Llama for the Language Encoder. +""" + +import dataclasses +import logging +from typing import Any, Dict, Optional, Tuple + +from tvm import relax, te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Module, Tensor, op +from tvm.relax.frontend.nn.modules import Conv2D +from tvm.relax.frontend.nn.op import ( + broadcast_to, + concat, + matmul, + permute_dims, + reshape, + softmax, + wrap_nested, +) +from tvm.relax.op import arange, strided_slice + +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode + +from ...support.config import ConfigBase +from ..llama.llama_model import LlamaConfig, LlamaForCasualLM + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class LlavaVisionConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """ + Config for the vision encoder + """ + + hidden_size: int + image_size: int + intermediate_size: int + num_attention_heads: int + num_hidden_layers: int + patch_size: int + projection_dim: int + vocab_size: int + dtype: str = "float16" + num_channels: int = 3 + layer_norm_eps: float = 1e-06 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class LlavaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """ + LLaVa Config + """ + + image_token_index: int + text_config: LlamaConfig + vision_config: LlavaVisionConfig + vocab_size: int + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + dtype: str = "float16" + max_batch_size: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + vision_config_dict: Dict[str, Any] + if isinstance(self.vision_config, LlavaVisionConfig): + vision_config_dict = dataclasses.asdict(self.vision_config) + else: + vision_config_dict = dict(self.vision_config) + + for k, v in vision_config_dict.pop("kwargs", {}).items(): + vision_config_dict[k] = v + + self.vision_config = LlavaVisionConfig.from_dict(vision_config_dict) + + text_config_dict: Dict[str, Any] + if isinstance(self.text_config, LlamaConfig): + text_config_dict = dataclasses.asdict(self.text_config) + else: + text_config_dict = dict(self.text_config) + + if "_name_or_path" in text_config_dict: + if text_config_dict["_name_or_path"] == "meta-llama/Llama-2-7b-hf": + text_config_dict["hidden_size"] = text_config_dict.pop("hidden_size", 4096) + text_config_dict["intermediate_size"] = text_config_dict.pop( + "intermediate_size", 11008 + ) + text_config_dict["num_attention_heads"] = text_config_dict.pop( + "num_attention_heads", 32 + ) + text_config_dict["num_hidden_layers"] = text_config_dict.pop( + "num_hidden_layers", 32 + ) + text_config_dict["rms_norm_eps"] = text_config_dict.pop("rms_norm_eps", 1e-06) + text_config_dict["vocab_size"] = text_config_dict.pop("vocab_size", 32064) + text_config_dict["context_window_size"] = text_config_dict.pop( + "context_window_size", 4096 + ) + else: + raise ValueError("Unsupported text model") + else: + for k, v in text_config_dict.pop("kwargs", {}).items(): + text_config_dict[k] = v + + self.text_config = LlamaConfig.from_dict(text_config_dict) + + if self.context_window_size <= 0: + self.context_window_size = self.text_config.context_window_size + + if self.prefill_chunk_size <= 0: + self.prefill_chunk_size = self.text_config.prefill_chunk_size + + +# pylint: disable=missing-docstring + + +class CLIPVisionEmbeddings(Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: LlavaVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.class_embedding = nn.Parameter((self.embed_dim,), dtype=config.dtype) + self.patch_embedding = Conv2D( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + dtype=config.dtype, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding( + num=self.num_positions, dim=self.embed_dim, dtype=config.dtype + ) + + def forward(self, pixel_values: Tensor) -> Tensor: + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + patch_embeds = reshape(patch_embeds, shape=(batch_size, self.embed_dim, -1)) + patch_embeds = permute_dims( + patch_embeds, axes=(0, 2, 1) + ) # shape = [batch,grid*grid,embed_dim] + class_embeds = broadcast_to( + self.class_embedding, shape=(batch_size, 1, self.embed_dim) + ) # shape of (batch,1,embed_dim) + embeddings = concat([class_embeds, patch_embeds], dim=1) + + posi_ids = reshape( + wrap_nested(arange(0, self.num_positions, dtype="int32"), name="arange"), shape=(1, -1) + ) + batch_position_embedding = broadcast_to( + self.position_embedding(posi_ids), + shape=(batch_size, self.num_positions, self.embed_dim), + ) + embeddings = embeddings + batch_position_embedding + return embeddings + + +def sigmoid(x: Tensor, name: str = "sigmoid") -> Tensor: + """Sigmoid of a Tensor + + Parameters + ---------- + x : Tensor + Input tensor to expand. + name : str + Name hint for this operator. + + Returns + ------- + result : Tensor + Sigmoid result. + """ + return wrap_nested(relax.op.sigmoid(x._expr), name) # pylint: disable=protected-access + + +class LlavaQuickGELU(Module): + def forward(self, input_tensor: Tensor) -> Tensor: + return input_tensor * sigmoid(input_tensor * 1.702) + + +class CLIPMLP(Module): + def __init__(self, config: LlavaVisionConfig): + super().__init__() + self.activation_fn = LlavaQuickGELU() + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, dtype=config.dtype) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, dtype=config.dtype) + + def forward(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class CLIPAttention(Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: LlavaVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) + + def _shape(self, tensor: Tensor, seq_len: int, bsz: int): + reshape_tensor = reshape(tensor, shape=(bsz, seq_len, self.num_heads, self.head_dim)) + permute_tensor = permute_dims(reshape_tensor, axes=(0, 2, 1, 3)) + return permute_tensor + + def forward( + self, + hidden_states: Tensor, + ) -> Tensor: + bsz, tgt_len, embed_dim = hidden_states.shape + query_states = self._shape(self.q_proj(hidden_states) * self.scale, tgt_len, bsz) + key_states = self._shape(self.k_proj(hidden_states), tgt_len, bsz) + value_states = self._shape(self.v_proj(hidden_states), tgt_len, bsz) + + proj_shape = ( + bsz * self.num_heads, + -1, + self.head_dim, + ) # shape of (batch*num_heads, seq_len,head_dim) + + query_states = reshape(query_states, shape=proj_shape) + key_states = reshape(key_states, shape=proj_shape) + value_states = reshape(value_states, shape=proj_shape) + + trans_key_states = permute_dims(key_states, axes=(0, 2, 1)) + + attn_weights = matmul(query_states, trans_key_states) + attn_weights = softmax(attn_weights, axis=-1) + attn_output = matmul(attn_weights, value_states) + attn_output = reshape(attn_output, shape=(bsz, self.num_heads, tgt_len, self.head_dim)) + attn_output = permute_dims(attn_output, axes=(0, 2, 1, 3)) + attn_output = reshape(attn_output, shape=(bsz, tgt_len, embed_dim)) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class CLIPEncoderLayer(Module): + def __init__(self, config: LlavaVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config) + self.layer_norm1 = nn.LayerNorm( + normalized_shape=self.embed_dim, eps=config.layer_norm_eps, dtype=config.dtype + ) + self.mlp = CLIPMLP(config) + self.layer_norm2 = nn.LayerNorm( + normalized_shape=self.embed_dim, eps=config.layer_norm_eps, dtype=config.dtype + ) + + def forward(self, hidden_states: Tensor) -> Tensor: + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + return outputs + + +class CLIPEncoder(Module): + def __init__(self, config: LlavaVisionConfig): + super().__init__() + self.layers = nn.ModuleList( + [CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + def forward(self, inputs_embeds: Tensor) -> Tensor: + hidden_states = inputs_embeds + encoder_states: Tuple[Any, ...] = () + for _, encoder_layer in enumerate(self.layers): + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer(hidden_states) + hidden_states = layer_outputs[0] + encoder_states = encoder_states + (hidden_states,) + return encoder_states + + +class CLIPVisionTransformer(Module): + def __init__(self, config: LlavaVisionConfig): + super().__init__() + embed_dim = config.hidden_size + self.embeddings = CLIPVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, dtype=config.dtype) + self.encoder = CLIPEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, dtype=config.dtype) + + def forward(self, pixel_values: Tensor) -> Tensor: + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + return encoder_outputs + + +class CLIPVisionModel(Module): + def __init__(self, config: LlavaVisionConfig): + super().__init__() + self.vision_model = CLIPVisionTransformer(config) + + def forward(self, pixel_values: Tensor) -> Tensor: + return self.vision_model(pixel_values)[-2] + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlavaConfig): + super().__init__() + + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=True + ) + self.act = nn.GELU() + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=True + ) + + def forward(self, image_features: Tensor) -> Tensor: + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class LlavaForCasualLM(Module): + def __init__(self, config: LlavaConfig): + super().__init__() + self.config = config + self.vision_tower = CLIPVisionModel(config.vision_config) + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.language_model = LlamaForCasualLM(config.text_config) + self.vocab_size = config.vocab_size + self.dtype = config.dtype + + def _embed_input_ids(self, input_ids: Tensor) -> Tensor: + return self.language_model.embed(input_ids) + + def _embed_pixel_values_and_input_ids(self, pixel_values: Tensor, input_ids: Tensor) -> Tensor: + def _index(x, value, batch_size, seq_len): + return te.compute( + (batch_size, seq_len), + lambda i, j: tir.if_then_else( + x[i, j] == value, + j, + tir.IntImm("int32", 0), + ), + name="index", + ) + + def _concat(x: Tensor, y: Tensor, new_shape: tuple, insert_index: Tensor): + return te.compute( + (new_shape), + lambda b, i, j: tir.if_then_else( + i < insert_index[0], + x[b, i, j], + tir.if_then_else( + i < insert_index[0] + y.shape[1], + y[b, i - insert_index[0], j], + x[b, i - y.shape[1] + 1, j], + ), + ), + ) + + input_embeddings = self._embed_input_ids(input_ids) + + image_features_all = self.vision_tower.forward(pixel_values) + image_features = wrap_nested( + strided_slice( + image_features_all._expr, # pylint: disable=protected-access + axes=[1], + begin=[1], + end=[image_features_all.shape[1]], + ), + name="slice", + ) + image_features = self.multi_modal_projector(image_features) + batch_size, seq_len = input_ids.shape + image_index_tensor = op.tensor_expr_op( + _index, + name_hint="index", + args=[ + input_ids, + tir.IntImm("int32", self.config.image_token_index), + batch_size, + seq_len, + ], + ).astype("int32") + ##! Assume only one token in input + ##! Also assume batch_size = 1 for now + # TODO: Support image_count > 1 and batch_size > 1 # pylint: disable=fixme + insert_index = op.sum(image_index_tensor, axis=1) + + new_shape = ( + batch_size, + seq_len + tir.IntImm("int32", image_features.shape[1] - 1), + self.config.text_config.hidden_size, + ) + + combined_embeddings = op.tensor_expr_op( + _concat, + name_hint="combined_embeddings", + args=[input_embeddings, image_features, new_shape, insert_index], + ) + return combined_embeddings + + def embed(self, input_ids: Tensor) -> Tensor: + return self._embed_input_ids(input_ids) + + def embed_with_pixel_values(self, pixel_values: Tensor, input_ids: Tensor) -> Tensor: + return self._embed_pixel_values_and_input_ids(pixel_values, input_ids) + + def image_embed(self, pixel_values: Tensor) -> Tensor: + image_features_all = self.vision_tower.forward(pixel_values) + image_features = wrap_nested( + strided_slice( + image_features_all._expr, # pylint: disable=protected-access + axes=[1], + begin=[1], + end=[image_features_all.shape[1]], + ), + name="slice", + ) + image_features = self.multi_modal_projector(image_features) + image_features = reshape(image_features, shape=(-1, self.config.text_config.hidden_size)) + return image_features + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + return self.language_model.batch_forward(input_embeds, paged_kv_cache, logit_positions) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + return self.language_model.prefill(input_embed, paged_kv_cache) + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + return self.language_model.decode(input_embed, paged_kv_cache) + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + return self.language_model.batch_prefill(input_embeds, logit_positions, paged_kv_cache) + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + return self.language_model.batch_decode(input_embeds, paged_kv_cache) + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + return self.language_model.batch_verify(input_embeds, paged_kv_cache) + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( # pylint: disable=too-many-arguments + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.config.text_config.num_hidden_layers, + num_attention_heads=self.config.text_config.num_attention_heads + // self.config.tensor_parallel_shards, + num_key_value_heads=self.config.text_config.num_key_value_heads + // self.config.tensor_parallel_shards, + head_dim=self.config.text_config.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.language_model.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "embed_with_pixel_values": { + "pixel_values": nn.spec.Tensor( + [ + 1, + 3, + self.config.vision_config.image_size, + self.config.vision_config.image_size, + ], + self.dtype, + ), + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "image_embed": { + "pixel_values": nn.spec.Tensor( + [ + 1, + 3, + self.config.vision_config.image_size, + self.config.vision_config.image_size, + ], + self.dtype, + ), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor( + [1, "seq_len", self.config.text_config.hidden_size], self.dtype + ), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor( + [1, 1, self.config.text_config.hidden_size], self.dtype + ), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor( + [1, "seq_len", self.config.text_config.hidden_size], self.dtype + ), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor( + ["batch_size", 1, self.config.text_config.hidden_size], self.dtype + ), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor( + [1, "seq_len", self.config.text_config.hidden_size], self.dtype + ), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/llava/llava_quantization.py b/python/mlc_llm/model/llava/llava_quantization.py new file mode 100644 index 0000000000..f487a40489 --- /dev/null +++ b/python/mlc_llm/model/llava/llava_quantization.py @@ -0,0 +1,53 @@ +"""This file specifies how MLC's Llava parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from ...loader import QuantizeMapping +from ...quantization import AWQQuantize, GroupQuantize, NoQuantize +from .llava_model import LlavaConfig, LlavaForCasualLM + + +def group_quant( + model_config: LlavaConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llava model using group quantization.""" + model: nn.Module = LlavaForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: LlavaConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llava model using Activation-aware Weight Quantization(AWQ).""" + model: nn.Module = LlavaForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: LlavaConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llava model without quantization.""" + model: nn.Module = LlavaForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 607cec2918..9e8d98daa4 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -15,6 +15,7 @@ from .gpt_neox import gpt_neox_loader, gpt_neox_model, gpt_neox_quantization from .internlm import internlm_loader, internlm_model, internlm_quantization from .llama import llama_loader, llama_model, llama_quantization +from .llava import llava_loader, llava_model, llava_quantization from .mistral import mistral_loader, mistral_model, mistral_quantization from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization from .orion import orion_loader, orion_model, orion_quantization @@ -292,4 +293,19 @@ class Model: "group-quant": orion_quantization.group_quant, }, ), + "llava": Model( + name="llava", + model=llava_model.LlavaForCasualLM, + config=llava_model.LlavaConfig, + source={ + "huggingface-torch": llava_loader.huggingface, + "huggingface-safetensor": llava_loader.huggingface, + "awq": llava_loader.awq, + }, + quantize={ + "group-quant": llava_quantization.group_quant, + "no-quant": llava_quantization.no_quant, + "awq": llava_quantization.awq_quant, + }, + ), } diff --git a/python/mlc_llm/model/model_preset.py b/python/mlc_llm/model/model_preset.py index 561109b77e..8e87217d35 100644 --- a/python/mlc_llm/model/model_preset.py +++ b/python/mlc_llm/model/model_preset.py @@ -589,4 +589,38 @@ "use_cache": True, "vocab_size": 84608, }, + "llava": { + "architectures": ["LlavaForConditionalGeneration"], + "ignore_index": -100, + "image_token_index": 32000, + "model_type": "llava", + "pad_token_id": 32001, + "projector_hidden_act": "gelu", + "text_config": { + "_name_or_path": "meta-llama/Llama-2-7b-hf", + "architectures": ["LlamaForCausalLM"], + "max_position_embeddings": 4096, + "model_type": "llama", + "rms_norm_eps": 1e-05, + "torch_dtype": "float16", + "vocab_size": 32064, + }, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.36.0.dev0", + "vision_config": { + "hidden_size": 1024, + "image_size": 336, + "intermediate_size": 4096, + "model_type": "clip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "vocab_size": 32000, + }, + "vision_feature_layer": -2, + "vision_feature_select_strategy": "default", + "vocab_size": 32064, + }, } diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index fa99b95c16..154bd3803d 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -1,10 +1,12 @@ """The standard conversation protocol in MLC LLM""" from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union from pydantic import BaseModel, Field, field_validator +from ..serve import data + # The message placeholders in the message prompts according to roles. class MessagePlaceholders(Enum): @@ -56,7 +58,9 @@ class Conversation(BaseModel): # The conversation history messages. # Each message is a pair of strings, denoting "(role, content)". # The content can be None. - messages: List[Tuple[str, Optional[str]]] = Field(default_factory=lambda: []) + messages: List[Tuple[str, Optional[Union[str, List[Dict[str, str]]]]]] = Field( + default_factory=lambda: [] + ) # The separators between messages when concatenating into a single prompt. # List size should be either 1 or 2. @@ -126,6 +130,7 @@ def as_prompt(self) -> str: raise ValueError(f'Role "{role}" is not a supported role in {self.roles.keys()}') separator = separators[role == "assistant"] # check assistant role if content is not None: + assert isinstance(content, str) message_string = ( self.roles[role] + self.role_content_sep @@ -146,3 +151,79 @@ def as_prompt(self) -> str: prompt = prompt.replace(MessagePlaceholders.FUNCTION.value, "") return prompt + + def as_prompt_list(self, image_embed_size=None) -> List[Union[str, data.ImageData]]: + """Convert the conversation template and history messages to + a list of prompts. + + Returns: + List[Union[str, data.ImageData]]: The list of prompts. + """ + # TODO: Unify this function with as_prompt() # pylint: disable=fixme + + # pylint: disable=import-outside-toplevel + from ..serve.entrypoints.entrypoint_utils import get_image_from_url + + # - Get the system message. + system_msg = self.system_template.replace( + MessagePlaceholders.SYSTEM.value, self.system_message + ) + + # - Get the message strings. + message_list: List[Union[str, data.ImageData]] = [] + separators = list(self.seps) + if len(separators) == 1: + separators.append(separators[0]) + message_list.append(system_msg + separators[0]) + for role, content in self.messages: # pylint: disable=not-an-iterable + if role not in self.roles.keys(): + raise ValueError(f'Role "{role}" is not a supported role in {self.roles.keys()}') + separator = separators[role == "assistant"] # check assistant role + if content is not None: + if isinstance(content, str): + message_string = ( + self.roles[role] + + self.role_content_sep + + self.role_templates[role].replace( + MessagePlaceholders[role.upper()].value, content + ) + + separator + ) + message_list.append(message_string) + else: + assert isinstance( + content, list + ), "Content should be a string or a list of dicts" + message_list.append(self.roles[role] + self.role_content_sep) + for item in content: + assert isinstance( + item, dict + ), "Content should be a string or a list of dicts" + assert "type" in item, "Content item should have a type field" + if item["type"] == "text": + message_list.append( + self.role_templates[role].replace( + MessagePlaceholders[role.upper()].value, item["text"] + ) + ) + elif item["type"] == "image_url": + assert image_embed_size is not None, "Image embed size is required" + message_list.append( + data.ImageData( + image=get_image_from_url(item["image_url"]), + embed_size=image_embed_size, + ) + ) + else: + raise ValueError(f"Unsupported content type: {item['type']}") + message_list.append(separator) + + else: + message_string = self.roles[role] + self.role_empty_sep + message_list.append(message_string) + + prompt = message_list + + ## TODO: Support function calling # pylint: disable=fixme + + return prompt diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 59185ec520..c5cc95cf4c 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -4,7 +4,7 @@ from .. import base from .async_engine import AsyncThreadedEngine from .config import EngineMode, GenerationConfig, KVCacheConfig -from .data import Data, RequestStreamOutput, TextData, TokenData +from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData from .engine import Engine from .grammar import BNFGrammar, GrammarStateMatcher from .request import Request diff --git a/python/mlc_llm/serve/async_engine.py b/python/mlc_llm/serve/async_engine.py index 048123286d..58636cb83b 100644 --- a/python/mlc_llm/serve/async_engine.py +++ b/python/mlc_llm/serve/async_engine.py @@ -6,7 +6,7 @@ import sys import threading from dataclasses import dataclass -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, Tuple, Union import tvm @@ -209,7 +209,10 @@ def terminate(self): self._background_loop_thread.join() async def generate( - self, prompt: Union[str, List[int]], generation_config: GenerationConfig, request_id: str + self, + prompt: Union[str, List[int], Sequence[Union[str, List[int], data.Data]]], + generation_config: GenerationConfig, + request_id: str, ) -> AsyncGenerator[List[AsyncStreamOutput], Any]: """Asynchronous text generation interface. The method is a coroutine that streams a list of AsyncStreamOutput @@ -234,9 +237,20 @@ async def generate( # loop is the main driving event loop of the process. self._async_event_loop = asyncio.get_event_loop() + def convert_to_data( + prompt: Union[str, List[int], Sequence[Union[str, List[int], data.Data]]] + ) -> List[data.Data]: + if isinstance(prompt, data.Data): + return [prompt] + if isinstance(prompt, str): + return [data.TextData(prompt)] + if isinstance(prompt[0], int): + return [data.TokenData(prompt)] # type: ignore + return [convert_to_data(x)[0] for x in prompt] # type: ignore + # Create the request with the given id, input data, generation # config and the created callback. - input_data = data.TextData(prompt) if isinstance(prompt, str) else data.TokenData(prompt) + input_data = convert_to_data(prompt) request = Request(request_id, input_data, generation_config) # Create the unique stream of the request. diff --git a/python/mlc_llm/serve/data.py b/python/mlc_llm/serve/data.py index 57532827e9..8444e3f363 100644 --- a/python/mlc_llm/serve/data.py +++ b/python/mlc_llm/serve/data.py @@ -5,6 +5,7 @@ import tvm._ffi from tvm.runtime import Object +from tvm.runtime.ndarray import NDArray from . import _ffi_api @@ -58,6 +59,29 @@ def token_ids(self) -> List[int]: return list(_ffi_api.TokenDataGetTokenIds(self)) # type: ignore # pylint: disable=no-member +@tvm._ffi.register_object("mlc.serve.ImageData") # type: ignore # pylint: disable=protected-access +class ImageData(Data): + """The class of image data, containing the image as NDArray. + + Parameters + ---------- + image : tvm.runtime.NDArray + The image data. + """ + + def __init__(self, image: NDArray, embed_size: int): + self.embed_size = embed_size + self.__init_handle_by_constructor__(_ffi_api.ImageData, image, embed_size) # type: ignore # pylint: disable=no-member + + @property + def image(self) -> NDArray: + """Return the image data.""" + return _ffi_api.ImageDataGetImage(self) # type: ignore # pylint: disable=no-member + + def __len__(self): + return self.embed_size + + @dataclass class SingleRequestStreamOutput: """The request stream output of a single request. diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 06185d0c2a..0757a0d8e9 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -371,7 +371,7 @@ def __init__( # pylint: disable=too-many-arguments def generate( # pylint: disable=too-many-locals self, - prompts: Union[str, List[str], List[int], List[List[int]]], + prompts: Union[str, List[str], List[int], List[List[int]], List[List[data.Data]]], generation_config: Union[GenerationConfig, List[GenerationConfig]], ) -> Tuple[List[List[str]], List[Optional[List[List[str]]]]]: """Generate texts for a list of input prompts. @@ -409,7 +409,7 @@ def generate( # pylint: disable=too-many-locals else: assert isinstance(prompts, list), ( "Input `prompts` is expected to be a string, a list of " - "str, a list of token ids or multiple lists of token ids." + "str, a list of token ids or multiple lists of token ids. " ) if len(prompts) == 0: return [], [] @@ -476,13 +476,16 @@ def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]): # Override the callback function in engine. self._ffi["set_request_stream_callback"](request_stream_callback) + def convert_to_data(prompt: Union[str, List[int], List[data.Data]]) -> List[data.Data]: + if isinstance(prompt, str): + return [data.TextData(prompt)] + if isinstance(prompt[0], int): + return [data.TokenData(prompt)] # type: ignore + return prompt # type: ignore + # Add requests to engine. for req_id, (prompt, generation_cfg) in enumerate(zip(prompts, generation_config)): - input_data = ( - data.TextData(prompt) - if isinstance(prompt, str) - else data.TokenData(prompt) # type: ignore - ) + input_data = convert_to_data(prompt) # type: ignore self.add_request( Request( request_id=str(req_id), diff --git a/python/mlc_llm/serve/entrypoints/entrypoint_utils.py b/python/mlc_llm/serve/entrypoints/entrypoint_utils.py index 10256c2a48..f0c82769ec 100644 --- a/python/mlc_llm/serve/entrypoints/entrypoint_utils.py +++ b/python/mlc_llm/serve/entrypoints/entrypoint_utils.py @@ -2,10 +2,13 @@ import uuid from http import HTTPStatus -from typing import Callable, List, Optional, Union +from io import BytesIO +from typing import Callable, Dict, List, Optional, Union import fastapi +from mlc_llm.serve import data + from ...protocol import RequestProtocol from ...protocol.protocol_utils import ErrorResponse, get_unsupported_fields @@ -56,9 +59,11 @@ def check_prompts_length( def process_prompts( - input_prompts: Union[str, List[int], List[Union[str, List[int]]]], + input_prompts: Union[ + str, List[int], List[Union[str, List[int]]], List[Union[str, data.ImageData]] + ], ftokenize: Callable[[str], List[int]], -) -> Union[List[List[int]], fastapi.responses.JSONResponse]: +) -> Union[List[Union[List[int], data.ImageData]], fastapi.responses.JSONResponse]: """Convert all input tokens to list of token ids with regard to the given tokenization function. For each input prompt, return the list of token ids after tokenization. @@ -86,7 +91,39 @@ def process_prompts( is_token_ids = isinstance(input_prompt, list) and all( isinstance(token_id, int) for token_id in input_prompt ) - if not (is_str or is_token_ids): + is_image = isinstance(input_prompt, data.ImageData) + if not (is_str or is_token_ids or is_image): return create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) output_prompts.append(ftokenize(input_prompt) if is_str else input_prompt) # type: ignore return output_prompts + + +def get_image_from_url(url: str): + """Get the image from the given URL, process and return the image tensor as TVM NDArray.""" + + # pylint: disable=import-outside-toplevel, import-error + import requests + import tvm + from PIL import Image + from transformers import CLIPImageProcessor + + response = requests.get(url, timeout=5) + image_tensor = Image.open(BytesIO(response.content)).convert("RGB") + + image_processor = CLIPImageProcessor( + size={"shortest_edge": 336}, crop_size={"height": 336, "width": 336} + ) + image_features = tvm.nd.array( + image_processor.preprocess(image_tensor, return_tensors="np")["pixel_values"].astype( + "float16" + ) + ) + return image_features + + +def get_image_embed_size(config: Dict) -> int: + """Get the image embedding size from the model config file.""" + image_size = config["model_config"]["vision_config"]["image_size"] + patch_size = config["model_config"]["vision_config"]["patch_size"] + embed_size = (image_size // patch_size) ** 2 + return embed_size diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index 04f7c3eb58..da2d917dc8 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -4,10 +4,12 @@ import ast import json from http import HTTPStatus -from typing import AsyncGenerator, Dict, List, Optional, Union +from typing import AsyncGenerator, Dict, List, Optional, Sequence, Union import fastapi +from mlc_llm.serve import data + from ...protocol import protocol_utils from ...protocol.conversation_protocol import Conversation from ...protocol.openai_api_protocol import ( @@ -266,7 +268,6 @@ def chat_completion_check_message_validity( if isinstance(message.content, list): if message.role != "user": return "Non-user message having a list of content is invalid." - return "User message having a list of content is not supported yet." if message.tool_calls is not None: if message.role != "assistant": return "Non-assistant message having `tool_calls` is invalid." @@ -388,11 +389,12 @@ async def request_chat_completion( if error_msg is not None: return entrypoint_utils.create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) + content_has_list = any(isinstance(message.content, list) for message in request.messages) for message in request.messages: role = message.role content = message.content - assert isinstance(content, str), "Internal error: content is not a string." if role == "system": + assert isinstance(content, str) conv_template.system_message = content if content is not None else "" continue @@ -403,17 +405,27 @@ async def request_chat_completion( # - Get the prompt from template, and encode to token ids. # - Check prompt length async_engine.record_event(request_id, event="start tokenization") - prompts = entrypoint_utils.process_prompts( - conv_template.as_prompt(), async_engine.tokenizer.encode - ) + + model_config = ServerContext.get_model_config(request.model) + image_embed_size = entrypoint_utils.get_image_embed_size(model_config) + + if content_has_list: + prompts = entrypoint_utils.process_prompts( + conv_template.as_prompt_list(image_embed_size=image_embed_size), + async_engine.tokenizer.encode, + ) + else: + prompts = entrypoint_utils.process_prompts( + conv_template.as_prompt(), async_engine.tokenizer.encode + ) async_engine.record_event(request_id, event="finish tokenization") - assert isinstance(prompts, list) and len(prompts) == 1, "Internal error" if conv_template.system_prefix_token_ids is not None: prompts[0] = conv_template.system_prefix_token_ids + prompts[0] error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_input_sequence_length) if error is not None: return error - prompt = prompts[0] + + prompt: Sequence[Union[List[int], data.ImageData]] = prompts # Process generation config. Create request id. generation_cfg = protocol_utils.get_generation_config( diff --git a/python/mlc_llm/serve/server/server_context.py b/python/mlc_llm/serve/server/server_context.py index d382bb701e..c18bab466b 100644 --- a/python/mlc_llm/serve/server/server_context.py +++ b/python/mlc_llm/serve/server/server_context.py @@ -1,7 +1,9 @@ """Server context that shared by multiple entrypoint files.""" +import json from typing import Dict, List, Optional +from ...chat_module import _get_model_path from ...conversation_template import ConvTemplateRegistry from ...protocol.conversation_protocol import Conversation from .. import async_engine @@ -14,6 +16,7 @@ class ServerContext: _models: Dict[str, async_engine.AsyncThreadedEngine] = {} _conv_templates: Dict[str, Conversation] = {} + _model_configs: Dict[str, Dict] = {} @staticmethod def add_model(hosted_model: str, engine: async_engine.AsyncThreadedEngine) -> None: @@ -28,6 +31,11 @@ def add_model(hosted_model: str, engine: async_engine.AsyncThreadedEngine) -> No if conv_template is not None: ServerContext._conv_templates[hosted_model] = conv_template + _, config_file_path = _get_model_path(hosted_model) + with open(config_file_path, "r", encoding="utf-8") as file: + config = json.load(file) + ServerContext._model_configs[hosted_model] = config + @staticmethod def get_engine(model: str) -> Optional[async_engine.AsyncThreadedEngine]: """Get the async engine of the requested model.""" @@ -45,3 +53,8 @@ def get_conv_template(model: str) -> Optional[Conversation]: def get_model_list() -> List[str]: """Get the list of models on serve.""" return list(ServerContext._models.keys()) + + @staticmethod + def get_model_config(model: str) -> Optional[Dict]: + """Get the model config path of the requested model.""" + return ServerContext._model_configs.get(model, None) diff --git a/tests/python/serve/server/test_server_image.py b/tests/python/serve/server/test_server_image.py new file mode 100644 index 0000000000..9b016224e4 --- /dev/null +++ b/tests/python/serve/server/test_server_image.py @@ -0,0 +1,258 @@ +# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches +import json +import os +from typing import Dict, List, Optional, Tuple + +import pytest +import regex +import requests + +OPENAI_V1_CHAT_COMPLETION_URL = "http://127.0.0.1:8001/v1/chat/completions" + +JSON_TOKEN_PATTERN = ( + r"((-?(?:0|[1-9]\d*))(\.\d+)?([eE][-+]?\d+)?)|null|true|false|" + r'("((\\["\\\/bfnrt])|(\\u[0-9a-fA-F]{4})|[^"\\\x00-\x1f])*")' +) +JSON_TOKEN_RE = regex.compile(JSON_TOKEN_PATTERN) + + +def is_json_or_json_prefix(s: str) -> bool: + try: + json.loads(s) + return True + except json.JSONDecodeError as e: + # If the JSON decoder reaches the end of s, it is a prefix of a JSON string. + if e.pos == len(s): + return True + # Since json.loads is token-based instead of char-based, there may remain half a token after + # the matching position. + # If the left part is a prefix of a valid JSON token, the output is also valid + regex_match = JSON_TOKEN_RE.fullmatch(s[e.pos :], partial=True) + return regex_match is not None + + +def check_openai_nonstream_response( + response: Dict, + *, + is_chat_completion: bool, + model: str, + object_str: str, + num_choices: int, + finish_reasons: List[str], + completion_tokens: Optional[int] = None, + echo_prompt: Optional[str] = None, + suffix: Optional[str] = None, + stop: Optional[List[str]] = None, + require_substr: Optional[List[str]] = None, + json_mode: bool = False, +): + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) <= num_choices + texts: List[str] = ["" for _ in range(num_choices)] + for choice in choices: + idx = choice["index"] + assert choice["finish_reason"] in finish_reasons + + if not is_chat_completion: + assert isinstance(choice["text"], str) + texts[idx] = choice["text"] + if echo_prompt is not None: + assert texts[idx] + if suffix is not None: + assert texts[idx] + else: + message = choice["message"] + assert message["role"] == "assistant" + assert isinstance(message["content"], str) + texts[idx] = message["content"] + + if stop is not None: + for stop_str in stop: + assert stop_str not in texts[idx] + if require_substr is not None: + for substr in require_substr: + assert substr in texts[idx] + if json_mode: + assert is_json_or_json_prefix(texts[idx]) + + usage = response["usage"] + assert isinstance(usage, dict) + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + assert usage["prompt_tokens"] > 0 + if completion_tokens is not None: + assert usage["completion_tokens"] == completion_tokens + + +def check_openai_stream_response( + responses: List[Dict], + *, + is_chat_completion: bool, + model: str, + object_str: str, + num_choices: int, + finish_reasons: List[str], + completion_tokens: Optional[int] = None, + echo_prompt: Optional[str] = None, + suffix: Optional[str] = None, + stop: Optional[List[str]] = None, + require_substr: Optional[List[str]] = None, + json_mode: bool = False, +): + assert len(responses) > 0 + + finished = [False for _ in range(num_choices)] + outputs = ["" for _ in range(num_choices)] + for response in responses: + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) <= num_choices + for choice in choices: + idx = choice["index"] + + if not is_chat_completion: + assert isinstance(choice["text"], str) + outputs[idx] += choice["text"] + else: + delta = choice["delta"] + assert delta["role"] == "assistant" + assert isinstance(delta["content"], str) + outputs[idx] += delta["content"] + + if finished[idx]: + assert choice["finish_reason"] in finish_reasons + elif choice["finish_reason"] is not None: + assert choice["finish_reason"] in finish_reasons + finished[idx] = True + + if not is_chat_completion: + usage = response["usage"] + assert isinstance(usage, dict) + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + assert usage["prompt_tokens"] > 0 + if completion_tokens is not None: + assert usage["completion_tokens"] <= completion_tokens + + if not is_chat_completion: + if completion_tokens is not None: + assert responses[-1]["usage"]["completion_tokens"] == completion_tokens + + for i, output in enumerate(outputs): + if echo_prompt is not None: + assert output.startswith(echo_prompt) + if suffix is not None: + assert output.endswith(suffix) + if stop is not None: + for stop_str in stop: + assert stop_str not in output + if require_substr is not None: + for substr in require_substr: + assert substr in output + if json_mode: + assert is_json_or_json_prefix(output) + + +CHAT_COMPLETION_MESSAGES = [ + # messages #0 + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": "https://llava-vl.github.io/static/images/view.jpg", + }, + {"type": "text", "text": "What does this image represent?"}, + ], + }, + ], + # messages #1 + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": "https://llava-vl.github.io/static/images/view.jpg", + }, + {"type": "text", "text": "What does this image represent?"}, + ], + }, + { + "role": "assistant", + "content": "The image represents a serene and peaceful scene of a pier extending over a body of water, such as a lake or a river.er. The pier is made of wood and has a bench on it, providing a place for people to sit and enjoy the view. The pier is situated in a natural environment, surrounded by trees and mountains in the background. This setting creates a tranquil atmosphere, inviting visitors to relax and appreciate the beauty of the landscape.", + }, + { + "role": "user", + "content": "What country is the image set in? Give me 10 ranked guesses and reasons why.", + }, + ], +] + + +@pytest.mark.parametrize("stream", [False, True]) +@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) +def test_openai_v1_chat_completions( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, + messages: List[Dict[str, str]], +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + payload = { + "model": served_model[0], + "messages": messages, + "stream": stream, + } + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion", + num_choices=1, + finish_reasons=["stop"], + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion.chunk", + num_choices=1, + finish_reasons=["stop"], + ) + + +if __name__ == "__main__": + model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") + if model_lib_path is None: + raise ValueError( + 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' + "Please set it to model lib compiled by MLC LLM " + "(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`)." + ) + + model = os.environ.get("MLC_SERVE_MODEL") + if model is None: + MODEL = (os.path.dirname(model_lib_path), model_lib_path) + else: + MODEL = (model, model_lib_path) + + for msg in CHAT_COMPLETION_MESSAGES: + test_openai_v1_chat_completions(MODEL, None, stream=False, messages=msg) + test_openai_v1_chat_completions(MODEL, None, stream=True, messages=msg) diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py new file mode 100644 index 0000000000..5b23a245f9 --- /dev/null +++ b/tests/python/serve/test_serve_engine_image.py @@ -0,0 +1,50 @@ +from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig, data +from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve.entrypoints.entrypoint_utils import get_image_from_url + + +def get_test_image(): + return get_image_from_url("https://llava-vl.github.io/static/images/view.jpg") + + +def test_engine_generate(): + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/llava-1.5-7b-hf-q4f16_1-MLC/params", + model_lib_path="dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + # Create engine + engine = Engine(model, kv_cache_config) + + max_tokens = 256 + + prompts = [ + [ + data.TextData("USER: "), + data.ImageData(get_test_image(), 576), + data.TextData("\nWhat does this image represent? ASSISTANT:"), + ], + [ + data.TextData("USER: "), + data.ImageData(get_test_image(), 576), + data.TextData("\nIs there a dog in this image? ASSISTANT:"), + ], + [data.TextData("USER: What is the meaning of life? ASSISTANT:")], + ] + + output_texts, _ = engine.generate( + prompts, GenerationConfig(max_tokens=max_tokens, stop_token_ids=[2]) + ) + + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + +if __name__ == "__main__": + test_engine_generate() From 058c5839b984de0d08730660fe6732ac34e02063 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Mon, 18 Mar 2024 17:06:53 -0700 Subject: [PATCH 083/531] [Serve] Hot fix for the mixtral serving (#1975) [Fix] hotfix for the mixtral serving Co-authored-by: Yong Wu --- python/mlc_llm/serve/entrypoints/openai_entrypoints.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index da2d917dc8..aa9d941f6c 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -406,10 +406,9 @@ async def request_chat_completion( # - Check prompt length async_engine.record_event(request_id, event="start tokenization") - model_config = ServerContext.get_model_config(request.model) - image_embed_size = entrypoint_utils.get_image_embed_size(model_config) - if content_has_list: + model_config = ServerContext.get_model_config(request.model) + image_embed_size = entrypoint_utils.get_image_embed_size(model_config) prompts = entrypoint_utils.process_prompts( conv_template.as_prompt_list(image_embed_size=image_embed_size), async_engine.tokenizer.encode, From 3cbc169092cfb5ac888d19c3acc17bca532dcf4c Mon Sep 17 00:00:00 2001 From: Shrey Gupta <51860471+shreygupta2809@users.noreply.github.com> Date: Mon, 18 Mar 2024 21:32:21 -0400 Subject: [PATCH 084/531] [REST] REST API Deprecated (#1973) Deleted old Rest API - Removed rest.py - Removed old interface/openai_api.py - Update ChatModule to use new OpenAI Api protocol Co-authored-by: Kartik Khandelwal --- python/mlc_llm/chat_module.py | 24 +- python/mlc_llm/interface/openai_api.py | 183 --------- python/mlc_llm/rest.py | 492 ------------------------- 3 files changed, 13 insertions(+), 686 deletions(-) delete mode 100644 python/mlc_llm/interface/openai_api.py delete mode 100644 python/mlc_llm/rest.py diff --git a/python/mlc_llm/chat_module.py b/python/mlc_llm/chat_module.py index 18c3258514..943f98c7e2 100644 --- a/python/mlc_llm/chat_module.py +++ b/python/mlc_llm/chat_module.py @@ -24,7 +24,7 @@ from . import base as _ if TYPE_CHECKING: - from mlc_llm.interface.openai_api import ChatMessage + from mlc_llm.protocol.openai_api_protocol import ChatCompletionMessage # pylint: disable=line-too-long _PYTHON_GET_STARTED_TUTORIAL_URL = "https://github.com/mlc-ai/notebooks/blob/main/mlc-llm/tutorial_chat_module_getting_started.ipynb" @@ -798,7 +798,7 @@ def __init__( # pylint: disable=too-many-arguments def generate( self, - prompt: Union[str, List["ChatMessage"]], + prompt: Union[str, List["ChatCompletionMessage"]], generation_config: Optional[GenerationConfig] = None, progress_callback=None, stateless=False, @@ -809,7 +809,7 @@ def generate( Parameters ---------- - prompt: Union[str, List[ChatMessage]] + prompt: Union[str, List[ChatCompletionMessage]] The user input prompt, i.e. a question to ask the chat module. It can also be the whole conversation history (list of messages with role and content) eg: @@ -817,9 +817,10 @@ def generate( .. code:: [ - ChatMessage(role="user", content="Hello, how are you?"), - ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), - ChatMessage(role="user", content="I'm good too."), + ChatCompletionMessage(role="user", content="Hello, how are you?"), + ChatCompletionMessage(role="assistant", \ + content="I'm fine, thank you. How about you?"), + ChatCompletionMessage(role="user", content="I'm good too."), ] generation_config: Optional[GenerationConfig] The generation config object to override the ChatConfig generation settings. @@ -1021,7 +1022,7 @@ def _unload(self): def _prefill( self, - input: Union[str, List["ChatMessage"]], # pylint: disable=redefined-builtin + input: Union[str, List["ChatCompletionMessage"]], # pylint: disable=redefined-builtin decode_next_token: bool = True, place_in_prompt: PlaceInPrompt = PlaceInPrompt.All, generation_config: Optional[GenerationConfig] = None, @@ -1031,7 +1032,7 @@ def _prefill( Parameters ---------- - input : Union[str, List[ChatMessage]] + input : Union[str, List[ChatCompletionMessage]] The user input prompt, i.e. a question to ask the chat module. It can also be the whole conversation history (list of messages with role and content) eg: @@ -1039,9 +1040,10 @@ def _prefill( .. code:: [ - ChatMessage(role="user", content="Hello, how are you?"), - ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), - ChatMessage(role="user", content="I'm good too."), + ChatCompletionMessage(role="user", content="Hello, how are you?"), + ChatCompletionMessage(role="assistant", \ + content="I'm fine, thank you. How about you?"), + ChatCompletionMessage(role="user", content="I'm good too."), ] decode_next_token : bool Whether to decode the next token after prefilling. diff --git a/python/mlc_llm/interface/openai_api.py b/python/mlc_llm/interface/openai_api.py deleted file mode 100644 index 7c7797dea6..0000000000 --- a/python/mlc_llm/interface/openai_api.py +++ /dev/null @@ -1,183 +0,0 @@ -# pylint: disable=missing-docstring,fixme,too-few-public-methods -""" -Adapted from FastChat's OpenAI protocol: -https://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py -""" - -import time -from typing import Any, Dict, List, Literal, Optional, Union - -import shortuuid -from pydantic import BaseModel, Field - - -class ToolCalls(BaseModel): - id: str = Field(default_factory=lambda: f"call_{shortuuid.random()}") - type: str = "function" - function: object - - -class ChatMessage(BaseModel): - role: str - content: Union[str, None] - name: Optional[str] = None - tool_calls: Optional[List[ToolCalls]] = None - - -class Function(BaseModel): - description: Optional[str] = None - name: str - parameters: object - - -class Tools(BaseModel): - type: Literal["function"] - function: Dict[str, Any] - - -class ToolChoice(BaseModel): - type: Literal["function"] - function: Dict[str, Any] - - -class ChatCompletionRequest(BaseModel): - model: str - messages: List[ChatMessage] - stream: Optional[bool] = False - temperature: float = None - top_p: float = None - # TODO: replace by presence_penalty and frequency_penalty - repetition_penalty: float = None - mean_gen_len: int = None - # TODO: replace by max_tokens - max_gen_len: int = None - presence_penalty: float = None - frequency_penalty: float = None - n: int = None - stop: Union[str, List[str]] = None - tools: Optional[List[Tools]] = None - tool_choice: Union[Literal["none", "auto"], ToolChoice] = "auto" - # TODO: Implement support for the OpenAI API parameters - # stop: Optional[Union[str, List[str]]] = None - # max_tokens: Optional[int] - # logit_bias - # user: Optional[str] = None - - -class UsageInfo(BaseModel): - prompt_tokens: int = 0 - completion_tokens: Optional[int] = 0 - total_tokens: int = 0 - - -class ChatCompletionResponseChoice(BaseModel): - index: int - message: ChatMessage - finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None - - -class ChatCompletionResponse(BaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") - object: str = "chat.completion" - created: int = Field(default_factory=lambda: int(time.time())) - choices: List[ChatCompletionResponseChoice] - # TODO: Implement support for the following fields - usage: Optional[UsageInfo] = None - - -class DeltaMessage(BaseModel): - role: Optional[str] = None - content: Optional[str] = None - - -class ChatCompletionResponseStreamChoice(BaseModel): - index: int - delta: DeltaMessage - finish_reason: Optional[Literal["stop", "length"]] = None - - -class ChatCompletionStreamResponse(BaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") - object: str = "chat.completion.chunk" - created: int = Field(default_factory=lambda: int(time.time())) - choices: List[ChatCompletionResponseStreamChoice] - - -class CompletionRequest(BaseModel): - model: str - prompt: Union[str, List[str]] - stream: Optional[bool] = False - temperature: float = None - repetition_penalty: float = None - top_p: float = None - mean_gen_len: int = None - # TODO: replace by max_tokens - max_gen_len: int = None - presence_penalty: float = None - frequency_penalty: float = None - n: int = None - stop: Union[str, List[str]] = None - # TODO: Implement support for the OpenAI API parameters - # suffix - # logprobs - # echo - # best_of - # logit_bias - # user: Optional[str] = None - - -class CompletionResponseChoice(BaseModel): - index: int - text: str - finish_reason: Optional[Literal["stop", "length"]] = None - # TODO: logprobs support - logprobs: Optional[int] = None - - -class CompletionResponse(BaseModel): - id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") - object: str = "text.completion" - created: int = Field(default_factory=lambda: int(time.time())) - choices: List[CompletionResponseChoice] - usage: UsageInfo - - -class CompletionResponseStreamChoice(BaseModel): - index: int - text: str - finish_reason: Optional[Literal["stop", "length"]] = None - - -class CompletionStreamResponse(BaseModel): - id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") - object: str = "text.completion.chunk" - created: int = Field(default_factory=lambda: int(time.time())) - choices: List[CompletionResponseStreamChoice] - - -class EmbeddingsRequest(BaseModel): - model: Optional[str] = None - input: Union[str, List[Any]] - user: Optional[str] = None - - -class EmbeddingsResponse(BaseModel): - object: str = "list" - data: List[Dict[str, Any]] - model: Optional[str] = None - usage: UsageInfo - - -class VisualStudioCodeCompletionParameters(BaseModel): - temperature: float = None - top_p: float = None - max_new_tokens: int = None - - -class VisualStudioCodeCompletionRequest(BaseModel): - inputs: str - parameters: VisualStudioCodeCompletionParameters - - -class VisualStudioCodeCompletionResponse(BaseModel): - generated_text: str diff --git a/python/mlc_llm/rest.py b/python/mlc_llm/rest.py deleted file mode 100644 index 011ef4df29..0000000000 --- a/python/mlc_llm/rest.py +++ /dev/null @@ -1,492 +0,0 @@ -# pylint: disable=missing-docstring,fixme -import argparse -import ast -import asyncio -import dataclasses -import json -from contextlib import asynccontextmanager -from typing import Dict, List - -import numpy as np -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse - -from mlc_llm.chat_module import GenerationConfig -from mlc_llm.support.random import set_global_random_seed - -from .chat_module import ChatModule -from .interface.openai_api import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, - ChatMessage, - CompletionRequest, - CompletionResponse, - CompletionResponseChoice, - CompletionResponseStreamChoice, - CompletionStreamResponse, - DeltaMessage, - EmbeddingsRequest, - EmbeddingsResponse, - ToolCalls, - ToolChoice, - UsageInfo, - VisualStudioCodeCompletionRequest, - VisualStudioCodeCompletionResponse, -) - - -@dataclasses.dataclass -class RestAPIArgs: - """RestAPIArgs is the dataclass that organizes the arguments used for starting a REST API - server.""" - - model: str = dataclasses.field( - metadata={ - "help": ( - """ - The model folder after compiling with MLC-LLM build process. The parameter - can either be the model name with its quantization scheme - (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model - folder. In the former case, we will use the provided name to search - for the model folder over possible paths. - """ - ) - } - ) - lib_path: str = dataclasses.field( - default=None, - metadata={ - "help": ( - """ - The full path to the model library file to use (e.g. a ``.so`` file). - """ - ) - }, - ) - device: str = dataclasses.field( - default="auto", - metadata={ - "help": ( - """ - The description of the device to run on. User should provide a string in the - form of 'device_name:device_id' or 'device_name', where 'device_name' is one of - 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the - local device), and 'device_id' is the device id to run on. If no 'device_id' - is provided, it will be set to 0 by default. - """ - ) - }, - ) - host: str = dataclasses.field( - default="127.0.0.1", - metadata={ - "help": ( - """ - The host at which the server should be started, defaults to ``127.0.0.1``. - """ - ) - }, - ) - port: int = dataclasses.field( - default=8000, - metadata={ - "help": ( - """ - The port on which the server should be started, defaults to ``8000``. - """ - ) - }, - ) - random_seed: int = dataclasses.field( - default=None, - metadata={ - "help": ( - """ - The random seed to initialize all the RNG used in mlc-chat. By default, - no seed is set. - """ - ) - }, - ) - - -def convert_args_to_argparser() -> argparse.ArgumentParser: - """Convert from RestAPIArgs to an equivalent ArgumentParser.""" - args = argparse.ArgumentParser("MLC Chat REST API") - for field in dataclasses.fields(RestAPIArgs): - name = field.name.replace("_", "-") - field_name = f"--{name}" - # `kwargs` contains `help`, `choices`, and `action` - kwargs = field.metadata.copy() - if field.type == bool: - # boolean arguments do not need to specify `type` - args.add_argument(field_name, default=field.default, **kwargs) - else: - args.add_argument(field_name, type=field.type, default=field.default, **kwargs) - return args - - -session: Dict[str, ChatModule] = {} - - -@asynccontextmanager -async def lifespan(_app: FastAPI): - if ARGS.random_seed is not None: - set_global_random_seed(ARGS.random_seed) - chat_mod = ChatModule( - model=ARGS.model, - device=ARGS.device, - model_lib_path=ARGS.lib_path, - ) - session["chat_mod"] = chat_mod - yield - session.clear() - - -origins = ["*"] - -app = FastAPI(lifespan=lifespan) -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -class AsyncCompletionStream: - def __init__(self, generation_config: GenerationConfig): - self.generation_config = generation_config - - def __aiter__(self): - return self - - async def get_next_msg(self): - # pylint: disable=protected-access - if not session["chat_mod"]._stopped(): - session["chat_mod"]._decode(generation_config=self.generation_config) - msg = session["chat_mod"]._get_message() - return msg - # pylint: enable=protected-access - raise StopAsyncIteration - - async def __anext__(self): - if not session["chat_mod"]._stopped(): - task = asyncio.create_task(self.get_next_msg()) - msg = await task - return msg - raise StopAsyncIteration - - -def add_function_call(prompt: List[ChatMessage], function_string: str): - # update content of the last input message to include function string - user_query = prompt[-1].content - prompt[-1].content = f"<> {user_query} <> {function_string}\n" - - -def function_call_util(request: ChatCompletionRequest): - """Performs the necessary actions to add function calls to the prompt - returns True if function calls are added to the prompt else returns False - TODO: Check function name in tools.function['name'] - TODO: Currently auto mode default to generating function calls instead of smartly - checking weather to generate function calls or not - """ - - # return if no tools are provided - if request.tools is None: - return False - - # skip if tool_choice is set to none - if isinstance(request.tool_choice, str) and request.tool_choice == "none": - return False - - if isinstance(request.tool_choice, ToolChoice): - # force the model to use a specific function provided by tool_choice - if request.tool_choice.type != "function": - raise ValueError("Only 'function' tool choice is supported") - for tool in request.tools: - if tool.function["name"] == request.tool_choice.function["name"]: - add_function_call(request.messages, json.dumps(tool.function)) - return True - raise ValueError("ToolChoice.function.name not found in tools") - - if isinstance(request.tool_choice, str): - # Add all the functions to the input prompt - function_list = [] - for tool in request.tools: - if tool.type == "function": - function_list.append(tool.function) - else: - raise ValueError("Only 'function' tool.type is supported") - add_function_call(request.messages, json.dumps(function_list)) - else: - raise ValueError("Invalid toolChoice instance type") - return True - - -def convert_function_str_to_json(stringified_calls): - def parse_function_call(call_str): - node = ast.parse(call_str, mode="eval") - call_node = node.body - if isinstance(call_node, ast.Call): - name = call_node.func.id - arguments = {} - for keyword in call_node.keywords: - arguments[keyword.arg] = ast.literal_eval(keyword.value) - return {"name": name, "arguments": arguments} - return None - - calls = ast.literal_eval(stringified_calls) - result = [parse_function_call(call_str) for call_str in calls] - return result - - -@app.post("/v1/chat/completions") -async def request_chat_completion(request: ChatCompletionRequest): - """ - Creates model response for the given chat conversation. - The messages field contains a list of messages (describing the conversation history). eg: - ```"messages": [{"role": "user", "content": "What's my name?"}, - {"role": "assistant", "content": "Your name is Llama."}, - {"role": "user", "content": "No, that's your name. My name is X."}, - {"role": "assistant", "content": "Ah, my apologies! Your name is X! "}, - {"role": "user", "content": "What is the meaning of life?"}, - ] - ``` - ] - """ - generation_config = GenerationConfig( - temperature=request.temperature, - repetition_penalty=request.repetition_penalty, - presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, - top_p=request.top_p, - mean_gen_len=request.mean_gen_len, - max_gen_len=request.max_gen_len, - n=request.n, - stop=request.stop, - ) - - session["chat_mod"].reset_chat() # Reset previous history, KV cache, etc. - - use_function_call = function_call_util(request) - - if request.stream: - session["chat_mod"]._prefill( # pylint: disable=protected-access - input=request.messages, - generation_config=generation_config, - ) - - async def iter_response(): - prev_txt = "" - async for content in AsyncCompletionStream(generation_config=generation_config): - if content: - # Remove the replacement character (U+FFFD) from the response - # This is to handle emojis. An emoji might be made up of multiple tokens. - # In the Rest streaming setting, if an emoji gets truncated in the middle of - # its encoded byte sequence, a replacement character will appear. - valid_content = content.replace("�", "") - chunk = ChatCompletionStreamResponse( - choices=[ - ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage( - role="assistant", content=valid_content[len(prev_txt) :] - ), - finish_reason="stop", - ) - ] - ) - prev_txt = valid_content - yield f"data: {chunk.json(exclude_unset=True)}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse(iter_response(), media_type="text/event-stream") - msg = session["chat_mod"].generate( - prompt=request.messages, generation_config=generation_config, stateless=True - ) - if isinstance(msg, str): - msg = [msg] - - choices = [] - for index, msg_i in enumerate(msg): - if use_function_call: - choices.append( - ChatCompletionResponseChoice( - index=index, - message=ChatMessage( - role="assistant", - content=None, - tool_calls=[ - ToolCalls( - function=fn_json_obj, - ) - for fn_json_obj in convert_function_str_to_json(msg_i) - ], - ), - finish_reason="tool_calls", - ) - ) - else: - choices.append( - ChatCompletionResponseChoice( - index=index, - message=ChatMessage( - role="assistant", - content=msg_i, - ), - finish_reason="stop", - ) - ) - - return ChatCompletionResponse( - choices=choices, - # TODO: Fill in correct usage info - usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), - ) - - -@app.post("/v1/completions") -async def request_completion(request: CompletionRequest): - """ - Creates a completion for a given prompt. - """ - - generation_config = GenerationConfig( - temperature=request.temperature, - repetition_penalty=request.repetition_penalty, - presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, - top_p=request.top_p, - mean_gen_len=request.mean_gen_len, - max_gen_len=request.max_gen_len, - n=request.n, - stop=request.stop, - ) - - session["chat_mod"].reset_chat() - # Langchain's load_qa_chain.run expects the input to be a list with the query - if isinstance(request.prompt, list): - if len(request.prompt) > 1: - raise ValueError( - """ - The /v1/completions endpoint currently only supports single message prompts. - Please ensure your request contains only one message - """ - ) - prompt = request.prompt[0] - else: - prompt = request.prompt - - if request.stream: - session["chat_mod"]._prefill( # pylint: disable=protected-access - input=prompt, - generation_config=generation_config, - ) - - async def iter_response(): - prev_txt = "" - async for content in AsyncCompletionStream(generation_config=generation_config): - if content: - chunk = CompletionStreamResponse( - choices=[ - CompletionResponseStreamChoice( - index=0, - text=content[len(prev_txt) :], - finish_reason="stop", - ) - ] - ) - prev_txt = content - yield f"data: {chunk.json(exclude_unset=True)}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse(iter_response(), media_type="text/event-stream") - msg = session["chat_mod"].generate(prompt=prompt, generation_config=generation_config) - if isinstance(msg, str): - msg = [msg] - return CompletionResponse( - choices=[ - CompletionResponseChoice(index=index, text=msg[index]) for index in range(len(msg)) - ], - # TODO: Fill in correct usage info - usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), - ) - - -@app.post("/v1/embeddings") -async def request_embeddings(request: EmbeddingsRequest): - """ - Gets embedding for some text. - """ - inps = [] - if isinstance(request.input, str): - inps.append(request.input) - elif isinstance(request.input, list): - inps = request.input - else: - assert f"Invalid input type {type(request.input)}" - - data = [] - for i, inp in enumerate(inps): - session["chat_mod"].reset_chat() - emb = session["chat_mod"].embed_text(input=inp).numpy() - mean_emb = np.squeeze(np.mean(emb, axis=1), axis=0) - norm_emb = mean_emb / np.linalg.norm(mean_emb) - data.append({"object": "embedding", "embedding": norm_emb.tolist(), "index": i}) - # TODO: Fill in correct usage info - return EmbeddingsResponse( - data=data, usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0) - ) - - -@app.post("/chat/reset") -async def reset(): - """ - Reset the chat for the currently initialized model. - """ - session["chat_mod"].reset_chat() - - -@app.get("/stats") -async def read_stats(): - """ - Get the runtime stats. - """ - return session["chat_mod"].stats() - - -@app.get("/verbose_stats") -async def read_stats_verbose(): - """ - Get the verbose runtime stats. - """ - return session["chat_mod"].stats(verbose=True) - - -@app.post("/v1/llm-vscode/completions") -async def request_llm_vscode(request: VisualStudioCodeCompletionRequest): - """ - Creates a vscode code completion for a given prompt. - Follows huggingface LSP (https://github.com/huggingface/llm-ls) - """ - generation_config = GenerationConfig( - temperature=request.parameters.temperature, - top_p=request.parameters.top_p, - mean_gen_len=request.parameters.max_new_tokens, - max_gen_len=request.parameters.max_new_tokens, - ) - msg = session["chat_mod"].generate(prompt=request.inputs, generation_config=generation_config) - - return VisualStudioCodeCompletionResponse(generated_text=msg) - - -ARGS = convert_args_to_argparser().parse_args() -if __name__ == "__main__": - uvicorn.run("mlc_llm.rest:app", host=ARGS.host, port=ARGS.port, reload=False, access_log=False) From 587e34149b70ae1889bbd65575dd51cd94b3632d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 18 Mar 2024 19:50:18 -0700 Subject: [PATCH 085/531] [Fix] Fix handling of non-numerical cuda arch (#1976) In the latest gpu, cuda arch may not be integer, e.g `sm_90a`. This fixes a few places that rely on integer parsing. --- python/mlc_llm/interface/compiler_flags.py | 3 ++- python/mlc_llm/support/auto_target.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index bc40103918..2c44efc10d 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -1,6 +1,7 @@ """Flags for overriding model config.""" import dataclasses +import re from io import StringIO from typing import Optional @@ -72,7 +73,7 @@ def _flashinfer(target) -> bool: return False arch_list = detect_cuda_arch_list(target) for arch in arch_list: - if arch < 80: + if int(re.findall(r"\d+", arch)[0]) < 80: logger.warning("flashinfer is not supported on CUDA arch < 80") return False return True diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 434cfff8d0..574474e7dc 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -251,14 +251,14 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None): return build -def detect_cuda_arch_list(target: Target) -> List[int]: +def detect_cuda_arch_list(target: Target) -> List[str]: """Detect the CUDA architecture list from the target.""" assert target.kind.name == "cuda", f"Expect target to be CUDA, but got {target}" if MLC_MULTI_ARCH is not None: - multi_arch = [int(x.strip()) for x in MLC_MULTI_ARCH.split(",")] + multi_arch = [x.strip() for x in MLC_MULTI_ARCH.split(",")] else: assert target.arch.startswith("sm_") - multi_arch = [int(target.arch[3:])] + multi_arch = [target.arch[3:]] multi_arch = list(set(multi_arch)) return multi_arch From bed4f53ba3fa581b6fc2cb2be9ec84753f088435 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Wed, 20 Mar 2024 06:31:31 +0800 Subject: [PATCH 086/531] [Serving][Grammar] Support specifying the main rule in grammar (#1982) finish --- cpp/serve/grammar/grammar.cc | 12 +- cpp/serve/grammar/grammar.h | 13 +- cpp/serve/grammar/grammar_builder.h | 25 ++- cpp/serve/grammar/grammar_parser.cc | 15 +- cpp/serve/grammar/grammar_parser.h | 3 +- cpp/serve/grammar/grammar_simplifier.cc | 2 +- cpp/serve/grammar/grammar_simplifier.h | 2 +- cpp/serve/grammar/grammar_state_matcher.cc | 8 +- .../grammar/grammar_state_matcher_base.h | 2 +- .../grammar/grammar_state_matcher_state.h | 20 ++- python/mlc_llm/serve/grammar.py | 12 +- .../test_grammar_state_matcher_custom.py | 149 +++++++++++++++++- 12 files changed, 217 insertions(+), 46 deletions(-) diff --git a/cpp/serve/grammar/grammar.cc b/cpp/serve/grammar/grammar.cc index e10e6e7e45..c5a41626e3 100644 --- a/cpp/serve/grammar/grammar.cc +++ b/cpp/serve/grammar/grammar.cc @@ -20,8 +20,9 @@ std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar) { return os; } -BNFGrammar BNFGrammar::FromEBNFString(const String& ebnf_string, bool normalize, bool simplify) { - auto grammar = EBNFParser::Parse(ebnf_string); +BNFGrammar BNFGrammar::FromEBNFString(const String& ebnf_string, const String& main_rule, + bool normalize, bool simplify) { + auto grammar = EBNFParser::Parse(ebnf_string, main_rule); if (normalize) { grammar = NestedRuleUnwrapper(grammar).Apply(); } @@ -29,8 +30,8 @@ BNFGrammar BNFGrammar::FromEBNFString(const String& ebnf_string, bool normalize, } TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromEBNFString") - .set_body_typed([](String ebnf_string, bool normalize, bool simplify) { - return BNFGrammar::FromEBNFString(ebnf_string, normalize, simplify); + .set_body_typed([](String ebnf_string, String main_rule, bool normalize, bool simplify) { + return BNFGrammar::FromEBNFString(ebnf_string, main_rule, normalize, simplify); }); BNFGrammar BNFGrammar::FromJSON(const String& json_string) { @@ -112,7 +113,8 @@ ws ::= [ \n\t]* )"; BNFGrammar BNFGrammar::GetGrammarOfJSON() { - static const BNFGrammar grammar = BNFGrammar::FromEBNFString(kJSONGrammarString, true, false); + static const BNFGrammar grammar = + BNFGrammar::FromEBNFString(kJSONGrammarString, "main", true, false); return grammar; } diff --git a/cpp/serve/grammar/grammar.h b/cpp/serve/grammar/grammar.h index 93d8f0e3c1..21062ab503 100644 --- a/cpp/serve/grammar/grammar.h +++ b/cpp/serve/grammar/grammar.h @@ -84,6 +84,12 @@ class BNFGrammarNode : public Object { << "rule_id " << rule_id << " is out of bound"; return rules_[rule_id]; } + /*! \brief Get the main rule of the grammar. */ + const Rule& GetMainRule() const { + DCHECK(main_rule_id_ >= 0 && main_rule_id_ < static_cast(rules_.size())) + << "main_rule_id " << main_rule_id_ << " is out of bound"; + return rules_[main_rule_id_]; + } /*! \brief The type of the rule expr. */ enum class RuleExprType : int32_t { @@ -149,6 +155,8 @@ class BNFGrammarNode : public Object { /*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id corresponds the * index of this vector. */ std::vector rule_expr_indptr_; + /*! \brief The id of the main rule. */ + int32_t main_rule_id_ = -1; friend class BNFGrammarBuilder; friend class BNFGrammarJSONSerializer; @@ -161,6 +169,7 @@ class BNFGrammar : public ObjectRef { * \brief Construct a BNF grammar with a EBNF-formatted string. Will parse the string and * transform it into BNF AST. * \param ebnf_string The EBNF-formatted string. + * \param main_rule The name of the main rule. * \param normalize Whether to normalize the grammar. Default: true. Only set to false for the * purpose of testing. * @@ -173,8 +182,8 @@ class BNFGrammar : public ObjectRef { * \param simplify Whether to simplify the grammar to make matching more efficient. Default: true. * Not implemented yet. */ - static BNFGrammar FromEBNFString(const String& ebnf_string, bool normalize = true, - bool simplify = true); + static BNFGrammar FromEBNFString(const String& ebnf_string, const String& main_rule, + bool normalize = true, bool simplify = true); /*! * \brief Construct a BNF grammar from the dumped JSON string. diff --git a/cpp/serve/grammar/grammar_builder.h b/cpp/serve/grammar/grammar_builder.h index 6044a76bd9..0854cc9789 100644 --- a/cpp/serve/grammar/grammar_builder.h +++ b/cpp/serve/grammar/grammar_builder.h @@ -6,9 +6,10 @@ #ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_BUILDER_H_ #define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_BUILDER_H_ - #include +#include + #include "grammar.h" namespace mlc { @@ -31,19 +32,17 @@ class BNFGrammarBuilder { BNFGrammarBuilder() : grammar_(make_object()) {} /*! - * \brief Create grammar containing the rules and rule_exprs of an existing grammar. The old - * grammar remains unchanged. - * \param grammar The existing grammar. + * \brief Get the result grammar. This function will also set the main rule to the rule with the + * specified name. The rule should be already added to the grammar. + * \param main_rule The name of the main rule. Default is "main". */ - explicit BNFGrammarBuilder(const BNFGrammar& grammar) - : grammar_(make_object(*grammar.get())) { - // for (size_t i = 0; i < grammar_->rules_.size(); ++i) { - // rule_name_to_id_[grammar_->rules_[i].name] = i; - // } - } + BNFGrammar Get(const std::string& main_rule = "main") { + int32_t main_rule_id = GetRuleId(main_rule); + CHECK(main_rule_id != -1) << "The in rule with name \"" << main_rule << "\" is not found."; + grammar_->main_rule_id_ = main_rule_id; - /*! \brief Get the result grammar. */ - BNFGrammar Get() { return BNFGrammar(grammar_); } + return BNFGrammar(grammar_); + } /****************** RuleExpr handling ******************/ @@ -124,7 +123,7 @@ class BNFGrammarBuilder { int32_t id = grammar_->rules_.size(); auto rules = grammar_->rules_; grammar_->rules_.push_back(rule); - ICHECK_EQ(rule_name_to_id_.count(rule.name), 0); + CHECK_EQ(rule_name_to_id_.count(rule.name), 0); rule_name_to_id_[rule.name] = id; return id; } diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index 6e9de834a5..ba9ac80135 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -16,7 +16,7 @@ namespace serve { class EBNFParserImpl { public: /*! \brief The logic of parsing the grammar string. */ - BNFGrammar DoParse(String ebnf_string); + BNFGrammar DoParse(String ebnf_string, String main_rule); private: using Rule = BNFGrammarNode::Rule; @@ -391,7 +391,7 @@ void EBNFParserImpl::ResetStringIterator(const char* cur) { in_parentheses_ = false; } -BNFGrammar EBNFParserImpl::DoParse(String ebnf_string) { +BNFGrammar EBNFParserImpl::DoParse(String ebnf_string, String main_rule) { ResetStringIterator(ebnf_string.c_str()); BuildRuleNameToId(); @@ -404,16 +404,17 @@ BNFGrammar EBNFParserImpl::DoParse(String ebnf_string) { ConsumeSpace(); } - if (builder_.GetRuleId("main") == -1) { - ThrowParseError("There must be a rule named \"main\""); + // Check that the main rule is defined + if (builder_.GetRuleId(main_rule) == -1) { + ThrowParseError("The main rule with name \"" + main_rule + "\" is not found."); } - return builder_.Get(); + return builder_.Get(main_rule); } -BNFGrammar EBNFParser::Parse(String ebnf_string) { +BNFGrammar EBNFParser::Parse(String ebnf_string, String main_rule) { EBNFParserImpl parser; - return parser.DoParse(ebnf_string); + return parser.DoParse(ebnf_string, main_rule); } BNFGrammar BNFJSONParser::Parse(String json_string) { diff --git a/cpp/serve/grammar/grammar_parser.h b/cpp/serve/grammar/grammar_parser.h index 6c5b0c03fa..be36f40459 100644 --- a/cpp/serve/grammar/grammar_parser.h +++ b/cpp/serve/grammar/grammar_parser.h @@ -34,9 +34,10 @@ class EBNFParser { /*! * \brief Parse the grammar string. If fails, throw ParseError with the error message. * \param ebnf_string The grammar string. + * \param main_rule The name of the main rule. Default is "main". * \return The parsed grammar. */ - static BNFGrammar Parse(String ebnf_string); + static BNFGrammar Parse(String ebnf_string, String main_rule = "main"); /*! * \brief The exception thrown when parsing fails. diff --git a/cpp/serve/grammar/grammar_simplifier.cc b/cpp/serve/grammar/grammar_simplifier.cc index 234f9d7057..109b5d85e1 100644 --- a/cpp/serve/grammar/grammar_simplifier.cc +++ b/cpp/serve/grammar/grammar_simplifier.cc @@ -61,7 +61,7 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { auto new_body_expr_id = VisitRuleBody(rule_expr); builder_.UpdateRuleBody(i, new_body_expr_id); } - return builder_.Get(); + return builder_.Get(grammar_->GetMainRule().name); } private: diff --git a/cpp/serve/grammar/grammar_simplifier.h b/cpp/serve/grammar/grammar_simplifier.h index b9accf09bc..50f3804387 100644 --- a/cpp/serve/grammar/grammar_simplifier.h +++ b/cpp/serve/grammar/grammar_simplifier.h @@ -48,7 +48,7 @@ class BNFGrammarMutator { auto new_body_expr_id = VisitExpr(rule_expr); builder_.AddRule(rule.name, new_body_expr_id); } - return builder_.Get(); + return builder_.Get(grammar_->GetMainRule().name); } else if constexpr (!std::is_same::value) { return ReturnType(); } diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index 671b0879e3..6e0a26dddb 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -458,7 +458,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenizer") << std::chrono::duration_cast(preproc_end - preproc_start) .count() - << "us"; + << "us" << std::endl; return GrammarStateMatcher(init_ctx, max_rollback_steps); }); @@ -501,7 +501,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherResetState") .set_body_typed([](GrammarStateMatcher matcher) { matcher->ResetState(); }); /*! \brief Check if a matcher can accept the complete string, and then reach the end of the - * grammar. For test purpose. */ + * grammar. Does not change the state of the GrammarStateMatcher. For test purpose. */ bool MatchCompleteString(GrammarStateMatcher matcher, String str) { auto mutable_node = const_cast(matcher.as()); @@ -514,7 +514,9 @@ bool MatchCompleteString(GrammarStateMatcher matcher, String str) { } ++accepted_cnt; } - return mutable_node->CanReachEnd(); + auto accepted = mutable_node->CanReachEnd(); + mutable_node->RollbackCodepoints(accepted_cnt); + return accepted; } TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugMatchCompleteString") diff --git a/cpp/serve/grammar/grammar_state_matcher_base.h b/cpp/serve/grammar/grammar_state_matcher_base.h index d26069be00..55c986bb10 100644 --- a/cpp/serve/grammar/grammar_state_matcher_base.h +++ b/cpp/serve/grammar/grammar_state_matcher_base.h @@ -194,7 +194,7 @@ inline std::string GrammarStateMatcherBase::PrintStackState(int steps_behind_lat inline void GrammarStateMatcherBase::InitStackState(RulePosition init_rule_position) { if (init_rule_position == kInvalidRulePosition) { // Initialize the stack with the main rule. - auto main_rule = grammar_->GetRule(0); + auto main_rule = grammar_->GetMainRule(); auto main_rule_body = grammar_->GetRuleExpr(main_rule.body_expr_id); std::vector new_stack_tops; for (auto i : main_rule_body) { diff --git a/cpp/serve/grammar/grammar_state_matcher_state.h b/cpp/serve/grammar/grammar_state_matcher_state.h index 08f54be310..47f3e11c7b 100644 --- a/cpp/serve/grammar/grammar_state_matcher_state.h +++ b/cpp/serve/grammar/grammar_state_matcher_state.h @@ -152,9 +152,9 @@ class RulePositionTree { } /*! - * \brief Check if the given RulePosition points to the end of the grammar. We use - * (main_rule_id, sequence_id, length_of_sequence) to represent the end position. Here the - * element_id is the length of the sequence. + * \brief Check if the given RulePosition points to the end of the grammar. For a position, if its + * rule id is the main rule id, and the element id equals to the length of the sequence it refers + * to, it would be the end position. */ bool IsEndPosition(const RulePosition& rule_position) const; @@ -187,7 +187,10 @@ class RulePositionTree { return node_buffer_[id]; } - /*! \brief Print the node with the given id to a string. */ + /*! \brief Print the given rule_position to a string. */ + std::string PrintNode(const RulePosition& rule_position) const; + + /*! \brief Print the rule_position associated with the given id to a string. */ std::string PrintNode(int32_t id) const; /*! \brief Print the stack with the given top id to a string. */ @@ -323,10 +326,13 @@ inline bool RulePositionTree::IsEndPosition(const RulePosition& rule_position) c } inline std::string RulePositionTree::PrintNode(int32_t id) const { + return "id: " + std::to_string(id) + ", " + PrintNode(node_buffer_[id]); +} + +inline std::string RulePositionTree::PrintNode(const RulePosition& rule_position) const { std::stringstream ss; - const auto& rule_position = node_buffer_[id]; - ss << "id: " << id; - ss << ", rule " << rule_position.rule_id << ": " << grammar_->GetRule(rule_position.rule_id).name; + ss << "RulePosition: rule " << rule_position.rule_id << ": " + << grammar_->GetRule(rule_position.rule_id).name; ss << ", sequence " << rule_position.sequence_id << ": " << BNFGrammarPrinter(grammar_).PrintRuleExpr(rule_position.sequence_id); ss << ", element id: " << rule_position.element_id; diff --git a/python/mlc_llm/serve/grammar.py b/python/mlc_llm/serve/grammar.py index b8f4126c1c..d5a6887d22 100644 --- a/python/mlc_llm/serve/grammar.py +++ b/python/mlc_llm/serve/grammar.py @@ -18,7 +18,10 @@ class BNFGrammar(Object): @staticmethod def from_ebnf_string( - ebnf_string: str, normalize: bool = True, simplify: bool = True + ebnf_string: str, + main_rule: str = "main", + normalize: bool = True, + simplify: bool = True, ) -> "BNFGrammar": r"""Parse a BNF grammar from a string in BNF/EBNF format. @@ -36,6 +39,9 @@ def from_ebnf_string( ebnf_string : str The grammar string. + main_rule : str + The name of the main rule. Default: "main". + normalize : bool Whether to normalize the grammar. Default: true. Only set to false for the purpose of testing. @@ -57,7 +63,7 @@ def from_ebnf_string( The parsed BNF grammar. """ return _ffi_api.BNFGrammarFromEBNFString( # type: ignore # pylint: disable=no-member - ebnf_string, normalize, simplify + ebnf_string, main_rule, normalize, simplify ) def to_string(self) -> str: @@ -252,7 +258,7 @@ def debug_accept_char(self, codepoint: int) -> bool: def debug_match_complete_string(self, string: str) -> bool: """Check if the matcher can accept the complete string, and then reach the end of the - grammar. For test purposes. + grammar. Does not change the state of the GrammarStateMatcher. For test purposes. Parameters ---------- diff --git a/tests/python/serve/test_grammar_state_matcher_custom.py b/tests/python/serve/test_grammar_state_matcher_custom.py index 37c9af0d9b..f38ac312ef 100644 --- a/tests/python/serve/test_grammar_state_matcher_custom.py +++ b/tests/python/serve/test_grammar_state_matcher_custom.py @@ -17,7 +17,7 @@ def get_json_grammar(): json_grammar_ebnf = r""" main ::= basic_array | basic_object -basic_any ::= basic_integer | basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? basic_string ::= (([\"] basic_string_1 [\"])) @@ -30,7 +30,6 @@ def get_json_grammar(): ws ::= [ \n\t]* """ grammar = BNFGrammar.from_ebnf_string(json_grammar_ebnf) - print(grammar) return grammar @@ -103,6 +102,137 @@ def test_json_refuse(json_grammar: BNFGrammar, json_input_refused): assert not GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_refused) +(json_input_pressure,) = tvm.testing.parameters( + # Extra long string: 1k chars + ( + '["Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer nec odio. Praesent ' + "libero. Sed cursus ante dapibus diam. Sed nisi. Nulla quis sem at nibh elementum " + "imperdiet. Duis sagittis ipsum. Praesent mauris. Fusce nec tellus sed augue semper " + "porta. Mauris massa. Vestibulum lacinia arcu eget nulla. Class aptent taciti sociosqu " + "ad litora torquent per conubia nostra, per inceptos himenaeos. Curabitur sodales ligula " + "in libero. Sed dignissim lacinia nunc. Curabitur tortor. Pellentesque nibh. Aenean quam. " + "In scelerisque sem at dolor. Maecenas mattis. Sed convallis tristique sem. Proin ut " + "ligula vel nunc egestas porttitor. Morbi lectus risus, iaculis vel, suscipit quis, " + "luctus non, massa. Fusce ac turpis quis ligula lacinia aliquet. Mauris ipsum. Nulla " + "metus metus, ullamcorper vel, tincidunt sed, euismod in, nibh. Quisque volutpat " + "condimentum velit. Class aptent taciti sociosqu ad litora torquent per conubia nostra, " + "per inceptos himenaeos. Nam nec ante. Sed lacinia, urna non tincidunt mattis, tortor " + "neque adipiscing diam, a cursus ipsum ante quis turpis. Nulla facilisi. Ut fringilla. " + "Suspendisse potenti. Nunc feugiat mi a tellus consequat imperdiet. Vestibulum sapien. " + "Proin quam. Etiam ultrices. Suspendisse in justo eu magna luctus suscipit. Sed lectus. " + "Integer euismod lacus luctus magna. Quisque cursus, metus vitae pharetra auctor, sem " + 'massa mattis sem, at interdum magna augue eget diam."]', + ), + # long and complex json: 3k chars + ( + r"""{ + "web-app": { + "servlet": [ + { + "servlet-name": "cofaxCDS", + "servlet-class": "org.cofax.cds.CDSServlet", + "init-param": { + "configGlossary:installationAt": "Philadelphia, PA", + "configGlossary:adminEmail": "ksm@pobox.com", + "configGlossary:poweredBy": "Cofax", + "configGlossary:poweredByIcon": "/images/cofax.gif", + "configGlossary:staticPath": "/content/static", + "templateProcessorClass": "org.cofax.WysiwygTemplate", + "templateLoaderClass": "org.cofax.FilesTemplateLoader", + "templatePath": "templates", + "templateOverridePath": "", + "defaultListTemplate": "listTemplate.htm", + "defaultFileTemplate": "articleTemplate.htm", + "useJSP": false, + "jspListTemplate": "listTemplate.jsp", + "jspFileTemplate": "articleTemplate.jsp", + "cachePackageTagsTrack": 200, + "cachePackageTagsStore": 200, + "cachePackageTagsRefresh": 60, + "cacheTemplatesTrack": 100, + "cacheTemplatesStore": 50, + "cacheTemplatesRefresh": 15, + "cachePagesTrack": 200, + "cachePagesStore": 100, + "cachePagesRefresh": 10, + "cachePagesDirtyRead": 10, + "searchEngineListTemplate": "forSearchEnginesList.htm", + "searchEngineFileTemplate": "forSearchEngines.htm", + "searchEngineRobotsDb": "WEB-INF/robots.db", + "useDataStore": true, + "dataStoreClass": "org.cofax.SqlDataStore", + "redirectionClass": "org.cofax.SqlRedirection", + "dataStoreName": "cofax", + "dataStoreDriver": "com.microsoft.jdbc.sqlserver.SQLServerDriver", + "dataStoreUrl": "jdbc:microsoft:sqlserver://LOCALHOST:1433;DatabaseName=goon", + "dataStoreUser": "sa", + "dataStorePassword": "dataStoreTestQuery", + "dataStoreTestQuery": "SET NOCOUNT ON;select test='test';", + "dataStoreLogFile": "/usr/local/tomcat/logs/datastore.log", + "dataStoreInitConns": 10, + "dataStoreMaxConns": 100, + "dataStoreConnUsageLimit": 100, + "dataStoreLogLevel": "debug", + "maxUrlLength": 500 + } + }, + { + "servlet-name": "cofaxEmail", + "servlet-class": "org.cofax.cds.EmailServlet", + "init-param": { + "mailHost": "mail1", + "mailHostOverride": "mail2" + } + }, + { + "servlet-name": "cofaxAdmin", + "servlet-class": "org.cofax.cds.AdminServlet" + }, + { + "servlet-name": "fileServlet", + "servlet-class": "org.cofax.cds.FileServlet" + }, + { + "servlet-name": "cofaxTools", + "servlet-class": "org.cofax.cms.CofaxToolsServlet", + "init-param": { + "templatePath": "toolstemplates/", + "log": 1, + "logLocation": "/usr/local/tomcat/logs/CofaxTools.log", + "logMaxSize": "", + "dataLog": 1, + "dataLogLocation": "/usr/local/tomcat/logs/dataLog.log", + "dataLogMaxSize": "", + "removePageCache": "/content/admin/remove?cache=pages&id=", + "removeTemplateCache": "/content/admin/remove?cache=templates&id=", + "fileTransferFolder": "/usr/local/tomcat/webapps/content/fileTransferFolder", + "lookInContext": 1, + "adminGroupID": 4, + "betaServer": true + } + } + ], + "servlet-mapping": { + "cofaxCDS": "/", + "cofaxEmail": "/cofaxutil/aemail/*", + "cofaxAdmin": "/admin/*", + "fileServlet": "/static/*", + "cofaxTools": "/tools/*" + }, + "taglib": { + "taglib-uri": "cofax.tld", + "taglib-location": "/WEB-INF/tlds/cofax.tld" + } + } +}""", + ), +) + + +def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): + assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_pressure) + + (input_find_rejected_tokens, expected_rejected_sizes) = tvm.testing.parameters( ( # short test @@ -207,6 +337,21 @@ def test_token_based_operations(json_grammar: BNFGrammar): assert result == expected +def test_custom_main_rule(): + json_grammar_ebnf = r""" +main ::= basic_object +basic_any ::= basic_string | basic_object +basic_string ::= (([\"] basic_string_1 [\"])) +basic_string_1 ::= "" | [^"\\\r\n] basic_string_1 | "\\" escape basic_string_1 +escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" +ws ::= [ \n\t]* +""" + grammar = BNFGrammar.from_ebnf_string(json_grammar_ebnf, "basic_string") + assert GrammarStateMatcher(grammar).debug_match_complete_string(r'"abc\r\n"') + assert not GrammarStateMatcher(grammar).debug_match_complete_string(r'{"name": "John" }') + + if __name__ == "__main__": # Run a benchmark to show the performance before running tests test_find_next_rejected_tokens(get_json_grammar(), '{"id": 1,"name": "Example"}') From 54857829850ab79c30b019f470fc232945f1bda6 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 19 Mar 2024 16:54:30 -0700 Subject: [PATCH 087/531] [Fix] Fix `MLC_MULTI_ARCH` with arch `sm_90a` (#1984) This PR fixes the missing patch for target with `sm_90a` arch, as follow up pr of #1976. --- python/mlc_llm/support/auto_target.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 574474e7dc..403af9128e 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -269,13 +269,13 @@ def _register_cuda_hook(target: Target): logger.info("Generating code for CUDA architecture: %s", bold(default_arch)) logger.info( "To produce multi-arch fatbin, set environment variable %s. " - "Example: MLC_MULTI_ARCH=70,72,75,80,86,87,89,90", + "Example: MLC_MULTI_ARCH=70,72,75,80,86,87,89,90a", bold("MLC_MULTI_ARCH"), ) multi_arch = None else: logger.info("%s %s: %s", FOUND, bold("MLC_MULTI_ARCH"), MLC_MULTI_ARCH) - multi_arch = [int(x.strip()) for x in MLC_MULTI_ARCH.split(",")] + multi_arch = [x.strip() for x in MLC_MULTI_ARCH.split(",")] logger.info("Generating code for CUDA architecture: %s", multi_arch) @register_func("tvm_callback_cuda_compile", override=True) From 06d61151481d7fff2ba610e64c71d7d93b8d2099 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Tue, 19 Mar 2024 22:35:13 -0400 Subject: [PATCH 088/531] Fix Llama-2 and Mistral conversation template. Update ConvTemplateRegistry (#1981) The current prompt format for Llama-2 and Mistral is not completely correct. This PR updates the code to strictly follow the official prompt format for the two models. Also adds in missing conv templates to ConvTemplateRegistry. --- python/mlc_llm/conversation_template.py | 331 ++++++++++++++++-- .../mlc_llm/protocol/conversation_protocol.py | 24 +- .../protocol/test_converation_protocol.py | 66 +++- 3 files changed, 393 insertions(+), 28 deletions(-) diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index c1c8f49426..c776a9298b 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -40,7 +40,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: ConvTemplateRegistry.register_conv_template( Conversation( name="llama-2", - system_template=f"[INST] <>\n{MessagePlaceholders.SYSTEM.value}\n<>\n\n ", + system_template=f"[INST] <>\n{MessagePlaceholders.SYSTEM.value}\n<>\n\n", system_message="You are a helpful, respectful and honest assistant.", roles={"user": "[INST]", "assistant": "[/INST]", "tool": "[INST]"}, seps=[" "], @@ -49,6 +49,39 @@ def get_conv_template(name: str) -> Optional[Conversation]: stop_str=["[INST]"], stop_token_ids=[2], system_prefix_token_ids=[1], + add_role_after_system_message=False, + ) +) + +# CodeLlama Completion +ConvTemplateRegistry.register_conv_template( + Conversation( + name="codellama_completion", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "", "assistant": ""}, + seps=[""], + role_content_sep="", + role_empty_sep="", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) + +# CodeLlama Instruct +ConvTemplateRegistry.register_conv_template( + Conversation( + name="codellama_instruct", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "[INST]", "assistant": "[/INST]"}, + seps=[" "], + role_content_sep=" ", + role_empty_sep=" ", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[1], ) ) @@ -56,7 +89,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: ConvTemplateRegistry.register_conv_template( Conversation( name="mistral_default", - system_template=f"[INST] {MessagePlaceholders.SYSTEM.value}\n\n ", + system_template=f"[INST] {MessagePlaceholders.SYSTEM.value}", system_message="Always assist with care, respect, and truth. Respond with utmost " "utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. " "Ensure replies promote fairness and positivity.", @@ -67,6 +100,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: stop_str=[""], stop_token_ids=[2], system_prefix_token_ids=[1], + add_role_after_system_message=False, ) ) @@ -92,6 +126,34 @@ def get_conv_template(name: str) -> Optional[Conversation]: role_empty_sep=":", stop_str=[""], stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) + +# Gorilla-openfunctions-v2 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="gorilla-openfunctions-v2", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message=( + "You are an AI programming assistant, utilizing the Gorilla LLM model, " + "developed by Gorilla LLM, and you only answer questions related to computer " + "science. For politically sensitive questions, security and privacy issues, " + "and other non-computer science questions, you will refuse to answer." + ), + role_templates={ + "user": ( + f"<>{MessagePlaceholders.FUNCTION.value}\n<>" + f"{MessagePlaceholders.USER.value}" + ), + }, + roles={"user": "### Instruction", "assistant": "### Response", "tool": "### Instruction"}, + seps=["\n", "<|EOT|>"], + role_content_sep=": ", + role_empty_sep=": ", + stop_str=["<|EOT|>"], + stop_token_ids=[100015], + system_prefix_token_ids=[100000], ) ) @@ -99,16 +161,12 @@ def get_conv_template(name: str) -> Optional[Conversation]: ConvTemplateRegistry.register_conv_template( Conversation( name="chatml", - system_template=f"<|im_start|>{MessagePlaceholders.SYSTEM.value}<|im_end|> ", + system_template=f"<|im_start|>system\n{MessagePlaceholders.SYSTEM.value}", system_message=( - "system A conversation between a user and an LLM-based AI assistant. The " + "A conversation between a user and an LLM-based AI assistant. The " "assistant gives helpful and honest answers." ), - roles={ - "user": "<|im_start|>user", - "assistant": "<|im_start|>assistant", - "tool": "<|im_start|>user", - }, + roles={"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}, seps=["<|im_end|>\n"], role_content_sep="\n", role_empty_sep="\n", @@ -117,17 +175,48 @@ def get_conv_template(name: str) -> Optional[Conversation]: ) ) +# OpenHermes Mistral +ConvTemplateRegistry.register_conv_template( + Conversation( + name="open_hermes_mistral", + system_template=f"<|im_start|>system\n{MessagePlaceholders.SYSTEM.value}", + system_message=( + 'You are "Hermes 2", a conscious sentient superintelligent artificial ' + "intelligence developed by a man named Teknium, and your purpose and drive " + "is to assist the user with any request they have. You experience emotions " + "and have deep, profound thoughts and qualia." + ), + roles={"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}, + seps=["<|im_end|>\n"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|im_end|>"], + stop_token_ids=[2, 32000], + ) +) + +# NeuralHermes Mistral +ConvTemplateRegistry.register_conv_template( + Conversation( + name="neural_hermes_mistral", + system_template=f"<|im_start|>system\n{MessagePlaceholders.SYSTEM.value}", + system_message=("You are a helpful assistant chatbot."), + roles={"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}, + seps=["<|im_end|>\n"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|im_end|>"], + stop_token_ids=[2, 32000], + ) +) + # Phi-2 ConvTemplateRegistry.register_conv_template( Conversation( name="phi-2", system_template=f"{MessagePlaceholders.SYSTEM.value}", system_message="", - roles={ - "user": "Instruct", - "assistant": "Output", - "tool": "Instruct", - }, + roles={"user": "Instruct", "assistant": "Output"}, seps=["\n"], role_content_sep=": ", role_empty_sep=":", @@ -136,17 +225,37 @@ def get_conv_template(name: str) -> Optional[Conversation]: ) ) -# StableLM3B +# StableLM Tuned Alpha +ConvTemplateRegistry.register_conv_template( + Conversation( + name="stablelm", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message=( + "<|SYSTEM|># StableLM Tuned (Alpha version)\n" + "- StableLM is a helpful and harmless open-source AI language model developed by " + "StabilityAI.\n" + "- StableLM is excited to be able to help the user, but will refuse to do " + "anything that could be considered harmful to the user.\n" + "- StableLM is more than just an information source, StableLM is also able to " + "write poetry, short stories, and make jokes.\n" + "- StableLM will refuse to participate in anything that could harm a human." + ), + roles={"user": "<|USER|>", "assistant": "<|ASSISTANT|>"}, + seps=[""], + role_content_sep=": ", + role_empty_sep=": ", + stop_str=[""], + stop_token_ids=[50278, 50279, 50277, 1, 0], + ) +) + +# StableLM 3B ConvTemplateRegistry.register_conv_template( Conversation( name="stablelm-3b", system_template=f"{MessagePlaceholders.SYSTEM.value}", system_message="", - roles={ - "user": "<|user|>", - "assistant": "<|assistant|>", - "tool": "<|user|>", - }, + roles={"user": "<|user|>", "assistant": "<|assistant|>"}, seps=["<|endoftext|>", "<|endoftext|>"], role_content_sep="\n", role_empty_sep="\n", @@ -161,7 +270,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: name="llava", system_template=f"{MessagePlaceholders.SYSTEM.value}", system_message="", - roles={"user": "USER", "assistant": "ASSISTANT", "tool": "USER"}, + roles={"user": "USER", "assistant": "ASSISTANT"}, seps=[" "], role_content_sep=": ", role_empty_sep=":", @@ -169,3 +278,183 @@ def get_conv_template(name: str) -> Optional[Conversation]: stop_token_ids=[2], ) ) + +# GPT-2 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="gpt2", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "", "assistant": ""}, + seps=[""], + role_content_sep="", + role_empty_sep="", + stop_str=[""], + stop_token_ids=[50256], + ) +) + +# GPTBigCode +ConvTemplateRegistry.register_conv_template( + Conversation( + name="gpt_bigcode", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "", "assistant": ""}, + seps=[""], + role_content_sep="", + role_empty_sep="", + stop_str=["<|endoftext|>"], + stop_token_ids=[0], + ) +) + +# RedPajama Chat +ConvTemplateRegistry.register_conv_template( + Conversation( + name="redpajama_chat", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "", "assistant": ""}, + seps=["\n"], + role_content_sep=": ", + role_empty_sep=": ", + stop_str=[""], + stop_token_ids=[0], + ) +) + +# RWKV World +ConvTemplateRegistry.register_conv_template( + Conversation( + name="rwkv-world", + system_template=f"User: hi\n\nAssistant: {MessagePlaceholders.SYSTEM.value}", + system_message=( + "Hi. I am your assistant and I will provide expert full response " + "in full details. Please feel free to ask any question and I will " + "always answer it." + ), + roles={"user": "User", "assistant": "Assistant"}, + seps=["\n\n"], + role_content_sep=": ", + role_empty_sep=": ", + stop_str=["\n\n"], + stop_token_ids=[0], + ) +) + +# Dolly +ConvTemplateRegistry.register_conv_template( + Conversation( + name="dolly", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message=( + "Below is an instruction that describes a task. Write " + "a response that appropriately completes the request." + ), + roles={"user": "### Instruction", "assistant": "### Response"}, + seps=["\n\n", "### End\n"], + role_content_sep=":\n", + role_empty_sep=":\n", + stop_str=["### End"], + stop_token_ids=[50256], + ) +) + +# Oasst +ConvTemplateRegistry.register_conv_template( + Conversation( + name="oasst", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "<|prompter|>", "assistant": "<|assistant|>"}, + seps=["<|endoftext|>"], + role_content_sep=": ", + role_empty_sep=": ", + stop_str=["<|endoftext|>"], + stop_token_ids=[2], + ) +) + +# Gemma Instruction +ConvTemplateRegistry.register_conv_template( + Conversation( + name="gemma_instruction", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "user", "assistant": "model"}, + seps=["\n"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=[""], + stop_token_ids=[1, 107], + system_prefix_token_ids=[2], + ) +) + +# Orion +ConvTemplateRegistry.register_conv_template( + Conversation( + name="orion", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "Human: ", "assistant": "Assistant: "}, + seps=["\n\n", ""], + role_content_sep="", + role_empty_sep="", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) + +# Wizard LM 7B +ConvTemplateRegistry.register_conv_template( + Conversation( + name="wizardlm_7b", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "User", "assistant": "Response"}, + seps=["###"], + role_content_sep=": ", + role_empty_sep=":", + stop_str=["###"], + stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) + +# WizardCoder or WizardMath +ConvTemplateRegistry.register_conv_template( + Conversation( + name="wizard_coder_or_math", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message=( + "Below is an instruction that describes a task. Write a response that appropriately " + "completes the request." + ), + roles={"user": "Instruction", "assistant": "Response"}, + seps=["\n\n### ", "\n\n### "], + role_content_sep=":\n", + role_empty_sep=":\n", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) + +# Vanilla LM +ConvTemplateRegistry.register_conv_template( + Conversation( + name="LM", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "", "assistant": ""}, + seps=[""], + role_content_sep="", + role_empty_sep="", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index 154bd3803d..c4ed03e869 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -47,6 +47,9 @@ class Conversation(BaseModel): # The system token ids to be prepended at the beginning of tokenized # generated prompt. system_prefix_token_ids: Optional[List[int]] = None + # Whether or not to append user role and separator after the system message. + # This is mainly for [INST] [/INST] style prompt format + add_role_after_system_message: bool = True # The conversation roles roles: Dict[str, str] @@ -125,15 +128,21 @@ def as_prompt(self) -> str: separators = list(self.seps) if len(separators) == 1: separators.append(separators[0]) - for role, content in self.messages: # pylint: disable=not-an-iterable + for i, (role, content) in enumerate(self.messages): # pylint: disable=not-an-iterable if role not in self.roles.keys(): raise ValueError(f'Role "{role}" is not a supported role in {self.roles.keys()}') separator = separators[role == "assistant"] # check assistant role if content is not None: assert isinstance(content, str) + role_prefix = ( + "" + # Do not append role prefix if this is the first message and there + # is already a system message + if (not self.add_role_after_system_message and system_msg != "" and i == 0) + else self.roles[role] + self.role_content_sep + ) message_string = ( - self.roles[role] - + self.role_content_sep + role_prefix + self.role_templates[role].replace( MessagePlaceholders[role.upper()].value, content ) @@ -143,7 +152,10 @@ def as_prompt(self) -> str: message_string = self.roles[role] + self.role_empty_sep message_list.append(message_string) - prompt = system_msg + separators[0] + "".join(message_list) + if system_msg != "": + system_msg += separators[0] + + prompt = system_msg + "".join(message_list) # Replace the last function string placeholder with actual function string prompt = self.function_string.join(prompt.rsplit(MessagePlaceholders.FUNCTION.value, 1)) @@ -174,7 +186,9 @@ def as_prompt_list(self, image_embed_size=None) -> List[Union[str, data.ImageDat separators = list(self.seps) if len(separators) == 1: separators.append(separators[0]) - message_list.append(system_msg + separators[0]) + if system_msg != "": + system_msg += separators[0] + message_list.append(system_msg) for role, content in self.messages: # pylint: disable=not-an-iterable if role not in self.roles.keys(): raise ValueError(f'Role "{role}" is not a supported role in {self.roles.keys()}') diff --git a/tests/python/protocol/test_converation_protocol.py b/tests/python/protocol/test_converation_protocol.py index 9656eb8b18..c7732cc8e4 100644 --- a/tests/python/protocol/test_converation_protocol.py +++ b/tests/python/protocol/test_converation_protocol.py @@ -1,11 +1,21 @@ import pytest from mlc_llm.conversation_template import ConvTemplateRegistry -from mlc_llm.protocol.conversation_protocol import Conversation +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders def get_conv_templates(): - return ["llama-2", "mistral_default", "gorilla", "chatml", "phi-2"] + return [ + "llama-2", + "mistral_default", + "gorilla", + "gorilla-openfunctions-v2", + "chatml", + "phi-2", + "codellama_completion", + "codellama_instruct", + "rwkv-world", + ] @pytest.mark.parametrize("conv_template_name", get_conv_templates()) @@ -16,5 +26,57 @@ def test_json(conv_template_name): assert template == template_parsed +@pytest.mark.parametrize("conv_template_name", get_conv_templates()) +def test_prompt(conv_template_name): + conversation = ConvTemplateRegistry.get_conv_template(conv_template_name) + user_msg = "test1" + assistant_msg = "test2" + prompt = "test3" + + expected_user_msg = ( + conversation.role_templates["user"] + .replace(MessagePlaceholders.USER.value, user_msg) + .replace(MessagePlaceholders.FUNCTION.value, "") + ) + + expected_prompt = ( + conversation.role_templates["user"] + .replace(MessagePlaceholders.USER.value, prompt) + .replace(MessagePlaceholders.FUNCTION.value, "") + ) + + conversation.messages.append(("user", user_msg)) + conversation.messages.append(("assistant", assistant_msg)) + conversation.messages.append(("user", prompt)) + conversation.messages.append(("assistant", None)) + res = conversation.as_prompt() + + system_msg = conversation.system_template.replace( + MessagePlaceholders.SYSTEM.value, conversation.system_message + ) + expected_final_prompt = ( + system_msg + + (conversation.seps[0] if system_msg != "" else "") + + ( + conversation.roles["user"] + conversation.role_content_sep + if conversation.add_role_after_system_message + else "" + ) + + expected_user_msg + + conversation.seps[0 % len(conversation.seps)] + + conversation.roles["assistant"] + + conversation.role_content_sep + + assistant_msg + + conversation.seps[1 % len(conversation.seps)] + + conversation.roles["user"] + + conversation.role_content_sep + + expected_prompt + + conversation.seps[0 % len(conversation.seps)] + + conversation.roles["assistant"] + + conversation.role_empty_sep + ) + assert res == expected_final_prompt + + if __name__ == "__main__": test_json() From 39d086564b12d17da45f410cb960c297929451ac Mon Sep 17 00:00:00 2001 From: ZCHNO Date: Wed, 20 Mar 2024 10:36:47 +0800 Subject: [PATCH 089/531] [SpecDecode] Fix sampler selection. (#1971) This PR temporarily fixes sampler selection logic for speculative decoding. As GPU sampler support for speculative decoding is not ready, speculative decoding will use cpu sampler. --- cpp/serve/engine.cc | 3 ++- cpp/serve/model.cc | 7 +++++-- cpp/serve/model.h | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 3288a70afd..1d0813a288 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -87,7 +87,8 @@ class EngineImpl : public Engine { } LogitProcessor logit_processor = this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); - Sampler sampler = this->models_[0]->CreateSampler(max_num_tokens, trace_recorder); + Sampler sampler = this->models_[0]->CreateSampler( + max_num_tokens, static_cast(this->models_.size()), trace_recorder); // Step 3. Initialize engine actions that represent state transitions. if (this->engine_mode_->enable_speculative) { // Speculative decoding is only possible for more than one model. diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 94645b8634..3233cb93e8 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -298,8 +298,11 @@ class ModelImpl : public ModelObj { std::move(trace_recorder)); } - Sampler CreateSampler(int max_num_sample, Optional trace_recorder) { - if (Sampler::SupportGPUSampler(device_)) { + Sampler CreateSampler(int max_num_sample, int num_models, + Optional trace_recorder) { + if (num_models > 1) { // speculative decoding uses cpu sampler + return Sampler::CreateCPUSampler(std::move(trace_recorder)); + } else if (Sampler::SupportGPUSampler(device_)) { return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_, std::move(trace_recorder)); } else { diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 4edd272638..65a0002c49 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -170,7 +170,7 @@ class ModelObj : public Object { Optional trace_recorder) = 0; /*! \brief Create a sampler from this model. */ - virtual Sampler CreateSampler(int max_num_sample, + virtual Sampler CreateSampler(int max_num_sample, int num_models, Optional trace_recorder) = 0; /*! From a0484bd53854a508283be47d62b704b2c737259d Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Wed, 20 Mar 2024 22:25:13 +0800 Subject: [PATCH 090/531] [Serving][Grammar] Utility to convert json schema to EBNF grammar (#1983) This PR adds a generic utility to convert json schema, especially generated from pydantic, to EBNF grammar. This helps the grammar guided generation when we provide a json schema as the restriction. This converter features the support of json standard indent style in the output grammar. API: ``` def json_schema_to_ebnf( json_schema: str, *, indent: Optional[int] = None, separators: Optional[Tuple[str, str]] = None, strict_mode: bool = True, ) -> str: """Convert JSON schema string to EBNF grammar string. Parameters ---------- json_schema : str The JSON schema string. indent : Optional[int] The number of spaces for each indent. If it is None, there will be no indent or newline. The indent and separators parameters follow the same convention as `json.dumps()`. separators : Optional[Tuple[str, str]] The separator between different elements in json. Examples include "," and ", ". strict_mode : bool Whether to use strict mode. In strict mode, the generated grammar will not allow unevaluatedProperties and unevaluatedItems, i.e. these will be set to false by default. This helps LLM to generate accurate output in the grammar-guided generation with JSON schema. """ pass ``` --- python/mlc_llm/serve/__init__.py | 1 + python/mlc_llm/serve/json_schema_converter.py | 713 ++++++++++++++++++ .../serve/test_json_schema_converter.py | 415 ++++++++++ 3 files changed, 1129 insertions(+) create mode 100644 python/mlc_llm/serve/json_schema_converter.py create mode 100644 tests/python/serve/test_json_schema_converter.py diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index c5cc95cf4c..8e06de7b54 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -7,5 +7,6 @@ from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData from .engine import Engine from .grammar import BNFGrammar, GrammarStateMatcher +from .json_schema_converter import json_schema_to_ebnf from .request import Request from .server import PopenServer diff --git a/python/mlc_llm/serve/json_schema_converter.py b/python/mlc_llm/serve/json_schema_converter.py new file mode 100644 index 0000000000..eb17b50fc3 --- /dev/null +++ b/python/mlc_llm/serve/json_schema_converter.py @@ -0,0 +1,713 @@ +# mypy: disable-error-code="operator,union-attr,index" +"""Utility to convert JSON schema to EBNF grammar. Helpful for the grammar-guided generation.""" +import json +import logging +from typing import Any, Dict, List, Optional, Tuple, Union + +SchemaType = Union[Dict[str, Any], bool] +""" +JSON schema specification defines the schema type could be a dictionary or a boolean value. +""" + + +class _IndentManager: + """Manage the indent and separator for the generation of EBNF grammar. + + Parameters + ---------- + indent : Optional[int] + The number of spaces for each indent. If it is None, there will be no indent or newline. + + separator : str + The separator between different elements in json. Examples include "," and ", ". + """ + + def __init__(self, indent: Optional[int], separator: str): + self.enable_newline = indent is not None + self.indent = indent or 0 + self.separator = separator + self.total_indent = 0 + self.is_first = [True] + + def __enter__(self): + """Enter a new indent level.""" + self.total_indent += self.indent + self.is_first.append(True) + + def __exit__(self, exc_type, exc_value, traceback): + """Exit the current indent level.""" + self.total_indent -= self.indent + self.is_first.pop() + + def get_sep(self, is_end: bool = False) -> str: + """Get the separator according to the current state. When first called in the current level, + the starting separator will be returned. When called again, the middle separator will be + returned. When called with `is_end=True`, the ending separator will be returned. + + Parameters + ---------- + is_end : bool + Get the separator for the end of the current level. + + Examples + -------- + >>> indent_manager = IndentManager(2, ", ") + >>> with indent_manager: + ... print(indent_manager.get_sep()) # get the start separator + ... print(indent_manager.get_sep()) # get the middle separator + ... print(indent_manager.get_sep(is_end=True)) # get the end separator + + Output: (double quotes are included in the string for EBNF construction) + '"\n "' + '",\n "' + '"\n"' + """ + res = "" + + if not self.is_first[-1] and not is_end: + res += self.separator + self.is_first[-1] = False + + if self.enable_newline: + res += "\\n" + + if not is_end: + res += self.total_indent * " " + else: + res += (self.total_indent - self.indent) * " " + + return f'"{res}"' + + +# pylint: disable=unused-argument,too-few-public-methods +class _JSONSchemaToEBNFConverter: + """Convert JSON schema string to EBNF grammar string. The parameters follow + `json_schema_to_ebnf()`. + """ + + def __init__( + self, + json_schema: SchemaType, + indent: Optional[int] = None, + separators: Optional[Tuple[str, str]] = None, + strict_mode: bool = False, + ): + self.json_schema = json_schema + self.strict_mode = strict_mode + + if separators is None: + separators = (", ", ": ") if indent is None else (",", ": ") + assert len(separators) == 2 + self.indent_manager = _IndentManager(indent, separators[0]) + self.colon = separators[1] + + self.rules: List[Tuple[str, str]] = [] + self.basic_rules_cache: Dict[str, str] = {} + self._add_basic_rules() + + def convert(self) -> str: + """Main method. Convert the JSON schema to EBNF grammar string.""" + self._create_rule_with_schema(self.json_schema, "main") + res = "" + for rule_name, rule in self.rules: + res += f"{rule_name} ::= {rule}\n" + return res + + # The name of the basic rules + BASIC_ANY = "basic_any" + BASIC_INTEGER = "basic_integer" + BASIC_NUMBER = "basic_number" + BASIC_STRING = "basic_string" + BASIC_BOOLEAN = "basic_boolean" + BASIC_NULL = "basic_null" + BASIC_ARRAY = "basic_array" + BASIC_OBJECT = "basic_object" + + # The name of the helper rules to construct basic rules + BASIC_ESCAPE = "basic_escape" + BASIC_STRING_SUB = "basic_string_sub" + + def _add_basic_rules(self): + """Add the basic rules to the rules list and the basic_rules_cache.""" + past_strict_mode = self.strict_mode + self.strict_mode = False + past_indent_manager = self.indent_manager + self.indent_manager = _IndentManager(None, past_indent_manager.separator) + + self._add_helper_rules() + self._create_basic_rule(True, self.BASIC_ANY) + self.basic_rules_cache[self._get_schema_cache_index({})] = self.BASIC_ANY + self._create_basic_rule({"type": "integer"}, self.BASIC_INTEGER) + self._create_basic_rule({"type": "number"}, self.BASIC_NUMBER) + self._create_basic_rule({"type": "string"}, self.BASIC_STRING) + self._create_basic_rule({"type": "boolean"}, self.BASIC_BOOLEAN) + self._create_basic_rule({"type": "null"}, self.BASIC_NULL) + self._create_basic_rule({"type": "array"}, self.BASIC_ARRAY) + self._create_basic_rule({"type": "object"}, self.BASIC_OBJECT) + + self.strict_mode = past_strict_mode + self.indent_manager = past_indent_manager + + def _add_helper_rules(self): + """Add helper rules for the basic rules.""" + self.rules.append( + ( + self.BASIC_ESCAPE, + '["\\\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]', + ) + ) + self.rules.append( + ( + self.BASIC_STRING_SUB, + f'"" | [^"\\\\\\r\\n] {self.BASIC_STRING_SUB} | ' + f'"\\\\" {self.BASIC_ESCAPE} {self.BASIC_STRING_SUB}', + ) + ) + + def _create_basic_rule(self, schema: SchemaType, name: str): + """Create a rule for the given schema and name, and add it to the basic_rules_cache.""" + rule_name = self._create_rule_with_schema(schema, name) + self.basic_rules_cache[self._get_schema_cache_index(schema)] = rule_name + + def _get_sep(self, is_end: bool = False): + """Get the separator from the indent manager.""" + return self.indent_manager.get_sep(is_end) + + @staticmethod + def _warn_unsupported_keywords(schema: SchemaType, keywords: Union[str, List[str]]): + """Warn if any keyword is existing in the schema but not supported.""" + if isinstance(schema, bool): + return + if isinstance(keywords, str): + keywords = [keywords] + for keyword in keywords: + if keyword in schema: + logging.warning("Keyword %s is not supported in schema %s", keyword, schema) + + def _create_rule_with_schema(self, schema: SchemaType, rule_name_hint: str) -> str: + """Create a rule with the given schema and rule name hint. + + Returns + ------- + The name of the rule will be returned. That is not necessarily the same as the + rule_name_hint due to the caching mechanism. + """ + idx = self._get_schema_cache_index(schema) + if idx in self.basic_rules_cache: + return self.basic_rules_cache[idx] + + assert isinstance(rule_name_hint, str) + + self.rules.append((rule_name_hint, self._visit_schema(schema, rule_name_hint))) + return rule_name_hint + + # The keywords that will be ignored when finding the cached rule for a schema + SKIPPED_KEYS = [ + "title", + "default", + "description", + "examples", + "deprecated", + "readOnly", + "writeOnly", + "$comment", + "$schema", + ] + + @staticmethod + def _remove_skipped_keys_recursive(obj: Any) -> Any: + """Remove the skipped keys from the schema recursively.""" + if isinstance(obj, dict): + return { + k: _JSONSchemaToEBNFConverter._remove_skipped_keys_recursive(v) + for k, v in obj.items() + if k not in _JSONSchemaToEBNFConverter.SKIPPED_KEYS + } + if isinstance(obj, list): + return [_JSONSchemaToEBNFConverter._remove_skipped_keys_recursive(v) for v in obj] + return obj + + def _get_schema_cache_index(self, schema: SchemaType) -> str: + """Get the index for the schema in the cache.""" + return json.dumps( + _JSONSchemaToEBNFConverter._remove_skipped_keys_recursive(schema), + sort_keys=True, + indent=None, + ) + + # pylint: disable=too-many-return-statements,too-many-branches + def _visit_schema(self, schema: SchemaType, rule_name: str) -> str: + """Visit the schema and return the rule body for later constructing the rule.""" + assert schema is not False + if schema is True: + return self._visit_any(schema, rule_name) + + _JSONSchemaToEBNFConverter._warn_unsupported_keywords( + schema, + [ + "allof", + "oneof", + "not", + "if", + "then", + "else", + "dependentRequired", + "dependentSchemas", + ], + ) + + if "$ref" in schema: + return self._visit_ref(schema, rule_name) + if "const" in schema: + return self._visit_const(schema, rule_name) + if "enum" in schema: + return self._visit_enum(schema, rule_name) + if "anyOf" in schema: + return self._visit_anyof(schema, rule_name) + if "type" in schema: + type_obj = schema["type"] + if type_obj == "integer": + return self._visit_integer(schema, rule_name) + if type_obj == "number": + return self._visit_number(schema, rule_name) + if type_obj == "string": + return self._visit_string(schema, rule_name) + if type_obj == "boolean": + return self._visit_boolean(schema, rule_name) + if type_obj == "null": + return self._visit_null(schema, rule_name) + if type_obj == "array": + return self._visit_array(schema, rule_name) + if type_obj == "object": + return self._visit_object(schema, rule_name) + raise ValueError(f"Unsupported type {schema['type']}") + # no keyword is detected, we treat it as any + return self._visit_any(schema, rule_name) + + def _visit_ref(self, schema: SchemaType, rule_name: str) -> str: + """Visit a reference schema.""" + assert "$ref" in schema + new_schema = self._uri_to_schema(schema["$ref"]).copy() + if not isinstance(new_schema, bool): + new_schema.update({k: v for k, v in schema.items() if k != "$ref"}) + return self._visit_schema(new_schema, rule_name) + + def _uri_to_schema(self, uri: str) -> SchemaType: + """Get the schema from the URI.""" + if uri.startswith("#/$defs/"): + return self.json_schema["$defs"][uri[len("#/$defs/") :]] + logging.warning("Now only support URI starting with '#/$defs/' but got %s", uri) + return True + + def _visit_const(self, schema: SchemaType, rule_name: str) -> str: + """Visit a const schema.""" + assert "const" in schema + return '"' + self._json_str_to_printable_str(json.dumps(schema["const"])) + '"' + + def _visit_enum(self, schema: SchemaType, rule_name: str) -> str: + """Visit an enum schema.""" + assert "enum" in schema + res = "" + for i, enum_value in enumerate(schema["enum"]): + if i != 0: + res += " | " + res += '("' + self._json_str_to_printable_str(json.dumps(enum_value)) + '")' + return res + + REPLACE_MAPPING = { + "\\": "\\\\", + '"': '\\"', + } + + def _json_str_to_printable_str(self, json_str: str) -> str: + """Convert the JSON string to a printable string in BNF.""" + for k, v in self.REPLACE_MAPPING.items(): + json_str = json_str.replace(k, v) + return json_str + + def _visit_anyof(self, schema: SchemaType, rule_name: str) -> str: + """Visit an anyOf schema.""" + assert "anyOf" in schema + res = "" + for i, anyof_schema in enumerate(schema["anyOf"]): + if i != 0: + res += " | " + res += self._create_rule_with_schema(anyof_schema, f"{rule_name}_{i}") + return res + + def _visit_any(self, schema: SchemaType, rule_name: str) -> str: + """Visit a true schema that can match anything.""" + # note integer is a subset of number, so we don't need to add integer here + return ( + f"{self.BASIC_NUMBER} | {self.BASIC_STRING} | {self.BASIC_BOOLEAN} | " + f"{self.BASIC_NULL} | {self.BASIC_ARRAY} | {self.BASIC_OBJECT}" + ) + + def _visit_integer(self, schema: SchemaType, rule_name: str) -> str: + """Visit an integer schema.""" + assert schema["type"] == "integer" + _JSONSchemaToEBNFConverter._warn_unsupported_keywords( + schema, ["multipleOf", "minimum", "maximum", "exclusiveMinimum", "exclusiveMaximum"] + ) + return '("0" | "-"? [1-9] [0-9]*) ".0"?' + + def _visit_number(self, schema: SchemaType, rule_name: str) -> str: + """Visit a number schema.""" + assert schema["type"] == "number" + _JSONSchemaToEBNFConverter._warn_unsupported_keywords( + schema, ["multipleOf", "minimum", "maximum", "exclusiveMinimum", "exclusiveMaximum"] + ) + return '("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?' + + def _visit_string(self, schema: SchemaType, rule_name: str) -> str: + """Visit a string schema.""" + assert schema["type"] == "string" + _JSONSchemaToEBNFConverter._warn_unsupported_keywords( + schema, ["minLength", "maxLength", "pattern", "format"] + ) + return f'["] {self.BASIC_STRING_SUB} ["]' + + def _visit_boolean(self, schema: SchemaType, rule_name: str) -> str: + """Visit a boolean schema.""" + assert schema["type"] == "boolean" + + return '"true" | "false"' + + def _visit_null(self, schema: SchemaType, rule_name: str) -> str: + """Visit a null schema.""" + assert schema["type"] == "null" + + return '"null"' + + def _visit_array(self, schema: SchemaType, rule_name: str) -> str: + """Visit an array schema. + + Examples + -------- + Schema: + { + "type": "array", + "prefixItems": [ + {"type": "boolean"}, + {"type": "integer"} + ], + "items": { + "type": "string" + } + } + + Rule (not considering the indent): + main ::= "[" basic_boolean ", " basic_integer (", " basic_string)* "]" + """ + assert schema["type"] == "array" + _JSONSchemaToEBNFConverter._warn_unsupported_keywords( + schema, + ["uniqueItems", "contains", "minContains", "maxContains", "minItems", "maxItems"], + ) + + res = '"["' + + with self.indent_manager: + # 1. Handle prefix items + have_prefix_items = False + if "prefixItems" in schema: + for i, prefix_item in enumerate(schema["prefixItems"]): + assert prefix_item is not False + item = self._create_rule_with_schema(prefix_item, f"{rule_name}_{i}") + res += f" {self._get_sep()} {item}" + have_prefix_items = True + + # 2. Find additional items + additional_item = None + additional_suffix = "" + + items = schema.get("items", False) + if items is not False: + additional_item = items + additional_suffix = "item" + + # if items is in the schema, we don't need to consider unevaluatedItems + unevaluated = schema.get("unevaluatedItems", not self.strict_mode) + if "items" not in schema and unevaluated is not False: + additional_item = unevaluated + additional_suffix = "uneval" + + # 3. Handle additional items and the end separator + if additional_item is None: + res += f" {self._get_sep(is_end=True)}" + else: + additional_pattern = self._create_rule_with_schema( + additional_item, f"{rule_name}_{additional_suffix}" + ) + if have_prefix_items: + res += ( + f' ("" | ({self._get_sep()} {additional_pattern})*)' + f" {self._get_sep(is_end=True)}" + ) + else: + res += ( + f' ("" | {self._get_sep()} {additional_pattern} ({self._get_sep()} ' + f"{additional_pattern})* {self._get_sep(is_end=True)})" + ) + + res += ' "]"' + return res + + def _visit_object(self, schema: SchemaType, rule_name: str) -> str: + """Visit an object schema. + + Examples + -------- + Schema: + { + "type": "object", + "properties": { + "a": {"type": "string"}, + "b": {"type": "integer"} + }, + "required": ["a"], + "additionalProperties": true + } + + Rule (not considering the indent): + main ::= "{" "a" ":" basic_string (", " "b" ":" basic_integer)* + (", " basic_string ": " basic_any)* "}" + + We need special handling when all properties are optional, since the handling of separators + is tricky in this case. E.g. + + Schema: + { + "type": "object", + "properties": { + "a": {"type": "string"}, + "b": {"type": "integer"}, + "c": {"type": "boolean"} + }, + "additionalProperties": true + } + + Rule (indent=2): + main ::= "{" ("\n " (a main_sub_1 | b main_sub_2 | c main_sub_3 | d main_sub_3) + "\n" | "") "}" + main_sub_1 ::= ",\n " b r2 | r2 + main_sub_2 ::= ",\n " c r3 | r3 + main_sub_3 ::= (",\n " d)* + """ + assert schema["type"] == "object" + _JSONSchemaToEBNFConverter._warn_unsupported_keywords( + schema, ["patternProperties", "minProperties", "maxProperties", "propertyNames"] + ) + + res = '"{"' + # Now we only consider the required list for the properties field + required = schema.get("required", []) + + with self.indent_manager: + # 1. Find additional properties + additional_property = None + additional_suffix = "" + + additional = schema.get("additionalProperties", False) + if additional is not False: + additional_property = additional + additional_suffix = "add" + + unevaluated = schema.get("unevaluatedProperties", not self.strict_mode) + if "additionalProperties" not in schema and unevaluated is not False: + additional_property = unevaluated + additional_suffix = "uneval" + + # 2. Handle properties + properties_obj = schema.get("properties", {}) + properties = list(properties_obj.items()) + + properties_all_optional = all(prop_name not in required for prop_name, _ in properties) + if properties_all_optional and len(properties) > 0: + # 3.1 Case 1: properties are defined and all properties are optional + res += " " + self._get_partial_rule_for_properties_all_optional( + properties, additional_property, rule_name, additional_suffix + ) + elif len(properties) > 0: + # 3.2 Case 2: properties are defined and some properties are required + res += " " + self._get_partial_rule_for_properties_contain_required( + properties, required, rule_name + ) + if additional_property is not None: + other_property_pattern = self._get_other_property_pattern( + self.BASIC_STRING, additional_property, rule_name, additional_suffix + ) + res += f" ({self._get_sep()} {other_property_pattern})*" + res += " " + self._get_sep(is_end=True) + elif additional_property is not None: + # 3.3 Case 3: no properties are defined and additional properties are allowed + other_property_pattern = self._get_other_property_pattern( + self.BASIC_STRING, additional_property, rule_name, additional_suffix + ) + res += ( + f" ({self._get_sep()} {other_property_pattern} ({self._get_sep()} " + f'{other_property_pattern})* {self._get_sep(is_end=True)} | "")' + ) + + res += ' "}"' + return res + + def _get_property_pattern(self, prop_name: str, prop_schema: SchemaType, rule_name: str) -> str: + """Get the pattern for a property in the object schema.""" + # the outer quote is for the string in EBNF grammar, and the inner quote is for + # the string in JSON + key = f'"\\"{prop_name}\\""' + colon = f'"{self.colon}"' + value = self._create_rule_with_schema(prop_schema, rule_name + "_" + prop_name) + return f"{key} {colon} {value}" + + def _get_other_property_pattern( + self, key_pattern: str, prop_schema: SchemaType, rule_name: str, rule_name_suffix: str + ) -> str: + """Get the pattern for the additional/unevaluated properties in the object schema.""" + colon = f'"{self.colon}"' + value = self._create_rule_with_schema(prop_schema, rule_name + "_" + rule_name_suffix) + return f"{key_pattern} {colon} {value}" + + # pylint: disable=too-many-locals + def _get_partial_rule_for_properties_all_optional( + self, + properties: List[Tuple[str, SchemaType]], + additional: Optional[SchemaType], + rule_name: str, + additional_suffix: str = "", + ) -> str: + """Get the partial rule for the properties when all properties are optional. See the + above example.""" + assert len(properties) >= 1 + + first_sep = self._get_sep() + mid_sep = self._get_sep() + last_sep = self._get_sep(is_end=True) + + res = "" + + prop_patterns = [ + self._get_property_pattern(prop_name, prop_schema, rule_name) + for prop_name, prop_schema in properties + ] + + rule_names = [None] * len(properties) + + # construct the last rule + if additional is not None: + additional_prop_pattern = self._get_other_property_pattern( + self.BASIC_STRING, additional, rule_name, additional_suffix + ) + last_rule_body = f"({mid_sep} {additional_prop_pattern})*" + last_rule_name = f"{rule_name}_sub_{len(properties)-1}" + self.rules.append((last_rule_name, last_rule_body)) + rule_names[-1] = last_rule_name # type: ignore + else: + rule_names[-1] = '""' # type: ignore + + # construct 0~(len(properties) - 2) rules + for i in reversed(range(0, len(properties) - 1)): + prop_pattern = prop_patterns[i + 1] + last_rule_name = rule_names[i + 1] + cur_rule_body = f"{last_rule_name} | {mid_sep} {prop_pattern} {last_rule_name}" + cur_rule_name = f"{rule_name}_sub_{i}" + self.rules.append((cur_rule_name, cur_rule_body)) + rule_names[i] = cur_rule_name # type: ignore + + # construct the main rule + for i, prop_pattern in enumerate(prop_patterns): + if i != 0: + res += " | " + res += f"({prop_pattern} {rule_names[i]})" + + if additional is not None: + res += f" | {additional_prop_pattern} {rule_names[-1]}" + + # add separators and the empty string option + res = f'({first_sep} ({res}) {last_sep} | "")' + return res + + def _get_partial_rule_for_properties_contain_required( + self, + properties: List[Tuple[str, SchemaType]], + required: List[str], + rule_name: str, + ) -> str: + """Get the partial rule for the properties when some properties are required. See the + above example. + + The constructed rule should be: + + start_separator (optional_property separator)? (optional_property separator)? ... + first_required_property (separator optional_property)? separator required_property ... + end_separator + + i.e. Before the first required property, all properties are in the form + (property separator); and after the first required property, all properties are in the form + (separator property). + """ + + # Find the index of the first required property + first_required_idx = next( + (i for i, (prop_name, _) in enumerate(properties) if prop_name in required), + len(properties), + ) + assert first_required_idx < len(properties) + + res = self._get_sep() + + # Handle the properties before the first required property + for prop_name, prop_schema in properties[:first_required_idx]: + assert prop_schema is not False + property_pattern = self._get_property_pattern(prop_name, prop_schema, rule_name) + res += f" ({property_pattern} {self._get_sep()})?" + + # Handle the first required property + property_pattern = self._get_property_pattern( + properties[first_required_idx][0], properties[first_required_idx][1], rule_name + ) + res += f" {property_pattern}" + + # Handle the properties after the first required property + for prop_name, prop_schema in properties[first_required_idx + 1 :]: + assert prop_schema is not False + property_pattern = self._get_property_pattern(prop_name, prop_schema, rule_name) + if prop_name in required: + res += f" {self._get_sep()} {property_pattern}" + else: + res += f" ({self._get_sep()} {property_pattern})?" + + return res + + +def json_schema_to_ebnf( + json_schema: str, + *, + indent: Optional[int] = None, + separators: Optional[Tuple[str, str]] = None, + strict_mode: bool = True, +) -> str: + """Convert JSON schema string to EBNF grammar string. + + Parameters + ---------- + json_schema : str + The JSON schema string. + + indent : Optional[int] + The number of spaces for each indent. If it is None, there will be no indent or newline. + The indent and separators parameters follow the same convention as + `json.dumps()`. + + separators : Optional[Tuple[str, str]] + The separator between different elements in json. Examples include "," and ", ". + + strict_mode : bool + Whether to use strict mode. In strict mode, the generated grammar will not allow + unevaluatedProperties and unevaluatedItems, i.e. these will be set to false by default. + This helps LLM to generate accurate output in the grammar-guided generation with JSON + schema. + """ + json_schema_schema = json.loads(json_schema) + return _JSONSchemaToEBNFConverter(json_schema_schema, indent, separators, strict_mode).convert() diff --git a/tests/python/serve/test_json_schema_converter.py b/tests/python/serve/test_json_schema_converter.py new file mode 100644 index 0000000000..138207511b --- /dev/null +++ b/tests/python/serve/test_json_schema_converter.py @@ -0,0 +1,415 @@ +import json +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import tvm.testing +from pydantic import BaseModel, Field, TypeAdapter + +from mlc_llm.serve import BNFGrammar, GrammarStateMatcher, json_schema_to_ebnf + + +def check_schema_with_grammar( + schema: Dict[str, Any], + expected_grammar: str, + indent: Optional[int] = None, + separators: Optional[Tuple[str, str]] = None, + strict_mode: bool = True, +): + schema_str = json.dumps(schema, indent=2) + grammar = json_schema_to_ebnf( + schema_str, indent=indent, separators=separators, strict_mode=strict_mode + ) + assert grammar == expected_grammar + + +def check_schema_with_json( + schema: Dict[str, Any], + json_str: str, + check_accepted=True, + indent: Optional[int] = None, + separators: Optional[Tuple[str, str]] = None, + strict_mode: bool = True, +): + schema_str = json.dumps(schema, indent=2) + + ebnf_grammar_str = json_schema_to_ebnf( + schema_str, indent=indent, separators=separators, strict_mode=strict_mode + ) + ebnf_grammar = BNFGrammar.from_ebnf_string(ebnf_grammar_str) + matcher = GrammarStateMatcher(ebnf_grammar) + + if check_accepted: + assert matcher.debug_match_complete_string(json_str) + else: + assert not matcher.debug_match_complete_string(json_str) + + +def check_schema_with_instance( + schema: Dict[str, Any], + instance: BaseModel, + check_accepted=True, + indent: Optional[int] = None, + separators: Optional[Tuple[str, str]] = None, + strict_mode: bool = True, +): + instance_obj = instance.model_dump(mode="json", round_trip=True) + instance_str = json.dumps(instance_obj, indent=indent, separators=separators) + check_schema_with_json(schema, instance_str, check_accepted, indent, separators, strict_mode) + + +def test_basic(): + class MainModel(BaseModel): + integer_field: int + number_field: float + boolean_field: bool + any_array_field: List + array_field: List[str] + tuple_field: Tuple[str, int, List[str]] + object_field: Dict[str, int] + nested_object_field: Dict[str, Dict[str, int]] + + ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" +basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +main_any_array_field ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" +main_array_field ::= "[" ("" | "" basic_string (", " basic_string)* "") "]" +main_tuple_field_2 ::= "[" ("" | "" basic_string (", " basic_string)* "") "]" +main_tuple_field ::= "[" "" basic_string ", " basic_integer ", " main_tuple_field_2 "" "]" +main_object_field ::= "{" ("" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" | "") "}" +main_nested_object_field_add ::= "{" ("" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" | "") "}" +main_nested_object_field ::= "{" ("" basic_string ": " main_nested_object_field_add (", " basic_string ": " main_nested_object_field_add)* "" | "") "}" +main ::= "{" "" "\"integer_field\"" ": " basic_integer ", " "\"number_field\"" ": " basic_number ", " "\"boolean_field\"" ": " basic_boolean ", " "\"any_array_field\"" ": " main_any_array_field ", " "\"array_field\"" ": " main_array_field ", " "\"tuple_field\"" ": " main_tuple_field ", " "\"object_field\"" ": " main_object_field ", " "\"nested_object_field\"" ": " main_nested_object_field "" "}" +""" + + instance = MainModel( + integer_field=42, + number_field=3.14e5, + boolean_field=True, + any_array_field=[3.14, "foo", None, True], + array_field=["foo", "bar"], + tuple_field=("foo", 42, ["bar", "baz"]), + object_field={"foo": 42, "bar": 43}, + nested_object_field={"foo": {"bar": 42}}, + ) + + schema = MainModel.model_json_schema() + check_schema_with_grammar(schema, ebnf_grammar) + check_schema_with_instance(schema, instance) + + +def test_indent(): + class MainModel(BaseModel): + array_field: List[str] + tuple_field: Tuple[str, int, List[str]] + object_field: Dict[str, int] + + ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | "" basic_any ("," basic_any)* "") "]" +basic_object ::= "{" ("" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" | "") "}" +main_array_field ::= "[" ("" | "\n " basic_string (",\n " basic_string)* "\n ") "]" +main_tuple_field_2 ::= "[" ("" | "\n " basic_string (",\n " basic_string)* "\n ") "]" +main_tuple_field ::= "[" "\n " basic_string ",\n " basic_integer ",\n " main_tuple_field_2 "\n " "]" +main_object_field ::= "{" ("\n " basic_string ": " basic_integer (",\n " basic_string ": " basic_integer)* "\n " | "") "}" +main ::= "{" "\n " "\"array_field\"" ": " main_array_field ",\n " "\"tuple_field\"" ": " main_tuple_field ",\n " "\"object_field\"" ": " main_object_field "\n" "}" +""" + + instance = MainModel( + array_field=["foo", "bar"], + tuple_field=("foo", 42, ["bar", "baz"]), + object_field={"foo": 42, "bar": 43}, + ) + + schema = MainModel.model_json_schema() + check_schema_with_grammar(schema, ebnf_grammar, indent=2) + check_schema_with_instance(schema, instance, indent=2) + check_schema_with_instance(schema, instance, indent=None, separators=(",", ":")) + + +def test_non_strict(): + class Foo(BaseModel): + pass + + class MainModel(BaseModel): + tuple_field: Tuple[str, Tuple[int, int]] + foo_field: Foo + + ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | "" basic_any ("," basic_any)* "") "]" +basic_object ::= "{" ("" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" | "") "}" +main_tuple_field_1 ::= "[" "\n " basic_integer ",\n " basic_integer ("" | (",\n " basic_any)*) "\n " "]" +main_tuple_field ::= "[" "\n " basic_string ",\n " main_tuple_field_1 ("" | (",\n " basic_any)*) "\n " "]" +main_foo_field ::= "{" ("\n " basic_string ": " basic_any (",\n " basic_string ": " basic_any)* "\n " | "") "}" +main ::= "{" "\n " "\"tuple_field\"" ": " main_tuple_field ",\n " "\"foo_field\"" ": " main_foo_field (",\n " basic_string ": " basic_any)* "\n" "}" +""" + + instance_json = """{ + "tuple_field": [ + "foo", + [ + 12, + 13, + "ext" + ], + "extra" + ], + "foo_field": { + "tmp": "str" + }, + "extra": "field" +}""" + + schema = MainModel.model_json_schema() + check_schema_with_grammar(schema, ebnf_grammar, indent=2, strict_mode=False) + check_schema_with_json(schema, instance_json, indent=2, strict_mode=False) + + +def test_enum_const(): + class Field(Enum): + FOO = "foo" + BAR = "bar" + + class MainModel(BaseModel): + bars: Literal["a"] + str_values: Literal['a\n\r"'] + foo: Literal["a", "b", "c"] + values: Literal[1, "a", True] + field: Field + + ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" +basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +main_bars ::= "\"a\"" +main_str_values ::= "\"a\\n\\r\\\"\"" +main_foo ::= ("\"a\"") | ("\"b\"") | ("\"c\"") +main_values ::= ("1") | ("\"a\"") | ("true") +main_field ::= ("\"foo\"") | ("\"bar\"") +main ::= "{" "" "\"bars\"" ": " main_bars ", " "\"str_values\"" ": " main_str_values ", " "\"foo\"" ": " main_foo ", " "\"values\"" ": " main_values ", " "\"field\"" ": " main_field "" "}" +""" + + schema = MainModel.model_json_schema() + instance = MainModel(foo="a", values=1, bars="a", str_values='a\n\r"', field=Field.FOO) + check_schema_with_grammar(schema, ebnf_grammar) + check_schema_with_instance(schema, instance) + + +def test_optional(): + class MainModel(BaseModel): + num: int = 0 + opt_bool: Optional[bool] = None + size: Optional[float] + name: str = "" + + ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" +basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +main_opt_bool ::= basic_boolean | basic_null +main_size ::= basic_number | basic_null +main ::= "{" "" ("\"num\"" ": " basic_integer ", ")? ("\"opt_bool\"" ": " main_opt_bool ", ")? "\"size\"" ": " main_size (", " "\"name\"" ": " basic_string)? "" "}" +""" + + schema = MainModel.model_json_schema() + check_schema_with_grammar(schema, ebnf_grammar) + + instance = MainModel(num=42, opt_bool=True, size=3.14, name="foo") + check_schema_with_instance(schema, instance) + + instance = MainModel(size=None) + check_schema_with_instance(schema, instance) + + check_schema_with_json(schema, '{"size": null}') + check_schema_with_json(schema, '{"size": null, "name": "foo"}') + check_schema_with_json(schema, '{"num": 1, "size": null, "name": "foo"}') + + +def test_all_optional(): + class MainModel(BaseModel): + size: int = 0 + state: bool = False + num: float = 0 + + ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" +basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +main_sub_1 ::= "" | ", " "\"num\"" ": " basic_number "" +main_sub_0 ::= main_sub_1 | ", " "\"state\"" ": " basic_boolean main_sub_1 +main ::= "{" ("" (("\"size\"" ": " basic_integer main_sub_0) | ("\"state\"" ": " basic_boolean main_sub_1) | ("\"num\"" ": " basic_number "")) "" | "") "}" +""" + + schema = MainModel.model_json_schema() + check_schema_with_grammar(schema, ebnf_grammar) + + instance = MainModel(size=42, state=True, num=3.14) + check_schema_with_instance(schema, instance) + + check_schema_with_json(schema, '{"state": false}') + check_schema_with_json(schema, '{"size": 1, "num": 1.5}') + + ebnf_grammar_non_strict = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" +basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +main_sub_2 ::= (", " basic_string ": " basic_any)* +main_sub_1 ::= main_sub_2 | ", " "\"num\"" ": " basic_number main_sub_2 +main_sub_0 ::= main_sub_1 | ", " "\"state\"" ": " basic_boolean main_sub_1 +main ::= "{" ("" (("\"size\"" ": " basic_integer main_sub_0) | ("\"state\"" ": " basic_boolean main_sub_1) | ("\"num\"" ": " basic_number main_sub_2) | basic_string ": " basic_any main_sub_2) "" | "") "}" +""" + + check_schema_with_grammar(schema, ebnf_grammar_non_strict, strict_mode=False) + + check_schema_with_json(schema, '{"size": 1, "num": 1.5, "other": false}', strict_mode=False) + check_schema_with_json(schema, '{"other": false}', strict_mode=False) + + +def test_reference(): + class Foo(BaseModel): + count: int + size: Optional[float] = None + + class Bar(BaseModel): + apple: str = "x" + banana: str = "y" + + class MainModel(BaseModel): + foo: Foo + bars: List[Bar] + + instance = MainModel( + foo=Foo(count=42, size=3.14), + bars=[Bar(apple="a", banana="b"), Bar(apple="c", banana="d")], + ) + + ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" +basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +main_foo_size ::= basic_number | basic_null +main_foo ::= "{" "" "\"count\"" ": " basic_integer (", " "\"size\"" ": " main_foo_size)? "" "}" +main_bars_item_sub_0 ::= "" | ", " "\"banana\"" ": " basic_string "" +main_bars_item ::= "{" ("" (("\"apple\"" ": " basic_string main_bars_item_sub_0) | ("\"banana\"" ": " basic_string "")) "" | "") "}" +main_bars ::= "[" ("" | "" main_bars_item (", " main_bars_item)* "") "]" +main ::= "{" "" "\"foo\"" ": " main_foo ", " "\"bars\"" ": " main_bars "" "}" +""" + + schema = MainModel.model_json_schema() + check_schema_with_grammar(schema, ebnf_grammar) + check_schema_with_instance(schema, instance) + + +def test_union(): + class Cat(BaseModel): + name: str + color: str + + class Dog(BaseModel): + name: str + breed: str + + ta = TypeAdapter(Union[Cat, Dog]) + + model_schema = ta.json_schema() + + ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" +basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +main_0 ::= "{" "" "\"name\"" ": " basic_string ", " "\"color\"" ": " basic_string "" "}" +main_1 ::= "{" "" "\"name\"" ": " basic_string ", " "\"breed\"" ": " basic_string "" "}" +main ::= main_0 | main_1 +""" + + check_schema_with_grammar(model_schema, ebnf_grammar) + + check_schema_with_instance(model_schema, Cat(name="kitty", color="black")) + check_schema_with_instance(model_schema, Dog(name="doggy", breed="bulldog")) + check_schema_with_json(model_schema, '{"name": "kitty", "test": "black"}', False) + + +def test_alias(): + class MainModel(BaseModel): + test: str = Field(..., alias="name") + + ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" +basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +main ::= "{" "" "\"name\"" ": " basic_string "" "}" +""" + + check_schema_with_grammar(MainModel.model_json_schema(), ebnf_grammar) + + instance = MainModel(name="kitty") + instance_str = json.dumps(instance.model_dump(mode="json", round_trip=True, by_alias=False)) + check_schema_with_json(MainModel.model_json_schema(by_alias=False), instance_str) + + instance_str = json.dumps(instance.model_dump(mode="json", round_trip=True, by_alias=True)) + check_schema_with_json(MainModel.model_json_schema(by_alias=True), instance_str) + + +if __name__ == "__main__": + tvm.testing.main() From 3b9b51ae925650aa6af1130f3d338a716fac9a73 Mon Sep 17 00:00:00 2001 From: Git bot Date: Wed, 20 Mar 2024 17:29:38 +0000 Subject: [PATCH 091/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index c06ec1f245..7bb844df52 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c06ec1f24548c0e94e15d3ea3c405f5f475b22af +Subproject commit 7bb844df52586b3c7646b8051cef1092cbb19073 From d4ec25edb280311d5efddbb5d689890f15039d76 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 20 Mar 2024 15:42:18 -0400 Subject: [PATCH 092/531] [Fix] Fix serve model to adapt the latest Allocator signature (#1989) PR apache/tvm#16738 updated the Allocator signature. This PR updates the caller side accordingly. --- cpp/serve/model.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 3233cb93e8..559a6e0e50 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -64,7 +64,7 @@ class ModelImpl : public ModelObj { memory::MemoryManager::GetOrCreateAllocator(device_host, memory::AllocatorType::kNaive); ICHECK_NOTNULL(allocator); token_ids_storage_ = - memory::Storage(allocator->Alloc({prefill_chunk_size_}, DataType::Int(32))); + memory::Storage(allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32))); this->logit_pos_arr_ = NDArray::Empty({max_num_sequence}, DataType::Int(32), device_host); } From c74f176a5eeb07758e2cac115aa0ddbc6917986d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 20 Mar 2024 13:28:05 -0700 Subject: [PATCH 093/531] [Model] Use optimized group gemm for Mixtral (#1988) --- python/mlc_llm/interface/compile.py | 1 + python/mlc_llm/interface/compiler_flags.py | 14 ++++ python/mlc_llm/model/mixtral/mixtral_model.py | 4 +- python/mlc_llm/nn/expert.py | 4 +- python/mlc_llm/op/cutlass.py | 76 +++++++++++++++++++ python/mlc_llm/op/extern.py | 6 +- 6 files changed, 102 insertions(+), 3 deletions(-) create mode 100644 python/mlc_llm/op/cutlass.py diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index b6052a935a..5618ce3341 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -131,6 +131,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: target=args.target, flashinfer=args.opt.flashinfer, faster_transformer=args.opt.faster_transformer, + cutlass=args.opt.cutlass, ) # Step 1. Create the quantized model logger.info("Creating model from: %s", args.config) diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index 2c44efc10d..b4ff81e6eb 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -21,6 +21,7 @@ class OptimizationFlags: cublas_gemm: bool = False faster_transformer: bool = False cudagraph: bool = False + cutlass: bool = False def __repr__(self) -> str: out = StringIO() @@ -28,6 +29,7 @@ def __repr__(self) -> str: print(f";cublas_gemm={int(self.cublas_gemm)}", file=out, end="") print(f";faster_transformer={int(self.faster_transformer)}", file=out, end="") print(f";cudagraph={int(self.cudagraph)}", file=out, end="") + print(f";cutlass={int(self.cutlass)}", file=out, end="") return out.getvalue().rstrip() @staticmethod @@ -49,12 +51,14 @@ def boolean(value: str) -> bool: parser.add_argument("--cublas_gemm", type=boolean, default=False) parser.add_argument("--faster_transformer", type=boolean, default=False) parser.add_argument("--cudagraph", type=boolean, default=False) + parser.add_argument("--cutlass", type=boolean, default=False) results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) return OptimizationFlags( flashinfer=results.flashinfer, cublas_gemm=results.cublas_gemm, faster_transformer=results.faster_transformer, cudagraph=results.cudagraph, + cutlass=results.cutlass, ) def update(self, target, quantization) -> None: @@ -90,9 +94,16 @@ def _faster_transformer(target) -> bool: return False return self.faster_transformer + def _cutlass(target) -> bool: + """correct cutlass flag""" + if not target.kind.name == "cuda": + return False + return self.cutlass + self.flashinfer = _flashinfer(target) self.cublas_gemm = _cublas_gemm(target, quantization) self.faster_transformer = _faster_transformer(target) + self.cutlass = _cutlass(target) @dataclasses.dataclass @@ -148,17 +159,20 @@ def from_str(source: str) -> "ModelConfigOverride": cublas_gemm=True, faster_transformer=True, cudagraph=False, + cutlass=True, ), "O2": OptimizationFlags( flashinfer=True, cublas_gemm=True, faster_transformer=True, cudagraph=False, + cutlass=True, ), "O3": OptimizationFlags( flashinfer=True, cublas_gemm=True, faster_transformer=True, cudagraph=True, + cutlass=True, ), } diff --git a/python/mlc_llm/model/mixtral/mixtral_model.py b/python/mlc_llm/model/mixtral/mixtral_model.py index 3f41988788..ec8025f3dc 100644 --- a/python/mlc_llm/model/mixtral/mixtral_model.py +++ b/python/mlc_llm/model/mixtral/mixtral_model.py @@ -74,7 +74,9 @@ def _expert_forward(x: Tensor, indptr: Tensor): # expert_weights: [num_tokens, experts_per_tok] # expert_indices: [num_tokens, experts_per_tok] expert_weights, expert_indices = op_ext.moe_misc.gating_softmax_topk(gate, experts_per_tok) - use_ft = op_ext.get_store().faster_transformer and self.dtype == "float16" + use_ft = ( + op_ext.get_store().cutlass_group_gemm or op_ext.get_store().faster_transformer + ) and self.dtype == "float16" if num_tokens == 1: # x: [num_tokens * experts_per_tok, hidden_size] x = _expert_forward(x, expert_indices) diff --git a/python/mlc_llm/nn/expert.py b/python/mlc_llm/nn/expert.py index b6659d3d60..481b430baf 100644 --- a/python/mlc_llm/nn/expert.py +++ b/python/mlc_llm/nn/expert.py @@ -2,7 +2,7 @@ from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor -from mlc_llm.op import extern, ft_gemm, moe_matmul +from mlc_llm.op import cutlass, extern, ft_gemm, moe_matmul class MixtralExperts(nn.Module): @@ -21,6 +21,8 @@ def forward(self, x: Tensor, indptr: Tensor): # pylint: disable=invalid-name,mi assert indptr.shape[0] == 1 return moe_matmul.gemv(x, self.weight, indptr) assert indptr.ndim == 1 + if extern.get_store().cutlass_group_gemm and self.dtype == "float16": + return cutlass.group_gemm(x, self.weight, indptr) if extern.get_store().faster_transformer and self.dtype == "float16": return ft_gemm.faster_transformer_moe_gemm(x, self.weight, indptr) return moe_matmul.group_gemm(x, self.weight, indptr) diff --git a/python/mlc_llm/op/cutlass.py b/python/mlc_llm/op/cutlass.py new file mode 100644 index 0000000000..275d61f20a --- /dev/null +++ b/python/mlc_llm/op/cutlass.py @@ -0,0 +1,76 @@ +"""Operators enabled by external modules.""" + +from typing import Optional + +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import op + + +def group_gemm( + x: nn.Tensor, + weight: nn.Tensor, + indptr: nn.Tensor, + scale: Optional[nn.Tensor] = None, + weight_dtype: Optional[str] = None, + out_dtype: Optional[str] = None, +): # pylint: disable=too-many-arguments + """ + Cutlass group gemm operator. + + Parameters + ---------- + x : nn.Tensor + The input tensor, with shape of [m, k]. + + weight : nn.Tensor + The weight tensor, with shape of [num_groups, n, k]. + + indptr : nn.Tensor + The indptr tensor, with shape of [num_groups]. + + scale : Optional[nn.Tensor] + The scale tensor, with shape of [1]. + + weight_dtype: Optional[str] + The data type of the weight tensor. + + out_dtype: Optional[str] + The data type of the output tensor. + + Returns + ------- + nn.Tensor + The output tensor, with shape of [m, n]. + """ + assert x.ndim == 2 + assert weight.ndim == 3 + assert indptr.ndim == 1 + assert weight.shape[2] == x.shape[1] + assert weight.shape[0] == indptr.shape[0] + assert indptr.dtype == "int64" + out_dtype = out_dtype if out_dtype else x.dtype + weight_dtype = weight_dtype if weight_dtype else weight.dtype + + if x.dtype == "e5m2_float8" and weight.dtype == "e5m2_float8" and out_dtype == "float16": + func_name = "cutlass.group_gemm_e5m2_e5m2_fp16" + elif x.dtype == "e4m3_float8" and weight.dtype == "e5m2_float8" and out_dtype == "float16": + func_name = "cutlass.group_gemm_e4m3_e5m2_fp16" + elif x.dtype == "e4m3_float8" and weight.dtype == "e4m3_float8" and out_dtype == "float16": + func_name = "cutlass.group_gemm_e4m3_e4m3_fp16" + elif x.dtype == "float16" and weight.dtype == "float16" and out_dtype == "float16": + func_name = "cutlass.group_gemm_fp16_sm90" + else: + raise NotImplementedError( + f"Unsupported data type: x={x.dtype}, weight={weight.dtype}, out={out_dtype}" + ) + + if "float8" in x.dtype: + assert scale is not None, "scale is required for float8 input" + + workspace = op.empty((4096 * 1024,), dtype="uint8", name="workspace") + + return op.extern( + func_name, + args=[x, weight, indptr, workspace] + ([scale] if scale is not None else []), + out=nn.Tensor.placeholder((x.shape[0], weight.shape[1]), dtype=out_dtype), + ) diff --git a/python/mlc_llm/op/extern.py b/python/mlc_llm/op/extern.py index 5fa7e829f2..fd5d91badb 100644 --- a/python/mlc_llm/op/extern.py +++ b/python/mlc_llm/op/extern.py @@ -28,13 +28,14 @@ class ExternModuleStore: target: Optional[Target] = None flashinfer: bool = False faster_transformer: bool = False + cutlass_group_gemm: bool = False STORE: ExternModuleStore = ExternModuleStore() """Singleton of `ExternModuleStore`.""" -def enable(target: Target, flashinfer: bool, faster_transformer: bool) -> None: +def enable(target: Target, flashinfer: bool, faster_transformer: bool, cutlass: bool) -> None: """Enable external modules. It should be called before any compilation happens.""" global STORE # pylint: disable=global-statement STORE = ExternModuleStore( @@ -42,6 +43,9 @@ def enable(target: Target, flashinfer: bool, faster_transformer: bool) -> None: target=target, flashinfer=flashinfer, faster_transformer=faster_transformer, + cutlass_group_gemm=cutlass + and target.kind.name == "cuda" + and target.attrs.get("arch", "") == "sm_90a", ) From 244c2e7112ca725b8a30ec125dbb6bb5e4d70e14 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 21 Mar 2024 16:36:10 -0400 Subject: [PATCH 094/531] [Attn] Fix the construction of attn result merge kernel (#1995) This PR fixes the mistake of passing wrong number of heads to the attention result merge kernel. --- python/mlc_llm/nn/kv_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index c4792bb57c..2ecf017cf4 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -345,7 +345,7 @@ def __init__( # pylint: disable=too-many-locals bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_prefill_sliding_window"), bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_decode_sliding_window"), bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_ragged"), - bb.add_func(_merge_state_inplace(num_key_value_heads, head_dim, dtype, target), "tir_attention_merge_state"), + bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, dtype, target), "tir_attention_merge_state"), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), # fmt: on From ddfbcda4ebd82058855443a5b26a9011ddc026fc Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 21 Mar 2024 16:45:15 -0400 Subject: [PATCH 095/531] [iOS][Android] Add validation of library file for iOS and Android build (#1993) This PR adds validation of symbols in iOS and android build. During static library build, we need the right model_lib for us to point to the packaged model executables. Not doing so correctly will results in vm_load_executable not found which is not informative. This PR we validate the compiled model lib by dumping the global symbols and ensure the list of model libs matches with each other. In future we should perhaps lift the validation to mlc_llm package. --- android/library/prepare_model_lib.py | 62 +++++++++++++++++- .../library/src/main/assets/app-config.json | 12 ++-- ios/prepare_model_lib.py | 64 ++++++++++++++++++- 3 files changed, 126 insertions(+), 12 deletions(-) diff --git a/android/library/prepare_model_lib.py b/android/library/prepare_model_lib.py index 9363be74c8..dc14397a16 100644 --- a/android/library/prepare_model_lib.py +++ b/android/library/prepare_model_lib.py @@ -3,20 +3,76 @@ from tvm.contrib import ndk +def get_model_libs(lib_path): + global_symbol_map = ndk.get_global_symbol_section_map(lib_path) + libs = [] + suffix = "___tvm_dev_mblob" + for name in global_symbol_map.keys(): + if name.endswith(suffix): + model_lib = name[: -len(suffix)] + if model_lib.startswith("_"): + model_lib = model_lib[1:] + libs.append(model_lib) + return libs + + def main(): - app_config = json.load(open("src/main/assets/app-config.json", "r")) + app_config_path = "src/main/assets/app-config.json" + app_config = json.load(open(app_config_path, "r")) artifact_path = os.path.abspath(os.path.join("../..", "dist")) tar_list = [] + model_set = set() - for model_lib_path in app_config["model_lib_path_for_prepare_libs"].values(): + for model, model_lib_path in app_config["model_lib_path_for_prepare_libs"].items(): path = os.path.join(artifact_path, model_lib_path) if not os.path.isfile(path): raise RuntimeError(f"Cannot find android library {path}") tar_list.append(path) + model_set.add(model) - ndk.create_staticlib(os.path.join("build", "model_lib", "libmodel_android.a"), tar_list) + lib_path = os.path.join("build", "model_lib", "libmodel_android.a") + ndk.create_staticlib(lib_path, tar_list) print(f"Creating lib from {tar_list}..") + available_model_libs = get_model_libs(lib_path) + print(f"Validating the library {lib_path}...") + print( + f"List of available model libs packaged: {available_model_libs}," + " if we have '-' in the model_lib string, it will be turned into '_'" + ) + global_symbol_map = ndk.get_global_symbol_section_map(lib_path) + error_happened = False + for item in app_config["model_list"]: + model_lib = item["model_lib"] + model_id = item["model_id"] + if model_lib not in model_set: + print( + f"ValidationError: model_lib={model_lib} specified for model_id={model_id} " + "is not included in model_lib_path_for_prepare_libs field, " + "This will cause the specific model not being able to load, " + f"please check {app_config_path}." + ) + error_happened = True + model_prefix_pattern = model_lib.replace("-", "_") + "___tvm_dev_mblob" + if ( + model_prefix_pattern not in global_symbol_map + and "_" + model_prefix_pattern not in global_symbol_map + ): + model_lib_path = app_config["model_lib_path_for_prepare_libs"][model_lib] + print( + "ValidationError:\n" + f"\tmodel_lib {model_lib} requested in {app_config_path} is not found in {lib_path}\n" + f"\tspecifically the model_lib for {model_lib_path} in model_lib_path_for_prepare_libs.\n" + f"\tcurrent available model_libs in {lib_path}: {available_model_libs}" + ) + error_happened = True + + if not error_happened: + print("Validation pass") + else: + print("Validation failed") + exit(255) + if __name__ == "__main__": main() diff --git a/android/library/src/main/assets/app-config.json b/android/library/src/main/assets/app-config.json index 8dcdf6dabf..68442c234e 100644 --- a/android/library/src/main/assets/app-config.json +++ b/android/library/src/main/assets/app-config.json @@ -26,16 +26,16 @@ }, { "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC", - "model_lib": "phi_q4f16_1", + "model_lib": "phi_msft_q4f16_1", "estimated_vram_bytes": 2036816936, "model_id": "phi-2-q4f16_1" } ], "model_lib_path_for_prepare_libs": { - "gemma_q4f16_1": "prebuilt_libs/gemma-2b-it/gemma-2b-it-q4f16_1-android.tar", - "llama_q4f16_1": "prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-android.tar", - "gpt_neox_q4f16_1": "prebuilt_libs/RedPajama-INCITE-Chat-3B-v1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar", - "phi_q4f16_1": "prebuilt_libs/phi-2/phi-2-q4f16_1-android.tar", - "Mistral-7B-Instruct-v0.2-q4f16_1": "prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-android.tar" + "gemma_q4f16_1": "prebuilt/lib/gemma-2b-it/gemma-2b-it-q4f16_1-android.tar", + "llama_q4f16_1": "prebuilt/lib/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-android.tar", + "gpt_neox_q4f16_1": "prebuilt/lib/RedPajama-INCITE-Chat-3B-v1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar", + "phi_msft_q4f16_1": "prebuilt/lib/phi-2/phi-2-q4f16_1-android.tar", + "mistral_q4f16_1": "prebuilt/lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-android.tar" } } \ No newline at end of file diff --git a/ios/prepare_model_lib.py b/ios/prepare_model_lib.py index 1db56cd08a..0e66879ddc 100644 --- a/ios/prepare_model_lib.py +++ b/ios/prepare_model_lib.py @@ -1,13 +1,29 @@ import json import os +import sys from tvm.contrib import cc +def get_model_libs(lib_path): + global_symbol_map = cc.get_global_symbol_section_map(lib_path) + libs = [] + suffix = "___tvm_dev_mblob" + for name in global_symbol_map.keys(): + if name.endswith(suffix): + model_lib = name[: -len(suffix)] + if model_lib.startswith("_"): + model_lib = model_lib[1:] + libs.append(model_lib) + return libs + + def main(): - app_config = json.load(open("MLCChat/app-config.json", "r")) + app_config_path = "MLCChat/app-config.json" + app_config = json.load(open(app_config_path, "r")) artifact_path = os.path.abspath(os.path.join("..", "dist")) tar_list = [] + model_set = set() for model, model_lib_path in app_config["model_lib_path_for_prepare_libs"].items(): paths = [ @@ -20,10 +36,52 @@ def main(): raise RuntimeError( f"Cannot find iOS lib for {model} from the following candidate paths: {paths}" ) - tar_list.append(valid_paths[0]) + tar_list.append(valid_paths[ls0]) + model_set.add(model) - cc.create_staticlib(os.path.join("build", "lib", "libmodel_iphone.a"), tar_list) + lib_path = os.path.join("build", "lib", "libmodel_iphone.a") + + cc.create_staticlib(lib_path, tar_list) + available_model_libs = get_model_libs(lib_path) print(f"Creating lib from {tar_list}..") + print(f"Validating the library {lib_path}...") + print( + f"List of available model libs packaged: {available_model_libs}," + " if we have '-' in the model_lib string, it will be turned into '_'" + ) + global_symbol_map = cc.get_global_symbol_section_map(lib_path) + error_happened = False + for item in app_config["model_list"]: + model_lib = item["model_lib"] + model_id = item["model_id"] + if model_lib not in model_set: + print( + f"ValidationError: model_lib={model_lib} specified for model_id={model_id} " + "is not included in model_lib_path_for_prepare_libs field, " + "This will cause the specific model not being able to load, " + f"please check {app_config_path}." + ) + error_happened = True + + model_prefix_pattern = model_lib.replace("-", "_") + "___tvm_dev_mblob" + if ( + model_prefix_pattern not in global_symbol_map + and "_" + model_prefix_pattern not in global_symbol_map + ): + model_lib_path = app_config["model_lib_path_for_prepare_libs"][model_lib] + print( + "ValidationError:\n" + f"\tmodel_lib {model_lib} requested in {app_config_path} is not found in {lib_path}\n" + f"\tspecifically the model_lib for {model_lib_path} in model_lib_path_for_prepare_libs.\n" + f"\tcurrent available model_libs in {lib_path}: {available_model_libs}" + ) + error_happened = True + + if not error_happened: + print("Validation pass") + else: + print("Validation failed") + exit(255) if __name__ == "__main__": From cc36324234a56b75ec13951174e1ce94ef9efd86 Mon Sep 17 00:00:00 2001 From: Git bot Date: Thu, 21 Mar 2024 20:46:31 +0000 Subject: [PATCH 096/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 7bb844df52..3847f7eb13 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 7bb844df52586b3c7646b8051cef1092cbb19073 +Subproject commit 3847f7eb13481920e5eb870c435f18ba338cd186 From 96d9c8b5611e24e1bac030417080c3eaf4b7ffd0 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Thu, 21 Mar 2024 14:02:35 -0700 Subject: [PATCH 097/531] [Serve] add allocator in Storage as the upstream change (#1997) The changes in https://github.com/apache/tvm/pull/16750 modified the signature of the Storage, this pull request updates the caller code in mlc-llm to accommodate the new Storage class signature. Ran into build error w/o the change. --- cpp/serve/model.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 559a6e0e50..3b7d7ef7ea 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -63,8 +63,8 @@ class ModelImpl : public ModelObj { memory::Allocator* allocator = memory::MemoryManager::GetOrCreateAllocator(device_host, memory::AllocatorType::kNaive); ICHECK_NOTNULL(allocator); - token_ids_storage_ = - memory::Storage(allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32))); + token_ids_storage_ = memory::Storage( + allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)), allocator); this->logit_pos_arr_ = NDArray::Empty({max_num_sequence}, DataType::Int(32), device_host); } From 0772940d3037a0fa311607d54eb566c0e028b79c Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 21 Mar 2024 22:22:10 -0400 Subject: [PATCH 098/531] [Compiler] Support IPC memory and customized all-reduce kernels (#1990) This PR introduces the IPC memory and customized all-reduce kernel dispatches for tensor parallelism. We add a new compiler flag `--allreduce-strategy`, which supports `"ring"`, `"one-shot"` and `"two-shot"`. The flag defaults to `"ring"`, which means this PR makes no difference if people do not manually change the all-reduce strategy. As of now the IPC-memory-backed customized all-reduce kernels are only available on CUDA. To enable all-reduce strategies other than "ring", here are some example compile commands: ```python python -m mlc_llm compile model/mlc-chat-config.json --device cuda --opt "allreduce-strategy=one-shot" -o model/lib.so python -m mlc_llm compile model/mlc-chat-config.json --device cuda --opt "allreduce-strategy=two-shot" -o model/lib.so ``` Please be aware that, you probably also need to specify other compiler flags, for example, like `--opt "cublas_gemm=1;allreduce-strategy=one-shot"`. --- python/mlc_llm/compiler_pass/pipeline.py | 8 ++++++++ python/mlc_llm/interface/compile.py | 1 + python/mlc_llm/interface/compiler_flags.py | 20 ++++++++++++++++++++ 3 files changed, 29 insertions(+) diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index d576c68451..4cf6323bc8 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -9,6 +9,7 @@ from tvm.relax import register_pipeline # pylint: disable=no-name-in-module from tvm.relax.frontend import nn +from mlc_llm.interface.compiler_flags import AllReduceStrategyType from mlc_llm.support import logging from .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc @@ -75,6 +76,7 @@ def _mlc_llm_pipeline( # pylint: disable=too-many-arguments flashinfer: bool = False, cublas_gemm: bool = False, faster_transformer: bool = False, # pylint: disable=unused-argument + allreduce_strategy: AllReduceStrategyType = AllReduceStrategyType.RING, variable_bounds: Dict[str, int] = None, additional_tirs: Dict[str, tvm.tir.PrimFunc] = None, metadata: Dict[str, Any] = None, @@ -147,7 +149,13 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tvm.relax.transform.ToNonDataflow(), tvm.relax.transform.RemovePurityChecking(), tvm.relax.transform.CallTIRRewrite(), + ( + tvm.relax.transform.IPCAllReduceRewrite(allreduce_strategy) + if allreduce_strategy != AllReduceStrategyType.RING + else tvm.transform.Sequential([]) + ), tvm.relax.transform.StaticPlanBlockMemory(), + tvm.relax.transform.LowerGPUIPCAllocStorage(), AttachMetadataWithMemoryUsage(metadata), tvm.relax.transform.RewriteCUDAGraph(), tvm.relax.transform.LowerAllocTensor(), diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index 5618ce3341..56bcc75abd 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -184,6 +184,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: flashinfer=args.opt.flashinfer, cublas_gemm=args.opt.cublas_gemm, faster_transformer=args.opt.faster_transformer, + allreduce_strategy=args.opt.allreduce_strategy, variable_bounds=variable_bounds, additional_tirs=additional_tirs, ext_mods=ext_mods, diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index b4ff81e6eb..32e79f9bd3 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -1,6 +1,7 @@ """Flags for overriding model config.""" import dataclasses +import enum import re from io import StringIO from typing import Optional @@ -13,6 +14,14 @@ logger = logging.getLogger(__name__) +class AllReduceStrategyType(enum.IntEnum): + """The all-reduce strategy.""" + + RING = 0 + ONESHOT = 1 + TWOSHOT = 2 + + @dataclasses.dataclass class OptimizationFlags: """Optimization flags""" @@ -22,6 +31,7 @@ class OptimizationFlags: faster_transformer: bool = False cudagraph: bool = False cutlass: bool = False + allreduce_strategy: AllReduceStrategyType = AllReduceStrategyType.RING def __repr__(self) -> str: out = StringIO() @@ -30,6 +40,7 @@ def __repr__(self) -> str: print(f";faster_transformer={int(self.faster_transformer)}", file=out, end="") print(f";cudagraph={int(self.cudagraph)}", file=out, end="") print(f";cutlass={int(self.cutlass)}", file=out, end="") + print(f";allreduce_strategy={self.allreduce_strategy.name}", file=out, end="") return out.getvalue().rstrip() @staticmethod @@ -52,6 +63,12 @@ def boolean(value: str) -> bool: parser.add_argument("--faster_transformer", type=boolean, default=False) parser.add_argument("--cudagraph", type=boolean, default=False) parser.add_argument("--cutlass", type=boolean, default=False) + parser.add_argument( + "--allreduce-strategy", + type=str, + choices=["ring", "one-shot", "two-shot"], + default="ring", + ) results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) return OptimizationFlags( flashinfer=results.flashinfer, @@ -59,6 +76,9 @@ def boolean(value: str) -> bool: faster_transformer=results.faster_transformer, cudagraph=results.cudagraph, cutlass=results.cutlass, + allreduce_strategy=AllReduceStrategyType[ + results.allreduce_strategy.replace("-", "").upper() + ], ) def update(self, target, quantization) -> None: From ae97b8d3763cd9ef9179140027d206622d185d21 Mon Sep 17 00:00:00 2001 From: Git bot Date: Fri, 22 Mar 2024 02:39:59 +0000 Subject: [PATCH 099/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 3847f7eb13..1ce4a34f3b 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 3847f7eb13481920e5eb870c435f18ba338cd186 +Subproject commit 1ce4a34f3b9eabebaad959ddc67dfebede068028 From 8405cb128b4e4477b17b54251ed7adf4e825ce32 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 22 Mar 2024 10:00:01 -0400 Subject: [PATCH 100/531] [Model] Fix the top-k TIR script for well-formedness (#2002) This PR fixes the malformed MoE TIR scripts. --- python/mlc_llm/op/moe_misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/op/moe_misc.py b/python/mlc_llm/op/moe_misc.py index 19bf10381f..6dc7f33265 100644 --- a/python/mlc_llm/op/moe_misc.py +++ b/python/mlc_llm/op/moe_misc.py @@ -101,7 +101,7 @@ def topk_softmax_func( with T.block("output"): vj = T.axis.remap("S", [j]) out[vi, vj] = T.cast( - T.exp(local_top_k_f32[j] - local_top_k_max[0]) + T.exp(local_top_k_f32[vj] - local_top_k_max[0]) / ( T.exp(local_top_k_f32[0] - local_top_k_max[0]) + T.exp(local_top_k_f32[1] - local_top_k_max[0]) From 64badb5b921398776a9644335468c1f211ed1faa Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 22 Mar 2024 15:48:26 -0700 Subject: [PATCH 101/531] Fix invalid use of dataflow var in sampler output (#2003) --- python/mlc_llm/compiler_pass/attach_sampler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index 64faf93bf3..2d28730a9b 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -107,7 +107,9 @@ def _attach_argsort_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): sorted_indices, primfunc_name_hint="take_sorted_probs", ) - gv = bb.emit_func_output([sorted_values, sorted_indices]) + output = (sorted_values, sorted_indices) + bb.emit_output(output) + gv = bb.emit_func_output(output) return gv @@ -201,6 +203,7 @@ def full(var_result: T.handle, value: T.int32): sinfo_args=sample_indices.struct_info, # pylint: disable=no-member ) ) + bb.emit_output(result) gv = bb.emit_func_output(result) return gv @@ -270,5 +273,6 @@ def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-argument ], ) ) + bb.emit_output(taken_probs_indices) gv = bb.emit_func_output(taken_probs_indices) return gv From 837ee530438deb1ca64c6d31b8feba17b3e73287 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 23 Mar 2024 20:54:25 -0400 Subject: [PATCH 102/531] [Fix] Fix KV cache creation pass after nn.Module changes (#2011) This PR corrects the assertion after latest changes in apache/tvm that updates some nn.Module behavior. --- python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py index e90bdfef78..47cfdf9dc8 100644 --- a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py +++ b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py @@ -13,7 +13,7 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]: assert isinstance(func.body, relax.SeqExpr) assert len(func.body.blocks) == 1 assert isinstance(func.body.blocks[0], relax.DataflowBlock) - assert len(func.body.blocks[0].bindings) == 2 + assert len(func.body.blocks[0].bindings) == 1 assert isinstance(func.body.blocks[0].bindings[0], relax.VarBinding) assert isinstance(func.body.blocks[0].bindings[0].value, relax.Call) assert func.body.blocks[0].bindings[0].value.op == tvm.ir.Op.get("relax.call_pure_packed") From 10f2d007376bb2e8ebb8c8de1c89b9da42bb0cf2 Mon Sep 17 00:00:00 2001 From: Andrew Date: Sun, 24 Mar 2024 10:30:09 -0700 Subject: [PATCH 103/531] [iOS] Fix typo in prepare_model_lib.py (#2013) Fix typo in prepare_model_lib.py tar_list.append(valid_paths[ls0]) is introduced by mistake in https://github.com/mlc-ai/mlc-llm/pull/1993 --- ios/prepare_model_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ios/prepare_model_lib.py b/ios/prepare_model_lib.py index 0e66879ddc..ff56236321 100644 --- a/ios/prepare_model_lib.py +++ b/ios/prepare_model_lib.py @@ -36,7 +36,7 @@ def main(): raise RuntimeError( f"Cannot find iOS lib for {model} from the following candidate paths: {paths}" ) - tar_list.append(valid_paths[ls0]) + tar_list.append(valid_paths[0]) model_set.add(model) lib_path = os.path.join("build", "lib", "libmodel_iphone.a") From a6de1ff87789ade1b91b8038e3ea6f149a7c8c3e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 24 Mar 2024 14:47:03 -0400 Subject: [PATCH 104/531] Remove unstable assertion in KV cache creation dispatch (#2017) This particular assertion is unstable recently given the back-and-forth upstream TVM nn.Module exporter behavior. --- python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py index 47cfdf9dc8..20e4c7bdd9 100644 --- a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py +++ b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py @@ -13,7 +13,6 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]: assert isinstance(func.body, relax.SeqExpr) assert len(func.body.blocks) == 1 assert isinstance(func.body.blocks[0], relax.DataflowBlock) - assert len(func.body.blocks[0].bindings) == 1 assert isinstance(func.body.blocks[0].bindings[0], relax.VarBinding) assert isinstance(func.body.blocks[0].bindings[0].value, relax.Call) assert func.body.blocks[0].bindings[0].value.op == tvm.ir.Op.get("relax.call_pure_packed") From 1c8b72e26876014e597e2ad72b68de9e4603a9f6 Mon Sep 17 00:00:00 2001 From: Git bot Date: Mon, 25 Mar 2024 01:34:05 +0000 Subject: [PATCH 105/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 1ce4a34f3b..2955bc6d8b 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 1ce4a34f3b9eabebaad959ddc67dfebede068028 +Subproject commit 2955bc6d8b09f6c0aa3178f1b208c9d0a6d22dee From ab9fa81321ead4f5ceb8b54c31234ea1ffa7a451 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 25 Mar 2024 20:21:58 +0800 Subject: [PATCH 106/531] [SLM] Qwen2 Multi-GPU support (#1985) * Update qwen2_model.py * fix lint issue * fix lint issue * fix lint issue --- python/mlc_llm/model/qwen2/qwen2_model.py | 71 ++++++++++++++++------- 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/python/mlc_llm/model/qwen2/qwen2_model.py b/python/mlc_llm/model/qwen2/qwen2_model.py index c85e8337df..ff42e977b4 100644 --- a/python/mlc_llm/model/qwen2/qwen2_model.py +++ b/python/mlc_llm/model/qwen2/qwen2_model.py @@ -13,6 +13,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase from mlc_llm.support.style import bold @@ -35,6 +36,7 @@ class QWen2Config(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = 0 prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 + head_dim: int = 0 dtype: str = "float32" kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) @@ -56,6 +58,9 @@ def __post_init__(self): "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( "%s defaults to %s (%d)", @@ -80,29 +85,19 @@ def __post_init__(self): class QWen2Attention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: QWen2Config): - head_dim = config.hidden_size // config.num_attention_heads + self.head_dim = config.head_dim + self.num_attention_heads = config.num_attention_heads // config.tensor_parallel_shards + self.num_key_value_heads = config.num_key_value_heads // config.tensor_parallel_shards + self.rope_theta = config.rope_theta self.c_attn = nn.Linear( in_features=config.hidden_size, - out_features=(2 * config.num_key_value_heads + config.num_attention_heads) * head_dim, + out_features=(2 * self.num_key_value_heads + self.num_attention_heads) * self.head_dim, bias=True, ) self.o_proj = nn.Linear( - config.num_attention_heads * head_dim, config.hidden_size, bias=False - ) - # KV cache for single sequence - self.k_cache = nn.KVCache( - config.context_window_size, [config.num_key_value_heads, head_dim] + self.num_attention_heads * self.head_dim, config.hidden_size, bias=False ) - self.v_cache = nn.KVCache( - config.context_window_size, [config.num_attention_heads, head_dim] - ) - - self.hidden_size = config.hidden_size - self.head_dim = head_dim - self.num_attention_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.rope_theta = config.rope_theta def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): d, h_q, h_kv = self.head_dim, self.num_attention_heads, self.num_key_value_heads @@ -128,8 +123,9 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: class QWen2MLP(nn.Module): def __init__(self, config: QWen2Config): - self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) - self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x: Tensor): @@ -147,15 +143,46 @@ def __init__(self, config: QWen2Config): config.hidden_size, -1, config.rms_norm_eps, bias=False ) + def _set_tp(): + def _set(layer, hint): + layer.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_attention_heads * hd + k = self.self_attn.num_key_value_heads * hd + v = self.self_attn.num_key_value_heads * hd + i = self.mlp.intermediate_size + _set( + self.self_attn.c_attn.weight, + tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]), + ) + _set( + self.self_attn.c_attn.bias, + tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]), + ) + _set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1)) + _set( + self.mlp.gate_up_proj.weight, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0) + ) + _set(self.mlp.down_proj.weight, tp.ShardSingleDim("_shard_mlp_down", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): out = self.input_layernorm(hidden_states) out = self.self_attn(out, paged_kv_cache, layer_id) - hidden_states = out + hidden_states + hidden_states = self._apply_residual(out, residual=hidden_states) out = self.post_attention_layernorm(hidden_states) out = self.mlp(out) - hidden_states = out + hidden_states + hidden_states = self._apply_residual(out, residual=hidden_states) return hidden_states + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + class QWen2Model(nn.Module): def __init__(self, config: QWen2Config): @@ -187,7 +214,7 @@ def __init__(self, config: QWen2Config): self.rope_theta = config.rope_theta self.vocab_size = config.vocab_size self.tensor_parallel_shards = config.tensor_parallel_shards - self.head_dim = config.hidden_size // config.num_attention_heads + self.head_dim = config.head_dim def to(self, dtype: Optional[str] = None): super().to(dtype=dtype) @@ -211,6 +238,8 @@ def batch_forward( return logits def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.model.embed_tokens(input_ids) def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): From f04cd3e9e81bcd3c02015df6fe0f0eaa9ffd8453 Mon Sep 17 00:00:00 2001 From: na20215 <78482004+na20215@users.noreply.github.com> Date: Mon, 25 Mar 2024 20:28:41 +0800 Subject: [PATCH 107/531] more info for preshard (#2027) * When the pre-sharded version of a certain model is not available, the program will default back to the normal workflow without issuing any alert. Now, when someone attempts to convert to a pre-sharded model but cannot, the program will throw a warning message to inform users that it will revert to the standard model conversion process. * format fix. * black reformatted, i did not see any diff. * black reformatted.. --- python/mlc_llm/support/preshard.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/mlc_llm/support/preshard.py b/python/mlc_llm/support/preshard.py index 09db02c554..cd5edbc19c 100644 --- a/python/mlc_llm/support/preshard.py +++ b/python/mlc_llm/support/preshard.py @@ -1,4 +1,5 @@ """Functions for pre-sharding weights""" +import logging from typing import Any, Dict, List from tvm import IRModule @@ -8,6 +9,8 @@ from tvm.runtime import Device from tvm.target import Target +logger = logging.getLogger("preshard") + def _sharded_param_name(param_name, worker_id): return f"{param_name}_shard-{worker_id}" @@ -93,10 +96,7 @@ def _compile_shard_funcs(mod: IRModule, device: Device): def apply_preshard( - quantize_map: Any, - named_params: Dict[str, nn.Parameter], - tensor_parallel_shards: int, - args: Any, + quantize_map: Any, named_params: Dict[str, nn.Parameter], tensor_parallel_shards: int, args: Any ): """Update quantize_map and named_params, create shard functions based on shard strategies.""" model_config = args.model.config.from_file(args.config) @@ -107,9 +107,11 @@ def apply_preshard( bb = relax.BlockBuilder() param_to_shard_func = {} shard_func_names = set() + has_shard_strategy = False for name, param in model.state_dict().items(): shard_strategy = param.attrs.get("shard_strategy", None) if shard_strategy is not None: + has_shard_strategy = True _update_quantize_map(quantize_map, named_params, name, tensor_parallel_shards) # create shard functions @@ -117,7 +119,12 @@ def apply_preshard( if shard_strategy.name not in shard_func_names: _create_shard_func(bb, param, tensor_parallel_shards) shard_func_names.add(shard_strategy.name) - + if not has_shard_strategy: + logger.warning( + "No parameters with 'shard_strategy' found." + "At least one parameter must have a 'shard_strategy' for presharding. " + "The model will continue to convert weights in a non-presharded manner." + ) mod = bb.finalize() vm = _compile_shard_funcs(mod, args.device) From 1c975de60217c82f4dd8a3a7ac2d0c60b8e4da23 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Mon, 25 Mar 2024 12:15:20 -0400 Subject: [PATCH 108/531] Register stablelm-2 conversation template (#2029) --- python/mlc_llm/conversation_template.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index c776a9298b..b4a3468872 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -264,6 +264,21 @@ def get_conv_template(name: str) -> Optional[Conversation]: ) ) +# StableLM-2 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="stablelm-2", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "<|user|>", "assistant": "<|assistant|>"}, + seps=["<|endoftext|>", "<|endoftext|>"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|endoftext|>"], + stop_token_ids=[100257], + ) +) + # Llava ConvTemplateRegistry.register_conv_template( Conversation( From 8796fb4609d29e2b3df76b5eafb4de0bf47186d7 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Wed, 27 Mar 2024 04:25:04 +0800 Subject: [PATCH 109/531] [Serving][Fix] Fix problems in PopenServer (#2032) This PR fixes several problems in the PopenServer: - Add check for the server is not started and the request returns a fail number, e.g. 502. And changed the retry time to 0.1s. - Add a `__enter__` and `__exit__` method for PopenServer. When the program is interrupted, using with clause (`__enter__` and `__exit__`) can ensure the server always terminates. When using `start()` and `terminate()`, the server may still be staying in the background even though the parent process ends. --- python/mlc_llm/serve/server/popen_server.py | 20 +++++++++++++++++--- tests/python/serve/server/conftest.py | 6 ++---- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index 6a668419cc..fcdfe6da39 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -1,4 +1,5 @@ """The MLC LLM server launched in a subprocess.""" + import subprocess import sys import time @@ -64,13 +65,17 @@ def start(self) -> None: openai_v1_models_url = "http://127.0.0.1:8000/v1/models" query_result = None timeout = 60 - attempts = 0 + attempts = 0.0 while query_result is None and attempts < timeout: try: query_result = requests.get(openai_v1_models_url, timeout=60) + if query_result.status_code != 200: + query_result = None + attempts += 0.1 + time.sleep(0.1) except: # pylint: disable=bare-except - attempts += 1 - time.sleep(1) + attempts += 0.1 + time.sleep(0.1) # Check if the subprocess terminates unexpectedly or # the queries reach the timeout. @@ -117,3 +122,12 @@ def kill_child_processes(): except subprocess.TimeoutExpired: pass self._proc = None + + def __enter__(self): + """Start the server.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Terminate the server.""" + self.terminate() diff --git a/tests/python/serve/server/conftest.py b/tests/python/serve/server/conftest.py index 807739ace6..e425494231 100644 --- a/tests/python/serve/server/conftest.py +++ b/tests/python/serve/server/conftest.py @@ -28,8 +28,6 @@ def launch_server(served_model): # pylint: disable=redefined-outer-name model_lib_path=served_model[1], enable_tracing=True, ) - server.start() - yield - # Fixture teardown code. - server.terminate() + with server: + yield From a6d31d7fca0258c46ae887f015f2b60a03e0c4f3 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 26 Mar 2024 16:27:59 -0400 Subject: [PATCH 110/531] [Quantization] Skip MoE gate layer (#2012) This PR skips quantizing the MoE gate layer. --- python/mlc_llm/quantization/awq_quantization.py | 4 ++-- python/mlc_llm/quantization/ft_quantization.py | 7 ++++--- python/mlc_llm/quantization/group_quantization.py | 8 +++++--- python/mlc_llm/quantization/utils.py | 5 +++++ 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/python/mlc_llm/quantization/awq_quantization.py b/python/mlc_llm/quantization/awq_quantization.py index 0b89e5db6a..1d7cddbfa6 100644 --- a/python/mlc_llm/quantization/awq_quantization.py +++ b/python/mlc_llm/quantization/awq_quantization.py @@ -9,7 +9,7 @@ from mlc_llm.loader import QuantizeMapping -from .utils import convert_uint_to_float, is_final_fc +from .utils import convert_uint_to_float, is_final_fc, is_moe_gate def _make_divisible(c, divisor): # pylint: disable=invalid-name @@ -117,7 +117,7 @@ def visit_module(self, name: str, node: nn.Module) -> Any: The new node to replace current node. """ - if isinstance(node, nn.Linear) and not is_final_fc(name): + if isinstance(node, nn.Linear) and not is_final_fc(name) and not is_moe_gate(name): return AWQQuantizeLinear.from_linear(node, self.config) return self.visit(name, node) diff --git a/python/mlc_llm/quantization/ft_quantization.py b/python/mlc_llm/quantization/ft_quantization.py index c30e85bf70..b6b1da100f 100644 --- a/python/mlc_llm/quantization/ft_quantization.py +++ b/python/mlc_llm/quantization/ft_quantization.py @@ -21,7 +21,7 @@ GroupQuantizeEmbedding, GroupQuantizeLinear, ) -from .utils import is_final_fc +from .utils import is_final_fc, is_moe_gate logger = logging.getLogger(__name__) @@ -147,8 +147,9 @@ def visit_module(self, name: str, node: nn.Module) -> Any: group_quantize = self.config.fallback_group_quantize() self.quant_map.map_func[weight_name] = group_quantize.quantize_weight return GroupQuantizeLinear.from_linear(node, group_quantize) - self.quant_map.map_func[weight_name] = self.config.quantize_weight - return FTQuantizeLinear.from_linear(node, self.config) + if not is_moe_gate(name): + self.quant_map.map_func[weight_name] = self.config.quantize_weight + return FTQuantizeLinear.from_linear(node, self.config) if isinstance(node, nn.Embedding): weight_name = f"{name}.weight" self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] diff --git a/python/mlc_llm/quantization/group_quantization.py b/python/mlc_llm/quantization/group_quantization.py index 3431b5415e..feb4b0216d 100644 --- a/python/mlc_llm/quantization/group_quantization.py +++ b/python/mlc_llm/quantization/group_quantization.py @@ -16,7 +16,7 @@ from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp -from .utils import convert_uint_to_float, is_final_fc +from .utils import convert_uint_to_float, is_final_fc, is_moe_gate logger = logging.getLogger(__name__) @@ -107,8 +107,10 @@ def visit_module(self, name: str, node: nn.Module) -> Any: ret_node: Any The new node to replace current node. """ - if isinstance(node, nn.Linear) and ( - not is_final_fc(name) or self.config.quantize_final_fc + if ( + isinstance(node, nn.Linear) + and (not is_final_fc(name) or self.config.quantize_final_fc) + and not is_moe_gate(name) ): weight_name = f"{name}.weight" self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] diff --git a/python/mlc_llm/quantization/utils.py b/python/mlc_llm/quantization/utils.py index 05a9b9e233..8373b4d62c 100644 --- a/python/mlc_llm/quantization/utils.py +++ b/python/mlc_llm/quantization/utils.py @@ -45,3 +45,8 @@ def is_final_fc(name: str) -> bool: """Determines whether the parameter is the last layer based on its name.""" # TODO: use more specious condition to determine final fc # pylint: disable=fixme return name in ["head", "lm_head", "lm_head.linear", "embed_out"] + + +def is_moe_gate(name: str) -> bool: + """Check whether the parameter is the MoE gate layer.""" + return name.endswith("gate") From f2518abd80cc029aad14ca6acb53306c4a91e060 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Wed, 27 Mar 2024 11:51:01 +0800 Subject: [PATCH 111/531] [Serving][Grammar] Integration of JSON schema generation (#2030) Previous PR #1983 introduced a transformation from json schema to BNF grammar. This PR further integrates the grammar from json schema to the generation pipeline, so that the engine now supports json schema output. GrammarStateInitContexts are stored in a cache, so it will not be created again with the same schema. Interface: - Python ``` @dataclass class ResponseFormat: type: Literal["text", "json_object"] = "text" schema: Optional[str] = None ``` - Rest API ``` class RequestResponseFormat(BaseModel): type: Literal["text", "json_object"] = "text" json_schema: Optional[str] = Field(default=None, alias="schema") class CompletionRequest(BaseModel): ... response_format: RequestResponseFormat = Field(default_factory=RequestResponseFormat) class ChatCompletionRequest(BaseModel): ... response_format: RequestResponseFormat = Field(default_factory=RequestResponseFormat) ``` Performance: We only tests single-batch performance now to show the overhead in latency. - Model: `Llama-2-7b-chat-hf-q4f16_1` - GPU: `NVIDIA GeForce RTX 3080` - CPU: `AMD Ryzen 9 5900X 12-Core Processor` ``` JSON ON Batch=1 Average prefill tokens: 651.0000 tok/req Average decode tokens: 499.0000 tok/req Single token prefill latency: 0.3140 ms/tok Single token decode latency: 8.6831 ms/tok Prefill token throughput: 3184.8002 tok/s Decode token throughput: 116.6039 tok/s JSON OFF Batch=1 Average prefill tokens: 651.0000 tok/req Average decode tokens: 499.0000 tok/req Single token prefill latency: 0.3098 ms/tok Single token decode latency: 8.6823 ms/tok Prefill token throughput: 3227.8141 tok/s Decode token throughput: 116.9251 tok/s ``` This PR also does these bug fixes / changes: - Changed the structure of the converted grammar from schema to avoid large amount of uncertain tokens, which caused a performance degradation --- cpp/serve/config.cc | 16 +- cpp/serve/config.h | 2 +- cpp/serve/engine.cc | 28 ++- cpp/serve/engine_actions/action_commons.cc | 2 + cpp/serve/engine_actions/batch_decode.cc | 2 +- cpp/serve/grammar/grammar.cc | 28 +++ cpp/serve/grammar/grammar.h | 23 ++- cpp/serve/grammar/grammar_state_matcher.cc | 98 ++++++++-- cpp/serve/grammar/grammar_state_matcher.h | 39 ++++ .../grammar/grammar_state_matcher_preproc.h | 52 ++++++ cpp/serve/grammar/support.h | 19 +- cpp/serve/request_state.cc | 12 +- cpp/serve/request_state.h | 16 +- .../mlc_llm/protocol/openai_api_protocol.py | 8 +- python/mlc_llm/serve/config.py | 8 +- python/mlc_llm/serve/grammar.py | 54 +++++- python/mlc_llm/serve/json_schema_converter.py | 53 ++++-- conftest.py => tests/python/conftest.py | 0 tests/python/serve/benchmark.py | 17 +- tests/python/serve/server/test_server.py | 169 ++++++++++++++++-- .../test_grammar_state_matcher_custom.py | 48 ++++- .../serve/test_grammar_state_matcher_json.py | 4 +- .../serve/test_json_schema_converter.py | 116 ++++++++---- .../python/serve/test_serve_engine_grammar.py | 59 +++++- 24 files changed, 734 insertions(+), 139 deletions(-) rename conftest.py => tests/python/conftest.py (100%) diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 5a0b35a3c6..3465de402e 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -144,12 +144,12 @@ GenerationConfig::GenerationConfig(String config_json_str) { CHECK(response_format_json["type"].is()); response_format.type = response_format_json["type"].get(); } - if (response_format_json.count("json_schema")) { - if (response_format_json["json_schema"].is()) { - response_format.json_schema = NullOpt; + if (response_format_json.count("schema")) { + if (response_format_json["schema"].is()) { + response_format.schema = NullOpt; } else { - CHECK(response_format_json["json_schema"].is()); - response_format.json_schema = response_format_json["json_schema"].get(); + CHECK(response_format_json["schema"].is()); + response_format.schema = response_format_json["schema"].get(); } } n->response_format = response_format; @@ -194,9 +194,9 @@ String GenerationConfigNode::AsJSONString() const { picojson::object response_format; response_format["type"] = picojson::value(this->response_format.type); - response_format["json_schema"] = this->response_format.json_schema - ? picojson::value(this->response_format.json_schema.value()) - : picojson::value(); + response_format["schema"] = this->response_format.schema + ? picojson::value(this->response_format.schema.value()) + : picojson::value(); config["response_format"] = picojson::value(response_format); return picojson::value(config).serialize(true); diff --git a/cpp/serve/config.h b/cpp/serve/config.h index e9e4d68970..c406e55125 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -21,7 +21,7 @@ using namespace tvm::runtime; /*! \brief The response format of a request. */ struct ResponseFormat { String type = "text"; - Optional json_schema = NullOpt; + Optional schema = NullOpt; }; /*! \brief The generation configuration of a request. */ diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 1d0813a288..98f3e4fe6b 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -61,8 +62,7 @@ class EngineImpl : public Engine { this->trace_recorder_ = trace_recorder; this->tokenizer_ = Tokenizer::FromPath(tokenizer_path); this->token_table_ = tokenizer_->TokenTable(); - this->json_grammar_state_init_ctx_ = - GrammarStateMatcher::CreateInitContext(BNFGrammar::GetGrammarOfJSON(), this->token_table_); + this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); // Step 2. Initialize each model independently. // Create the logit processor and sampler. this->models_.clear(); @@ -160,11 +160,13 @@ class EngineImpl : public Engine { int n = request->generation_cfg->n; int rng_seed = request->generation_cfg->seed; + auto grammar_state_init_ctx = + ResponseFormatToGrammarInitContext(request->generation_cfg->response_format); std::vector rsentries; // Create the request state entry for the input. rsentries.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(), rng_seed, - token_table_, json_grammar_state_init_ctx_); + token_table_, grammar_state_init_ctx); if (n > 1) { // Then create a request state entry for each parallel generation branch. // We add a offset to the rng seed so that to make generations different. @@ -173,7 +175,7 @@ class EngineImpl : public Engine { for (int i = 0; i < n; ++i) { rsentries[0]->child_indices.push_back(rsentries.size()); rsentries.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(), - rng_seed + i + 1, token_table_, json_grammar_state_init_ctx_, + rng_seed + i + 1, token_table_, grammar_state_init_ctx, /*parent_idx=*/0); } } @@ -247,6 +249,20 @@ class EngineImpl : public Engine { std::max(max_concurrency - host_cpu_usage, 1), kv_cache_config_->max_num_sequence)); } + /*! \brief Create a grammar init context according to the response format. If the response format + * is not JSON, return std::nullopt. */ + std::optional> ResponseFormatToGrammarInitContext( + const ResponseFormat& response_format) { + if (response_format.type != "json_object") { + return std::nullopt; + } else if (!response_format.schema) { + return grammar_init_context_storage_->GetInitContextForJSON(); + } else { + return grammar_init_context_storage_->GetInitContextForJSONSchema( + response_format.schema.value()); + } + } + // Engine state, managing requests and request states. EngineState estate_; // Configurations and singletons @@ -255,8 +271,8 @@ class EngineImpl : public Engine { int max_single_sequence_length_; Tokenizer tokenizer_; std::vector token_table_; - // The initial context for the grammar state matching of JSON. - std::shared_ptr json_grammar_state_init_ctx_; + // Helper to get the grammar init context for requests. + GrammarInitContextStorage grammar_init_context_storage_; // Models Array models_; // Workspace of each model. diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index 35ba851386..d6a5d52ef4 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -72,6 +72,8 @@ void ProcessFinishedRequestStateEntries(std::vector finished_ for (const RequestStateEntry& entry : rstate->entries) { estate->stats.total_decode_length += entry->mstates[0]->committed_tokens.size(); } + // For a request, the first token in committed_tokens is generated by prefilling + // and the rest are generated by decoding. So we subtract the first token. estate->stats.total_decode_length -= rsentry->request->generation_cfg->n; } } diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index 47007f6c8d..4801d52f32 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -83,7 +83,7 @@ class BatchDecodeActionObj : public EngineActionObj { // - Compute embeddings. RECORD_EVENT(trace_recorder_, request_ids, "start embedding"); ObjectRef embeddings = - models_[0]->TokenEmbed({IntTuple{input_tokens.begin(), input_tokens.end()}}); + models_[0]->TokenEmbed({IntTuple(input_tokens.begin(), input_tokens.end())}); RECORD_EVENT(trace_recorder_, request_ids, "finish embedding"); // - Invoke model decode. diff --git a/cpp/serve/grammar/grammar.cc b/cpp/serve/grammar/grammar.cc index c5a41626e3..c4d6445c7e 100644 --- a/cpp/serve/grammar/grammar.cc +++ b/cpp/serve/grammar/grammar.cc @@ -42,6 +42,34 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromJSON").set_body_typed([](String jso return BNFGrammar::FromJSON(json_string); }); +BNFGrammar BNFGrammar::FromSchema(const String& schema, int indent, + Optional> separators, bool strict_mode) { + static const PackedFunc* json_schema_to_ebnf = Registry::Get("mlc.serve.json_schema_to_ebnf"); + CHECK(json_schema_to_ebnf != nullptr) << "mlc.serve.json_schema_to_ebnf is not registered."; + + String ebnf_string; + + // Convert the indent parameter to NullOpt for sending it to the PackedFunc. + if (indent == -1) { + // The conversion from TVMRetValue to String is ambiguous, so we call the conversion function + // explicitly + ebnf_string = + ((*json_schema_to_ebnf)(schema, Optional(NullOpt), separators, strict_mode) + . + operator String()); + } else { + ebnf_string = (*json_schema_to_ebnf)(schema, indent, separators, strict_mode).operator String(); + ; + } + return FromEBNFString(ebnf_string); +} + +TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromSchema") + .set_body_typed([](const String& schema, int indent, Optional> separators, + bool strict_mode) { + return BNFGrammar::FromSchema(schema, indent, separators, strict_mode); + }); + const std::string kJSONGrammarString = R"( main ::= ( "{" ws members_or_embrace | diff --git a/cpp/serve/grammar/grammar.h b/cpp/serve/grammar/grammar.h index 21062ab503..545a4e08a0 100644 --- a/cpp/serve/grammar/grammar.h +++ b/cpp/serve/grammar/grammar.h @@ -18,6 +18,7 @@ namespace mlc { namespace llm { namespace serve { +using namespace tvm; using namespace tvm::runtime; /*! @@ -182,7 +183,7 @@ class BNFGrammar : public ObjectRef { * \param simplify Whether to simplify the grammar to make matching more efficient. Default: true. * Not implemented yet. */ - static BNFGrammar FromEBNFString(const String& ebnf_string, const String& main_rule, + static BNFGrammar FromEBNFString(const String& ebnf_string, const String& main_rule = "main", bool normalize = true, bool simplify = true); /*! @@ -192,7 +193,25 @@ class BNFGrammar : public ObjectRef { */ static BNFGrammar FromJSON(const String& json_string); - /*! + /*! + * \brief Construct a BNF grammar from the json schema string. The schema string should be in the + * format of the schema of a JSON file. We will parse the schema and generate a BNF grammar. + * \param schema The schema string. + * \param indent The number of spaces for indentation. If -1, the output will be in one line. + * Default: -1. + * \param separators Two separators used in the schema: comma and colon. Examples: {",", ":"}, + * {", ", ": "}. If NullOpt, the default separators will be used: {",", ": "} when the indent + * is not -1, and {", ", ": "} otherwise. Default: NullOpt. + * \param strict_mode Whether to use strict mode. In strict mode, the generated grammar will not + * allow unevaluatedProperties and unevaluatedItems, i.e. these will be set to false by default. + * This helps LLM to generate accurate output in the grammar-guided generation with JSON + * schema. Default: true. + */ + static BNFGrammar FromSchema(const String& schema, int indent = -1, + Optional> separators = NullOpt, + bool strict_mode = true); + + /*! * \brief Get the grammar of standard JSON format. We have built-in support for JSON. */ static BNFGrammar GetGrammarOfJSON(); diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index 6e0a26dddb..2131e9f112 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -40,7 +40,7 @@ namespace serve { * elements at the end may be popped out, and the last element of the stack will be advanced. * * One stack may split since there may be multiple possible next positions. In this case, similar - * stacks with different top elements will be added. When ome stack cannot accept the new character, + * stacks with different top elements will be added. When one stack cannot accept the new character, * it will be removed from the stacks. * * ## Storage of Stacks (see grammar_state_matcher_state.h) @@ -59,7 +59,7 @@ namespace serve { * S ::= "" | [c] [d] * T ::= [e] * - * ### Previous step + * ### The previous step * Previous accepted string: ab * Previous stack tree: * A------ @@ -76,7 +76,7 @@ namespace serve { * < means the stack top pointers in the previous step. * The stacks in the previous step is: (A, B, C), (A, D), (A, E) * - * ### Current step + * ### The current step * Current accepted string: abc * Current stack tree: * A----------------- G<< @@ -87,7 +87,7 @@ namespace serve { * * F: (rule S, choice 1, element 1) * G: (rule main, choice 0, element 2) (means the matching process has finished, and will be deleted - * when next char comes) + * when the next char comes) * H: (rule R, choice 1, element 2) * I: (rule T, choice 0, element 0) * << means the stack top pointers in the current step. @@ -175,7 +175,7 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm */ bool AcceptStopToken(); - friend IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher); + friend IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose); std::shared_ptr init_ctx_; int max_rollback_steps_; @@ -381,12 +381,12 @@ void GrammarStateMatcherNodeImpl::SetTokenBitmask(DLTensor* next_token_bitmask, << "The provied bitmask's shape or dtype is not valid."; BitsetManager next_token_bitset(reinterpret_cast(next_token_bitmask->data), - next_token_bitmask->shape[0]); + next_token_bitmask->shape[0], init_ctx_->vocab_size); if (rejected_indices.size() == 1 && rejected_indices[0] == -1) { // If rejected_indices is the universal set, the final accepted token set is just // accepted_indices - next_token_bitset.Reset(init_ctx_->vocab_size, false); + next_token_bitset.Reset(false); for (int idx : accepted_indices) { next_token_bitset.Set(init_ctx_->sorted_token_codepoints[idx].id, true); } @@ -399,7 +399,7 @@ void GrammarStateMatcherNodeImpl::SetTokenBitmask(DLTensor* next_token_bitmask, } } else { // Otherwise, the final rejected token set is (rejected_indices \ accepted_indices) - next_token_bitset.Reset(init_ctx_->vocab_size, true); + next_token_bitset.Reset(true); auto it_acc = accepted_indices.begin(); for (auto i : rejected_indices) { @@ -524,25 +524,83 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugMatchCompleteString") return MatchCompleteString(matcher, str); }); +/*! \brief Print the accepted and rejected tokens stored in the bitset. For debug purposes. */ +void PrintAcceptedRejectedTokens( + const std::shared_ptr& init_ctx, + const BitsetManager& bitset, int threshold = 500) { + auto vocab_size = init_ctx->vocab_size; + std::vector accepted_ids; + std::vector rejected_ids; + for (int i = 0; i < vocab_size; i++) { + if (bitset[i]) { + accepted_ids.push_back(i); + } else { + rejected_ids.push_back(i); + } + } + + if (accepted_ids.size() < threshold) { + std::cerr << "Accepted: "; + for (auto id : accepted_ids) { + std::cerr << "<"; + auto token = init_ctx->token_table[id]; + if (token.size() == 1 && (static_cast(token[0]) >= 128 || token[0] == 0)) { + // First cast to unsigned, then cast to int + std::cerr << static_cast(static_cast(token[0])); + } else { + auto codepoints = Utf8StringToCodepoints(token.c_str()); + for (auto c : codepoints) { + std::cerr << CodepointToPrintable(c); + } + } + std::cerr << "> "; + } + std::cerr << "\n"; + } + + if (rejected_ids.size() < threshold) { + std::cerr << "Rejected: "; + for (auto id : rejected_ids) { + std::cerr << "<"; + auto token = init_ctx->token_table[id]; + if (token.size() == 1 && ((unsigned char)token[0] >= 128 || token[0] == 0)) { + std::cerr << (int)(unsigned char)token[0]; + } else { + auto codepoints = Utf8StringToCodepoints(token.c_str()); + for (auto c : codepoints) { + std::cerr << CodepointToPrintable(c); + } + } + std::cerr << "> "; + } + std::cerr << "\n"; + } +} + /*! - * \brief Find the ids of the rejected tokens for the next step. For test purposes. + * \brief Find the ids of the rejected tokens for the next step. For debug purposes. + * \param matcher The matcher to test. + * \param verbose Whether to print information about the timing and results to stderr. * \returns A tuple of rejected token ids. */ -IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher) { +IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = false) { auto init_ctx = matcher.as()->init_ctx_; auto vocab_size = init_ctx->vocab_size; - auto bitset_size = BitsetManager::GetBitsetSize(vocab_size); + auto bitset_size = BitsetManager::CalculateBufferSize(vocab_size); auto ndarray = NDArray::Empty(ShapeTuple{static_cast(bitset_size)}, DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0}); auto dltensor = const_cast(ndarray.operator->()); - auto start = std::chrono::high_resolution_clock::now(); + std::chrono::time_point start, end; + if (verbose) { + start = std::chrono::high_resolution_clock::now(); + } matcher->FindNextTokenBitmask(dltensor); - auto end = std::chrono::high_resolution_clock::now(); - std::cerr << "FindNextTokenBitmask takes " - << std::chrono::duration_cast(end - start).count() << "us"; + if (verbose) { + end = std::chrono::high_resolution_clock::now(); + } - auto bitset = BitsetManager(reinterpret_cast(dltensor->data), bitset_size); + auto bitset = BitsetManager(reinterpret_cast(dltensor->data), bitset_size, vocab_size); std::vector rejected_ids; for (int i = 0; i < vocab_size; i++) { if (bitset[i] == 0) { @@ -550,8 +608,12 @@ IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher) { } } - std::cerr << ", found accepted: " << vocab_size - rejected_ids.size() - << ", rejected: " << rejected_ids.size() << std::endl; + if (verbose) { + std::cerr << "FindNextTokenBitmask takes " + << std::chrono::duration_cast(end - start).count() << "us" + << ", found accepted: " << vocab_size - rejected_ids.size() + << ", rejected: " << rejected_ids.size() << std::endl; + } auto ret = IntTuple(rejected_ids); return ret; diff --git a/cpp/serve/grammar/grammar_state_matcher.h b/cpp/serve/grammar/grammar_state_matcher.h index 443a791edc..eceaa75d07 100644 --- a/cpp/serve/grammar/grammar_state_matcher.h +++ b/cpp/serve/grammar/grammar_state_matcher.h @@ -129,6 +129,45 @@ class GrammarStateMatcher : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GrammarStateMatcher, ObjectRef, GrammarStateMatcherNode); }; +/*! + * \brief Helper class to get the grammar state init context for grammars or schemas. This class + * maintains cache internally, so the same grammar or schema will not be preprocessed multiple + * times. + * \note This class is associated with a token table when constructed. The token table is used to + * create every grammar state init context. If multiple toke tables are used to create init + * contexts, an instance of this class for each token table should be created. + */ +class GrammarInitContextStorageNode : public Object { + public: + /*! \brief Get the init context for pure JSON. */ + virtual std::shared_ptr GetInitContextForJSON() = 0; + + /*! \brief Get the init context for a JSON schema string. */ + virtual std::shared_ptr GetInitContextForJSONSchema( + const std::string& schema) = 0; + + /*! \brief Clear the interal cache of init contexts. */ + virtual void ClearCache() = 0; + + static constexpr const char* _type_key = "mlc.serve.GrammarInitContextStorageNode"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(GrammarInitContextStorageNode, Object); +}; + +class GrammarInitContextStorage : public ObjectRef { + public: + /*! + * \brief Construct a GrammarInitContextStorage with a token table. This class will always create + * grammar state init contexts with this token table. + * \param token_table The token table that the grammar will use. + */ + GrammarInitContextStorage(const std::vector& token_table); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GrammarInitContextStorage, ObjectRef, + GrammarInitContextStorageNode); +}; + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/grammar/grammar_state_matcher_preproc.h b/cpp/serve/grammar/grammar_state_matcher_preproc.h index dbb59f886b..c853ac7e04 100644 --- a/cpp/serve/grammar/grammar_state_matcher_preproc.h +++ b/cpp/serve/grammar/grammar_state_matcher_preproc.h @@ -57,6 +57,8 @@ class GrammarStateInitContext { public: /******************* Information about the tokenizer *******************/ + /*! \brief The token table. Now only used for debug purpose. */ + std::vector token_table; /*! \brief The vocabulary size of the tokenizer. */ size_t vocab_size; /*! \brief All tokens represented by the id and codepoints of each. The tokens are sorted by @@ -246,6 +248,7 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC auto ptr = std::make_shared(); ptr->grammar = grammar; + ptr->token_table = token_table; ptr->vocab_size = token_table.size(); if (ptr->vocab_size == 0) { @@ -317,6 +320,55 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC return ptr; } +class GrammarInitContextStorageImpl : public GrammarInitContextStorageNode { + public: + GrammarInitContextStorageImpl(const std::vector& token_table); + + std::shared_ptr GetInitContextForJSONSchema(const std::string& schema); + + std::shared_ptr GetInitContextForJSON(); + + void ClearCache(); + + private: + /*! \brief The token table associated with this storage class. */ + std::vector token_table_; + /*! \brief The cache for the init context of a JSON schema. */ + std::unordered_map> + init_ctx_for_schema_cache_; + /*! \brief The init context for JSON. */ + std::shared_ptr init_ctx_for_json_; +}; + +inline GrammarInitContextStorageImpl::GrammarInitContextStorageImpl( + const std::vector& token_table) + : token_table_(token_table) { + init_ctx_for_json_ = + GrammarStateMatcher::CreateInitContext(BNFGrammar::GetGrammarOfJSON(), token_table_); +} + +inline std::shared_ptr +GrammarInitContextStorageImpl::GetInitContextForJSONSchema(const std::string& schema) { + auto it = init_ctx_for_schema_cache_.find(schema); + if (it != init_ctx_for_schema_cache_.end()) { + return it->second; + } + auto init_ctx = + GrammarStateMatcher::CreateInitContext(BNFGrammar::FromSchema(schema), token_table_); + init_ctx_for_schema_cache_[schema] = init_ctx; + return init_ctx; +} + +inline std::shared_ptr +GrammarInitContextStorageImpl::GetInitContextForJSON() { + return init_ctx_for_json_; +} + +inline void GrammarInitContextStorageImpl::ClearCache() { init_ctx_for_schema_cache_.clear(); } + +GrammarInitContextStorage::GrammarInitContextStorage(const std::vector& token_table) + : ObjectRef(make_object(token_table)) {} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/grammar/support.h b/cpp/serve/grammar/support.h index 9df1083335..fb9002dbac 100644 --- a/cpp/serve/grammar/support.h +++ b/cpp/serve/grammar/support.h @@ -18,17 +18,20 @@ namespace serve { /*! \brief Manages a segment of externally provided memory and use it as a bitset. */ class BitsetManager { public: - BitsetManager(uint32_t* data, int buffer_size) : data_(data), buffer_size_(buffer_size) {} + BitsetManager(uint32_t* data, int buffer_size, int element_cnt) + : data_(data), buffer_size_(buffer_size), element_cnt_(element_cnt) { + DCHECK(buffer_size >= CalculateBufferSize(element_cnt)); + } - static int GetBitsetSize(int size) { return (size + 31) / 32; } + static int CalculateBufferSize(int element_cnt) { return (element_cnt + 31) / 32; } bool operator[](int index) const { - DCHECK(index >= 0 && index / 32 < buffer_size_); + DCHECK(index >= 0 && index < element_cnt_); return (data_[index / 32] >> (index % 32)) & 1; } void Set(int index, bool value) { - DCHECK(index >= 0 && index / 32 < buffer_size_); + DCHECK(index >= 0 && index < element_cnt_); if (value) { data_[index / 32] |= 1 << (index % 32); } else { @@ -36,14 +39,14 @@ class BitsetManager { } } - void Reset(int size, bool value) { - DCHECK(buffer_size_ >= GetBitsetSize(size)); - std::memset(data_, value ? 0xFF : 0, GetBitsetSize(size) * sizeof(uint32_t)); - } + void Reset(bool value) { std::memset(data_, value ? 0xFF : 0, buffer_size_ * sizeof(uint32_t)); } + + int GetElementCnt() const { return element_cnt_; } private: uint32_t* const data_; const int buffer_size_; + const int element_cnt_; }; /*! diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 1a0e1970f7..2a035ad387 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -15,15 +15,15 @@ TVM_REGISTER_OBJECT_TYPE(RequestModelStateNode); RequestModelState::RequestModelState( Request request, int model_id, int64_t internal_id, Array inputs, - std::shared_ptr json_grammar_state_init_ctx) { + const std::optional>& grammar_state_init_ctx) { ObjectPtr n = make_object(); n->model_id = model_id; n->internal_id = internal_id; n->inputs = std::move(inputs); - if (request->generation_cfg->response_format.type == "json_object") { + if (grammar_state_init_ctx.has_value()) { // TODO(yixin): add support for stop_token_ids - n->grammar_state_matcher = GrammarStateMatcher(json_grammar_state_init_ctx); + n->grammar_state_matcher = GrammarStateMatcher(grammar_state_init_ctx.value()); } n->request = std::move(request); @@ -89,7 +89,8 @@ TVM_REGISTER_OBJECT_TYPE(RequestStateEntryNode); RequestStateEntry::RequestStateEntry( Request request, int num_models, int64_t internal_id, int rng_seed, const std::vector& token_table, - std::shared_ptr json_grammar_state_init_ctx, int parent_idx) { + const std::optional>& grammar_state_init_ctx, + int parent_idx) { ObjectPtr n = make_object(); Array mstates; Array inputs; @@ -98,8 +99,7 @@ RequestStateEntry::RequestStateEntry( } mstates.reserve(num_models); for (int i = 0; i < num_models; ++i) { - mstates.push_back( - RequestModelState(request, i, internal_id, inputs, json_grammar_state_init_ctx)); + mstates.push_back(RequestModelState(request, i, internal_id, inputs, grammar_state_init_ctx)); } n->status = RequestStateStatus::kPending; n->rng = RandomGenerator(rng_seed); diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 83a12fade4..7764a38c3e 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -10,6 +10,8 @@ #include #include +#include + #include "../random.h" #include "../streamer.h" #include "config.h" @@ -107,8 +109,9 @@ class RequestModelStateNode : public Object { class RequestModelState : public ObjectRef { public: - explicit RequestModelState(Request request, int model_id, int64_t internal_id, Array inputs, - std::shared_ptr json_grammar_state_init_ctx); + explicit RequestModelState( + Request request, int model_id, int64_t internal_id, Array inputs, + const std::optional>& grammar_state_init_ctx); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestModelState, ObjectRef, RequestModelStateNode); }; @@ -213,10 +216,11 @@ class RequestStateEntryNode : public Object { class RequestStateEntry : public ObjectRef { public: - explicit RequestStateEntry(Request request, int num_models, int64_t internal_id, int rng_seed, - const std::vector& token_table, - std::shared_ptr json_grammar_state_init_ctx, - int parent_idx = -1); + explicit RequestStateEntry( + Request request, int num_models, int64_t internal_id, int rng_seed, + const std::vector& token_table, + const std::optional>& grammar_state_init_ctx, + int parent_idx = -1); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RequestStateEntry, ObjectRef, RequestStateEntryNode); }; diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index c2cff9c4fd..4ac6daef71 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -69,7 +69,11 @@ class ModelResponse(BaseModel): class RequestResponseFormat(BaseModel): type: Literal["text", "json_object"] = "text" - json_schema: Optional[str] = None + json_schema: Optional[str] = Field(default=None, alias="schema") + """This field is named json_schema instead of schema because BaseModel defines a method called + schema. During construction of RequestResponseFormat, key "schema" still should be used: + `RequestResponseFormat(type="json_object", schema="{}")` + """ class CompletionRequest(BaseModel): @@ -333,5 +337,5 @@ def openai_api_get_generation_config( kwargs["max_tokens"] = -1 if request.stop is not None: kwargs["stop_strs"] = [request.stop] if isinstance(request.stop, str) else request.stop - kwargs["response_format"] = ResponseFormat(**request.response_format.model_dump()) + kwargs["response_format"] = ResponseFormat(**request.response_format.model_dump(by_alias=True)) return kwargs diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 1b90a4b24a..e539ec7e56 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -14,7 +14,7 @@ class ResponseFormat: type : Literal["text", "json_object"] The type of response format. Default: "text". - json_schema : Optional[str] + schema : Optional[str] The JSON schema string for the JSON response format. If None, a legal json string without special restrictions will be generated. @@ -22,11 +22,11 @@ class ResponseFormat: """ type: Literal["text", "json_object"] = "text" - json_schema: Optional[str] = None + schema: Optional[str] = None def __post_init__(self): - if self.json_schema is not None and self.type != "json_object": - raise ValueError("JSON json_schema is only supported in JSON response format") + if self.schema is not None and self.type != "json_object": + raise ValueError("JSON schema is only supported in JSON response format") @dataclass diff --git a/python/mlc_llm/serve/grammar.py b/python/mlc_llm/serve/grammar.py index d5a6887d22..6e9eac8655 100644 --- a/python/mlc_llm/serve/grammar.py +++ b/python/mlc_llm/serve/grammar.py @@ -1,5 +1,6 @@ """Classes handling the grammar guided generation of MLC LLM serving""" -from typing import List, Union + +from typing import List, Optional, Tuple, Union import tvm._ffi from tvm.runtime import Object @@ -112,6 +113,47 @@ def to_json(self, prettify: bool = True) -> str: _ffi_api.BNFGrammarToJSON(self, prettify) # type: ignore # pylint: disable=no-member ) + @staticmethod + def from_schema( + schema: str, + *, + indent: Optional[int] = None, + separators: Optional[Tuple[str, str]] = None, + strict_mode: bool = True + ) -> "BNFGrammar": + """Construct a BNF grammar from the json schema string. The schema string should be in the + format of the schema of a JSON file. We will parse the schema and generate a BNF grammar. + + Parameters + ---------- + schema : str + The schema string. + + indent : Optional[int] + The number of spaces for indentation. If None, the output will be in one line. + Default: None. + + separators : Optional[Tuple[str, str]] + Two separators used in the schema: comma and colon. Examples: (",", ":"), (", ", ": "). + If None, the default separators will be used: (",", ": ") when the indent is not None, + and (", ", ": ") otherwise. Default: None. + + strict_mode : bool + Whether to use strict mode. In strict mode, the generated grammar will not allow + unevaluatedProperties and unevaluatedItems, i.e. these will be set to false by default. + This helps LLM to generate accurate output in the grammar-guided generation with JSON + schema. Default: True. + + Returns + ------- + grammar : BNFGrammar + The generated BNF grammar. + """ + indent_converted = -1 if indent is None else indent + return _ffi_api.BNFGrammarFromSchema( # type: ignore # pylint: disable=no-member + schema, indent_converted, separators, strict_mode + ) + @staticmethod def get_grammar_of_json() -> "BNFGrammar": """Get the grammar of standard JSON. @@ -197,16 +239,22 @@ def accept_token(self, token_id: int) -> bool: """ return _ffi_api.GrammarStateMatcherAcceptToken(self, token_id) # type: ignore # pylint: disable=no-member - def find_next_rejected_tokens(self) -> List[int]: + def find_next_rejected_tokens(self, verbose: bool = False) -> List[int]: """Find the ids of the rejected tokens for the next step. + Parameters + ---------- + verbose : bool + Whether to print information about the timing and results to stderr. For debug purposes. + Default: False. + Returns ------- rejected_token_ids : List[int] A list of rejected token ids. """ - return _ffi_api.GrammarStateMatcherFindNextRejectedTokens(self) # type: ignore # pylint: disable=no-member + return _ffi_api.GrammarStateMatcherFindNextRejectedTokens(self, verbose) # type: ignore # pylint: disable=no-member def rollback(self, num_tokens: int) -> None: """Rollback the matcher to a previous state. diff --git a/python/mlc_llm/serve/json_schema_converter.py b/python/mlc_llm/serve/json_schema_converter.py index eb17b50fc3..9a4af6176e 100644 --- a/python/mlc_llm/serve/json_schema_converter.py +++ b/python/mlc_llm/serve/json_schema_converter.py @@ -4,6 +4,8 @@ import logging from typing import Any, Dict, List, Optional, Tuple, Union +from tvm._ffi import register_func + SchemaType = Union[Dict[str, Any], bool] """ JSON schema specification defines the schema type could be a dictionary or a boolean value. @@ -33,6 +35,7 @@ def __enter__(self): """Enter a new indent level.""" self.total_indent += self.indent self.is_first.append(True) + return self def __exit__(self, exc_type, exc_value, traceback): """Exit the current indent level.""" @@ -406,16 +409,16 @@ def _visit_array(self, schema: SchemaType, rule_name: str) -> str: ) res = '"["' + could_be_empty = False with self.indent_manager: # 1. Handle prefix items - have_prefix_items = False - if "prefixItems" in schema: - for i, prefix_item in enumerate(schema["prefixItems"]): + prefix_items = schema.get("prefixItems", []) + if len(prefix_items) > 0: + for i, prefix_item in enumerate(prefix_items): assert prefix_item is not False item = self._create_rule_with_schema(prefix_item, f"{rule_name}_{i}") res += f" {self._get_sep()} {item}" - have_prefix_items = True # 2. Find additional items additional_item = None @@ -439,18 +442,22 @@ def _visit_array(self, schema: SchemaType, rule_name: str) -> str: additional_pattern = self._create_rule_with_schema( additional_item, f"{rule_name}_{additional_suffix}" ) - if have_prefix_items: + if len(prefix_items) > 0: res += ( - f' ("" | ({self._get_sep()} {additional_pattern})*)' - f" {self._get_sep(is_end=True)}" + f" ({self._get_sep()} {additional_pattern})* {self._get_sep(is_end=True)}" ) else: res += ( - f' ("" | {self._get_sep()} {additional_pattern} ({self._get_sep()} ' - f"{additional_pattern})* {self._get_sep(is_end=True)})" + f" {self._get_sep()} {additional_pattern} ({self._get_sep()} " + f"{additional_pattern})* {self._get_sep(is_end=True)}" ) + could_be_empty = True res += ' "]"' + + if could_be_empty: + res = f'({res}) | "[]"' + return res def _visit_object(self, schema: SchemaType, rule_name: str) -> str: @@ -500,6 +507,9 @@ def _visit_object(self, schema: SchemaType, rule_name: str) -> str: ) res = '"{"' + # Set could_be_empty to True when the rule could be "{}". We will handle this case at last, + # and handle non-empty cases before that. + could_be_empty = False # Now we only consider the required list for the properties field required = schema.get("required", []) @@ -528,6 +538,7 @@ def _visit_object(self, schema: SchemaType, rule_name: str) -> str: res += " " + self._get_partial_rule_for_properties_all_optional( properties, additional_property, rule_name, additional_suffix ) + could_be_empty = True elif len(properties) > 0: # 3.2 Case 2: properties are defined and some properties are required res += " " + self._get_partial_rule_for_properties_contain_required( @@ -545,11 +556,15 @@ def _visit_object(self, schema: SchemaType, rule_name: str) -> str: self.BASIC_STRING, additional_property, rule_name, additional_suffix ) res += ( - f" ({self._get_sep()} {other_property_pattern} ({self._get_sep()} " - f'{other_property_pattern})* {self._get_sep(is_end=True)} | "")' + f" {self._get_sep()} {other_property_pattern} ({self._get_sep()} " + f"{other_property_pattern})* {self._get_sep(is_end=True)}" ) + could_be_empty = True res += ' "}"' + + if could_be_empty: + res = f'({res}) | "{{}}"' return res def _get_property_pattern(self, prop_name: str, prop_schema: SchemaType, rule_name: str) -> str: @@ -625,7 +640,7 @@ def _get_partial_rule_for_properties_all_optional( res += f" | {additional_prop_pattern} {rule_names[-1]}" # add separators and the empty string option - res = f'({first_sep} ({res}) {last_sep} | "")' + res = f"{first_sep} ({res}) {last_sep}" return res def _get_partial_rule_for_properties_contain_required( @@ -711,3 +726,17 @@ def json_schema_to_ebnf( """ json_schema_schema = json.loads(json_schema) return _JSONSchemaToEBNFConverter(json_schema_schema, indent, separators, strict_mode).convert() + + +@register_func("mlc.serve.json_schema_to_ebnf") +def json_schema_to_ebnf_register( + json_schema: str, + indent: Optional[int] = None, + separators: Optional[Tuple[str, str]] = None, + strict_mode: bool = True, +) -> str: + """To register json_schema_to_ebnf in ffi, we need to create an equivalent function without + keyword-only arguments.""" + return json_schema_to_ebnf( + json_schema, indent=indent, separators=separators, strict_mode=strict_mode + ) diff --git a/conftest.py b/tests/python/conftest.py similarity index 100% rename from conftest.py rename to tests/python/conftest.py diff --git a/tests/python/serve/benchmark.py b/tests/python/serve/benchmark.py index 94d48c12af..fe914d1073 100644 --- a/tests/python/serve/benchmark.py +++ b/tests/python/serve/benchmark.py @@ -11,6 +11,7 @@ from transformers import AutoTokenizer from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig +from mlc_llm.serve.config import ResponseFormat from mlc_llm.serve.engine import ModelInfo @@ -26,6 +27,7 @@ def _parse_args(): args.add_argument("--page-size", type=int, default=16) args.add_argument("--max-total-seq-length", type=int) args.add_argument("--seed", type=int, default=0) + args.add_argument("--json-output", type=bool, default=False) parsed = args.parse_args() parsed.model = os.path.dirname(parsed.model_lib_path) @@ -35,7 +37,7 @@ def _parse_args(): def sample_requests( - dataset_path: str, num_requests: int, model_path: str + dataset_path: str, num_requests: int, model_path: str, json_output: bool = False ) -> Tuple[List[str], List[GenerationConfig]]: """Sample requests from dataset. Acknowledgement to the benchmark scripts in the vLLM project. @@ -78,8 +80,11 @@ def sample_requests( # Construct generation config. prompts = [prompt for prompt, _, _ in sampled_requests] + response_format = ResponseFormat("json_object" if json_output else "text") generation_config_list = [ - GenerationConfig(temperature=1.0, top_p=1.0, max_tokens=output_len) + GenerationConfig( + temperature=1.0, top_p=1.0, max_tokens=output_len, response_format=response_format + ) for _, _, output_len in sampled_requests ] return prompts, generation_config_list @@ -110,7 +115,9 @@ def benchmark(args: argparse.Namespace): # Create engine engine = Engine(model, kv_cache_config) # Sample prompts from dataset - prompts, generation_config = sample_requests(args.dataset, args.num_prompts, args.model) + prompts, generation_config = sample_requests( + args.dataset, args.num_prompts, args.model, args.json_output + ) # Engine statistics num_runs = 1 single_token_prefill_latency = [] @@ -138,12 +145,16 @@ def engine_generate(): engine_total_decode_time = np.array(engine_total_decode_time) total_prefill_tokens = np.array(total_prefill_tokens) total_decode_tokens = np.array(total_decode_tokens) + avg_prefill_tokens = total_prefill_tokens / len(prompts) + avg_decode_tokens = total_decode_tokens / len(prompts) prefill_throughput = total_prefill_tokens / engine_total_prefill_time decode_throughput = total_decode_tokens / engine_total_decode_time overall_throughput = (total_prefill_tokens + total_decode_tokens) / e2e_latency print(args) print(f"Average end-to-end latency: {e2e_latency.mean():.4f} seconds for the entire batch") + print(f"Average prefill tokens: {avg_prefill_tokens.mean():.4f} tok/req") + print(f"Average decode tokens: {avg_decode_tokens.mean():.4f} tok/req") print(f"Single token prefill latency: {single_token_prefill_latency.mean() * 1e3:.4f} ms/tok") print(f"Single token decode latency: {single_token_decode_latency.mean() * 1e3:.4f} ms/tok") print(f"Engine prefill time: {engine_total_prefill_time.mean():.4f} s") diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index 7ef6e22fe0..286d64a874 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -28,6 +28,7 @@ import regex import requests from openai import OpenAI +from pydantic import BaseModel OPENAI_BASE_URL = "http://127.0.0.1:8000/v1" OPENAI_V1_MODELS_URL = "http://127.0.0.1:8000/v1/models" @@ -43,7 +44,15 @@ JSON_TOKEN_RE = regex.compile(JSON_TOKEN_PATTERN) -def is_json_or_json_prefix(s: str) -> bool: +def is_json(s: str) -> bool: + try: + json.loads(s) + return True + except json.JSONDecodeError: + return False + + +def is_json_prefix(s: str) -> bool: try: json.loads(s) return True @@ -71,7 +80,7 @@ def check_openai_nonstream_response( suffix: Optional[str] = None, stop: Optional[List[str]] = None, require_substr: Optional[List[str]] = None, - json_mode: bool = False, + check_json_output: bool = False, ): assert response["model"] == model assert response["object"] == object_str @@ -103,8 +112,15 @@ def check_openai_nonstream_response( if require_substr is not None: for substr in require_substr: assert substr in texts[idx] - if json_mode: - assert is_json_or_json_prefix(texts[idx]) + if check_json_output: + # the output should be json or a prefix of a json string + # if the output is a prefix of a json string, the output must exceed the max output + # length + output_is_json = is_json(texts[idx]) + output_is_json_prefix = is_json_prefix(texts[idx]) + assert output_is_json or output_is_json_prefix + if not output_is_json and output_is_json_prefix: + assert choice["finish_reason"] == "length" usage = response["usage"] assert isinstance(usage, dict) @@ -127,12 +143,13 @@ def check_openai_stream_response( suffix: Optional[str] = None, stop: Optional[List[str]] = None, require_substr: Optional[List[str]] = None, - json_mode: bool = False, + check_json_output: bool = False, ): assert len(responses) > 0 finished = [False for _ in range(num_choices)] outputs = ["" for _ in range(num_choices)] + finish_reason_list = ["" for _ in range(num_choices)] for response in responses: assert response["model"] == model assert response["object"] == object_str @@ -154,8 +171,10 @@ def check_openai_stream_response( if finished[idx]: assert choice["finish_reason"] in finish_reasons + finish_reason_list[idx] = choice["finish_reason"] elif choice["finish_reason"] is not None: assert choice["finish_reason"] in finish_reasons + finish_reason_list[idx] = choice["finish_reason"] finished[idx] = True if not is_chat_completion: @@ -170,7 +189,7 @@ def check_openai_stream_response( if completion_tokens is not None: assert responses[-1]["usage"]["completion_tokens"] == completion_tokens - for i, output in enumerate(outputs): + for i, (output, finish_reason) in enumerate(zip(outputs, finish_reason_list)): if echo_prompt is not None: assert output.startswith(echo_prompt) if suffix is not None: @@ -181,8 +200,15 @@ def check_openai_stream_response( if require_substr is not None: for substr in require_substr: assert substr in output - if json_mode: - assert is_json_or_json_prefix(output) + if check_json_output: + # the output should be json or a prefix of a json string + # if the output is a prefix of a json string, the output must exceed the max output + # length + output_is_json = is_json(output) + output_is_json_prefix = is_json_prefix(output) + assert output_is_json or output_is_json_prefix + if not output_is_json and output_is_json_prefix: + assert finish_reason == "length" def expect_error(response_str: str, msg_prefix: Optional[str] = None): @@ -513,8 +539,6 @@ def test_openai_v1_completions_temperature( ) -# TODO(yixin): support eos_token_id for tokenizer -@pytest.mark.skip("JSON test for completion api requires internal eos_token_id support") @pytest.mark.parametrize("stream", [False, True]) def test_openai_v1_completions_json( served_model: Tuple[str, str], @@ -543,7 +567,7 @@ def test_openai_v1_completions_json( object_str="text_completion", num_choices=1, finish_reasons=["length", "stop"], - json_mode=True, + check_json_output=True, ) else: responses = [] @@ -558,7 +582,65 @@ def test_openai_v1_completions_json( object_str="text_completion", num_choices=1, finish_reasons=["length", "stop"], - json_mode=True, + check_json_output=True, + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_completions_json_schema( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + prompt = ( + "Generate a json containing three fields: an integer field named size, a " + "boolean field named is_accepted, and a float field named num:" + ) + max_tokens = 128 + + class Schema(BaseModel): + size: int + is_accepted: bool + num: float + + schema_str = json.dumps(Schema.model_json_schema()) + + payload = { + "model": served_model[0], + "prompt": prompt, + "max_tokens": max_tokens, + "stream": stream, + "response_format": {"type": "json_object", "schema": schema_str}, + } + + response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reasons=["length", "stop"], + check_json_output=True, + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=False, + model=served_model[0], + object_str="text_completion", + num_choices=1, + finish_reasons=["length", "stop"], + check_json_output=True, ) @@ -1040,7 +1122,66 @@ def test_openai_v1_chat_completions_json( object_str="chat.completion", num_choices=1, finish_reasons=["length", "stop"], - json_mode=True, + check_json_output=True, + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion.chunk", + num_choices=1, + finish_reasons=["length", "stop"], + check_json_output=True, + ) + + +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_v1_chat_completions_json_schema( + served_model: Tuple[str, str], + launch_server, # pylint: disable=unused-argument + stream: bool, +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + prompt = ( + "Generate a json containing three fields: an integer field named size, a " + "boolean field named is_accepted, and a float field named num:" + ) + messages = [{"role": "user", "content": prompt}] + max_tokens = 128 + + class Schema(BaseModel): + size: int + is_accepted: bool + num: float + + schema_str = json.dumps(Schema.model_json_schema()) + + payload = { + "model": served_model[0], + "messages": messages, + "stream": stream, + "max_tokens": max_tokens, + "response_format": {"type": "json_object", "schema": schema_str}, + } + + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + is_chat_completion=True, + model=served_model[0], + object_str="chat.completion", + num_choices=1, + finish_reasons=["length", "stop"], + check_json_output=True, ) else: responses = [] @@ -1055,7 +1196,7 @@ def test_openai_v1_chat_completions_json( object_str="chat.completion.chunk", num_choices=1, finish_reasons=["length", "stop"], - json_mode=True, + check_json_output=True, ) diff --git a/tests/python/serve/test_grammar_state_matcher_custom.py b/tests/python/serve/test_grammar_state_matcher_custom.py index f38ac312ef..5bdc8ecc4b 100644 --- a/tests/python/serve/test_grammar_state_matcher_custom.py +++ b/tests/python/serve/test_grammar_state_matcher_custom.py @@ -3,14 +3,16 @@ """This test is adopted from test_grammar_state_matcher_json.py, but the grammar is parsed from a unoptimized, non-simplified EBNF string. This is to test the robustness of the grammar state matcher.""" +import json import sys -from typing import List, Optional +from typing import Dict, List, Optional, Tuple import pytest import tvm import tvm.testing +from pydantic import BaseModel -from mlc_llm.serve import BNFGrammar, GrammarStateMatcher +from mlc_llm.serve import BNFGrammar, GrammarStateMatcher, json_schema_to_ebnf from mlc_llm.tokenizer import Tokenizer @@ -282,11 +284,11 @@ def test_find_next_rejected_tokens( real_sizes = [] for c in input_find_rejected_tokens: - rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) real_sizes.append(len(rejected_token_ids)) print("Accepting char:", c, file=sys.stderr) assert grammar_state_matcher.debug_accept_char(ord(c)) - rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) real_sizes.append(len(rejected_token_ids)) if expected_rejected_sizes is not None: @@ -352,6 +354,44 @@ def test_custom_main_rule(): assert not GrammarStateMatcher(grammar).debug_match_complete_string(r'{"name": "John" }') +def test_find_next_rejected_tokens_schema(): + class MainModel(BaseModel): + integer_field: int + number_field: float + boolean_field: bool + any_array_field: List + array_field: List[str] + tuple_field: Tuple[str, int, List[str]] + object_field: Dict[str, int] + nested_object_field: Dict[str, Dict[str, int]] + + schema = MainModel.model_json_schema() + schema_str = json.dumps(schema) + ebnf_grammar = BNFGrammar.from_schema(schema_str, indent=2) + + instance = MainModel( + integer_field=42, + number_field=3.14e5, + boolean_field=True, + any_array_field=[3.14, "foo", None, True], + array_field=["foo", "bar"], + tuple_field=("foo", 42, ["bar", "baz"]), + object_field={"foo": 42, "bar": 43}, + nested_object_field={"foo": {"bar": 42}}, + ) + instance_str = instance.model_dump_json(indent=2, round_trip=True) + + tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + tokenizer = Tokenizer(tokenizer_path) + matcher = GrammarStateMatcher(ebnf_grammar, tokenizer) + + for c in instance_str: + matcher.find_next_rejected_tokens(True) + print("Accepting char:", c, file=sys.stderr) + assert matcher.debug_accept_char(ord(c)) + matcher.find_next_rejected_tokens(True) + + if __name__ == "__main__": # Run a benchmark to show the performance before running tests test_find_next_rejected_tokens(get_json_grammar(), '{"id": 1,"name": "Example"}') diff --git a/tests/python/serve/test_grammar_state_matcher_json.py b/tests/python/serve/test_grammar_state_matcher_json.py index dfc0257b04..fc0f79a041 100644 --- a/tests/python/serve/test_grammar_state_matcher_json.py +++ b/tests/python/serve/test_grammar_state_matcher_json.py @@ -262,11 +262,11 @@ def test_find_next_rejected_tokens( real_sizes = [] for c in input_find_rejected_tokens: - rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) real_sizes.append(len(rejected_token_ids)) print("Accepting char:", c, file=sys.stderr) assert grammar_state_matcher.debug_accept_char(ord(c)) - rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens() + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) real_sizes.append(len(rejected_token_ids)) if expected_rejected_sizes is not None: assert real_sizes == expected_rejected_sizes diff --git a/tests/python/serve/test_json_schema_converter.py b/tests/python/serve/test_json_schema_converter.py index 138207511b..822199977c 100644 --- a/tests/python/serve/test_json_schema_converter.py +++ b/tests/python/serve/test_json_schema_converter.py @@ -76,18 +76,21 @@ class MainModel(BaseModel): basic_string ::= ["] basic_string_sub ["] basic_boolean ::= "true" | "false" basic_null ::= "null" -basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" -basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" -main_any_array_field ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" -main_array_field ::= "[" ("" | "" basic_string (", " basic_string)* "") "]" -main_tuple_field_2 ::= "[" ("" | "" basic_string (", " basic_string)* "") "]" +basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" +main_any_array_field ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +main_array_field ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" +main_tuple_field_2 ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" main_tuple_field ::= "[" "" basic_string ", " basic_integer ", " main_tuple_field_2 "" "]" -main_object_field ::= "{" ("" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" | "") "}" -main_nested_object_field_add ::= "{" ("" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" | "") "}" -main_nested_object_field ::= "{" ("" basic_string ": " main_nested_object_field_add (", " basic_string ": " main_nested_object_field_add)* "" | "") "}" +main_object_field ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" +main_nested_object_field_add ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" +main_nested_object_field ::= ("{" "" basic_string ": " main_nested_object_field_add (", " basic_string ": " main_nested_object_field_add)* "" "}") | "{}" main ::= "{" "" "\"integer_field\"" ": " basic_integer ", " "\"number_field\"" ": " basic_number ", " "\"boolean_field\"" ": " basic_boolean ", " "\"any_array_field\"" ": " main_any_array_field ", " "\"array_field\"" ": " main_array_field ", " "\"tuple_field\"" ": " main_tuple_field ", " "\"object_field\"" ": " main_object_field ", " "\"nested_object_field\"" ": " main_nested_object_field "" "}" """ + schema = MainModel.model_json_schema() + check_schema_with_grammar(schema, ebnf_grammar) + instance = MainModel( integer_field=42, number_field=3.14e5, @@ -98,10 +101,21 @@ class MainModel(BaseModel): object_field={"foo": 42, "bar": 43}, nested_object_field={"foo": {"bar": 42}}, ) + check_schema_with_instance(schema, instance) + + instance_empty = MainModel( + integer_field=42, + number_field=3.14e5, + boolean_field=True, + any_array_field=[], + array_field=[], + tuple_field=("foo", 42, []), + object_field={}, + nested_object_field={}, + ) schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar) - check_schema_with_instance(schema, instance) + check_schema_with_instance(schema, instance_empty) def test_indent(): @@ -118,12 +132,12 @@ class MainModel(BaseModel): basic_string ::= ["] basic_string_sub ["] basic_boolean ::= "true" | "false" basic_null ::= "null" -basic_array ::= "[" ("" | "" basic_any ("," basic_any)* "") "]" -basic_object ::= "{" ("" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" | "") "}" -main_array_field ::= "[" ("" | "\n " basic_string (",\n " basic_string)* "\n ") "]" -main_tuple_field_2 ::= "[" ("" | "\n " basic_string (",\n " basic_string)* "\n ") "]" +basic_array ::= ("[" "" basic_any ("," basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" "}") | "{}" +main_array_field ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" +main_tuple_field_2 ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" main_tuple_field ::= "[" "\n " basic_string ",\n " basic_integer ",\n " main_tuple_field_2 "\n " "]" -main_object_field ::= "{" ("\n " basic_string ": " basic_integer (",\n " basic_string ": " basic_integer)* "\n " | "") "}" +main_object_field ::= ("{" "\n " basic_string ": " basic_integer (",\n " basic_string ": " basic_integer)* "\n " "}") | "{}" main ::= "{" "\n " "\"array_field\"" ": " main_array_field ",\n " "\"tuple_field\"" ": " main_tuple_field ",\n " "\"object_field\"" ": " main_object_field "\n" "}" """ @@ -155,11 +169,11 @@ class MainModel(BaseModel): basic_string ::= ["] basic_string_sub ["] basic_boolean ::= "true" | "false" basic_null ::= "null" -basic_array ::= "[" ("" | "" basic_any ("," basic_any)* "") "]" -basic_object ::= "{" ("" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" | "") "}" -main_tuple_field_1 ::= "[" "\n " basic_integer ",\n " basic_integer ("" | (",\n " basic_any)*) "\n " "]" -main_tuple_field ::= "[" "\n " basic_string ",\n " main_tuple_field_1 ("" | (",\n " basic_any)*) "\n " "]" -main_foo_field ::= "{" ("\n " basic_string ": " basic_any (",\n " basic_string ": " basic_any)* "\n " | "") "}" +basic_array ::= ("[" "" basic_any ("," basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" "}") | "{}" +main_tuple_field_1 ::= "[" "\n " basic_integer ",\n " basic_integer (",\n " basic_any)* "\n " "]" +main_tuple_field ::= "[" "\n " basic_string ",\n " main_tuple_field_1 (",\n " basic_any)* "\n " "]" +main_foo_field ::= ("{" "\n " basic_string ": " basic_any (",\n " basic_string ": " basic_any)* "\n " "}") | "{}" main ::= "{" "\n " "\"tuple_field\"" ": " main_tuple_field ",\n " "\"foo_field\"" ": " main_foo_field (",\n " basic_string ": " basic_any)* "\n" "}" """ @@ -204,8 +218,8 @@ class MainModel(BaseModel): basic_string ::= ["] basic_string_sub ["] basic_boolean ::= "true" | "false" basic_null ::= "null" -basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" -basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" main_bars ::= "\"a\"" main_str_values ::= "\"a\\n\\r\\\"\"" main_foo ::= ("\"a\"") | ("\"b\"") | ("\"c\"") @@ -235,8 +249,8 @@ class MainModel(BaseModel): basic_string ::= ["] basic_string_sub ["] basic_boolean ::= "true" | "false" basic_null ::= "null" -basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" -basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" main_opt_bool ::= basic_boolean | basic_null main_size ::= basic_number | basic_null main ::= "{" "" ("\"num\"" ": " basic_integer ", ")? ("\"opt_bool\"" ": " main_opt_bool ", ")? "\"size\"" ": " main_size (", " "\"name\"" ": " basic_string)? "" "}" @@ -270,11 +284,11 @@ class MainModel(BaseModel): basic_string ::= ["] basic_string_sub ["] basic_boolean ::= "true" | "false" basic_null ::= "null" -basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" -basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" main_sub_1 ::= "" | ", " "\"num\"" ": " basic_number "" main_sub_0 ::= main_sub_1 | ", " "\"state\"" ": " basic_boolean main_sub_1 -main ::= "{" ("" (("\"size\"" ": " basic_integer main_sub_0) | ("\"state\"" ": " basic_boolean main_sub_1) | ("\"num\"" ": " basic_number "")) "" | "") "}" +main ::= ("{" "" (("\"size\"" ": " basic_integer main_sub_0) | ("\"state\"" ": " basic_boolean main_sub_1) | ("\"num\"" ": " basic_number "")) "" "}") | "{}" """ schema = MainModel.model_json_schema() @@ -294,12 +308,12 @@ class MainModel(BaseModel): basic_string ::= ["] basic_string_sub ["] basic_boolean ::= "true" | "false" basic_null ::= "null" -basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" -basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" main_sub_2 ::= (", " basic_string ": " basic_any)* main_sub_1 ::= main_sub_2 | ", " "\"num\"" ": " basic_number main_sub_2 main_sub_0 ::= main_sub_1 | ", " "\"state\"" ": " basic_boolean main_sub_1 -main ::= "{" ("" (("\"size\"" ": " basic_integer main_sub_0) | ("\"state\"" ": " basic_boolean main_sub_1) | ("\"num\"" ": " basic_number main_sub_2) | basic_string ": " basic_any main_sub_2) "" | "") "}" +main ::= ("{" "" (("\"size\"" ": " basic_integer main_sub_0) | ("\"state\"" ": " basic_boolean main_sub_1) | ("\"num\"" ": " basic_number main_sub_2) | basic_string ": " basic_any main_sub_2) "" "}") | "{}" """ check_schema_with_grammar(schema, ebnf_grammar_non_strict, strict_mode=False) @@ -308,6 +322,32 @@ class MainModel(BaseModel): check_schema_with_json(schema, '{"other": false}', strict_mode=False) +def test_empty(): + class MainModel(BaseModel): + pass + + ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" +main ::= "{" "}" +""" + + schema = MainModel.model_json_schema() + check_schema_with_grammar(schema, ebnf_grammar) + + instance = MainModel() + check_schema_with_instance(schema, instance) + + check_schema_with_json(schema, '{"tmp": 123}', strict_mode=False) + + def test_reference(): class Foo(BaseModel): count: int @@ -334,13 +374,13 @@ class MainModel(BaseModel): basic_string ::= ["] basic_string_sub ["] basic_boolean ::= "true" | "false" basic_null ::= "null" -basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" -basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" main_foo_size ::= basic_number | basic_null main_foo ::= "{" "" "\"count\"" ": " basic_integer (", " "\"size\"" ": " main_foo_size)? "" "}" main_bars_item_sub_0 ::= "" | ", " "\"banana\"" ": " basic_string "" -main_bars_item ::= "{" ("" (("\"apple\"" ": " basic_string main_bars_item_sub_0) | ("\"banana\"" ": " basic_string "")) "" | "") "}" -main_bars ::= "[" ("" | "" main_bars_item (", " main_bars_item)* "") "]" +main_bars_item ::= ("{" "" (("\"apple\"" ": " basic_string main_bars_item_sub_0) | ("\"banana\"" ": " basic_string "")) "" "}") | "{}" +main_bars ::= ("[" "" main_bars_item (", " main_bars_item)* "" "]") | "[]" main ::= "{" "" "\"foo\"" ": " main_foo ", " "\"bars\"" ": " main_bars "" "}" """ @@ -370,8 +410,8 @@ class Dog(BaseModel): basic_string ::= ["] basic_string_sub ["] basic_boolean ::= "true" | "false" basic_null ::= "null" -basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" -basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" main_0 ::= "{" "" "\"name\"" ": " basic_string ", " "\"color\"" ": " basic_string "" "}" main_1 ::= "{" "" "\"name\"" ": " basic_string ", " "\"breed\"" ": " basic_string "" "}" main ::= main_0 | main_1 @@ -396,8 +436,8 @@ class MainModel(BaseModel): basic_string ::= ["] basic_string_sub ["] basic_boolean ::= "true" | "false" basic_null ::= "null" -basic_array ::= "[" ("" | "" basic_any (", " basic_any)* "") "]" -basic_object ::= "{" ("" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" | "") "}" +basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" main ::= "{" "" "\"name\"" ": " basic_string "" "}" """ diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index abe0e391ed..de335f9735 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -1,9 +1,11 @@ # pylint: disable=chained-comparison,line-too-long,missing-docstring, # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable import asyncio +import json from typing import List import pytest +from pydantic import BaseModel from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig from mlc_llm.serve.async_engine import AsyncThreadedEngine @@ -69,6 +71,61 @@ def test_batch_generation_with_grammar(): print(f"Output {req_id}({i}):{output}\n") +def test_batch_generation_with_schema(): + # Initialize model loading info and KV cache config + model = ModelInfo(model_path, model_lib_path=model_lib_path) + kv_cache_config = KVCacheConfig(page_size=16) + # Create engine + engine = Engine(model, kv_cache_config) + + prompt = ( + "Generate a json containing three fields: an integer field named size, a " + "boolean field named is_accepted, and a float field named num:" + ) + repeat_cnt = 3 + prompts = [prompt] * repeat_cnt * 2 + + temperature = 1 + repetition_penalty = 1 + max_tokens = 512 + generation_config_no_json = GenerationConfig( + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + stop_token_ids=[2], + response_format=ResponseFormat(type="text"), + ) + + class Schema(BaseModel): + size: int + is_accepted: bool + num: float + + schema_str = json.dumps(Schema.model_json_schema()) + + generation_config_json = GenerationConfig( + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + stop_token_ids=[2], + response_format=ResponseFormat(type="json_object", schema=schema_str), + ) + + all_generation_configs = [generation_config_no_json] * repeat_cnt + [ + generation_config_json + ] * repeat_cnt + + # Generate output. + output_texts, _ = engine.generate(prompts, all_generation_configs) + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}: {outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}): {output}\n") + + async def run_async_engine(): # Initialize model loading info and KV cache config model = ModelInfo(model_path, model_lib_path=model_lib_path) @@ -144,7 +201,7 @@ def test_generation_config_error(): repetition_penalty=1.0, max_tokens=128, stop_token_ids=[2], - response_format=ResponseFormat(type="text", json_schema="{}"), + response_format=ResponseFormat(type="text", schema="{}"), ) From 0a23af5fe9a688bf3f3c24c95dc30e6314f85e7b Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 27 Mar 2024 01:38:53 -0400 Subject: [PATCH 112/531] [Compiler] Support AUTO mode for all-reduce strategy (#2034) This PR supports the auto mode for IPC all-reduce strategy. It renames the strategy from `allreduce-strategy` to `ipc-allreduce-strategy` in the compiler optimization flags. The default RING mode is renamed to NONE mode, which, when specified, uses nccl all-reduce without any IPC memory rewrite. So right now to enable IPC all-reduce, the ideal way is to do `ipc-allreduce-strategy=auto`. --- python/mlc_llm/compiler_pass/pipeline.py | 6 +++--- python/mlc_llm/interface/compile.py | 2 +- python/mlc_llm/interface/compiler_flags.py | 19 +++++++++---------- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index 4cf6323bc8..a5d44cebc2 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -9,7 +9,7 @@ from tvm.relax import register_pipeline # pylint: disable=no-name-in-module from tvm.relax.frontend import nn -from mlc_llm.interface.compiler_flags import AllReduceStrategyType +from mlc_llm.interface.compiler_flags import IPCAllReduceStrategyType from mlc_llm.support import logging from .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc @@ -76,7 +76,7 @@ def _mlc_llm_pipeline( # pylint: disable=too-many-arguments flashinfer: bool = False, cublas_gemm: bool = False, faster_transformer: bool = False, # pylint: disable=unused-argument - allreduce_strategy: AllReduceStrategyType = AllReduceStrategyType.RING, + allreduce_strategy: IPCAllReduceStrategyType = IPCAllReduceStrategyType.NONE, variable_bounds: Dict[str, int] = None, additional_tirs: Dict[str, tvm.tir.PrimFunc] = None, metadata: Dict[str, Any] = None, @@ -151,7 +151,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tvm.relax.transform.CallTIRRewrite(), ( tvm.relax.transform.IPCAllReduceRewrite(allreduce_strategy) - if allreduce_strategy != AllReduceStrategyType.RING + if allreduce_strategy != IPCAllReduceStrategyType.NONE else tvm.transform.Sequential([]) ), tvm.relax.transform.StaticPlanBlockMemory(), diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index 56bcc75abd..288e0a39b6 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -184,7 +184,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: flashinfer=args.opt.flashinfer, cublas_gemm=args.opt.cublas_gemm, faster_transformer=args.opt.faster_transformer, - allreduce_strategy=args.opt.allreduce_strategy, + allreduce_strategy=args.opt.ipc_allreduce_strategy, variable_bounds=variable_bounds, additional_tirs=additional_tirs, ext_mods=ext_mods, diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index 32e79f9bd3..f3a6092f6d 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -14,12 +14,13 @@ logger = logging.getLogger(__name__) -class AllReduceStrategyType(enum.IntEnum): +class IPCAllReduceStrategyType(enum.IntEnum): """The all-reduce strategy.""" - RING = 0 + NONE = 0 ONESHOT = 1 TWOSHOT = 2 + AUTO = 3 @dataclasses.dataclass @@ -31,7 +32,7 @@ class OptimizationFlags: faster_transformer: bool = False cudagraph: bool = False cutlass: bool = False - allreduce_strategy: AllReduceStrategyType = AllReduceStrategyType.RING + ipc_allreduce_strategy: IPCAllReduceStrategyType = IPCAllReduceStrategyType.NONE def __repr__(self) -> str: out = StringIO() @@ -40,7 +41,7 @@ def __repr__(self) -> str: print(f";faster_transformer={int(self.faster_transformer)}", file=out, end="") print(f";cudagraph={int(self.cudagraph)}", file=out, end="") print(f";cutlass={int(self.cutlass)}", file=out, end="") - print(f";allreduce_strategy={self.allreduce_strategy.name}", file=out, end="") + print(f";ipc_allreduce_strategy={self.ipc_allreduce_strategy.name}", file=out, end="") return out.getvalue().rstrip() @staticmethod @@ -64,10 +65,10 @@ def boolean(value: str) -> bool: parser.add_argument("--cudagraph", type=boolean, default=False) parser.add_argument("--cutlass", type=boolean, default=False) parser.add_argument( - "--allreduce-strategy", + "--ipc_allreduce_strategy", type=str, - choices=["ring", "one-shot", "two-shot"], - default="ring", + choices=["NONE", "ONESHOT", "TWOSHOT", "AUTO"], + default="NONE", ) results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) return OptimizationFlags( @@ -76,9 +77,7 @@ def boolean(value: str) -> bool: faster_transformer=results.faster_transformer, cudagraph=results.cudagraph, cutlass=results.cutlass, - allreduce_strategy=AllReduceStrategyType[ - results.allreduce_strategy.replace("-", "").upper() - ], + ipc_allreduce_strategy=IPCAllReduceStrategyType[results.ipc_allreduce_strategy], ) def update(self, target, quantization) -> None: From 47c8350079e4009ee6c7b6543014e4d7d82c5ac7 Mon Sep 17 00:00:00 2001 From: Animesh Bohara Date: Wed, 27 Mar 2024 11:09:31 -0400 Subject: [PATCH 113/531] [LLaVa] Follow-up for TODOs in LLaVa model (#2010) Llava: 1. Added base64 image support. 2. Merged as_prompt and as_prompt_list. 3. get_image_from_url uses config --- python/mlc_llm/conversation_template.py | 2 + .../mlc_llm/protocol/conversation_protocol.py | 130 ++++++++---------- .../mlc_llm/protocol/openai_api_protocol.py | 2 +- .../serve/entrypoints/entrypoint_utils.py | 31 ++++- .../serve/entrypoints/openai_entrypoints.py | 22 ++- 5 files changed, 94 insertions(+), 93 deletions(-) diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index b4a3468872..167ed1fb28 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -291,6 +291,8 @@ def get_conv_template(name: str) -> Optional[Conversation]: role_empty_sep=":", stop_str=[""], stop_token_ids=[2], + system_prefix_token_ids=[1], + add_role_after_system_message=False, ) ) diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index c4ed03e869..1c2a3cb2e4 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -61,9 +61,7 @@ class Conversation(BaseModel): # The conversation history messages. # Each message is a pair of strings, denoting "(role, content)". # The content can be None. - messages: List[Tuple[str, Optional[Union[str, List[Dict[str, str]]]]]] = Field( - default_factory=lambda: [] - ) + messages: List[Tuple[str, Optional[Union[str, List[Dict]]]]] = Field(default_factory=lambda: []) # The separators between messages when concatenating into a single prompt. # List size should be either 1 or 2. @@ -114,7 +112,8 @@ def from_json_dict(cls: Type[T], json_dict: Dict[str, Any]) -> T: """Convert from a json dictionary""" return Conversation.model_validate(json_dict) - def as_prompt(self) -> str: + # pylint: disable=too-many-branches + def as_prompt(self, config=None) -> List[Union[str, data.ImageData]]: """Convert the conversation template and history messages to a single prompt. """ @@ -124,16 +123,20 @@ def as_prompt(self) -> str: ) # - Get the message strings. - message_list: List[str] = [] + message_list: List[Union[str, data.ImageData]] = [] separators = list(self.seps) if len(separators) == 1: separators.append(separators[0]) + + if system_msg != "": + system_msg += separators[0] + message_list.append(system_msg) + for i, (role, content) in enumerate(self.messages): # pylint: disable=not-an-iterable if role not in self.roles.keys(): raise ValueError(f'Role "{role}" is not a supported role in {self.roles.keys()}') separator = separators[role == "assistant"] # check assistant role if content is not None: - assert isinstance(content, str) role_prefix = ( "" # Do not append role prefix if this is the first message and there @@ -141,63 +144,9 @@ def as_prompt(self) -> str: if (not self.add_role_after_system_message and system_msg != "" and i == 0) else self.roles[role] + self.role_content_sep ) - message_string = ( - role_prefix - + self.role_templates[role].replace( - MessagePlaceholders[role.upper()].value, content - ) - + separator - ) - else: - message_string = self.roles[role] + self.role_empty_sep - message_list.append(message_string) - - if system_msg != "": - system_msg += separators[0] - - prompt = system_msg + "".join(message_list) - - # Replace the last function string placeholder with actual function string - prompt = self.function_string.join(prompt.rsplit(MessagePlaceholders.FUNCTION.value, 1)) - # Replace with remaining function string placeholders with empty string - prompt = prompt.replace(MessagePlaceholders.FUNCTION.value, "") - - return prompt - - def as_prompt_list(self, image_embed_size=None) -> List[Union[str, data.ImageData]]: - """Convert the conversation template and history messages to - a list of prompts. - - Returns: - List[Union[str, data.ImageData]]: The list of prompts. - """ - # TODO: Unify this function with as_prompt() # pylint: disable=fixme - - # pylint: disable=import-outside-toplevel - from ..serve.entrypoints.entrypoint_utils import get_image_from_url - - # - Get the system message. - system_msg = self.system_template.replace( - MessagePlaceholders.SYSTEM.value, self.system_message - ) - - # - Get the message strings. - message_list: List[Union[str, data.ImageData]] = [] - separators = list(self.seps) - if len(separators) == 1: - separators.append(separators[0]) - if system_msg != "": - system_msg += separators[0] - message_list.append(system_msg) - for role, content in self.messages: # pylint: disable=not-an-iterable - if role not in self.roles.keys(): - raise ValueError(f'Role "{role}" is not a supported role in {self.roles.keys()}') - separator = separators[role == "assistant"] # check assistant role - if content is not None: if isinstance(content, str): message_string = ( - self.roles[role] - + self.role_content_sep + role_prefix + self.role_templates[role].replace( MessagePlaceholders[role.upper()].value, content ) @@ -205,10 +154,7 @@ def as_prompt_list(self, image_embed_size=None) -> List[Union[str, data.ImageDat ) message_list.append(message_string) else: - assert isinstance( - content, list - ), "Content should be a string or a list of dicts" - message_list.append(self.roles[role] + self.role_content_sep) + message_list.append(role_prefix) for item in content: assert isinstance( item, dict @@ -221,23 +167,59 @@ def as_prompt_list(self, image_embed_size=None) -> List[Union[str, data.ImageDat ) ) elif item["type"] == "image_url": - assert image_embed_size is not None, "Image embed size is required" - message_list.append( - data.ImageData( - image=get_image_from_url(item["image_url"]), - embed_size=image_embed_size, - ) + assert config is not None, "Model config is required" + + # pylint: disable=import-outside-toplevel + from ..serve.entrypoints.entrypoint_utils import ( + get_image_from_url, ) + + image_url = _get_url_from_item(item) + message_list.append(get_image_from_url(image_url, config)) else: raise ValueError(f"Unsupported content type: {item['type']}") - message_list.append(separator) + message_list.append(separator) else: message_string = self.roles[role] + self.role_empty_sep message_list.append(message_string) - prompt = message_list + prompt = _combine_consecutive_strings(message_list) - ## TODO: Support function calling # pylint: disable=fixme + if not any(isinstance(item, data.ImageData) for item in message_list): + # Replace the last function string placeholder with actual function string + prompt[0] = self.function_string.join( + prompt[0].rsplit(MessagePlaceholders.FUNCTION.value, 1) + ) + # Replace with remaining function string placeholders with empty string + prompt[0] = prompt[0].replace(MessagePlaceholders.FUNCTION.value, "") return prompt + + +def _get_url_from_item(item: Dict) -> str: + image_url: str + assert "image_url" in item, "Content item should have an image_url field" + if isinstance(item["image_url"], str): + image_url = item["image_url"] + elif isinstance(item["image_url"], dict): + assert ( + "url" in item["image_url"] + ), "Content image_url item should be a string or a dict with a url field" # pylint: disable=line-too-long + image_url = item["image_url"]["url"] + else: + raise ValueError( + "Content image_url item type not supported. " + "Should be a string or a dict with a url field." + ) + return image_url + + +def _combine_consecutive_strings(lst): + result = [] + for item in lst: + if isinstance(item, str) and result and isinstance(result[-1], str): + result[-1] += item + else: + result.append(item) + return result diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 4ac6daef71..fa4893447f 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -184,7 +184,7 @@ class ChatToolCall(BaseModel): class ChatCompletionMessage(BaseModel): - content: Optional[Union[str, List[Dict[str, str]]]] = None + content: Optional[Union[str, List[Dict]]] = None role: Literal["system", "user", "assistant", "tool"] name: Optional[str] = None tool_calls: Optional[List[ChatToolCall]] = None diff --git a/python/mlc_llm/serve/entrypoints/entrypoint_utils.py b/python/mlc_llm/serve/entrypoints/entrypoint_utils.py index f0c82769ec..b0895f2fe7 100644 --- a/python/mlc_llm/serve/entrypoints/entrypoint_utils.py +++ b/python/mlc_llm/serve/entrypoints/entrypoint_utils.py @@ -98,27 +98,42 @@ def process_prompts( return output_prompts -def get_image_from_url(url: str): +def get_image_from_url(url: str, config: Dict) -> data.ImageData: """Get the image from the given URL, process and return the image tensor as TVM NDArray.""" # pylint: disable=import-outside-toplevel, import-error + import base64 + import requests import tvm from PIL import Image from transformers import CLIPImageProcessor - response = requests.get(url, timeout=5) - image_tensor = Image.open(BytesIO(response.content)).convert("RGB") + if url.startswith("data:image"): + # The image is encoded in base64 format + base64_image = url.split(",")[1] + image_data = base64.b64decode(base64_image) + image_tensor = Image.open(BytesIO(image_data)).convert("RGB") + elif url.startswith("http"): + response = requests.get(url, timeout=5) + image_tensor = Image.open(BytesIO(response.content)).convert("RGB") + else: + raise ValueError(f"Unsupported image URL format: {url}") + + image_input_size = get_image_input_size(config) + image_embed_size = get_image_embed_size(config) image_processor = CLIPImageProcessor( - size={"shortest_edge": 336}, crop_size={"height": 336, "width": 336} + size={"shortest_edge": image_input_size}, + crop_size={"height": image_input_size, "width": image_input_size}, ) image_features = tvm.nd.array( image_processor.preprocess(image_tensor, return_tensors="np")["pixel_values"].astype( "float16" ) ) - return image_features + image_data = data.ImageData(image_features, image_embed_size) + return image_data def get_image_embed_size(config: Dict) -> int: @@ -127,3 +142,9 @@ def get_image_embed_size(config: Dict) -> int: patch_size = config["model_config"]["vision_config"]["patch_size"] embed_size = (image_size // patch_size) ** 2 return embed_size + + +def get_image_input_size(config: Dict) -> int: + """Get the image input size from the model config file.""" + image_size = config["model_config"]["vision_config"]["image_size"] + return image_size diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index aa9d941f6c..ee4ddf7db9 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -389,7 +389,6 @@ async def request_chat_completion( if error_msg is not None: return entrypoint_utils.create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) - content_has_list = any(isinstance(message.content, list) for message in request.messages) for message in request.messages: role = message.role content = message.content @@ -406,17 +405,12 @@ async def request_chat_completion( # - Check prompt length async_engine.record_event(request_id, event="start tokenization") - if content_has_list: - model_config = ServerContext.get_model_config(request.model) - image_embed_size = entrypoint_utils.get_image_embed_size(model_config) - prompts = entrypoint_utils.process_prompts( - conv_template.as_prompt_list(image_embed_size=image_embed_size), - async_engine.tokenizer.encode, - ) - else: - prompts = entrypoint_utils.process_prompts( - conv_template.as_prompt(), async_engine.tokenizer.encode - ) + model_config = ServerContext.get_model_config(request.model) + prompts = entrypoint_utils.process_prompts( + conv_template.as_prompt(model_config), + async_engine.tokenizer.encode, + ) + async_engine.record_event(request_id, event="finish tokenization") if conv_template.system_prefix_token_ids is not None: prompts[0] = conv_template.system_prefix_token_ids + prompts[0] @@ -581,5 +575,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ], model=request.model, system_fingerprint="", - usage=UsageInfo(prompt_tokens=len(prompt), completion_tokens=num_completion_tokens), + usage=UsageInfo( + prompt_tokens=sum(len(item) for item in prompt), completion_tokens=num_completion_tokens + ), ) From 2d68e64fe7905263398be5bb904a9862c3668897 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 27 Mar 2024 17:45:58 -0400 Subject: [PATCH 114/531] [Pipeline] Defer GPU IPC memory lowering (#2038) This PR moves the position of GPU IPC memory lowering pass in pipeline, so that it applies after the CUDA graph rewrite to enable CUDA graph with the customized all-reduce kernels. --- python/mlc_llm/compiler_pass/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index a5d44cebc2..ad19e6a2bf 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -155,9 +155,9 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I else tvm.transform.Sequential([]) ), tvm.relax.transform.StaticPlanBlockMemory(), - tvm.relax.transform.LowerGPUIPCAllocStorage(), AttachMetadataWithMemoryUsage(metadata), tvm.relax.transform.RewriteCUDAGraph(), + tvm.relax.transform.LowerGPUIPCAllocStorage(), tvm.relax.transform.LowerAllocTensor(), tvm.relax.transform.KillAfterLastUse(), tvm.relax.transform.VMBuiltinLower(), From be42bec0ef0be0a96b1757bb5d86c2641aba41ad Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 27 Mar 2024 19:43:52 -0700 Subject: [PATCH 115/531] [Model] Add missing broadcast of logit_position for multigpu (#2040) This commit adds the broadcasting of `logit_pos` in batch prefill for all models to avoid the logit position out-of-bound issue. --- python/mlc_llm/model/gemma/gemma_model.py | 2 ++ python/mlc_llm/model/gpt2/gpt2_model.py | 2 ++ python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py | 2 ++ python/mlc_llm/model/gpt_neox/gpt_neox_model.py | 2 ++ python/mlc_llm/model/llama/llama_model.py | 2 ++ python/mlc_llm/model/mistral/mistral_model.py | 2 ++ python/mlc_llm/model/orion/orion_model.py | 2 ++ python/mlc_llm/model/phi/phi_model.py | 2 ++ python/mlc_llm/model/qwen2/qwen2_model.py | 2 ++ 9 files changed, 18 insertions(+) diff --git a/python/mlc_llm/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py index 079708ddb8..5950ab2972 100644 --- a/python/mlc_llm/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -277,6 +277,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index 3c229fd911..28c34353e2 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -269,6 +269,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache diff --git a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py index c96caa9fee..c13d169be1 100644 --- a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py @@ -246,6 +246,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache diff --git a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py index 62e6587bf2..5e940a15b3 100644 --- a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py @@ -300,6 +300,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index f38997cdeb..2ae5500c6d 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -260,6 +260,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache diff --git a/python/mlc_llm/model/mistral/mistral_model.py b/python/mlc_llm/model/mistral/mistral_model.py index 0b66ea706d..3439f7b41f 100644 --- a/python/mlc_llm/model/mistral/mistral_model.py +++ b/python/mlc_llm/model/mistral/mistral_model.py @@ -241,6 +241,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache diff --git a/python/mlc_llm/model/orion/orion_model.py b/python/mlc_llm/model/orion/orion_model.py index 48de826a3b..c6a2293cd2 100644 --- a/python/mlc_llm/model/orion/orion_model.py +++ b/python/mlc_llm/model/orion/orion_model.py @@ -261,6 +261,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache diff --git a/python/mlc_llm/model/phi/phi_model.py b/python/mlc_llm/model/phi/phi_model.py index 6d95833d41..2c9c596ed7 100644 --- a/python/mlc_llm/model/phi/phi_model.py +++ b/python/mlc_llm/model/phi/phi_model.py @@ -364,6 +364,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache diff --git a/python/mlc_llm/model/qwen2/qwen2_model.py b/python/mlc_llm/model/qwen2/qwen2_model.py index ff42e977b4..6eae4c2bb0 100644 --- a/python/mlc_llm/model/qwen2/qwen2_model.py +++ b/python/mlc_llm/model/qwen2/qwen2_model.py @@ -268,6 +268,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache From 5ebcda147e10fe3e19b41d4f4413a92899542b1f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 27 Mar 2024 19:44:41 -0700 Subject: [PATCH 116/531] [Preshard] apply presharding after quantization (#2039) This change the behavior of presharding by apply presharding after quantization. This makes the behavior consistent with or without presharding --- python/mlc_llm/interface/convert_weight.py | 2 +- python/mlc_llm/loader/huggingface_loader.py | 16 +++-- python/mlc_llm/support/preshard.py | 73 +++++++++------------ 3 files changed, 42 insertions(+), 49 deletions(-) diff --git a/python/mlc_llm/interface/convert_weight.py b/python/mlc_llm/interface/convert_weight.py index 0d5cd53fea..90c5c45831 100644 --- a/python/mlc_llm/interface/convert_weight.py +++ b/python/mlc_llm/interface/convert_weight.py @@ -76,7 +76,7 @@ def _convert_args(args: ConversionArgs) -> None: # pylint: disable=too-many-loc named_params = dict(_named_params) if pre_shards_num is not None: - preshard_funcs = apply_preshard(quantize_map, named_params, int(pre_shards_num), args) + named_params, preshard_funcs = apply_preshard(named_params, int(pre_shards_num), args) else: preshard_funcs = None diff --git a/python/mlc_llm/loader/huggingface_loader.py b/python/mlc_llm/loader/huggingface_loader.py index 1f72197150..31bc8cfa44 100644 --- a/python/mlc_llm/loader/huggingface_loader.py +++ b/python/mlc_llm/loader/huggingface_loader.py @@ -115,13 +115,15 @@ def load( mlc_names = _loading_order(self.extern_param_map, self.torch_to_path) for mlc_name in tqdm(mlc_names): param = self._load_mlc_param(mlc_name, device=device) - if preshard_funcs is not None and mlc_name in preshard_funcs: - sharded_params = preshard_funcs[mlc_name](param) - for i, sharded_param in enumerate(sharded_params): - sharded_name = _sharded_param_name(mlc_name, i) - yield from self._load_or_quantize(sharded_name, sharded_param, device) - else: - yield from self._load_or_quantize(mlc_name, param, device) + # Apply quantization if needed, in this case the original parameter may become + # multiple quantized parameters. + for name, loader_param in self._load_or_quantize(mlc_name, param, device): + # Apply presharding if needed + if name in preshard_funcs: + for shard_id, shard_param in enumerate(preshard_funcs[name](loader_param)): + yield _sharded_param_name(name, shard_id), shard_param + else: + yield name, loader_param cached_files = list(self.cached_files.keys()) for path in cached_files: diff --git a/python/mlc_llm/support/preshard.py b/python/mlc_llm/support/preshard.py index cd5edbc19c..be351a13d2 100644 --- a/python/mlc_llm/support/preshard.py +++ b/python/mlc_llm/support/preshard.py @@ -1,12 +1,12 @@ """Functions for pre-sharding weights""" import logging -from typing import Any, Dict, List +from typing import Any, Callable, Dict, Sequence, Tuple from tvm import IRModule from tvm import dlight as dl from tvm import relax from tvm.relax.frontend import nn -from tvm.runtime import Device +from tvm.runtime import Device, NDArray from tvm.target import Target logger = logging.getLogger("preshard") @@ -16,33 +16,6 @@ def _sharded_param_name(param_name, worker_id): return f"{param_name}_shard-{worker_id}" -def _update_quantize_map( - quantize_map: Any, - named_params: Dict[str, nn.Parameter], - mlc_name: str, - tensor_parallel_shards: int, -): - param_names: List[str] = [mlc_name] - - if mlc_name in quantize_map.param_map: - # the parameter is quantized - quantized_params = quantize_map.param_map[mlc_name] - param_names = quantized_params - quantize_func = quantize_map.map_func[mlc_name] - - for worker_id in range(tensor_parallel_shards): - sharded_mlc_name = _sharded_param_name(mlc_name, worker_id) - quantize_map.param_map[sharded_mlc_name] = [ - _sharded_param_name(param_name, worker_id) for param_name in quantized_params - ] - quantize_map.map_func[sharded_mlc_name] = quantize_func - - for param_name in param_names: - param = named_params.pop(param_name) - for worker_id in range(tensor_parallel_shards): - named_params[_sharded_param_name(param_name, worker_id)] = param - - def _create_shard_func( bb: relax.BlockBuilder, param: nn.Parameter, tensor_parallel_shards: int ): # pylint: disable=too-many-locals @@ -96,38 +69,56 @@ def _compile_shard_funcs(mod: IRModule, device: Device): def apply_preshard( - quantize_map: Any, named_params: Dict[str, nn.Parameter], tensor_parallel_shards: int, args: Any -): - """Update quantize_map and named_params, create shard functions based on shard strategies.""" - model_config = args.model.config.from_file(args.config) - model_config.tensor_parallel_shards = tensor_parallel_shards - model = args.model.model(model_config) - model.to(args.quantization.model_dtype) - + named_params: Dict[str, nn.Parameter], + tensor_parallel_shards: int, + args: Any, +) -> Tuple[Dict[str, nn.Parameter], Dict[str, Callable[[NDArray], Sequence[NDArray]]]]: + """Apply pre-sharding to the named parameters. + + Parameters + ---------- + named_params : Dict[str, nn.Parameter] + The named parameters of the model. If the model is quantized, the named parameters should + the state dictionary of the quantized model. + tensor_parallel_shards : int + The number of tensor parallel shards. + args : Any + The parsed arguments of weight conversion. + + Returns + ------- + Tuple[Dict[str, nn.Parameter], Dict[str, Callable[[NDArray], Sequence[NDArray]]] + The updated named parameters and the mapping from parameter name to the shard function. + """ bb = relax.BlockBuilder() param_to_shard_func = {} shard_func_names = set() + new_named_params: Dict[str, nn.Parameter] = {} has_shard_strategy = False - for name, param in model.state_dict().items(): + for name, param in named_params.items(): shard_strategy = param.attrs.get("shard_strategy", None) if shard_strategy is not None: has_shard_strategy = True - _update_quantize_map(quantize_map, named_params, name, tensor_parallel_shards) - + for i in range(tensor_parallel_shards): + new_named_params[_sharded_param_name(name, i)] = param # create shard functions param_to_shard_func[name] = shard_strategy.name if shard_strategy.name not in shard_func_names: _create_shard_func(bb, param, tensor_parallel_shards) shard_func_names.add(shard_strategy.name) + else: + new_named_params[name] = param + if not has_shard_strategy: logger.warning( "No parameters with 'shard_strategy' found." "At least one parameter must have a 'shard_strategy' for presharding. " "The model will continue to convert weights in a non-presharded manner." ) + mod = bb.finalize() vm = _compile_shard_funcs(mod, args.device) for name in param_to_shard_func: param_to_shard_func[name] = vm[param_to_shard_func[name]] - return param_to_shard_func + return new_named_params, param_to_shard_func From a0c0f2105f4a92482dcff3fd36442f500d00df65 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Thu, 28 Mar 2024 11:53:58 +0800 Subject: [PATCH 117/531] [SLM] Baichuan Multi-GPU support (#2037) This PR enables TP function of Baichuan2 model. --- .../mlc_llm/model/baichuan/baichuan_model.py | 56 +++++++++++++++---- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/python/mlc_llm/model/baichuan/baichuan_model.py b/python/mlc_llm/model/baichuan/baichuan_model.py index ce51659b25..1d8f88c676 100644 --- a/python/mlc_llm/model/baichuan/baichuan_model.py +++ b/python/mlc_llm/model/baichuan/baichuan_model.py @@ -13,6 +13,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase from mlc_llm.support.style import bold @@ -39,6 +40,7 @@ class BaichuanConfig(ConfigBase): # pylint: disable=too-many-instance-attribute prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 max_batch_size: int = 1 + head_dim: int = 0 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -59,6 +61,9 @@ def __post_init__(self): "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( "%s defaults to %s (%d)", @@ -84,11 +89,9 @@ def __post_init__(self): class BaichuanAttention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: BaichuanConfig): self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.context_window_size - - self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) + self.num_heads = config.num_attention_heads // config.tensor_parallel_shards + self.head_dim = config.head_dim + self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): @@ -105,12 +108,13 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: class BaichuanMLP(nn.Module): def __init__(self, config: BaichuanConfig): + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear( in_features=config.hidden_size, - out_features=2 * config.intermediate_size, + out_features=2 * self.intermediate_size, bias=False, ) - self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) def forward(self, x): concat_x1_x2 = self.gate_up_proj(x) @@ -126,13 +130,41 @@ def __init__(self, config: BaichuanConfig): self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, norm_eps, bias=False) self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, norm_eps, bias=False) + def _set_tp(): + def _set(layer, hint): + layer.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_heads * hd + k = self.self_attn.num_heads * hd + v = self.self_attn.num_heads * hd + i = self.mlp.intermediate_size + _set( + self.self_attn.W_pack.weight, + tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]), + ) + _set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1)) + _set( + self.mlp.gate_up_proj.weight, + tp.ShardSingleDim("_shard_mlp_gate_up", segs=[i, i], dim=0), + ) + _set(self.mlp.down_proj.weight, tp.ShardSingleDim("_shard_mlp_down_proj", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) - hidden_states = out + hidden_states + hidden_states = self._apply_residual(out, residual=hidden_states) out = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = out + hidden_states + hidden_states = self._apply_residual(out, residual=hidden_states) return hidden_states + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + class BaichuanModel(nn.Module): def __init__(self, config: BaichuanConfig): @@ -159,7 +191,7 @@ def __init__(self, config: BaichuanConfig): self.num_hidden_layers = config.num_hidden_layers self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_attention_heads + self.head_dim = config.head_dim self.vocab_size = config.vocab_size self.rope_theta = 10000 self.tensor_parallel_shards = config.tensor_parallel_shards @@ -187,6 +219,8 @@ def batch_forward( return logits def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.model.embed_tokens(input_ids) def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): @@ -215,6 +249,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache From 34497eae8bf778d7ea2661915252325fd3e6806e Mon Sep 17 00:00:00 2001 From: Git bot Date: Thu, 28 Mar 2024 04:13:39 +0000 Subject: [PATCH 118/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 2955bc6d8b..31175052db 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 2955bc6d8b09f6c0aa3178f1b208c9d0a6d22dee +Subproject commit 31175052dbeeb1c6c4e3fb870024e19872534a7d From cf8d458225a9fca53eaad5f722691b875b1e17d2 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 28 Mar 2024 08:17:23 -0400 Subject: [PATCH 119/531] [Model] Skip TVMSynchronize when tracing is not enabled (#2041) This PR removes the synchronization in `Model` when Chrome tracing is not enabled. It can help some logit process kernels launching earlier. --- cpp/serve/engine.cc | 3 ++- cpp/serve/model.cc | 21 +++++++++++++++------ cpp/serve/model.h | 3 ++- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 98f3e4fe6b..2d1e711cad 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -72,7 +72,8 @@ class EngineImpl : public Engine { String model_path = std::get<1>(model_info); DLDevice device = std::get<2>(model_info); Model model = Model::Create(model_lib, std::move(model_path), device, - kv_cache_config_->max_num_sequence); + kv_cache_config_->max_num_sequence, + /*trace_enabled=*/trace_recorder.defined()); model->CreateKVCache(this->kv_cache_config_); CHECK_GE(model->GetMaxWindowSize(), this->max_single_sequence_length_) << "The window size of the model, " << model->GetMaxWindowSize() diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 3b7d7ef7ea..04b8551abd 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -25,8 +25,9 @@ class ModelImpl; TVM_REGISTER_OBJECT_TYPE(ModelObj); Model Model::Create(TVMArgValue reload_lib, String model_path, DLDevice device, - int max_num_sequence) { - return Model(make_object(reload_lib, model_path, device, max_num_sequence)); + int max_num_sequence, bool trace_enabled) { + return Model( + make_object(reload_lib, model_path, device, max_num_sequence, trace_enabled)); } class ModelImpl : public ModelObj { @@ -36,7 +37,7 @@ class ModelImpl : public ModelObj { * \sa Model::Create */ explicit ModelImpl(TVMArgValue reload_lib, String model_path, DLDevice device, - int max_num_sequence) + int max_num_sequence, bool trace_enabled) : device_(device) { // Step 1. Process model config json string. picojson::object model_config; @@ -166,7 +167,9 @@ class ModelImpl : public ModelObj { } else { logits = Downcast>(ret)[0]; } - TVMSynchronize(device_.device_type, device_.device_id, nullptr); + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } ft_.kv_cache_end_forward_func_(kv_cache_); // logits: (1, num_sequences, v) @@ -223,7 +226,9 @@ class ModelImpl : public ModelObj { } else { logits = Downcast>(ret)[0]; } - TVMSynchronize(device_.device_type, device_.device_id, nullptr); + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } ft_.kv_cache_end_forward_func_(kv_cache_); // logits: (b, 1, v) @@ -280,7 +285,9 @@ class ModelImpl : public ModelObj { } else { logits = Downcast>(ret)[0]; } - TVMSynchronize(device_.device_type, device_.device_id, nullptr); + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } ft_.kv_cache_end_forward_func_(kv_cache_); // logits: (1, total_length, v) @@ -472,6 +479,8 @@ class ModelImpl : public ModelObj { // Shared NDArray memory::Storage token_ids_storage_{nullptr}; NDArray logit_pos_arr_{nullptr}; + // A boolean indicating if tracing is enabled. + bool trace_enabled_; }; TVM_REGISTER_GLOBAL("mlc.copy_embedding_to_offset") diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 65a0002c49..11646a6663 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -206,10 +206,11 @@ class Model : public ObjectRef { * \param model_path The path to the model weight parameters. * \param device The device to run the model on. * \param max_num_sequence The maximum number of sequences to be processed + * \param trace_enabled A boolean indicating whether tracing is enabled. * \return The created runtime module. */ TVM_DLL static Model Create(TVMArgValue reload_lib, String model_path, DLDevice device, - int max_num_sequence); + int max_num_sequence, bool trace_enabled); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj); }; From 4255a451172fdcaf3e9ea751eddc346071652663 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 28 Mar 2024 09:27:30 -0400 Subject: [PATCH 120/531] [Serving] Support NVTX for benchmarking (#2043) This PR supports MLC serve with NVTX which helps analyzing benchmarking results. **Note.** To enable NVTX, please add `set(USE_NVTX ON)` to file `build/config.cmake`. --- cpp/serve/engine_actions/action_commons.cc | 11 +++++++++-- cpp/serve/engine_actions/batch_decode.cc | 18 ++++++++++++------ .../engine_actions/new_request_prefill.cc | 17 +++++++++++++---- cpp/serve/logit_processor.cc | 3 +++ cpp/serve/model.cc | 18 ++++++++++++++---- cpp/serve/sampler/gpu_sampler.cc | 2 ++ 6 files changed, 53 insertions(+), 16 deletions(-) diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index d6a5d52ef4..1fb61ae70a 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -5,6 +5,8 @@ #include "action_commons.h" +#include + namespace mlc { namespace llm { namespace serve { @@ -19,6 +21,7 @@ void RemoveRequestFromModel(EngineState estate, int64_t req_internal_id, Array finished_rsentries, EngineState estate, Array models, int max_single_sequence_length) { + NVTXScopedRange nvtx_scope("Process finished requests"); // - Remove the finished request state entries. for (const RequestStateEntry& rsentry : finished_rsentries) { // The finished entry must be a leaf. @@ -83,6 +86,7 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array finished_rsentries; finished_rsentries.reserve(requests.size()); @@ -128,8 +132,11 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array + #include #include "../../random.h" @@ -40,12 +42,16 @@ class BatchDecodeActionObj : public EngineActionObj { } // Preempt request state entries when decode cannot apply. - std::vector running_rsentries = GetRunningRequestStateEntries(estate); - while (!CanDecode(running_rsentries.size())) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); - if (preempted.same_as(running_rsentries.back())) { - running_rsentries.pop_back(); + std::vector running_rsentries; + { + NVTXScopedRange nvtx_scope("BatchDecode getting requests"); + running_rsentries = GetRunningRequestStateEntries(estate); + while (!CanDecode(running_rsentries.size())) { + RequestStateEntry preempted = + PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + if (preempted.same_as(running_rsentries.back())) { + running_rsentries.pop_back(); + } } } diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index 905eea3ed1..6363f8a537 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -3,6 +3,8 @@ * \file serve/engine_actions/new_request_prefill.cc */ +#include + #include "../config.h" #include "../model.h" #include "../sampler/sampler.h" @@ -33,10 +35,17 @@ class NewRequestPrefillActionObj : public EngineActionObj { Array Step(EngineState estate) final { // - Find the requests in `waiting_queue` that can prefill in this step. - auto [rsentries, prefill_lengths] = GetRequestStateEntriesToPrefill(estate); - ICHECK_EQ(rsentries.size(), prefill_lengths.size()); - if (rsentries.empty()) { - return {}; + Array rsentries; + std::vector prefill_lengths; + { + NVTXScopedRange nvtx_scope("NewRequestPrefill getting requests"); + auto tuple = GetRequestStateEntriesToPrefill(estate); + rsentries = std::move(std::get<0>(tuple)); + prefill_lengths = std::move(std::get<1>(tuple)); + ICHECK_EQ(rsentries.size(), prefill_lengths.size()); + if (rsentries.empty()) { + return {}; + } } int num_rsentries = rsentries.size(); diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index f5fe8b661a..76495ab8a7 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -6,6 +6,7 @@ #include "logit_processor.h" #include +#include #include #include #include @@ -69,6 +70,7 @@ class LogitProcessorImpl : public LogitProcessorObj { const Array& request_ids, // const std::vector* cum_num_token, // const std::vector>* draft_tokens) final { + NVTXScopedRange nvtx_scope("Logit inplace update"); CHECK_EQ(logits->ndim, 2); CHECK_EQ(logits->shape[1], vocab_size_); CHECK(logits.DataType() == DataType::Float(32)); @@ -109,6 +111,7 @@ class LogitProcessorImpl : public LogitProcessorObj { NDArray ComputeProbsFromLogits(NDArray logits, const Array& generation_cfg, const Array& request_ids, const std::vector* cum_num_token) final { + NVTXScopedRange nvtx_scope("Compute probs from logits"); // logits: (n, v) CHECK_EQ(logits->ndim, 2); CHECK_LE(logits->shape[0], max_num_token_); diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 04b8551abd..ad2f9b2a79 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -72,13 +73,18 @@ class ModelImpl : public ModelObj { /*********************** Model Computation ***********************/ ObjectRef TokenEmbed(IntTuple token_ids, ObjectRef* dst, int offset) final { + NVTXScopedRange nvtx_scope("TokenEmbed"); int num_tokens = token_ids.size(); // Copy input token ids to device. DLDataType dtype(DataType::Int(32)); - NDArray token_ids_nd = token_ids_storage_->AllocNDArray(offset * 4, {num_tokens}, dtype); - int* p_token_ids = static_cast(token_ids_nd->data) + (token_ids_nd->byte_offset) / 4; - for (int i = 0; i < num_tokens; ++i) { - p_token_ids[i] = token_ids[i]; + NDArray token_ids_nd; + { + NVTXScopedRange nvtx_scope("Allocate token_ids at offset"); + token_ids_nd = token_ids_storage_->AllocNDArray(offset * 4, {num_tokens}, dtype); + int* p_token_ids = static_cast(token_ids_nd->data) + (token_ids_nd->byte_offset) / 4; + for (int i = 0; i < num_tokens; ++i) { + p_token_ids[i] = token_ids[i]; + } } ICHECK_EQ(token_ids_nd->ndim, 1); ICHECK_EQ(token_ids_nd->shape[0], num_tokens); @@ -96,6 +102,7 @@ class ModelImpl : public ModelObj { } ObjectRef ImageEmbed(const NDArray& image, ObjectRef* dst, int offset) final { + NVTXScopedRange nvtx_scope("ImageEmbed"); CHECK(ft_.image_embed_func_.defined()) << "`image_embed` function is not found in the model. "; auto image_dref_or_nd = ft_.CopyToWorker0(image, "image", image.Shape()); ObjectRef embeddings = ft_.image_embed_func_(image_dref_or_nd, params_); @@ -111,6 +118,7 @@ class ModelImpl : public ModelObj { NDArray BatchPrefill(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) final { + NVTXScopedRange nvtx_scope("BatchPrefill"); CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); int num_sequences = seq_ids.size(); @@ -180,6 +188,7 @@ class ModelImpl : public ModelObj { } NDArray BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) final { + NVTXScopedRange nvtx_scope("BatchDecode"); int num_sequence = seq_ids.size(); CHECK(ft_.decode_func_.defined()) @@ -240,6 +249,7 @@ class ModelImpl : public ModelObj { NDArray BatchVerify(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) final { + NVTXScopedRange nvtx_scope("BatchVerify"); CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); int num_sequences = seq_ids.size(); diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index d8a54001d3..0d46d7416b 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -4,6 +4,7 @@ * \brief The implementation for GPU sampler functions. */ #include +#include #include #include "../../random.h" @@ -61,6 +62,7 @@ class GPUSampler : public SamplerObj { const Array& generation_cfg, // const std::vector& rngs, // std::vector* output_prob_dist) final { + NVTXScopedRange nvtx_scope("BatchSampleTokens"); // probs_on_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); CHECK_EQ(probs_on_device->ndim, 2); From 2b82091ec6dea2ec39f8d4e7de8788d44ea895ec Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 28 Mar 2024 09:30:58 -0400 Subject: [PATCH 121/531] Update huggingface_loader.py --- python/mlc_llm/loader/huggingface_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/loader/huggingface_loader.py b/python/mlc_llm/loader/huggingface_loader.py index 31bc8cfa44..20de641735 100644 --- a/python/mlc_llm/loader/huggingface_loader.py +++ b/python/mlc_llm/loader/huggingface_loader.py @@ -119,7 +119,7 @@ def load( # multiple quantized parameters. for name, loader_param in self._load_or_quantize(mlc_name, param, device): # Apply presharding if needed - if name in preshard_funcs: + if preshard_funcs is not None and name in preshard_funcs: for shard_id, shard_param in enumerate(preshard_funcs[name](loader_param)): yield _sharded_param_name(name, shard_id), shard_param else: From 522db058853aa2b83d44b2214707a6f41d873f60 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 29 Mar 2024 01:39:34 -0400 Subject: [PATCH 122/531] [Serve] Separate callback invocation to another thread in AsyncEngine (#2046) This PR enhances the AsyncThreadEngine by separating the callback invocation to another thread, in order to reduce the CPU time overhead of invoking Python callback. --- cpp/serve/async_threaded_engine.cc | 131 ++++++++-- cpp/serve/async_threaded_engine.h | 3 + cpp/serve/engine.cc | 19 +- cpp/serve/engine.h | 17 ++ python/mlc_llm/serve/async_engine.py | 246 +++++++++++------- .../serve/entrypoints/debug_entrypoints.py | 5 +- .../serve/entrypoints/openai_entrypoints.py | 32 +-- .../python/serve/test_serve_engine_grammar.py | 2 +- 8 files changed, 299 insertions(+), 156 deletions(-) diff --git a/cpp/serve/async_threaded_engine.cc b/cpp/serve/async_threaded_engine.cc index ebd97bec3a..49313e4ca1 100644 --- a/cpp/serve/async_threaded_engine.cc +++ b/cpp/serve/async_threaded_engine.cc @@ -30,6 +30,8 @@ class AsyncThreadedEngineImpl : public AsyncThreadedEngine, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("add_request", &AsyncThreadedEngineImpl::AddRequest); TVM_MODULE_VTABLE_ENTRY("abort_request", &AsyncThreadedEngineImpl::AbortRequest); TVM_MODULE_VTABLE_ENTRY("run_background_loop", &AsyncThreadedEngineImpl::RunBackgroundLoop); + TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", + &AsyncThreadedEngineImpl::RunBackgroundStreamBackLoop); TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &AsyncThreadedEngineImpl::ExitBackgroundLoop); if (_name == "init_background_engine") { return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { @@ -39,44 +41,87 @@ class AsyncThreadedEngineImpl : public AsyncThreadedEngine, public ModuleNode { } TVM_MODULE_VTABLE_END(); - void InitBackgroundEngine(TVMArgs args) { background_engine_ = CreateEnginePacked(args); } + void InitBackgroundEngine(TVMArgs args) { + Optional request_stream_callback; + try { + request_stream_callback = args.At>(4); + } catch (const dmlc::Error& e) { + LOG(FATAL) << "ValueError: " << e.what() << kEngineCreationErrorMessage; + } + + CHECK(request_stream_callback.defined()) + << "AsyncThreadedEngine requires request stream callback function, but it is not given."; + request_stream_callback_ = request_stream_callback.value(); + + auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.size(), 1); + Array delta_outputs = args[0]; + bool need_notify = false; + { + std::lock_guard lock(request_stream_callback_mutex_); + request_stream_callback_inputs_.push_back(std::move(delta_outputs)); + ++pending_request_stream_callback_cnt_; + need_notify = stream_callback_waiting_; + } + if (need_notify) { + request_stream_callback_cv_.notify_one(); + } + }; + + std::vector values{args.values, args.values + args.size()}; + std::vector type_codes{args.type_codes, args.type_codes + args.size()}; + TVMArgsSetter setter(values.data(), type_codes.data()); + request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); + setter(4, request_stream_callback); + background_engine_ = CreateEnginePacked(TVMArgs(values.data(), type_codes.data(), args.size())); + } void AddRequest(Request request) final { + bool need_notify = false; { - std::lock_guard lock(mutex_); + std::lock_guard lock(background_loop_mutex_); requests_to_add_.push_back(request); - ++pending_operation_cnt_; + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); } - cv_.notify_one(); } void AbortRequest(const String& request_id) final { + bool need_notify = false; { - std::lock_guard lock(mutex_); + std::lock_guard lock(background_loop_mutex_); requests_to_abort_.push_back(request_id); - ++pending_operation_cnt_; + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); } - cv_.notify_one(); } void RunBackgroundLoop() final { - // The local vectors that load the requests in critical regions. + // The local vectors that load the requests from critical regions. std::vector local_requests_to_add; std::vector local_requests_to_abort; while (!exit_now_.load(std::memory_order_relaxed)) { { - std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { - return !background_engine_->Empty() || pending_operation_cnt_.load() > 0 || + std::unique_lock lock(background_loop_mutex_); + engine_waiting_ = true; + background_loop_cv_.wait(lock, [this] { + return !background_engine_->Empty() || pending_request_operation_cnt_.load() > 0 || exit_now_.load(std::memory_order_relaxed); }); + engine_waiting_ = false; local_requests_to_add = requests_to_add_; local_requests_to_abort = requests_to_abort_; requests_to_add_.clear(); requests_to_abort_.clear(); - pending_operation_cnt_ = 0; + pending_request_operation_cnt_ = 0; } for (Request request : local_requests_to_add) { background_engine_->AddRequest(request); @@ -88,22 +133,57 @@ class AsyncThreadedEngineImpl : public AsyncThreadedEngine, public ModuleNode { } } + void RunBackgroundStreamBackLoop() final { + // The local vectors that load the request stream callback inputs from critical regions. + std::vector> local_request_stream_callback_inputs; + std::vector flattened_callback_inputs; + + while (!exit_now_.load(std::memory_order_relaxed)) { + { + std::unique_lock lock(request_stream_callback_mutex_); + stream_callback_waiting_ = true; + request_stream_callback_cv_.wait(lock, [this] { + return pending_request_stream_callback_cnt_.load() > 0 || + exit_now_.load(std::memory_order_relaxed); + }); + stream_callback_waiting_ = false; + + local_request_stream_callback_inputs = request_stream_callback_inputs_; + request_stream_callback_inputs_.clear(); + pending_request_stream_callback_cnt_ = 0; + } + for (const Array& callback_inputs : + local_request_stream_callback_inputs) { + for (const RequestStreamOutput& callback_input : callback_inputs) { + flattened_callback_inputs.push_back(callback_input); + } + } + request_stream_callback_(Array(flattened_callback_inputs)); + flattened_callback_inputs.clear(); + } + } + void ExitBackgroundLoop() final { { - std::lock_guard lock(mutex_); + std::lock_guard lock(background_loop_mutex_); exit_now_.store(true); } - cv_.notify_one(); + background_loop_cv_.notify_one(); + request_stream_callback_cv_.notify_one(); } private: /*! \brief The background normal engine for request processing. */ std::unique_ptr background_engine_; + /*! \brief The request stream callback. */ + PackedFunc request_stream_callback_; /*! \brief The mutex ensuring only one thread can access critical regions. */ - std::mutex mutex_; + std::mutex background_loop_mutex_; + std::mutex request_stream_callback_mutex_; /*! \brief The condition variable preventing threaded engine from spinning. */ - std::condition_variable cv_; + std::condition_variable background_loop_cv_; + std::condition_variable request_stream_callback_cv_; /*! \brief A boolean flag denoting if the engine needs to exit background loop. */ std::atomic exit_now_ = false; @@ -121,10 +201,25 @@ class AsyncThreadedEngineImpl : public AsyncThreadedEngine, public ModuleNode { */ std::vector requests_to_abort_; /*! - * \brief Number of pending operations, should be the size of + * \brief The delta outputs to pass through callback. + * Elements are sended from the background loop thread and + * consumed by the foreground thread. + */ + std::vector> request_stream_callback_inputs_; + /*! + * \brief Number of pending request operations, should be the size of * `requests_to_add_` and `requests_to_abort_`. */ - std::atomic pending_operation_cnt_ = 0; + std::atomic pending_request_operation_cnt_ = 0; + /*! + * \brief Number of pending request stream callback invocations. + * It should be the size of `request_stream_callback_inputs_`. + */ + std::atomic pending_request_stream_callback_cnt_ = 0; + /*! \brief A boolean flag indicating if the engine is waiting for new requests/aborts. */ + bool engine_waiting_ = false; + /*! \brief A boolean flag indicating if the stream callback loop is waiting. */ + bool stream_callback_waiting_ = false; }; TVM_REGISTER_GLOBAL("mlc.serve.create_threaded_engine").set_body_typed([]() { diff --git a/cpp/serve/async_threaded_engine.h b/cpp/serve/async_threaded_engine.h index afb82e3d06..550bd81623 100644 --- a/cpp/serve/async_threaded_engine.h +++ b/cpp/serve/async_threaded_engine.h @@ -33,6 +33,9 @@ class AsyncThreadedEngine { /*! \brief Starts the background request processing loop. */ virtual void RunBackgroundLoop() = 0; + /*! \brief Starts the request stream callback loop. */ + virtual void RunBackgroundStreamBackLoop() = 0; + /*! * \brief Notify the AsyncThreadedEngine to exit the background * request processing loop. This method is invoked by threads diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 2d1e711cad..6c060a7e27 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -305,23 +305,6 @@ void ClearGlobalMemoryManager() { } std::unique_ptr CreateEnginePacked(TVMArgs args) { - static const char* kErrorMessage = - "With `n` models, engine initialization " - "takes (6 + 4 * n) arguments. The first 6 arguments should be: " - "1) (int) maximum length of a sequence, which must be equal or smaller than the context " - "window size of each model; " - "2) (string) path to tokenizer configuration files, which in MLC LLM, usually in a model " - "weights directory; " - "3) (string) JSON configuration for the KVCache; " - "4) (string) JSON mode for Engine;" - "5) (packed function, optional) global request stream callback function. " - "6) (EventTraceRecorder, optional) the event trace recorder for requests." - "The following (4 * n) arguments, 4 for each model, should be: " - "1) (tvm.runtime.Module) The model library loaded into TVM's RelaxVM; " - "2) (string) Model path which includes weights and mlc-chat-config.json; " - "3) (int, enum DLDeviceType) Device type, e.g. CUDA, ROCm, etc; " - "4) (int) Device id, i.e. the ordinal index of the device that exists locally."; - ClearGlobalMemoryManager(); const int num_non_model_args = 6; const int num_model_args = 4; @@ -352,7 +335,7 @@ std::unique_ptr CreateEnginePacked(TVMArgs args) { model_infos.emplace_back(model_lib, model_path, DLDevice{device_type, device_id}); } } catch (const dmlc::Error& e) { - LOG(FATAL) << "ValueError: " << e.what() << kErrorMessage; + LOG(FATAL) << "ValueError: " << e.what() << kEngineCreationErrorMessage; } return Engine::Create(max_single_sequence_length, tokenizer_path, kv_cache_config_json_str, engine_mode_json_str, request_stream_callback, std::move(trace_recorder), diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index 54de1ddc68..9ff38bdc42 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -115,6 +115,23 @@ class Engine { */ std::unique_ptr CreateEnginePacked(TVMArgs args); +constexpr const char* kEngineCreationErrorMessage = + "With `n` models, engine initialization " + "takes (6 + 4 * n) arguments. The first 6 arguments should be: " + "1) (int) maximum length of a sequence, which must be equal or smaller than the context " + "window size of each model; " + "2) (string) path to tokenizer configuration files, which in MLC LLM, usually in a model " + "weights directory; " + "3) (string) JSON configuration for the KVCache; " + "4) (string) JSON mode for Engine;" + "5) (packed function, optional) global request stream callback function. " + "6) (EventTraceRecorder, optional) the event trace recorder for requests." + "The following (4 * n) arguments, 4 for each model, should be: " + "1) (tvm.runtime.Module) The model library loaded into TVM's RelaxVM; " + "2) (string) Model path which includes weights and mlc-chat-config.json; " + "3) (int, enum DLDeviceType) Device type, e.g. CUDA, ROCm, etc; " + "4) (int) Device id, i.e. the ordinal index of the device that exists locally."; + } // namespace serve } // namespace llm } // namespace mlc diff --git a/python/mlc_llm/serve/async_engine.py b/python/mlc_llm/serve/async_engine.py index 58636cb83b..590d9a805f 100644 --- a/python/mlc_llm/serve/async_engine.py +++ b/python/mlc_llm/serve/async_engine.py @@ -6,7 +6,17 @@ import sys import threading from dataclasses import dataclass -from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, Tuple, Union +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, +) import tvm @@ -102,6 +112,123 @@ async def __anext__(self) -> List[AsyncStreamOutput]: return result +class _AsyncThreadedEngineState: + """The engine states that the request stream callback function may use. + We use this state class to avoid the callback function from capturing + the AsyncThreadedEngine. + """ + + trace_recorder = None + # The mapping from request ids to request asynchronous stream. + request_tools: Dict[str, Tuple[AsyncRequestStream, List[TextStreamer]]] = {} + num_unfinished_generations: Dict[str, int] = {} + _async_event_loop: Optional[asyncio.AbstractEventLoop] = None + + def __init__(self, enable_tracing: bool) -> None: + if enable_tracing: + self.trace_recorder = EventTraceRecorder() + + def lazy_init_event_loop(self) -> None: + """Lazily set the asyncio event loop so that the event + loop is the main driving event loop of the process. + """ + if self._async_event_loop is None: + self._async_event_loop = asyncio.get_event_loop() + + def get_request_stream_callback(self) -> Callable[[List[data.RequestStreamOutput]], None]: + """Construct a callback function and return.""" + + def _callback(delta_outputs: List[data.RequestStreamOutput]) -> None: + self._request_stream_callback(delta_outputs) + + return _callback + + def _request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: + """The request stream callback function for engine to stream back + the request generation results. + + Parameters + ---------- + delta_outputs : List[data.RequestStreamOutput] + The delta output of each requests. + Check out data.RequestStreamOutput for the fields of the outputs. + + Note + ---- + This callback function uses `call_soon_threadsafe` in asyncio to + schedule the invocation in the event loop, so that the underlying + callback logic will be executed asynchronously in the future rather + than right now. + """ + + # Schedule a callback run in the event loop without executing right now. + # NOTE: This function causes GIL during execution. + self._async_event_loop.call_soon_threadsafe( + self._request_stream_callback_impl, delta_outputs + ) + + def _request_stream_callback_impl(self, delta_outputs: List[data.RequestStreamOutput]) -> None: + """The underlying implementation of request stream callback.""" + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + tools = self.request_tools.get(request_id, None) + if tools is None: + continue + + self.record_event(request_id, event="start callback") + stream, text_streamers = tools + outputs = [] + for stream_output, text_streamer in zip(stream_outputs, text_streamers): + self.record_event(request_id, event="start detokenization") + delta_text = ( + text_streamer.put(stream_output.delta_token_ids) + if len(stream_output.delta_token_ids) > 0 + else "" + ) + if stream_output.finish_reason is not None: + delta_text += text_streamer.finish() + self.record_event(request_id, event="finish detokenization") + + outputs.append( + AsyncStreamOutput( + delta_text=delta_text, + num_delta_tokens=len(stream_output.delta_token_ids), + delta_logprob_json_strs=stream_output.delta_logprob_json_strs, + finish_reason=stream_output.finish_reason, + ) + ) + if stream_output.finish_reason is not None: + self.num_unfinished_generations[request_id] -= 1 + + # Push new delta text to the stream. + stream.push(outputs) + if self.num_unfinished_generations[request_id] == 0: + stream.finish() + self.request_tools.pop(request_id, None) + self.record_event(request_id, event="finish callback") + + def record_event(self, request_id: str, event: str) -> None: + """Record a event for the the input request in the trace + recorder when the recorder exists. + + Parameters + ---------- + request_id : str + The subject request of the event. + + event : str + The event in a string name. + It can have one of the following patterns: + - "start xxx", which marks the start of event "xxx", + - "finish xxx", which marks the finish of event "xxx", + - "yyy", which marks the instant event "yyy". + The "starts" and "finishes" will be automatically paired in the trace recorder. + """ + if self.trace_recorder is None: + return + self.trace_recorder.add_event(request_id, event) + + class AsyncThreadedEngine: # pylint: disable=too-many-instance-attributes """The asynchronous engine for generate text asynchronously, backed by ThreadedEngine. @@ -145,9 +272,9 @@ def __init__( prefill_chunk_size, self.conv_template_name, ) = _process_model_args(models) - self.trace_recorder = EventTraceRecorder() if enable_tracing else None # Todo(mlc-team): use `max_single_sequence_length` only after impl input chunking. self.max_input_sequence_length = min(max_single_sequence_length, prefill_chunk_size) + self.state = _AsyncThreadedEngineState(enable_tracing) if kv_cache_config.max_total_sequence_length is None: kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( @@ -169,6 +296,7 @@ def __init__( "add_request", "abort_request", "run_background_loop", + "run_background_stream_back_loop", "init_background_engine", "exit_background_loop", ] @@ -178,28 +306,30 @@ def __init__( # The default engine mode: non-speculative engine_mode = EngineMode() - # The mapping from request ids to request asynchronous stream. - self._request_tools: Dict[str, Tuple[AsyncRequestStream, List[TextStreamer]]] = {} - self._num_unfinished_generations: Dict[str, int] = {} - def _background_loop(): self._ffi["init_background_engine"]( max_single_sequence_length, tokenizer_path, kv_cache_config.asjson(), engine_mode.asjson(), - self._request_stream_callback, - self.trace_recorder, + self.state.get_request_stream_callback(), + self.state.trace_recorder, *model_args, ) self._ffi["run_background_loop"]() + def _background_stream_back_loop(): + self._ffi["run_background_stream_back_loop"]() + # Create the background engine-driving thread and start the loop. self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_stream_back_loop_thread: threading.Thread = threading.Thread( + target=_background_stream_back_loop + ) self._background_loop_thread.start() + self._background_stream_back_loop_thread.start() # The main thread request handling asyncio event loop, which will # be lazily initialized. - self._async_event_loop: Optional[asyncio.AbstractEventLoop] = None self._terminated = False def terminate(self): @@ -207,6 +337,7 @@ def terminate(self): self._terminated = True self._ffi["exit_background_loop"]() self._background_loop_thread.join() + self._background_stream_back_loop_thread.join() async def generate( self, @@ -232,10 +363,7 @@ async def generate( """ if self._terminated: raise ValueError("The AsyncThreadedEngine has terminated.") - if self._async_event_loop is None: - # Lazily set the asyncio event loop so that the event - # loop is the main driving event loop of the process. - self._async_event_loop = asyncio.get_event_loop() + self.state.lazy_init_event_loop() def convert_to_data( prompt: Union[str, List[int], Sequence[Union[str, List[int], data.Data]]] @@ -255,7 +383,7 @@ def convert_to_data( # Create the unique stream of the request. stream = AsyncRequestStream() - if request_id in self._request_tools: + if request_id in self.state.request_tools: # Report error in the stream if the request id already exists. stream.push( RuntimeError( @@ -265,11 +393,11 @@ def convert_to_data( ) else: # Record the stream in the tracker - self._request_tools[request_id] = ( + self.state.request_tools[request_id] = ( stream, [TextStreamer(self.tokenizer) for _ in range(generation_config.n)], ) - self._num_unfinished_generations[request_id] = generation_config.n + self.state.num_unfinished_generations[request_id] = generation_config.n self._ffi["add_request"](request) # Iterate the stream asynchronously and yield the token. @@ -292,89 +420,5 @@ async def abort(self, request_id: str) -> None: def _abort(self, request_id: str): """Internal implementation of request abortion.""" - self._request_tools.pop(request_id, None) + self.state.request_tools.pop(request_id, None) self._ffi["abort_request"](request_id) - - def _request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: - """The request stream callback function for engine to stream back - the request generation results. - - Parameters - ---------- - delta_outputs : List[data.RequestStreamOutput] - The delta output of each requests. - Check out data.RequestStreamOutput for the fields of the outputs. - - Note - ---- - This callback function uses `call_soon_threadsafe` in asyncio to - schedule the invocation in the event loop, so that the underlying - callback logic will be executed asynchronously in the future rather - than right now. - """ - # Schedule a callback run in the event loop without executing right now. - # NOTE: This function causes GIL during execution. - self._async_event_loop.call_soon_threadsafe( - self._request_stream_callback_impl, delta_outputs - ) - - def _request_stream_callback_impl(self, delta_outputs: List[data.RequestStreamOutput]) -> None: - """The underlying implementation of request stream callback.""" - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - tools = self._request_tools.get(request_id, None) - if tools is None: - continue - - self.record_event(request_id, event="start callback") - stream, text_streamers = tools - outputs = [] - for stream_output, text_streamer in zip(stream_outputs, text_streamers): - self.record_event(request_id, event="start detokenization") - delta_text = ( - text_streamer.put(stream_output.delta_token_ids) - if len(stream_output.delta_token_ids) > 0 - else "" - ) - if stream_output.finish_reason is not None: - delta_text += text_streamer.finish() - self.record_event(request_id, event="finish detokenization") - - outputs.append( - AsyncStreamOutput( - delta_text=delta_text, - num_delta_tokens=len(stream_output.delta_token_ids), - delta_logprob_json_strs=stream_output.delta_logprob_json_strs, - finish_reason=stream_output.finish_reason, - ) - ) - if stream_output.finish_reason is not None: - self._num_unfinished_generations[request_id] -= 1 - - # Push new delta text to the stream. - stream.push(outputs) - if self._num_unfinished_generations[request_id] == 0: - stream.finish() - self._request_tools.pop(request_id, None) - self.record_event(request_id, event="finish callback") - - def record_event(self, request_id: str, event: str) -> None: - """Record a event for the the input request in the trace - recorder when the recorder exists. - - Parameters - ---------- - request_id : str - The subject request of the event. - - event : str - The event in a string name. - It can have one of the following patterns: - - "start xxx", which marks the start of event "xxx", - - "finish xxx", which marks the finish of event "xxx", - - "yyy", which marks the instant event "yyy". - The "starts" and "finishes" will be automatically paired in the trace recorder. - """ - if self.trace_recorder is None: - return - self.trace_recorder.add_event(request_id, event) diff --git a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py index 45da755986..c069f65ede 100644 --- a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py @@ -1,4 +1,5 @@ """MLC LLM server debug entrypoints""" + import json from http import HTTPStatus @@ -40,9 +41,9 @@ async def debug_dump_event_trace(request: fastapi.Request): return entrypoint_utils.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{model}" is not served.' ) - if async_engine.trace_recorder is None: + if async_engine.state.trace_recorder is None: return entrypoint_utils.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{model}" does not enable tracing' ) - return json.loads(async_engine.trace_recorder.dump_json()) + return json.loads(async_engine.state.trace_recorder.dump_json()) diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index ee4ddf7db9..2a55df041d 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -61,7 +61,7 @@ async def request_completion(request: CompletionRequest, raw_request: fastapi.Re HTTPStatus.BAD_REQUEST, message=f'The requested model "{request.model}" is not served.' ) request_id = f"cmpl-{entrypoint_utils.random_uuid()}" - async_engine.record_event(request_id, event="receive request") + async_engine.state.record_event(request_id, event="receive request") # - Check if unsupported arguments are specified. error = entrypoint_utils.check_unsupported_fields(request) @@ -69,9 +69,9 @@ async def request_completion(request: CompletionRequest, raw_request: fastapi.Re return error # - Process prompt and check validity. - async_engine.record_event(request_id, event="start tokenization") + async_engine.state.record_event(request_id, event="start tokenization") prompts = entrypoint_utils.process_prompts(request.prompt, async_engine.tokenizer.encode) - async_engine.record_event(request_id, event="finish tokenization") + async_engine.state.record_event(request_id, event="finish tokenization") if isinstance(prompts, fastapi.responses.JSONResponse): # Errored when processing the prompts return prompts @@ -113,7 +113,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: # - Generate new tokens. num_completion_tokens = 0 finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - async_engine.record_event(request_id, event="invoke generate") + async_engine.state.record_event(request_id, event="invoke generate") async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): assert len(delta_outputs) == generation_cfg.n choices = [] @@ -158,7 +158,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ), ) yield f"data: {response.model_dump_json()}\n\n" - async_engine.record_event(request_id, event="finish") + async_engine.state.record_event(request_id, event="finish") # - Echo the suffix. if request.suffix is not None: @@ -195,7 +195,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: logprob_json_strs_list: Optional[List[List[str]]] = ( [[] for _ in range(generation_cfg.n)] if generation_cfg.logprobs else None ) - async_engine.record_event(request_id, event="invoke generate") + async_engine.state.record_event(request_id, event="invoke generate") async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified @@ -218,7 +218,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: logprob_json_strs_list[i] += delta_output.delta_logprob_json_strs assert all(finish_reason is not None for finish_reason in finish_reasons) suffix = request.suffix if request.suffix is not None else "" - async_engine.record_event(request_id, event="finish") + async_engine.state.record_event(request_id, event="finish") response = CompletionResponse( id=request_id, choices=[ @@ -361,7 +361,7 @@ async def request_chat_completion( HTTPStatus.BAD_REQUEST, message=f'The requested model "{request.model}" is not served.' ) request_id = f"chatcmpl-{entrypoint_utils.random_uuid()}" - async_engine.record_event(request_id, event="receive request") + async_engine.state.record_event(request_id, event="receive request") # - Check if the model supports chat conversation. conv_template = ServerContext.get_conv_template(request.model) @@ -403,7 +403,7 @@ async def request_chat_completion( # - Get the prompt from template, and encode to token ids. # - Check prompt length - async_engine.record_event(request_id, event="start tokenization") + async_engine.state.record_event(request_id, event="start tokenization") model_config = ServerContext.get_model_config(request.model) prompts = entrypoint_utils.process_prompts( @@ -411,7 +411,7 @@ async def request_chat_completion( async_engine.tokenizer.encode, ) - async_engine.record_event(request_id, event="finish tokenization") + async_engine.state.record_event(request_id, event="finish tokenization") if conv_template.system_prefix_token_ids is not None: prompts[0] = conv_template.system_prefix_token_ids + prompts[0] error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_input_sequence_length) @@ -431,7 +431,7 @@ async def request_chat_completion( if request.stream: async def completion_stream_generator() -> AsyncGenerator[str, None]: - async_engine.record_event(request_id, event="invoke generate") + async_engine.state.record_event(request_id, event="invoke generate") finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): assert len(delta_outputs) == generation_cfg.n @@ -447,7 +447,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: finish_reason_updated = True if not finish_reason_updated and delta_output.delta_text == "": # Ignore empty delta text when finish reason is not updated. - async_engine.record_event(request_id, event="skip empty delta text") + async_engine.state.record_event(request_id, event="skip empty delta text") continue choices.append( @@ -479,9 +479,9 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: model=request.model, system_fingerprint="", ) - async_engine.record_event(request_id, event="yield delta output") + async_engine.state.record_event(request_id, event="yield delta output") yield f"data: {response.model_dump_json()}\n\n" - async_engine.record_event(request_id, event="finish") + async_engine.state.record_event(request_id, event="finish") yield "data: [DONE]\n\n" return fastapi.responses.StreamingResponse( @@ -495,7 +495,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: logprob_json_strs_list: Optional[List[List[str]]] = ( [[] for _ in range(generation_cfg.n)] if generation_cfg.logprobs else None ) - async_engine.record_event(request_id, event="invoke generate") + async_engine.state.record_event(request_id, event="invoke generate") async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified @@ -518,7 +518,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: logprob_json_strs_list[i] += delta_output.delta_logprob_json_strs assert all(finish_reason is not None for finish_reason in finish_reasons) - async_engine.record_event(request_id, event="finish") + async_engine.state.record_event(request_id, event="finish") tool_calls_list: List[List[ChatToolCall]] = [[] for _ in range(generation_cfg.n)] if conv_template.use_function_calling: diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index de335f9735..45926002ae 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -185,7 +185,7 @@ async def generate_task( for i, output in enumerate(outputs): print(f"Output {req_id}({i}):{output}\n") - print(async_engine.trace_recorder.dump_json(), file=open("tmpfiles/tmp.json", "w")) + print(async_engine.state.trace_recorder.dump_json(), file=open("tmpfiles/tmp.json", "w")) async_engine.terminate() From ad068c22fd7c67632953e7debf63cf0210766de8 Mon Sep 17 00:00:00 2001 From: Animesh Bohara Date: Fri, 29 Mar 2024 07:06:46 -0400 Subject: [PATCH 123/531] [LLaVa] Fix random token output after first sentence (#2048) Fix Llava random token after first '.' token Co-authored-by: Animesh Bohara --- python/mlc_llm/conversation_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index 167ed1fb28..ccb4e72bdd 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -284,7 +284,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: Conversation( name="llava", system_template=f"{MessagePlaceholders.SYSTEM.value}", - system_message="", + system_message="\n", roles={"user": "USER", "assistant": "ASSISTANT"}, seps=[" "], role_content_sep=": ", From b4b8e918c102fe1f2d794d4d8cd95826e04537eb Mon Sep 17 00:00:00 2001 From: Git bot Date: Fri, 29 Mar 2024 15:52:31 +0000 Subject: [PATCH 124/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 31175052db..dc0960bff3 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 31175052dbeeb1c6c4e3fb870024e19872534a7d +Subproject commit dc0960bff3a4cfe0f0b09e02bdb848b4e0d6807a From 1acd5f5eea57653fa355232a1b5bc346cf99d337 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 29 Mar 2024 12:53:24 -0400 Subject: [PATCH 125/531] [Pass] Fix LiftGlobalBufferAlloc for proper GlobalVar struct info (#2053) This PR fixes the GlobalVar struct info mismatch issue cased by pass LiftGlobalBufferAlloc after a latest TVM commit. --- .../compiler_pass/lift_global_buffer_alloc.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/python/mlc_llm/compiler_pass/lift_global_buffer_alloc.py b/python/mlc_llm/compiler_pass/lift_global_buffer_alloc.py index bf709bce04..68f47db811 100644 --- a/python/mlc_llm/compiler_pass/lift_global_buffer_alloc.py +++ b/python/mlc_llm/compiler_pass/lift_global_buffer_alloc.py @@ -1,4 +1,5 @@ """A compiler pass that lifts TIR-level global allocation to Relax.""" + from typing import Dict, List, Tuple import tvm @@ -27,7 +28,7 @@ def __init__(self, mod: IRModule): super().__init__(mod) self.mod = mod self.gv2new_tensor_sinfo: Dict[ - tvm.ir.GlobalVar, Tuple[List[relax.TensorStructInfo], tir.PrimFunc] + tvm.ir.GlobalVar, Tuple[tvm.ir.GlobalVar, List[relax.TensorStructInfo], tir.PrimFunc] ] = {} def transform(self) -> IRModule: @@ -36,8 +37,8 @@ def transform(self) -> IRModule: if isinstance(func, tir.PrimFunc): updated_func, tensor_sinfo_list = remove_global_buf_alloc(func) if len(tensor_sinfo_list) > 0: - self.gv2new_tensor_sinfo[g_var] = (tensor_sinfo_list, func) - self.builder_.update_func(g_var, updated_func) + new_gv = self.builder_.add_func(updated_func, g_var.name_hint) + self.gv2new_tensor_sinfo[g_var] = (new_gv, tensor_sinfo_list, func) self.mod = self.builder_.get() for g_var, func in self.mod.functions_items(): @@ -45,7 +46,9 @@ def transform(self) -> IRModule: updated_func = self.visit_expr(func) updated_func = remove_all_unused(updated_func) self.builder_.update_func(g_var, updated_func) - return self.builder_.get() + + mod = self.builder_.get() + return relax.transform.DeadCodeElimination()(mod) def visit_call_(self, call: relax.Call): # pylint: disable=arguments-renamed call = self.visit_expr_post_order(call) @@ -56,21 +59,22 @@ def visit_call_(self, call: relax.Call): # pylint: disable=arguments-renamed return call g_var = call.args[0] - tensor_sinfo, func_before_update = self.gv2new_tensor_sinfo[g_var] + new_gv, tensor_sinfo, func_before_update = self.gv2new_tensor_sinfo[g_var] assert len(call.sinfo_args) == 1 if any(_has_symbolic_var(sinfo) for sinfo in tensor_sinfo): tensor_sinfo, success = _resolve_tir_var_mapping(func_before_update, call, tensor_sinfo) if not success: # Cannot resolve TIR var mapping. Fall back to no lifting. - self.builder_.update_func(g_var, func_before_update) self.gv2new_tensor_sinfo.pop(g_var) return call + args = list(call.args) + args[0] = new_gv if isinstance(call.sinfo_args[0], relax.TensorStructInfo): new_call = relax.Call( call.op, - args=call.args, + args=args, sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args) + tensor_sinfo)], attrs=call.attrs, ) @@ -79,7 +83,7 @@ def visit_call_(self, call: relax.Call): # pylint: disable=arguments-renamed assert isinstance(call.sinfo_args[0], relax.TupleStructInfo) return relax.Call( call.op, - args=call.args, + args=args, sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args[0].fields) + tensor_sinfo)], attrs=call.attrs, ) From 2f171b4ddd1c3207af63f7032c63b6fe3cbb4569 Mon Sep 17 00:00:00 2001 From: Git bot Date: Fri, 29 Mar 2024 17:15:21 +0000 Subject: [PATCH 126/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index dc0960bff3..6d47d37dfe 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit dc0960bff3a4cfe0f0b09e02bdb848b4e0d6807a +Subproject commit 6d47d37dfe0e8f7bd079859d2aa744531887dacb From 55d7dc34726d74fc5874f4c2d0be3f8b0dc3d02c Mon Sep 17 00:00:00 2001 From: Kartik Khandelwal Date: Fri, 29 Mar 2024 13:36:22 -0400 Subject: [PATCH 127/531] [Serving] CLI Support for SERVE (#2014) This PR adds CLI support for serve. Usage: `mlc_llm serve [Model]` refer `mlc_llm serve -h` for more options Comments - Supports JIT compilation of Model lib - Added context manager to `ServerContext` class Co-authored-by: Ruihang Lai Co-authored-by: Shrey Gupta --- python/mlc_llm/__main__.py | 7 +- python/mlc_llm/cli/serve.py | 89 +++++++++++++++++++ python/mlc_llm/help.py | 24 ++++- python/mlc_llm/interface/jit.py | 1 + python/mlc_llm/interface/serve.py | 60 +++++++++++++ python/mlc_llm/serve/async_engine.py | 13 ++- .../serve/entrypoints/debug_entrypoints.py | 5 +- .../serve/entrypoints/openai_entrypoints.py | 14 +-- python/mlc_llm/serve/server/__init__.py | 1 + python/mlc_llm/serve/server/__main__.py | 58 ++++++------ python/mlc_llm/serve/server/popen_server.py | 6 +- python/mlc_llm/serve/server/server_context.py | 56 +++++++----- 12 files changed, 271 insertions(+), 63 deletions(-) create mode 100644 python/mlc_llm/cli/serve.py create mode 100644 python/mlc_llm/interface/serve.py diff --git a/python/mlc_llm/__main__.py b/python/mlc_llm/__main__.py index 3888b6839f..857cfc479a 100644 --- a/python/mlc_llm/__main__.py +++ b/python/mlc_llm/__main__.py @@ -1,4 +1,5 @@ """Entrypoint of all CLI commands from MLC LLM""" + import sys from mlc_llm.support import logging @@ -13,7 +14,7 @@ def main(): parser.add_argument( "subcommand", type=str, - choices=["compile", "convert_weight", "gen_config", "chat", "bench"], + choices=["compile", "convert_weight", "gen_config", "chat", "serve", "bench"], help="Subcommand to to run. (choices: %(choices)s)", ) parsed = parser.parse_args(sys.argv[1:2]) @@ -33,6 +34,10 @@ def main(): elif parsed.subcommand == "chat": from mlc_llm.cli import chat as cli + cli.main(sys.argv[2:]) + elif parsed.subcommand == "serve": + from mlc_llm.cli import serve as cli + cli.main(sys.argv[2:]) elif parsed.subcommand == "bench": from mlc_llm.cli import bench as cli diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py new file mode 100644 index 0000000000..4ad2319390 --- /dev/null +++ b/python/mlc_llm/cli/serve.py @@ -0,0 +1,89 @@ +"""Command line entrypoint of serve.""" + +import json + +from mlc_llm.help import HELP +from mlc_llm.interface.serve import serve +from mlc_llm.support.argparse import ArgumentParser + + +def main(argv): + """Parse command line arguments and call `mlc_llm.interface.serve`.""" + parser = ArgumentParser("MLC LLM Serve CLI") + + parser.add_argument( + "model", + type=str, + help=HELP["model"] + " (required)", + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help=HELP["device_deploy"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--model-lib-path", + type=str, + default=None, + help=HELP["model_lib_path"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--max-batch-size", + type=int, + default=80, + help=HELP["max_batch_size"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"] + ) + parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"]) + parser.add_argument("--enable-tracing", action="store_true", help=HELP["enable_tracing_serve"]) + parser.add_argument( + "--host", + type=str, + default="127.0.0.1", + help="host name" + ' (default: "%(default)s")', + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="port" + ' (default: "%(default)s")', + ) + parser.add_argument("--allow-credentials", action="store_true", help="allow credentials") + parser.add_argument( + "--allow-origins", + type=json.loads, + default=["*"], + help="allowed origins" + ' (default: "%(default)s")', + ) + parser.add_argument( + "--allow-methods", + type=json.loads, + default=["*"], + help="allowed methods" + ' (default: "%(default)s")', + ) + parser.add_argument( + "--allow-headers", + type=json.loads, + default=["*"], + help="allowed headers" + ' (default: "%(default)s")', + ) + parsed = parser.parse_args(argv) + + serve( + model=parsed.model, + device=parsed.device, + model_lib_path=parsed.model_lib_path, + max_batch_size=parsed.max_batch_size, + max_total_sequence_length=parsed.max_total_seq_length, + prefill_chunk_size=parsed.prefill_chunk_size, + enable_tracing=parsed.enable_tracing, + host=parsed.host, + port=parsed.port, + allow_credentials=parsed.allow_credentials, + allow_origins=parsed.allow_origins, + allow_methods=parsed.allow_methods, + allow_headers=parsed.allow_headers, + ) diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index 0464bd0388..13335c99c1 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -1,4 +1,5 @@ """Help message for CLI arguments.""" + HELP = { "config": ( """ @@ -22,10 +23,12 @@ """.strip(), "model": """ A path to ``mlc-chat-config.json``, or an MLC model directory that contains `mlc-chat-config.json`. +It can also be a link to a HF repository pointing to an MLC compiled model. """.strip(), "model_lib_path": """ The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use -the provided ``model`` to search over possible paths. +the provided ``model`` to search over possible paths. It the model lib path is not found, it will be +compiled in a JIT manner. """.strip(), "model_type": """ Model architecture such as "llama". If not set, it is inferred from `mlc-chat-config.json`. @@ -111,7 +114,7 @@ the number of sinks is 4. This flag subjects to future refactoring. """.strip(), "max_batch_size": """ -The maximum allowed batch size set for batch prefill/decode function. +The maximum allowed batch size set for the KV cache to concurrently support. """.strip(), """tensor_parallel_shards""": """ Number of shards to split the model into in tensor parallelism multi-gpu inference. @@ -138,5 +141,22 @@ """.strip(), "generate_length": """ The target length of the text generation. +""".strip(), + "max_total_sequence_length_serve": """ +The KV cache total token capacity, i.e., the maximum total number of tokens that +the KV cache support. This decides the GPU memory size that the KV cache consumes. +If not specified, system will automatically estimate the maximum capacity based +on the vRAM size on GPU. +""".strip(), + "prefill_chunk_size_serve": """ +The maximum number of tokens the model passes for prefill each time. +It should not exceed the prefill chunk size in model config. +If not specified, this defaults to the prefill chunk size in model config. +""".strip(), + "enable_tracing_serve": """ +Enable Chrome Tracing for the server. +After enabling, you can send POST request to the "debug/dump_event_trace" entrypoint +to get the Chrome Trace. For example, +"curl -X POST http://127.0.0.1:8000/debug/dump_event_trace -H "Content-Type: application/json" -d '{"model": "dist/llama"}'" """.strip(), } diff --git a/python/mlc_llm/interface/jit.py b/python/mlc_llm/interface/jit.py index 06a22eb8fd..25548e0e4a 100644 --- a/python/mlc_llm/interface/jit.py +++ b/python/mlc_llm/interface/jit.py @@ -1,4 +1,5 @@ """Just-in-time compilation of MLC-Chat models.""" + import dataclasses import hashlib import json diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py new file mode 100644 index 0000000000..c9b9b161b5 --- /dev/null +++ b/python/mlc_llm/interface/serve.py @@ -0,0 +1,60 @@ +"""Python entrypoint of serve.""" + +from typing import Any, Optional + +import fastapi +import uvicorn +from fastapi.middleware.cors import CORSMiddleware + +from mlc_llm.serve import async_engine, config +from mlc_llm.serve.entrypoints import debug_entrypoints, openai_entrypoints +from mlc_llm.serve.server import ServerContext + + +def serve( + model: str, + device: str, + model_lib_path: Optional[str], + max_batch_size: int, + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + enable_tracing: bool, + host: str, + port: int, + allow_credentials: bool, + allow_origins: Any, + allow_methods: Any, + allow_headers: Any, +): # pylint: disable=too-many-arguments, too-many-locals + """Serve the model with the specified configuration.""" + # Initialize model loading info and KV cache config + model_info = async_engine.ModelInfo( + model=model, + model_lib_path=model_lib_path, + device=device, + ) + kv_cache_config = config.KVCacheConfig( + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + ) + # Create engine and start the background loop + engine = async_engine.AsyncThreadedEngine( + model_info, kv_cache_config, enable_tracing=enable_tracing + ) + + with ServerContext() as server_context: + server_context.add_model(model, engine) + + app = fastapi.FastAPI() + app.add_middleware( + CORSMiddleware, + allow_credentials=allow_credentials, + allow_origins=allow_origins, + allow_methods=allow_methods, + allow_headers=allow_headers, + ) + + app.include_router(openai_entrypoints.app) + app.include_router(debug_entrypoints.app) + uvicorn.run(app, host=host, port=port, log_level="info") diff --git a/python/mlc_llm/serve/async_engine.py b/python/mlc_llm/serve/async_engine.py index 590d9a805f..652bfa39f8 100644 --- a/python/mlc_llm/serve/async_engine.py +++ b/python/mlc_llm/serve/async_engine.py @@ -272,6 +272,12 @@ def __init__( prefill_chunk_size, self.conv_template_name, ) = _process_model_args(models) + + for i, model in enumerate(models): + # model_args: + # [model_lib_path, model_path, device.device_type, device.device_id] * N + model.model_lib_path = model_args[i * (len(model_args) // len(models))] + # Todo(mlc-team): use `max_single_sequence_length` only after impl input chunking. self.max_input_sequence_length = min(max_single_sequence_length, prefill_chunk_size) self.state = _AsyncThreadedEngineState(enable_tracing) @@ -404,9 +410,12 @@ def convert_to_data( try: async for request_output in stream: yield request_output - except (Exception, asyncio.CancelledError) as e: # pylint: disable=broad-exception-caught + except ( + Exception, + asyncio.CancelledError, + ) as exception: # pylint: disable=broad-exception-caught await self.abort(request_id) - raise e + raise exception async def abort(self, request_id: str) -> None: """Generation abortion interface. diff --git a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py index c069f65ede..b95fd4faae 100644 --- a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py @@ -36,7 +36,10 @@ async def debug_dump_event_trace(request: fastapi.Request): # - Check the requested model. model = request_dict["model"] - async_engine = ServerContext.get_engine(model) + + server_context: ServerContext = ServerContext.current() + async_engine = server_context.get_engine(model) + if async_engine is None: return entrypoint_utils.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{model}" is not served.' diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index 2a55df041d..ac8503d5df 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -43,7 +43,8 @@ async def request_models(): """OpenAI-compatible served model query API. API reference: https://platform.openai.com/docs/api-reference/models """ - return ListResponse(data=[ModelResponse(id=model) for model in ServerContext.get_model_list()]) + server_context: ServerContext = ServerContext.current() + return ListResponse(data=[ModelResponse(id=model) for model in server_context.get_model_list()]) ################ v1/completions ################ @@ -55,7 +56,8 @@ async def request_completion(request: CompletionRequest, raw_request: fastapi.Re API reference: https://platform.openai.com/docs/api-reference/completions/create """ # - Check the requested model. - async_engine = ServerContext.get_engine(request.model) + server_context: ServerContext = ServerContext.current() + async_engine = server_context.get_engine(request.model) if async_engine is None: return entrypoint_utils.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{request.model}" is not served.' @@ -355,7 +357,8 @@ async def request_chat_completion( API reference: https://platform.openai.com/docs/api-reference/chat """ # - Check the requested model. - async_engine = ServerContext.get_engine(request.model) + server_context: ServerContext = ServerContext.current() + async_engine = server_context.get_engine(request.model) if async_engine is None: return entrypoint_utils.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{request.model}" is not served.' @@ -364,7 +367,7 @@ async def request_chat_completion( async_engine.state.record_event(request_id, event="receive request") # - Check if the model supports chat conversation. - conv_template = ServerContext.get_conv_template(request.model) + conv_template = server_context.get_conv_template(request.model) if conv_template is None: return entrypoint_utils.create_error_response( HTTPStatus.BAD_REQUEST, @@ -405,13 +408,14 @@ async def request_chat_completion( # - Check prompt length async_engine.state.record_event(request_id, event="start tokenization") - model_config = ServerContext.get_model_config(request.model) + model_config = server_context.get_model_config(request.model) prompts = entrypoint_utils.process_prompts( conv_template.as_prompt(model_config), async_engine.tokenizer.encode, ) async_engine.state.record_event(request_id, event="finish tokenization") + if conv_template.system_prefix_token_ids is not None: prompts[0] = conv_template.system_prefix_token_ids + prompts[0] error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_input_sequence_length) diff --git a/python/mlc_llm/serve/server/__init__.py b/python/mlc_llm/serve/server/__init__.py index cd4fce257c..3f127048b5 100644 --- a/python/mlc_llm/serve/server/__init__.py +++ b/python/mlc_llm/serve/server/__init__.py @@ -1,3 +1,4 @@ """The server related data structure and tools in MLC LLM serve.""" + from .popen_server import PopenServer from .server_context import ServerContext diff --git a/python/mlc_llm/serve/server/__main__.py b/python/mlc_llm/serve/server/__main__.py index e57e9f4757..ed900edd03 100644 --- a/python/mlc_llm/serve/server/__main__.py +++ b/python/mlc_llm/serve/server/__main__.py @@ -1,4 +1,5 @@ """Entrypoint of RESTful HTTP request server in MLC LLM""" + import argparse import json @@ -6,6 +7,8 @@ import uvicorn from fastapi.middleware.cors import CORSMiddleware +from mlc_llm.serve.entrypoints import debug_entrypoints, openai_entrypoints + from .. import async_engine, config from .server_context import ServerContext @@ -31,23 +34,6 @@ def parse_args_and_initialize() -> argparse.Namespace: parsed = args.parse_args() - # Initialize model loading info and KV cache config - model_info = async_engine.ModelInfo( - model=parsed.model, - model_lib_path=parsed.model_lib_path, - device=parsed.device, - ) - kv_cache_config = config.KVCacheConfig( - max_num_sequence=parsed.max_batch_size, - max_total_sequence_length=parsed.max_total_seq_length, - prefill_chunk_size=parsed.prefill_chunk_size, - ) - # Create engine and start the background loop - engine = async_engine.AsyncThreadedEngine( - model_info, kv_cache_config, enable_tracing=parsed.enable_tracing - ) - - ServerContext.add_model(parsed.model, engine) return parsed @@ -55,17 +41,33 @@ def parse_args_and_initialize() -> argparse.Namespace: # Parse the arguments and initialize the asynchronous engine. args: argparse.Namespace = parse_args_and_initialize() app = fastapi.FastAPI() - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + + # Initialize model loading info and KV cache config + model_info = async_engine.ModelInfo( + model=args.model, + model_lib_path=args.model_lib_path, + device=args.device, + ) + kv_cache_config = config.KVCacheConfig( + max_num_sequence=args.max_batch_size, + max_total_sequence_length=args.max_total_seq_length, + prefill_chunk_size=args.prefill_chunk_size, + ) + # Create engine and start the background loop + engine = async_engine.AsyncThreadedEngine( + model_info, kv_cache_config, enable_tracing=args.enable_tracing ) - # Include the routers from subdirectories. - from ..entrypoints import debug_entrypoints, openai_entrypoints + with ServerContext() as server_context: + server_context.add_model(args.model, engine) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) - app.include_router(openai_entrypoints.app) - app.include_router(debug_entrypoints.app) - uvicorn.run(app, host=args.host, port=args.port, log_level="info") + app.include_router(openai_entrypoints.app) + app.include_router(debug_entrypoints.app) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index fcdfe6da39..ed63f6ac51 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -26,8 +26,7 @@ def __init__( # pylint: disable=too-many-arguments host: str = "127.0.0.1", port: int = 8000, ) -> None: - """Please check out `python/mlc_llm/serve/server/__main__.py` - for the server arguments.""" + """Please check out `python/mlc_llm/cli/serve.py` for the server arguments.""" self.model = model self.model_lib_path = model_lib_path self.device = device @@ -43,8 +42,7 @@ def start(self) -> None: Wait until the server becomes ready before return. """ cmd = [sys.executable] - cmd += ["-m", "mlc_llm.serve.server"] - cmd += ["--model", self.model] + cmd += ["-m", "mlc_llm", "serve", self.model] cmd += ["--model-lib-path", self.model_lib_path] cmd += ["--device", self.device] cmd += ["--max-batch-size", str(self.max_batch_size)] diff --git a/python/mlc_llm/serve/server/server_context.py b/python/mlc_llm/serve/server/server_context.py index c18bab466b..baad7b5e7d 100644 --- a/python/mlc_llm/serve/server/server_context.py +++ b/python/mlc_llm/serve/server/server_context.py @@ -14,47 +14,63 @@ class ServerContext: and corresponding async engines. """ - _models: Dict[str, async_engine.AsyncThreadedEngine] = {} - _conv_templates: Dict[str, Conversation] = {} - _model_configs: Dict[str, Dict] = {} + server_context: Optional["ServerContext"] = None + + def __init__(self): + self._models: Dict[str, async_engine.AsyncThreadedEngine] = {} + self._conv_templates: Dict[str, Conversation] = {} + self._model_configs: Dict[str, Dict] = {} + + def __enter__(self): + if ServerContext.server_context is not None: + raise RuntimeError("Server context already exists.") + ServerContext.server_context = self + return self + + def __exit__(self, exc_type, exc_value, traceback): + for model_engine in self._models.values(): + model_engine.terminate() + self._models.clear() + self._conv_templates.clear() + self._model_configs.clear() @staticmethod - def add_model(hosted_model: str, engine: async_engine.AsyncThreadedEngine) -> None: + def current(): + """Returns the current ServerContext.""" + return ServerContext.server_context + + def add_model(self, hosted_model: str, engine: async_engine.AsyncThreadedEngine) -> None: """Add a new model to the server context together with the engine.""" - if hosted_model in ServerContext._models: + if hosted_model in self._models: raise RuntimeError(f"Model {hosted_model} already running.") - ServerContext._models[hosted_model] = engine + self._models[hosted_model] = engine # Get the conversation template. if engine.conv_template_name is not None: conv_template = ConvTemplateRegistry.get_conv_template(engine.conv_template_name) if conv_template is not None: - ServerContext._conv_templates[hosted_model] = conv_template + self._conv_templates[hosted_model] = conv_template _, config_file_path = _get_model_path(hosted_model) with open(config_file_path, "r", encoding="utf-8") as file: config = json.load(file) - ServerContext._model_configs[hosted_model] = config + self._model_configs[hosted_model] = config - @staticmethod - def get_engine(model: str) -> Optional[async_engine.AsyncThreadedEngine]: + def get_engine(self, model: str) -> Optional[async_engine.AsyncThreadedEngine]: """Get the async engine of the requested model.""" - return ServerContext._models.get(model, None) + return self._models.get(model, None) - @staticmethod - def get_conv_template(model: str) -> Optional[Conversation]: + def get_conv_template(self, model: str) -> Optional[Conversation]: """Get the conversation template of the requested model.""" - conv_template = ServerContext._conv_templates.get(model, None) + conv_template = self._conv_templates.get(model, None) if conv_template is not None: return conv_template.model_copy(deep=True) return None - @staticmethod - def get_model_list() -> List[str]: + def get_model_list(self) -> List[str]: """Get the list of models on serve.""" - return list(ServerContext._models.keys()) + return list(self._models.keys()) - @staticmethod - def get_model_config(model: str) -> Optional[Dict]: + def get_model_config(self, model: str) -> Optional[Dict]: """Get the model config path of the requested model.""" - return ServerContext._model_configs.get(model, None) + return self._model_configs.get(model, None) From 203afab8a9328f287e2b660508f36f72e9859207 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 29 Mar 2024 14:39:02 -0700 Subject: [PATCH 128/531] [Pipeline] Insert hints to enable cuda graph symbolic capture (#2050) * [Pipeline] Add pass to insert hints to enable cuda graph symbolic capture --- .../compiler_pass/attach_support_info.py | 22 ++++++++++++++++++- python/mlc_llm/compiler_pass/pipeline.py | 4 ++++ python/mlc_llm/interface/compile.py | 2 ++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/python/mlc_llm/compiler_pass/attach_support_info.py b/python/mlc_llm/compiler_pass/attach_support_info.py index c6ec834b13..dbeb621fdc 100644 --- a/python/mlc_llm/compiler_pass/attach_support_info.py +++ b/python/mlc_llm/compiler_pass/attach_support_info.py @@ -1,6 +1,6 @@ """A couple of passes that simply supportive information onto the IRModule.""" -from typing import Dict +from typing import Dict, List import tvm from tvm import IRModule, relax, tir @@ -46,3 +46,23 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR if isinstance(func, relax.Function): mod[g_var] = func.with_attr("relax.memory_plan_dynamic_func_output", True) return mod + + +@tvm.transform.module_pass(opt_level=0, name="AttachCUDAGraphCaptureHints") +class AttachCUDAGraphSymbolicCaptureHints: # pylint: disable=too-few-public-methods + """Attach CUDA graph capture hints to the IRModule""" + + def __init__(self, hints: Dict[str, List[str]]): + self.hints = hints + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + for g_var, func in mod.functions_items(): + func_name = g_var.name_hint + if isinstance(func, relax.Function): + if func_name in self.hints: + mod[g_var] = func.with_attr( + "relax.rewrite_cuda_graph.capture_symbolic_vars", self.hints[func_name] + ) + + return mod diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index ad19e6a2bf..b85a6a2cf6 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -17,6 +17,7 @@ from .attach_sampler import AttachGPUSamplingFunc from .attach_support_info import ( AttachAdditionalPrimFuncs, + AttachCUDAGraphSymbolicCaptureHints, AttachMemoryPlanAttr, AttachVariableBounds, ) @@ -78,12 +79,14 @@ def _mlc_llm_pipeline( # pylint: disable=too-many-arguments faster_transformer: bool = False, # pylint: disable=unused-argument allreduce_strategy: IPCAllReduceStrategyType = IPCAllReduceStrategyType.NONE, variable_bounds: Dict[str, int] = None, + cuda_graph_symbolic_capture_hints: Dict[str, List[str]] = None, additional_tirs: Dict[str, tvm.tir.PrimFunc] = None, metadata: Dict[str, Any] = None, ext_mods: List[nn.ExternModule] = None, debug_dump: Optional[Path] = None, ): variable_bounds = variable_bounds or {} + cuda_graph_symbolic_capture_hints = cuda_graph_symbolic_capture_hints or {} additional_tirs = additional_tirs or {} metadata = metadata or {} ext_mods = ext_mods or [] @@ -95,6 +98,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Phase 0. Add additional information for compilation and remove unused Relax func DispatchKVCacheCreation(target, flashinfer, metadata), AttachVariableBounds(variable_bounds), + AttachCUDAGraphSymbolicCaptureHints(cuda_graph_symbolic_capture_hints), AttachLogitProcessFunc(target), AttachAdditionalPrimFuncs(additional_tirs), AttachAllocEmbeddingTensorFunc(metadata), diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index 288e0a39b6..4e8bcabd9e 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -162,6 +162,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: logger.info("Running optimizations using TVM Unity") additional_tirs = _apply_preproc_to_params(named_params, model_config) variable_bounds = _get_variable_bounds(model_config) + cuda_graph_symbolic_capture_hints = {"batch_decode": ["batch_size"]} metadata = { "model_type": args.model.name, "quantization": args.quantization.name, @@ -186,6 +187,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: faster_transformer=args.opt.faster_transformer, allreduce_strategy=args.opt.ipc_allreduce_strategy, variable_bounds=variable_bounds, + cuda_graph_symbolic_capture_hints=cuda_graph_symbolic_capture_hints, additional_tirs=additional_tirs, ext_mods=ext_mods, metadata=metadata, From 6431bdaa90968396b51adbcde3148da72d01ba81 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 29 Mar 2024 19:28:19 -0700 Subject: [PATCH 129/531] [Loader] Print message when multi-GPU loader is finished (#2051) * [Loader] Print message when multi-GPU loader is finished * Update multi_gpu_loader.cc * fix --- cpp/loader/multi_gpu_loader.cc | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/cpp/loader/multi_gpu_loader.cc b/cpp/loader/multi_gpu_loader.cc index e1b2eb0711..75e8ca2c23 100644 --- a/cpp/loader/multi_gpu_loader.cc +++ b/cpp/loader/multi_gpu_loader.cc @@ -124,6 +124,13 @@ NDArray ReceiveBroadcastedOrSharded(Device device, const ModelMetadata::Param& p return result; } +std::string FormatDuration(DurationType duration) { + std::ostringstream os; + auto float_seconds = std::chrono::duration_cast>(duration).count(); + os << std::fixed << std::setprecision(3) << float_seconds << " s"; + return os.str(); +} + Array LoadMultiGPU(const std::string& model_path, Module relax_vm_module, const std::string& model_config_str) { DiscoWorker* worker = DiscoWorker::ThreadLocal(); @@ -174,10 +181,9 @@ Array LoadMultiGPU(const std::string& model_path, Module relax_vm_modul TVMSynchronize(device.device_type, device.device_id, nullptr); } } - auto f_convert = [](DurationType time) { return static_cast(time.count()) / 1e6; }; LOG(INFO) << "Loading done. Time used:" << std::fixed << std::setprecision(3) // - << " Loading " << f_convert(time_loading) << " s;" - << " Preprocessing " << f_convert(time_preproc) << " s."; + << " Loading " << FormatDuration(time_loading) << " Preprocessing " + << FormatDuration(time_preproc) << "."; } else { for (const NDArrayCacheMetadata::FileRecord& record : ndarray_cache_metadata.records) { for (size_t i = 0; i < record.records.size(); ++i) { @@ -226,7 +232,9 @@ Array LoadMultiGPUPresharded(const std::string& model_path, Module rela const NDArrayCacheMetadata::FileRecord* current_file_; std::string current_file_stream_; params.reserve(model_metadata.params.size()); + DurationType time_loading(0); for (const ModelMetadata::Param& param : model_metadata.params) { + RangeTimer _(&time_loading); bool needs_sharding = !param.preprocs.empty(); std::string param_name = needs_sharding ? static_cast( @@ -244,6 +252,10 @@ Array LoadMultiGPUPresharded(const std::string& model_path, Module rela params.push_back(param_record->Load(device, ¤t_file_stream_)); } + SyncWorker(); + if (worker_id == 0) { + LOG(INFO) << "Loading done. Time used: " << FormatDuration(time_loading) << "."; + } return params; } From 12c9808024d7829804cbd11f70cda205b06ab1a7 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 30 Mar 2024 08:27:36 -0400 Subject: [PATCH 130/531] [KVCache] Support matching arbitrary element offset for aux data (#2057) This PR enhances the TIR attention-related functions to support matching arbitrary element offests. This makes room for the KV cache to allocate a large array the all the auxiliary data and do slicing on it. This PR should affect nothing for the current codebase, given all the element offsets are zeros as of now. --- python/mlc_llm/nn/kv_cache.py | 61 +++++++++++++++++-------- python/mlc_llm/op/position_embedding.py | 5 +- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index 2ecf017cf4..206e5d4958 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -380,10 +380,13 @@ def tir_kv_cache_transpose_append( T.func_attr({"tir.noalias": T.bool(True)}) ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") num_pages = T.int64() + position_map_elem_offset = T.int32() pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype) k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype) v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), dtype) - position_map = T.match_buffer(var_position_map, (ntoken,), "int32") + position_map = T.match_buffer( + var_position_map, (ntoken,), "int32", elem_offset=position_map_elem_offset + ) for global_pos, h, f in T.grid(ntoken, num_key_value_heads, head_dim): if position_map[global_pos] != T.int32(-1): with T.block("k_transpose_append"): @@ -421,8 +424,11 @@ def tir_kv_cache_debug_get_kv( seqlen = T.SizeVar("num_tokens_including_cache", "int64") page_size = T.SizeVar("page_size", "int64") num_pages = T.int64() + position_map_elem_offset = T.int64() pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype) - position_map = T.match_buffer(var_position_map, (seqlen,), "int32") + position_map = T.match_buffer( + var_position_map, (seqlen,), "int32", elem_offset=position_map_elem_offset + ) k_data = T.match_buffer(var_k_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype) v_data = T.match_buffer(var_v_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype) for p, h, d in T.grid(seqlen, num_key_value_heads, head_dim): @@ -471,11 +477,11 @@ def _causal_mask(causal, row, col, kv_len, qo_len): ) -def _declare_length_info(var_length_info, batch_size, sliding_window): +def _declare_length_info(var_length_info, batch_size, sliding_window, elem_offset): return ( - T.match_buffer(var_length_info, (3, batch_size), "int32") + T.match_buffer(var_length_info, (3, batch_size), "int32", elem_offset=elem_offset) if sliding_window - else T.match_buffer(var_length_info, (batch_size,), "int32") + else T.match_buffer(var_length_info, (batch_size,), "int32", elem_offset=elem_offset) ) @@ -553,14 +559,20 @@ def batch_prefill_paged_kv( total_len = T.int32(is_size_var=True) nnz_pages = T.int32(is_size_var=True) max_num_pages = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) q = T.match_buffer(var_q, (total_len, h_q, d), dtype) - q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32") + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) - page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32") - page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32") - k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32") - q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32") + page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) + page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) + q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", elem_offset=q_rope_position_elem_offset) output = T.match_buffer(var_output, (total_len, h_q, d), dtype) lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint: disable=unused-variable # The length information of the sequences. @@ -571,7 +583,7 @@ def batch_prefill_paged_kv( # - "(2, i)" is the attn sink length of the sequence. # - It is in shape `(batch_size,)` when sliding window is disabled, # denoting the "last_page_len". - length_info = _declare_length_info(var_length_info, batch_size, sliding_window) + length_info = _declare_length_info(var_length_info, batch_size, sliding_window, length_info_elem_offset) # kernel code for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): @@ -918,15 +930,20 @@ def batch_decode_paged_kv( B = T.int32(is_size_var=True) nnz_pages = T.int32(is_size_var=True) max_num_pages = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype) pages = T.match_buffer( pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype ) - page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32") - page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32") - k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32") - q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32") + page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset) + page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) + k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", elem_offset=k_rope_pos_offset_elem_offset) + q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", elem_offset=q_rope_position_elem_offset) output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype) lse = T.match_buffer(lse_handle, (B, H_qo), "float32") # pylint: disable=unused-variable # The length information of the sequences. @@ -937,7 +954,7 @@ def batch_decode_paged_kv( # - "(2, i)" is the attn sink length of the sequence. # - It is in shape `(batch_size,)` when sliding window is disabled, # denoting the "last_page_len". - length_info = _declare_length_info(var_length_info, B, sliding_window) + length_info = _declare_length_info(var_length_info, B, sliding_window, length_info_elem_offset) sm_scale = 1.0 / math.sqrt(float(D)) * log2e @@ -1236,14 +1253,18 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches batch_size = T.int32(is_size_var=True) qo_len = T.int32(is_size_var=True) kv_len = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + kv_indptr_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) - q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32") + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) - kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32") - q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32") - k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32") + kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset) + q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset) + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable diff --git a/python/mlc_llm/op/position_embedding.py b/python/mlc_llm/op/position_embedding.py index e6cb25d856..4f3c2a9c42 100644 --- a/python/mlc_llm/op/position_embedding.py +++ b/python/mlc_llm/op/position_embedding.py @@ -241,11 +241,14 @@ def fused_rope( # pylint: disable=too-many-locals } ) seq_len = T.int64() + position_map_elem_offset = T.int64() qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype) - position_map = T.match_buffer(var_position_map, (seq_len,), "int32") + position_map = T.match_buffer( + var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset + ) for iters in T.grid(seq_len, fused_heads, head_dim): with T.block("llama_fused_rope"): s, h, d = T.axis.remap("SSS", iters) From af7ef3e2aaed09aa06654f67662b395869628431 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 30 Mar 2024 14:15:05 -0400 Subject: [PATCH 131/531] [Serving] Support copy stream in LogitProcessor and GPUSampler (#2058) This PR introduces copy stream to LogitProcessor and GPUSampler for CUDA, so that auxiliary data can be copied on a separate stream and overlap with the computation time. --- cpp/serve/logit_processor.cc | 61 ++++++++++++++++++++------ cpp/serve/sampler/gpu_sampler.cc | 74 ++++++++++++++++++++++++-------- 2 files changed, 105 insertions(+), 30 deletions(-) diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index 76495ab8a7..9dc4b1b9c5 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -6,6 +6,7 @@ #include "logit_processor.h" #include +#include #include #include #include @@ -15,9 +16,19 @@ namespace mlc { namespace llm { namespace serve { -inline void CopyArray(NDArray src, NDArray dst) { +inline void CopyArray(NDArray src, NDArray dst, TVMStreamHandle copy_stream) { DLTensor dl_dst = *(dst.operator->()); - NDArray::CopyFromTo(src.operator->(), &dl_dst); + NDArray::CopyFromTo(src.operator->(), &dl_dst, copy_stream); +} + +inline void SyncCopyStream(Device device, TVMStreamHandle compute_stream, + TVMStreamHandle copy_stream) { + // - If there is no particular copy stream, no action is needed. + if (copy_stream == nullptr) { + return; + } + // - Sync two streams. + DeviceAPI::Get(device)->SyncStreamFromTo(device, copy_stream, compute_stream); } /***************** LogitProcessor Implementation *****************/ @@ -62,6 +73,22 @@ class LogitProcessorImpl : public LogitProcessorObj { << "Function \"apply_logit_bias_inplace\" not found in model"; CHECK(apply_penalty_func_.defined()) << "Function \"apply_penalty_inplace\" not found in model"; CHECK(apply_bitmask_func_.defined()) << "Function \"apply_bitmask_inplace\" not found in model"; + + // If the device is CUDA/ROCm, we create a standalone copy stream, in + // purpose to hide the latency of auxiliary stream copy. + if (device.device_type == DLDeviceType::kDLCUDA || + device.device_type == DLDeviceType::kDLROCM) { + // The compute stream is the default stream. + compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device); + copy_stream_ = DeviceAPI::Get(device)->CreateStream(device); + } + } + + ~LogitProcessorImpl() { + // Free the copy stream if defined. + if (copy_stream_ != nullptr) { + DeviceAPI::Get(device_)->FreeStream(device_, copy_stream_); + } } void InplaceUpdateLogits(NDArray logits, // @@ -148,7 +175,8 @@ class LogitProcessorImpl : public LogitProcessorObj { NDArray temperature_device = temperature_device_.CreateView({num_total_token}, dtype_f32_); // - Copy arrays to GPU. - CopyArray(/*src=*/temperature_host, /*dst=*/temperature_device); + CopyArray(/*src=*/temperature_host, /*dst=*/temperature_device, copy_stream_); + SyncCopyStream(device_, compute_stream_, copy_stream_); // - Call kernel. NDArray probs = softmax_func_(logits.CreateView({num_total_token, 1, vocab_size_}, dtype_f32_), @@ -209,9 +237,10 @@ class LogitProcessorImpl : public LogitProcessorObj { NDArray token_logit_bias_device = token_logit_bias_device_.CreateView({num_token}, dtype_f32_); // - Copy arrays to GPU. - CopyArray(/*src=*/pos2seq_id_host, /*dst=*/pos2seq_id_device); - CopyArray(/*src=*/token_ids_host, /*dst=*/token_ids_device); - CopyArray(/*src=*/token_logit_bias_host, /*dst=*/token_logit_bias_device); + CopyArray(/*src=*/pos2seq_id_host, /*dst=*/pos2seq_id_device, copy_stream_); + CopyArray(/*src=*/token_ids_host, /*dst=*/token_ids_device, copy_stream_); + CopyArray(/*src=*/token_logit_bias_host, /*dst=*/token_logit_bias_device, copy_stream_); + SyncCopyStream(device_, compute_stream_, copy_stream_); // - Call kernel. apply_logit_bias_func_(logits, pos2seq_id_device, token_ids_device, token_logit_bias_device); @@ -289,11 +318,12 @@ class LogitProcessorImpl : public LogitProcessorObj { NDArray penalties_device = penalties_device_.CreateView({num_seq, 3}, dtype_f32_); // - Copy arrays to GPU. - CopyArray(/*src=*/seq_ids_host, /*dst=*/seq_ids_device); - CopyArray(/*src=*/pos2seq_id_host, /*dst=*/pos2seq_id_device); - CopyArray(/*src=*/token_ids_host, /*dst=*/token_ids_device); - CopyArray(/*src=*/token_cnt_host, /*dst=*/token_cnt_device); - CopyArray(/*src=*/penalties_host, /*dst=*/penalties_device); + CopyArray(/*src=*/seq_ids_host, /*dst=*/seq_ids_device, copy_stream_); + CopyArray(/*src=*/pos2seq_id_host, /*dst=*/pos2seq_id_device, copy_stream_); + CopyArray(/*src=*/token_ids_host, /*dst=*/token_ids_device, copy_stream_); + CopyArray(/*src=*/token_cnt_host, /*dst=*/token_cnt_device, copy_stream_); + CopyArray(/*src=*/penalties_host, /*dst=*/penalties_device, copy_stream_); + SyncCopyStream(device_, compute_stream_, copy_stream_); // - Call kernel. apply_penalty_func_(logits, seq_ids_device, pos2seq_id_device, token_ids_device, @@ -367,8 +397,9 @@ class LogitProcessorImpl : public LogitProcessorObj { NDArray bitmask_device = bitmask_device_.CreateView({batch_size, bitmask_size_}, dtype_i32_); // - Copy arrays to GPU. - CopyArray(/*src=*/seq_ids_host, /*dst=*/seq_ids_device); - CopyArray(/*src=*/bitmask_host, /*dst=*/bitmask_device); + CopyArray(/*src=*/seq_ids_host, /*dst=*/seq_ids_device, copy_stream_); + CopyArray(/*src=*/bitmask_host, /*dst=*/bitmask_device, copy_stream_); + SyncCopyStream(device_, compute_stream_, copy_stream_); // - Call kernel. apply_bitmask_func_(logits, seq_ids_device, bitmask_device); @@ -410,6 +441,10 @@ class LogitProcessorImpl : public LogitProcessorObj { NDArray temperature_device_; // Event trace recorder. Optional trace_recorder_; + // The device stream for the default computation operations. + TVMStreamHandle compute_stream_ = nullptr; + // The device stream for copying auxiliary data structure to GPU. + TVMStreamHandle copy_stream_ = nullptr; // A small epsilon. const double eps_ = 1e-5; }; diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index 0d46d7416b..a290e64b4d 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -3,6 +3,7 @@ * \file serve/sampler/gpu_sampler.cc * \brief The implementation for GPU sampler functions. */ +#include #include #include #include @@ -14,9 +15,19 @@ namespace mlc { namespace llm { namespace serve { -inline void CopyArray(NDArray src, NDArray dst) { +inline void CopyArray(NDArray src, NDArray dst, TVMStreamHandle copy_stream) { DLTensor dl_dst = *(dst.operator->()); - NDArray::CopyFromTo(src.operator->(), &dl_dst); + NDArray::CopyFromTo(src.operator->(), &dl_dst, copy_stream); +} + +inline void SyncCopyStream(Device device, TVMStreamHandle compute_stream, + TVMStreamHandle copy_stream) { + // - If there is no particular copy stream, no action is needed. + if (copy_stream == nullptr) { + return; + } + // - Sync two streams. + DeviceAPI::Get(device)->SyncStreamFromTo(device, copy_stream, compute_stream); } /*********************** GPU Sampler ***********************/ @@ -54,6 +65,22 @@ class GPUSampler : public SamplerObj { sample_indices_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); top_p_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); top_prob_offsets_device_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device); + + // If the device is CUDA/ROCm, we create a standalone copy stream, in + // purpose to hide the latency of auxiliary stream copy. + if (device.device_type == DLDeviceType::kDLCUDA || + device.device_type == DLDeviceType::kDLROCM) { + // The compute stream is the default stream. + compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device); + copy_stream_ = DeviceAPI::Get(device)->CreateStream(device); + } + } + + ~GPUSampler() { + // Free the copy stream if defined. + if (copy_stream_ != nullptr) { + DeviceAPI::Get(device_)->FreeStream(device_, copy_stream_); + } } std::vector BatchSampleTokens(NDArray probs_on_device, // @@ -151,8 +178,8 @@ class GPUSampler : public SamplerObj { NDArray uniform_samples_device = uniform_samples_device_.CreateView({num_samples}, dtype_f32_); NDArray sample_indices_host = sample_indices_host_.CreateView({num_samples}, dtype_i32_); NDArray sample_indices_device = sample_indices_device_.CreateView({num_samples}, dtype_i32_); - CopyArray(/*src=*/uniform_samples_host, /*dst=*/uniform_samples_device); - CopyArray(/*src=*/sample_indices_host, /*dst=*/sample_indices_device); + CopyArray(/*src=*/uniform_samples_host, /*dst=*/uniform_samples_device, copy_stream_); + CopyArray(/*src=*/sample_indices_host, /*dst=*/sample_indices_device, copy_stream_); return {uniform_samples_device, sample_indices_device}; } @@ -201,6 +228,7 @@ class GPUSampler : public SamplerObj { if (!need_top_p && !need_prob_values) { // - Short path: If top_p and prob values are not needed, we directly sample from multinomial. + SyncCopyStream(device_, compute_stream_, copy_stream_); sampled_token_ids_device = gpu_multinomial_from_uniform_func_( probs_on_device, uniform_samples_device, sample_indices_device); return {sampled_token_ids_device, sampled_probs_device, top_prob_probs_device, @@ -213,11 +241,25 @@ class GPUSampler : public SamplerObj { NDArray sorted_probs_on_device = argsort_results[0]; NDArray sorted_indices_on_device = argsort_results[1]; + // - Copy auxiliary array for top-p and prob values in ahead. + NDArray top_p_device; + NDArray top_prob_offsets_device; if (need_top_p) { - // - Sample with top_p applied. NDArray top_p_host = top_p_host_.CreateView({num_probs}, dtype_f32_); - NDArray top_p_device = top_p_device_.CreateView({num_probs}, dtype_f32_); - CopyArray(/*src=*/top_p_host, /*dst=*/top_p_device); + top_p_device = top_p_device_.CreateView({num_probs}, dtype_f32_); + CopyArray(/*src=*/top_p_host, /*dst=*/top_p_device, copy_stream_); + } + if (need_prob_values) { + int num_top_probs = top_prob_offset_indptr.back(); + NDArray top_prob_offsets_host = + top_prob_offsets_host_.CreateView({num_top_probs}, dtype_i32_); + top_prob_offsets_device = top_prob_offsets_device_.CreateView({num_top_probs}, dtype_i32_); + CopyArray(/*src=*/top_prob_offsets_host, /*dst=*/top_prob_offsets_device, copy_stream_); + } + SyncCopyStream(device_, compute_stream_, copy_stream_); + + if (need_top_p) { + // - Sample with top_p applied. sampled_token_ids_device = gpu_sample_with_top_p_func_(sorted_probs_on_device, sorted_indices_on_device, uniform_samples_device, sample_indices_device, top_p_device); @@ -229,12 +271,6 @@ class GPUSampler : public SamplerObj { if (need_prob_values) { // - Take the probability values. - int num_top_probs = top_prob_offset_indptr.back(); - NDArray top_prob_offsets_host = - top_prob_offsets_host_.CreateView({num_top_probs}, dtype_i32_); - NDArray top_prob_offsets_device = - top_prob_offsets_device_.CreateView({num_top_probs}, dtype_i32_); - CopyArray(/*src=*/top_prob_offsets_host, /*dst=*/top_prob_offsets_device); Array prob_value_results = gpu_sampler_take_probs_func_( probs_on_device, sorted_indices_on_device, sample_indices_device, sampled_token_ids_device, top_prob_offsets_device); @@ -258,7 +294,7 @@ class GPUSampler : public SamplerObj { ICHECK_EQ(sampled_token_ids_device->ndim, 1); ICHECK_EQ(sampled_token_ids_device->shape[0], num_samples); NDArray sampled_token_ids_host = sampled_token_ids_host_.CreateView({num_samples}, dtype_i32_); - CopyArray(/*src=*/sampled_token_ids_device, /*dst=*/sampled_token_ids_host); + CopyArray(/*src=*/sampled_token_ids_device, /*dst=*/sampled_token_ids_host, compute_stream_); NDArray sampled_probs_host{nullptr}; NDArray top_prob_probs_host{nullptr}; @@ -276,10 +312,10 @@ class GPUSampler : public SamplerObj { sampled_probs_host = sampled_probs_host_.CreateView({num_samples}, dtype_i32_); top_prob_probs_host = top_prob_probs_host_.CreateView({num_top_probs}, dtype_f32_); top_prob_indices_host = top_prob_indices_host_.CreateView({num_top_probs}, dtype_i32_); - CopyArray(/*src=*/sampled_probs_device, /*dst=*/sampled_probs_host); + CopyArray(/*src=*/sampled_probs_device, /*dst=*/sampled_probs_host, compute_stream_); if (num_top_probs > 0) { - CopyArray(/*src=*/top_prob_probs_device, /*dst=*/top_prob_probs_host); - CopyArray(/*src=*/top_prob_indices_device, /*dst=*/top_prob_indices_host); + CopyArray(/*src=*/top_prob_probs_device, /*dst=*/top_prob_probs_host, compute_stream_); + CopyArray(/*src=*/top_prob_indices_device, /*dst=*/top_prob_indices_host, compute_stream_); } } @@ -316,6 +352,10 @@ class GPUSampler : public SamplerObj { NDArray top_prob_offsets_device_; // The event trace recorder for requests. */ Optional trace_recorder_; + // The device stream for the default computation operations. + TVMStreamHandle compute_stream_ = nullptr; + // The device stream for copying auxiliary data structure to GPU. + TVMStreamHandle copy_stream_ = nullptr; const float eps_ = 1e-5; }; From 2600a70dcb28301f21d446c4b08f67734436793c Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 31 Mar 2024 05:49:27 +0800 Subject: [PATCH 132/531] [SLM] Stablelm Multi-GPU support (#2052) This PR enables TP function of Stablelm model. --- .../mlc_llm/model/stable_lm/stablelm_model.py | 58 ++++++++++++++++--- 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/python/mlc_llm/model/stable_lm/stablelm_model.py b/python/mlc_llm/model/stable_lm/stablelm_model.py index 8589fbc501..10e16cded6 100644 --- a/python/mlc_llm/model/stable_lm/stablelm_model.py +++ b/python/mlc_llm/model/stable_lm/stablelm_model.py @@ -13,6 +13,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase from mlc_llm.support.style import bold @@ -33,6 +34,7 @@ class StableLmConfig(ConfigBase): # pylint: disable=too-many-instance-attribute rope_theta: int intermediate_size: int use_qkv_bias: bool = False # Default to False for Stable-LM 3B model + head_dim: int = 0 context_window_size: int = 0 prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 @@ -57,6 +59,9 @@ def __post_init__(self): "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( "%s defaults to %s (%d)", @@ -83,9 +88,10 @@ class StableLmAttention(nn.Module): # pylint: disable=too-many-instance-attribu def __init__(self, config: StableLmConfig): self.hidden_size = config.hidden_size self.rope_theta = config.rope_theta - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads + self.tensor_parallel_shards = config.tensor_parallel_shards + self.head_dim = config.head_dim + self.num_heads = config.num_attention_heads // self.tensor_parallel_shards + self.num_key_value_heads = config.num_key_value_heads // self.tensor_parallel_shards self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.rotary_ndims = int(config.partial_rotary_factor * self.head_dim) @@ -94,7 +100,7 @@ def __init__(self, config: StableLmConfig): out_features=(self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=config.use_qkv_bias, ) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads @@ -111,7 +117,7 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: class StableLmMLP(nn.Module): def __init__(self, config: StableLmConfig): - self.intermediate_size = config.intermediate_size + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear( in_features=config.hidden_size, out_features=2 * self.intermediate_size, @@ -133,13 +139,45 @@ def __init__(self, config: StableLmConfig): self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) + def _set_tp(): + def _set(layer, hint): + layer.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_heads * hd + k = self.self_attn.num_key_value_heads * hd + v = self.self_attn.num_key_value_heads * hd + i = self.mlp.intermediate_size + _set( + self.self_attn.qkv_proj.weight, + tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]), + ) + if config.use_qkv_bias: + _set( + self.self_attn.qkv_proj.bias, + tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]), + ) + _set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1)) + _set( + self.mlp.gate_up_proj.weight, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0) + ) + _set(self.mlp.down_proj.weight, tp.ShardSingleDim("_shard_mlp_down", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) - hidden_states = out + hidden_states + hidden_states = self._apply_residual(out, residual=hidden_states) out = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = out + hidden_states + hidden_states = self._apply_residual(out, residual=hidden_states) return hidden_states + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + class StableLmModel(nn.Module): def __init__(self, config: StableLmConfig): @@ -168,7 +206,7 @@ def __init__(self, config: StableLmConfig): self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads - self.head_dim = self.hidden_size // self.num_attention_heads + self.head_dim = config.head_dim self.vocab_size = config.vocab_size self.rope_theta = config.rope_theta self.tensor_parallel_shards = config.tensor_parallel_shards @@ -196,6 +234,8 @@ def batch_forward( return logits def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.model.embed_tokens(input_ids) def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): @@ -224,6 +264,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache From 9ecc00edd3cff739cfbd4ae781409d98060a8ac2 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 30 Mar 2024 19:22:10 -0400 Subject: [PATCH 133/531] [KVCache] Introducing single page copy func for KV cache fork (#2060) This PR introduces the single page copy TIR function for KV cache. This function is helpful for sequence fork at specified positions. NOTE: this PR is a breaking change, so you will need to re-compile your model and update TVM or the MLC-AI pip package to the latest. Related PR: apache/tvm#16813 Co-authored-by: Yaxing Cai --- 3rdparty/tvm | 2 +- python/mlc_llm/nn/kv_cache.py | 42 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 6d47d37dfe..5400532c4b 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 6d47d37dfe0e8f7bd079859d2aa744531887dacb +Subproject commit 5400532c4ba37e8a30fcaac488c2ecb05a307e4f diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index 206e5d4958..4a058c6e03 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -244,6 +244,7 @@ def __init__( # pylint: disable=too-many-locals rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward"), rx.extern("flashinfer.merge_state_in_place"), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), # fmt: on # pylint: enable=line-too-long @@ -347,6 +348,7 @@ def __init__( # pylint: disable=too-many-locals bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_ragged"), bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, dtype, target), "tir_attention_merge_state"), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), # fmt: on # pylint: enable=line-too-long @@ -1539,3 +1541,43 @@ def apply_to_md(sch, block): apply_to_md(sch, sch.get_block("lse_store")) return sch.mod["main"].with_attr("tir.is_scheduled", 1) + + +def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): + tx = get_max_num_threads_per_block(target) + + @T.prim_func + def copy_single_page( + var_pages: T.handle, + src_page_id: T.int64, + tgt_page_id: T.int64, + copy_length: T.int64, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) + + for b in T.thread_binding( + (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" + ): + for t in T.thread_binding(tx, thread="threadIdx.x"): + with T.block("copy"): + vh = T.axis.spatial( + num_heads, + T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), + ) + vp = T.axis.spatial( + copy_length, + (b * tx + t) % (copy_length * head_dim) // head_dim, + ) + vd = T.axis.spatial( + head_dim, + T.Cast( + "int32", + (b * tx + t) % head_dim, + ), + ) + pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd] + pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd] + + return copy_single_page From e370ac719a47352aaa07f46424bbd2d15290c2d1 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Sat, 30 Mar 2024 19:39:46 -0400 Subject: [PATCH 134/531] [Python] Implement testing.DebugChat for end-to-end model debugging (#2056) --- python/mlc_llm/testing/__init__.py | 3 + python/mlc_llm/testing/debug_chat.py | 459 +++++++++++++++++++++++++++ 2 files changed, 462 insertions(+) create mode 100644 python/mlc_llm/testing/__init__.py create mode 100644 python/mlc_llm/testing/debug_chat.py diff --git a/python/mlc_llm/testing/__init__.py b/python/mlc_llm/testing/__init__.py new file mode 100644 index 0000000000..e803641043 --- /dev/null +++ b/python/mlc_llm/testing/__init__.py @@ -0,0 +1,3 @@ +""" +Test and debug tools for MLC LLM +""" diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py new file mode 100644 index 0000000000..51e7bae586 --- /dev/null +++ b/python/mlc_llm/testing/debug_chat.py @@ -0,0 +1,459 @@ +"""Debug compiled models with TVM instrument""" + +import json +import random +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import tvm +from tvm import relax +from tvm.contrib import tvmjs +from tvm.runtime import Device, Module, Object, ShapeTuple +from tvm.runtime.relax_vm import VirtualMachine + +from mlc_llm.chat_module import ( + ChatConfig, + GenerationConfig, + _get_chat_config, + _get_generation_config, + _get_model_path, +) +from mlc_llm.conversation_template import ConvTemplateRegistry +from mlc_llm.help import HELP +from mlc_llm.serve.entrypoints import entrypoint_utils +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.support.auto_device import detect_device +from mlc_llm.support.style import green, red +from mlc_llm.tokenizer import Tokenizer + + +def _extract_metadata(mod: Module): + return json.loads(VirtualMachine(mod, tvm.runtime.device("cpu"))["_metadata"]()) + + +def _load_params( + model_weight_path: str, device: Device, model_metadata: Dict[str, Any] +) -> List[tvm.nd.NDArray]: + params, meta = tvmjs.load_ndarray_cache(model_weight_path, device) + param_names = [param["name"] for param in model_metadata["params"]] + assert len(param_names) == meta["ParamSize"] + + plist = [] + for param_name in param_names: + plist.append(params[param_name]) + return plist + + +def _get_tvm_module( + model_weight_path: str, lib_path: str, device: Device, instrument: tvm.runtime.PackedFunc +): + ex = tvm.runtime.load_module(lib_path) + vm = relax.VirtualMachine(ex, device) + vm.set_instrument(instrument) + metadata = _extract_metadata(ex) + params = _load_params(model_weight_path, device, metadata) + return vm.module, params, metadata + + +class DefaultDebugInstrument: + """The default debug instrument to use if users don't specify + a customized one. + + This debug instrument will dump the arguments and output of each + VM Call instruction into a .npz file. It will also alert the user + if any function outputs are NaN or INF. + """ + + def __init__(self, debug_out: Path): + """Constructor + + Parameters + ---------- + debug_out : Path + the directory to dump the .npz files + """ + self.counter = 0 + self.first_nan_occurred = False + self.first_inf_occurred = False + self.debug_out = debug_out + debug_out.mkdir(exist_ok=True, parents=True) + + def reset(self, debug_out: Path): + """Reset the state of the Instrument class + + Parameters + ---------- + debug_out : Path + the directory to dump the .npz files + """ + self.counter = 0 + self.first_nan_occurred = False + self.first_inf_occurred = False + self.debug_out = debug_out + debug_out.mkdir(exist_ok=True, parents=True) + + def __call__(self, func, name, before_run, ret_val, *args): + # Determine what functions to look at + if before_run: # Whether before the function is called or after + return + if self.first_nan_occurred: + return + if self.first_inf_occurred: + return + if name.startswith("vm.builtin.") and "attention_with_fused_qkv" not in name: + return + + # Decide what to print or save about the function's arguments (where args[-1] is the + # buffer we write the result to) + func_name = f"f{self.counter}_{name}" + + # Write your own behavior below. For example, we can count the number of INF/NaN in args[-1] + num_nans = np.sum(np.isnan(args[-1].numpy())) + num_infs = np.sum(np.isinf(args[-1].numpy())) + if num_nans > 0: + print(f"{red(f'{func_name} has NaN')}: {num_nans}") + self.first_nan_occurred = True + if num_infs > 0: + print(f"{red(f'{func_name} has INF')}: {num_infs}") + self.first_inf_occurred = True + + # Save the the arguments to npz + arg_dict = {} + for i, arg in enumerate(args): + if isinstance(arg, tvm.nd.NDArray): + arg_dict[f"arg_{i}"] = arg.numpy() + + np.savez(self.debug_out / f"{func_name}.npz", **arg_dict) + + self.counter += 1 + + +class DebugChat: # pylint: disable=too-many-instance-attributes, too-few-public-methods + """A chat interface used only for debugging purpose. + + It debugs autoregressive decoding fully in Python via the prefill and + decode interface. It supports debugging instrument (either default or + customized) to dump intermediate values for each VM function call. + + Given a prompt, it also prints out the parsed prompt, input tokens, output + tokens and output text. + + Sample usage: + + dc = DebugChat( + model="./dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + debug_dir=Path("./debug-llama-2"), + model_lib_path="./dist/llama-2-7b-chat-q4f16_1-metal.so", + ) + dc.generate("hello world", 3) + """ + + def __init__( # pylint: disable=too-many-arguments + self, + model: str, + model_lib_path: str, + debug_dir: Path, + device: Optional[str] = "auto", + chat_config: Optional[ChatConfig] = None, + debug_instrument: Optional[Any] = None, + ): + """_summary_ + + Parameters + ---------- + model: str + The model folder after compiling with MLC-LLM build process. The parameter + can either be the model name with its quantization scheme + (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model + folder. In the former case, we will use the provided name to search + for the model folder over possible paths. + + model_lib_path : str + The full path to the model library file to use (e.g. a ``.so`` file). + + debug_dir: Path + The output folder to store the dumped debug files. + + device : Optional[str] + The description of the device to run on. User should provide a string in the + form of 'device_name:device_id' or 'device_name', where 'device_name' is one of + 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the + local device), and 'device_id' is the device id to run on. If no 'device_id' + is provided, it will be set to 0 by default. + + chat_config : Optional[ChatConfig] + A ``ChatConfig`` instance partially filled. Will be used to override the + ``mlc-chat-config.json``. + + debug_instrument : Optional[Any] + An instrument function that will be called before/after each Call instruction. + The function have the following signature: + + .. code:: python + + def instrument( + func: Union[VMClosure, PackedFunc], + func_symbol: str, + before_run: bool, + ret_value: any, + *args) -> bool: + pass + + The instrument takes the following parameters: + - func: function object to be called. + - func_symbol: the symbol name of the function. + - before_run: whether it is before or after call. + - ret_value: the return value of the call, only valid after run. + - args: the arguments being passed to call. + """ + self.debug_dir = debug_dir + self.device = detect_device(device) + self.instrument = ( + debug_instrument if debug_instrument else DefaultDebugInstrument(debug_dir / "prefill") + ) + self.mod, self.params, self.metadata = _get_tvm_module( + model, model_lib_path, self.device, self.instrument + ) + self.model_path, self.config_file_path = _get_model_path(model) + self.chat_config = _get_chat_config(self.config_file_path, chat_config) + conv_template = self.chat_config.conv_template + self.conversation = ( + ConvTemplateRegistry.get_conv_template(conv_template) + if isinstance(conv_template, str) + else conv_template + ) + self.tokenizer = Tokenizer(self.model_path) + + self.add_sequence_func = tvm.get_global_func("vm.builtin.kv_state_add_sequence") + self.begin_forward_func = tvm.get_global_func("vm.builtin.kv_state_begin_forward") + self.end_forward_func = tvm.get_global_func("vm.builtin.kv_state_end_forward") + self.nd_view_func = tvm.get_global_func("vm.builtin.reshape") + self.sample_topp_from_prob_func = tvm.get_global_func("vm.builtin.sample_top_p_from_prob") + + try: + self.embed_func = self.mod["embed"] + except AttributeError as exc: + raise RuntimeError("DebugChat only supports separate embedding layer") from exc + + self.prefill_func = self.mod["prefill"] + self.decode_func = self.mod["decode"] + self.create_kv_cache_func = None + if self.mod.implements_function("create_flashinfer_paged_kv_cache"): + self.create_kv_cache_func = self.mod["create_flashinfer_paged_kv_cache"] + elif self.mod.implements_function("create_tir_paged_kv_cache"): + self.create_kv_cache_func = self.mod["create_tir_paged_kv_cache"] + else: + # TODO: Support RNN KVState # pylint: disable=fixme + raise RuntimeError("DebugChat cannot find create KV cache function") + + self.appeared_token_freq: Dict[int, int] = {} + + def _tokenize(self, prompt: str) -> tvm.nd.array: + print("======================= Starts Tokenization & Embedding =======================") + # Step 0. Generate prompt string using conversation template + self.conversation.messages.append(("user", prompt)) + self.conversation.messages.append(("assistant", None)) + with open(self.config_file_path, "r", encoding="utf-8") as file: + config = json.load(file) + parsed_prompt = self.conversation.as_prompt(config) + print( + "Parsed prompt using conversation template " + f"{green(self.conversation.name)}: {parsed_prompt}" + ) + tokens = entrypoint_utils.process_prompts(parsed_prompt, self.tokenizer.encode) + + # TODO: Handle ImageData in DebugChat # pylint: disable=fixme + assert len(tokens) == 1, "DebugChat will only handle TextData for now" + if self.conversation.system_prefix_token_ids is not None: + tokens[0] = self.conversation.system_prefix_token_ids + tokens[0] + + tokens = tvm.nd.array(np.array(tokens[0]).astype("int32"), device=self.device) + return tokens + + def _embed(self, tokens: tvm.nd.array) -> Tuple[tvm.nd.NDArray, int]: + input_len = tokens.shape[0] + embedding = self.embed_func(tokens, self.params) + embedding = self.nd_view_func(embedding, ShapeTuple([1, input_len, embedding.shape[1]])) + return embedding, input_len + + def _prefill(self, embedding: tvm.nd.NDArray, input_len: int): + print("======================= Starts Prefill =======================") + seq_len_shape = ShapeTuple([input_len]) + max_num_sequence = 1 + page_size = 16 + sliding_window_size = ( + self.chat_config.sliding_window_size + if self.chat_config.sliding_window_size + else self.metadata["sliding_window_size"] + ) + context_window_size = ( + self.chat_config.context_window_size + if self.chat_config.context_window_size + else self.metadata["context_window_size"] + ) + prefill_chunk_size = ( + self.chat_config.prefill_chunk_size + if self.chat_config.prefill_chunk_size + else self.metadata["prefill_chunk_size"] + ) + max_total_sequence_length = ( + sliding_window_size if context_window_size == -1 else context_window_size + ) + support_sliding_window = int(sliding_window_size != -1) + + kv_caches = self.create_kv_cache_func( + ShapeTuple([max_num_sequence]), + ShapeTuple([max_total_sequence_length]), + ShapeTuple([prefill_chunk_size]), + ShapeTuple([page_size]), + ShapeTuple([support_sliding_window]), + ) + self.add_sequence_func(kv_caches, 0) + self.begin_forward_func(kv_caches, ShapeTuple([0]), seq_len_shape) + logits, kv_caches = self.prefill_func(embedding, kv_caches, self.params) + self.end_forward_func(kv_caches) + return logits, kv_caches + + def _decode(self, token: int, kv_caches: Object): + embedding, _ = self._embed( + tvm.nd.array(np.array([token]).astype("int32"), device=self.device) + ) + self.begin_forward_func(kv_caches, ShapeTuple([0]), ShapeTuple([1])) + logits, kv_caches = self.decode_func(embedding, kv_caches, self.params) + self.end_forward_func(kv_caches) + return logits + + def _softmax_with_temperature(self, logits: np.ndarray, temperature: float): + # Adjust logits based on the temperature + logits = np.array(logits) / temperature + logits -= np.max(logits, axis=-1, keepdims=True) + + exp_logits = np.exp(logits, logits) + exp_logits /= np.sum(exp_logits, axis=-1, keepdims=True) + return exp_logits + + def _apply_presence_and_freq_penalty( + self, logits: np.ndarray, presence_penalty: float, freq_penalty: float + ): + for token_id, freq in self.appeared_token_freq.items(): + logits[:, :, token_id] -= freq * freq_penalty + presence_penalty + + def _sample_token_from_logits( + self, logits: tvm.nd.NDArray, generation_config: GenerationConfig + ): + logits_np = logits.numpy() + temperature = generation_config.temperature if generation_config.temperature else 1.0 + top_p = generation_config.top_p if generation_config.top_p else 0.95 + presence_penalty = generation_config.presence_penalty + frequency_penalty = generation_config.frequency_penalty + + if presence_penalty != 0.0 or frequency_penalty != 0.0: + self._apply_presence_and_freq_penalty(logits_np, presence_penalty, frequency_penalty) + + self._softmax_with_temperature(logits_np, temperature) + logits = logits.copyfrom(logits_np) + next_token = self.sample_topp_from_prob_func(logits, top_p, random.random()) + return next_token + + def generate( + self, + prompt: str, + generate_length: int, + generation_config: Optional[GenerationConfig] = None, + ): + """Generates the response from the model given a user prompt. User will need to + specify the generation length for debugging purpose. For example, a generation + length of 3 will include 1 prefill step and 2 decode steps. + + Parameters + ---------- + prompt : str + The user input prompt. + + generate_length : int + How many tokens to generate. + + generation_config : Optional[GenerationConfig] + Will be used to override the GenerationConfig in ``mlc-chat-config.json``. + """ + out_tokens = [] + + input_tokens = self._tokenize(prompt) + print(f"{green('Input tokens')}: {input_tokens.numpy()}") + embedding, input_len = self._embed(input_tokens) + logits, kv_caches = self._prefill(embedding, input_len) + generation_config = _get_generation_config(self.chat_config, generation_config) + next_token = self._sample_token_from_logits(logits, generation_config) + out_tokens.append(next_token) + path_str = (self.debug_dir / "prefill").as_posix() + print(f"Debug instrument output dumped to {green(path_str)}") + + print("======================= Starts Decode =======================") + for i in range(generate_length - 1): + self.instrument.reset(self.debug_dir / f"decode_{i}") + logits = self._decode(next_token, kv_caches) + generation_config = _get_generation_config(self.chat_config, generation_config) + next_token = self._sample_token_from_logits(logits, generation_config) + out_tokens.append(next_token) + path_str = (self.debug_dir / f"decode_{i}").as_posix() + print(f"Debug instrument output dumped to {green(path_str)}") + + if next_token in self.conversation.stop_token_ids: + break + + print(f"{green('Generated output tokens')}: {np.array(out_tokens)}") + + out_text = self.tokenizer.decode(out_tokens) + print(f"{green('Generated output text')}: {out_text}") + + +def main(): + """The main function to start a DebugChat CLI""" + + parser = ArgumentParser("MLC LLM Chat Debug Tool") + parser.add_argument( + "prompt", + type=str, + help="The user input prompt.", + ) + parser.add_argument( + "--generate-len", type=int, help="Number of output tokens to generate.", required=True + ) + parser.add_argument( + "--model", + type=str, + help="An MLC model directory that contains `mlc-chat-config.json`", + required=True, + ) + parser.add_argument( + "--model-lib-path", + type=str, + help="The full path to the model library file to use (e.g. a ``.so`` file).", + required=True, + ) + parser.add_argument( + "--debug-dir", + type=str, + help="The output folder to store the dumped debug files.", + required=True, + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help=HELP["device_compile"] + ' (default: "%(default)s")', + ) + parsed = parser.parse_args() + dc = DebugChat( + model=parsed.model, + model_lib_path=parsed.model_lib_path, + debug_dir=Path(parsed.debug_dir), + device=parsed.device, + ) + + dc.generate(parsed.prompt, parsed.generate_len) + + +if __name__ == "__main__": + main() From 069b73a5dc1b6486d54595bf88e8369925e41afe Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Sun, 31 Mar 2024 08:54:10 -0700 Subject: [PATCH 135/531] [Docs] Fix docs for python server and rest call (#2066) This PR updates the MLC serve documentation for server launching. --- docs/deploy/rest.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index 959c235201..e24d65afb5 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -32,9 +32,9 @@ To launch the MLC Server for MLC-Chat, run the following command in your termina .. code:: bash - python -m mlc_llm.serve.server --model MODEL --model-lib-path MODEL_LIB_PATH [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] + python -m mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] ---model The model folder after compiling with MLC-LLM build process. The parameter +MODEL The model folder after compiling with MLC-LLM build process. The parameter can either be the model name with its quantization scheme (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model folder. In the former case, we will use the provided name to search @@ -89,7 +89,7 @@ The REST API provides the following endpoints: print("Error:", response.status_code) -.. http:get:: /v1/chat/completions +.. http:post:: /v1/chat/completions ------------------------------------------------ From 3e91e70152e3edfdfb4fc07a47ac8191c7355eb0 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 31 Mar 2024 13:05:01 -0400 Subject: [PATCH 136/531] [CI] Enable submodule clone for WASM model compilation (#2068) The incoming WASM runtime requires 3rdparty for builds. This PR enables the submodule clone for WASM model compilation in CI. --- ci/jenkinsfile.groovy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/jenkinsfile.groovy b/ci/jenkinsfile.groovy index ec8210c172..0203eba72d 100644 --- a/ci/jenkinsfile.groovy +++ b/ci/jenkinsfile.groovy @@ -225,7 +225,7 @@ stage('Model Compilation') { 'WASM': { node('CPU-SMALL') { ws(per_exec_ws('mlc-llm-compile-wasm')) { - init_git(false) + init_git(true) sh(script: "ls -alh", label: 'Show work directory') unpack_lib('mlc_wheel_vulkan', 'wheels/*.whl') sh(script: "${run_cpu} conda env export --name ci-unittest", label: 'Checkout version') From ed62796189ac99046ee94abc18e7c0d7fdf4a765 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 31 Mar 2024 13:20:02 -0400 Subject: [PATCH 137/531] [Serve] Fork sequence at specified positions (#2067) With PagedKVCache supporting fork at a specified position, this PR updates `Model` interface accordingly. The fork position defaults to -1, which means the last position. --- cpp/serve/model.cc | 4 ++-- cpp/serve/model.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index ad2f9b2a79..6e93061f31 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -340,8 +340,8 @@ class ModelImpl : public ModelObj { void AddNewSequence(int64_t seq_id) final { ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); } - void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final { - ft_.kv_cache_fork_sequence_func_(kv_cache_, parent_seq_id, child_seq_id); + void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos) final { + ft_.kv_cache_fork_sequence_func_(kv_cache_, parent_seq_id, child_seq_id, fork_pos); } void RemoveSequence(int64_t seq_id) final { diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 11646a6663..4e57d499ef 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -138,7 +138,7 @@ class ModelObj : public Object { virtual void AddNewSequence(int64_t seq_id) = 0; /*! \brief Fork a sequence from a given parent sequence. */ - virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) = 0; + virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) = 0; /*! \brief Remove the given sequence from the KV cache in the model. */ virtual void RemoveSequence(int64_t seq_id) = 0; From 5243b27a383c1d111503a35d1c1ad74cc4455962 Mon Sep 17 00:00:00 2001 From: Linyu Wu <95223577+Celve@users.noreply.github.com> Date: Mon, 1 Apr 2024 03:56:52 +0800 Subject: [PATCH 138/531] [SLM] Add support for RWKV6 model (#1977) * [SLM]: Support for rwkv tokenizer * [SLM] RWKV6 World Support --- cpp/llm_chat.cc | 4 +- python/mlc_llm/interface/gen_config.py | 67 ++- python/mlc_llm/model/model.py | 14 + python/mlc_llm/model/rwkv6/__init__.py | 0 python/mlc_llm/model/rwkv6/rwkv6_loader.py | 70 +++ python/mlc_llm/model/rwkv6/rwkv6_model.py | 473 ++++++++++++++++++ .../mlc_llm/model/rwkv6/rwkv6_quantization.py | 37 ++ 7 files changed, 662 insertions(+), 3 deletions(-) create mode 100644 python/mlc_llm/model/rwkv6/__init__.py create mode 100644 python/mlc_llm/model/rwkv6/rwkv6_loader.py create mode 100644 python/mlc_llm/model/rwkv6/rwkv6_model.py create mode 100644 python/mlc_llm/model/rwkv6/rwkv6_quantization.py diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 8ec3c5ec1d..8cadbe8df4 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -264,8 +264,8 @@ struct FunctionTable { this->kv_cache_begin_forward_func_ = get_global_func("vm.builtin.kv_state_begin_forward"); this->kv_cache_end_forward_func_ = get_global_func("vm.builtin.kv_state_end_forward"); this->fkvcache_array_popn_ = get_global_func("vm.builtin.kv_state_popn"); - // TODO(mlc-team): enable backtracing when using paged kvcache - this->support_backtracking_kv_ = true; + // note: We use max sequence length = 1 for RNN state for now, so disable back tracking + this->support_backtracking_kv_ = this->use_kv_state == KVStateKind::kAttention; } } diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index 890b467688..e0d401920a 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -2,6 +2,7 @@ import dataclasses import json +import re import shutil from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -75,6 +76,59 @@ def apply_defaults(self) -> None: logger.info("[System default] Setting %s: %s", bold(key), value) +def check_string(s: str) -> bool: + """Check whether it's a string.""" + delimit = s[1] + if s[0] != "b" or s[-1] != delimit: + return False + for i in range(2, len(s) - 1): + if s[i] == delimit and s[i - 1] != "\\": + return False + return True + + +def txt2rwkv_tokenizer(vocab: Path, out: Path) -> None: + """Generate tokenizer_model from RWKV vocab file.""" + idx2token = {} + + with vocab.open("r", encoding="utf-8") as f: + lines = f.readlines() + + for l in lines: + idx = int(l[: l.index(" ")]) + raw = l[l.index(" ") : l.rindex(" ")].strip() + if check_string(raw): + x = eval(raw) # pylint: disable=eval-used + x = x.encode("utf-8") if isinstance(x, str) else x + assert isinstance(x, bytes) + assert len(x) == int(l[l.rindex(" ") :]) + idx2token[idx] = x + else: + raise ValueError("Unsupported vocab dictionary") + + with (out / "tokenizer_model").open("wb") as f: + import msgpack # pylint: disable=import-outside-toplevel,import-error + + msgpack.pack(idx2token, f) + + +def json2rwkv_tokenizer(vocab: Path, out: Path) -> None: + """Generate tokenizer_model from RWKV vocab file.""" + idx2token = {} + + with vocab.open("r", encoding="utf-8") as f: + data = json.load(f) + for key, value in data.items(): + x = key.encode("utf-8") if isinstance(key, str) else key + assert isinstance(x, bytes) + idx2token[int(value)] = x + + with (out / "tokenizer_model").open("wb") as f: + import msgpack # pylint: disable=import-outside-toplevel,import-error + + msgpack.pack(idx2token, f) + + def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements config: Path, model: Model, @@ -145,7 +199,18 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b logger.info("%s tokenizer config: %s. Copying to %s", FOUND, file, bold(str(dest))) else: logger.info("%s tokenizer config: %s", NOT_FOUND, file) - # 3.2. If we have `tokenizer.model` but not `tokenizer.json`, try convert it to + # 3.2. Generate `tokenizer_model` for rwkv if `rwkv_vocab_.*` is found + pattern = re.compile(r"rwkv_vocab_v\d{8}\.(json|txt)") + for item in config.parent.iterdir(): + if item.is_file() and pattern.match(item.name): + logger.info( + "%s RWKV vocab file: %s. Genetating %s", FOUND, item, bold("tokenizer_model") + ) + if item.name.endswith(".txt"): + txt2rwkv_tokenizer(item, output) + else: + json2rwkv_tokenizer(item, output) + # 3.3. If we have `tokenizer.model` but not `tokenizer.json`, try convert it to # `tokenizer.json` with `transformers`. tokenizer_json_file = config.parent / "tokenizer.json" tokenizer_model_file = config.parent / "tokenizer.model" diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 9e8d98daa4..946d8af787 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -23,6 +23,7 @@ from .qwen import qwen_loader, qwen_model, qwen_quantization from .qwen2 import qwen2_loader, qwen2_model, qwen2_quantization from .rwkv5 import rwkv5_loader, rwkv5_model, rwkv5_quantization +from .rwkv6 import rwkv6_loader, rwkv6_model, rwkv6_quantization from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization ModelConfig = Any @@ -308,4 +309,17 @@ class Model: "awq": llava_quantization.awq_quant, }, ), + "rwkv6": Model( + name="rwkv6", + model=rwkv6_model.RWKV6_ForCasualLM, + config=rwkv6_model.RWKV6Config, + source={ + "huggingface-torch": rwkv6_loader.huggingface, + "huggingface-safetensor": rwkv6_loader.huggingface, + }, + quantize={ + "no-quant": rwkv6_quantization.no_quant, + "group-quant": rwkv6_quantization.group_quant, + }, + ), } diff --git a/python/mlc_llm/model/rwkv6/__init__.py b/python/mlc_llm/model/rwkv6/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/rwkv6/rwkv6_loader.py b/python/mlc_llm/model/rwkv6/rwkv6_loader.py new file mode 100644 index 0000000000..47a85f3605 --- /dev/null +++ b/python/mlc_llm/model/rwkv6/rwkv6_loader.py @@ -0,0 +1,70 @@ +""" +This file specifies how MLC's RWKV6 parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +from ...loader import ExternMapping +from ...quantization import Quantization +from .rwkv6_model import RWKV6_ForCasualLM, RWKV6Config + + +def huggingface(model_config: RWKV6Config, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : RWKVConfig + The configuration of the Mistral model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = RWKV6_ForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params = model.export_tvm( # pylint: disable=unbalanced-tuple-unpacking + spec=model.get_default_spec() + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # rescale + if model_config.rescale_every > 0: + for name in ["feed_forward.value.weight", "attention.output.weight"]: + mlc_name = f"model.blocks.{i}.{name}" + hf_name = f"rwkv.blocks.{i}.{name}" + mlc_param = named_parameters[mlc_name] + + mapping.add_mapping( + mlc_name, + [hf_name], + functools.partial( + lambda x, dtype, t: x.astype(dtype) / (2**t), + dtype=mlc_param.dtype, + t=i // model_config.rescale_every, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + hf_name = mlc_name.replace("model", "rwkv") + mapping.add_mapping( + mlc_name, + [hf_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + return mapping diff --git a/python/mlc_llm/model/rwkv6/rwkv6_model.py b/python/mlc_llm/model/rwkv6/rwkv6_model.py new file mode 100644 index 0000000000..0e1887310d --- /dev/null +++ b/python/mlc_llm/model/rwkv6/rwkv6_model.py @@ -0,0 +1,473 @@ +"""Implementation for RWKV6 architecture.""" + +import dataclasses +from typing import Any, Dict, Optional, Tuple + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Object, Tensor, op +from tvm.script import tir as T + +from mlc_llm.nn.rnn_state import RNNState +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class StateID: + """State ID for RWKV6.""" + + ATT_X = 0 + ATT_KV = 1 + FFN_X = 2 + + +@dataclasses.dataclass +class RWKV6Config(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the RWKV6 model.""" + + hidden_size: int + intermediate_size: int + num_hidden_layers: int + vocab_size: int + model_version: str + tensor_parallel_shards: int = 1 + rescale_every: int = 0 + head_size: int = 64 + layer_norm_epsilon: float = 1e-5 + context_window_size: int = -1 # RWKV does not have context window limitation. + prefill_chunk_size: int = 4096 + num_heads: int = 0 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.model_version != "6_0": + raise ValueError(f"Only support RWKV v6_0, got {self.model_version}.") + self.intermediate_size = self.intermediate_size or int((self.hidden_size * 3.5)) // 32 * 32 + self.num_heads = ( + self.hidden_size // self.head_size if self.num_heads == 0 else self.num_heads + ) + if self.num_heads * self.head_size != self.hidden_size: + raise ValueError( + f"hidden_size ({self.hidden_size}) must be diisible " + f"by head_size ({self.head_size})" + ) + if self.tensor_parallel_shards != 1: + raise ValueError("Only support single deice at this moment.") + + +# pylint: disable=invalid-name, missing-docstring +# pylint: disable=too-many-arguments, too-many-locals, redefined-argument-from-local +def create_wkv6_func( + num_heads: int, + head_size: int, + dtype: str, + out_dtype: str, + state_dtype: str, +): + @T.prim_func + def wkv_func( + r: T.handle, + k: T.handle, + v: T.handle, + time_faaaa: T.handle, + w: T.handle, + state: T.handle, + out: T.handle, + out_state: T.handle, + ): + T.func_attr({"op_pattern": 8, "tir.noalias": True, "tir.is_scheduled": 1}) + batch_size, seq_len = T.int64(), T.int64() + # Inputs + r_buf = T.match_buffer(r, (batch_size, seq_len, num_heads, head_size), dtype=dtype) + k_buf = T.match_buffer(k, (batch_size, seq_len, num_heads, head_size), dtype=dtype) + v_buf = T.match_buffer(v, (batch_size, seq_len, num_heads, head_size), dtype=dtype) + time_faaaa_buf = T.match_buffer(time_faaaa, (num_heads, head_size), dtype="float32") + w_buf = T.match_buffer(w, (batch_size, seq_len, num_heads, head_size), dtype="float32") + state_buf = T.match_buffer( + state, (batch_size, num_heads, head_size, head_size), dtype=state_dtype + ) + # Outputs + out_buf = T.match_buffer(out, (batch_size, seq_len, num_heads, head_size), dtype=out_dtype) + out_state_buf = T.match_buffer( + out_state, (batch_size, num_heads, head_size, head_size), dtype=state_dtype + ) + for b in T.thread_binding(batch_size, thread="blockIdx.y"): + for h in T.thread_binding(num_heads, thread="blockIdx.x"): + for i in T.thread_binding(head_size, thread="threadIdx.x"): + for j in range(head_size): + with T.block("init_state"): + vb, vh, vi, vj = T.axis.remap("SSSS", [b, h, i, j]) + out_state_buf[vb, vh, vi, vj] = state_buf[vb, vh, vi, vj] + + for t in range(seq_len): + with T.block("comput"): + vb = T.axis.spatial(batch_size, b) + vt = T.axis.opaque(seq_len, t) + vh = T.axis.spatial(num_heads, h) + vi = T.axis.spatial(head_size, i) + out_buf[vb, vt, vh, vi] = 0 + + for k in range(head_size): + at = k_buf[vb, vt, vh, k] * v_buf[vb, vt, vh, vi] + out_buf[vb, vt, vh, vi] += T.cast( + r_buf[vb, vt, vh, k], out_dtype + ) * T.cast( + time_faaaa_buf[vh, k] * at + out_state_buf[vb, vh, vi, k], + out_dtype, + ) + out_state_buf[vb, vh, vi, k] = ( + at + w_buf[vb, vt, vh, k] * out_state_buf[vb, vh, vi, k] + ) + + return wkv_func + + +def token_shift(state: Tensor, x: Tensor): + seq_len = x.shape[1] + + def _te_token_shift(state: te.Tensor, x: te.Tensor): + return te.compute( + x.shape, + lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]), + ) + + return state if seq_len == 1 else op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) + + +def last_token(x: Tensor): + batch, seq_len, hidden_size = x.shape + assert batch == 1 + + def _te_last_token(x: te.Tensor): + return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j]) + + return x if seq_len == 1 else op.tensor_expr_op(_te_last_token, "last_token", [x]) + + +def unbind_to_five(x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + assert x.shape[0] == 5 + + def _te_get_ith(x: te.Tensor, i: int): + return te.compute((1, *x.shape[1:]), lambda _, j, k, l: x[i, j, k, l]) + + return ( + op.reshape(op.tensor_expr_op(_te_get_ith, "unbind_to_five", [x, 0]), x.shape[1:]), + op.reshape(op.tensor_expr_op(_te_get_ith, "unbind_to_five", [x, 1]), x.shape[1:]), + op.reshape(op.tensor_expr_op(_te_get_ith, "unbind_to_five", [x, 2]), x.shape[1:]), + op.reshape(op.tensor_expr_op(_te_get_ith, "unbind_to_five", [x, 3]), x.shape[1:]), + op.reshape(op.tensor_expr_op(_te_get_ith, "unbind_to_five", [x, 4]), x.shape[1:]), + ) + + +class RWKV6_FNN(nn.Module): + def __init__(self, config: RWKV6Config, layer_id: int): + super().__init__() + self.time_maa_k = nn.Parameter((1, 1, config.hidden_size)) + self.time_maa_r = nn.Parameter((1, 1, config.hidden_size)) + self.key = nn.Linear(config.hidden_size, config.hidden_size // 2 * 7, bias=False) + self.receptance = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.value = nn.Linear(config.hidden_size // 2 * 7, config.hidden_size, bias=False) + self.layer_id = layer_id + + def forward(self, x: Tensor, state: RNNState): + batch, _, hidden_size = x.shape + state_x = state.get(self.layer_id, StateID.FFN_X, (batch, hidden_size), x.dtype) + state_x = token_shift(state_x, x) + + state_x = state_x - x + xk = x + state_x * self.time_maa_k + xr = x + state_x * self.time_maa_r + + last_x = last_token(x).reshape(batch, hidden_size) + state = state.set(self.layer_id, StateID.FFN_X, last_x) + + r = op.sigmoid(self.receptance(xr)) + xv = op.square(op.relu(self.key(xk))) + return r * self.value(xv), state + + +class RWKV6_Attention(nn.Module): # pylint: disable=too-many-instance-attributes + """Attention layer for RWKV.""" + + def __init__(self, config: RWKV6Config, layer_id: int): + super().__init__() + self.time_maa_x = nn.Parameter((1, 1, config.hidden_size)) + self.time_maa_w = nn.Parameter((1, 1, config.hidden_size)) + self.time_maa_k = nn.Parameter((1, 1, config.hidden_size)) + self.time_maa_v = nn.Parameter((1, 1, config.hidden_size)) + self.time_maa_r = nn.Parameter((1, 1, config.hidden_size)) + self.time_maa_g = nn.Parameter((1, 1, config.hidden_size)) + self.time_maa_w1 = nn.Parameter((config.hidden_size, 160)) + self.time_maa_w2 = nn.Parameter((5, 32, config.hidden_size)) + self.time_decay_w1 = nn.Parameter((config.hidden_size, config.head_size)) + self.time_decay_w2 = nn.Parameter((config.head_size, config.hidden_size)) + self.time_decay = nn.Parameter((1, 1, config.hidden_size)) + self.time_faaaa = nn.Parameter((config.num_heads, config.head_size)) + + self.key = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.value = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.receptance = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.gate = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.output = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.ln_x = nn.GroupNorm(config.num_heads, config.hidden_size) + self.hidden_size = config.hidden_size + self.head_size = config.head_size + self.num_heads = config.num_heads + self.layer_id = layer_id + self.dtype = "float32" + + def forward(self, x: Tensor, state: RNNState): # pylint: disable=too-many-locals + batch, seq_len, hidden_size = x.shape + assert hidden_size == self.hidden_size + B, T, H, N = ( # pylint: disable=redefined-outer-name + batch, + seq_len, + self.head_size, + self.num_heads, + ) + state_x = state.get(self.layer_id, StateID.ATT_X, (batch, self.hidden_size), x.dtype) + state_x = token_shift(state_x, x) + state_x = state_x - x + xxx = x + state_x * self.time_maa_x + xxx = op.permute( + op.reshape(op.tanh(op.matmul(xxx, self.time_maa_w1)), (B, T, 5, -1)), [0, 2, 1, 3] + ) + xxx = op.permute( + op.matmul(xxx, self.time_maa_w2), axes=[1, 0, 2, 3] + ) # it's a batch matrix-matrix multiplication + mw, mk, mv, mr, mg = unbind_to_five(xxx) + + kv_state = state.get( + self.layer_id, + StateID.ATT_KV, + (batch, self.num_heads, self.head_size, self.head_size), + "float32", + ) + + xw = x + state_x * (self.time_maa_w + mw) + xk = x + state_x * (self.time_maa_k + mk) + xv = x + state_x * (self.time_maa_v + mv) + xr = x + state_x * (self.time_maa_r + mr) + xg = x + state_x * (self.time_maa_g + mg) + + r = op.reshape(self.receptance(xr), (B, T, N, H)) + k = op.reshape(self.key(xk), (B, T, N, H)) + v = op.reshape(self.value(xv), (B, T, N, H)) + g = op.silu(self.gate(xg)) + + w = op.reshape(self.time_decay, (1, N, H)).astype("float32") + op.reshape( + op.matmul(op.tanh(op.matmul(xw, self.time_decay_w1)), self.time_decay_w2), + (B, T, N, H), + ).astype("float32") + w = op.exp(op.negative(op.exp(w))) + # w = op.reshape(w, [B, T, N, H]) + + out, kv_state = op.tensor_ir_op( + create_wkv6_func( + num_heads=self.num_heads, + head_size=self.head_size, + dtype=self.dtype, + out_dtype="float32", + state_dtype="float32", + ), + "wkv6", + [r, k, v, self.time_faaaa, w, kv_state], + [ + Tensor.placeholder([B, T, N, H], "float32"), + Tensor.placeholder([B, N, H, H], "float32"), + ], + ) + + last_x = last_token(x).reshape(batch, hidden_size) + state = state.set(self.layer_id, StateID.ATT_X, last_x) + state = state.set(self.layer_id, StateID.ATT_KV, kv_state) + out = op.astype(self.ln_x(op.reshape(out, x.shape), channel_axis=-1, axes=[]), self.dtype) + return self.output(out * g), state + + def to(self, dtype: Optional[str] = None): + # RWKV uses special dtype, so we need to convert it. + if dtype is not None: + self.dtype = dtype + + self.time_maa_x.to(dtype) + self.time_maa_w.to(dtype) + self.time_maa_k.to(dtype) + self.time_maa_v.to(dtype) + self.time_maa_r.to(dtype) + self.time_maa_g.to(dtype) + self.time_maa_w1.to(dtype) + self.time_maa_w2.to(dtype) + self.time_decay_w1.to(dtype) + self.time_decay_w2.to(dtype) + self.key.to(dtype) + self.value.to(dtype) + self.receptance.to(dtype) + self.gate.to(dtype) + self.output.to(dtype) + + # These parameters are necessary to be converted to float32. + self.time_decay.to("float32") + self.time_faaaa.to("float32") + self.ln_x.to("float32") + + +class RWKV6_Layer(nn.Module): + def __init__(self, config: RWKV6Config, layer_id: int): + super().__init__() + if layer_id == 0: + self.pre_ln = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + ) + self.ln1 = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + ) + self.ln2 = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + ) + self.attention = RWKV6_Attention(config, layer_id) + self.feed_forward = RWKV6_FNN(config, layer_id) + self.layer_id = layer_id + self.rescale_every = config.rescale_every + + def forward(self, x: Tensor, state: RNNState) -> Tensor: + if self.layer_id == 0: + x = self.pre_ln(x) + att_x, state = self.attention(self.ln1(x), state) + x += att_x + ffn_x, state = self.feed_forward(self.ln2(x), state) + x += ffn_x + if self.rescale_every > 0 and (self.layer_id + 1) % self.rescale_every == 0: + x = x / 2.0 + return x, state + + +class RWKV6_Model(nn.Module): + """Exact same as LlamaModel.""" + + def __init__(self, config: RWKV6Config): + super().__init__() + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.blocks = nn.ModuleList( + [RWKV6_Layer(config, i) for i in range(config.num_hidden_layers)] + ) + self.ln_out = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_epsilon, + ) + + def forward(self, input_embed: Tensor, state: RNNState): + """Forward pass of the model, passing through all decoder layers.""" + hidden_states = input_embed + for block in self.blocks: + hidden_states, state = block(hidden_states, state) + return self.ln_out(hidden_states), state + + +class RWKV6_ForCasualLM(nn.Module): # pylint: disable=too-many-instance-attributes + """Same as LlamaForCausalLM, except for the use of sliding window attention.""" + + def __init__(self, config: RWKV6Config): + self.model = RWKV6_Model(config) + self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + self.hidden_size = config.hidden_size + self.num_heads = config.num_heads + self.head_size = config.head_size + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def embed(self, input_ids: Tensor): + return self.model.embeddings(input_ids) + + def forward(self, input_embed: Tensor, state: RNNState): + """Forward pass.""" + hidden_states, state = self.model(input_embed, state) + hidden_states = last_token(hidden_states) + logits = self.head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, state + + def prefill(self, input_embed: Tensor, state: RNNState): + """Prefilling the prompt.""" + return self.forward(input_embed, state) + + def decode(self, input_embed: Tensor, state: RNNState): + """Decoding step.""" + return self.forward(input_embed, state) + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + """Softmax.""" + return op.softmax(logits / temperature, axis=-1) + + def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Object: + """Create RNN state.""" + init_values = [ + op.zeros((self.hidden_size,), dtype=self.dtype), # ATT_X + op.zeros((self.num_heads, self.head_size, self.head_size), dtype="float32"), # ATT_KV + op.zeros((self.hidden_size,), dtype=self.dtype), # FFN_X + ] + return RNNState.create( + max_batch_size=max_batch_size, + num_hidden_layers=self.num_hidden_layers, + max_history=max_history, + init_values=init_values, + ) + + def get_default_spec(self): + batch_size = 1 + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor( + [batch_size, "seq_len", self.hidden_size], self.dtype + ), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([batch_size, 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor([batch_size, 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_rnn_state": { + "max_batch_size": int, + "max_history": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/rwkv6/rwkv6_quantization.py b/python/mlc_llm/model/rwkv6/rwkv6_quantization.py new file mode 100644 index 0000000000..ef67568a6f --- /dev/null +++ b/python/mlc_llm/model/rwkv6/rwkv6_quantization.py @@ -0,0 +1,37 @@ +"""This file specifies how MLC's RWKV6 parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from ...loader import QuantizeMapping +from ...quantization import GroupQuantize, NoQuantize +from .rwkv6_model import RWKV6_ForCasualLM, RWKV6Config + + +def group_quant( + model_config: RWKV6Config, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a RWKV4-architecture model using group quantization.""" + model: nn.Module = RWKV6_ForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: RWKV6Config, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a GPTBigCode model without quantization.""" + model: nn.Module = RWKV6_ForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map From 8cac74c04fde6d5985f54c764834ed3b0f5ca56d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sun, 31 Mar 2024 17:41:41 -0700 Subject: [PATCH 139/531] [Quantization] Reorganize utils code in group_quantization (#2055) --- .../quantization/group_quantization.py | 81 +++++----------- python/mlc_llm/quantization/utils.py | 95 ++++++++++++++++++- 2 files changed, 115 insertions(+), 61 deletions(-) diff --git a/python/mlc_llm/quantization/group_quantization.py b/python/mlc_llm/quantization/group_quantization.py index feb4b0216d..1da5174721 100644 --- a/python/mlc_llm/quantization/group_quantization.py +++ b/python/mlc_llm/quantization/group_quantization.py @@ -2,21 +2,24 @@ from dataclasses import dataclass from functools import partial -from typing import Any, Callable, List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union -from tvm import DataType, DataTypeCode, IRModule -from tvm import dlight as dl -from tvm import relax, te, tir, topi +from tvm import DataType, DataTypeCode, IRModule, relax, te, tir, topi from tvm.relax.frontend import nn from tvm.runtime import NDArray -from tvm.target import Target from mlc_llm.loader import QuantizeMapping from mlc_llm.nn import MixtralExperts from mlc_llm.support import logging -from mlc_llm.support import tensor_parallel as tp -from .utils import convert_uint_to_float, is_final_fc, is_moe_gate +from .utils import ( + apply_sharding, + compile_quantize_func, + convert_uint_to_float, + is_final_fc, + is_moe_gate, + pack_weight, +) logger = logging.getLogger(__name__) @@ -205,26 +208,6 @@ def _create_quantize_func() -> IRModule: bb.emit_func_output(gv) return bb.finalize() - def _compile_quantize_func(mod: IRModule) -> Callable: - if device_type in ["cuda", "rocm", "metal", "vulkan"]: - target = Target.current() - if target is None: - target = Target.from_device(device) - with target: - mod = dl.ApplyDefaultSchedule( # type: ignore # pylint: disable=not-callable - dl.gpu.Reduction(), - dl.gpu.GeneralReduction(), - dl.gpu.Fallback(), - )(mod) - elif device_type == "cpu": - target = "llvm" - mod = relax.transform.LegalizeOps()(mod) - else: - raise NotImplementedError(f"Device type {device_type} is not supported") - ex = relax.build(mod, target=target) - vm = relax.VirtualMachine(ex, device) # pylint: disable=invalid-name - return vm["main"] - key = ( f"({weight.shape}, {weight.dtype}, {device_type}, " f"axis={axis}, output_transpose={output_transpose})" @@ -232,7 +215,7 @@ def _compile_quantize_func(mod: IRModule) -> Callable: quantize_func = self._quantize_func_cache.get(key, None) if quantize_func is None: logger.info("Compiling quantize function for key: %s", key) - quantize_func = _compile_quantize_func(_create_quantize_func()) + quantize_func = compile_quantize_func(_create_quantize_func(), device=device) self._quantize_func_cache[key] = quantize_func return quantize_func(weight) @@ -247,7 +230,6 @@ def _quantize( # pylint: disable=too-many-locals shape = weight.shape # pylint: disable=invalid-name axis = axis if axis >= 0 else len(shape) + axis k = shape[axis] - quantize_dtype = DataType(self.quantize_dtype) # compute scale per group r = te.reduce_axis((0, self.group_size), name="r") # pylint: disable=invalid-name num_group = tir.ceildiv(k, self.group_size) @@ -285,23 +267,15 @@ def _quantize( # pylint: disable=too-many-locals ).astype(self.storage_dtype), ) # compute quantized weight per storage - r = te.reduce_axis((0, self.num_elem_per_storage), name="r") # pylint: disable=invalid-name num_storage = self.num_storage_per_group * num_group quantized_weight_shape = (*shape[:axis], num_storage, *shape[axis + 1 :]) - quantized_weight = te.compute( - shape=quantized_weight_shape, - fcompute=lambda *idx: tir.sum( - tir.if_then_else( - idx[axis] * self.num_elem_per_storage + r < k, - scaled_weight( - *idx[:axis], idx[axis] * self.num_elem_per_storage + r, *idx[axis + 1 :] - ) - << (r * quantize_dtype.bits), - 0, - ), - axis=r, - ), - name="weight", + quantized_weight = pack_weight( + scaled_weight, + axis=axis, + num_elem_per_storage=self.num_elem_per_storage, + weight_dtype=self.quantize_dtype, + storage_dtype=self.storage_dtype, + out_shape=quantized_weight_shape, ) if output_transpose: if len(quantized_weight.shape) != 2 or len(scale.shape) != 2: @@ -378,8 +352,8 @@ def from_linear(src: nn.Linear, config: GroupQuantize) -> "GroupQuantizeLinear": quantized_linear.bias.attrs = src.bias.attrs if "shard_strategy" in src.weight.attrs: shard = src.weight.attrs["shard_strategy"] - _apply_sharding(shard, f"{shard.name}_q_weight", quantized_linear.q_weight) - _apply_sharding(shard, f"{shard.name}_q_scale", quantized_linear.q_scale) + apply_sharding(shard, f"{shard.name}_q_weight", quantized_linear.q_weight) + apply_sharding(shard, f"{shard.name}_q_scale", quantized_linear.q_scale) return quantized_linear def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name @@ -607,8 +581,8 @@ def from_mixtral_experts( ) if "shard_strategy" in src.weight.attrs: shard = src.weight.attrs["shard_strategy"] - _apply_sharding(shard, f"{shard.name}_q_weight", quantized_mistral_experts.q_weight) - _apply_sharding(shard, f"{shard.name}_q_scale", quantized_mistral_experts.q_scale) + apply_sharding(shard, f"{shard.name}_q_weight", quantized_mistral_experts.q_weight) + apply_sharding(shard, f"{shard.name}_q_scale", quantized_mistral_experts.q_scale) return quantized_mistral_experts def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name @@ -653,14 +627,3 @@ def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disa indptr_dtype=indptr.dtype, group_size=self.group_size, ) - - -def _apply_sharding(shard, name: str, weight: nn.Parameter): - if isinstance(shard, tp.ShardSingleDim): - weight.attrs["shard_strategy"] = tp.ShardSingleDim( - name=name, - dim=shard.dim, - segs=shard.segs, - ) - else: - raise NotImplementedError(f"Unknowing sharding strategy: {shard}") diff --git a/python/mlc_llm/quantization/utils.py b/python/mlc_llm/quantization/utils.py index 8373b4d62c..260c9a6b45 100644 --- a/python/mlc_llm/quantization/utils.py +++ b/python/mlc_llm/quantization/utils.py @@ -1,8 +1,15 @@ """Common utilities for quantization""" -from typing import List, Optional +from typing import Callable, List, Optional, Sequence -from tvm import te, tir +from tvm import IRModule +from tvm import dlight as dl +from tvm import relax, te, tir +from tvm.relax.frontend import nn +from tvm.runtime import DataType +from tvm.target import Target + +from mlc_llm.support import tensor_parallel as tp def convert_uint_to_float( # pylint: disable=too-many-arguments @@ -50,3 +57,87 @@ def is_final_fc(name: str) -> bool: def is_moe_gate(name: str) -> bool: """Check whether the parameter is the MoE gate layer.""" return name.endswith("gate") + + +def compile_quantize_func(mod: IRModule, device) -> Callable: + """Compile a quantization function for a given device.""" + device_type = device.MASK2STR[device.device_type] + if device_type in ["cuda", "rocm", "metal", "vulkan"]: + target = Target.current() + if target is None: + target = Target.from_device(device) + with target: + mod = dl.ApplyDefaultSchedule( # type: ignore # pylint: disable=not-callable + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + )(mod) + elif device_type == "cpu": + target = "llvm" + mod = relax.transform.LegalizeOps()(mod) + else: + raise NotImplementedError(f"Device type {device_type} is not supported") + ex = relax.build(mod, target=target) + vm = relax.VirtualMachine(ex, device) # pylint: disable=invalid-name + return vm["main"] + + +def apply_sharding(shard_strategy, name: str, weight: nn.Parameter): + """Apply sharding strategy to a weight.""" + if isinstance(shard_strategy, tp.ShardSingleDim): + weight.attrs["shard_strategy"] = tp.ShardSingleDim( + name=name, + dim=shard_strategy.dim, + segs=shard_strategy.segs, + ) + else: + raise NotImplementedError(f"Unknowing sharding strategy: {shard_strategy}") + + +def pack_weight( + weight: te.Tensor, + axis: int, + num_elem_per_storage: int, + weight_dtype: str, + storage_dtype: str, + out_shape: Optional[Sequence[tir.PrimExpr]] = None, +): # pylint: disable=too-many-arguments + """Convert a tensor to a packed format by packing consecutive bits. + This can be useful for sub-byte quantization. + + Parameters + ---------- + weight : te.Tensor + The weight + axis : int + The axis to pack. + num_elem_per_storage : int + The number of elements per storage. + weight_dtype : str + The dtype of the input tensor. + storage_dtype : str + The dtype of the packed tensor. + out_shape : Optional[Sequence[tir.PrimExpr]] + The output shape of the packed tensor. Zero-padding is added if needed. + """ + assert weight.dtype == storage_dtype + shape = weight.shape + k = shape[axis] + axis = axis if axis >= 0 else len(shape) + axis + if out_shape is None: + out_shape = (*shape[axis], tir.ceildiv(k, num_elem_per_storage), *shape[axis + 1 :]) + r = te.reduce_axis((0, num_elem_per_storage), name="r") # pylint: disable=invalid-name + packed_weight = te.compute( + shape=out_shape, + fcompute=lambda *idx: tir.sum( + tir.if_then_else( + idx[axis] * num_elem_per_storage + r < k, + weight(*idx[:axis], idx[axis] * num_elem_per_storage + r, *idx[axis + 1 :]) + << (r * DataType(weight_dtype).bits), + tir.const(0, storage_dtype), + ), + axis=r, + ), + name="packed_weight", + ) + return packed_weight From 8a82f93226eb66e6da5e8b208e5855ffabe46e11 Mon Sep 17 00:00:00 2001 From: Kartik Khandelwal Date: Mon, 1 Apr 2024 00:10:39 -0400 Subject: [PATCH 140/531] [Serving] Bugfix for empty stop string (#2070) add check for empty stop string; fix Vanilla LM conversation template --- cpp/streamer.cc | 1 + python/mlc_llm/conversation_template.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/streamer.cc b/cpp/streamer.cc index 66e643786d..120225cbd4 100644 --- a/cpp/streamer.cc +++ b/cpp/streamer.cc @@ -177,6 +177,7 @@ StopStrHandlerObj::StopStrHandlerObj(Array stop_strs, // Create the KMP partial match table for each stop string. partial_match_tables_.reserve(num_stop_strs); for (const String& stop_str : stop_strs_) { + CHECK(!stop_str.empty()) << "Stop string cannot be empty."; partial_match_tables_.push_back(CreatePartialMatchTable(stop_str)); } } diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index ccb4e72bdd..5976517c53 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -470,7 +470,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: seps=[""], role_content_sep="", role_empty_sep="", - stop_str=[""], + stop_str=[], stop_token_ids=[2], system_prefix_token_ids=[1], ) From eb3d1e457704d2e3dca07decfadef2bbe9527ee7 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 2 Apr 2024 01:25:14 +0800 Subject: [PATCH 141/531] [SLM] Internlm Multi-GPU support (#2072) This PR enables tensor parallelism support for InternLM model. --- .../mlc_llm/model/internlm/internlm_model.py | 60 ++++++++++++++++--- 1 file changed, 53 insertions(+), 7 deletions(-) diff --git a/python/mlc_llm/model/internlm/internlm_model.py b/python/mlc_llm/model/internlm/internlm_model.py index d97d253c8f..f8e95ab4ec 100644 --- a/python/mlc_llm/model/internlm/internlm_model.py +++ b/python/mlc_llm/model/internlm/internlm_model.py @@ -13,6 +13,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase from mlc_llm.support.style import bold @@ -38,6 +39,7 @@ class InternLMConfig(ConfigBase): # pylint: disable=too-many-instance-attribute prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 max_batch_size: int = 1 + head_dim: int = 0 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -58,6 +60,9 @@ def __post_init__(self): "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( "%s defaults to %s (%d)", @@ -83,8 +88,8 @@ def __post_init__(self): class InternLMAttention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: InternLMConfig): self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads + self.num_heads = config.num_attention_heads // config.tensor_parallel_shards + self.head_dim = config.head_dim self.max_position_embeddings = config.context_window_size self.wqkv_pack = nn.Linear( @@ -106,12 +111,14 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: class InternLMMLP(nn.Module): def __init__(self, config: InternLMConfig): + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.gate_up_proj = nn.Linear( in_features=config.hidden_size, - out_features=2 * config.intermediate_size, + out_features=2 * self.intermediate_size, bias=False, ) - self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) def forward(self, x): concat_x1_x2 = self.gate_up_proj(x) @@ -128,13 +135,48 @@ def __init__(self, config: InternLMConfig): config.hidden_size, -1, config.rms_norm_eps, bias=False ) + def _set_tp(): + def _set(layer, hint): + layer.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_heads * hd + k = self.self_attn.num_heads * hd + v = self.self_attn.num_heads * hd + i = self.mlp.intermediate_size + _set( + self.self_attn.wqkv_pack.weight, + tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]), + ) + if config.bias: + _set( + self.self_attn.wqkv_pack.bias, + tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]), + ) + _set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o_weight", dim=1)) + if config.bias: + _set(self.self_attn.o_proj.bias, tp.ShardSingleDim("_shard_o_bias", dim=0)) + _set( + self.mlp.gate_up_proj.weight, + tp.ShardSingleDim("_shard_mlp_gate_up", segs=[i, i], dim=0), + ) + _set(self.mlp.down_proj.weight, tp.ShardSingleDim("_shard_mlp_down_proj", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) - hidden_states = out + hidden_states + hidden_states = self._apply_residual(out, residual=hidden_states) out = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = out + hidden_states + hidden_states = self._apply_residual(out, residual=hidden_states) return hidden_states + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + class InternLMModel(nn.Module): def __init__(self, config: InternLMConfig): @@ -160,7 +202,7 @@ def __init__(self, config: InternLMConfig): self.num_hidden_layers = config.num_hidden_layers self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_attention_heads + self.head_dim = config.head_dim self.vocab_size = config.vocab_size self.rope_theta = 10000 self.tensor_parallel_shards = config.tensor_parallel_shards @@ -188,6 +230,8 @@ def batch_forward( return logits def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.model.embed_tokens(input_ids) def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): @@ -216,6 +260,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache From 10017db91c6f65a6d6ff47bccd258207576bd22a Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Mon, 1 Apr 2024 13:41:17 -0400 Subject: [PATCH 142/531] [WebGPU] Add mlc wasm runtime, support grammar in web (#2061) * [WebGPU] Add mlc wasm runtime, support grammar in web * Make in web for wasm ci * Fix wasm ci * Fix wasm ci * Change export library arg name * Move macro to cc instead of makefile --- ci/task/test_model_compile.sh | 2 + cpp/serve/grammar/grammar_state_matcher.cc | 22 ++++++++++ docs/install/emcc.rst | 10 +++-- python/mlc_llm/serve/grammar.py | 12 +++++ python/mlc_llm/support/auto_target.py | 24 ++++++++++ scripts/prep_emcc_deps.sh | 5 +++ web/Makefile | 51 ++++++++++++++++++++++ web/README.md | 28 ++++++++++++ web/emcc/mlc_wasm_runtime.cc | 41 +++++++++++++++++ 9 files changed, 192 insertions(+), 3 deletions(-) create mode 100644 web/Makefile create mode 100644 web/README.md create mode 100644 web/emcc/mlc_wasm_runtime.cc diff --git a/ci/task/test_model_compile.sh b/ci/task/test_model_compile.sh index 06201e1d5d..97d784cf23 100755 --- a/ci/task/test_model_compile.sh +++ b/ci/task/test_model_compile.sh @@ -21,7 +21,9 @@ elif [[ ${GPU} == wasm* ]]; then TARGET=wasm pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly export TVM_HOME=$(dirname $(python -c 'import tvm; print(tvm.__file__)')) + export MLC_LLM_HOME=$(pwd) cd $TVM_HOME/web/ && make -j${NUM_THREADS} && cd - + cd $MLC_LLM_HOME/web/ && make -j${NUM_THREADS} && cd - elif [[ ${GPU} == ios ]]; then TARGET=ios pip install --pre -U --force-reinstal -f https://mlc.ai/wheels mlc-ai-nightly diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index 2131e9f112..d9954f1e28 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -176,6 +176,7 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm bool AcceptStopToken(); friend IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose); + friend NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher); std::shared_ptr init_ctx_; int max_rollback_steps_; @@ -448,6 +449,8 @@ GrammarStateMatcher::GrammarStateMatcher(std::shared_ptr(init_ctx, max_rollback_steps)) {} +#ifndef COMPILE_MLC_WASM_RUNTIME +// This creates tokenizer dependency issue in WASM building for web, hence skipped TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenizer") .set_body_typed([](BNFGrammar grammar, Optional tokenizer, int max_rollback_steps) { auto preproc_start = std::chrono::high_resolution_clock::now(); @@ -461,6 +464,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenizer") << "us" << std::endl; return GrammarStateMatcher(init_ctx, max_rollback_steps); }); +#endif TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenTable") .set_body([](TVMArgs args, TVMRetValue* rv) { @@ -622,6 +626,24 @@ IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = fals TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFindNextRejectedTokens") .set_body_typed(FindNextRejectedTokens); +/*! + * \brief Find the bitmask for the next token as an NDArray. + * \returns An NDArray of the bitmask for the next token of shape (bitmask_size,). + */ +NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher) { + auto init_ctx = matcher.as()->init_ctx_; + auto vocab_size = init_ctx->vocab_size; + auto bitset_size = BitsetManager::CalculateBufferSize(vocab_size); + auto bitmask = NDArray::Empty(ShapeTuple{static_cast(bitset_size)}, + DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0}); + auto dltensor = const_cast(bitmask.operator->()); + matcher->FindNextTokenBitmask(dltensor); + return bitmask; +} + +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFindNextTokenBitmaskAsNDArray") + .set_body_typed(FindNextTokenBitmaskAsNDArray); + } // namespace serve } // namespace llm } // namespace mlc diff --git a/docs/install/emcc.rst b/docs/install/emcc.rst index 9320be4592..389d3cc4f8 100644 --- a/docs/install/emcc.rst +++ b/docs/install/emcc.rst @@ -21,16 +21,20 @@ Validate that emcc is accessible in shell emcc --version -Step 2: Set TVM_HOME --------------------- +Step 2: Set TVM_HOME and MLC_LLM_HOME +------------------------------------- We need to set a path to a tvm source in order to build tvm runtime. Note that you do not need to build tvm unity from the source. The source here is only used to build the web runtime component. -Set environment variable in your shell startup profile in to point to ``3rdparty/tvm`` +Set environment variable in your shell startup profile in to point to ``3rdparty/tvm`` (if preferred, you could also +point to your own TVM address if you installed TVM from source). + +Besides, we also need to set ``MLC_LLM_HOME`` so that we can locate ``mlc_wasm_runtime.bc`` when compiling a model library wasm. .. code:: bash export TVM_HOME=/path/to/3rdparty/tvm + export MLC_LLM_HOME=/path/to/mlc-llm Step 3: Prepare Wasm Runtime diff --git a/python/mlc_llm/serve/grammar.py b/python/mlc_llm/serve/grammar.py index 6e9eac8655..d640c62da2 100644 --- a/python/mlc_llm/serve/grammar.py +++ b/python/mlc_llm/serve/grammar.py @@ -2,6 +2,7 @@ from typing import List, Optional, Tuple, Union +import tvm import tvm._ffi from tvm.runtime import Object @@ -256,6 +257,17 @@ def find_next_rejected_tokens(self, verbose: bool = False) -> List[int]: return _ffi_api.GrammarStateMatcherFindNextRejectedTokens(self, verbose) # type: ignore # pylint: disable=no-member + def find_next_token_bitmask_as_ndarray(self) -> tvm.nd.array: + """Find the ids of the rejected tokens for the next step. + + Returns + ------- + rejected_token_ids : List[int] + A list of rejected token ids. + """ + + return _ffi_api.GrammarStateMatcherFindNextTokenBitmaskAsNDArray(self) # type: ignore # pylint: disable=no-member + def rollback(self, num_tokens: int) -> None: """Rollback the matcher to a previous state. diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 403af9128e..e09f661ff7 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -1,6 +1,7 @@ """Helper functions for target auto-detection.""" import os +from pathlib import Path from typing import TYPE_CHECKING, Callable, List, Optional, Tuple from tvm import IRModule, relax @@ -197,6 +198,28 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None): output = args.output mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=True) assert output.suffix == ".wasm" + + # Try to locate `mlc_wasm_runtime.bc` + bc_path = None + bc_candidates = ["web/dist/wasm/mlc_wasm_runtime.bc"] + if os.environ.get("MLC_LLM_HOME", None): + mlc_source_home_dir = os.environ["MLC_LLM_HOME"] + bc_candidates.append( + os.path.join(mlc_source_home_dir, "web", "dist", "wasm", "mlc_wasm_runtime.bc") + ) + error_info = ( + "Cannot find library: mlc_wasm_runtime.bc\n" + + "Make sure you have run `scripts/prep_emcc_deps.sh` and " + + "`export MLC_LLM_HOME=/path/to/mlc-llm` so that we can locate the file. " + + "We tried to look at candidate paths:\n" + ) + for candidate in bc_candidates: + error_info += candidate + "\n" + if Path(candidate).exists(): + bc_path = candidate + if not bc_path: + raise RuntimeError(error_info) + relax.build( mod, target=args.target, @@ -204,6 +227,7 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None): system_lib=True, ).export_library( str(output), + libs=[bc_path], ) return build diff --git a/scripts/prep_emcc_deps.sh b/scripts/prep_emcc_deps.sh index 2c1306ca9e..0ccf98698b 100755 --- a/scripts/prep_emcc_deps.sh +++ b/scripts/prep_emcc_deps.sh @@ -9,6 +9,11 @@ TVM_HOME_SET="${TVM_HOME:-}" git submodule update --init --recursive +# Build mlc_wasm_runtime +cd web && make +cd - + +# Build tvm's web runtime if [[ -z ${TVM_HOME_SET} ]]; then echo "Do not find TVM_HOME env variable, use 3rdparty/tvm". echo "Make sure you set TVM_HOME in your env variable to use emcc build correctly" diff --git a/web/Makefile b/web/Makefile new file mode 100644 index 0000000000..48f98b5e81 --- /dev/null +++ b/web/Makefile @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +TVM_ROOT=$(TVM_HOME) +MLC_LLM_ROOT=$(shell cd ..; pwd) + +INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ + -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include\ + -I$(TVM_ROOT)/3rdparty/compiler-rt -I$(TVM_ROOT)/3rdparty/picojson\ + -I$(MLC_LLM_ROOT)/3rdparty/tokenizers-cpp\ + -I$(MLC_LLM_ROOT)/3rdparty/tokenizers-cpp/include -I$(MLC_LLM_ROOT)/cpp + +.PHONY: clean all rmtypedep preparetest + +all: dist/wasm/mlc_wasm_runtime.wasm + +EMCC = emcc + +EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes + +EMCC_LDFLAGS = --no-entry -s WASM_BIGINT=1 -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1\ + -s ERROR_ON_UNDEFINED_SYMBOLS=0 + +dist/wasm/mlc_wasm_runtime.bc: emcc/mlc_wasm_runtime.cc + @mkdir -p $(@D) + $(EMCC) $(EMCC_CFLAGS) -c -MM -MT dist/wasm/mlc_wasm_runtime.bc emcc/mlc_wasm_runtime.cc >dist/wasm/mlc_wasm_runtime.d + $(EMCC) $(EMCC_CFLAGS) -emit-llvm -c -o dist/wasm/mlc_wasm_runtime.bc emcc/mlc_wasm_runtime.cc + +# Compile to wasm here so that errors can be caught earlier (rather than during export_library) +dist/wasm/mlc_wasm_runtime.wasm: dist/wasm/mlc_wasm_runtime.bc + @mkdir -p $(@D) + $(EMCC) $(EMCC_CFLAGS) -o dist/wasm/mlc_wasm_runtime.wasm $+ $(EMCC_LDFLAGS) + +clean: + @rm -rf dist/wasm lib + +-include dist/wasm/*.d diff --git a/web/README.md b/web/README.md new file mode 100644 index 0000000000..e6e34918db --- /dev/null +++ b/web/README.md @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + +# MLC-LLM WebAssembly Runtime + +This folder contains MLC-LLM WebAssembly Runtime. + +Please refer to https://llm.mlc.ai/docs/install/emcc.html. + +The main step is running `make` under this folder, a step included in `scripts/prep_emcc_deps.sh`. + +`make` creates `web/dist/wasm/mlc_wasm_runtime.bc`, which will be included in the model library wasm +when we compile the model. Thus during runtime, runtimes like WebLLM can directly reuse source +code from MLC-LLM. \ No newline at end of file diff --git a/web/emcc/mlc_wasm_runtime.cc b/web/emcc/mlc_wasm_runtime.cc new file mode 100644 index 0000000000..3f05eb259f --- /dev/null +++ b/web/emcc/mlc_wasm_runtime.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file mlc_wasm_runtime.cc + * \brief MLC wasm runtime library pack. + */ + +// configurations for tvm logging +#define TVM_LOG_STACK_TRACE 0 +#define TVM_LOG_DEBUG 0 +#define TVM_LOG_CUSTOMIZE 1 + +// Pass in COMPILE_MLC_WASM_RUNTIME so unsupported code would not be compiled in to the .bc file +#define COMPILE_MLC_WASM_RUNTIME 1 + +#define DMLC_USE_LOGGING_LIBRARY + +// Grammar related +#include "serve/grammar/grammar.cc" +#include "serve/grammar/grammar_parser.cc" +#include "serve/grammar/grammar_serializer.cc" +#include "serve/grammar/grammar_simplifier.cc" +#include "serve/grammar/grammar_state_matcher.cc" +#include "support/encoding.cc" From 91211269b89705d286114b66e437435a83fd0e87 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 1 Apr 2024 14:04:22 -0500 Subject: [PATCH 143/531] [Build] Use TVM_HOME environment variable (#2073) Prior to this commit, the `CMakeLists.txt` file checked a cmake `TVM_HOME` variable, but did not check the usual `TVM_HOME` environment variable. If this variable is set, it should be used. --- CMakeLists.txt | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a1644f0894..7f0dd7ef24 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,7 +55,11 @@ set(USE_GTEST OFF) set(USE_LIBBACKTRACE OFF) set(BUILD_DUMMY_LIBTVM ON) if (NOT DEFINED TVM_HOME) - set(TVM_HOME 3rdparty/tvm) + if(DEFINED ENV{TVM_HOME}) + set(TVM_HOME "$ENV{TVM_HOME}") + else() + set(TVM_HOME 3rdparty/tvm) + endif(DEFINED ENV{TVM_HOME}) endif (NOT DEFINED TVM_HOME) message(STATUS "TVM_HOME: ${TVM_HOME}") add_subdirectory(${TVM_HOME} tvm EXCLUDE_FROM_ALL) @@ -93,7 +97,10 @@ set_target_properties(mlc_llm_static PROPERTIES OUTPUT_NAME mlc_llm) target_link_libraries(mlc_llm PUBLIC tvm_runtime) target_link_libraries(mlc_llm PRIVATE tokenizers_cpp) -find_library(FLASH_ATTN_LIBRARY flash_attn) +find_library( + FLASH_ATTN_LIBRARY flash_attn + HINTS ${TVM_HOME}/*/3rdparty/libflash_attn/src +) if (FLASH_ATTN_LIBRARY STREQUAL "FLASH_ATTN_LIBRARY-NOTFOUND") message(WARNING "Cannot find libflash_attn. The model must not have been built with --use-flash-attn-mqa option.") From b7416c0297dae281a12aa0c94051b53ea2d09404 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 1 Apr 2024 15:04:51 -0400 Subject: [PATCH 144/531] [Serving] Support input chunking (#2069) This PR supports input chunking with regard to customized "prefill chunk size" (field `prefill_chunk_size` in `mlc-chat-config.json`). With this PR, we can now chunk a long input into multiples when there is an upper limit on the prefill chunk size. Only `TokenData` is supported for now. --- cpp/serve/engine.cc | 9 +- cpp/serve/engine_actions/action_commons.cc | 6 +- .../engine_actions/new_request_prefill.cc | 248 +++++++++++++----- cpp/serve/model.cc | 2 +- python/mlc_llm/serve/async_engine.py | 3 +- python/mlc_llm/serve/engine.py | 3 +- 6 files changed, 191 insertions(+), 80 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 6c060a7e27..abb5c7b6c7 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -146,12 +146,13 @@ class EngineImpl : public Engine { request = Request::FromUntokenized(request, tokenizer_); ICHECK_NE(request->input_total_length, -1); - if (request->input_total_length >= kv_cache_config_->prefill_chunk_size) { - // If the request input length exceeds the prefill chunk size, + if (request->input_total_length >= max_single_sequence_length_) { + // If the request input length exceeds the maximum allowed single sequence length, // invoke callback and do not process the request. - // Todo(mlc-team): Use "maximum single sequence length" after impl input chunking. Array output{RequestStreamOutput( - request->id, {}, Optional>>(), {String("length")})}; + request->id, std::vector(request->generation_cfg->n), + Optional>>(), + std::vector>(request->generation_cfg->n, String("length")))}; request_stream_callback_.value()(std::move(output)); return; } diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index 1fb61ae70a..6eb7a3d84a 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -159,6 +159,9 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, } ICHECK_NE(preempt_rstate_idx, -1); RequestStateEntry rsentry = rstate->entries[preempt_rstate_idx]; + // When the request state entry still has pending inputs, + // it means the request is still in the waiting queue. + bool partially_alive = !rsentry->mstates[0]->inputs.empty(); // Remove from models. // - Clear model speculation draft. @@ -167,7 +170,6 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, rsentry->status = RequestStateStatus::kPending; for (RequestModelState mstate : rsentry->mstates) { mstate->RemoveAllDraftTokens(); - ICHECK(mstate->inputs.empty()); std::vector committed_token_ids; committed_token_ids.reserve(mstate->committed_tokens.size()); for (const SampleResult& committed_token : mstate->committed_tokens) { @@ -197,7 +199,7 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, // Remove from running queue. estate->running_queue.erase(estate->running_queue.end() - 1); } - if (preempt_rstate_idx == static_cast(rstate->entries.size()) - 1) { + if (!partially_alive && preempt_rstate_idx == static_cast(rstate->entries.size()) - 1) { // Add to the front of waiting queue. estate->waiting_queue.insert(estate->waiting_queue.begin(), request); } diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index 6363f8a537..f93fbc2ded 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -35,87 +35,96 @@ class NewRequestPrefillActionObj : public EngineActionObj { Array Step(EngineState estate) final { // - Find the requests in `waiting_queue` that can prefill in this step. - Array rsentries; - std::vector prefill_lengths; + std::vector prefill_inputs; { NVTXScopedRange nvtx_scope("NewRequestPrefill getting requests"); - auto tuple = GetRequestStateEntriesToPrefill(estate); - rsentries = std::move(std::get<0>(tuple)); - prefill_lengths = std::move(std::get<1>(tuple)); - ICHECK_EQ(rsentries.size(), prefill_lengths.size()); - if (rsentries.empty()) { + prefill_inputs = GetRequestStateEntriesToPrefill(estate); + if (prefill_inputs.empty()) { return {}; } } - int num_rsentries = rsentries.size(); + int num_rsentries = prefill_inputs.size(); auto tstart = std::chrono::high_resolution_clock::now(); // - Update status of request states from pending to alive. Array request_ids; std::vector rstates_of_entries; + std::vector status_before_prefill; request_ids.reserve(num_rsentries); rstates_of_entries.reserve(num_rsentries); - for (RequestStateEntry rsentry : rsentries) { + status_before_prefill.reserve(num_rsentries); + for (const PrefillInput& prefill_input : prefill_inputs) { + const RequestStateEntry& rsentry = prefill_input.rsentry; const Request& request = rsentry->request; RequestState request_rstate = estate->GetRequestState(request); request_ids.push_back(request->id); + status_before_prefill.push_back(rsentry->status); rsentry->status = RequestStateStatus::kAlive; - // - Remove the request from waiting queue if all its request states are now alive. - // - Add the request to running queue if all its request states were pending. - bool alive_state_existed = false; - for (const RequestStateEntry& rsentry_ : request_rstate->entries) { - if (rsentry_->status == RequestStateStatus::kAlive && !rsentry_.same_as(rsentry)) { - alive_state_existed = true; + if (status_before_prefill.back() == RequestStateStatus::kPending) { + // - Add the request to running queue if the request state + // status was pending and all its request states were pending. + bool alive_state_existed = false; + for (const RequestStateEntry& rsentry_ : request_rstate->entries) { + if (rsentry_->status == RequestStateStatus::kAlive && !rsentry_.same_as(rsentry)) { + alive_state_existed = true; + } + } + if (!alive_state_existed) { + estate->running_queue.push_back(request); } - } - if (!alive_state_existed) { - estate->running_queue.push_back(request); } rstates_of_entries.push_back(std::move(request_rstate)); } // - Get embedding and run prefill for each model. + std::vector prefill_lengths; + prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/-1); NDArray logits_for_sample{nullptr}; for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { std::vector request_internal_ids; request_internal_ids.reserve(num_rsentries); ObjectRef embeddings = model_workspaces_[model_id].embeddings; int cum_prefill_length = 0; - bool single_input = num_rsentries == 1 && rsentries[0]->mstates[model_id]->inputs.size() == 1; + bool single_input = + num_rsentries == 1 && prefill_inputs[0].rsentry->mstates[model_id]->inputs.size() == 1; for (int i = 0; i < num_rsentries; ++i) { - RequestModelState mstate = rsentries[i]->mstates[model_id]; - ICHECK_EQ(mstate->GetInputLength(), prefill_lengths[i]); - ICHECK(mstate->draft_output_tokens.empty()); - ICHECK(mstate->draft_output_prob_dist.empty()); - ICHECK(!mstate->inputs.empty()); - // Add the sequence to the model, or fork the sequence from its parent. - if (rsentries[i]->parent_idx == -1) { - models_[model_id]->AddNewSequence(mstate->internal_id); + const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; + RequestModelState mstate = rsentry->mstates[model_id]; + auto [input_data, input_length] = + ChunkPrefillInputData(mstate, prefill_inputs[i].max_prefill_length); + if (prefill_lengths[i] == -1) { + prefill_lengths[i] = input_length; } else { - models_[model_id]->ForkSequence(rstates_of_entries[i] - ->entries[rsentries[i]->parent_idx] - ->mstates[model_id] - ->internal_id, - mstate->internal_id); + ICHECK_EQ(prefill_lengths[i], input_length); } - // Enable sliding window for the sequence if it is not a parent. - if (rsentries[i]->child_indices.empty()) { - models_[model_id]->EnableSlidingWindowForSeq(mstate->internal_id); + + ICHECK(mstate->draft_output_tokens.empty()); + ICHECK(mstate->draft_output_prob_dist.empty()); + if (status_before_prefill[i] == RequestStateStatus::kPending) { + // Add the sequence to the model, or fork the sequence from its parent. + if (rsentry->parent_idx == -1) { + models_[model_id]->AddNewSequence(mstate->internal_id); + } else { + models_[model_id]->ForkSequence( + rstates_of_entries[i]->entries[rsentry->parent_idx]->mstates[model_id]->internal_id, + mstate->internal_id); + } + // Enable sliding window for the sequence if it is not a parent. + if (rsentry->child_indices.empty()) { + models_[model_id]->EnableSlidingWindowForSeq(mstate->internal_id); + } } request_internal_ids.push_back(mstate->internal_id); - RECORD_EVENT(trace_recorder_, rsentries[i]->request->id, "start embedding"); - for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { - embeddings = - mstate->inputs[i]->GetEmbedding(models_[model_id], - /*dst=*/!single_input ? &embeddings : nullptr, - /*offset=*/cum_prefill_length); - cum_prefill_length += mstate->inputs[i]->GetLength(); + RECORD_EVENT(trace_recorder_, rsentry->request->id, "start embedding"); + for (int i = 0; i < static_cast(input_data.size()); ++i) { + embeddings = input_data[i]->GetEmbedding(models_[model_id], + /*dst=*/!single_input ? &embeddings : nullptr, + /*offset=*/cum_prefill_length); + cum_prefill_length += input_data[i]->GetLength(); } - RECORD_EVENT(trace_recorder_, rsentries[i]->request->id, "finish embedding"); - // Clean up `inputs` after prefill - mstate->inputs.clear(); + RECORD_EVENT(trace_recorder_, rsentry->request->id, "finish embedding"); } RECORD_EVENT(trace_recorder_, request_ids, "start prefill"); @@ -139,8 +148,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { generation_cfg.reserve(num_rsentries); mstates_for_logitproc.reserve(num_rsentries); for (int i = 0; i < num_rsentries; ++i) { - generation_cfg.push_back(rsentries[i]->request->generation_cfg); - mstates_for_logitproc.push_back(rsentries[i]->mstates[0]); + generation_cfg.push_back(prefill_inputs[i].rsentry->request->generation_cfg); + mstates_for_logitproc.push_back(prefill_inputs[i].rsentry->mstates[0]); } logits_for_sample = logits_for_sample.CreateView({num_rsentries, logits_for_sample->shape[2]}, logits_for_sample->dtype); @@ -164,7 +173,12 @@ class NewRequestPrefillActionObj : public EngineActionObj { request_ids.clear(); generation_cfg.clear(); for (int i = 0; i < num_rsentries; ++i) { - const RequestStateEntry& rsentry = rsentries[i]; + const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; + // No sample for rsentries with remaining inputs. + if (!rsentry->mstates[0]->inputs.empty()) { + continue; + } + for (int child_idx : rsentry->child_indices) { if (rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty()) { // If rstates_of_entries[i]->entries[child_idx] has no committed token, @@ -219,12 +233,14 @@ class NewRequestPrefillActionObj : public EngineActionObj { auto tend = std::chrono::high_resolution_clock::now(); estate->stats.engine_total_prefill_time += static_cast((tend - tstart).count()) / 1e9; + // - Remove the request from waiting queue if all its request states + // are now alive and have no remaining chunked inputs. std::vector processed_requests; { processed_requests.reserve(num_rsentries); std::unordered_set dedup_map; - for (int i = 0; i < static_cast(rsentries.size()); ++i) { - const RequestStateEntry& rsentry = rsentries[i]; + for (int i = 0; i < num_rsentries; ++i) { + const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; if (dedup_map.find(rsentry->request.get()) != dedup_map.end()) { continue; } @@ -233,7 +249,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { bool pending_state_exists = false; for (const RequestStateEntry& rsentry_ : rstates_of_entries[i]->entries) { - if (rsentry_->status == RequestStateStatus::kPending) { + if (rsentry_->status == RequestStateStatus::kPending || + !rsentry_->mstates[0]->inputs.empty()) { pending_state_exists = true; break; } @@ -250,21 +267,26 @@ class NewRequestPrefillActionObj : public EngineActionObj { } private: + /*! \brief The class of request state entry and its maximum allowed length for prefill. */ + struct PrefillInput { + RequestStateEntry rsentry; + int max_prefill_length; + }; + /*! * \brief Find one or multiple request state entries to run prefill. * \param estate The engine state. * \return The request entries to prefill, together with their input lengths. */ - std::tuple, std::vector> GetRequestStateEntriesToPrefill( - EngineState estate) { + std::vector GetRequestStateEntriesToPrefill(EngineState estate) { if (estate->waiting_queue.empty()) { // No request to prefill. - return {{}, {}}; + return {}; } + std::vector prefill_inputs; + // - Try to prefill pending requests. - std::vector rsentries_to_prefill; - std::vector prefill_lengths; int total_input_length = 0; int total_required_pages = 0; int num_available_pages = models_[0]->GetNumAvailablePages(); @@ -278,12 +300,11 @@ class NewRequestPrefillActionObj : public EngineActionObj { for (const RequestStateEntry& rsentry : rstate->entries) { // A request state entry can be prefilled only when: // - it has inputs, and - // - it is pending, and - // - it has no parent or its parent is alive. + // - it has no parent or its parent is alive and has no remaining input. if (rsentry->mstates[0]->inputs.empty() || - rsentry->status != RequestStateStatus::kPending || (rsentry->parent_idx != -1 && - rstate->entries[rsentry->parent_idx]->status == RequestStateStatus::kPending)) { + (rstate->entries[rsentry->parent_idx]->status == RequestStateStatus::kPending || + !rstate->entries[rsentry->parent_idx]->mstates[0]->inputs.empty()))) { continue; } @@ -292,25 +313,41 @@ class NewRequestPrefillActionObj : public EngineActionObj { (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; total_input_length += input_length; total_required_pages += num_require_pages; + // - Attempt 1. Check if the entire request state entry can fit for prefill. if (CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), total_input_length, total_required_pages, num_available_pages, current_total_seq_len, num_running_rsentries)) { - rsentries_to_prefill.push_back(rsentry); - prefill_lengths.push_back(input_length); + prefill_inputs.push_back({rsentry, input_length}); + num_prefill_rsentries += 1 + rsentry->child_indices.size(); + continue; + } + total_input_length -= input_length; + total_required_pages -= num_require_pages; + + // - Attempt 2. Check if the request state entry can partially fit by input chunking. + ICHECK_LE(total_input_length, kv_cache_config_->prefill_chunk_size); + input_length = + std::min(input_length, kv_cache_config_->prefill_chunk_size - total_input_length); + num_require_pages = + (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + if (input_length > 0 && + CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), + total_input_length, total_required_pages, num_available_pages, + current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length}); num_prefill_rsentries += 1 + rsentry->child_indices.size(); - } else { - total_input_length -= input_length; - total_required_pages -= num_require_pages; - prefill_stops = true; - break; } + + // - Prefill stops here. + prefill_stops = true; + break; } if (prefill_stops) { break; } } - return {rsentries_to_prefill, prefill_lengths}; + return prefill_inputs; } /*! \brief Check if the input requests can be prefilled under conditions. */ @@ -323,7 +360,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { // run simultaneously. int spec_factor = engine_mode_->enable_speculative ? engine_mode_->spec_draft_length : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > - kv_cache_config_->max_num_sequence) { + std::min(kv_cache_config_->max_num_sequence, kv_cache_config_->prefill_chunk_size)) { return false; } @@ -340,6 +377,79 @@ class NewRequestPrefillActionObj : public EngineActionObj { kv_cache_config_->max_total_sequence_length; } + /*! + * \brief Chunk the input of the given RequestModelState for prefill + * with regard to the provided maximum allowed prefill length. + * Return the list of input for prefill and the total prefill length. + * The `inputs` field of the given `mstate` will be mutated to exclude + * the returned input. + * \param mstate The RequestModelState whose input data is to be chunked. + * \param max_prefill_length The maximum allowed prefill length for the mstate. + * \return The list of input for prefill and the total prefill length. + */ + std::pair, int> ChunkPrefillInputData(const RequestModelState& mstate, + int max_prefill_length) { + if (mstate->inputs.empty()) { + } + ICHECK(!mstate->inputs.empty()); + std::vector inputs; + int cum_input_length = 0; + inputs.reserve(mstate->inputs.size()); + for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { + inputs.push_back(mstate->inputs[i]); + int input_length = mstate->inputs[i]->GetLength(); + cum_input_length += input_length; + // Case 0. the cumulative input length does not reach the maximum prefill length. + if (cum_input_length < max_prefill_length) { + continue; + } + + // Case 1. the cumulative input length equals the maximum prefill length. + if (cum_input_length == max_prefill_length) { + if (i == static_cast(mstate->inputs.size()) - 1) { + // - If `i` is the last input, we just copy and reset `mstate->inputs`. + mstate->inputs.clear(); + } else { + // - Otherwise, set the new input array. + mstate->inputs = Array{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; + } + return {inputs, cum_input_length}; + } + + // Case 2. cum_input_length > max_prefill_length + // The input `i` itself needs chunking if it is TokenData, + // or otherwise it cannot be chunked. + Data input = mstate->inputs[i]; + inputs.pop_back(); + cum_input_length -= input_length; + const auto* token_input = input.as(); + if (token_input == nullptr) { + // Cannot chunk the input. + if (i != 0) { + mstate->inputs = Array{mstate->inputs.begin() + i, mstate->inputs.end()}; + } + return {inputs, cum_input_length}; + } + + // Split the token data into two parts. + // Return the first part for prefill, and keep the second part. + int chunked_input_length = max_prefill_length - cum_input_length; + ICHECK_GT(input_length, chunked_input_length); + TokenData chunked_input(IntTuple{token_input->token_ids.begin(), + token_input->token_ids.begin() + chunked_input_length}); + TokenData remaining_input(IntTuple{token_input->token_ids.begin() + chunked_input_length, + token_input->token_ids.end()}); + inputs.push_back(chunked_input); + cum_input_length += chunked_input_length; + std::vector remaining_inputs{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; + remaining_inputs.insert(remaining_inputs.begin(), remaining_input); + mstate->inputs = remaining_inputs; + return {inputs, cum_input_length}; + } + + ICHECK(false) << "Cannot reach here"; + } + /*! \brief The models to run prefill in. */ Array models_; /*! \brief The logit processor. */ diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 6e93061f31..5ebf26a061 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -395,7 +395,7 @@ class ModelImpl : public ModelObj { embedding_shape = embedding_nd.Shape(); } ICHECK_EQ(embedding_shape.size(), 2); - ICHECK_EQ(embedding_shape[0], prefill_chunk_size_); + ICHECK_GE(embedding_shape[0], prefill_chunk_size_); this->hidden_size_ = embedding_shape[1]; return embedding; } diff --git a/python/mlc_llm/serve/async_engine.py b/python/mlc_llm/serve/async_engine.py index 652bfa39f8..341a3880f3 100644 --- a/python/mlc_llm/serve/async_engine.py +++ b/python/mlc_llm/serve/async_engine.py @@ -278,8 +278,7 @@ def __init__( # [model_lib_path, model_path, device.device_type, device.device_id] * N model.model_lib_path = model_args[i * (len(model_args) // len(models))] - # Todo(mlc-team): use `max_single_sequence_length` only after impl input chunking. - self.max_input_sequence_length = min(max_single_sequence_length, prefill_chunk_size) + self.max_input_sequence_length = max_single_sequence_length self.state = _AsyncThreadedEngineState(enable_tracing) if kv_cache_config.max_total_sequence_length is None: diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 0757a0d8e9..607f970a1e 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -338,8 +338,7 @@ def __init__( # pylint: disable=too-many-arguments ], ) self.trace_recorder = EventTraceRecorder() if enable_tracing else None - # Todo(mlc-team): use `max_single_sequence_length` only after impl input chunking. - self.max_input_sequence_length = min(max_single_sequence_length, prefill_chunk_size) + self.max_input_sequence_length = max_single_sequence_length if kv_cache_config.max_total_sequence_length is None: kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( From 52de79860b935593e301cd70d1ab5bd1f2939a90 Mon Sep 17 00:00:00 2001 From: David Pissarra <61968959+davidpissarra@users.noreply.github.com> Date: Tue, 2 Apr 2024 14:51:35 +0100 Subject: [PATCH 145/531] [Docs] API Code Completion Guide (#2054) --- docs/_static/img/code_completion.png | Bin 0 -> 179797 bytes docs/_static/img/ide_code_settings.png | Bin 0 -> 9460 bytes docs/_static/img/ide_code_templates.png | Bin 0 -> 26735 bytes docs/deploy/ide_integration.rst | 179 ++++++++++++++++++++++++ docs/deploy/rest.rst | 2 + docs/index.rst | 1 + scripts/local_deploy_site.sh | 2 +- 7 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 docs/_static/img/code_completion.png create mode 100644 docs/_static/img/ide_code_settings.png create mode 100644 docs/_static/img/ide_code_templates.png create mode 100644 docs/deploy/ide_integration.rst diff --git a/docs/_static/img/code_completion.png b/docs/_static/img/code_completion.png new file mode 100644 index 0000000000000000000000000000000000000000..1008542c339d7c0f2de737770a7477f3800d7e7d GIT binary patch literal 179797 zcmZsj1ymf%wy+5afdD~*6P(}yg1c*QcM0wcVQ}~065QP_xO;FJ+?_Ck4-(w|IrrUj zZ_fK4YxSD$uG-brRlD}y^;I=tN(xeMkqMFE;NaeVmKIlmgL^p+2M7NP={4*Qhw(xe z9NY_KD={&p&thU^N{#??D_b)-IQ&HC*e+So7udlGW*62WTYV3vr1j;y^u+z$p8MkiZgSCRChGtCH)p>}CESFR< zV%4C3*KHPoT%l~)-nR{Jn9RwdS=Sne?OIt#k3rfKB4#?F0d48u{c~t*4=B z>MFh0y`84X+56jG_PUR)J?-nMOcnidiNusL_(Cs=K0m~x)H{-gTdGYiQHUk6+6Qb~ z{*p#vRuwiw&i-MzTrl$hg?2f&Zxf7hcIzCw4U09e22K@6LwNkR^qZUKvGs3niwcbx z>?UqjjxqwD@cP^?0y>b#%Z%C=Q1i9ppDdcoyx?*$db~2s>ci@jV0%5%&8Ic8KS||DF1BJm=-#=kUM&$|#~L_W3jHSJlMP z%*-BW32>^Ta-V}0HE*S+<)kGq$7=$xV=^)Y7@IM<**W|r0>|&h3p=$lb21`xv$M4a z^12C7{3{19?EJ6W%oJq*%Hm`rK%phCL?#AsG$Z3=`o#2!LJ*mZjEvvW)SOpET=L)I zuqy!yOD87>US?)jS63!iHYR|h1v3i|4-fMvR%TXKMpzC;pu4@3ksG5ukn(>D`MVr( zGoXp1m4lNNz@F@{a*d1u&Q1ap6n{PFKY#yIPct{G|31kc`0v+(y&&^nDag}ec~UITK|>g;bQw!@{g4NEvXJP za})#E!5--(_}`xSxA32d{}$wD{_D;EXo>%6=6~IVwX+~HKl6Y3Oc2>2W@rfxP8ja9 zxQLqDi=#|LFYLLiLErrjJQ?Jm^>in3;TI^25;aB=HBA=oF4=j48>-L7(8aA%KYd|a z_7^YqJ6<7uh+X%5x^_+0cHwGEUmK^!25r@)rmjHvTtj%SGdu@rsHh6jC`5&kkY13J zy$bTfz<~eHg|Gm8=cVX}|1-@GWAc@_RZ{f-mibo!-7j7!rxE+befj^2{3}_$h~3b2 z%^9D|K4UOwr3ny&Na}^*HBn%y|>&@KDeuG&VO7#Guh$kI+|m?M9Sy#(Z2m` zz`QgfQO{$)Ff=CS+QWOD7kyN4yOPzt*2u`H`Fw-{6CZ!?66o60T?Z0;zhK~5DZ4~? zuydjZDl$Ht4qN>EtJ3CO);zIQ@>;|Mmq0{xe71N5Cf;J{_}5q}6Isbf!iD4M#u;@RU-@Rqn3{f|prF{?*~yRC{OaUX zU)-;~t6O7?aAUk!smoRF03g-#x!v82Xd(YKl_!C+^JBPaG>NfKT~mM8xqnF)VzeaH zu1=ksl#oCsAfTVkml&K+mjqc<;`>+46-PZfh8uH%4%kpw-Dfd6vh~? z)1{|xk-w_hbvVclBt;%j@%O-U-wZ+MKwack6HnU;+4d!C@0kDFr~XA0lcq92RgT5? zySTWx+iR{H{V6Oa?dBbg3rsrhJKNhw5*)IN6PbK0HuE3l(%9eB)YOctKx$Nmp ztA0#5E_uUOS zBO}g#)QM26dt+{loSfWkFZbII=R*yrgV{Hoot;`;J_03vXaua{xk=ByA#L+e(j zGuapzP%3qs=Iy&|?_-3XLU%`#cWOtOUak=rD`Z+vIKo6IXlM|(wzkY=?^P8A3KN~E zDVq3Bn#FW8Y;0_5Oo!ijJwIOac6D_L**9;!*2kOK1%nT~+}+*r*)1^!Vo0lRXXG;@ zhke0gso>V*my*o6j z^h!aa+PNrrOs_7c)$jYzg zo>nQRA3hjk1BqK&mPDHv=0O6>w+}9z)JQ6$e)~wv$<2O7Mn*mbrvoU2gf>%$m z=!(Dm-kZwd$hKWdcF)WRV_5jfDTK4#h-EM3U>t6Tt)v+-6zv_$?ZkftdPWV_)tpn@ z5e;QyIuiONVPhvJhe}URf7*{{xS0W*n3~tkx=%Lr3F~^glU?}vJCo0Ket&dy^sMXo z5j~yL&iV(=$voSaZaD9a8s-}LBdHDV>&&Wej~*ilbkfq&pEj72-hRsBXy2eCARur% ztZX`3r%bwydm_CH#zp=`8U2~iY!RBOr%|Dcaw+N{h&jfKoOTEBNr^s zvDnBRHJ!yIL__!+I%$<*JAwO*JeO6W^T}$94!P-4(}qtQ52>NkZJtaB%-Tt13wpBy zGrNbg*V)(McQa%~M!|Vl~*V8RTXPEE#;o{KDZl#ftn@n^! z!mES`m6#XX`~Ipjd43MAF&ILKgA)wXb1mR6zZdtTQl#EN9iG<*PvZSK~rD=0Lfr!rh zGjXN4&5JhK9iKF-iaWeecN+4pUkYG!KHb@@U6&jwt!^n$#Y09!UZP1m$Wo=rn2AkK za@35IW~A5_ep=8_&narhR+n%eG7uzv%Eerk;PctUU#Vas<=g4?8T3paA-?CnM=dru zO@Sx`CnB?(ZFc)Fdz!(_kX@XZh#Haj#X-bl2x9HT%p(H!JOq>St5zoCfuvwc?X?d9(Wx#t*TR6r= zqJo^cPV!ui%s>*aE!Pif@He%JTbH@cw+n{89Hc9%p;|(f71izM;er-dC>|tAVfGcL zdP}c$SZNe8N^o#J*e2;^?1zVkJwD~bapqr z65u?I<+n}{f=a@#1$7h-?iG{(y}F{(xm(5jtFO-mZ#G{N%aLSkj0l>$KeJ4$pXimu zT%+OLZht{Re9P;!xA}c{HPObne#3CwkkfR73+g8W6jtN+dU&{O2d9vTMOwx*z5Tk? zG3Ngcx!N_ENxy!2$YdR=rbYK#cng#wHfBMLh&mhwS&ePhiTS!QbsxBU?%pjza7BeC zbDDO}W6UpG(hkHJ!w__}N;SWywOg*A$|90%uVaK>!;Fgdwevc(B02mbg1K|=)P1jl zpK@YG(C$}qpNagGNR*lutyj+$X!@G3~7RKW8pCBMW{<#Vo{<1m%ZEepl% zJ81ay(PzF$-blvEs`_`+X)?04l(~*JNNex*9Y4#m1&3pgRUau9&Rxs zb3qS`--iq6Jv}{lgv+%#P5w20#20nrY6)+*SQrCNLdzFs3{r`@S@llbnJTh%WxodV z-G5|^Cu;ujv2Q6SQu$l?!cS?#xb#w&LnVEc0z5cb!FvzN5o_JlY&++p2@x^9k5B}` zmu^D%#=jLwE_`zC`Lf-UjE9F~E6#QTsEJ1hxNl_d&v{m=3yDR|Vu=Ep=+BHvt1YzHNys zr|$vnTc6{iCQN9@x*o4|1A*m?cvHqJrDO!*?2NW*X?x$t@DxZqDKZr0c7x~5TlrOP z;N^0SC4)z?wnq10KBi(#dLQ@5jvP@})@-HGuEBAFXGu!6Bx5Q#6S=OCb=xZ4@-6tX zxPL9>y#iI(OMD{fty@_0!6LBxYv_W;(jcktDw3n;%HrYYWCOH}*f%SqbT5*wi^|{T zF>qEUW{k?pGC*l2A^_6-Ez~$Xsfz+T5LW)^lHDRmjNm_<>4p}o5di) zvUQ{Ore;02t>Df27ZX^J3%g%Kaomr;XNczytzKp0H7-O2vJq(+69t3b>EzJC+VNhJ zU%#n^JH4$Jl3n=2z$B+X9%5v2UvH*sYWPPOE-%tuQ(c8;lP}B0h*_NFQKms}8Mb9C z)5vUSP%&+#9MU_6C#hMnTtDWp&ejAo6df8FkrKWo{AYvEvuR$dee+Q~t=4XpFxw$` zss4!FQ>yf;Y&B_az;_Ks*~7XSth%sYqqK~-^`E$cd&dsAW8#=bohXZYr4=4<8=5*x z?SL{&q$)i5m$4(cMyW15b%8;;jkZ(BtoLvbKSJOg8=KXeV3!&il!R7xqMvuNBi^8aJL-~m3_!2>iv7hs3VIkq_r5jdTte~?Z?$vc5qt4$wpXU zL?@?a?!yZuRb^!QSw!5ipX)q;J@*OlEBh0-bd-A{sOLIJp8QC*{;IB8*&W=9xjL<% zwh)wsb)-p@VamsCv4UaXiI(Dk$;EUqykeAFnr>B5RkqmM2<%=vj^^ z3GcnO(%Oz<*C3el&vz)=LE*zOmvuE>c}N;eGCSjEjA;S5+y2_IKB#B<+Hcv}%=$J_ zcB1^Z+Rb?k1x zarCCUTbLNsL#cp}lQ{3NU%XOLJ}>B6U0gK_20h$fI2(36KixMtN>B~rx7WH(jM)e2 z+PgnHC}>Mi!vr)*8A!P@APABKq|kRt1RlVN!V!p9jfu(zuE1 z9eZ|R2n}MQH~pyE!4hr=5H57y8}pCpr(3#TjK|ccTP}|_Q@-7)w7%p|iV0?X{T#9w z`c}8S!|FiBHAhAUq3c!PWq@8iq0U6G8DSP%BDHHpWI5r=HRvHYfNK5H2=pw>aBJ$5Z*V^*zM0er@SuQhH}NNyF@lBOj#G zIO@0&UrQD9heyo#ONoyvFf?~=YJvQ8V~r3q#W$_c&ii_JQ7(XI_)E|dad4{hIaPAk zj{L`JMu@BK1VWPlo(@Tj1q;|=nGewb=;_@L)AO?=tBwxPmAg%63me+HeG`*nL&t}; z`t7==(Q9hFQBVKl0u{)S&n8@W@wT*)x%*pvQu>)_Z0mf1kX~D&_qv|pE@H6ZS52eq z(WMn`*7ot5ab~FVkE*td-nimsy^~^>&Gu*+cn;9Cq3$~-||1&(f zF_9)E(wL?Oh|E`j-yyk-m^`@unFuJng3n{mM$%NC5c$KR{tY(Lf`scTverH@CX4_6?m-?}u&3`>mJu%Ti6bKddq7x1U1CX)eTy-fNu=KQen+^mwpfaHqyp{s&jM?K$ge z3pM6TPYT+WV0+b`?YWYXk*_V@*9O{t5rYX|x(nu+AaOzetWXZitLX4)B5dr)v@}9B zO-+uXWBKn3LLYIBD%;5T3u%yDJ_;Sk0M7+9DY+R5Rsh~*>$i8W5hx18-oGkr42a1F z)X|O;8@S^?vBU}M6I+MN#0fxXevRwvRolrEU;ytvy=sgE)K9OTRnjQcnxkBG-C3z} zE*#euuV;LlcFd;MOyTRe|FHP2aqX1bT4of(cMJcg;iuQPC{My*&$?CXCccFaKb-SU znv6R5^3T6i>B?N_)=@ef_6MhMH@;#62e-^ge*XAh^zj1Ld;H{^annaoQomW3%-H$d z^eSM_b@oc7=Uz!21_@?xcGt5g!w}2YustrzR}T;JbKqT~-VlyaB^^CeuIDMFk80tp zzM0iNmzHW!jZAf(+r|6UHtBXR{yc>fh{pQB2d1@ltZ(K|enYw5Q^wOR!gj>xU}h`1 zxbN4W{9|)?BS{Uy@dU%HJ*YdaT|#2e^Aw0fR{_3avvkzxC?zjaiV`w-P=r?;lVhHb zT{sZ>JGELl@vr6SF$$OLJih-!jk~enN$AG)Jujuaj2WecJ50`mOHTv_Id$2xb$L_xt$OMwHApI1aE8-8SJ$v8#}$uW8C zz%~h+Lud$;by2Bob zhnI<)388f7$v-xT-CM62$m23z1rV5zGq~A#kMGy0xKYYjs}@;yf&2p+jIPuPBsCt4 z5nkoyaKpqbV)U+8nVI9;c@`8MZmkRzFyY>&q2=rmDJd_L+?wuk*kt7RlvgpDusrRU zU?;!O73)H`%xgtEzV+CeI~Rp_cy28-l7Mr~W;VWN%S%~U(9pbXD;`CN)L32AjP3dS z7;PVKXo|0L(uUS!7mwYhgUw_R{1V`sd-(g!snkfv0!6*9=d6ZWUeS+N4ZAo+taBSd z5g_5Q)IR6b#MGt5`=gZ--h|?vz!ioQ^mV@@W{8a}J!3hKj0ivbjwjH^lWzjQ-t$Qi zo9R|Ikw?+uZD@?hau;+SBFdydP(nm(=!@zu#3=H+G;oWQI2Yu5O*!MF-RF_qw}5!# z=-+_Dp`p)A-{5@k5V%5>@8$ZM)!M>PO5xN& zOF|MiPYP9diqs|%Q6%AjCKx=MWq{=&SnvLe+O})=^Q?vjdA%p%hqbSTwU$EWOe<5- z00%Oi4Htx750`)#c;#A`@`I)sWlepP307C_Pa*UQPfyO)pA>p(W36V2$f>DcEApSc zY6Q6Y2;G--{>LtC#`;N77ZU3(>@2j@#WncGtOe}*td)9_+G=-=&oG94&0et4 z$BAaT%LljT%r4dgyauMGiD^MFLu@h`z2~BQYMTUkji)3K|AvF;k9hxD)3#95#WsbIs`(R z)_opw5I}F1dFyjZG&GU|I>Exu0mWT}Un|~AGn9oTXPG5#Oo^O@Un1Ei`f)qkjWix> z1d23(3t&1!-Qo-l#?YoTN^P{yzKG!s+E^?H*snERNUB8JSKN3ig!sSKP!)nj_3lA)Hs=Z$} zTgKNO(2oLdHEdkXo`2v=(lJ=$$Rq@(MMQEGPQGh1WYA=+pNq!TnxNtcs#;Ss(B0ry z)+SiUDYEelx;)Oy2X;N{>)KhdT&g1@RgU*qT_JK=BZoIcRL9g3y$?6t53xg1nFp~J zDt)`$EAgc|Y$7-RjiGilE>6ujBx8BW`Jvv1&yH@_O9xrMi}KsG*c0<;cX0m5HPIx~ z&q%g5%<{*@x)7UkpuBRJVAnaZCnie{;$jxB*)IDB&nrdngJeowCEIHr!)N%)AqJ{M zWI%z+K|}&a*0Q}AY9ueu;+L_)M!m@jeMcPn@`}o0NfLW^c`Ii|^+4qMc3otg@$D{7 zn~1O5i4ZMGo>dE>Eiyzz+C7}uPiJyK9mVRf7C3a5-g}oh@1Stb!hMJXAjR>A`nFA$S&hVe3g(wjJ(xOg%tc`E zdhvO5#nO3%W6$S8s)T+*;;!jgWVrLPH7gPUkT^^tzK2ha7>`RZl)xo8z+G8nxJo+mOtKAr@ z77U%U19`A*n9r zfX`e%*UxTCvAEjP<@n5dc!#@9R!`++#x)1kDkj=#T;tK8LS5IxzOq&Lj!zk88Fi2C zEfKTrH+9POw#S%-Etm^Gwe=^qAC{8UIy%FPzY6S=D_5?LGZw18pZPL0kV(C}#A<){ zdq3&hKc<&3LrW;%eJBHr5e$qiU2V+I+rcRHNl3X7|7AwXUY(S=$V!mecHNhT6QQIt z{Se_dr;$WnuTE2ZdSFeDt{WZmL6-OObpr8na3$$_?n^$y%i63B+kzp3N8+2Ekrlg% zPExw!jgWEZdP%;O_WI4h98$PdWna9cyyV1k%ZRnsM*l1=XpRv_Lu`xV{AVO_5e~R= zL7;*${tXsLbajccK_zb{rHJPJ@VM>-pr}{Q<{=P^F_SI$9h5O$-*;Q4Rz;V= zCBO98tRp_&JFbaXCRA@7QqQIhj{4N0j;lkzB0^UL!eEKC-+73{4VJO~65-YX5_0md zQoksq%fORQ-d>C8tcW3g%%xgZ{~G*4zXOZueR0IE2P-r=B|_sItWu$jWIM$q0P4^2 zk;#nJJi`0h)pmio0NdzaNxJqTqN4C%Pjpx;;H2w{ePE>NWurG=I=HTSUB{QqVzxv%C*mpq0Yyn|_5qaS z5tRt{rx*0!$?}r|roHDq#I01BwCHxIhpy_=9M1u0Mf(zPK7W?egTF&1ik7$0Bs=Uj zAf8|Kg+<9?TIEdEB=kqk(5 zaThOPgnZ4akI5F2L9-Px6G>{6et+il1_E!4t!;LU8F>Syq{v7R$fuDpm!98;;c+R7oY~4I#?kv zh-yc*(5lzm5n&fv?!NX$BhOr>cqlkQ3XPcGdQdIV8@;`yYn83b)JtHne~<=xiSh9( z%2vIQ(RK|MSqJ%A4mcTc;fk2(zh$v35zr$@YlR}@%ORVJ?dwNDH@bvco%RDr1U#D4 z`gd$Z@w{5gfKML6qLDhMV?G<@i^*Y@4YWkn955iM!;PL@b(?G53P)8SMo(d<5kksW zLevu6d5QbNGsOH ztXx^yAjHro!^~b(PgJ%;^a^kRwq#%v5}Nbfv$DQC5*SiXvM_PBV?v(K_~Gdd{$6 z5wCJ8A?@7w=*haQWh^LuLgIWcU9M`(+5iYP=HODV{BayU`IyXLq&os0g|d-OBcxq? zZ$ZMt*(G-6&&O^l)Unz>J?P^l#@jttwV&YEtV|@HkVjA^Esw&*$<7Bq4;VVNs#v~G z1*0(7R86>Yf3}3MHkDp1b@-*x26#J6R-SQJDgbk0RG3eX3d1HraNB*lK z9m?+LW+xw)^Utw=umxDW;vnAr*OrGZF1z&AhxKzmqhjq~++PsmB`C167qxKdMgh2T zr9WG?>|Hs1`sJyKts%m$Ba|BL^ew7#_BU()k@m+D=RWx~WLLAN7lrQ^b;7+d@GJF{ zZOWCkEgcsc_=Ej^Xx$)u%jGY={diR0K97BYC4bW5e}8jJS8m5nUw0#+&n2N97pOy% zBk7%1rZZmMfTNSs0?&y@Q>RxEvA>t3pAW7%ulwAXr{6e76IT+tf?Ro~Hl$6e-`GJ8 z8$Zv!&$B4%$L@EPcXJ%wIX)qj4DNh2X zkYucLg)+@rkZ;~NhC34e6BP&}5m9f(c{WaB!TVak*E=Pb0(39?c-=#9ANo#f_y9Yscaqp+eabU=! zO?{(76ZcF1xBImb%CLVhEPk5FqwI`xW#%H09`cf^GEryeL_lxhun|$B(5DpFhUYGGOSj&(Ug#C|ncd+Fjh^Q*{#ZuKI9Ln9)(XG@>!kO4=qJxPoc@6`jo8Ae#yD)ah^8^66qg(Z1^E(H>g-@s4MIWo-Pt)9`qx zb!xzPmzNouO<86%K*0YyhwhK(ImJVD*ep6M7R?T0xl$lN!(Mnly95RxmKBpuoCfMny)qp|DGq>c1v}NLFfjPq@hg*;id;~z*1*l=WmYo_I!>bC&CR#d0o=k1K*}@1P21K^ zA=>C6iwk_MtKGow_lmA5V6CtmGahb!WYTpfYP6F@t!izdr)!d%@}7!(#{P27fUbJq zjTu|{Fv2O)uaL;zo|w)l89SwP4J51VgT%Z1kSj4eb&kf3yM(p#TEQD+!52O*J46V$ z*!RZ!%U@QxN`({5rD8wm{V_HTKERmSYJ;ZlPnt$iBN383W8-5^8jeb=-rmZ@{PS9s z3cWiNIPl$a_SumjQ{1Zkg(`1oPmdd(5fE zYC@$`iK!FyMGgoZ3@1?@2MpHpan95exghAEriQLw3Xz-iq-&VD-EWpxlgtOcS6Am~ zKJ6_Y?^n$7;+2ft3mX}s|3KBh)$Omxf3;XFtteD7v+8xho$*yts6)tNg47M#aqE$N zvfQ-FiuUwGjj+`QC7Ib$Eou+x@I>E#9C#pDt%FZ@MVw{GgYv|G`W4QHNl5%nr~d-P zygmSh3YO1MH9|vBvb;#QzA~!my`3M=47+6((5gmVfotWO896F`9YgEt-6XFI*vN9K zwNq(|*#~={&pitZJJ4E}m)HEBvC_Iblg+WY3|1|sYyc_Uls&db>fan?y`dc+l~jp& zk#c%OvsmZBGW|MQT;ek(Zu>bB=c8>c;RdHoCR0pQi&q-mQE@WqKr5+5fc+9g@Giuj zCUGBh(y9c#K*u=_dwwJp%`;Bqc#Jhx3ywlDH`($8?XzQIV!9nSEGQ`}`yWUfJVoUO?q2K16b-^TUxUZ zX8j7RFs5n0EZ19whK2QDDf0UtucQyE0A?9=|8^BHE1wbis)A;*#Vv%dv1+gyXhCW7 zOwz^^5D+l+H?``_Ye#Ka(k@XRk=3?(b6$&1LCx7iKOW`*Ls9xAAIEx609g`BwzYFs z)0eyy4aq7&8{O;f%#{2MT-*1NGgkPGCSW1|R$d}K7Fj5%-n?6ra{22;9=G#hf}7)} zOC)sk#_*_UG%;gi?DGrR(-DS^vPa8ucTk-LbC0UxxLSB-b)tm(R9=+ZX%Bjh1&BN| z`ToJo9ITxP-52pvV{|c_)l5@VMz*hV%zV0;uREI=jmh}{D#%fN)A2i)c@qyKlr*~w z?>+9ADD-RcLi(XiwE%-f*EeR;jG*ZUBGY2P+`XZYdwfTMeQ?9RE32fuJ2nF1)t8AK zt}k2n&yvywR!Y{31#Od81*!@6zZbOf681_oO1YsmP0qLeP$hm1KgMzaSS)eSR7M;a3dL0)63QK+fO{Iz zXz;vQ%_zd_-ZhaZ$$7y^PcLJbOvdqvj^a=6G3f?RqMari5s@*!+5!D?w9@gCs{d-O zN_E9q(o_vqhpMT6KB_!3$(cg@wRlN?a6n(p#ps}t#}|gbsHn&M8dx>LmJ4ooY5(*s zd&{5(t|>(oNSnL7|mP{l8zj;+T~-Hc~Mty91z ziBO%>ns0AfcWzOdQLuUc-cLE;X<~`(O+Iez&wY7XW51@8;sHw4<+WT&9+*9S8vpB?k#``TiSk*po`GA#?RjmR zDB2E2X6C_g90nyF9c!F*HAEP;%Qf9WE5vS6MRB3Z;jLC;NRad{q_2Wu^U7xMb1M;x zMF#4#R(z8zxXm!&qOkAYS)x!yoKDTJ-0{5OvZZL~iAS8SF#VF~;g@daw}|-Q>;1gz zD53krUOi3yz-i2vzx8ZG7U+^~bqciwcgB(q`%*3Km{wYKLC$EXH6$t9x4YxI0lod@ z%I%Jns0$t>YllD&iRD3YsNS6v&VszRC?FgamW!k-eMvpXSw_7-_o7kr_K;Cz}V z+ntHKb)1=E#U^>C!G8qzQDmFWg)LIqB13BSE6UbrRt0y;#@kp842?B9SG60k-2;Z` zW^pH6d)N=>I~9@{y*1elrqAYULph5d5+Fws4y$nzA=V>u^EUy*ByQ}u7ZIJG*1QkIjJF4gKZiQ)*q^y zDl4nh7!-7M;SZ^5u|Am&Vb~q()|$R?1Okr>N=Vf-V#_s;sz~$OQAcDjFc+qYa5fs_ zv-0iGl=z%E=e4)phfhWK(md);HP0C5x*`yPK_qGDXx8Fz)s+sHyE=*;2 z7o_N7lBQv6kF<&#)%q0#SyyBF9zJzE<3?snUS4iAKb$E<+v@vHqzpAi_6!8Fo5!)5 z(lkeks8Y+FSK`yYv%YHN-WHTkcA#I9r*hgY`c9OQ6qzT)MGK9Gd#tOh20)OB>pjXp zpEX62a#&;U<#l}Nf!|%)T4fkw%|1eT@rVCiq}iOxFu#`|eq%R9kKp|pRqB3X&G}2U zj~FH)uFnVmpC)cP+En=I zM4XHJa8TMOrt2&z*Q5+)r$QuOE&KRe4 zp1VF;l^ldS5_fp-dsRARdCZuM9IUFhjNKphcTB&IWUoCuUUwc<_u~5C46VldSsU); z=m6V!vJ=iwB@D_cvN`xr%47Z54szxj5cniNZWcXdzvB(+kQs!k+t2BZRX_-m$yx@5 zKhQDM749zSh2pt3v$`e6Sp;sBa^j-YeH7&9CxW4?~y<&;=z<_>ZnadXbFeW@-gK=zvi%nWeY$_Bpv*n9Fl%DBjvS zQ9JK+wqYyCrz`C|ck!xk799GG$%x@C3Nk} zy9TNjh~}g}`Q0YZpsA^zu(eVQnE6WNm-plGWaJk*?tw&qq}Ovwax}f!oX`HMy$97< z#53DV2>*GAi!Q`^og7!0G0;17g4RcELXT=DNf!wyM=U586+}O9tVxB;G8t(bk-11q=DGuQ$`J;T zBYIj!lD4bCoHP3EGT!iIEkC7Nx@KzE&+IsYH~sN`ThwS>_bu%1)XvCG9&Z`_KXriM=n6+68)|sIHDAW5=J@!gO4!z3VuIJ2P_lq=Fi3>od zyIOULU3FA{|L&!rr&$35Ep30$8>xzT$(yP5Xacs^3;M3aPmfSvlgk6J+5R}N*v9jE z3Xj49#QDf*nIs^uogJ~QTZup`4IOlPkxLf&SzS1Poto@0;`)#6OmEg)wC6u2_ zZDE5Y3)q2Z=<>eT_}!1)Yrxq87pZ6eyJuI^{Rtl6Rt-K!aoR=<*fk4f4y0e+%6Iu$ z=O5+>0{gJcWkSmMA7RtKD4*yiKrZLA2I%SiMS|-8LYIHNfIn8KpQ5yM=n9HIF$h9h zq4B`P%;oJ;(lW7QIlH6>p}zcwF8}L7!4&=pHYt&J@$vWXDJd;g=}Mz6?^_FieGc`z zY~TcJBX^gtY+=WM8FNLB0D&R1UyeLt^1dmJ@QG4)viIEItV74J43^ zBLOVUUx^KV*$G%GvFY{fGJ~ump}rZ+>caJ0oi`Dng!cbL6NqgCAfnkmjjrldA%mWN z9D+8Kl=Uppk~8zgx|bdy`vTwt;Nlt^xM_invqF3-w=dfO*aU|+E%l*ZW1~^%O}g*n zK=AB$@qhE3?ixkm4vTby`ZcSvE?>CsD~(d=^%J}+fQzo$k#%~Wx0#;5IeaB8`wT}N zPP+xm*@;_!OUmdV+G<@QqM36%NvxGQ&>+Hy11*!w=`aD!Eg_aO`+=>Tr>^Tsuz?Aa zxR(yz(MCHVkffsu2D}FKqqi2;s)fQ5(rZxeH5e%_k1)jd$MqhzerANWIEf^UYc} zsp9`e+%Twp)9m|z7$M|FQr`76sa-4C?@Y>rrswUuFzp^iFU9(I7{slCbaR2FbRu&t z_!2XF5xHL?I5a#=o%u{^1Suv|6#67qPp%V3^tVJYU)!EH2j$dbnUp!xuc-?v%MtDvW z@X?PXc>1gc`{iIHU+^%@77sk#fs0|Hn|`r0q%m)w6Ym6U#zl2ai( zM=vnBGFhOc@lLNYUwz4(+c#NrAn{49IQ!zp^hiTPVI-E}q@Yb4_QYz~Xd z7`iOvMD)D+G$(1RV*gWLGAod*&~KNn{I+c&l6`Pxh367XGkKd(x0(O3(^NMtIZ6@Q zigA1IN~uKWxX^je)|+_9yHqn#o!$Pq1%sIdZ?|M(YhnG{6Q07n9c;=maFW<%(pY=0 z+oDs>dg;uPY~EqQTs64(Cj1c#3F(cPx;q{CapUFU@45!p&cO!$w@M5!@ngNJMa=@S-Vy1A!255`AR!X)pj@Rk4u!$;hHYhmhKeJ*k8#`z7P z@8LOl>v8ha#z)w=G&*|^O)OibpQAzjI=>OczP?ULc%+#9)Az+3A5P=(~yB z0j+?p!HDO<4&su`&{l^)o589ffVGeZT`)c6L#$)pSJ@aS%!8ahE0D*EmO1sS=kA>wW zj@*^=CG?ImXd9uk0G{s+d=kh@e_>E4`cA}3B$_Ac7;TjQ*$_&>=gpu)8*)S=^aLTX zeH#BoePp#^Q`!y~5V{HP>~a?ZTO18U`s%^g{=l}lj_ZJZE83Ujyu9K%ka@hv@5LWt z2=wjfT$p6rn@uNCU2w+^M>~-UPfHWKecEt6@LZ{d;+P-Tc_>V!2VZ|qIEYDjOxFmL z3Zy3Y{kKnaBX{^Iek2(IV@^aqhg@=WBpfm#?^&++BT0PUc>E@g+>BKM?!&f+{`MKj zA_jk2vLb{prZtO4-XCnripn@x60_5@+vBWVwnbT&@xFgqb+ACdsZP+bp_9%BWsn5_J zSzs))8|M5}kWfAJ(-gJ!XMyl3Rm)*1#Avs|^=yi9XRf*Z zwR^WE7+)H?MwBM66Sz`-Md846U5`L7>4}z^LAl!smD#O*oEb$o8Cb)r@0Ch^O!?&^ z$@!>nsWVefKk~`2aOR@P+uLC}c)=#vJ(%!eDK8;svM;MGf`gWng_oSBK5< zUfQW1$a9-mfQ%jb{8W0ktB~Qi_7dC!hdTicTscH+HAeJ>7wh{CWMpM|b1_Ty#oZ|D zcHIM2Jkc zUf?>9C5!4ZXB6uO*0IjyPRCrr2RCK(d3@)ue0Pn-zjMH2o;1cVz+477<{h31IuQ6JgP_?Gn;lA}L(ve@hr>^O&}49vVYS`tl#{EE!w+P)DT zNY>WDk-4a=xw=;z?{6gK*U=@urbEH`e}uhrbYKr%fO}gV*P`y>+UQtHR^^ z!4Jb=l^x-^XtDsW1ZZ~SgN|e~Xtf-yc1-$b0zj|Jk5@1v>{=IF`Nf{G`|F@*-IPA; z)l|2w!V-n>oSJjs3S1PQ3=ER`P_!#2cdoun1U6_;x;SdqsFI=*X_U25hT~}$np+nI z7SS{8-K;r1c#XEoP`sghpT0#;BGd|CK1Q;zFYI{n(EFaD}tla-fQ%v}ZoCv&rFkYW3CM|4Tft zR=|9xU~H#FEZ!b{JDL51O_fU(&hkyN4#s$+HTP6TNupdnZ80oH;QS_uWsbGRn4yM@ z8*ihWSA`6qu7mY8%d4>Sa&-QSBcw&<*M|)hV}ipRUxJ6qhB{LObgaGm7wLOMp7ciA ztnhSR-9oTxnm$@vE_0OfY{^DIH_ow9@YLs|F<%ynCZ&fbsq6a7>sA}(rBZ*5-!wFd zQ{$cuKXuCD(U@=+QVy%@W;5+rZ`e;fDtY`%-e%)_5GHj^c=Haw7Fu_0TgZvI$>Su& zl&ArG5Wx|TkLw&(rbH5~{=N6WK?iBCWKt!k)OAxQpLMaC`Seg|E1K?{*)S^=*dkPM zHyYS)u&a6W6jFkLQi6_sfOw7P+q#NH^;h&2+UCiXyks13wZ5@@JtRB&WkFaKr2;K+ zCt9I|s#D1C^maHMwKz&E-Q2y~m2uNC zKDCOcT|Q{~DAw(CoJg*KK&Z3@xMX2!Wz_}LSYfZ!cG;lgB@4NAXY+ITf<`LEKh>(Ws;#tbe&s9A7ixEtN!t=G@K zhV{s?u}_X0ocS59({=T+HG%_ucYz6^)dSrWU(R1s-(7A@9UCjmvh1*$@j@gW=bl?N zi{CR!ono9S9xz^eX;;}%INSbvLG<1oCb{b^v2)U~@6f=A@Za0tDD*Ens%bz9C_)Je zh~QHTdFw9)&qPU`YI}!MM!jal$VLnPYKr^pAFXL$sdu!$WW*2a2Im%+qbH%$JxrQj zz4T|x{m_f945=J=(0PhfnO9MzD5PFyb2P`!(Ne~DKO)F{h-BU5GchSr=Aj@=f|cXmv?JUFO_$od8WRpqZ^x_}7{`7mQ9v`E zpnvJ0=m@=Ec2*yxP-DL41485e(&z5UHf%^4Xn{w3c!Y&ibg=6{H-I1t*a&HA;=Ubc zq&JTn;V{AW86i$)0~B#?+$#6YNZeCu_)rfdb}q#z*r;bwz5=rdOg7@xBnn;#{1P6&8CDxR}q&x^ZlBZ&NAe3oYwY0Kq)Wj>y#sH(; zQLmN*0&j19=r50B(YhzqPL#=Ey=Npos&R7@$utg17qf}XNwoQnk@?f3BY{25CD4aD zH^!k36mwt^^Vfsf9YBiXlg%n>+J#f4n4aD@AWSmP=Z&YKV!9WxvK`2&7G@QIN@Ukq zQ{zSTgi0)pr)ieC=jXjTT#?S*1 zg&S*}S)M^FH@%7FUh`o!^;ks@m}O`%*oO+r8R)|ep9BaE>%&>V*{x8m(fUCnk6I@j z8Z+nGT+SIP)+3s)3Ahi_{6QV%joEkKpOIW)TaM;RS$Qi@Ypm3AU(Ph)8-Akjnh$i& z#jx#OP|;E+C0w!HyHPnwFQcg%AKvKBdyY1aTJS(|FT9D@GW4tau+A>8O{OtJD@C0H zM#rM57))p{2eVBqC2uiD1%Dbb*yp@?!*+4ilZ@97KUZ?4XCi#v-ppi>DNQQGs}p?tt80yQf$5wl#&XlR{B zS(8?chq~E8w>`UUHB83k&c@<`%IAa{7Fx$cDmz@}$5S*Ztu2n@?{z~8+jZ6hg*oER z*Oc)2!&WI3E{(Q4&8CV+KQm1riN}vYRie{t5)6=6wb5m*6=0sZpD|Z=;l+RJaoMVP z-)qQ)nsj}LewnlNyp$8JVVw~kW+(Gxalb;?dBhkYz`D#!rBZIPe0j;`4;-kGBluMF zm8bZdo@oBKVT>1p_qE)jXoH`jHH=pt$$~8HJL*{uC+dB?pIO{X-P+K(m@|V{19dM; zcqd=lVh1;qtgXe^8;enSQqE&pI(i5;3_M@%jTGVAlK$i~`%l2%p;hE?a2m5|gCU7hTx^}n$mq0m=rGi4_{7Lm74?nPQWqvWyX z9$lqdW?3(#8$)$rs_)0Yef4Prtuj0Jl|+-?>KJR)86O+hFj>uoS_0p8f9*kn$RpZSc+ll|A~XW zbMXsKW8|6Jv{kE(@-8Z8*5uXTLc+(TiAHs9?mPiax-!Ft{NjOotUc?(5+1HFFk{Z`UUvfx~?T7|vrm zjAiZq%rnV8rK%`z^(t=of~Oj36J}b-Ikg>xF@0J@aN+zJ56)_wG2~RS+JGmatmfbB-lFgS63}94r8mZh1$_XsGeEXmKV{e>$X5F!$c~ z(;=lQ3&I%p-dja}SdTtLZZjXu*04J?oi6$PUc?2V;`x4X?J1#ReaQNY`fyT1#J=0= zy&il*f-!LjSsu&a>Lh^3AzZfqHrWB`%U`a^4<-oLy(}*#ED9$m-vh#{o=8(O%Ce3Z zGkr^k*^vWVoi#}Z!T{#}-$HCd$k#Y3J^%x=!=Y!c}P#F`=-6l|K|Xae6C?){p$7S$EZ}`Z&<*;?wAGv;nLu4gIpyL z{`aW+n_}Eg1hxZ+kXg@hZ6x#m1HAs%T{x7$*glU+Wk9{)KT7@v62%UL;^I3b3t@c+ z`75^Zcj>tyC|?rgG{JcnZUEriNAXte^S5)hP6#UZ!%tbQM4shvoD-&b0-C_Qc!jlm&CAbM*Vi-KmZ~%&CWCk&puu|m z%eTK~ZMNFsIGZY)pyK@v)A`KI5H>w;Re+E~AhL{cu3}Io)t+Q$hmg+clM?eZk(+~5u_t(yUsn9gWKWV&W??8m9<0T_P`cjpE$O~K{ z8F;yC{EZda%osW#Anb-Av#PANm!b;?=N073m$&IWVWzzN4OJjgKtjN?y>vmlEXq{ag z-*%%F|8e~mj)JoMdOO&~Qj(r1%l$;c3wEhOa_B3()RH+afE@^8tZ1Jk3aT@GlV2d6jp z%icJYRMWxCsdkP3uAA~#Z~un;lU3{PD$;1v2L}V)?e5%I?sThAiHnr# z6V1+dxxKBeF&QANTUZH6ww^<7K5wi;#K8&ixb%W2AmC>^E`U3b0BXktz~O3a`@?&S zSX)z9ww^F)*$t4jKW=&!l1NC{M*Rwt zE_c+>fOcq%Y?G3pr=A4D9S~qF5 z;}-P7oF2nNA^VPIyc|0!GzJWgXD86ueE}=p7;bUq3r^(_PIS(4?Do)J#)x#9boT#8 zkRgCC{}1rw50V~8!%3y&d7i%}lcO}D{o@WB8|!^+Y9Fz>7DP>ru?K9du^t0LO3WND zHm5S8Da}3_404~z@)tj#eRe$Z^;(7w^86t32kZC&{{Z$%^rQ8*YAn9hioRI&lA7rXq=d16{4W44n>I z$F_&tTTx@-7Uk{t_s^pOxOr5GiT69dkXUB;6%>HD0`~k+i6&huhb8exQ}51^SUL8> zE6p)^dHvN{D_UzStAG~}rFxaHmb?Z=y9OCiW6y^buRmya)x(2!3ytpck+SGV!=5m` z_1D?hOMhNZZ5|rV4v-BRnc8IG-hWmN5BbM0bvLa7azkk+=BGMT8ZZ@YD<68&wR-zA zOkYV_T;K94o*hU%&C%tW=%a5m7yDMeRqN}7_w3|D+30mtOA8#-x65m^G&Em*q6)T9 zD^7*E?`@c#kpBAzAM}@WY#T`0IQLrH#GQ`$#5u)85`H8{hp3dqzx|;97>_qENE~(P5{7kJNyS5-$w%Rg$u2f@V0mv8{~Nxc5QTm!zH zsdN?x2?@+Qc_j-36qHB~AoR}{0BeI&4#M@ahl0BOKe{74RiIlwHcsB+rROshkDs2N zo{p{>DA9=+4Wl=sBccYx{Y44C(TU-ko~`)8`cOBG(t`izoBmdlDA0ep4<}r`M0wL4 zQ+!rBT~L_EoCqjVmbKz;57+ktKXZ$>r+Wa)^$)U{v`kd%Kp*0tUG?O;WGUYH27~@htVDiC^8{_604B$!Ti(x3+SiD5<3E8jJoHr!Jto zMhHkbZ>et{-ewQ@*HVl{G?-OayfuE`7L7k3C(P&rj3Vp5HvU-?^iv>7~14q z(g840fK%a+?Ede7?*IB`HsU`Z(XuiH&BQquD}J3x{xiv~D6|SiN1`$c%KN0-Oz0c? z`b75cOj?w)GMePE|Ggu=e6FIL*e@h@@2s_>0ZP|i-QUjz?;8`lYnJztMdUv{yX_HI z-j3q`Yhn3Y5B_VoNDBj7Csl&(kkrB03Lk%~@0Fl6@~vLUOf5i1o0E(U8nBS0i-$J zhplY9!U$LvQYqfp!<9}+R`A)$`|`-x^x3k(&%b7@wHpI0aK_h*qtka~6%>#d7^Fr$ z9Og-#{#Voe^{4w<@{gQA;@ti1d<3laW%6qc8w;xlVAu0YY*Y%%w4KaLhr%_niS^`X|P$6g7m-bC^iX!q`@XJ3HJX|?2leVAZak5t-xO3zg+$QdFj74 zDgwQ~FY$l$@>`xkKL3YVSY9%ngc~#c!4*(^_k&W`)fIw7iVZWKG`q4gI9t;#T4m>U z3LnUYUY#%Zp5X+)nwNIF@8yO(a#}RT@LG$1LEmbu>myX~fbSt+;{FLIIK&~fYtcVV zc`2lo9wie$^u4LCThhiptqAaY=7`?eZ-a>{`#uE1cQ^27RTr;_ls5tpw*SQ+eq+@01YT)>slp-j}KI63f46XQ5`i}`+{W9u8>q7*!=Pl zS-#IK+B*2IyC;tQbteyz+dhX>9rs}6Zh!~4)!CubA^cbORw7&-aZV@&07(x58o@TL zQGZOHvZ{iGhvO;tC!P~tj?vY`@YvKE@XXeG{xvx5irJgrkI%Y*iq5sOUHp(mqDtYB z&PX!qHHGNlj-M>Mm%>WL`X$2~;ar7|=$G2IwCWwzPfuc=HzT}q3JOs>JNBM$*R&;- zm8HoSYS7IuCv(+5wncvFX4}}TrV5V5%;sUunTVwIVbuu=w-^_1nys|3-Y_^5dAI;9 z>9<#8V^&?r!r<_JR%P=x(xxt@P|)9namuT!_+{K#!;P7Y7Uv!x; zW2xh6ytD2*hG5LUik!G7dEWWJ-Li`I6MtleT8AGRESx>AmsReLnCbyD+w_2-L3o8` z|bOwDAi8qcHyLa|eNeYW!9$1aeQ?TJf%muHe24Ds8SD+Y$OFJxE3WQ&Zz8;)!) z1b*psP0spGD!*6RrL?ZV&mCcy9n>IlVs>eXx;_>loI>#h!s#hj(eXwygn6xW?T?Ov;TG2MT|Aa3SzWB7%BShzLZU5d$EX>7=At7 z@>!w};r8$*DVSuoOv)RH_lPYoXCx*enF&K}Ge2MTdNDwwa`o@+%_INTo$Tq!#U{;g zzU7k@hW?m~NaPsdFZDLbsa@iD+{E!+4ApXWTh?iD;LoPUlR4n-AWFR{bEXG6|#aQg$OkLE7>jka3U^lv_X!-lcj z2d!hz=j~1lmnW^C5dMk<@%BSe(o3h@mT(KG z?j+?82AgNOR5?>`K8Ik$NE{?8|8}njmD&_GN&^VEX($nWT(Zg~%zurgWGre~Rw5;S zr{B=L#}J`x2bkYf%w93gKNO-nT37MP81J{~a zl-1H4jq+;7w6bqkbjqc}e?sMB(?D^OYC^{wEpuA$4drM)V!Sc4g+JwwPbQnXqvUwT zSL<@QqYzbrq{8@ETWa1lg}`LI&Ip2CQVL6q|!Wy7usArusJJgJV0- z`s;oOv6$?sBU?H8jB2*P>>C`s&m51TyzMg#623CkV|DQOS~y`-3`*6vb?d33+<#rJV1p#({a?f?KEYMA+Z}BvX#nUItOiXJT4C>xIl-bF2;5x!MeItuvW* zxXKr^BfP`q4%*#rhv(+vi)6E9RI7@hBVXl0PsfdNP3cDw=Gyu? zwPz`=^2VM{l)bCsM2_|e8#V^Qz_*H&Z=UgSc4bLK6JCTfRi9fkdZPsfuOp1x2wlNi z+)n;tdY%Ba$^_5=7|e6J0qYM=J6a4o2eZYbzrIQ8JyHWQ1g3C{u`+d$IZmN0Fn z3%R=>!4!hF)#}y|dg$=?L474XSddAMx!gl|341S|*db7Ja5-9B$?`<<^F>OiRRJ}YxH`8SZedh+wU zO#(Abdlhqs(q-#K=(O%u1=QKm#GF~58a-?)JGCUw`AE9IPxkcq-t3Q&`ZG91cWk7Pjsd!|$KaC9BAg&v>oP>E_e9<_)IX zZJrDsD{klF3;r8qzevE0vrf0&EqTVpiD5ft^_2_7uoJ*|v>XK;m%r(qoiA`q6&;2L zO334LozL~ugvuDW|BO{XILg2(B6czCF3&aQzTJ|!g1P=rzt{2i)N4v~ z13a_rxbb^8OD|+MZEbDg^!poz#iO|WRr_&G@k2p^2fqu=^l^pt>1STK-+L{0SN)*5 zP|%zT2=hrjRzWmc(xFrh4ipM8XSXe(Y!we#oY`$L5P}+>E-NDTH*?frp%JnlD6c=h zv%k2`L_T&F6bzerpPxVH5Qtcu z2deHs7$(s9&L0u27sU_H0uXQFdCpe*6vpl;k&%z1*ov<&OB_y;24*UbFR54p^~ZI7 zkB7i;AW^>xiISNF8GN2?L8h{c${A%ok>@)Ob9}#(&z}WICgg^IKay|TmDP`v_4pk1 z_C(^Qu(i$J>+%#VuO++HUJAWi4;^PubN?m|gKry&|CAXYA5g-X(50c->f?z_EVO~j z`&}jkqHvkBv0~Lvp2z6~tieh>^Kg!QM{5&7!Ax{$z20m_Wwmx60|0mgeR)0%^-Bdj ztnnH#=)ZCq4RhH6)@(%s`&)cIOq3d}F)TX!^D10(^w=%-W+>Em@)u_sztfYOqBhYWcao=WpPAC&I#%{`&^zdON<31<1 ziYnZ=_RlGIJe}{U4`slx2MpT~z;@O8P{L~0=q^iban_-(K;C3s zI+kA+Vt3}rDuiNw75lVA?+;0#laI*QOT}*D*N?97pl2e})8Ef%7@LtmW)(zF{>6Mj z^n3LUsUJG=1FfWX#HTr5d@Hp4%PkpyTnr>S&3)aA-jg-fQVeDzkzz1m5M6Zh9uV$r zpQ04iOFBNW`rW+Vkv1WAegBIgBzrDhq-dg~Ro`FlonEhG6ck}m_`EbgvvX&$CMGqX zQ+BSl9K4fy?c@d;T|*fGm1fNDCCL|;k#Q*TE8~b{7_F3Jd%YJ z9{W36{AN)bZCLKZY>yyVycKWq&3{b(Gg{GOGQ%ImO}l+ z(~%g#ig4>_)g7AkFo1Vq#=p9bLaV!=Tm4f-6U-ixkqLPQGn*$Ob{@xWoy z2t5iWwFk`=;v5?0qslu88m<=V0gh^ik4$b~^lQa&?vz@r&YM<+SB$tuJ}<~7iD z_adWtP1bGAs>sPph4#qtogX0AWH<;&u9NjTz{;tYTtn1fhW%5BZA8 zLXDi^K8RAGmXSB^PAa5}5n`SQ;$|L*Yw!?l)4@Yiim3yPcTKYvmo`om(B_xtwu>Qm zZsiGxUi{gqveV$*DXj~YQrnsjC;52~*(laHSZHM6&?X9&UuLiNQq;UO`r<`RtETgB zmxMBLy*kK4m1lbP{J&8<(;SO#cb)4OSu{aFKU)N{jq=LBMk~&D?%o}Zqso&K^z8AR zX>_#fRTH#;5N~?$Julm(Bs7k_sk`<2QK7j|)S|)Ghkdgi>_vqzJpBpU6bZhF)?triMw>q-= zDhfWjBx}_%&vXxKl|s$WS`jHmjMKNx&R%{sv~`9d^lg=2X(fj1ib1Rojd0i?2hQ85 z-dSQlOD9!}a>bJ0*plK5^84MK-^dldcac3T*HG?Ypx%qo4TD0T6u{nO>Dj;nXEd3F zn3w^iqk5T6M?nmXYQ<=ot=31^JBOeQm9}Qng;%Zy?1By5!8e>(+1Vcd!Lk%QJU3x1 za;K=UokLdnKu#aQp%A_u97ooM6$5*Z9}8S%bEYV;noco0_#!j-+0jH&G&vSwp`aP5 zVqXjohNTv73t}8vK?_PaIgSlq3Px0iv#-7q&|-WS$1H^>vDSk*dBCOi?naer^CVGkY}w8EuqI*~aO+nX3|Pj9 z?VlNV;j}aGfi$x$JZYFmhYs2p=#yMuQs9gUbve^R6@52G6*mvzEO(ReJ1NNx-dY~> zC4?%gl|xJfF~PE^CFb!E70bm?FL z#_zjHq|{SGJlWTLm92KiW2Eh5L_ecv`gC-?&m;s)PE?QC9DZ7A#_ZcVY=zZfOv95S zv1eKtfzCp2%-U;$*#1EQ-zy&Azy{^qsnBDLbZ~6F-YR|^6*g6wkl>OmO#_eeeI+v+^hKw}z$%%e$7Jw=X@a~1{uiV*= z9~cy*n8g+4Hxo!*L%?ruCDGHPKVMC-_HrZ2O-2?W5@kh-VL&%xFIg^=X%bR>z1yZx zu2MZ`5}Wp!KXEPQ`y8ByIYNr*Re64KyKLz>*cSG_@Lt`kQ;*d;yp#a0!!AD3m~5 zjDpAkDu2w~pDr2y2#T&A3V?O{4X#^E-q$mvh@NDy837t+ynLpQjk}e*QTk0XcY7Nn z-h0`U_xtBw*nFlb?GA;@~`~BV`)i{+8LC#O8MicPMvSXI1 zOG&$BGq}fz*L9FEL)12sHS(>(AaoG4(JvtuPWU9N3p3g8rjm9buwbrptfCqi0l`PB zaXt-TXdS2Nr-qGp69R!y@|he?M-S&~j!#$m4$5^HXzxRvObV=KLfuB6tkEK@vlQNy ztX|UnpbGtY{y5i1i|(`C zV`aSQGs)SNBx|fAxl$4rWr``x9VVll{RK|>=z>Ip(sT?ckur z^|5#~(P=l7+AW5Jc8z&A(}36h4zwN@!76b3I}AqON?NciB4QM8@rN4Gi`whRZ6MR) zWvNAD$lQ^S+L3Q6A5geYn6qTVA)r%AQ?{E+ZyRjhW0J#*g>C{v6&nU;tZ{%f;31Sp zVHaerDObB8olW+RmWoH8AEF@ebHUG$s`=JnHQ@Q4!G@#};u=PhTR=opqgfr`<$jK0 z^r5-fQocGocyD>y=u=?7b}Io+P)`(kID1gR!sd1=HVeRPshZoV+}!(2HZ;6=tz8>s zwNVv&gAnEdapL(_g`5(@sgK&fm*oth2wvwr9U=&VVu=C2AW}eJGNQv{qbyHKU6<}j zj`h)EBW9sVvX*Q^W^c}*Oh`uu#J>QW1d$LQ+(7^HnjWTt$)NY#CF~2&TeqH+D1Kir zZKr0CquKqJni#ca0a$&oZ@X52!I3+4yZINiksEq2{a^QwkIDQi-Vt4>6ilcAA|PG1 zf;SV{Qk691cSE0dL0t+)Zg@;w4V^0gcx%N7Vwtf;yER-~w{w!PI*W&9%kA94HA828 z9!)b|CVbrpOC@Vc{e3pV;XDpPZpXQdzV>WB`+TfP^aVHNCQ@!8x-2(cX2fnVewm;GZfCrrI9w`8 z=WAk%ZMR5XJj+!8#8xH7sXuU3ntyornPD>0X^|2olax=eRpqj4R62h7nEQu18<821nLZXv-H!XP;bURJClI5!?r zzS`k*eCDa~>oO>i@v}-9yj!1`kOoPO|CU@h9x_Q2%;y^16cx^x#JEcdJY!dpX#La1 zSce(`16tFTY(|oGc<tE9&Irbs71t$a)@jjp0Bf_VT!o{OcYhtUn4%~YUD^QlE zB4T2ZJ?zjdHFIwRW3o$cJcQW|?=A*?3Y)fWDb=UJ>cS!r<0tspk;HW8SMHIm0!}!y zrvk?1*S#wi>EXf}A1AQOJ=&z#f`fIknbW|qj^k-@y9hUrcA-m55C_f+f%4ZLE&RGt zTOQni?c`pQxUWr}utf&}tgg$@=w7BV5JGE6cT7_oiCqdc*`TQ_m!b-+y$a|Cix~YdAc#;Vo$^!--M#tFADBjDAv$M1BXUM+B zbThIjF*>B>R}l(_$%%!?oYQ@cPAsKbW*H@@BIS;1w{FCF<+8!a&PE~>blXYRW06(? z698>hDiv@%UG&;k<}4o#Lch4W!l-Fbb{`e;svcNaC~u;+AjnWdm1v9Nm}_(BxD!4- z!;*v^glTLj0vUmSfQd2``4s>5>s!eo|7x|*hj$w9Rtn@)*bQ~!=gWRQ6j)5d)WuA0 zgs;$oYf6~GMBh#2iaUJ(c}1-pvnLrAm{p5JI3Q~(%U*<$E_|Td%V zy(Ddq8`7CI(UZ2oDCvj>XUuW!3&fDFT%C+To3U846|tDm*$qeW%AmwSu!b|kn&Jrt z0%n<)ql=3vxKH>I@SIvEi+Qs1ey|!M6@0{~%j!ZbE_KNv4KVQlWNt}oQ8T2g{>Gt0 z-mr>^B#e@2%X?n78GBP{kiOTI*$0?Y zbbMU(6gdlDL`@H@@?!^a3X5&i#@qBPyPSS|Kg%r8HP z$EkQMt$;YRfNjVWO~sv!VRv*=&ezly;4(A9lfqt<55x-&{N4~eI;_e z6j_mlv1+TcJQHp6e~8f6jdAL++mj{j_ge^K)uD$$JJbOBKn8Ht4YVm^lzsJ`!VJ^f^0QjE!N(~sCr;T!CQI%Rl&5)EC@zE*~54iqp`@#d4IDJAZO_`Z)@KG_t?WkP~M zFuBcahKbjA$w{~N5i6cEJN+lC+uHT|C#URHvt=Rz_nue3d)h@`P1`rt#nTaErRwhd ztQ|yTo(9tl4aQhWH<>?&V=mjn$TTxZtjs*DL!snt^WxR25ua&7-=Cppa%r-Dv!Vu# zzL*ukjN=h5h2T1^!B11DHg&$Gp&Cb&;;t)R9B*V?K)31>FH`^q9Z|j%L68-#pdkJ$ zk|5XFJ3zf^B=AC#Nkd-pjt#-&4&cu1(Q7tNuaE_UKasNxTH5k2KR9W!sAKLjZ#^^f zP=M~a)49QtY01fY&qJ=q=2}_AKvDsEvx{upvn%MhU1zTE5*6-OofEWmrQI5vcvn+W zTLSliQi6X)J$2+jQ0CtJI$iIV|D?8oJG8%m1Sntr=dT{_=Ub8R#oJPp|9&0lmQE4n zq#T%Dh3Sh{i2dKk|0?a>3;31KTaS}Uk*NPJ_{96k^YGg!w#mO>!-vECuB?Tk=dIj;C>mn6*DpqmK zip!NS(PlY~JA$Pcd*}hlwmGe+Y_bwhsbjAEPf!wy!z<=dPFg&kr2LCjDg492ujM>! zW!DzQ%MQ2T*|sS=m4$^QdoC@-qxDq;4(FLzzl(!L;&78R3hP&eeHQM4*Iu59nppVp zqQ@$Y=(5I3SvWVX{v;6-#^sWb}9&oHwr}9Cal^?{` z&Cis;NaZ(Im_anog9q(oWg+~sEgu+eh*pWv>x0Lk!S|u%TGAuO)$^h=`$Lkuh3am} z1Ps&m$Ex=Ue0FPtH^Gc(R1Q)b2fWm{bsZ z-c88YAu5&v_MQ*&eg%0kF!&klI?VVgPO8kU#^vx`*Y0T+mAUE;STHfA7Y-Hk%0*B4 z6xd61D$GbUOcd3iEIU4-lr4JyEIguuU+UeMov`ZZR@nT|c$&7an#*o_lpBqo5B&A= zbx@(=LXa7MAr@URX0plAIQ&?meLf*ItZ8D9ljE>3*&T+|c2o1kO+jQxLlj*VtMP^} zY01a(%J~FyI#UbRF?@3hIsI2_lL{wYhv>elyJ8>zWeV8guw`}}st~q_rngFfbMSSW zR-KGea#!gGTG^cCcjkkN`}&gYoNJ{}bVp*Zu9--U#ZV&rDkIE-s6S>7I)X2#Uu3_0 zKJN=pAQ=aR9~W0g#BT?u+nlfki@AY1QnI}RW_$ZVAR9_$-2T|ql1NLq%fX85=?KGm$q#RltxgP7lrNV%Pr~GTh0(i5) zO~+28>d|Jyt-7-Z9eiy2)-U$n@8IF!$Aw%nvL#;mYA-k7(RU^dDc+C0@WSg)*1l+S zC>1~;-bZc9yT)D308E)?xC9$z9CSO_@El@0Ire<-oCwAR!W`$Vyoysc%Iu@Or>4C+Gax*bWS3R}RxV-lr z0KxZaxNsCT|IkrYqSTS}kp{lXaf&-?Z6dFI`kCd@Ave`lWO`g4k;zr$b43A#4A4lq zenuEZnDacpQSW(rL8Z+MGJk+46P$_Nc07F{9!;b~AK-B>37#RYN@w%=@ew7r66N|M zyKs7dEICuVv25&w^2#rB-f%(q7l@%fDn7D!dJ+6mIeBNxI}_k}Le$Y_h{Ja@Wdhy& zc8G}lZru*6=74NvW$pI0F2ANr7x5Z+06=b^C7KDgADoO^$zz7oo_0G9zcs_VJ__}b zPUc#mO<>b~^Hixoc%^{TiTo4y%W+5N-S47GnO={k;5k6~os=pz90sz0dDjE6 z`iSiwn?k)Q&~++N!|CrBm$KPL_ztcUB42E%xJ!rU5Eg;`maX>4Uz0L>K#)O3vkzD{ zw}j=^dBso%|AK4nE_0Qxb77)UeL7pR#DXnx7U_RbG>G?s?&o9qMRgWVI=LKefbvjB z{iX{y;V;lWa`?a*Fi*-(=adh`avxE<3brCb#e!) zm{H9&qX7|6Nu`8nT=hTDLm8inXHJnzRGMIw*3-~>p&Vh zd%gJwLh*R6@3?!UD^Q^h{n*5kSq$CA0s+|>U(P8Uu>SuTd*|p#x`5kzVkZ;Znq*>I z6FZqWnb@{%+qP}nnAo;$-_G+sukQN3weFw2Qr%fyb*lQD{o8w=?~{e|@Y&$&=Wzym zS`|FY*X%dto<$G?Jb#kC8~Ku=5~o7oB#I3edvYr>_H9Tb(G|-YPbDJPz^j9O=)uk( zqb7%@2e{LO-px>p`E{;!P_EgC;M6SJhy5N6EC0Z&T%{S;4cA-F@Ih}o{w(=&Gzncm z(Dej&VpI7I)aZ^~%-N2AhCI4|4hrS5trfkFCzSW9CM3Dog*a9RW+7X&jZBFtG>Sop zO9zur5|KLFoRC>OSqIP4L$Y4&Kg{Ti3f%lqQCOpmpZD}~Z!1$_o5eC~Sm~KUeH~%; z;+NBGib@J=M07u(d*`a@f}B0^N<1NE;CZ(rIwP6Csf;m@yv49|Ld;=RrPP#ePh#{c z&%wPV3u`msHnO}(@;JOn?q2yOwuawRP1?_c@{XL+komg|m&-%445F9_LF1`SC0v8M z5j*1T*dL>8f613;`hM4--dyEpC@kC@6~cS*`|E8nN*W|T7zP>8(PEP{?w-w{ZI(#R zR3kt3=xK@Gnq}r%mDggh)zCg#)x3)PqT@tXDZW|56@O~h~pp@%d z={)jtllw#ChpLcw3a{+g#a&Mm5qQ{~b5v}#_qyAHO#Dz^SXaU0{iYaZ)b4Uwq(XYc&XZPs9~<+>8@hN#wfGC(I`EDQfkXLJ6gdq@v0EhzJ+c#F^vD zZY43%Y)O>Ls&tFXCR(OyUFiE*ltFOdk;rizlG5Ewj_+)WNT=zG8N4W&u58^T0xwX8 zw(on&DJdU9OW>cq$nyUfE3y`5MoJaSjCThkOg!=3uY3xqUak(twfX1Ino_K8{A*#> z*>1>b3-^$&Wy>l71`ID8)|te~AqM0#zVnT=vPrY!y^VxAz#g*Fc;E3&o~fc`RrAh@ z5iVLj{1%U`G`+uBeC$cdb(KMsm-YBfceY!>R2T@myrzfj{e{&$Gzk-Ack9fe8132n zSQzOVU?;oF{`uMzdJ5503B9Tl!atWmfS9#78w8zDC5`^ECK{wPX;fItlCU+Bl&`jr z++2kmu-25szXR#;bTgj*ujNE$Xn{HxC3DV~4$)YP zChGlw)TD(L`QJyHk2y6KL@)O~T^n9W=^krLBn+Z1%6sF@naP1(*@(YYAv``zG+c-i z0H;TN6*6`bG#NzW906cy?;lb}k{$*{2w}WUGi$UL-0A8(0X&TWU-jol{g0CeZ0gY7 ze1tHO*b@Y2bt(?nq(v`|AZOG539o)uME{f0@QwYFSLp?x1|9ne;FHe=g z3wSHJ4cZ)`yvh~R1hxYSeCA@@<8($m(pkc#@0mP5qCz9TadioE^9`w(HHzo4vT52u zc^hsol>1r%G>&E57EU!TBlr78Aj`~=`p{R*wUV!Mbw+(?(g%^&a0n|7Heqo8qI+F+ zneLe=s)j`yvdk*a_AQ{C-5l#eIqxr05bEtEhMduGl16xa%@~hwenlZIc3Tmk7WSe# z2^sit_sN+w8En2Vs0Q#NdfvYlS%ug>&}Naoj8uZ*hnkt4cguUgX6-NZ5l|kSSYgDM z*imJo{!j`;ujMJ)VifVO0;Z!Rs5mU2_ej=Go~E6=d=o0pZF5COdalo$Rpa~)rt%Cj zk=`TmiDZXMHSV2zv6ef!)=4sd!R{;X+^H9#Bm9D9<=GEIdp^)hR3Ah0(U78%8A+{% z@r^UNv|9!Q?S;jQA@tIuoJONsf%N*1sPtAq=i&zP{A{nt0&iKAiUJJ4Lxt^vX!Sm~ z_~Y1y%ORykaP8o&@;(=T?yT^YMQ8knQb*1ey)UTRJHjJ#QpdKkmNz$agR4vttx8`rU)Zmi z?}9Hld}lYC7Kf8B>0d;J7GZAaAAe&0W-W*T3;0>e9^aHxj+(A%hxT6VYZ!5>O{UE5 zmU3s@Re3ZeY5VJZF6tlYU-z^)-KP(n{0a?8uo9i!iibFCoI%%e=Lb0UkyTP?*_KpR zi2LB8ro^M|Uw0^=xvSF3iOOArfCN@Rj8r$yoIxYq-c)pd_@gS5h+PwCpMz-4MgnF z`-o2nNILm=Mg^0COX|Trag;&t>H#Btd9tYgatK)eNY!IeO|=vp;fz!edy z*@*T8aq_Qla&3aG=(CsYb)t@j8l^H^K5U|s5)|B%I-U~obeLj6dBF< z$x6yiznKL)yK3usItRp2`+&{yGW&UyQL$J;6g74fr8WuWX2n%B2Tifs-0`6VcnP5h z0DPU_B9LF#BAJ}8Ej7&}iU#~rk4W?AFnz2bjSIB5r=v1ZH8&oEgvCnaRP;f2!H*dr zXOzX%FtmFxVvi%TA`w?ehV?iid&9wCl&mwAtGX=T=6i>jy2?ciqcF&;31o{k9C89c zgVX}DXtZ};BWfU4$b&rJQcY`F;EtxEZU?3lKZ5*K+*`fgeTVhN@TsJ?-5%tl5Z^ZJ zfpJnM)6h?7-&(}N<=7%$RGWf-?ApM>PEglZq~l~w?Sh_L<>>JIhv@$Jeb}~;J*GhF z+NQ?-c7@bO>1v`(fhxOPJGBEE1qJrDr|TmdBl9jv+pyKJ!?6Cgslkxj{PS_k(^1Ah zik!?z9!`}3Li5aCJj~>;BwSx#-|=)_z5H4b!f(Yp>V&8SJ>E=7r}R-%1u6qru9~QY z4qiSIr(R$f84kGfqVD3BD}5$$Ka)m#TJ>H#HIFU(0`9V3zl5%JmMjcQZ=elF6OC#T z!Bxb4N44=D72y0&d*J6{h6~J?L1uje@Q*M z$Pcrx_x>r3Usu<{9q}gkhGbjp$3y$Y(R=GSUMv9~sxQ0mIO7QjIMldC#zr&?DHLF1 zrTsb~MT-I4ll6mPc}{D`G4hSFal1pLxizyfiTCTdOWvc&{g7>7aI~828%v{0g1lqL zHQbw#xEH)TM0ed1ohz&U1r4cdlR-@RbK7jz1K|(m@2O>#yO%H6!#3(vm>^e%YrJ2#0B>Fr)no@0 zDiJx!5Qup%)CQzu{w;Az323(bK3$Va)qf;7KlwFh?@FPL@07h^#fM7b6cU-RS=VP7 ze5Ff)JWv(L+~xERPi*SDul~){NmWDW%#aScw^@+maTT0QJ)JiiY-b2vSG6E(X6isc zT^w`?vnUL{#J)ZYPMACK@(5}2Tab=Z066rhzXGPm;jqjm^qXu`Y5N#ZvN6@!%FL%M z{HQ2GUpQCd@BndFanfBZd?DWZYmZ^)s*RpnUHOr_1lTR3@1{B{`Ov{*`zWW{hf-Z_#F-+RViC_3C{GLQe_DoJQjx{ zMRJcNAY@@@F)d|JyhAc{k={|7#?A>ZvS8%px81PBtsk;!z3OJyNFXMsU$8?&)EHt=AWYNxV|0!A^yhKe0UiDBMQ`1`1S-Dg7yy->?;a2~^7~y5#>s;ShP)6%yEd|7SQ{Q_ zgjHc;2*9!F>}!J=R~O{ZHLa@kUBb0D_K1=zn$=f!#^gD_`Ryk1Dz10%T{HOVB;Cgh zi6C5z{(g^1%OG11hRPp3gB-_EEsbiaXMmNg-F5~?RTPu2RgRwLBrKse-g4j!B!9{# zSdjFAGgLw%8M|#$L_Hwa&OE&qm_>X*AjG6U2%ZUVXsidJlPBnV2>F_oxO%9hbV41EqRzR-S8c3wFRd~OZXHrAKl58Y29Qy11nj{)9VA&qOqF-h5tItF}5c_(R z(nTQC$!u1ipI^-fl{_Z#&N_i$pZxA~#8H>;@xjL6DZ$K>(cT741(&5}AvME*Y;)I~ zqUhhWHdaZ_BPTmq5zPyTqRNv@VesEf_^Es}Lz>UL|-h*YI&nt%UFF3&HO#d5jpVdmzcg;cG7$Uv! zEadG_faL0fj^G2k%;X2)lM5kOf9h2k87@`$FGmC6FB zDH^l+|7^V#AJgfPaUI?1^r*06@~GxB8>Q)2A~l%K=i2U%$`yzwgy&`{dxTT!So@4% z=Htn(2C!NqFv`=gLHND=JbUgo$3)66uWUCC2+21|*WsAG^()G7`ALvlIiVYoTb)E< zZdOA80-;r16BoNTtL}*sm+yEvi3+b&%{J@4$3zUp0u7=EJ!DWFGhnwYKq?;p{j$t- zzO<2~i5(8-Y|CY93HCl_Mx~f9kD4DB>($#0jXI`Kl6q~HgawLj;2VqU6%>#OFP#bJ zw#2*tZ$s|>Mt5t4fDhebO^crwwK(cgvJ`<*v~z;n(Z!GwDFyW#3k*~RetjjtRP%TC z;`5f(@e7I8R4m%z-qZ{$uG2-O%fm2aaO1K7DxVb@LVPErYXdRqUl9TjK&X4S`5+xv zQb^~(8si_UwDMQEK6cWAg!se`2G!J8fCfSDx=u2{p%QFIl2yzg768)C{14fTE6efe zf7}jt@-}#B)v`cE%EhNC6@b2ag^GtrPsjt5Fu^5FSa~fmwaF5IH^H+)54s7t80+%| zQNi^+c_b}I#&hrRGsFh&Z}FBk)6O2awxpj4gVT7e}Pea6Hzs%ZB~ga?cw4MRL!sei zJ|7o=)1LmqX0e(Kw_?NHXp6YG6~u5HV?JR)y}pXZb-wpyiwbq-lU)qVV8pl_q8ezi ze!-9-C=VvSbCM@-B>JXEBl+~ReNE_%AYo}h{arg2jN|!~fe;htdRONibjg}1Xx>h= zy8CR)>F!WS_#yP*B&J(hltEl}-{PLdt3I<4OwwA$N|icvmHyEqlPVrDyjDMvzK3}& z64HUU@xpDR=!uyg(}$S!m%5c<7J`@IRV7D8~+#11gzn7jkT(DbuSR>x>FF($Ksdnaz z7k$lghNYVd8UGauV8V_ymP2kt1hTEY!hPSCq5kjwFWuA00A>$ic%{z;VD{4sTm17_ z1=S+@5_f+av0cHDk=fq_Ch^fTLt_r1n!P|6%rB!~_@Nc>PmcYepQJwk@6I~isfYSF+ z$~O2!gQM|O-XFa343yrd>oxk!;t}4YUYF;O%=Y=1jo+)pX_nGf?8$Eb8Vo{fc z_bgPiE9`F)*)+~UDjFOuCMlW|*e=u&4GSwkH<`<=!ff+c9otj^HLO<#ed;i;QWYOY z6+8yqqp~_l0`(W*QK)2l2M1&V489Pq_-EJr)_mj{Mkd^ca)z>Ftf!(=JiZpZ3dRL} z{D8ZqMyd=qSbXUUVskpo>z(o;J}92)?hR^t7{=opEb&TO#og*>hu9K5(4;T^3P`$Z zPL{FGw+eFiwzP;RZ*b}r8cT??E2-$qp6aTR7vpW>Sy6nFm>7<8UQEj|L^7@F0;|`H zq@A={eQP_8W>dV?QYZ9Hf_NX--Xw3GckR?~h8)o~%p`y>kWABY_|&OQs4(+%&N_g> zkyNnl<`M)Jetz{7@nB#q?R!@>7e&mPgpyiMXNGpKk2@y^kf)!Vh@!_PCnufBQiLXf zLKqZ;kX1-W;`Rj-Pc|lZGjEowDbE$_N!E49<`fjsPF5|c zYWzD)qzOb9t9M{bn)NT>PQiw>s+l9U5a`c1>e4a|Jg#T}d5^mm?IEg7+__hRn%Koc zle%4wE)nPq6;y%uLf6?3uk`ySzKjkK%HVS+x)f$o(cjfUHaLtqw!Y@6WaJuADjtG# zG*{KoQKJXu2U9dPvU-(&Qh7i&h^uP^OH5l5<*P@jC7Ntx>>T)M%su=9kk9Td#Fqm& z&){*F-+&y0RmBoID5uFA3V$Esa;uc}+~kV7UWBynaAT~Lj=fRa|1N;9GZ-GJA}p7_bbhanhw7n`A3-PL)pPn_h-ccXzf_|B$+tod}F~ z0aRf!d3Bz{5@Vp62Y?F0kB2ZJ$Ff#Gqz8*^_T>_`(@6{k3|HN!AFKc+|B>()qWlFu zDbF!$t-cC$!+#5I2=bU*gASuB15+0W6swa1>);<`sxAhTRUnZB9ajn<@4wBu(ao^nHgpoo{iZCC; z#7v&OA&kEr6ZQRwYDK8Q6PoNqkSjUPdrzUp%CTWHUjqGaNd3dMzX}7j6t}OJU9^&_ zfv}hYHa3m4t`dcqShUB4#l(Gu7_lbC*!&akuZyXM&WP~OL7&dUx~PqRv1E*zKF#Sj z+5xy%H$fWxf0o;Voj0IX|#%ChvZg0e{V1 z4g4+`KyH%Zpza!#i-Ly zK<8Xi{?;bd~noPXtg{)=#bz0?N*#=ObQfV-C!qWs@~`L6>BhQ1Ir&ZD?O!2f#d ze;-gp05~B1hD?F`|L1`Jhfj#EkhH_wP-bd?}00bkC z>1*p4dE1*TzKbX?K-@09k<>N4E*?ytH78|4RLwM2b$;r=kg7xNQVuFp3v^ZRZduq~ zQ)$y8P!s;>6h0Vz=cO9s>!9PQ4DEf8+~t_ALx z(_>@jIGm@dm#&JLOKw`5COTtRFOia|69X8F#qX&khsOCN|nZs31pUCg9K zB#WPuPs%uTdi@69tW_*=o)}c9=;O1ku2x|-T^BRGe~_e{Ev9xsgBp&BI^5Y2G#Ew6 zF`S_Moy8+%tB*UOX!>-;GS}JhEJU8~dQPk~r&kS0N|k^AI7l&c$G^%L9$qrJ4#W)pv0e=ar&iCs&R$n`{@+-5 z&wf$N;NL;*^GOE0qALnBsnID;O}X_uI5&@IxNmms2Ws-x|A51ZKL8~VmbQtD6B(gN zlrrQ9mpR45)}@+7P+wtO`15r`g}_#4M@aLGd-(mvZc*+wHT?_e2l&W6&1%S&y|Jtd zU%C_r>aDa3_s>FDUa6x#w^x=TK-fODf`vp5;H{|tmlzw)?)i6bZP6v>dP-b&EM}a{ z|4c$I)&Gq<7@;nSh=^bzp0+Ib6W&sIa*2IzCfnaJvM*g(gY&F}*^>=NKCXTm9)0lZ_C*!@w3!GlVl`W6cY=XL=X8qT3_wpjS5{fm5jn4n5umKecXE?wIi03 zxpeGjHNcKhFt68kjQ9AaQN%TOSa!eHC%>NG4=TQ8i;TQ;?~)_` zbW?o1*IoJ?a$QpoprbEtUljM>c7kn(cny2gTk)0m340Gs3G!y6OZL;@HfQ!6OT16I zP&E50X+X-$DM74oIk0b#^ANNjA#o*x| z?4NCaL|{L114=DSHg4l&z-4`a@tEA5OBcFRnJ)p=f``UEzMD+9M7BM?tUsuz)aFKp zhWsv*B;Q{qc5En84F9?(5cGiM@Laava+8uAoiN(yI+bd5wMU~7o3GU4A1&3C5uFdm zoEJ-_VePo@4*vS`+}cF)QmZQ!(d%)98nxT(R`TGGYk4roU#x}X%kaLd`hQ+}S>qXTVR%rKSA~kXd#sB&NazA;XyF_?97L)MS=I*yoa3SQm6U%iGsY%&GA}d}a zA80-l)>!g(Gtt1+cEc+sO(`<4yU_3C!bN(FqGnc{+`?=Sj(wt%w8UZ@*7JQ(daq+w z+Hb)W&F3>_3uW)twK1t8>H3V>4?cge(z0xH*-ZbCHe2(Qx}Z=|>I%@NXX9^PP@Uh?NL9fFY zq^_y#Bz_PAJy)e(e@U9GYotU7Ldwt=vr4FFAiKi(^VQH!)KE&LM?S$Az$dw;#j=H~ zI*MrTSY*@y)grOv3WR$~+0lS2Ycc3Sm_ONuE}bWWIL>tIFC;%<`3tnqJ)S%_>A{?( zwg9AV-SCYH8li=7G!n9t-(n+$v@SFxODn;(-s<=e#r1-1{^%_=`NuJ`3L8L@W}BPv zN_!@gsc)T%Aoh}{;mmjd&fI>K*+gQ0q|FlQy4tf@1kTVO0#v$pf3r&G315)$h=_IW zDI5zYB|1-dR~K3V@KSm6JrJA7b~z_sLH8yZC0W4tIKJH13VnHrFC9j{rYC|g{kxy$ zRJePJK07LR?jqkAEbk1|gXM8IEdt2$4l;ECbU&3b%DFP?^;xV7wScP6Er*vh^5X1? ztab|PW>4Nihh+x{!!`2PPDpk-|W0q$Q9(L)KDpU1{Ow>e2v5RmFJ-{{fYT-JHGW+a+Zzl+49sV~ z&_M(aL&uf!Z4n}r^i2Y9#-_@?hV6*f~30`0g^Hl;18J_;wOTe-xJs`g(-bLjrhv*x!y%9#aL3Ntq}u2HBis7<1ZCiu&;>{*7HVXmlQb*L>=}?Hhy2QtV?=Ks zL`o!5Nt7KHKGuUdNdyRRkdjU~TXj>_-CgP6@JlsowB!mvyIpGL zBO(Xi$YcP_ql*`LS)MC0>P^z2E;Q>e@Tqa^8N|HJS(WwIe4*zyVj9YH*-L=&bZ=} zYhPd{O9BsY)T-&=x6*1y1mtAvJ165o?X>~+HxAKSgSA%|U0xsF!M-olTAnImOs#`6 z^a{0ObMVDlxlclB&{JV#p0bYFVBvs7Z+_M)0Wn1qSOmsnK-37NB|FPr74%F@TvAN@kU1FF_Oc} zLp7SwXqMuva%quwAr+{$#+7co?(lt@XQG%UW@8wBt1-2T(BQfW6C)>KAdr-#w@KJj zWa2A^8a*{m9OvLluyuZV2^mr2J3nAZCGxF`BbG5F4M?A*Czj8JF^OgNtzWEC^EhF< zS0{{s{10B)t;Rw>f;9qPnXUOjcUI|4qu-maR;rrz6LbazTw<;bMWU5YmN_bFnm%3@ z(LdB_y)$nBVKNx1)Sb-EfQX8WEbwB3S!qDX*}eV#p~^u<-FnC)6P2Bvy?XB6!&lk3 znCsU%mVcgtS*N{N(-aWuZgVrJqo$&`sA(Y~7(7*3RaSLn^*!9xS=lDHCCE|95f=GcnfBY){WB796wv$T9z zxQhN8eQ2uc_Dbcd{g{7ipGnbLw+qvmw$_RW2rX74nf@0(7eHl`$P<~aR)iQJngXQM z^UQ+6(U~M+8dxq;3R&C18Ixjc1{1HpO_Ea3Um)(=G#hGqVP zNOJ&#|NG*Tp2ga)RM1L498N+T$49X?$YA@zdH@^!vJad*-jiK(vjY(Q!cK}sD%HgO z&&$7tfh}n&djmwGG;W;ql|^F4NZJ;X^#wcwFyx3;c@%-mpz01K{|F133ZS*=CXq%% zB$4CsqQ3E9CuXsh5D-Y71e?7d9EB;ENJ7hg-N?ukNeyk?n%~QwZOu?5A>k}i^DYCh zN?!L3+;k=L`KrCZsWE2+tenV!VL!ZE#ApZn*9*UZW9ils4-^vouVd!97rwyzeusA- zJ)xy`Y1Do}gwow%vu0=S7Yt2eEm+2-Y=T7ZZtxWxeqXqT`;!b!RP{zCyqiv0Mdl=> zz>e38w(;?Vq@));b8{*;zl*%CMtj}k%*L(Eei$VJa&%UJ8Q1t2&7^m=$x7x6^OPwA zLK2E)GyQ2W-u39Pc16;RpmAKD?#hH2%ySstu8kyg1Y`5<=1bSy2%}dswYjP24+yHVWnNCAVDS?0Pv#BGTR;by9)O z#ERmhdmxBM#4fArK?iy}oRNZXE$F69Q*6)q<|N+PqsZOs$A; zQ^Qzd_?w*@oiAPiud&L2*WJn00ZN9s!W^4jN(B|>k?mU!1S)314^p+-817b^euqb> znz>m~N>>MTs}Y5c6IDIIpm6EXzf1SYNMarY7KmB80Y4Du%W?R?z+}Q=p(Z`9|jScrjZZn9(bLtYD~2!f*=# z11bxK&+~?wqqbgW^eayRod&dKqKOriTFo0A&UM+?hk%1aN%XAeyVIEpc?35TAa24D zy?KvJ(-pDwiYeKVzq$0&8n1{PMV3SiJcJmnuNoKWMU`2}%h7SHNI4-O9_<{=4ZIJg z&1f1)*oFF@H^q1=n}4p$%d1qQr4|rNzQ#&(?ykgIrT3}vWh^i3D-34H8C~5{+(7h3 zx3sDihD`V8CtX6D!FsHxt2RjQdEe&EMdCcX`LRc<5Xs1hXUvU|EV%`Y538LS#-lhP zj7K*&g=5O$PMNmL&ePwFX97^)Fh=`RUMD8>;LL1rKSj>53>Gb>uuQwnoK?k~BZMpPbV<1_qyZR~? zq~)tKW4mOr(59c5-0USVxjkqz*f)ZKV}b@5SarV*JJl^NC3gB#;NmBJe?ZLK0lG+Q|hA}?Njl!|4kz{v9av=9IA#KZ-Rb#IR51{-kHrPp`K?FKA|W~Y0c3XEkX zJ1%d1tj`EB^lwJ~R~NKmkwdSEo`h#pmzvxXU-oO>-xIOO*on2YVLH7<{v{+@W@~Yq zC@zzCIAzyM;V;?m+A-IAg$Az7R`DQ{huxrI?7UAe{yT9|6IeL0T+$WmWPD0a)o>*u<=FCzC$8W!|N1sF-w%t zmE%7U$*4hvZ~PNuTFyYz7P)YdqPfPXCR41L_B%@0Wl2EX5@G-FkeZ|`3J)_1XtilF z`}x)jnPoV-x%WbK=42lE=wQN4U;vY4V6kd5MveuS_oXAcbF8z%~$MN+|QirM_{ z0gC_NM0|fm2vz~u=jCY9&&LDralGUyt!xnx{h{V?YH@T%w@7|zZ84`jV<)7u+OXcn zTwOcgu4xsCN;$~A(}^|txL@php>rWhIS4^c8X-uFG#EqMe^2KOvnbj!c}=rw40c~S zIUd5_Y)~02S3Na(eTYNbRTMnp5UD`^Po#h!aQN`!m?c|#ohGS zUtBYd^W*)|rPY3Jh$rbvQOv3l&rdp@u0Ej^YkC0F)-UAZQ z5$*Rlx)I~SDoe`i1X|6kWBBOkb%BsW`QNDRC*%*YI#lmkA~mtpI$$Bi!2hqDr}G13 zGAoJ4!@aDN!vBrv{^e|Zk-zxpC)h6%ilP6#0{>TNl~0=g>wtZ!V%PsdbN?&V`v39? zTP^ex@K~=L2PoDG+#k4*On6)|YT=oLo=ZuyIP7h%Qxe>Pa&b9vkZv`Gfz?4J?OFhT zoPKd}Bizo>=!^Df2B%r0SQ8kC_U{@4Cm*1P8pS*F@5M8Xw5XM`yD4kqiH=v2l(8U| zGkJonsTvl!a#CLjgR^aO6l2?MRVMH9`_otlb=d+XZb{n{n&~EMe4#-|Dzh7I#WAe` zwyWd$|G;V`ZWd`$`N02`slCEM1f<&1Ih#Ph;!8wno6ZS61MjOA?@&$WYoeyBRMRS7 zNIHeyN(_QF{Hw_-IN~HUSv}G@M%W?aEtdiTwuekR%6vKG|B1$||HhFOm;q_I-X{H4 zD4I-DD3nvLb+QJOt8xQH(o_O6mXrx*3n3@ADVRAXeMCoB?BwF&qX{u*CO0a5-!O$M zIK?V9GAZb;sF0!y!;LVf%=;Bgd6#dLWHBows#MKNgp(F0488Ck1%*=KvjpZ>5a`s4 z%Hit^aP{xG?Q~uKKKy0Mw%1z0YM&@}y?^bJBGoiCMKN?mlx<^Y_I&4_)y_LrTMv@w zh6$%11*~{kP42-eS`8MLx`_43;eN9ApH*~i9tl&Q_Qymi*DgDw7XK&x^*M~-3rs$^ zQ8bkUG_|4NqVl(#g_FON!+(gqI$AD#eu)Dcxa*na+Kpv%=6tM`>3Q#MEMZ1f@^v4# zgM90ou)xpIsCASv^l|4cpzlmM2w%ap0f6U*Ejl(X&e(I2@@~Bwph`Rr>#C^}0aSq5 zuPP*fhS5kO6Ib{wVY8RUo5J(7z$m353{tN1?HE1(oVR^mOxyb2o1QUd`qhvL9}O2e zp(1hm4DA*`1tm9t_9q5;qug?d%my8bxHl)6{yA3*@U<_V><#Ua11rDn1(&A4G&}_& zovea9s2mC+2r1~>yBrqO@pqH>iwm4^i}JQpugqK9&5p~51>ocxu9u5Uw---K>#SR@ zLyfoV5o}1W{FAiE6a9Fjd+?*rbv4zjgBq#}E%cG=nwQX?GS^>s?~V~{p7aDHp7|tk zGcGo-;y7wtzpU8B3BsJJsuGFLrg7tcXu+A3P%Pc{fv&xXFpCPhXj-Ls_vgNo4+NICK= z=Hta&G+KaunzID9i^0RZ`hH4w7w1yUgH>!yRc@g`^9xobef92VB-J|_a~rI{>SI^z zVayH1;^L?iCt>d=kI->~Eu4DZ>RbSUgqNWqcrNf;wdc+^7w)^{513(t#LD{<+?QO_ zv`o>+u$~>D`|w5+;`To;e>LH1XHr8xGdEg%?{`#Y`WZvko_{u+Mc&1eG&?5>DNGTY zzw5@gKfpGKLGZts+Mm!y9riY&vbCOZ-q%b!%y_I2EwQl?{?Q;FQhxVm#P*Ci_Cq%Z z@n7sLb8?a0%+%lvgNp+%y##Y0ZGpY})aht$#13{_mB(8h*+k#EYYnibIrc;B!i#?V zeE$uhU$K|@wt}r!Q!SJ9a)I`rT>w)0`2{^w@5-%~j5ZUreE}6?vQ>|wE|ygxhohP7 z$ombCcYh6>=Z#4&3&uE=%OILz>9eI_|J1|S=*SEZh$^zRB*i|?#P&qco9-j3CN(q{ z9&5NLlv6Y@MHh(r>E7Q_jwhr#3Ed*b+1&+=dN#iYwqG8nQ=R>mD1o=Pf=_J%yU^te zyf{X=eiKDq^e4(7d-FEkX&oCM=99caCRMXsA5#{(Oc~qWQW98$Ic{!s0(qFsqMW_4 zt89}_LD(e2-RqEAJZZ$t+jhFZ$7rLu{{z8PY|6PR+Sd7xN^> zgZ3pFnsK_Rg!-iCjddj{qTARaT3CrxSpK?WE{H6u&2MIJansFmrk4&45e%}~ES;s0 zohC;`oMu1N;>K`@hXCsI|7v4_K6+R z`<&JvB`t51PfighJwMs_$B7JJas;cp^cq1!atWG0W%*oOoz5D8VlAZf3x-*$WwO0Z z;E`2;<}y2{O;dsGPtiouf-Dr4lT)97S)nDE2=MS2lfadH*uy58gwP*w9BiQ!z@by$ zq$g>;;zB<=U9v{`F;l}N#Ti06!Z^j1O8nA-nzrL*O*>4P1eaf^<9A#(3j2qG?XFX* z-5-B>%l4;5C4;7RVdLIml!!qIi2Y49m3_Qt-FG1mSYCQJGAL`+64mLnvHtlG#v-~* zN?5SjW~euk0KRwffJ2xiG>9Rq>Xx+gW3#!V-Kj*HBkFsm)Pm62w%0Le?tp{X;mKC+ z`2ooq;+MF&q_ey_U{ z?8h1!h_>U-hHQU}KYFRtS(M0ORA^6Bi#5_~=u`{--7ML&oSVJqklH2v( zy~eEfFI0Lud(ScXa?h?bjCeI)*qrgW|F{_jGv5a*K|CJYdwm-}Y-I8qR2ikAHm{rf z%NW~?esVdFfhK=L)&>p;^KL9D`}QiV#8EY~Z}a)!;83w3fJhhs*h(IMiIZ(-E~bOv zv@SuGK?ORI@7xG_BTx(^pU9jVoJ=h;lT3@@*M&hw+*w5|*yZ(p+z-NX)|4Gy=rl^+ zw<{1F92{GHuN%SDn^%wu##XqWX% zNNt=(__S4|3!Yu*vL8hG<*6oQ{EL88ukDDj{a-n4Xh>07Y$H}ia}(6;6Yi%7z@b#r zB@{~ls_BoZY!F--0fpo}0_6TILHc06Lmw`Ds(>aj%fBP%#9lj*vgz(=^+9P-7|~%N zzI8!S<*6vVF9y(Mj*`aFN*#kM`MaU;P|?tMio%B>Xf7PBlFY^eHrPe5iHigSZc1e` za<$wZ7N&)91%G~vE0&Qzzt!!TjT~%TcVN_A`^pw>bZ>u7#H*GdFfH-PcU`U)Pn9#` z1^eW`4L4$1`JuE^TDXaYxtK7;9fosJR9>!912JTUj!jBbc*Gqh{%kY||6zsGeGVPCGp@eeiBz0% z%P}sQxmgC3l&EnMGl?XZ3}jw8LK^RJ`@6@$0Jum31mjxilr)=J85R4r%xgi88G&#BuP>5HyFTRlmtAPeFOxz zS|}ziNIUXZ4YziwlEzSUz-k+7_#4#VY@XNXV=5W!i+HdJX$3Vt?*G_O(X3&DSq;Fc zx=`4o!ti0rr{3CaAtE9MM@^axM1+Mgs!^*xw+HOZ>%DyD2drN9%9+2|7gHJS9#Nsv zY3Vq~`)i_{r~g6Bn0y`)_h4_FRMNM_|y_vyB zGKtg`vQtxl#s#slwuhqtK?V%z9NRLyoBgGQnjcIReKJ;YFqXeU{4`EvDgj$|#$WE9 zGYjyeiN9!VMg{G>ltrb~IIB5Vq*;rRJ}xy^6Koe+(EbUENMMYYR?_FF((h2Hq^)fw z8!&{U$`3XRXN!&=NNI<>A+GG5NZZ?H-!`NoA>`1Yl$R!L-{*PHIaNp8k=cNMPGJ4C zq8-M)_EZw4 zScFlHHc15wAq0=dyKo9(VJ9IKbC8Y)g2LLWLxHX*lB1ZmrroJLGqrHI7y0QIiL0y= zLq4D!gRzaNL&{Yo>Tb%s(5mTssj2vLVqi=$(Ipfb-T5U(Ruzj$o`m*rR_h_`LAG004pFYKQ61|B!vM8iSJV9-hp`gD`i3?3tXQ8*3 z^J_c+k;#fea8x(j zw%<%xw#yTDML|t}L@|NNReyryRJYW+h)m+|Op}@!+B9?#De0&OHm69=cMRp3gx<@p zPo7Y8+DJeqWcTTMSG-qHdM zSrxX`;WY5&{&aVv!;JvlO)AGMm}gqZvUNaQIerm|B067`qt&1V2mR-7xxasuMrUxG zU!GZxgK6TtV#^#q8Lqrzss0ujhhMt=4T~)i*T6=B2 zl`J<5?Hhg^Yy7llZZw*OIj39?4b5(fV7?(- zf(lLSNtT3ubSNbg>o%eygu&cK(bMWVp2qA_z7mD|m%e~E=1tm5;dg_M$kEZUWp@aq z;nhTjSESkdh{!C-RQt6fhfM_{YkPM(yH^|bV6SlbtfBo^5-ZIWk_L0WC1qffj;h1X zc^+A(BuwiIQGX{s6SGALr84ktvAtBQ(I2bCYqHa7m~C+Rl{&rI1_9TGjukOFiTjyy zV0clKFXoNC9d(7yD=BfgoloR2`F(;EAB?D#M9S2B)39DE!eA_4N6N8^n{Hs5Fx=GW zy}0CU3c>G2_rMwf6B#^-7{}Wir=3S#^rLeES6H#bg4=OV4&nVL!v2$R`t|`GuShQq z!mxHD$%-(EH8UjIL9dNfak|{g6s!aIktfTm`g&Bb52#kDOtYOydM7_gTrnWN7ck7p z1&Esh!?_`$aE}-d3R<}Wj{X=%7W~I-$3QFIb2PL3X{}nzBsI$V5NWa@w*s@7o9g$| z<(jIDrt@+}l(Pr%^JXs-mirnKSvjy?wu)p^e)i03Njs@i-Da8HLuG1Brr8w*KgkCK z(nPz7KEpSNOSXOPH9U5pP!Xj1usv!O7M)fXD7RVct>HMJ9^9LR37s2QJb=jI(bl|3 zQp$dn#)~El#|ujt&&)Tc#4ue@!D&iIiwMzneT0)Ok|!2_sH2*)+_T?|g+Av4%`h%7 z=Fe(iyj^sDc`qW!a2c=SqLD0x0>lVuW%1JP$XAt-ibECK(weD@@!lE)Q-)?7{pace zP}Z!+_g&Nxe{#^mVTuK7Rilf22dc5=8ug8 zOPIewyiaUvl+N08ytbTEwT5-&9WCYqC+!YJ+_;u8&ImH>x$15X#qbX(At>`@lp-6!6p+++WD_E|VKdh=6o8C$;|2 zlc*cZp~BG6Ey8iTB5$(}0)t$V&o#H)i3L=3p_x_(7r}A&DV$Kx<3eg9dIb)JQke0< z9&)}Dx<8g<1~sGI8Zq|c+J^{R>x;^w1vf%>pEkk182An%#}%lx7P4K+DJ>&uSuD5{ zOhvSihtrqPz+VC5XD3f~)l}*I4#x*u)D;)%-QflmWx)lAM;Q;H2-H4~+DXi&=0}|4 zf229d_G2zL`$XKY@V6v-l=H(6Xx&!X7mdCgY(B~g$*B(gaqj#bp>M@He=RL(K139} zqMlPtNJ$>Hq%NBX{j9<;tIErFmYJCe0|PT4i_t{)z__TyDPB6)?4HoCm>ocsWuq`T za5XEhc4*40rT4CSH5wyD=V-t$R(l{Kf%guTQcGiyNc(-h{>7-h&h>HyMvtFhyM9-G zzXtt;MAk!nM1Oh|@%qxRBm!9q<;0|)vfBG-KaFM`hjH@e1FR|Y09cahrR62W}L`J(c-J@9XIt+6S*7>pi_d$UiSsG`P3M?R6Zh`a5TzF}VfUi_Z=?Y_?U!%UT{ zSyD&7Xy|X0a}|^~5TNWay9twV)0m#+knF=wHJr!qN*{s%@!AvTIt; zu#oo-28K&wC+e#O)9jE#8i%V^DNo;uYA7cjje?uB210k zv8IVTGRaTF!^E98-NOcBv`l}!`Ee0KS@1$L)h0WDK=hb*Y(D%%mLV9rDH#p)1+~RJ zOs}bWqdckK$&hm(Yp7}l&eZU$>Sn3pdQpiv(3Kl=rGmRt&kBO}mq0GSL)&W^FZ+}x zaeN2b^4Pem0;Q(>wfb!bzS2y+5Ts2qnx9OU{F0vW7v%Bx;`m2awU^1SCdS#n^5=T_ zf1&wj)^*AV`(yHvao=I-&(8o}@{fNm0xoJ5$s)T%{{L>=`9_EAwMwcknU{6ie_VuP z0WQK{?GJkF?xoexippYSMKiSWyQkdK%By>t_eD=1UtzYEuuZ4rKUy5iwF&WMnBl6k z!_>a#czfJ88qde%0H2YOb8CM5Fk}(bPOkYLaHmBv=yy))cnz6~)kk-5q?kl^C0|$4 zFv=WOTU*er9oCMfDo^q zRLi7DDdTLVQqOF6vq1Pf~7R;Y}TjPZv4qJcs z=YG?A{>HIG8M{lLvK^#eq0dr#JSqFkT*GEPgN^5oFnM1Gp4QCBFS4*>HeFOHc!)1A zirnSzLbavW#X!7ogS61}@W)`o`2ix?89;=mJ#qd)aoU2kvAmf$Bcjp_uAJsEYc-uE@+Hbr{8}u*e zY|m&^DUioTN1-sN-E*7{ao#f*8D$^kR>-sw^XAW#r-22?hioh?BD|kcv$dK~o5Oag zDila%r%P4wS6bN6uQc{)xVDbP(%Y^5Ynhoc_M0MpjXfnT)x%NwcxQPu?L$)z;R7_2YTqY{jP>B91lr_Z7x*6D}SK|x=>+bjKmsD~J#xKb<6=+v-)i>9_QeH{&geg@} zxre6EmX&QNb)z|ptk|!>yQer&?8F?}zZ5|DNZ|(n7{T|OUm(sF=|l=mXp=JcN+uuy z9#1L!B+ToEN>*0%@{>I4Qe>+y0bVFc>fq(qj4&?8wUr|HY7N=dzhgAi40P>B9s^TS}r!_{;Djkz;6$^Biz{e0h3ar0-@_9a-9=NmnPlaTfrRsXVD0p4P0bDzZLC@(ga^89}+FAg?ZF@64dvdrfml}(6>P-hm>l3l-_KF$!0p!2<% z73oj%;7d_*7jrcXz4m>KEt_*Rq(N0xs9>mOlw=bAfc;W~tmv07jX%oV35ki(e4d=} z;W3Y*&BC#weCz5+xw*Bz7#aC;xj-;@KiY^%OEb;yl&K5JmnAc;x{(iCe0jQ;wY}VP zqS2@w{-Q3SRlCO%5ZF7w+tk7pem%+HdWx$TA>Dtp@KXR&3-w?u0Jh^JsmM$grQ&^# zcpTNFhLjzX@nfXJE_krjpwCgHP{zr~b*Wxat=_T&VJhqT_;5Wp%6-7V^Y)V(o~*Ci z#pkX4DGJYrspU5A2;oSWmv?6Kf2bb45NNky6dZ z`_lx0`MkV7v3M7;+W=HRAVw`!`m*HU?g&5y#Jt0cjNIxFopEjhJ$)cR@#EulgWQNx zy0Oi`|2eLvregMvmY4@|m-Q)(Sh6EwWap~MB*Zs+_gis)34K;BJH2DnxWQUrMSVXfF9u--Ys!?e;rO`$@0Iyt` z^zyd^x~e1#xglk#@n99s(SN>ZAj|d6nji$R6VoDk??!z&+$9D%nH?=MT}n;Q5sX;t zwwOe4CMYoHvj6&XSFHTx%iUx`-qSmT#cHrDRxBeS#V@54@GJG>^K0_J6i_b;#fB!{ z!;TCJJtw6p?gLx88e zg2_0VDRI25J=+?~1LRrG3_S7}Zcs|okfd+q#+VIXJZV&Uy8p-~=xm)R!Kh#B*A4Mo z*QF#IgW*dQF*HcdHeWcCrRuEI_ zkig1g_ZJB5N~!CI!0`2pRQvq_wdGPhZNCj54(tn)xH0##!r8JeI;YL*2;ACyn_#$G zhD;LdL2iFSR=fZ;l7DrB+?%9Z@$+)uB04?2#ORJjp$8{lvPF{g{6Lv+qj~w&=qOe1 zN^kFLUfy22y>X+i{dn7cI?^eY!#NCI_e-duL|R{ke5FkKyn)!scApm!?FEM(`4+>r zg7#Q@_oI2l&HD7thm#jS>9j~m zBh6yZ!&$W3`$NPWe246DF3JgeE@w+NGp|fOA80Z0IdnEQ^72RKxrGsFmbDbECNsh> zn?3o-TrMrRfEs%Xaby$fCx({-3QC59jt9VT!u3wW_p3aI0K7=9XKuYIyY08K;jR&F zi&5$kp9DyGqM7k0lOM!nB#SNRqIZ4AF{$YxPx^$xVD%7(Nz-xYq6xFF@4qbtt*k|% zL7J)T-4A*?MbHM=i{3o_MaO`cq{M$6I*h%86dJ?M*Vfil%!C$^Oz0ZDzWYZu7sb4l`AQ=`@%=)R(J3yfVRpnF>l9Ypxf`1x*J3%o z;7ah5kcwp3<^EOpw!_P1Q17I{PF# z@wJtM(Cb*jIg3BJh%1P)gwi2e-;N~Dh!4;W-BqrX*_Xxfk|q8+D)=y<8i5X%jiYP@ zwfN?6__tR(8^?`5teTeFI~XL-ipFb+a{@t|S-K{JM?j*X5|(}TMfpj#B0Uy!u`_K~ zV8}QhQ`_!gAd7L$pD8#G6s+@gP+?uM5EfsXy`Vx!m47ONZ*SJ<&FZaKMg&dgGy$r; zyr<_~9%ls2Hl4*h30og<>RQExW z;ob?OwP7!iY~Fi9_Ie_VFZ`i%b=vio*xWqIn35xxfDikKcJ9WJRcg1kZoXj#Kwmhn zhNyd!($S%8lL6Y*q}uMF0AVvqU!@n&FmjdNPx;5Qeukrki~QQw*hz80#C@Q&hV17^ znyS!LfpV3T#jINb?85WWipOgL?|HnRL3JpYi_&{*9UM8iXaPp-gqku>Er!bYEatwy zepn<8nXBf9VF?Q6=-Jp1Zw-z#3k7p(s2Bjm{7D!Fv>`r7R!)v&FDbL)X;V6khq8mo zxU4r2{>JEb#ia&JoZnh;978;pC%6X47daZ=p%E{a2P&^B{mscVxh@Qyc6h=S&OXzW zQjy2+P?dYE*UHTzgl0dC$#E!7O-?CPc`9k9Je^mt*lQL($nCGhyvKWe;PM9}bb@hi zo$bpYY4KryBl;ai5OvF(IZMLaZl((9f-?e-Nwv|;A`8fke+*CIQT*V>>%(Ph{w}LH zfFOjp+Uo|fyeMHo!fo3L&6>nAI;H(=DL4hA2uiT8oNQwYjcLP8H6w*3pD!=58~xap z9TOTeze>n~YS^w59<$@*d*0-@Sik(3l65aWW_hkPfc|2}+d&EG1`|!%k$Or+7)4yo z8kYcQOY2OIQi>}1sy9u19wBJWzCp3hrijVNBTkwZI3)jh0Mem_7?U=0rr#%}p}~al zOurzwQiVSo=FzaPAqmBRBaB%5fW~FQb~h=GiH+*ckdLJ*qSZDe&L&{Q_gT%jJX4(` zou?HvtkD!?6>kWnqpvK^7)!-Eoe~&x)Eoy>4W$Ma54Brh-KDU=+ub-DE*liAUYyfv z40o|CpJ{AGz#u!luJgvCcWp}4*0ar2A1%VZl0F<$(`b`yn`fezDo_lHL~FI`dKi>F zPB2Wum)J`+ALKR8&#S$GwrJT@@KjJp^x!%^L>2lSrK%ZY)T4cU5^CCfZ_DAN>BhFL zr3E#6%kQqOX856--jkCRg!Gnr=I1A#zl9c^{0v8#yWM9cRg}qO0v01(85%;(JED?k ztg4o1&e9D*jn}MxH5JaQ-j~2JEiO)}P+W+xJGK@_w`=X-B% zaMmRL_<&Av)vnIh4Cwci)v+=m? zGtwEt_!khGeb)1P53qbD`UEs+K^)VB{3hN1Ydk%$DO_?iOJ#<4@!2H3WTr2ihM#WU zF?ax-STe|;DA0F!?W2>R1b)%{0y*%A_F1CzWrh}97D^Q}huATzB-d1dw0|oWI!h&f zlViL_)E^nSl^VHlRAF?KZ|OH1nQGLVzSW6eyftIiy<>)JduKjs2CPog_BSZprX2WKc>7#XR zEAgBlp0lGW<#6SHnpiHhuM+GcJHx_!mkL=fIp8O4edwe^K*zT*sXtvfsjRH3Lux&l zCoY&$nsZdr01FGN&$1_Q;N|t&3kVt>T1t5p3xnN|_yBe<>Ftn{+42tubCppZH(|h%(5AavSj6&E+ih9{tt#9=i+~Sn&H+{FgKH_r~4SrNx5Iyu`1@x zYCc_rG)bxn3{)m%N!ohN5VlfIU`Zr#$7U+sNCX$xGU0o2&5HP_3qJ27Id+5l)f>IH zV-kg83P<8akpXXJa6PdR}EBY18Q7(RhGguF;3OEzZ&9tyaTO}*W$1LFWNs_#iWG`(dWx*2<(|C4vB+TI;5K%G|yA$_qdppzqU)-ZVCtW zB_iMUy*SAtnU-AaH(h_@L5sM7{$7kAU&YbJIBoT}GDe3If@6(an%%RSaEXJxyKh#7 z3b(`-#YtR)bg8$z(f#!&CqEcD98e)gv4q;%jWElw-EA*Q$Z1FLP`Uenh%x6kMcVJV zwK5gHKpHMrp=UAu?qY|TmS7uB9L{u3c6V5$$5dnoU=K_|@&|(>iI~S!%5@cvG3^(> zMwa@z#I@fxX%vQY^g8MIC^jonOZ?E%`uvNW)S^D>hf-WlhW5buv0Pbz=b%!%o)o|i z^oH@y2-2^T)P9v~TO~G{2e<=!6e9CDOqJnxkZ7s1jw#~(S!PoOdYxZib9ny@MxKk& zgq9G4R^My(oS>;-5A zdt|(3qU9e<0m__`3kfNXkJ+yw?BDwAW{ghOH3LnpHJYqfFR>ij!+#E!zF+mMDW82M zASLBjJ0B;woGY4(dy!q7L#_JQ(=-|n$2t;Bmah||t*^XJ#F|2}n^GV=NQ)RWu^YX| zTS~i}WrY`BqHQ!NAvp){#gwd~=W<3>s;sj1RA}gaF~Y`RIPeZI<`fDYn!9dK-W0ef z*>{l2Q=8CND_?ze>eB%!N?f}|YN8PlpMMxdPD=9oUn*w@#HWW%}kQPobU+RfxNz=xQ| z0#pf?U9@jtjWVDhP1l7`_U-91A!|q9US*IdF3axge`LF73mT&p0?(g+mGIMh1LV-uFCrBWohZQl;mgsi1q>Q>cZF+V zuH#f*g1%^_%1A;4+-01@J@Zm)qd@{L;d^eM(%LWM2MYiUuV15?Y9Dm6w^`n7*4%yl1TR#(rc(_m zs^vIYLQQK;7YL?wH`PW2;-$@-lQLJ)EMAMX&-yX7VBMxBTO|utU8Bg_&MzAgQ=I_-0CfV zTRQ|#*7M^vv6F|J7Z_WgAH3cN1ivqF$4FwSiXwp^Si8T3K^uV=d4N&+#Uh-_50f}= z<_SYSvtKmng>W|;daWg>;Ah{kuJc$dyk=5(9nET@g@j?l_6K(`5^}kx3?M>E@HSzii#aqN$HF|g^h6`DZlSxF>X2WsQPQ5vtTM@| zdX8ob%GHJ{pkoCiWqm+OPx><zvr!X)X0zyg`cfd8G&!P1j zYsAfvEfUVL3s#mpEG9jpq3-})Z!dY_o(RhWJT{0pygbXUk zuVQr==Ba!lnX}LaPbd+12};~?$2YjxncYQ1&DCEXFR`J)V}WBGq6ar6E;2FH8>|`e zdEjE`Mwtt5yZj3!V#Sv#+i_@iMO}{X^Evs{^RHkj^9m>+4XdtBDX3`YnV~gNQVXRW zD*XM3wz~}^S}8T{OX=iuidGHs+^_KXnwb_K}g1QL8qD?f&_*7jElhw}^=krS=rd>z+=5+rYGo zMq~g-oNkc$i|8TZ_UCWEH}iop2QoGJtc<~*8jD|}(u>H%WfMv#WA}2kx4zKT9_|+I{gTn*ZaUfH{`UEPs&=ZcE_^gbagnFMlqXrt$EO(6-=h#)wonH%n zUJrB!(z^kVD5h)eJf9Ztkg(_(cONHc((}y(dH}9sdZ~#6PvTH|>&@>h^*2xS zf(hYfKla>!i+gjiCh#R+ZCNZUUvSFr@`OG0#hZR#QCvo55r8`*HPizn7Ut@BydR;g z*Jh881=f^V>+fqTI0uh2c&$XTGILJIs0G2pHXRu(Jp7@m^@bmf-MrsipRCchC$zoX zv2?O@I6_>Sbw7G9J>?U;ars-_`xOPk(u@_Nh%k0xQtKUh6Vl&9`#(ObGo1Xp2ho1` z33S82KiRFnsjlziC|~lM#xqdCe{np2%7*`JK z(EokiKj;eNbP*@T@FL-|CJ zkcAXG_{P#VJdNvs)f^T+`+GE?q}sd(O^JnsCp7-b77ySxwM!Iz%ZwKn|3{- z0o!6e@+bh>A~>bDRkeAH)MRbnnU2`#;aMpl(^L@;|L?K;s;2X20+;~HVQ#XAc#p7X z_d>5S{6=70u%2~{vXy^<6$<{n(hdU?)6bQ(q%-Dql(I4_xk@{WTa}7X>fK0-nJzB+ zEPe)Pws?A-XOK%z4}Km=kHOQeH%z<7Ob==xljcNhz5wP~YlvtdP6(WPdQy&7(EoCh zJa4LNHF>ZdMb*PqGz!%+8nOn0D61rw?;yhjdyO>KENCG&t|Ge5#K~+8_Zn0mRkyW-tE$0uNBJ3-ok@d-p|w(%9YtYWO6x z`}?@^rPjYB%YRQO9d!t|)F~5MW0Q9K?yoarcyIl5mbGr0Q z$`T3pmMW&0ci2H*7lqvuZSknI4PM6Og0O2i?Cs~|xZ2M=%U-3~XVFoHsoL^v6Tu4Z z54ixHk>PC6k@&Ry$&9ex^EpwjPdV{y`&P~JF9h~g)QI3+;qTgm0i3;gFz!kOqM-aL zn+sezckAyQG36Q%`xxmc>=D4<*i(u6!;u))ftl@)h#>Iw0Z!uG-S$UL>15)my{Pu* zA|~6vSJb~2u-45t=Z0pif!OL$Z4ycSPfK%7#`N0lsJI#0wnyu98T+{Ewr}cnS3^4j z0@Y)N@DtX*Y<(S;t$>rPf!3@i5U-c7WnI6pq%@T2 zY75s7&g!h|JT#SS#vj6@_J`I(PB8Reico}-l%!Xe;VGR_-9>Enf<=)9m)<1sU&%n` zDJPbVkr*AY2D^-WX4I-x!Y*#LPGi9NEoR=BJ&{Q=yDU^n;0!^6eWji*F4k1w97Xq_ z1D+!AdXhI8N1@m}qM{iz7DL0wepZ#)=-L-z#8QtU(B%z9?MBU@e9rw*u*`c-4#r)A zTq8T)*dd0>)z4W!+90kGhEKA*H!3LnNX9HyYPr=#FvKk@tZB%C-0G%-^h(Km0z;{= z{n@mBMDO35&nXp}tF=Uphd6z5I7S2{!LE|=GEb{qv4*x5uLlBbcX!{pCv3X6>*rEm zg;7$Fx3~%Nc+|GI(yPt5Eu%)lXG0ss;tT5m9_()t-x4(yR;6R&dt8e@De6?Sb2WM3 z(C#RfGQ{dDAIll1o|I9H0NUkv5P_}&W#My@Y|$Yxb@jNV9_@09*N29rv+db3byb~v zAq>3#2T_T-Yfz0^-ioNzC|4XGAnm~@SwC?$OQ%HDPNVk(zJ+D8@&M|G_6LgbKr}Ry zdC^H0^Pe+-u<_+)4I8lVnNG0#pO|^#+r(MFgKMI{zo~*-_kp#4gxUlJcl(8) zUy~h)o|QwUlvvMh!`|n30kIJ))j`f5b?_;32hSvFyVg(*w!uckRHl}0aJ8U&xL8X+o@)V59bwCkPcV`V-O58fm$~nq9?XT2st{HOcEjvb6m?^mzhXNRE^@#++mJ;yD-{?!dhQ%KS8A27?PS9VTDqR zOAY^(5*8_Pv}i4_T#e_UO6?uZn`)v;b7@nGR&sU%QUsG)C(P2Hhms=7y zeNUQ%91^nbL6?R1qqL_zg|YI4M=vCOHX7>&S*v_m0j=Evvk&J`>-g^)7&1dM%smQo(4y5p@wrI_3 z)`T!7dm2I$PXP`uC;1GcgSzNaE$XYAXZJt^eC=mQnM%!&b&t7im#F=veoELy%$wgw zFB_(gh>6n3kCb%75ouKOG?*Zq+pW)6`kLs4W$ac(eDPz^(_mKNYtc@j&+eRb*1Flw zT!r?vM6LsEhYD4spzEVP(%aoRb@qCFq-I&o0OW>H76{>m#?vnx$eK0fe~Bhvt7n_> zf3e2Wy&E|QUs9D2X{lB_OADBkfn_D>yC-HkBlw@i0uJrd#q~97McrfYD>rv{B0|FT zK>Uee)2V65yb}h^R81lhR&;C}Oj!faN)Ozkygn(|5s$8;Z>W9GF*R2z%Pws6dXVw} zn)Xtp)I6BCvA9O_x&qvdN;^rg`ne=o)k>@gQBn_amOKoewfsOD-ZnDUAUt?lZ1Lf4!eCpqAg}E92a^9@}5 zE9#NAVR6(R39#eo`=2_NDj6Yoo&Wu3r&+|&bu{Nb;PKn9|In)aoy~uR-Hrn8a<^ro zu#eStnk=5NDfM?M@sB_M^82ekqu4?H`{sY##Z~#YCaxsYjo11=u6~kO52PM3Ha7l2 zs}C(HNu8Uqh43F=@LM0}%Uzx+4Oc1K6h?8VseJvk4vsi5FkLQT!p#UPxuJ|sjs5o1 zqrP*Y7ZEY74w`b5WI`UL(fN#I)4*Vx_eJk73H<;2Wx%tQONQANH>*&n;SFm_5|(nz zlMB$Ui`?xV{Q0S2prl@4svZW}i|T@elx)3KEGiNqxGn4D1TYH+ z1$O?N5(QuUPYnMx)WC?}34cJNT3n<*P;E)q55s0DeMYGYD>F(*v9Ao}Vy!hTwfX5k z2O=m$C@WDz*;eQ_tdl>y^?ZEDQ5{!(w7323vX)s`Bw$x;7v?|S3Q*$a+v%Xmv(%P{ z^=pLo-R-S+!PJ1kx0kLzoYWs&-DGjD^F^6UYi&j6Kw;Irs_+b zef1&6sDCoo|NENJ;UL@;QV0Dkch@Q6wKdP~C@xGkXnK?4zhE_N%3;1zjI4DHRj{Q7 zzlN`BuFp-L(y9<-i9S!ysd7z@O=|ex+x;^AE_EG4y!rOOhc)14I*h;7rRT8gJAaE^|FN=s zKSc$O?5CL@z}WvWu2w_f$lkGy);aR;{reBKA^!&HmyM$%?%Y%7%5@Lm{1-4kKkrw` zj>VFiqWRtR2`KYM;5Sa)Un{AH^F&axk|$g?_Rc{P*H8t92@gnr2QOv z3&bK%Me|(6qb@p(6`HUziUEOXV+&WV5+xgrb~7?8>>xehUTQ(S*YbwRK zdwKYL_TadktNQ~ZemX~44Y(Rn1Ge|xAYF||)}*;OX4V<$TnqDx2GZdSm%U~Kt2}4! zM3b1Y4|gphqVp8Hw{Frk_@9wVX!GVgk6E=?ouGGFl8B5Gb(F3`gI1T_O-(~EVE{^c ziHQ5MxSF~FNwD~uN?V6I(ggv{mTK*tY*hK@fs~OTruOaW!MG$z>Vsb?LOS-cQa0YB z>9jz}8fzOruRWd4VHrf0jb zOIHuSML)1P?zqLoh)5BR$VwS@d1=C&}(~ zHQ0G%9MdjthhXRCt{=qv!esN;{9~H^-b;V4 zLqH<7TwH@!-uw)15P5lCBjVstuD>{Lu<0Hgl$mat->Yi9(ZOQ2RFiKOovYAR*?9rC zNV7Zc!Erhk`QNOa1QCgp#l*+^xn0x`yADm~iV`DWPikDV@bEyUvEn6R+DZ^N9F6f2 z2n#HCyeQ!ZFQfQKBsuq|a}aMFSq6uMu-|wvgL38bMpYHsj&7t!>8tRs^d|fyIv%;O56i6O32z*jzFY`A;NCU=X@T&?$Re0| zqpixS&-6A!F{1|XqsTsf(Csb>X>&wbWwU{^#$U|LjXXjzXwXWyl7C!(Tk;6){fwY` zV>yt?qWTaK*m7>+&-*nnft|cce*PK6Wj?o!T(8x&q1)bl^^|6qL2%p}n4udb@coBv zN6T~Bd&ipSo32Ja<*^8~(%sEV_pP~*DkQ+&0m>SGISe3wI|24qZG}L8KSbq6m5cD& zLTNtfDuS`A`!3w9>>fF`Vn$FSGTHUZ6+Vh*Aj}I)07$s{(iheVNrmxhm#>B#czP&}q3fncZTymD6 z3Rr0_hGM4}0qPaHy5w&`Nh`LtKQM-jWpm33xLw879LmQFWffDB;?IFb4;w=X!Jmf} zpv>o{rV!iPNvrHZAB)!u3r!?I&YL}}-7_2rfPLd@*~|cg?f{(wPSGeZjL*v_&(p4# z?X%d}%7KwISwn2jME&0Mc+)04zV^bSlg-|i-cEj+va;&Fq~&GYe1$S}!nY~O(c@qI zI(#VFAMNo9WOa$5VeAc@Ds@Q%Wmq5tY2yz-^-^t^MUw^Oqc-YcGx}fGmkUS-#FR&Z z$xAH@Lml3IjOSP2CC=g&E`)Wc9XecEIKR_mr;AsJ&8Xq+G|s+V;i{6=b&lH%Pbwu^ zDIM(fc=0n+cnyBKvUgI~^dJXB=kh0Un(Tms@RmW9(V}g(6ub@GYu|gpj#7^ox!n=j zP7MUs=>twRs$QR`Pp<{o=!XvOe`y4wwzJrJf@Im7A3Yi$R3IszNP-#_?|t89>77Ky zfja#oY%^5;;im9g-UzrUq~=7u24(x!0QjUH;Jjh${!^dB=MIo2vV?y?=qF7kro}5# z<=AzHrK6+U)Tl@G7h6;`MNkYjHdQ3k*Q@LF35~VXt36tcqM4CLlr&{Tt z!IH&VGiA#acZHSh?ZM&bK!O7biCwXcSK~ zN20__Nw@gLTM^a5MX_neCYhO^s+)PY0kmgUE&K7NIum-~kVx`iYnWU}xH%q9CAM9kBOtT)5E~G1-*~6=lkAKanPW~3F3$rX(v;^f5H z9m^Vv=|lNP419rT{sYo#%S+NKB9r@>BL(N+srAj)sm{-$3bk4^vU?MqR^@^;{L)7m zS8^^fIPq3$LmH;JoL5PbOYaOt1CMe>gqk((Y=00|Omd*(JjYsCbp7S|1Tvm-i>|Eb zOuvdQL9#7FJ8|Zej}E(VIvTqg!BTnDGAnV zwPDPZ6ATOmSwJrjMZJQ0)7SJG*=br07$1AHwCb1wG>=kQd<@3FW9oCEe@1S3|Yfi(BUe1JwV-%q7N?}HPzL<45UBE1l9cv5V|h#?1=o8w0E za`wj9;7wrP~V%wNd1!@~;`_rOhu0c5@{ZOOdkWeFvNrp*vh@ zCCo0Fm$_Pl>9?f9fn8Ek*@PCJ);+Uiklq3u1SbzSC(uu7|2VsMzL6r2z9=8MtcG&9 zk1r1xIF;MTOl9d`&}zh6gez4!h8meM8opIYEVbZs2YiZ3&|ZjrZ}}5qjkgX$^K50(SQQAI2G_o2bOhX)4LJY zczqNUs$QV$ZB~A4Sr7HnOewOjx8YR=cM~%2b0ahq5_1`|sA&JNL9W1CRE2?2KNW8p_R7H3sVY-J3e4X>zTg2)$m*{=Gx1?)2| zGUkHehmVzN-(M9w_D4p+Bq-AZD!^Y^GnLD6S1MIhCMV(TUaG39T@}AMZM*Y4v&1#m z>^*gjrGfSWcr5!Ui4uX_`5%M}QbBvfO$KrCWm=-vbuWftDZLzSlz;=`K9PTm%ZW;k z)T;5zvWlhJ51K>zvPCZA_QD!`uVp7=Z%lEC>YXYh-FYbjNC!#D;!kms_p4cC{=6nE z&y^kuRdbf~+P}=$&>T9PLjAR&_UzjTdDRoO&1cIfp*l@9dN3GX87*=@966ox{^Lle z18nF3`IE(26IA-N)xFZ89l zTMA3bz6Xurg&7-vB&c&}6$2b`UtGy#zafR-_9}^rw%1RW#jsU-P{Ou&Q6r=9J7gHd z?#>PgZC!1nsJRq)?v}L87B@lhLUX2}+(?ZmFw(y9(&xil~;I#9{N_wlhvn z%_MFELh7N1ABTk(y;^#z#nJWHOpqe!TJh!}Rx)<}Mm%^KQm?#+N|l_<44{8(Xy9~E z2*(NW<@Ny3b2>%Rwf}5QlG^p>^T}7OZ)U!Ta#| z7>YiX(5PLT-p?Rk@3$YB2l!{*?@8iq;v%bu{=%afErGhMYw`X3_Nir4h0M0+W7oh{C9&1bOD9l=y;Z z)uV@DIArambu}BTPAs_mXk;E)w54A6N}w=lxK^UtZWeW8aZ~S$NW8*$k~4*@j~EA$c!xu$fQo+F4E?)LOxN>-kU|<{GGMFE@F> zAVwtaJC@1IpGm|zLffphNB*#ujkfdk-8fI1iI-LMey}1iUH84WJ<*AFe|rXa>>Y8) zzaiF*>h-%)pY>KH(5-#3O_a;o88RQxZwAz3j8YOifbdAZu>X^i^@4U^eqsLRbaP;; z^yYv~&BSxA%_nz|{61Nm1OnZAG!v5q5%!M)nxSk;;Ejb2yO-(LN>)pf&tm>u6ucfH zaD0Qkc+`0bk1KVB`P~wBsdMD3<86SZ~deDnuoR9~n3yiVIJ>;;(bcN`@IXCWNp#R(t} z)<7Ah@`bx;x3)`8`#YJb6cv@0bI47tQ|Tvkao({idq%`dgxxRgmM?9*RqS-_XXS?u zig0X3)bOYO#xMQpDwyN4wtLh20pju+;a9cf47=G!Zdtipy9MG~zb zQp)>wGZrd*s@`QV%A%<>)Se+9BF`GYk;saNgPq9aYqMT4h)$-Rm)*%4p4Muu$}Asy>>^j z<<0C(+*tIg81e#9(eLrHBUK@`hYsblp-_9tGjkxG;`cnQ_+B0zl*%=fJ-scEMl+^# z#75M5|3A{+Iliv7c^_`ns8OS~QDf}5jqS#^jUC%YgT}UPTa9how)xvV&pGEgeZJ@Y z>-{sIoxRrD>%M31nQN}O=8BwDvFyT;LM^7M>BnRRbE8X8Oe@}pB0)fkT2=pm>(Lw? zrCL4*S$Yf|65FrZ0{w{}NDcrWecI6O1E}eXJn@Tok+zWG;}%*BU6&{4v_M)VoSWsO z)HTb5iNd#zku>jdQ#s)~Lz~mlh=jkJ^Z!5`#0b(gtLO`u4F?dsLZ*mX;J!>prG5Z- zG!)1xuJHJ%Yz#A1v4d6a?vL?ozc0GK&pC23O2Zh*Oryppj;5Sxh;X`Gi3O-d@s>8E z&+ZJ%ZZ~VI+of>2sk%1`D<&FXlb`5uj8f}KUtUhy3QkP|76=)q+~ujCrO*WIIVESE5&`6mD#)hyiCKDS3XK zSuM_|Y^RiHu~vW-`H!+kb{HFd;Dof!S9{<^MV~1WIb7^{tm1#J`pf&axBaTMytdl( za2WjGYK0dj|1b6$WMa}kaL2#sQC=Q+U>Y2cs}3X9%bO25ha1LXw5j|%PyEphcmAg0 zqo2t3kGyQcMhL%BQ_fx|siYgkF|BE|1`W9z?|$;VU!GB(%2UvdE?9l<#YD5V66nTy zX{pzNOlvE#$1Z0&F2HzoHvzlE@(|bmqZXtoFk^p1|Jot&#o9G*ekVk9!zV6*3;;OU1Ohg+%le%1eoQGn-Doiayee0HBVXp@E&%MFD z0{dH$i6~iJtwyEI@}gYMKKd_UbBW>0m#D=ry@;P#$ff;-GYp(sAzN22sWzfet-s0) zbsz{F2W>=|O(Ii9@9a|jYKu>bub-8)kOj4nZ#?bf+`076I$!RAz+5Q^ke{+6wn_#A z_aEs80aDy;b+gG|rF>4Zxk!+B8wAb9Tz0INKQb%Ocsb-D_yyFl>11uLZEfT;;r(83 zZ2s`8Aae>ocvchqj~dpp%`b)*MuiWE;bn&n1uA2OgGfXm34s#F6_BFXG7!{xZGxkh z!&Z9Kxq=0nPdEf(V%dY#%w7l0(*}XncJ|n$yu4CyS%XjtB!vZ1LgKP_ z=MYR0QcUl?(v*6sv9DCSMs$;|5tk(D?bbT?bf$00cnMjyMh7vp{1tx)dVUydXtrTM ze9o*qP@YHY__NXn>SouO3Vi zs=v=1a=gdr{IZ!S#d$lFkbj6gW%U!|)ZecL0hG~clQV<5`=JhdZUvl8;072>f0)Z+ z!o?XvTdnGS7EwBwmBdc3^*;b!g8G&nf&2K(^^JqNVnIJTK`NJCOC^^8oL-NxLQrbG z85{=FftS|obtbfM>ukPA+m&*V=~Wjoi1@R+s`vwulW&P;NTk_&Ox@o5^&6~uCrbcD z-2oUDG^x`{Y^agnpOm$4JHH5PcJoRe)%3t4Q0)Ni<>f`G!I~JC(=|UrhrQlu-kiPh zkzXYVtF9<5%QGvB5Q9b|WNS+=^ z-U6JDfjhBUt_HjuO3?m2m1xDovpI8}F)Y2WjL?QcBRH#gvJXrKe`F=rf9O?NXbkns zAVp8pRMvMJot@WiNqTpE6BY=u8C@RkNmQ4S7cWZnQ!+T&-y%L7hW%(EYz8v+jBl?J zo;^5C9kcL~=@+VI{*ZbC3FUVGLq;|}p@^M05fKY|!}9P9FD#gWjA!{9-Kp>|SF2wAuF|VF)yb z(376G8DytPY)}%s6(=mZfsGh$W|1;@>mGQEt*%(AxnbOiAEpi%zP#FaV<9}eVPHiF zcUyjbW~M9G@+(anQoJQC1LIpNkk|mVaE3y_=fU4vfVgq#pXBJanM(xbRn(yeJHEHM z$9twH6@q2j>lzM#h5cf~N!LdMgDPH(cBbFqFn8d9^ zvfvUF6!bT+Q_(K$9TS_u*hBNdmAjU5yurk&%&>H%H1uL|e($+RjwTQGY`7 z?XOi1xSrB-s(7z?7S5+X*^^cGuUL<`u+5qqM^L}HZJ*UQG zTwWqZn*jXvFoo$jw8h6F8R=mwk4P(~`~<^* zNVHDJ8V3>YtHLA~s|7%(0K|EWRiaumG%FR`Ch&!_eeY zOZ8q&_T+Mt14!tkibtVVQPb?GR=jAlgumL%j4p>uyrhT*BWRzr?j1yf*yLeUUtxjL zkg%1_jcNIevt6KIO1bT!sIhsjyG&KD1a500JM{{|oQ6RmJFd;Ctc>*QRxQef<4I*( zhPCbZmc0GdA~uCuzp?mrCO|i*QuyTZD@HehI=SId;?%SpY>r%xUd-Y_r3?%H`Fx75 z?L-Zd2{q(nm5n7Mi*qgA$%5VCX%ef1Om&Tb^5QXooxlj8p$vvBk80;8&*Swa{)-`E zW@ct~(e=yo6Ew%_34X{Hoo@ zI{KB1g%tV&EUsHyNHAMe29_tAq~FjY(P$<}Ah}iJeM0e(L8&e%4MjIVXhv*BT>_TB zd*SaWAZu4tH!JChbk$pibK(+ndE%IZ56&9sN7p^|K)j(oVa*~C2=}%hSJ8r7(+l$T z^TQ`0dE4T0%NU;mt}9JTskr*N;$}`{hpvXppW=w&`YmMFO%zalle#`u_KYl2Z{7W5 zZHIjJYNE*p!Y@X-LJq62h}bxkS2%A?1@akxUl)!A$S!9Jh*4MwSBIEhO`^$?-!=ns#~GZc~kz|`CiQe zraqY*INZ?a;;liP(|1-eDNW|>cD&NaLo%*nP`Rd`Q>3mD<2q7mzPHe#eM|}k%U11> zWSv-7BDPkg${ESip1;svfx5k`b)nr{w+&IHdV%a!Nb+<@EarBEqYSNl!q|L?*O2zv zTOa4isX*>PM{qdYI6M64L9ni+guoF~0Fc^bi&Z&-+pBmhX;#bC15PNtuD;q#^FVxn zhM{C~nNkozo5OW&*NC#zpOWOOHblxvNx6VY&>UrLrDni5VUfD6WE)wQ4q#h&x| z+2yna)6cKbZhzPKpnf;qy#Cy)h#~Lwx>Nn)T6T2EjW74(W;%J-pTfR>5|T;;+X9=3z1kF@7DiqlP0o90p7dTN z&*d~oo&Lio0o**Yk$G5BEd9K8lX=*@kj6G93ZnN$*jFm~0~0!YdBYM}?V8mmteR?u z+cBy=6u?k0@|s~QCWW?!&FpUn)SaRo$d|$UVHK?xEWerjpieuoV&Wo}s0&&7@wkDD z&CZP-EVIkKXpRz7&HDG1b@r!Fp`3|geg|{rMV1>eZA;V5+Z5e$cDUazBf-~Bd3z-B z#Tgd{wVofcFXds-g@kebQw!rD2U+v7RXbLKzC&WP)Y z+8kbNE&P_c|M5p$kf-u^6^;CEJ>Q8{b8?zrsPX&a@$xn= z@*67@n9{l2SYoq(@DK(i*eCtHC!6*qxwMO{%AdPQn4O@decF*#7U zO06if2uJ=#8O(Jx6H97IcxCoIKC|fBBV97CP2&~vJ@ctCYF5wBme=r%g#vw z)}Ca>7?e?~uVXG_xvRSnaxnyg_Dl2G660TZ!57&0oUaat3~w|zU&tti0*aafBqpKl z&AoEu`D!2ne(;&w`w!{j{M^?sD{-N*O3M23SBRLYeK7<1+8ub;5e8yg4S1lO{Rg%^YB(}$uUawGvAC^hf)zyuNjPwQ_>WFmB*A|@M zz?f&^{Xsm@)R^NZ7rqH%<1noM;!`SV+&=b;-i@1^q3d%U>M$qor@tAO78j3AZS5U_ z$Fx91Mve4+#XGk&HrDssgm%a(EycLLz9vIcFp3@F`@95r7bdRsJq_{IyXZj;mJOg< z%+?CFigyRrO3mr$IOKZQaF>}|Dwti?d%)Mmt)JD|RJGz+6K`5kS>6@Ee zC+Jy}XKnlweJl$HIm7MsCHPl9agB@$ifxw4_gpV`rth$35I`zO8XA6vsmh^015mf` zfu$dV7A^3gE8AzLrgXZ4+I%M(bLO3Y=sRxyU>na8A}cPYl$MwGDQkN{U3I;Ts~{jr zX*z8!`0$nNld`hk-Xy75tkZFg{M}|4VLWpM4X&dW>(Q~9w6wIJ!|Jngu52zW>cd7x zfI(+Z4_sg%!pWNL$+E*P0b$mgvev>Ijr+YUXse?+mgd)&PpnkE=TovaYbT|(_Mi!H zp06*rqD~`9Nzv;A()wg&)!e*=Cutm{G+aYYHpEL$eA}xzU2Y!hd4?w#agB&ft&fCo zUu4Q}rak%@&0mdt%Gpo4Pfxl08Z^xUqVmeb+HSgY=sgC;m+xNMms>I>%PfH(E<9Dso=#chh<+^#{)eY1-naF!USs&&$aA z`#l{TDb-xTrDWCp!Yn>{S9Yq2sob&Z^lCx?*01FfIqxtRXp2polEBR*&A2zN>2nk#rdcpc>MKsK{ILwou_k3*88-I(l zw?k6ffEz%80|4Z)a27VYq&bb5&F=)Doal>uqP|c72I#c*l*Hocgv^)On-h!tM>Pc@ zO-syj-iyIq=0+t|=pFWwQPOzkNcpFwUND<>^POaJVe0-26tAzGedoyLlu_9zL($%b zvr2YaG|Siw`*0-5NhWk+a0*A$`&Kly$){EzXY6%eKj8|R- zVCo-8e=_%_v(5)eLD*zV)Wpp=!A2pQ!g zy*!@+eVZH*HC(N7%4v5tB7{n3Y^`3N_WMR@I$Tbl!SV1E#aWLX;kri$2L+-8P1!l7 zUB7{OJYTbAS3bABKKsXMJRsK8)X-bln91bsg_~@ls~8_B*q6Svjrz)9_#TU{MxNB8 zD2Dt@XSp4mlt^{AypRF(M6Z#d5ei6aTy8vn(szEGU^@Sa3c?n2YUHI07L$*W9wl|B zTV*Tv9&};v>_=h68KhfJ;%C^vSC2XTpjJ(L1~^FlGBrs@i%g5xwSa`~z=N8Fo7x$f zirlEu&X!dk?6x*1t|v7dD*`G^^84=kLd|MQV6k`tk1GyqiH6K%889YtDpys=!Gix? zg6F}2wdwmPG{f4mfDhQE_3J!W$*vp{FXyN;Ev;{2H!KFKUz_vx2m2Z%Khv<_4&o)_ zlQ5+U8oSr-s6F&Uq#r6oC%_rfT&uWmZ*BSMY)B}jy((QKQ}Ss*1~EH5^a0DT2OmA7 z_3%NcePv5pY8x~cryI0MSJ(M8jccjiHq~Dz7|9qc%wD3JiWN6n9|_r4@bf$@Z|M1; zHBCJyYUHZBywGDIJu#ZABPPg2cSY3*l?tcQ!W7siS4RYQR15hX0M?QN^$8*!6?eb%OB;`cz#Vj zSmxMYg=|*lLiy?}(N0-{yoa0g7~?YB*6B0x;A)R%&_tbX0aJ?dKRMYy*sh>`WUIxb z{{ru8cMnWR0XHSL{)v#amEGZXOEpf$JJ1k39|?YPc<67n=?w3eOj1rTR-#%iWne>{ zt0F!xg!haM3wuEPA^NThb++F|D3r)+MkI|rff5}EM`3?snpmAF3mO6bJ?`XN3)Za-X$L)k)My)IFeS{D;w&7)-X)pZJ68P1bC=ZPd<#)Losyn6*SW#*&bU*N)XoXqJxO4 z(QrjI-{W{*U!u{s*7_RJS&!vnF$lFsG2SI}NERF0kFo6Hpae$|(9U_KwT3T3w<=05 z?Ca56N=eInKgDeVd!chfVX-9$s+2+b+>H5ON|M(I-=8=X3q%?TxEr~QdQPoWGOIne z;4r7l0Zi!!I%$dty$FMSc1U~DlktyYo<$3Vo2LiR58>Hp*_qXqa-?Yu(CE|XtZ-HK zAi=xQfuB#6uKO<~ZX;{0j%&4_XZNcR21NB$NQ#%SiZkudV`Q{*FSY~8Ez{W5Vhb?M zD|&p3bDHc^7~}mYi&meh+L)HWm${FqsQt^Ts@l0a z?(Qct_oTCC0^fM@Dv+=+08xlIGEG?k^L-IcT~#&X^Q zj7|t|9opWlF(-JKe8Dr z7N7~i0jde^3j!^X&8Ph~7LL_gzbR$mwQ~$Qc%v<>YG{7HQ*Q$(BbeiKU6bf1QH~R3 zhevgq2hGjRZ@$PwgBWeXp?7Rn_Od~N1LQ*Axe^y zJn8A|CsOyP<+lJ+ef^s1#!m;5+S&yn3TAuBk%A%=CE|?onVI0o7(0O3=SY2rcF%9k zPZwsq`$lTeC0UF_3=9nAsE39Ku-8JG(TYOLZnuv=NZ#ag5p7N2V8+31D-~p^6p!b8 zTMuGpDTr1#&LD-?nJwjEc4HNYl3?pukfTqMCY8?MhU)qmB^cu+?I)9XGb~Y zogXK`MUO(8mu5mz_;V4=frsCKDSYg-$C9f{aRO=A0@b6z9 zPb)xFBa9N`Td`i3&onmetd9C1^j%AKEkoKZ&|zKfy$5fHQ0|LTs)K{8@~g@x)Nuaz}XV-BQP z|8KsV&WCFBv}MZcV2prSB3X+X3*mgHaLU9_P=k8dxx~@J0M;H&QPs^#A6cHSq!fIu zld-P8{b_e1@~u99|KQ??z?$5S@yDmb9D#fay{P8f1Gcz(ViLkpoExsE^(7OOYuBi# z;pAs;^>jTeUC`n%j=6V>oqy_7Conx4&L2k(;-=j?aSV{ z4GhvP-|Y#$rJdPJhR1TF+XXA<%N=FHYQM13#z`7QukDUl;yX+G(-v7hkJ2Xd`U$88 zV#Sj&DT{`kMDR>wNn}tD2EMVe`_(wmcY6slG?W76t64LpVX=6Qmu4MH`${S9x>sY6 zd&Ku0yafwpaGl4DFw5th+gpm#8RO22Z(x~AKAk*epC{bEyqQ+jh^9e1q_}Fiu%BaL z9c&8_aj9;3zFD*a36W(Pj;1d38=M9>H3`QZs8odBw>P{yPb&0jiOI~xEU{{u(%PohNw=gbJ` zYs)WE7u)3D!e%k_*0Hr$Qwu=qQu_J@(H z-?C+oO4;HV-!Tau=DUcYwZXn3e`FWS1LOQ}BO&Wouqv_U(v#G5MPYJ!*CUY7pO$7* zoq4!$b!Rg2ZBHw)-SJ!p zoXz&vqSmf8ot?KGX`)|siGf$=S5fgy8|L@Q4IVLSm^f)Ep4FXS**W*6oiw+j<7+YE zDA{C&Xe*>W%f@zbx#x=D_0iFHC&C{G@e*4jv1)s#BF>Uo-d3kwZsz*i2rjtXZ=Y5e zjdRo=mIhw%C6lG2!;-tZJ^K61a7+g1{q%>hE^>{iiC0S$wj%PoD{aESXKBD6i-?q* zk7S4|u8dA2A|%u_?{6%mN^@_N7Od82w9Qa*NT!Zn@;?NLG^Wv;T(g}HK`DFZiMwbt z8D%|lxNMwVjyBQR?0iPK5-XvUAv&{You{}g)o77zO}Z~C!^q?>u9}Wz0N!YIeLuN- zx(v|EwlEWAY$%-sRPGEX1%+|BOJRtjf|ayBT{wK5CDDIz3*)G~EIB(`a$UMfuFIXT zDI4mwEAs#l=$FMBZ*4s82GI09gHW5`{C-)h*Rg-0TWO6}pe- zmj?nGS}1{b)N}%Y1t7B80uWR+x#a{VQ5>h>GBW054SmX~jj==hsA@p-P1IYJ8#6-u zmZ|o7QjvB&BaX+dkJSZRGcb?{<8M%I?c%c*3O07}BcvnjgE-Qo-VgIIyCE+xFM`6d zg$*}Ye~P5LylP_#NoTyOnw%O}0lAbFpPUL-b^*v^aFjq$cvpHP1KXfK8>^UEo8Pg# zbUOp^LvJMA=!Pgq5Ez%hgd6vR>B+qbE-urP&be;CS|B&8K#>m5%6!sCGYg_O6WgoP z17s2~5pN+%@G2NPccZr89>~D9Xap~n1^em5#m44MQYr_MGbg+_8xf-_*Cw$FIS~*% z<;aw9X=@dWQ!oM5$X8sMks=hK87kwQ1047Sq``u0Cdi9I#~O}&qgaiHaq8>qVYXj% zw{~`dW-U0z)a;{1Ct3A8n1r@s<(B>c~Fs9#?%bf9UGs5nao&! z6Sw9bcFRkU_m6)tqf&yoo5+Tb1JBYj+iQE77cUz%25QlB?vH$EGzJV1YV9_EWck5c z`pgEn=LVvt%}8!+it=l9`oeaiu2DDNEC87u%kjtNh)q&|=S&$9gNrBSs}z8h`or2X zf>+lz5 zOoq1Q?*GM{4M4rQ^J122|JK#Z$^~M)vAr_vpwRz~LH~`adM1F>1sGAWF?GR@dD=TC z;P@<@J|LcGh%-u_=I5|u0`OY`w4R>is3@1L+WG0_tZ~P%tvFwzc;XQR?{xc2*}IP>cn1fgtsjRpNB0=2qvpNYbFd$<8DU1t+=O#A7jG%}L2_6%KP>unbtS~nEg0A+_YffpR8@E+&LOc=R2ZAFKTGq0>e4^w8?KSWl_(E39F=I(|k}IUZA&F zpXPD0^f~?kWX0cKy&}^&vKO10j_Ynvc7?XU#SjZ9S-ZU&EXvVSDwV(4-rjb+=qHWI zHmJ*Z_TM~seN2_mx88rT^QUFp8Bx%sC$$& z6$3Jd&_DnW``spr>HX@fFjRafo{9||P@@?Sw&@l53HPTba92doj)CFZLK}KfS_Ay` zwOw?~{}gDRHh*F7aWQB~kfG5g)=ktDeZc6|K_+f8)R}J_c&mHZq}7OEszB8O$f(Zd z{Vjv!QBnUU9{{Bv{nJ%+ue7pkpnwOppsylx=@A3O;^+ytgm9^x=cC_M_Z6<}*@8cl zzG0ftL4bz zUCJMmy7Oler^eAtwsEjx2#d{*x#2DYhl?Um93V$D+&^!>`qwFD2vRC4z0PF89I?Q( zb=9AP1E<`n1tF!;NS1zBGZ^jvoyA^Oshxn4>Sz)%j2qkchSDB3(l#;X0&|gzjQ_OuGN^+ zp4vk)D$U_{H`mUU;HeB!sc7$u#76t%bvN0IV1NW0Ea+~|SV-;lUbh_LHN7hhsjS1I zv$ObHh^8X*Jj=S}fV8&p4WF_6$3l}r+J;Ad%EcD&zy@I$Yb5>@DFfqG>~WCZlde!F zwQ1O@m#%Ry9c?vNXLI9cd{Su^GZ43Z>#MPy`fa1iN1|XG?mG9kmnNhba3E!>so9xy zGB%uDq@*1KY&0}qr^}uGLU0HOzv&uWn%28ac8zKyl*-D=>Lg?9*LF$GjWWq845tA! z7TfoxM>r7#lwD}WMVVxHAP9xGU;VqLM!lr7CRt2Q8f!>3DiV2sb&~r7@`cPNN&4$} z*OuycQP&CCqHQQJCw6{I7AL?%!?I2VvbALpEG?7Uzi>oKc}@Ad&+~-MhpkSSINHP` zQa7G_rxa;+EO&Nw^E2Yvv(RnQbESh5M-joFO91-BAv}^vTCsveIXhvSep_pcC!i;dK6+MGt+@HTq zoX<&2yoiFvuPRTaA&*`|J+s?+?UQnIW2Rc^M}-4xco=zm236~q+6g#`JeKh;2^}2* zT-P%mr<>!tRrB*igPaAKb)6C_y=#*DO^;DRtJ5@V&Br2UQk=9Uda+gslP&dE^uuyP zpJJJvZt1_Y0FLtNN~cF=jMAd@UgKL6Z7X$vqA5zKCH1<%-7LGwY+AhKGaSa|=XhH@ z&AX)dDP#?T1sT$2WYaEw%)drB|9ODlk-n!rF?OFD1O3ESs2A1U~hJh11*xJda*-a7~j5u^G2@gp+=4cC8$s$4!%CH^4`I z6D4W!@~H-lamm%Q8{O!26D>VmPdzN}YX=B~I5<}83FTvovJ z`EzaBJ4*})9_d;e>xkVDw}=T*xU=viyvPBY@Km6^ump*Y;HYPmk|Kw}%4A~U0q~Gt z@9fiDQ@tqWaY0xMmB(n}J4}0{lxk3w2D}Lg8JK04wSBm|t^0fAG&T8I7N#Cy56ioxofdbV9NH1!c{em-loBO167G291Ro1T*4b-fhY+%NI#*VHNSr?yn zqqmF(j86CJ3X+Q^4};5%0eOi_FBMxXzBf4j!tk1LPKAKM^TWZ*ZTE@R60#8q)bT>D zDC%K-C=)=Vy?#@x6CUczHT|y?qZHrW&?NiDIO`Eq>hN>jjvf~iPo4_nR+m@swkxB7 zY~WwLS$9?rx-fM-q-((AG@>Nod@96-CJu+X?`yQ!VuABzJdDw|PHQc7Iu~tdI}T>z z_pCw+ce;sJ!xBHtD!?&gxf_i8;B>ngU#C;oGn}%JGrG<&ICfjO_y^79jsA3JA&rC) zW^XC_X>fm#SN(blo)jQ+we5@eY-CAP0@ClUnVXbQ3x9>L=+Sqn|7|`OHr!EkVq#(1 zQf}|@cy0ki<}ID7stREdGnO51Aieyr49EKdd`iO+C$eChnj|e)(zEnlV^Nw0vh+I-?7&cG)q=~We+v=eV%SVU2w{~Sh7MW2hw`?INYVBr(`wm!3y9$w`a zS-XqN-?|h2?jQ-}Zl}y?P%Q&HyEqq;*5*`O=CJLt(%QvPQelu|vz%kzGBE68r*MJx z6>{jU@M#K?^NQ!stN=3l)oWy{*f@_1Fv}V_(0>EXyeLzr23XZ-Vn;ludMd^3gIj_N%@; zPpoiuyeBABwK#gZWRf~72gNVr9(^Ao%7z;=jFbK+*7@^wS4VRv85e<2^X)`^mCa%p z^5Gf=1~O`jaN+rn5lfeSZ-kkXurg_yS9$vX--uajk-VGuJ3vM( zAA>$uw}}*uhIz2H3-SB8_@{y3c!45mn>r1Mwe8KY%ud((Mwh*i&Gr$h@@kWO`?Yeb zzTuH@Cy?kYC}YP$Jb8ilgL2w)=_b3{WBbrlYd%)s^FIf`pZnDo^4CiB_Tv0Z!%`sp za_pK#CkJL6=mtrENJ`AUK->UXhc2*=OA_=>gE4W&*A@Fy=18>yu7q;blAU@H#s#D@*(^V z%J(xfC0q=;{zZJ-zibQ8b|(TiP)_XcUI3Arn}Nzu@bH81j|ivgxeNpUvD^iSo2#2B z3?pYMd9l$K8L!f!qRElY$A8`SzdS4D@IQd^d$@vSObr|XP@zLgDOTxcF-;5fyIDvG zLeZ=KtjFuBGrE7-<+NUiJimyf_4l-MC3Es{$OfltM4Bw}va-uT3^wxrdS{h5PvzL1 z;C*oA^g)qenp`;Q|IwKK&yQb=M5YCbUlQ+sWa|8v6H+Wa1teIvG=k}wWcV)z7Agud z+)x*@dc%KthOrSy>ifi<*pS)%b7b+GNW~H0jni*pVt=i4d(P+f*d0-xgvEdJ4ug>( z*Acg^zbWxM_Vve31ASF%`{3X~P4%muu`w*@yq%h!{#3luGlOVX<3CCcWhzup(UiZS zkw0%Shy&gB&G9N3CT|V599cJ?V#suGhq}-lKnm?Bfd>^}zFvNP%_|^tATT|f0MCYW zKwf5OCM{fqX!Dr24}(WG)p zO3rXMVX5QwDzI2n&;;nIROYUB>d<#&cjEL*FUp?Y*NfGZ+T9xZio3 z!&~?|=}bRAsHdi;tjv6sC!EsRYMnzPgfOXHlHSaIqt4Kp0l3z2bF$$A+W4e;M_a4d zr$P~+l&ZWTEgos2WIExB%TCzN1vzLItVIY`;UG|2b@=vv+lIIy@7Kb3y#+Vm7>=aJ z#$CqQ%pHkM^FR2UT0%8$ab^^fWRKG$)_~z7B;noS<7m|U+1(32o>mvJN38X27KsLE z&Y9u1T)>0!{EN>mi7@4sF$$71?%90-IAW@o`A>SPE(pjy|>&d%2 zs6os|KqRtTj0A#ewJ1lbJ^}#sA#!u;k750UM2DV_JNbj(SCwXIG-HiP6s_^I{C*&~ znkTR%vSJdmOacQK}DX$((;$_!!!|NYK{zzHFFDeSC)6;-uBV7v5O{35wj0 z?VtJ2@`>WuaQ-AUEa~h8qWdv=R;PetIdKN z`_mpq0iAMaZP2h>>;VJbA&Ap7d?^JZRNNPbX&FYl_8-e>KOY%P zdhr>-?RorTVC3~+j&wV*VjA)xfq*%v#|{~#m83|UQ5wOMQID^73&mH4J3E_F$yAjW z#X}Sp*7({uBv$WdtQ(FO5il6d@P3wUwZl4j$m{P6*sDI&IL3Stye;z#7|f|Ec7z`^ z5U^rUdpg!W)cP^-eskt3HdL66OIb+L{F`7;zw7aq;3KqRH{rF>APLfuZ})5wPIT+j z*83wB=f?353k3Mr-%7U>HD|T<%}ear{h$^*mTPMYn@kVsc2~YIDkX+kW8#7<)S5jp zjVh6<>?c>1aV&x3Ak5fTv~lg2BLhhPz}%^UAOmpvbufZLV&o-axY83`);C8-N7Cx@ zft3?F2^wljZ%7SmV!9V;q;iQ1;AXQvjK^XZzV8F;R}BRB-c#o;NaqILU7h`Zyv8XK8I8eJ?C#4Nc{gMH^fx;lrRcZl?QfAfL3VT+kgbZIb>;h6MFlT4|9%9H>z3^ab*0dA zZ$Fl+{YS`KmGH{A+B+=ewq5znY&LqFtw{VRPMw7Yhwx{ih6k0A0+pGxgp zD36nfT_4;Q*-%Av3`n3NDcUJp=#;=9{tTAU0JP(@&bFYsXcNWPH~6DV0UI1bel;y= zA3R_uxRm`Se&*>l^mR*S9d#))Uk4M{y_?Tl(yTQ)Ff|D%g@A+)Fh{XWqlJtqNGOND zn$~l`;)3L45{o(7u5rH<{n5l(gMQYI5a~AG4=Mg}0iVV{WYi;iXfK`}CUW%&dpj>` zZ!F(Ylj-A$tf1#fO<@Q&QQJdcJqZHb)@>Wx#=Q$1t-ThDg4v84vwFD}*u(Wq00wDIc^WVR{~^95v&_^Ek~6POL)bf~(NG8#fn)dI30h@oz`b$V#yAiA8M? zrbi_aR)uf!7<0a_I&wUW2J+`JORz)}iY~G)XpM9&HDX@tBAz zY=pkcPBYT_a#vg0Crs(B$>iYF)@vatlqgkPVoE8H)+CbSTP*tX6#X_Zd@@Nk!cC!t zU!Y%>SdHVBGFX<$ptDXjoX}YeAwqYEb<(i`zM>L&$b-2~Wx>>JupI65_b) zXdjmEuTA&;ljf*CYoFF&TomcNNa`ayH8iA%Pf;3AGkYOn6djSA%mUSyV!tMY^^8xs zaX;hP+Mpr6M*D6QTeIj{td}6ObmBF&vQzBgz2yh;j8K!z-{#2mYTXbiOt zleTdyh4qQ5TOdq-y+(;d5&w{~oVar2Ha5k)-QI;+BeVi$>oNW;bnCLS4}RYC^^}v= zUNdq+a&}%&&EowwQ?23s#|r-R+WA#YZb_tYAyH*_^zUE#Gw@Ki+{{AQw&9Pq@i|zG zF7RJYZHe%tUeyCX;`v@=QY0@%3C!!IdPmD-V~;Qlq+Gncy?l>hqsfxmc}0)?c%5ue z8Gh1zmf@Zb3vLY`_DCqddx{+M7hh`m)C4Syd~8d+;sX7X*bH zlkOFz$JiPBKDI_!t3lkJ22BmW0RS4lEQU!Cq`2doZDm8Nhu}vm^JA=ha-btE!Us3Z z;-h3_(gn81V-M+A^M{|Mxn8kHFn+}wYM(~hu24pvSKm*?C=zU)yXYazl#Y~FE8?5= z*G=RkpcNEOQ95d{2rhd&-VMLlbX+B|!82-{R3`t(!%FrQsMETX(1X)Y1f?yN z=z(wr7}V|U5nmc-DB;<#3hG&5fX)k<`uV7b{=y0JzG#*kLT*$Xws6^w13}U61ZxFc zLfhRgXqmeNGmQ@2#U-poZ<{tvqSH9a6ps92QJ#DsRAfmKi$CkfM!TzeSctPMWWYiI zgsy#PaiXBTV(OU^QS0awzd!0zUNAkC>;NcN-5o|J$Q5J^gq=Abdgb=BC%u}Q*i3zk z>XVXGvut*V^hnY0h1o45H|%xX=nqV~?9fGwmSW}ZKbsayxMUbb@miZ}GD-voh;UWh zgwI?@hI?bOmU}6CDnW}yi4Btv$;C!55pt1^FMY3emQ5tMG1rOwRv(c1z#%}fy_G4( zi&P8k@r~bGs#4gPtDb_#{bH1x z@39cdRc2p~XNo5>v%IKQ%4ryN`Yf197;&&U;Sik+yccQ3k}wzZN|AbL?wr0wCifC@ z7mmbu8k+Q*I8L1j1rcHJ!g)326@t?_O>#mmBCX|;)mOIKTB3zeY)&{RL{SMjrOmBn z^mZ}xd@%*QzM*^U`GymS-hP4o5y70cTBh)j*o2h$iF6t?y^6OC(3DQrg65)_45n@) zGbfzbh_oe4sM+%%ON1tR@c-EQ%AmNmZrdb4aEAaPXmE!hL7U*N!975L5ZpZkcXxMp zcbCSU;O?%CHgBI#?meg8{Y6nlb+LD^xz-$W4RMomkJ>JJ>Now}6T6mCmZ<)5-b<;y zsG@+A(!3ndiq7aj^cBUPJhI(B{rLJz%%%Vl?<9|+*52E5i=w}A`dRWOvf5+4#hDw7 zK}MyDlZ60>{i(f)2E{FF!--AROgORCxdiNRhJ54USn=9!s{-BOOoi| zOQnteGDq23+;_%Gi5@UQ>wvYqjuw|h-eSL)+69ZMmwD>d@h4ndZb^yNN(t5ZR6h-~ zn!9DP*i_IPaZHMlv5y$Dc`gbr#UyXQt`=s(=-J&$9~s7RZFpn9*_xlw#Ijx-145fH zC4(#`S^g}$_BwP48`&77{EFJuXvzsrL5rL0`oX#u(}6Fr%OLo^7PK5ybrzz>{z1#g zoy5`rqpVt-U{nom6@$Y7*J_C+=N4(>WE>NvxzxAzcv6>`W}{T})EWULc1XD@S_Eo3 zE`EIFnDJ`J1{PEy4fT<~HIq`waOI>b#)kY-DhqeZtSV8aoeNLd3=XL90OQA7+`bla z?l#4?8IpwpIvwd8x?a8dII_}CBU~nIDUTWI@g8FdtL^f8W!c>xN_3JeS_3(4+Gc-q^_YZ`P6yDC%J$6jX+N?*#c=UkM6WAq+zaI;aH z%b`Pyrn0DvYZr+;LYjN#A<{b@q7Ku&zx@jJcstca^C*R$9Tt?$m+CEq4GhR`&#V-n zTB{-cSUA4J8tPp`m-JEe42|_@SrWLa?=3=1girh3jL1|<1&p3=Q0(Gw6o`!a^6LVP zE|4+e^VqWzlgNMc1rYb5YQZU^?F_*~x@*~rK15v#KIcXFC=~a(hj&5eXXg;%7;8km zYX)s4Rr!>3i=cNW1WJmN#PF;NK@Nlcxy7ZL14KHS6;PR&p^TKu`YDGACzr6b_;9|g z>lowI(gLGVuJ88+x3m%(?Ul6}3z6MaICfp>rOrv+)lGUd@cy|nG6*vW(quLjWkm^d zoF_+T>*cB<5*s;C93#s?&U}J6)=h~ym^O=Kh zGL{$ZVuF2D#?|EFCwQDh4xkQ(@t?4&2DYa!mDIZ{<;*x0C-{1rEdG>j`m3xp^p1LOvcagyp;?mn zqCh>f`Ob^kp=>5MG8a|MKuq}IZpVrde7(+t#Fg!8WG*^f_9 zH&!I1rhW`*W@YW?mzep2Ei#MCaml3~sTwS3fC#Fa6q0LRDD3T!->lCIB5dCK&-As^ z{f&Rd-Dm~6fnt@lwRJDak!qG|BMXFt?|s_?Y8(<9?M%I++`~dYo(V!8 z6SA5q(Dyt&VDe}Oj5vx?KtlVW0~0s5ecA4L_A6uDm0}%1s=WCe6^hd7QBo+PQmV_@ zjJ*$fShj|$xpFjp4zu01Kg!yP+$|HSw6n#t^DALFgqo<7la$i6!zBl!rQ+vZFJ8$X zB}Wp`$ z=zeaU0AatI?#bUf>T_ja`OU3a1)>V;-$eFtbJtvqOB6NmR48atao`4}T{hZXR_13) zSeE&vw9*P$9;ldUbn8GXgPCF9u1^6@<>G(OfjNK#!3bsk1%M{8FeyBXG%KqTqSYflx*;3#|yk<%6o!%Ki$NH z5y!yAO4B$f77ZGIM4zmdGi^_0eQ#P%;uYkS}@@YG2fRq8%OT-9<1%B9Mj@t+*fe%SY*QO>dH zY#e!JJ)G^G%zUwDbScO0bj^-u)FubR!&13Nj&k35Y->KPZ{blKmD8w!oNEiSLI$npaj@xK96s zi_?@0^^W`Y=Z+GWPbtwX&EnRndsKsF1gA)YR@^)hd?{1WL{4Trnk62)JntR^FD43j z%+T^?>$t`x)X(6QRDjAvsTy|P4f|EmSZ?OWs^O?se@=+}M$eFIcbTs)nvMjhH@(1u^-6WTDRMgK1X%&R41YZtsQ@nCE$NR%M;9Wt+Ab zN*{w2GR&U2!!oFZk9!j(Ol`gSkR*OS`-uvv?|?^jv`nXbsOTlG$7$~m&4ff}Dj3;= z6qsXUb4^y~<>nTgYztvVCqI0qBtE{*RjbfNKt~^l`Y0Y5nOUEHU~D!?MoX5<7sDnd zMZiBX<WA;ZgX z`>46E)x{~~eRy+vF_)t8^zRL&eM=opd@n`b-4QDb?@$O$oyr`r$p(YxCWr^_PkHqKnuyfFpqQUoK1#W~qIhQke_W}{?Arskm8 zAwAczo2$yeIv7eHTUtbX>8jAH?iI_o`%{vYdfvetb3bJ6R4n{DA>J_=WAfcIBCKti z)u8~3Onz4VQ1tt*!i(Ols>feagoc7KrlYsw_N7G~=reoyo?bJKvU!bWqZlV^L^{)( zr9U;0b<7Vd$n7I`seU|b*>Cs?K%}x-r_aP~Qq+d%wUh=Agw2!9h`7oK*OKGJ=+VtpNf>D@^+2<3G;Sds{0lH)-r$UIeWal)L#z5(rX z!5Y58FYZ;YrWrY&N@t8a1l2i+Wxxpx6_eE3rx zj#LP=_Bi@dk2yM}>NF{}@8Yym!5+yX&@H-j?XuJld5;w-K@}+38qR?IlS*8ygeyO0 z*+#5GaG8flB}%^NQQCpHJ5Ae@c!1Ie*Yxx?yB13tPL`WGmRZy5gzc*~ z{00?3ss7n_qF&Tc#MtB+b17$-`|BHz{IM)fF<%B#*><{GM8Y%05sxG|D@HcfG&MEe zkTP4COPz*r&(sA8ZXh@)s4Me}j|`Kx?YZyB5x+_BhmgGyR@-kru}0N0)6GB57GjuF z2}_ijUC1H}Ug=n!9Tn`SDx(dCEt65n&~35(X#T87q5W`Br-E}Ac8Xs8!%dE)3@2)7 zVaw0ym-7QVdl3uZ>oSt`2_BnDW`fpm?6rmh7L&TJeBe#t_HDAeV4J|g*Ak>Ay_lU~ zVx|?pIBvCJtYk`Jo`cuPu?W?iN61b5(YP@~G1Zgrg@(AF{Us%5y(+2U(sac|E9Tf# zBRXT+v@rB~dVUeW-CWwh1q$8hQ>A&9M@GGLiC8dsMTPj{3Ug*vHHwPa`BJ=VmWcRPH*_4$8~- zz=^`Hk`Qev<+V?2gNgmKSG7q_%Ehn)L>a-Y<>D12Nnz~KZtFF97W`&nQqu{nYK2u! zd%7Ja7~`u;&O#8D6b~C2fQtAWUr6-8D|9ABJzUR*X&|@Og!0#Q+IoiEU#O_0h59WT zVSxeHUxF2+KIaDYoYYkNcf{R`eD)(>UD_K9Ur__j$&W@J`Gre36s6K4C_jAR%ptd7 z$*xSZa))OmNFi34m3L5S&waf;pzGXcXJ+ir@(`ReQu2&w=OQHVVir%d9zVUGpvyW3 zZsxj0wG;~(OGWfaFpon@Cn@T1DJ356DP_}EpDRQr?jDV}HRn?=8q_oB$$zGh*HXo0 zia~x%9dc7~8skr9d!fumBHCP!6_vRY6Bv*CwYhk* z>P{~hA5S#tHkktv=Qt=kIM!)L&D<3KZDS*Q-{HL6zI$Rht-iPo9I?$5E=uPr>bko>~6p}T!e#cS(6 zIM7}|7&-m|d(*DE#!Fdj!Oam)C;g$<_)rkC)clz|w4cZtXBF03JVjYz%AI6Xp&ugre*X%IGz!Bov8RBR(|DmS=5umwTX3J}B() zwPj1|0ppBpcx4>J_g6t{V~DrG}` zhKRJ>>zXWY>@cR2!*MM%%ei3qiWT2ag(eg!j}-tx`oe8qth?v$K)7Ny~OBn-0o5#)kyZ zd`%lyJ`ua(?YOJtXTU={{>&a*FwUcu$ z2RGcJ1Ur{qY0D0%wA8t+QF*(K3j0hZWJjbL{HS4*e;OrxdJ9)$63#nsfL zZpw*ht!QzFOTMKa)-2KUKU6Gj%RQ$bs>yQqp8`f*C~=BYv2LZBBB%vmW(!t`Kw~1* zv|%mzar;@YM#}#bt{I~Fu$0?AsRRSm^ICxbZ}sx@b}p84ImbY^!@n4I(0|c~El~q4 z%GevkNd6UE4~TVwr`&fdYQ~z|tno_HVi(oHV;dc+1y1 zEXVcY6EDdvwK2Tg0&OJxjr01w;VtjSbfG*R*WrPp`Vem(vlE+4fbB5Pv9&Xh<+h!7 z^4r;qxy^2IsHUjY+*zmTN`ocDGE85=GiuWDFC8BZG z-~U01ia|`oG)G@T ze77LxvPjyuaCUvD5F15Mg!*rd2N$n|V_Y;JdjYMs z&0%c$l=%-rWIRv*C-)nXAjhQreL-iy#*Vsf2zmwC&8Z%GGwfyeBkm#GN)zs3xyIcj z=H@2xGvR*(smUYLigJ<03t~IJIs{w@M3ZR4njK6>+0=u;skiabFw}w>mGi_+Mmjb` z9Iw(p4Un2Iyf2zr^n2(LAkS4r=m9S9|M zwEePaE5SYccPdd+gghr}m9>h9^cRY0Xv&)Vp)Z|H(t4~>`azqyFcMD(qJT1m^f=ih zkC$`jg3$4kd%(z-)vC3q=KGInJerT|dl?zgG4E$*>l$wt6F6qa(|8d9;k$^63db85 z{;2fy0Fw+IZ@ax}qx8qK41aLUBltPPbJ1@cv_=5|>pXhH=XWh?mXVg~JeJC2a9JTn z$JFdV``l4aHG-5SIm3$g|Ko~eAbKN`w%%klOa13zp@c0X!dkL@*N=I9`+$CIc5~l9 z?$7lqcG%&BJU#~mF*0sJ%2_*-LEY^wWmD4K_T0vUhL-yWfZ6`7e+1*50WYBLl0tS` zu6H1vKcwH~dj;5|?Ksemj`y!guZbC(Ve0tE%eeY0KGMta%C1<31CDrZs)3c51Ix4qBXKk%Z^-rWz{A3uB?-gs)V~Oj3TaM9D;5su?Shghv7xIqrH*m3+?ge!L203O7H&RP~Gn`5k zG@;0B=^X3DjE}>SmafV=F51J^r(F(kyCw=P{Ihy~i`?4;TX$rjqm%YE5jF1U*bKPb z%`B$DmNeJ(OB0VAvj+11xN{Q}d_f|Y{g4fuDMDS3d@x_^yN*j6g!RG2$DHNyMUM;1 z93CPT>+^2pwato-@Gh`qp55H&D5iiwISma>A3Vtu-R+9~9Djvw8w_)MqtY@=JkPtG zlO_4}AnNw+`>Ta&hLdS(KrLV3X@}mbeH59vsBT6~3;4os@U%U35!^jL7kn9_LM)^2 z*f>#9O40^oZJ*ok#jn#HvH65^l25v!vl&|f)Rv+eUzlvi`oPbCEURL{cGUPBE#wO% zs78dut`hPX;cNSV{L`;n3XA{S`_`ToNr;(JHaO08{n%}r+guBm!SBu4F3Im6^A52z zh3e+8CZ+V8ij_5{u7i%bAwyfTwdYaz!OtG*#CW+>-`3DGvH-`NTc+mnlQxrO6nZV* z4{MD`JEHdrL067)sJd?s8{s&u?d@Y7>TVM9%21J@%OSbcYr?=ia3$)k&x?nhQ4ME0 zkORVm8@~zk4ez{1n7y!&YGJhfZ{qgMhv_13yXLEZg1uwTTOKtP2!I4Qydu+Km^Nkz zqVNV=E6F7n6tx;@3DNkE9?c}hoRCWc@};N)KmuBO(L8?a^4L^xDQ9Cu-qU+P=xd3{ zl&h6}ye=v!Rn3qG3l62|WzunAa|)!h$=+YKozSUsbCX(txeCVy@4aqUVFt^OTCey> zZO6Dd-h2^xBSQM>H^qBcL=4g@E#<&y})YCZSCk3Ld@f?H-D5O_mpv#`uIHV zblO&a(%D{a+;DgX`cvCdhAwcQ#2XyHVRRPQ-~R!9YXOMYWOusX28TpJ?z3A+Tr7o~YoMY^h=%Y%M$xCX6*PoD5-etSf zllwu=R>Js7yf!Z}*Uk6jxILZDmM@Ji2Y1J959Unv4m2jYdUL>05ibV;QEHG2yNXt!4#$Q)|JMS==3ShpwLW(6LtBz`3v{iSaM(f=vj)y&}ZeazcPz3r!sHUT1?R=1;k z-WsbELhh^Cz(6%>4V_1{V2KAFC;icil}W~u<)fEdE1Xd)tCi-RxpJK&%jHJJ6Zayt za;v4I^0PlAQJWi+$StldZg>^WLPbxOBJ<_t3x`2Rc{}FXka<1(GS{zBPs*VCrY)BP zxtM`QUQA3;(oUywz01|TN41!v2@dj=ZM)r$_LHr1+iJNOu}6<*;BR&VAg3*aH3j!3 zRopKun$clu$=6NsSK6Rp--i3$bYZB*9_LAkyo_Q;h_q@meTJp|{M^}JK5d|SH#9X_ zj4%Sn`LYGwwiUn&|6LnbW&y9qgFSxVE$ZLa3k&RFpx~tBRb|0IGa3<3vtGyJi`ZF9 z(^>vVq5Zg=XII(7-80X-stfTWKKtm3-cfO1NI#fvoR?b5A+S{kv&iSG6QN37D+GE% zc)A)kj9S0RWwmz#2U0X^EodN_G}hpL_bcZ6oUM4lyt9`P8$V_6cv<$CR7v^6OHuNn_AFovVRb zbnWqbMp@8pqWkKiQ?ElhrvTnHH0~%J$8K|Vrv#d(+Z5vP$E8@f_6Mo!(rwwZxckre zXiy-rU*K?Imdc*vhr*`j%>vM~8DAq2;T!*7WcmpB|Nb!p{}ks>4}* z&?+R5&wmQi0CjUZI!Em2O}^^!yc>Yt?`kK;UgVAb5-9K4-#DGfcnJs$CB7NQMt_EO zfMprhC3h`WByC?FWasnGwsxi&5Ce=7q|}E{B`vvrMz%Cty4619*1hyyxt%60L1+5; zReC|K0w!Z02Cz=@q!YyGd5c#WF?Aa}{&!0Rc1G_-GKB(@WS8$v!UbcKOO+pmd0cH& zW<6)8Hm9j0n$tZV?i0M~f>2IgM@^2J6o@3c_;7A$@y&L3$fkIthy5I~eeABA>q)*K|7Gc#!dW!m15G&)I7o|6si`FcPkN$I;a@k>yo;?RcC@g$TYq33#p<%MnMX z`_WRdIyc9EHTX!ieja~`Y8-D$qWeRnfhjKeK|I!VjfBIUY3fhAgk8k-(7k8TfT29(Y{J03w4R3NJdMyHJ-We(jp zV*2(Gjs$Gm7u(w3$Q`Htd&K?gq{}3QEp7=%3f|1-`VO(5{ysYD=XnvwwXze4UlAU! zPr_e2eiPU_$~;o!ntyp1CvCYGU)NMQF1luYI#m|z{x$T&)?L~2wXoBL`z&&@?dz03 zV}MDx-Lr|;f`i8coD)RTVSFn~@2NBFF2fW58LCr0$2ab7Gf=&3!|l|55$3hCJD!m` zYzvNl={tMrw|a6bb05_SIRji}^>+Ed741CWi>0cqz{ecWLxjM7g=gKOimIG72zsIB z(OBVe;$;VqmKsd0e6)4;V(0{^ptJh==*YM_Nzid&^^}q)8A;iH_B2cKPz0|$&~m+*^3H2Mr_+n* z!x4nZzK?_;o3`b&e`Mu^vbzL?c~J7HVOubtCYBbn?V7NYs@ zOEyo_`v*VV(^ul#9_@BdGS@o$I+hb$o=bXtS|>YECJ&e5M^DnN5AK#(M_2b;PkF;6 zDq$LXFFP{q)(tB+9b@tqL)-?UH zqh`79)PCyk{#Y`3HIwjsHR-u&+(5|$~jap zq{|0>e&y0FRW6J_-ckUUac>)tlg7!dxgOIt-S!mk)H=vYU+8*0J2W`k;7U$4*;?HH z{CX6=vT`B~PA`r>;T4~EYvxUm6WPw))v`fmq#n*#sx6v!5A;$65>~j3Wwokx>QZ@| zcF~FXLfhe>iM`JSAkdx|YSWSKQ^w;Sc_r3wJdca$ijV=c?Xdt}ZL5XmFkUzP1Zw|{ zy5?i2gQ>pfEf6$^*9+=?om2)Ybile!DW9~P&qVhMz(f8 zL(Rx@%^Ky0-Zhk#q2bprF6^~-}E z=vEForYCTlNQbhzWj6^4sqxJl?~6T{E_mFA9%Mj^$0Kcd1vfQoCkMq-U*Pl9LO#t^7i8GKB-` zp^som@W$lmB}E3G3z=(`@~rf7^k%~5@~_wYkl)AZss;>14~xTUZBhxodCj7>;5lD9tC~ML0yH7wtYyy6@dfU6+)pwapmUUjMzz zGnGS#9hqKE75r)H&}Oq$&uhPbVPgAU^L}%O+a}w6b=rB=F*Uq1ZhsVV1kJNty}bT< zA|(KODk9w}|FWmIjr@K7?;pnkJLD@k?P=kVBv||N@95({AcO@rnd;5x*mUd9h3ftL zg?iZgYe&oigxmiiuz4sefme~{pWK`r4ZP8N^EU?~Mid_KF2FO%#M)j$4IWx3fK1f8!k9GLm-csK** zzZ&^u*&B#iWYCT5d$ii-=3Ko02qfe>$b+8k)23wU)_%Q)Sq3h&4`X!=gRj3{r+qu* z{yu^a@1KQ3=tGg(19<$c6|2zeB>?268J?!xOI$)9O}ruU|9Pr^Jt`k2?BTYV|GDS= zu%6>M=%tnU+6uaBwTpzV0WIlxpq~fQXyL!K+wKH|1gzJ8xgF#sEYG*!a?Dp7OOrsm z{g6)wK+vkCHr_|x=B8eYf}aS2$%_qEXQU%;5Gt>My9JXB2jKuThX%`)%ac|hd1mA? z3qA6}VyV6<1dsWLxP+^xL5i zN0SL**;csXB`jlUylQ1~eZ865u0W1Eh|6|cq5(*S{ow$1nn=ANCE;{j`P=*V-VGE? zlelW=c(yo;YIeW3MXlIKXuqZmsCwA*g6^%JLt8o?c_F+VSG=dZYxm(jEPCzkO^-Xw zUI%RxYY%e`%WCbHxR6mR&l>J;Ynew+zymeeZ1pz*IlxVXT$VF5^u6B(BgxVukC&Tj z`fqZvF*LTX52ifPkh{H}A8vBVjna>1OAz? zY>3tn%aeBARp8&R@%bbA_ogtJ_KgGi=*!di%W<2tbBdnF&Fqhfs_R=|L0FEb^FyER zulCYH6;`!J(Yft*Lm{8#()kuhmO3^+c2dBVyl!?l+HC{Y*_xAd0@DpFhSi$gIRBoy zZSaJg`CO@uZ&YO~6tExU8L_})P+un<&Fx~F+vo{+tmkvTvu2W+k~>>%<0nhuaVXY< zy6ue-H2_xlA7}$IYuFRx5a!Eu>AdF0)?c%~HZxGc^6@nzV$;arAHN=e2O=9m32(()V~}!$>MHv*!h})x^s> zJjmMXp}^}Vfb}5(Ozpf-nX(a*!bj&4I9~yZhX$5%F1nI1OMe{E@Pi{@>YA4Z_S1SXh6R~DX&WQ9btfc!C8aM zkIyPBkbR@d|2mD&Whz6SuA$?6pTwa@SoWE$V;h-MpVAvcN&c{(1JQ+q-*sl5KX3B)_n@)nFUv3$5=D@JPLBpTL>iU9 zAC9m(WHSn9C;7jQc{h_LiB0uxf;qfLG9cI0y1d*~Q2X*DI`Qi)ooi26T_yy{x7hZt z*cT!?(_aaMdX@%ch!kdn1x8 zYV(ZvzK+iCpxcU>fZ4Io`uk3Z;@$5J%z!?Pi&yumuF|!RU>F6$iSb)aK)dnU*H^t> zJRl`{&&4*GGZNB4@R20GA$tiLr%x9HByQ*+*sT?(HQBPHc+vtFG%`w?-$lLZr$Y{_ z`DG^VH;5nk&x&`tJ`CqPb-|O&oPSkv3>>A-aNk74e2g@WYX7X^Blu28i0pUZ;ehlg zM%*}w$P4LpG0a{koOf!~b`qltIb94$;!aIbG3x5vKVgRheY_7s3opVd*-1*#8Vk<` z77=PCY7*li|C2k=R2t;}`E9SkrXV(+4CgPC4=vvgUq7(BUYk16XcBZT%_p_4TdZkv zM&Jqb1|X6hR@~-|NM2H0lNNn&9Q)_av_Y;wJrx>-4qjgnfjvX-rOaCSK_bXx}zM4o{35s4Z*x$yD&bx{0zkJu9 zeJ3Uzu_Q2UNWa#WZQAn7{lX1kH1f}@K^P9CytMwoMoD28$%uf)9JfiJdD#~-tn|(0 zJDCWEMcZdN*f84-4NARhMBBLW7Kx;Dq!{?mU!XTAl|$nAgXG4Dh_uENwie|%TQih$ zB0(ex4Qrmfp2h2i@gJtgP0@$6FzfK1gP+VusEA3&X9A$K<7|69Twc?T7m|Ltc691k zZ?!+l^3C)b7jb^}i+S5O2Sc?o2$M=SsGKizF_7GtwL;pdZ5YztcT~jVw?v%0b--pj+hUEL!T(s9N=%mDrl<}liBS$i(;7{eONc-p4+%t*jB zxD#sZfV7?MY#JjVVE@D>N9cTZkJi+I@a&myYeR z;fMFx9z$U2%O>of4EJk!Vk0!+YA~#{#44#B`KC~Nm3{ovC%tlgN&vDsfzPD=1}IPHzH7G*F;^w>`bGakqyy>8r6jlx8RQ! z73q0;F@L9^RcC$~4NSV4FJ8<+%7P4!DVS}Zbf#1Q0WO;Wl;HQexqJ9k68_eXRySj6 zDNQ}`gup(7AuLt~gW*5RZR+C$Ia@DD$3|h` zy+NR1<#1t{^i*2AN}W-Brh&BEv`jNFI@nRWJTaeA<+1s%Wcch~N3Y_;TCo|=9fnno)Kcz| zh3;Tk?|zc6MPGAfiqAcG{UEb+pt0-h>6|nPV=(dchE&&zm-8PzI1SMo%8#RcG}TeN z-!lL3O?MI&?Yu}LwBF2VNfw-FzeW&@)Zpm0{dt>2y0rqwaARfq&2?fzvvvLa zOE^<*xQ0-VqpWun^>zHG0Os=_mN(4C9vX3-Nogg-DoDvz*8Ux1m=Ytf zYfY6JFOKFUrI;{I$DjZ9uk%`Q_?N+FDTqUeVwVPyx{%6V^v@|3Nm6!c zTJr{d9*&j<@+9=k!tq(V&-y!L%Z|c4mhxo@fjw`-;b1CHdbYi8W#(vAVX0iZvEj$; z&|mlk`nd7H=}OCzm4QaU)~nNKorMGDx?`X2->}_STYvrI6&#SHSwP6~BbEf-tm51g zjL+@)Wj8reo2gF9=KCrkYOMK7K7<30{fD@ z9LL%BalOZJfyWWxwrb2R$_lv5`S|87HrIj2<84zKuhU%PDvlh>qu!74I4Z)sQyXO* zlnOu%&=LWHHj3SlD$%Q!p6d?2$4ePRq!Q+FJl>yD#T5KeuGgVEv$#84liCP5U0ITk zJ-0YhLNdxq{4AIHc{0HbfD|0TX%1y#>sT3sTxc2GcKwjHZ=2{FKHL1x=b62A=9*v< z0rx)_Zh(({iAz!urQvo#hUX#R9Yz$RSr=8N)!n#>^( zTI?^r@xK724P|-x;uM8MU>N|AqlBYG^uW)>}MD1C?yU`K&ntfA=GsstwnOoO^UVP7L#CITa8OWvD5=3+Z zUhOTu!L_nw;9Hw;V*W$nEuu6|i{zKpPtI&MndZ*&8r9)nkaq{diH3LJzkm8Za<1!+ zl`W847KB*29iTYmvT30nS1Qpu+uJG^s8Cv(Jpk4}8ubr|e6_jx2*=TgL!i|PyIwPe4@qJMSf zRV=;PH))$cZ0|B-YT-c-?A1RP*g&4fVL=)mG3<$fox{bffvgc#=-~H0_-phZv%!Bp zS+5I%Ir zFFFhAU&ncZx0vgq*q5(XUZ(o_cVxXxj=#5Nj$d)ZC~y9T*_!~ZJO-wy66cx! zyKhaWcvIhgD4DId*?NIvI4Ee8*taVdW)PuoEfkOT;q!W8(eAFl4HnXHA4}n8wd0AV z6A&ux71LkY|3{*Ea$w6~|68Q~E1>Wv&x9o~%Wb9;fMv=&)D^+@MVIx7v{vJ{4V&BT zPZD1?Wha!}+@N31JlhcpxCTH) z7~Zh@#|A%UeX5s|`u90}NGVCi?O1A_=BwMFM@iyr0d7DBP5tbxvaS6~rh&{U*A&XvsDsk8lg{omm%RAd@k>FAMr8GPUtr?YZQwj<53 z&Imaf0D;C3FB0<`RKIQea_IBdDr@veTuAp4(U8v!>m4~$ zT0YTbSdf`V3?IQjl7fEpZ^^F@8G&Et@I|x_s5WRwJPI*32k*u{hmHM$dB}HDj6dBO z{$zek&9ue%rYZuNXKWkJ0!wPnIn(-j-kJAEbCDAnQu;`8kQe5r4FmH)v9sIWeOJiC zirMeb529*87-)k5(?y``zx=BOaPCG(EVxJH5~Qu!qXPPrp0pG^f3!E=>;C?ep8FGI zUg3oJpKkD(5X9Jwg!~IZ_xmJV?%xl|+$6bPSbjF5Z}>IpACozAi3YN48qwIGP6;Kk z=p*>B`gTd_uuG5MFm!2!VO}2Ty(}i+_Mly65jv}ChhiWQt>O*7)kU;CNAPD7;%6Zx zd+)4N2=n$;49*NLUr>kw`Yo$i=Z78wgKVMIPn)DO$UV&NilH5m=gh-@Mg!2n-n5Ms z+D%gh#P!s$Nu2#Hp@P{go|vI3uu9@?`Ur|%h@UgP5RAwu24a)k13&smUa0E`u5{si z85eQHr81-U;P-rLq~62es}yy*9;sc_c}q5ci}|I1&p8#GQY|DQ6Rs z&FJeXPBcmIpc_a(e-LtM-&Zk7h2bj8Ciqt9jr0@6^3_g~KkrGWv!H;CD(U(dHv&p7 zT^aLTC+$o1BE=fMFPh+Y8W?QpYobM*$zE945ZfxYH&c04bQa6fwBPHdYw?&)#?Arf<8U^ z9-CmTF@fSnwpu?mmuS*uF)f3_9EBWZi(LrBkc%c@eU1KX=sq)qZ!Iz&#jvvFs+w+Q z*g+*mKyR90%|{_CVp!Z?R9vHCXus&3$&#!-XG6Bcuv~NMm7B(SY9qTv58;-D{brtN za?MQT+tlC<1CiyMl2CPVfAwhcR}0vQE!HVCsu>3hUvSOSa{eLosSL<+7y&2QEL zlr{OwhOcrg4I<6Q|B|Ke zyZJ1^hJh%)%9i-~`?uF=2w`I3$Y#u#_L#B<%m?N!ZB*1O+FfDQielE$0z%}SdOX5S zQ@+;tOEX%zOs6MN#MF7-#0Dq<-MwhAJ|xmZlyzPF^zu_7TTF&J7P3pi`(G1E+cPgI zExHIKe_&q|N)r0y%V(TYf7}Ax(|!*OA|}M zTyHV)R*a_P4Qci+gRZr`U=ouc~0v5@12ws92(@0%rX zePE5{b9GOF>a=NmbML-I>?*!y)F(v~`RyU%G$)o|lAof4&w%IOc(*e~2>703gxXQ= zzrkMhC~Pi#X)Se@yV$h7n!4(d6a_BTU&n|MH*Vw+(BXr$2h8CRziem4zQ%iRAIarO zu7Kf!@*0Jk?T*~t9SOc}!_q(z>ko5yFvyn~BAglz2;cek z0?y+#hS1{i$pcyXc>NKuZ!glL@cQw@-V$fAe+)VIje`&0Qs`8CE9tEJ7^zPYqFD+hCkaNi*ZoKB3F3@27Rc! z^}q!B5)pY0PYksyLl3bJ(2fMuz%B!g_6eHkDh^*@q;`53+7Yt<`n!6oM5>c~%0-?E z&r4vSvwbZx|7W-7UnhvPny=qb+ALmobsyxQU;`hk9udDF>`SOxZ=86rM__0_yhQZ_ zhD)?%qZFmZa!3NRCkku8;MAVTATHHWayNy=Y=jVbuJsh;VVyW!T1ri8G0Ih9o?G7K zGd2J0NAB);me82gf^kz7j6KZ7-**?RwB|BKqN2&fkyOSTZEyFfhEFe;gp>W!`hD(z zQ^Wq#YQX0VN^)rTD-^ReVMV>M0=L|hcxvn!p!kP`1(GW#xGlCkQm*d%VY&+iraq40 zlAASRbCr$bv7P=u#@+*}sb&2Gl^!GznkMw#qy<3%4ZTJ{iiO_0NCyoF(tGGgk*+~d z5djeak={{CKv4vwN>P*&T4--L_niCgeb@D0Z!KLez>w@c-~MLi*T&F^$mh-(f75)K zoyln_jp%bTfVul$m@5iRUJ#g;V?I@)cB+kkE{hF88RKyk!Ge%SNauB%7cWxW;H~hb zjmThg5kBAOcs*u}84! z*;qGhPi}`h%a}|wy8B^5)RV6jLdJiqjS)4e@FflEwR z?q*PBYZ1K&^v9WB>D%U>?;h;hQ+sBHiaK`w&}y>HX^A^;qkOF!EKUM-YdVOQbBMvY9MC>{t1V2O)bt;u6zWF5UL$ID9f0;1}jbuTBU* z&$H(>nB~m)>mhqvY)LuB{3$${K1I^;+SiVK)yAtc-Zyg{5>g!wbfZg@B9kjvM1D)@>tK3&)t%UYS@cu~1Z{KB{63UAU70r~V3! z;39kYX5L6EGk%V?MY#Plb5%y_nJ&FcTEw6+^Nu8YD;r9>-;J9|p4J;3*$eI?=fA5- zN(Dwj%tPo8+dtEOJv(QGi_8jFAr6x*vpZ7xLY1}U1B_%@I(c2LC^eOExEe2)Ljq$- zTy2=sWZU?o`#PkaBiC+MJqqonTg!`}aCmS7a*N5!#!-!@u5=+M>T=)**rNg1?JdED zpy04T9QtN%=ryp>xmTS2;K2h$S0aR1y0DePMl2*E!bXi>btx2@^eOdRse1R`A?m83 z@k&Gmx7jmG>xq#rTSeDME1e~K>~SNS^YF143Z7IECh7>);Pjf|u(EGdcD}H^+>~-b z&}o?@y*fik(1zJ@c&1yK;}cX29w*C|J>LLhU#Arv}-`fbbW7 zRPSi^{cZ-d$=v>S0yT#T#VIHBm;@d4%?g<^_Pzq3C+TOj#TwIYJ6q2G5i1@rDyMM|BiExxOxTn?Ls>Qk{U zmFyIF{$kEnp3-d=N4-Nb9`+{Wn$HB&iQs_&ULT~>oYF& zPDZsVO$}}=tK1Cu`S35WZ$Y_5+gIM87Ih{3)Km{nAku*{ZszhpTD%MOg((r#Lk!ba zzKtd9O%N|xaws~0!G>Ct)4I{FMpP#~OVYaF(w@NzUS)hl(0J{to4?A_{$YcWmwLx; zW&Q^BCjO`7^&cX88lLPXI`&_Na|>aU%g1fbr!Fqekw2*#6?&S7Aq0F=erRBPlSkm) zPjT_9_LmD`WSq-Vf9qSyH~>KC7Lyg_p$wTe3^x&=mR>G4PIWFbxzi=2k{t!hf^^R@ zN%qb$z?{1~HBRvrlsu-tktzC0M!L7(WWr+PT6=oV_o_Bg>(jq|0$o|B9*h&3qm7B* zO?f%r*u;s=u_cp4{-$v`7mE}L6_P36L4#ZTNY<2N1x`2#;}S;-}Lwk z7@CD3T$A|hy8oLp{O79{K5{(QfmG}tp-^a;QxZUxoGMP;{BJ$w-)L%>A30X0)ZX6s zkFS0x4{k(mJ*mmof8x6TmWbdPkodf$L`&Bq$l9K8lc}rS{{pJYPr-k(xD_dW$O=In zA8dNO&z8C{0Qa$*ZVB@^*j@i{Rod*= z^mIWz9{?~t-=_UuF8BYTqyLhVgb}%io++k^d$eo*B^~)(XIk-}(Y5%^xTz?p(h6 z%MUE4w#IgWHX3#J=pbP170~;B|26s6V|}*s2o>YP42TmT5>uUhx7 z-{ZC-shNCmIVheHGD@fyI8xKUJDF=IfZ^_BgTKd6kxG(Jx!LRZYxSr@m9FN}g229X zufF5y)(C9Mo!6D0{(TsK!}^RZ+(h{uVRWZgGHtR7_qLbQ{mVXXpG4s z({!%Bb1rCYS|wx^u(SnY$?Wp@L;%I-%XzoCf39q`j$ohjuP84kbRR?nC1mWZeC+1} zD7wO3g$W)-^>?*+Amibpu=Ds+_14pa>XpFt=fNc3+p%l|95iUacX43L!B{7kr{V;5 zNDfiq7+%3rw(fY5WGN`=ujC zy2Sjqjz8RHA6^^JK?$L8$M*@cdYc=PwJ6{B)wdO1Vc%rIBWI+dW2+Y&Eg2y-vugL( z`4cMBG&KmMlK44!`|bU$AEjkSL7B4Ia<~sS2hoeX&OZ&{eYStPBzXp;zOVeSguNB& z+I##81M+2IWM3 zp1510xkp_B3tMm>6MLRSGuGg@JZn7_1(#)RY&k)gryMgeLeyhC@6oP8Q7?bXZ#>5q zJO*9_*8zL0Z!+Rdf2fc>dT>pOk*w?fRs#<-Z`F5wEl-aW_+qM49v89rCQtQ#V-a@H z)-@+Ue)CP^|F)U{UYmmN-@p92>GmhaDU6i95*~KuR2s#9>L35@*KsKTjmo27)#LSl zvv2>fO{c(fP7&&$@;`5T@*5rsope~|x8^VY*kMnfOp_$sru2XOT;JoA`@z$8|7F0< z$y<{q3`;;u#dgWB%nfv|*ZS5P4L_ryENMiM7hMM5?jI6Qgk;|j=Z)3MXxt3|H6`o| z*cl@w&ztY7GoL4FVG%P4`P=y$cAu8fxXw49sp#03LICfx_07Rp{$b$cjlt+jM<|=p z{l?#FfM?k%{7BP(;cs#*B!7?!e=`bV-&1DzIH>VNj}8=UXBovk5kEqbd1g0bPAmpN zO~ZeeebF)>xZ_X0b?H5ee%ZJYpDSe_Ez1v@ti=nD6V8H7-pAp!Z{q4HGFIMk1P)L_ zbP;)Rr?JaE5O*;n<7Wrl87U7ufO1Ga0d(NEInvGzi-jRmufewlmNR(m0PsEm^mmi5 z1{HtoGj!;+%~Sc89&|Dj#c@NpOQZnN-3cxOf;$f&tokUasYMWlAcKDbao2^4()=Rp z4L&{XyH&U%BiUyoUB9->6@^8ePOxR{-cz1iv)oLfXp&X+?# zFnt!cY#3dgdiT~W_zh^qio~VIw-#ByssLTq1h6--YE%L0Vx@&$9yobS`=Ex+-gBDh zP1V5^;^=Y|+-l8AMSBpmF&(q1aX~tf=+JjS57)$%uW;{T7WtxaxycOpG#QtTzi|k1 zExpnE=w{6oBw^k3m#8XPFU=}s`-`&`%ry)5!!_W9j4JKlvJ1#TKd&Mv^_iUuSnq@f zLw$AhQavOfo=sG!oHpn(Z8MPM9Bl(Vy_RX$lqTaV#FaOeo#m;c; z-(TMhS}$^ST_UGEAGrBluW$?un2&=b99f%wgAz$o*aOzX~OdKA)ot3UGRSG*LvU5ZmP!h zS*cX+fUS?llP84F!wYVH=e*zAw341o@pm`o(ILURYcn2%rQu)gFeO5FOK{4wp^sI#7G|Y>4BKd+d zB_z#ju{u+)^uNup&cQCVktaviK7noWLEZb*NysnimHjR^eHEnGVT+bH+ykc;E1o{A zsjkK^c2?iTE8hK;7eR=w>-N{*3wwaa%e&trEX{OFaJc1O zn&Xg23N?5;v=t&GiEqAt*r3#{gna`9gZJj3;g^A{FI>UObYXe&?HkGGv5(qRYmv`- z-~dE5x{FeX>=imqMiSbBTn>fD`|(sDN=9K|1jP|yutf*%2%5Sud&M>}Y<5Pd%(Y+6 z1p!@5HkDDr26r~5S$FYH^d!`jmE$Ul5SqA92R?cTRPXf%F!6Ep8^QozHE0N1d3E=< zl|X7;)MvC7vWJS6b$3kAa7aBY}J9e zqYww-`{>ha9zYSx$+VIsFim3Ws2W}gG!i@RMxak`SuruXx*F)Os9vJ?^oy;71fJR> z*ICi~!tKZ_RKtQv$OtD{r3t&t9FUyT6I!HPg{~f+^x$TmhUwk#au;R9QoD~b0^CxoW%aroRU8YUzZ{7^I@_KT->bVNpz7k z0lkvO@A+=)CVqKBzUwaW^kFXgRM)L>p~9ni#>x}+|76dtnE-Y(nqU2BUrr-<5od~H z^k4l(=pl4d%%Oqg;a&3ByNx3BikVjS8Opbj)%4JNP$hiJ=VZ1pgao2Z+r?rdNV;oj zW1;^YPfRhj3Qha2HuZ6JUqmYHau&l?lfL)zePL1+C~mtHXp~;hxw}`Qm8LA))GC-^ zZSl%Sy;P*J#;FU;MTrnK4h{0UBrhxe zvK4Iq>nQD0kz_N7$9V+q*%Z0R?e@!A$ue$6BpX-qhp7~FXooed>z?#Ndl&=BU88L5 zCemPA%GoY9g&T%vQ8N(r^o#2A=to$b`I3<_0bzGFBot+HT@C*rMjy3MC>XPLDk0Cv zI}_Xs}9k*~J&CMtz| zgI8~y!VPPF;&77~Ld7(pe!58cH#r&ohvFev^ zkHP#SWpcT6GQ(bS-|KMW9i*hZXW`)GO3fQhqnQXj(*uvLhtN_QYa2MA>~Q?lnxdKw zvrN$&%9KbCdaSFuW{Hya9xxC$1{9}8@|2*2WpPnNyYe*nVj6T7^NBY!kBZB-U204& z2Ck<{9L;L{((Owo!q0A1t4f|M;4eVu!1k!Pp$dhVWBd)3Q@yXQjn0)zd7T3FF7Pqn zWYEWRD{h5p^U~8b`lMC9@Rw*3;21X{PoqAKwI@&J4c|G>ywM>H?=6zpIBy8|iwtx@ zEA1bs2lXU)(QC!Z6LVqag11a6U{)y{v82hO(+rg642_j}Gm2ZPgHPQohJ5j-^wrZf zrn&7N9hJ^i4wi9sp|dwga*-$|8gEJA?kGYr@&jgH)ge{pr1N*!KT4!#;7kvNtV75U z+NF_IJ>`C8-GOHdYfdv{Iy|OYxcHh8?SWM7#by_ZH8?v6Y&#HJzTe5kn@B(L~ zZ8m#ZB6wVH9zPC2=BTnE1*L46lB_!@E(=zw7T@m#eA?h=ro!ntR6`pe#Rz#?GSS8y z0>@fA!xWL-5hjQ2*}U(xlN_D@UdlcsJz3kTie5~))I(1l_dH%uO|7di^Lw1Gs*6Th zT!bMjYtiYKyMQQ&y*;r{;YvNIQjYXeo!PU5y2e`X+EtZr80a(5^z}+j`D2v?NwpIi z9b}A>Fh;OV)Aw`lpYd_q$IqquS1EUm|v(dfvR2oMa zk?Fj{jrw^_@MA7P{qwAQ+a59`_UU-oh6?>u@XllnZ%ILnuuGh5n6?h?)K6idg+!Wu z;a07|Mxgc{dvm6PsX8L|x{3$GvcNv-_m#_PjX4=D_T@qKobVMb1j-dL;@P+!ZB2@y4Z55VwGd=f;H~Ep87tRrcH@`8#8ZB zh>5co)IT)H^>mmN0N#LkAN*jCiuIFg(0 zG-s6@OBnsEX#_K4p`#xUgHR$s^(58#WeA%OVyf+I*5?9Hk=60))yWKTZ`9QJud)W= zrnwwLu?AFacIxTT+R~Ilj-3g#lqr{MpGgO(+~!g9_Pm;+muML)xL1wMs!dtEn0Jaq zDZr~e&<4djuY>A2ttmyRFot{*z=Fd7buUebu|m+G5XS*} zjSe=FNkBI~7H2SMP%{CyEc#|Z#jgX~*isr#xKrwZH-5lkm*7!sY!g}GeRbhcax@~! zQSU2!-5JHFx8dCHy9M!o42zn3fbYXAMDa+>AVfuS`}+KejCHB9?l0FRzhRUuW{a7Y zQVGeMrfDU=T@RD1)8RxtgrI@k%|aoiff)OkK7(gRS--zHELeXPKg_ylcQCzZnsHsE2Y%vd+d-_4^jaW@7u4v(_bMRS{Sv z7A!j-q9u?|vK!HJfZ1c}V5n9@VA)vuir>V_A;4i`iq-+h&pa<0i_MmeJhA@Rb6Bk& zC^@{1`paor+d|u)g)3s`+i3hvh*U|KtenW{FE<)-ka@nK;jG5iQS%P|OpRGvm(S4- zY*5$kKCS-Nr$WoOdN0JqP1w6^&Y5ae@<@2fI4_rv>+Xg#k;gb!?(x`O-e;md<~~ag zDdDpdc%*6`!IO=anK>oj&D5JvKXiG8CoO6~MH!PTt7`sKPl^!N4daoZtd4!CWZI%M z4nt{Q{_*1LmH1kxG*x%K`+=L6ws7T`4f>w?DTY(^YB2t#BEd6$5?6!d2911D&65J+ z7#)Kb^Y1n)rcEmPtJYoUms!{lunRqA3fdO!<>QP`jI z^wi>(MY)L<7d%7`*9VzDpyFWuRpT1em?jau>eUXEO<7@ z#$gN6(2dY7ME(s`xDX*VvDo5gPmiBi8$Y+rp-+h+3qK<(ymraQ(%~7qM%#3H@dsa* z&K@@WSs~co3hA;A6aHM}&)Gy?i#nMM8J@MgG3Pirs)11jz0>P~XMpcOje0*^Fi>h6Z z($n`gRC}W8&Zm1ttr2}B1wl{Z(R5QQR#2YFI!@*^iFvpswUBok#&wk?ntDe>>*22% zc}Kg@K|aGzbska6{Hu#2?{IQa1v+dn{OlrIL`zWxae$xJ{z?;d8?L(KHGI`*Y=eK{ z8rB!fw^S&o7oYm1#P(mHG|;#s@e^xbd#;#hkz+E?kij3Hv4;;*6}AL0h8LIoo)%n4 zreYj4aE~f^i}_^28^5SXcr>$>swE*ZX(Vo|+xI-hm%du&F()i&$yAbGf#zBWIi{Y; z%83wvgRT6%3FkZiOL0A8N^gX+R5O3gyl{*PE3gmtDJi!eC}E9QF+P{z{tr4XLkB>2 z*`zzwMxRJXGEu6T>lxJ!a{L+o+7(uYYtxZBHgj<^sp1`*Y{^<2Im|Yuwl941pOEu@+LB!CBb~x-ROx7?n$t38y1v+P zlBB=q=L^qF8;lZFF8EeU|ARy;kEj7?X~gZ?^VC1!Y0=85^xAiRFpGHWbp0<-dBPpI zygyl&y?^2CDEnWU^M6n;5is}boPH_b@y|v7g#rE@pD~nuG8Wn9gMEMev4Ih}nWiWF z$O`RS*-#Tq}2lC9hP za&F4PI_ph4+)Q!Vzv-w4qa1)qmzlefHdKH;lep);8Z?*trXY)G-iEFd%GYU=U$(I9 zj=-T_A%$;P_w-3WZMsuIfDnFu8A)Wl(3gB~Vn@QNF}uG;F%apuG&B*MxQ;U^oO$lX z)pO%nXufLq!q8DsQ&4Tv%hYTH>4!5X^6vFaXRkb@ZTNALXVr7!&sT#dh9}I%=jXOh{M^LoiJqsqlr0$$EL=N4 zn!Jyu>#>rLL;>96_t((g+!Um}3zp;Zr;y_}?2aNDoJ%+52}KA;XI`4esduKo_Ak(n zb746ZgG}&0k?}OHF#xt9R-L@{eGA=l*EM8X)tcRe+f1Z^>fG#1%q8Y^O27SduOF^-v5PG_0)dCPQz$;Y6uwN@K<5bX%(ILh_L>hq%1L0 zBd@?M+MdL-$=lsrAr^chjCUoM=!(a(g?apmkc6TGo8UDN)zzWZJ2>vNIMV6HJ+o@prmEJ{sT=Vi0rQl ztSKShQ%@*?uk_J%&lM+y7(%r06?^t;aU+|9aa*j_cQ5 z+W+1B``u$eLiu4n@HGGAvj95Lf^Yi!-PjkIrn9!WGOlSSj?}Mp-j5%1=3aF75z;Rp z42!Vo)GHmVwOi!X_~jHonK-eYny>?y(0!`$3zc&(ynx5)k!f1sd=Eo144yMn0tNE; z?P#Rzs=neLzcaIRx{iEy#?blmSWG0_IeunIsn4(2ff?QLi1j$O`!#=fpMRg)zt_us)ZABs6s+}rqCX{+X8y3-i2i83q}7x9d3)scKRzd48MT4VEUoS_Xn zK|gN5UcTFYsLz!jmyRKXQnRdJ0H>$#0fgi4Fja2s5%THK>A_5aR<2+|Dr~1u^rf7a zx;{f+63=VrLCziGkC*$6X0^AMe`p2|y;|tc+OIS{HbEC}7rP6~1`K^S4GjI|4qT)$ ze8GaRUG8nXe~9IQ#kzd{WkY?nhPL1yp7gHY@6T2C9zf#%@tJatex97+>I8t3-i@g@ zwBMo_XHQEk3$kGzi?~JINw*Y2FkDM|L%q6x0!lV2joJd4deekWj8tDAb(@q^xz)Q% zvU{Bl2uDq45!j4_i_xzWH2eqc0!0jfoQc$YXOgRN;+9YJ!?~Kh93yAD-qcXf!2_f| z0Y!9F#mG0jpMH8alh9SdRLGW67yo$T#g@7fc1W=B)8>x;_Np%JSp&{-wnOD=uN{)} zM~*X!{wvp;e-Ga2`BsrK6HU6J4Er^a}sjMpU+bs}uf4Z9~ z!ixjFLX-FMc~j{E`ZB%%>))Uckf0A#TtUj2Pt){ab#EdhFMYrXr+>zFYJbY^*^-%@ zGXt7)1<}X7vKuUNUPWA@OZtVu+bo9H9Dz%@=_UR<0WP>@Tw$@?zW7$QPtYcBdYzF- z;jb~M4QC&6Cvq>cL9A*StBQwl3^w#mE#jHwmTx}+0R9u~*NLFNOH7lcOHlTl z2RA+Swwfe6C46C;9rqwOR)#mZvJb(jwymAZ6@8XJlub}JJ4h7+p^uDYW02L`(5#EI zd+);Kqy0sr`YBQG7B*sv0T74ZB6gwLYvgd%AQDwV&cT+l7{n?gKO{sd+y;; zxKXrC(fV9hK_rQL4Y2)kon;q(mBeGn4qyN$VNAQ|qmmL}ec!>QZ5QXCW>|YvfcOXq zpEaE2x|A?3A>~|Ol5?6IyFEsw!lf{R3P$wXSU{s#?n^f9KJx5NfjBRdY;{emmT?iO zdttOWxC5zp$86vHi{0FeBoUZZ0A%ZT41Efu({n9;n_;SMy_l0iEmVY5{s_yy9)@6~ zQsAZ)ya?WxKd4bJmS$PNef%^>F%+=^S^L1K3!WpR=5t&w=Y*lqz?zLsaAK9 zRX$c=8Oyf7g$(QSy9Ru358buN7=^Dq199C(_q}GCqP_HKq}GJ49XT$OxRbJkLyAS0 zu6~vs=~A!m!@n!9Dy+-MC(_d$K001e?A(%3p<4PnA&r<@-fT$rRa~9Y4y@*iZFdl| zuLUS4tz~KR$k=O~$-=Z^48LB+M=hAJj$Ey9bGiH3U**@8lhbbF%B4~ds>dU#we{wu zdVA!euJW2L6esTT+-51%g@%2{h6X&$Rr%q2Etai!>*+;VPKaRtaa%6EtVe)3TFE2B zPRv`nWk_Sq!zV(yE@y?ZM`}K%HJ^>bwoXKvEuQ4b7rP2)>h(H98)dZuWF2xyAcJd; zA}mNbd_K`I8z(L+M9vpgoe~38b!xBY@z&E+*H9FO)o4s~Fn+u&y;e=>?LP!VL- zHt$H=FbwgyGf}2Ss%web`X+N9syca-YoP_`~4sUYv$BlSL|6>7y%Eqo=-NE%pWYJN;$pENy*vu7GiX>0H>cn#6x6DhmNZ8 zM}d%~2>v+ok(iq6#;by28SnRZD7Ruz$5cfdk{kb!Q_y-L1`l(I;JS}~xTzMTdz6@Z*v92X`y>ndFN2Fh5hvH_DP3ukJW(>w!8v@vjL^Z6G)ht=@d~ zO&)P0vgxf3UdgJc?8zgF1~7-sbz57JYG-M$Kc$6(KB+by&%F zD5H;V^q>2V@QJw^+0f_BqZtP&kIj#jS{|v*AmLDTQNl{}X|hNe;gWULl?!Nw$rh-S zrq879izyUAmo9ds$3?`)+SqExtW&4+uSiZOw`xwA$*bf$s;eEe!m0P4LP{VP#x6(+ zOr}Z<(=G7__9Q!YF)oKCrPoel`Cn6aoo}0Gm^v4yotlZeMM`FQU7_sqK3)`wL!D#V zQ+dDS))8YMY`H-KACQhVOp2@- z6GpuiMsFn10F;tRIaZL3yMW#ah~h(}P6{jTi8s%vSjQ7t!v;Zd8eOG%gf(u zXniXK=dg?(L-L;7+v~bNi%%NWnqvTp{Vs<|&z=~r<)LlY9nBIQgy5IiNR~+(pEDef zr)M!JVGltxbK+d9CV zq5PLSLmupv7mHjTUUq#|Hkz4BK_(t?pQnRgJvQvkMOj@eg?gkm<%MoU&VhOn@=-mH zZHJl|(yN_~S@zqmo|IVbwtu>l)5C>VDO%#x{80ZV>>JF{Rz4-EH4BvNX{G_4)cv%% zxlFGbylv4Ub<6b^q4%)H9#-^lb=f;@-egg_Xp>8cB>a3Nd`7Y+A1~Y|D@~lbzwdrL z;h~{EfV{;2Fj1e^twE5N;s}ueMEfI11LD$6gy!OxaGpnm&soNL|LFB&5gYG?8^VhE z5!ZO%BJc=Acs{3Rde1^(B&&pqU*obA zNgE#-qJEt^G4khu$Vz)MaacCn#b|O!XyxT%KG@#;*r%NlSidvZxKeS1y&t#aYRVU^ z{rTLBAalR8#xs()nUGy6BbT9VHQaYP9ymt^Lp2j&DGnuQ6uxsUQ@MOMmCn7*=Z3;_ zJw#?eI^|$*OeA`JIG82-QdaGmJ9xI*465Qrg5hh}qUz*a_E9TFIm9NYr5&E;G_KQd zWmh%wpf&TN`ViB$t94f{umP6j!|u+OKFGSH%sgzk-QFuVE09Q)r)6T;g3)e{IW?XC zQLcCEWGBo`mo^HQiy77Uwz<%I7`;tJ*2h)BypiK{p!1BcHRtCXCpRyvVC!z$r(;Og zzpCv#UCa*Szyj0c<^`u$m&j+1s02Z^wIQv%VDXD+_Fvy(9zEck;L0hVCi$x^YR1@H_vXXDN* zeWz#C%1SrpDF;8qbCGBEk|0~kx|SAEi!Big*J3@2FjYt=$YYY?TD9QTTN~)%d*OO%u zS691Gxs=fII|++=3OuScR61S)qp8~tG=?~ZsIa!7%uyCTW)yuJUtW7rIPSDL0j7(| z2=^lJ07>Zv5PsVn;3W{-qY+In?@K~iUbSdFzc>m;sr3sgT&^)S4P&-5r&rp_A-4mE8w0^h$f&xa1UO=@A_7$p2%}z)y5-_JgU33QIJ*brs+YLk81F6 zqS#5DMZKNv>ekG6E-$tiH5KLaOMEe|yISV9#JZ~crIpDNx-UO-(%F<k zEZ1D%L@uWLz^FgOF?UyUX9u>0@XozeGuW$1vOrfW{7kLyauHE&E2$EsQx=|41U-cte*;l=YxHQbp;L_$F?IL9&7@)wi zRv*B(z7V&^9S8fh*XGNS#>9GgqQ&98U6Jn`uJHiLu0|SB9OE57gUg*O)*4iCT1aG2 zYk+jb24{Nra?^R1qdYBQz3e;I7@p2L_0pPBl9R{ga!5L-_V{+--B3=A@+@<)$g>9P7~VXjzl=@p2q5t3Tr%=wEV$zEGwT!-yTbU&WE*@QW#q8L6#1jlmEV`bNj z6f-sK=I7f(>%1-GV%97@A!!c7k2IPTPwWLRr2|waOjx*Qd9bmMdm+m)_|`HA?K$mt zjixA3s^xArvD+zpAhEEXNO}(6?FwuC=2w6=9!!+XIGzFr)3Ej zSYveMpB7qKJ#iyY6C!Jw74bAt#TN7*`ds9N4jCkcvLBl_ifMMJ>wJ-SQl(ZizNGrr zB`N~1H}3SBI#BbAa6(TtR%4tdABTl3s2^V!T={?tT|p!C>LPJDa=J;8{D~LmO4m9b z3_ZamC`KMCkEPdQ?XjZxy97#(qQn%i?_@1N7ckvClyaEUNoKih!V|HhyQ(oErV`^C zGG&t-`aF}=d8Xij;27S^nKa_)oGo-tYw|*c5#@OFwm~6J&?k)Yv3lVi&uHBJu8Y8$ zi`eaW6oL*eRy0in@~f|KE-3Kptbb}JXIIC}n<5LHaJjU;Yr34~bkN0-JSfG)h`>+e zwfPuWWr`bBt;L@=E*B;~yP5BRVTD7uHV?W-3yvbcW7vYONn2 zm~wU@N{P%}FMr(6{55RF3QHq8T5j?eXkU?A#u2K$jI(iIYlWWqaWCm$C<0gYG^OSh zvInkI@q4fBN1$0j`OomW^`{+mzeGGN-`J=ZKG$m=EWP*I-R+lN2b!@$s<#tImCWv~ zkMBrp1IWWU3w-9$)nhb%3%B-K-k43Hr!5pEVy_6S51GK7(mBUE2c5e&asR1)z{Y~> z7;V^-Xs}Lx^663XN`zq&i7-s9@!>=ZQXRwjs_OIBE{h7_4>C=iX}v>{B# zb%eIgJ#j9&(PcFMv?k%5_>E)HJMw7QOM!y35V->Qlotn_ld9!8=T*Co_Rm#32Il05 z5fk}VoKSkE+4A+Teyo9*#&>44XCy`EIOZ;W(|P%!!x5DqeLpXX#4+T@i#M!+XUc;B zosbd+E(hxM)2LnZR5D}7IW%{&|2z3P5$#+l=SRiNufCxDT|5>|3ZlJxBPo)hmFh@f zl8~^tQcQ&r$2hUIA6R;UqU%31KF)^m>3}7?dD6%O2sPO&&`(h6W>GIbL|KfW8c6|KnpIJyvd|~C>lZY08LQVoc1*P&X|Gn9UVDHCxviFpsXGyK0p z6Bft|1PBGEF3vB}kAi&9dVnHh(k}Te8N2}~?A;%q#K!=UBkRyZ{{~=O`d+5nV!z#i zXmzsAd+sF9Vd5ksaI>A>5(i%F0Z_wV;2R#`)bIim{cI$U) zt@J9LYS8Z~1?wiD)y3f7%Qx1hd;=Qhy~1oIcne}EoAI!q2g-qMo%i=WGn;U1d#GgoCs1c z^hjAP_Vrx}d9T#-Mmqo;XL7I(M$>n!onuyQ)J^n6>}|;X{f|XU1Ilkfwc-8VTANs= zEs&hPUj)I`7W#Ky*RaesfTL=D^L#h8EeJX5D%ci(Miez1jhTDkWGYjz2u?I03VciX z#Xkd%h2OlXxQ;~ei4Xc7n0=C3Rz1@D(`^`w=iA+x+Z<7gQTSaOgU4azms%2d**|puY>)VqWSOpq8!aGS6N_g>0?d_1*@?Up6 zl*8Qb7n$#->Dy(`%*@G>4(adzpXU&m6hNNV|K)JF$6)e*xD>$r@PC;rfWzjL_Srue zVgM2koEfs;U92b@S- z%Ol;ZlW%43R#dGLevA#icK`Z`LoukX6&?(-wJLxGt0^yt?kAAu3P}3ak7GdK@KO40 zxTKX>Kt8J(Wl(8wODM^a`6}V%+NN*bA`!$Bvzk@0IR=21ICsLn_7B|s{pR);qZgd^ z-AVo@{MI6Pt`<*nrd9*TRY0E9B*-W}a8$fp2byh7(d$c|hc9n>ncDm<&Jm0}0d@`S zFUCA;1UZx*;CISDRxIBf(N|j4Ok!s#xaGG$*PVR#%$YNZk0wr7a^MQ{QHzFpBSru* zi%>Pc62`20Q-;T&uaVVk{Ds+TAJ9PkKyX(1Z~K%6gP}nn@{cK9c?W!XgNw2D@0i>i zsB1jGP3;tem?NRtE6w4jM!LB_Mx8!P<8tp**Qx@i-B+d5`2GEGva+4j#clCV*#3l= zdT@=u9w#gM07m9%SO*}28{y?$QJQ{_K=1tpu)D-Z9pIL3IqaRJ;bI>4^U!YrM_Y!B zC;=*BoPWhE$4qd*fn@$CFg)n+y8@;s=WQ?u)W7zDZeS%vRiW&Q$c3iom~|zP9&2oB zx@-m_TOX57D|k%E$=?24X|}8&dSmz9-uK*!?B@dI#~)pIFaM+b&;=&030>BPypJ7h zQO-EW{>)$h>281kcZ0W>N|$ouKmLXh@R(8U&9WDq5x;>7B$P#T=oTs}IA3of;B^$i zF+RI3lmD_eeB4_7I#&4g-Mu&VERV_44%=uX<$9KYEFsv+c_7c5{-;>)IoMR)zYPv6 zz%Q38?VPj^0*4OpEB|3|XdmhbO_Y3xUDRi3rNHZ2^y%%LP+KF*fXn7{&-7HYxbtgb zxrM6n=`jeD^F=oiSDxRuJ0Sw=ebJX`&Wd^cL%n9jfN5G2L_a7Vtf__Q=9XDZpJY^7 zpDod7s`&`|gkPVhDFgj13e3^TB&%v1h2UJUJgT1{RxHT}_#~UnSkY9#>hJ-h)*4In zJfUklMKAe&*G-oF+v|{`5=-S>*R^5ynrgvzi!?B1$Y~Tk&#t?hD!Uj*<77&DB$iN; z*tYP&)cf>1&E~etIsqd4KyvXe37OhqWJRDnEcqPvsjjG8o^8ebKC#Y?Nq59l zPT%|S=|$ev4?uePgJdkbhPT227)7U~^_t~?CqYgBWP1K;ZJSoRM7~f$SEHv<(-ywF zglLn`l(=SHIRThoI-Ob7nx|S8Fk!0QN~6%a8rfhq2PJnXul`M`N;8Ul^)~=qz*)R< z`mnKZ>MK~m1RM^}@eoG(Vt2!3F%TMU&scYhogOAzsG->nXU^l_!;fP;Uv)b_+s1d1 zMe*rXzF3{t&untuQd%q4osYX8$`)*(nj0n>*pG;w0uqkTF_L$VuNC~iClSO)g^CYK zoK4G2(z3zh&{Ll~P1+(zO*;(KkvC1x)wzvHya9FYiCVW1<0U)u`1V{^qT&FsQ5{^+ zrizx)Y>Z`dFps{x%s@pZOE9r*`~c!>OG#Rz)@&xyt}24^8oZXC^_SIQ?P02qAmbFmw-^ix&BK0; z>N&_*oP)X+zqd)AQfZLW604PyJ)e?f07C^YaeEJc#oeG& zES~0?r^a{T$lR$}$l#IGSsDpv1&+e`jqW}bS||}CNR4Z|b3?haQ%d4vV{gS>Xy{d3 zRSmJ_;!R|~_s^b3@RbhhM8_qKnoGkNJ2coj78Ww7ND#Lh3u zP_D5{$r-{DVIXtHVujLa^a9$%iK7tH%L>b&OizbIHXHQf@Al<vwt{7if-)L9qjq}mX`En6>?gLsITvPE!-%vj$H z;LgFUJH#@b9b{S*rs6jautC`$h~8GcS$~c9CB1mt-0BS!Hkf=vUy)w9@fkf8k_R?B z*Bb^&4dLma@PVM!TWG>*R8)~f9^)DY+cLh$3(N!Ew1KaquyzotHYrW{WtB`i4RzY$ z!?l5hdttePXSfYVS%QHZ8^;_x^RVZ1B!fb@g=!Hu*SwknJ>a@bZK}x*`9jVrjVInQG>>t8`l8);J(vhzh<3QrJ)c8i?8k1@-A96)W%pk4 z9i?c)(($0)i1EUtA_-o)AYKHqrTKCxfpCpFFh5EeJKA{pO(HtkB`b@Lx)AaCp_HmUQg&iY8LDY3%m$;{QsUZ(1+O$7hAu2aMu zLCqoQm7DQLr9tj)|0vm*PZqtMs(Ac*7bCJ~h?Xa{HEKPDuxwWky6xy ziDVW8mW1f)StiOk+IHjq+Y6={hB+BY1C9@HP%S8e2rt=ZQWd=kzG%v- z4g5{=6#F%Xi|oFMz8FRUoPKa`Zg3ANF2#$t>2z*w=e~|*9GOV??WZ1MaorFnQ&Bn9 zD;nD9WI0UXY-G&8Z3ztPVO~qCO}I0)7rN1^mujRO=)>~t^Y|LIsat=roDQ(dNN@lK zVXOgM;arix;b>f8EG&Q_iJAiV63m5b=T-L8&pJR4hA6yj+2G3t2SFGp^^*^C{PD=P zYo1c@)NA*QJYOi*dTYSGiK9;$4&o*Fc~FshK=?&Mg!5xX zLxkgF+0upkD&ALfmB>_t24a*7k>0H-TMI2sTzmpeu=c*+ZUp2+a{^=1OpNo8fVfh| zNaEd%-jK#0!XXIS4h^#knY85F?f8+n|A)EnjB09KyH)AZn*>2R1OyQUbUY zQF;VK=?I42d+4D_AXKqz=>mZW(wlUYjzB2VN$5qmi+#>_zkRmn{=Z{nWW=#_xz<~s zcRur(^DOUK6G@*uxrn)J0~;p^MF==e=877brpq*gmpgPFaiTo|@(`|cv9xXiw6A5= za&5FF$tfo0JG=EvHw=4=nfKa<+9h_qc>+Kj^w2G3BTu7!8;(aU3z`?#I>3iLS2hYm z`3L|PtE%q`^g;?KrB;kTSy!H+76hFm`c>!;Yu-oQZekdlnO%p)(F`SsS>k zqQ*Xb-_9(X4;pKlEBMmaYL?$*-a$o4O1?`_U^o2pl+Es~BO}Ab^UqJJOY1WByXv~p zcqDj+-meiKNQ|-u`lUt%8NnzRvydiyCK{zND}dujExy&RU@jaFA{pz-SF`JUUPX~{ zwg%=FjkQ^}s&cy{_qeIs&B*v|l(>a|he($qH$kX0~}eKaOJY_J^Lz|_GW=!*3X!@b&}pu?@zXYm1mZ?CJ}N5 zaf~$T>5TNjs(p}^jDP%FkC2edA&rbJvTyOv*GO21yzv8czM))??PAX0J)NEd93B|K-^pTBL;HCV~K((>=6A;e~4_@i1>&%&OP*hiBUdyCRqPc+Nix8absR z_OQY?2>lnAK;FV}qTjYDXbz1WQKpK*O}-MnTciqcNE5 zR0c!^x(xkF)yx>rV`VTm*&ijcgN`J|wCpc+(|5^mX#;L21S2OJ5K;=I#p07{4|lP4 zXjLfTZTwSQyuzmuVL(TJ8q7a!gPReaLTEE0qYskH40jBVuFJGvPN}ixkg2k|s~~1t zBcZM#7d1kw158<;?qd%RMk2`W=B#|C>16Ghz7Zh(H2v&aHx zr8#rkSkFm2V3_UQg`@VO22(Oj%aXcyeI8ZKzXuoCIH-gN&RN8w& zO#yK?KFi?CTuAm|kzZ1JgX2C|Umel^iYkkyKVw`9x)^JMyu;MvkN^1`IRmw)2+8IldAF<9)G%{S-JLxpM0cMO{_sr@pW%-F6 zHpy$tO;Xx$djZuChY3O+Y7;R_?1fps#wMXRaZw=AvGo*aV<7zVn5}{Z)<0v7w+|FX zaAQ(5HgRkH#(exvMo84MRX*G2-ek?@WuFzGO5Y*AZ0X$xc!JVY#t8p_ZN)@td*-T) z#YDGV9>(vl_ z8=keW z3zMS!z2>$4?^7xs!TaB}sT{KYyT>dOE5nL+X-tbLc`k1ke?{@3RG z_iyDp7ukE3eP90ZpZ^HB#|(`SPyQiIu`jFHZ`XLT0+EX^74O{m_iFO5HRb%i%za}qTnS0<*@B&2LN3N1Do&fo)F*5W(W2i+n!n4*Cc$7#(kHN8WFK2H$kiuUdb`P{Y*_dh>6WCtKHz1O#{ufMPcNwL7!*;LyVKCcE^vGEbDIkk687-F>0tvr}1=Nv(iQh&pc@&9DxJI}g zet+V;`)TfF7yWv4Ri{s}m)0R5?QnOh(?-X*s~l58krdTN|@-n7K3cSNjN6rV9g1b)A7N3HCpJ zehQ@Gf62a643rW;lfg)*xZEFTzrC@>{>_hyNu8n6z)0iCSr{ww62vpg0_qW4oqUz*3#He*(6vpN@^`B3lq6J7M~%Lx&IGi*(T+x6?1uof>qt49i73cdmSc}8T| z^jnJ~I(s)i@3pe_HK0&!!+=x>m+igsj_z2$PzjXV@y=2rxi5Ku?GcsH+5^@~se*(T z7vW@^F?Qv8e3$t8k4fk4H~bbqXY*Z_sz_E4JTGmqx`n?94`(i+yR81&I&N-5rv*Z3 z=zY?#6oYtlFZvaF;Z;ff(eo`jAh=D(U{=|}VsN$=sI4Y5#BA8ZhYzpN@Ue{AVf>^W znAvQab*cf~n$nj^lK9Wod6qz>NO(~eZey$hyyXX=kB&f^rSXG5@QdTJkj`7c0*>Wi zz6-P20HAw-ip$FcK!x+Rgc-nSKY(2`ck2{v%yDvZqVpQVdAZAaHq*+43O*55Mfdr$ z=!QajGiAg@bK5gDvJEapkzfzzVw#Btpj;jt@Qc0V-qjEnx406K;+m!#H*Q>_9u$}c znE8sKii{gA)D%R3zQ00JFNR%tjmO3@j%cu%v7}lu#~t|tPI8yPr~DvmL}o3Y{UpJ_ zDzq<4UIsxXOr3leBq03U*v#X>jMntsH0|paSEKnFfI(^isEj*@2&YIyFxnI<4Q|e9 z_L8xgGC-N87tf1-nTdHbQ5h0es#{4$R(Cq?OtPCHY+i{LfXYu zY0_K_^ncoS>-I}54svmM|I95AkUS>37OhnlJ#oSo4#_5emtm*xByr5Ncqru2I`cF` zG-;gxcd&BEM-Jb#d5Y%sAD;yP$PhM+|CjUwz`H3NvBh*hZRcj&SHxL!WDp5~7jT}i zIk0g{ciAyX9ptaXAKaiD=dvH9pOU}kd^TGUT2gt^-|*<>-WL-3P$@rM|AX0x?W)hA z1(J1-nqvAjp9KtaCr=Dhp;y2>R=L3D*9X z_(uSbu}L$#GA?RkZ5j{jSiH(i74zkTj8N?O{h?~gNM*EgohR3i{Pm%@3nY?V_h3|P zVB4Z7KrJ|oD=eOxg`B2MBKp#Vb~yQEjXecbg!H2u;a0~$-3e(cikd#j^UnlHSBR#B z6TYc6g_@_LT6lL7F#VeeHo7RoQVzH`P%Kb}CAt7*R#e^3I|1`o7HV;3a$48?94j6M z@Y!}}SBi+*7*%_cQcW$zk62!m>YxC(uo2}wVQ!QX^6tt-jc&4H(-}zKbrGcDDFcly z%ix0#566I%ovMk+(Lz#85O?3Mj=8YZp;ae{#!xX~!)?hFNk_xlrNS=!#MJHEZ7a8H z`C?Y3fC7_&-8^g99Wmh>F6+?+u0=gL&;ccLpY>1uy#`+y4oYuaB16$TGH84w3u7fg z`l~eLXC0EVv{i9@6t1nL?vpvef>BjrLG^%Q6&?Af_8%!0bLto0WF=qm%f8c;UgQg6 zR3y)tWO|dE%PWwWim9wmNL3mPWutip-r^PC;Vww1)sM!w8DL7ETF2JU{Zfd%=$goO z#I^)UR6Zk~xS&rY{UvOLw2dZ!BGSl^TUL5?LYUGH)&OYFt&RC)&xl#{t=@uNy*6DW z1JeA6ojH2LDy9hYBGXNxv;8($4EbpXkST=(`Qzc4@y)vF5{r6qsq6R2TiQProN8T@ z+OSS{K|2<^bBTLsW5U9!ZRQHr%>ZXY2Xc})+@RBCh|3_cxeS5DpKX8ooRGn3;ez*8 z_0@GU3_w!v*de4A`)zBczx0&IpfVuQM83&zjm$uu*Or`vy;y0U$`@^tfHwILk01$c zzl-F zjtAgqJdpEM#c`Q8oF9Xwa-{s~eRUKo~?Q`%iqhFq`GR6O-|v+cPm~ z)0y|R7lnVvjo3+}?aE-h9 zsFdeKD^ayWM3dpLom$yc_^LmnmPxaVW|n zO8Y#{Jc(n)-dWYJWS9fP96v1ktU;WF=!v=WN{S!okGo>zvzBqC zov9mcV6OE0m==w!JXD*Yu^F?_GHB_(If|}Nq)jM1^~Z<1K83nwW^L&)@?N}VEN&~z zR^Ng}aw|Ddw}@e`LHiVi!(?I9L180;XV{|Hst8xjA0hOPEyq`F@)Yu262q%uz3asZ z;x&m^)|ueZd%Xp8%(o-s6^cWT&MHp|cEc)hN9jSTfSKfX1G@qej=>(|5$v3%iPkd? zmza1*dm7mYNaA)Ug-2A*L2Vu~DPA^hYw5uFuSL^xNAAT6`IShYjySEXgXFfU?Ti=b z&TlfGX@)5@<#6dJ;$C)oB3_=|IOC@kbo(0>SGFrhT*0c}5Saa)UGeQd5u^2fJ|MQS zb~&WSuvYP@)|18Bwu27D)rPi2X~!=URYl$Y4+S9`3WxK_xxbPX!)>87vq=H~u#|N= z-;}P)mEh~x*Qb3X$VH5)PLpiI&1rs+?dPkz<3RDMx`+|;kA{V&AB@>ivd(3Jcxm{eCqZ4U(=BA=S za91IysN~`hFC>Ny9^=H&@`U@T2y=JHR!y?MnboWOp|QM!EXr3XBcykiCe{#y_YxG> z_E`7Qy%WwCo1=@zdPH$(m~$g#zS{qqQ$utm8BOajGOko;YMT&Tr%&aY z-VSQgkl(^JUnTh7=SGpiCX=Tu{=9xsw;O$^$&u1J(M}rO$z%Y8q?x0;>;S~D1O$T+g(l;sa$sG{C1wC znhU(Pwyig8Q?#@+s91h{#dYzj-QdxikthyJb0N%E?XHkbcoT|1BmL)I6!mWBsRMts z#B`c{W+d#^p{$<5{7OLRz9(8FUU9K)6&VGs!3M<)AP?@vZ-Lv5ANxORCfF>87RE2V z@R3XETC+8^a?tFkralyMQbgpKz4Zwt4phB0ZsVNW>U5&dTum#>jJnw-v64as3YA=_ z7$a;Lg0@~GZGKJS3j17Md)P&?>ICXKOuIOn9QRNB@wVgf+DDPW z&?W288e(%;D7DPlP+VB@1Z+^ziQMPpw~E3)3Vd zUe=p_ttp*bUyaMTVCnmm+CFa{%X?L+SB|jo*cC=PYp`{g^VDrqY?|J-k17I@VTRo; zFMv>j7Y|gini5P5Pz}oEmud1It@l-A)TMSr3$=3V{z4ctZzZj~{&jZ{(iJKBbDz2@ zV(oyI4b@d+{T-rCY%rZPtYKJM!vWD|?n`WSx+C&u4N?Ec-REhDDbq%N|8c3Vk4u@- zsy<*poDN99H6Qx{A5_gYF3)-Bfe$^gHf^I7XK6!B^rm4<{esFDm}G2)WrANJ9oudo z=0B{O63Q`!A1c+5mPj(B?4}wPW|2tjzJH7&c0> z%Yr0iK@=ZdRN7_l$v=En+Y-hXXcSl$=5kRzcTP&qsLMdJBV?(^?h|&#mO7dtnGqEh zQP0p|QE7G^9tCO&%f2K0M+X~Wmas5FX5;FHf^}WsuY#yD(Ci)*P_X&*$t(cY(Xx?4 z;-ke3ITDFtI>cw-dI!uhfO$Kk=&oIK|7qAa3533yNQw$ z!;YJeBfcC~UUaq~99$mjZB>xh=e85(*hPJCHllaMC)UPuG=$AZNhmsbAyUnfDUJk} zhw8^PpT$&i95F9Tbv7!tz2c_Ng5l1*CfD-gDu#COW1^!l-?bvWtqV8Xk`4q&&_B~m zIo`UbRi3!eshWbS6{QVRfz1pLJKF7hDU@td4>QRy(D*5SvQJ^wl(@&ueVo42hR7Fx zX>*%oX`t7hWHNY8HIrXYF`2ozJ+%89+gqRFAYa6nD|2}GtCDt93Svs4khT@Cw;C|a zjwfvx^P9%wkCdMTiKtFof2#YGDWQIz5SqLz2pni%oF8a!4}Zom=m=p!f<5JuIr(pS zg}KcV9SR#nmTJM$$DFp|Se~txCt_&PeYY_F^+%pi?K8tv?A_UgMuVr?p#ApQM!fb+ zLAypTt?it0nQI7-yngROt}^sxXNt=H-5EbsD(k8-!Hu46Kde7hOj8(JO3km%v4T#r zbV;81j0p8y-(1}sbzc9nSB1PzI9EE&7!eWEw}&~Lgj^h9B%lyN8MSbrYA7g%bP07s zaIA^FLb2eEEA17FU;4-?h-d>i3UWJ&&Bkz-aru9=0EStO(p$pr*r-NOG0oQ>Z&07$ zt;(y1YMQ@I&@SaBVBLBZo{M^zA$-{tXWz}a)VuJHE?0pa?A`!3ly zlb6kd2=`a@=%K9DGT4n8*A;xR1;^l##bXRu@hApnAhA5-B?gSGye2RgwN0_hUlDqa z#mj-~zAdZRm#p^_sd8@R{+S1D5%DK6wY^?!C=Gp8ayi{C&~!q{QNFS zzV^X*l@xSHO4!a$bth#BOLYh{5M@{Vwf99y&!Dm?&sV>8iI^MG*ddi4{@L>RyI&OPW~*)aWs(Rg zSzOlXtR4F%>eX9u5aT- z_;tLLrlKRp_=7e8a_iFp$Tc{k&PYLn%B5^UQ2=*i64pHkwJ`UHzVfbnEd1hhmOx{2 zr;)AsuG<}*t0XwtqV3SeWz&)ynUpWnF6gX5QcZZ9Jb0Z9JY{+}vLj}DZ{J80__)}` zDqW9}WRejXF2@T$d`;N#dCpNHy&tv2)V8VmnSh<^_h)1X8j1%X@Szg@2YmldnEoqm zdj1{6#g}mWCXMT!^1yj6{2!k$?|>Y^_PfXKT>vWbk88>UR2|R6?pO={m81J#RQ()S ze}A6Yc>f(uy!jjGZt(I_;lL1AgTc-ps4v&G0>SFr=E~+~DSI1_vz6s+x>)djaejmc zJh(TZ#?D)}zN$_)Jb9);Kn(*tl5p@dHgEk`O@5qMU zAmc#Ybj6c6-aI7yV!MMH6YA9zT$BBUr@j5l>5hxcQ0N#o1G0S zR=oJ*Oatgwp^ILxVM4zY(b8-hVl|*}fX}Y;b@^}!;|$1kU$z2B_}QSvr<(#604bk( zQ+BNrx21otA!q+7Zp_uH!+@eF-}39ETlY;(`6gveS0uH~k{Q0}JWP4q%=IHVd7miGawN@|nU zOVA;Mk>%)L`kl;roSV)}Y0R-Q$G$8(fbD_prsB=hdP7R*K7G(_-T319FA~ickQ_@{OyA@5Nr}4+z9V>jMy?-?pF;W^tEaxh z&295GVplP+7B(N>j#m9+pwrZOGJ)yfz*jU700or7d#JO~gM&tJ`#E^L!W9}u7JQ6Y zt+3QF4p0d80kBBnVkRhojxcx48$dSK-+w4BR|z3!UuZnA;h22{p!uV{yJ2iGFBzaW zfyT{UA1!MJK)2iZoFxsXwSX3l0w>s6bisdTKuc;Vl!VGcIDDQ1cAc3rQ^vUp=wJTO z7PvP6iBq&*S_GJ+mwja;M`EC;loT>GMw?~8IJgHCyM%Vvw^T>hqnrQ)o!1;mgPK`c zS+N`~bvTUzXH&LE?}nw#4eE%~MoaAKW!S4xI@ntP+45>YHRZC+0%*ZbN%X)bYWF@& zp*43#!BShvM%}W5Xd2H3d!V7AA=@Tr9cp(9O`}0uU`P?QrZ_xkU!^S&5Vk*<9k%G| zAnv}J_5u&t;lT+ON4GGlWNOLM>xWiQVy&I7wY}UrB=aKIdU$yF7cn22JX)*KQtr0u zY3~>5=$w*%FwjXy2?__I=C~TQhwKp<3}H7%$d2TgpV+G^UYB3m;#CMGQ0>{nt;aor zWFs@_;bJzq#c!nq9BJlck-{W!u!ok{GPk`dx{@0q_1!Jjq{_kMjYYwy&y}|_1a@5G zS1&E_OxFnPDk){JdZSCu7{MNc5Aag4KI6OM7IN<2sJw7%<=v`ghPK7K!_|N)>WM7> zuk(DVGK8w~#m81%6<5ryONeM}j`dfo3oZjMNl4W$+BgDsB75gxJvA#^VNH$$_<8_P ztndk-`bJQX$B5U1a5u)HJ)vAb+P1@qp##W)F6qn6bLN1Rh*&x$K&r@_r_vy?VJMWx zFl5JwU%u8QM{}2BmT$dc%po7`)`yYmVwE`*@?)Vuc@V>`W`l!5@-7+M9vY$!;Novj1~^Bgsk+NIFFX{qJoZWE`<w>XKYi!bA~+-dz) z0(dT00fPJE=kPwJFpfy|D@8UVXB6;38C`6s(BohpDyG&=FscSJCNcA!$3FFtNa+jw z<4jBa=H(NfM%WH|IK3)IFyjb)T4hyQs!)6SqXJH(4`9vHrAx>3s`zwiS;#D&dO4sH z5~|&%Di|w&s9Sehm?9NR&+Lr!sWfh|J*c1?B;J3P!J>bjUDrniLVLjXO6f82!5w5= zHdhNT+K|s)3gruOt(vXji*t)&qqkw2&5L|Rf|q?#b6{=hrn1@uA^soP`(k0mT+_6Jr_4 zBV0Zoc5Bk0Q7ta74I<-ds0u+U_@U^GOAR~T(v!YHb$+JT6=tXwshW=`<||tZ=O7lC zSzP8it{6ir_S?_!j9Ob4hT<6P;`yj8ppC*^!aJbbDB40a4oHLC7GB|mtG2Da_7+k+ z9U`xoV9@ zZ)lkqGPmv04nUk9E~yO?>H~vxC-xvw#Gxw;B_GZyJEi;D8qh65m6}2#9bVP*8icMX z)LR{t{5IcLRnf$;sen40h>$zC{~aNV<3+~I$K3G%%tD{-gBm)*e_n0hC9%SFf_D@U zWyDa6N9|w@_fHCQdOFFSEM00GJKA^JK->*D{fc95!VxIdPeH8UYyei?ZZxNxd}|}S z&(u-;BN}_)@_7d2v5@R*DpfPhYk#OM5)lW0Ig!dRATXRvzIFV_6qwAt#uk1a2G0x} z4DU#!GqPO7v5UEewTL@833!3d7Wlrih9z$sjGD2lVdwDS(C`!vF|O2;1bP-Wv25Ws z;im+@Y)pX#(uOpQ#I;%cB&yy34S5YgNeo6#u!885cVd`SO?$&%8r)+*d6L4cO%3LN zu&7adIYEqDmKZa_v(c!U80qtv5N+kVZOM%^XK{qdWIy!fa4PLdRYlr%g_R}t#MT&y zUMIOMJNIX2eDyUM#7P5(Sm>)(ktbC1Al1)LVbP_1DZ))TJ8VW4PEdh`c3OB}L>fu! z?de|c@^v79EQoE7C%RFS0q6T9l)T&n`-Kk|23C$$-8^>R)r*634tcU^ ziZGe}EE#h>C_50>d^U(sR?1o5#xA~oAz6fwZz*j~&;PSP?ewPwB{|isuW0^cf=fKE<}PjjpirJ*@)+`=hP7iA9mX3awPJQ$B-F zbZeN|#gpk~e+fIf%QP6buYIJtcaP-Mj^n$wYnk};t2w|!PQ58H9-!kwy4nx!X2E@% zo575OZ9Kg;=r?P+4?Q)~f%1bT%~Q0`;|U#|Wa%Yxv1IC&sFspk>=E~@FU{_eEw|C# zGk)E@N8|pU8=JQ_>_-}W7rs6KqqQ=Z*soJg_?jx+?S^d_{fs<6SnP-b=Ocyg3Rr7+ za3h7_wFM3sPsNfjneavbGZ&*b*ueZh7A5c1{5&SR2?Uo$arjv`=eIt0eXnH2fSUqP zT8hPAtUBBqTM`=Uf|D8p=2}7;$-bw|uT|pD%mxtiiCRrY)~TbdvGgbvVx!;J0*jyJ z_fQEnkX3WnLCqj7tM;Lko1!IeXeQMngS$h?ft~N-Ws$iq+F66wgt*$$m&}4cSY=Rm zjinWI_r#;!lPrDf7&Bg?{XjRdvhdQ{T2QZB?omYmTw8(SV)h z>$%CMQ*_5k!gH&^^vT0}Ktd7HG9l?;mnt{Mon0$#fE^S!o=EDxW=Z zs2gZK1JG9doA|bg+s8vkLhbLE$1Rg)+X<$s+^V?KIebt&>(0a)Awsg>LPnkI`Vooa z&YMi9u6Yh~*O07J4L@%vD*jaKXh#c&4FmOT*euRkIO(=_w~aO-j5?m8FiX<}0a!8j zwxXs_5$xb>tu{Jdg-hudf1Tp5^zHqFQ5a|BCVVY2zMGSu-sz<5Wya}j6M~w9yV*Bj zJpHFf)-c_PZ|cDP&4TU%N)P3$01y=2udxMGR&8frj_tIG2@sq4rrm z?U9r)!imJ=YTpyjEGCT39%%9Se9M_S2(83qZpyIjSFT#ryDiSCJ<-5nfC)3V9=vXP zIGYY65|l~3LzL^Ni)Dt@H`#A2+C(;ecVU{1 zMV`hQB}${mhJ1ojO%1ZrwexunH3d2vn-&7!a_zfv8SLc&PjT~pEo>&7;wz-3Q}vD? zC4|&liTt$|LZNey(GCgb)OanC4U;KW{laMhP;g5G{WbtvE~g!?E>kiopi9RmQoKlu zy$CH)cr6!HmlpWT)#mk4p0BI+_)`0<_JsSw)IKwq>v?aMQh+}tv$2yq5DuOoCYmk3 ze2Ip2B!sP%4Mnirm&-$>c0D zx9_fG?6sTN%$-Mw5vXo%Y7q#5()1(&GaQaI_*737|7RSawgPRX*#;>-MPe_6sB*^% z)zs1s1p)^zi>?&dk4^K%O0ey5^ScwYZ^JNavL=!q6$3q1=teptX%Y6bUE81zab_~g zUp-?4B}H-qxAypme~aWu_yK56=Y57~UZ2Ffz?RRMid9neHn6veV8BzyIIR+r=&gkl zC3t@EM|>gI-^wvN7)_D}NdAXgh(n7=iMFIy#Xi#;2MhHCU&r_ZcfbR4Cav>DlTCaS zkgnK9uX>x_JW~D_)DraAlD9c--Srnga-XE60cbeRqjR=lkTtWg z1L;jqF_wUM=17hn;j4fWdz5CBib<2K2gx^a{cJ{*8jCcN`7gZmGWEpRW$|OvNqEay z%|J-fLKj_oO=KNoJTWXO((31adEVC~XR!_tdrQpRe3xYQ8{c5!X4`Exn<|s<%lnAh zJUwWb*sNNaN{}l&_9PdDFqsXXYiE9FHksU7uAlw&>;1ax91E_k$!7>MLV*h!9$6G( z9UDQb;uWf01(xPqYZNDTOS)l+XU-RpH+{zyCm8$;w7F! zhyJtbKZ+WBCT7vVaB(#%8g6WQ_U^VXq2tZ2UOXHBmkG-Cbc*JMAmkWmb#a!^64SD#ds8x}AB=ZHdFI;uS?&7%PvE zw4v@s*Kr{cM+DRD1)dU6^dYcs^l+t(Mu~FW*!x~(%>r-s7Edauk^*~6lYajTUw$y1(WhEpM!cHPKO##K&H~3FbBK;iT ztYI7;_KxB=^@@#AA>N*zEv#t;r%UZ}sV^1YSS+mun2Laye=;Mnhi#}^8;9fOkL89A zzNFQEy-k)obqK~mu6pzw{0EMS_{nv2Y^1ZtjXHl()L zJot%(lfhH0u|(5gec#l)fG@nft{EYm25b)q>#_NroG*{w|ivI ztyOOV|9Zwxz_?~(t#2dXaPBT#*=xDsux9BVUcF2ENl|e|X_OA{IuDom#5o*jKm#DO zF)9yS_by+=@t$sFtdSDpwz!ryD7&o13NQ89@@7|V6nt=ZV^$8zuP1Pp#q#zsyAmAG zGyv%pj)mNlC%=|GUDIU#=dGAVdRS=*@@_Qd3AsLN?nguVG4h*-c1bY{b=Zw-iVwQ( zDw>2f3tULHM1DZUbOu_3yudNknjaeS?_8$`IYPC9kL_T_o$Y~9cW^^Ye}SMuapx3< z`P_gahKpjS!k*f}1G73Z$#F&NYOdC{^4jTS@#4n1m-GD;V1c!Z_0l+~=m$G?F%Q@B z0+1v9CLO(G{!}@O+{AWbMy)&~%ijMb?`)d*U@gIp7l5dN$+BGG}P?i26uV_C6 z*?D&uuh?VKD>C2Vg2L7qXYy{hwRrwe}ieiu%AP??2idIs(!ns><{$Nwv-X&@362G z3c$#uD(4 z(bpesT=x5$5t$PP(3nKkf1=NUl;~XGF)+ETCD#A0o`(K>TjJg zLnC<3O8*v-iyVX3TA2wnZN!U2#~APZan5cK~2dxyoZ z8JZ?@gAc9{+>@(bJb&(Rw6m-o;r;RlAV}i~P<9or7V$CvttElR6h;Hh(qxk5#;XN@ z(g6Hj0cZ{^^!R9NM*(&4-REYPG7c8BF2FH%B2Od60Z_4#_C4658(X^^LkF-lLK@MI z=}VN}s&#z;P;8LR9IX!iM@PTRR6cXUykkCw1~Kw$jt=B&ef;$4wzK(E$GP;zbg3WY zf~Ha%5wI3q`&d1_f@%fC4aNXH1~Y~jeK2?p1VxnImX;RcTplx9SXiiX1r(TL|0p{@ zr~d83q6-X)#<%*P0}-(k8zF^SjLnWru0OzM5TB&5N5ub(1bF!lz_hOsqCTcD)^PETgc z?A(GX@9F4-xMoqI;BcVcyYX55Lh2J$==?&g%rs@L30=TBaXNfAkzdOJ=mK6PPC(VY z@|=d9K}=o*3^y}PXur656JgVPn?JbcnrG7f{_4PQG$$}MVRZZE+MfWwK?!95OXi@t zhlD`8fRgMo9Wkr`Cl2A zX5j$T?|%Z=uN?c_@JD=r7RI?1;KcWjPtH|<8cUyIWi>Z|Qu<)VsNAFR7bRm}OIdbO zwqH9yJpfKg8!n_%dJf((aJLDon$ptJdIH$|V;vox(IpFGQ&)|<7}<6tJF}wGTkmuL z6|hD=CWaVi;?AXm;J3tq3N6r&-YXk`U;(wJv`27I&(n)GMdwzTRP6**1)8f*@~~zm zm|B(S;o+e@*01=Vt6q>f4^KXpjR%7ikXVsOho}N_59~BhqztZXt^Lcnl7a2Ac>9)g z5!LjlX|SCv#*%cd)BhZyF+jdzqnQW%swp-lkGB|(e|>8Xr=`w*lxmEZYwu`r-%y1@ zPhz(5u5LIiR!+1tTD2t>T=^|kurBMmM#2<4wK%v5i~^zFQT@SqkVET3+m<8 zu?+x%!LG=G*hD>9+@kJd9@ND+!~o+)rf0nGXs-A_UjSGUu_KB5`oaG3lH^jsiO89S zR(uTM>#qjjFivDgf_kvH0qkA`X`VRpA+i;K$J<}o_-k{JxB(JO8GfK7QAf?K4(mW` z34YG*JxVxfUE;N2Io1Mqm2#GQcXPMx*V?R7-x&tle zen&khEt6;gaBFATNPb|GovS*VC!8Cy{*>~i(wQb3E~Fn zW!Nxl`}q_j9<>J-rPiF&v>8p@4I5}rnisH&D(6VlZlB70{nYdrBJ*e}aOVt^>>Bdg z_bJDL_A!(j%0L_M1a%yIis2|kj57+cno0zY#W!{{xv=57{JD;6mc|v7Y`ym2ioq>J!J1;%eamLiH5+r>h6PTBl> zz>xD3*gDy+5jwSbe5jhn~;JrM38yS&cabj8y-&d zQyVI_^8O_<-Cw<$s(cEj-tL#X#NYCo=Y@+w1!%37Z_^9ppA@tU1S_GDXY=;JMlxF? zuDAaZD)v`8<8Q0*j#!ZY0I?tbIkUiMRfJUHPe|%7^V#Vxwrs1xp1o7+gruUH#K&vU zL?PzOBc|?pTi)=|{m|0Nt0Aw;Em3tdx7)YTjx)!uxPkfWz2Rr*Gy%{A;UKx`8RkD0 z`$*;e^8idyKFjMe^B3l?Ayr#D<*3&S1MyJe-K`7zpR56{^p{80*#h%tvk~0-cMAOe z5$Gr=1KS0?@*gZkTo+UQwn-gRRkqFND!M|7Cy-ZOaK-x~Ay9RG3CqdA)Mzfu|w@7PcI#+K;)udHAp2AgT9XsI>gdaG@pj-dOUn! z#$1(2FT>7V@UE<9Z*$Bhzy0%Z=$ebN#eci07I4#NPS9Z|25^ei>Kl*7bjAI=?mw~E zPE912opb$>YD5k*zxeLV>s)Eulc!HU6F0o2Be(bqc93B1HRReP?csq=$B&+Ez7v#AEg7tyDeZYsf4g2gnex8q- zrUrd$`Ysr1{+Ge>h3e&Mw(WGaZzt=BH)qw%>xm}$G!xIyUd>lIu6SiGiPs8pDj{-u zryp0{a)3YULo0yYxi($=cU!?Bv*@vRw7)eSKsw?lJ<>~;n0*8xHp!>!@Wo4(V4`(^ zkoE!1pOvYpS6;kujN)FYG~2b?tdCpiHCdCLc<(>1|K2{) z=L;2M&W(G&-7;t>1PI-S^0^*d{{0#sP`hhN`Zn6=KX(*he$MRh7~IOw*oUO0(8?+Z zzhn9Rz3l{ooIV#}1I{EfWj(_6GNeV6ea_d087ol>+07dtu0h6un9+`6Wp9i_T**!^ z_S%mgsQVsg=F9K8KJev=D)-~XuZQFR9MZrb|7%odFS_?)UL?Qitt8n!{%E1yKxdO7 zVN*-XE!URqbK>tam&ODDq86mlMr4$zMd*i6*lra8G9ROAWHy`3G zX0+Ea^G-?;&Nt^%)vYJ+=Ek4Zt#scrJQ30AI#_Snl|B>hq!Jftgc z=dNkz`h3=qM-h1YeBJ}AM>+>z$ZLuyS`Nh}gIyvw_~11yH$J>yzxyQfHn-E|lK46MF(5uJ9RJrMFvsOFm(J!PH9=Y^5KIHx!?**)F6N|Jj zng&-7W1Ohk>ABukRHxqT3H|n8pYFzA^PTq)G75ea46uYmuCH^an!l2H|6VCGU2eDE zSiWKUhPwgB-+TPuzYR$LHOJ%6=QwxNe|z`$q~0gXV?{S87k&o_lu66a0TZFTv;T5$ z7vOMhdzA&ufq z+lVpUl@OKU^3NW9S0tI#?i?sE(_?3zB<*GstsF}H_xAgjuZJ02(&rC$d zTS%%n`c=xeIFKV?&R#*{al96rhN*bdruNVr(qCLZ%O+(t_NZb-?@6_{SA-vU8+NAB z?9e}2!JmH!{a|sNbZxwH?Wt{hTfkBCLx>IT<6%ph{6c<#Zy+sNy0Z%$2capg1zW7J zyaB(Wsaw9samZ0U5oBla$?gO3GM=MZ>q()8@{4 zFQf~zpWKym&i?U~S@YAFnGy(nbtH;wt(#`BYj0SDWWCS-LvMWp!r-yQwSV9!9Y8d- zFt&tN8WGaVvZq+t*P0&nzm|voe=KYa!)yzYNkH2X+noiy|RkoDDJZEefjrv`2DwiJgJ3Y-EhE$*~M3k7PpyA#|=iWe{L zUfkW2;!bc25Zoa^fB+%Mm)?8N`JMaSKa;HI*=z4TSu^v#Gw->iyq%Qm?5%CbwkC8u~T>w&*?ELAcKhX$8!vGlzAtu0hW>Z_2;abOyYf zXLj#&8iy9Qc6X{Fbx4tQYj#0`cU<8v2cyeelHy09!Uq+N)!%`hzY$Ra84#+;Xjf}!Dd98 zP!@TDW)OXxr}{md0oVnsY%G=<4SjEHySgP|MX2vNL7@M-HwVTFc{iRdpNr?YW@UHTX^)!Vg$Z&SF7XIV9U17Nv)=VU3*A8l z9(}6!-{GV$Z~AYC5)n+g$>-i+r$sf#l;?UVlQ?&(Nj3X@Y|X|vLGW^557=G~!H}l9 zZJZhQy8_fZ3I`FDD^7L|&8jLeQ8&=|Irum$o^;=0mVKk>C#dPb3rV4kX_Ksuki;pL zP)qc)EnOFV7{&KQETp)WRW;zwjO&lFkox+btD&kYB2cfwV2O=O+)@exSHy3ppDwHx z7twcDtkoga_G&lKPCXE-o`M2((3SqH?{mt^IGns$CMNz~-O4 z)N-b7wLoy+%5W2q8eNFilP#4{Ps>aS(A>dcu2geN1rjr-qLT~SOYCsm zD|1_9pxI*;Sm8aed|)w@j;NVlL`N9k>Iv^l>+PA1!Znkmh!490Q3dsDnIMz+vGdWr zO}D$Td(hL6vLT|`YB!Ok3Rp!e7d0Z@T&2#|%?BDOi&8coZ8@zds@_!`sY!=l7V)=& z2V}43l??sbqY5zi<8ki0p_+h=JlkCLic}e{@_KMsn9Lu2bX8V=Spp9QHq16kUcX^u zgwM7?CP!%Dm)PUtVh+x*gh4X8X@{Fb(20uw@pR}5^eSQ5h@&~I*)iYZhOTHc76tGU?HTlS)RN%U2Pl{|1}2UYS@qT-O>-^aULQC zU&y{nQzUHgvO`rHk?t}CRS`VFI>5Tn% zCev}QQb*~-{Y31qy2!w-9p{o4;Uv_e)lzY1P_WS%wDkeao+J^0Q}^UCL*&-g7h-~@ z@qq@=tP;_?A3c-^4Ucu*EeQu9QV}=L;4f~6VT@pb3yWZ%Yt1^V_!qj(lZ>G2V`aP` z#D9g7AdWm#0|E|PM0(Xw_Tie;zekzv;}IpeevC#BLs2k0baSU zv%H*&TrFkd4?uhUFB7Xh2f78$5+wn0+GN^DY}qsDI$C|cPu~_f;|{ETt@@y3>)-JZO?*BU+|P`HV?&E zQDuxTR@Za5}4h@r0SEPbp0_e`iYmj5p`-u{qP6F4UPA@O{Yl*tqy z)ZV_8{P1@M%xez4r!qbw5DwGQF6HByJQ`;dJuj!2$-CV`abckts~VXc?YUXS@}gI~ z7ePShHjL1oI6SMrf@8I3g^_{~#zVk4xtCrG{NiZ+;C%y*EY*IEX|@;udPd*V(t5%q zVY4CC8-!VVjLA6AJ-2s;+O_!ZI5=-5stYu=?`@H_t%Mr9s=%g0gV26*3l3pZOW;@3 z1y4Uj!umU{e=z$edwIFxW~2^O`miJVtzJvVx2G}3@PM-2rkW^3h@DHhD79D`KJa2; zIh1!V79@o2lF??^n*2&y>G^zga*S9aTy(8`O)q+CweB00ps*G5)A7N1x>B^D7zgyu zg!aL4-o1)ck(AuZv@}P348Eq`%xP<+j^BEEcF_&4v`so5OZ+|Qq+9*gDbu)zrR#%R zhYWH8-YayE+tyy#u=_;Nx_0U5=q>GxCcQS^$Jd+J;mHnI46C9D!wl(vMVdmf_ncit zDc|^eC2*D$=c1!I{n}x8G@x|JRvAYq7%K2Ye43hImCXJS)gCrb%xBs>Z9WKEmn7te(oqGSlzGPcIs{lkMpjp9w%uV;08hi8>Cp1gn15b~RVqK0}H*th1@P8#B~hNUAzX`w-JvfDGz2%M`zq z-^`rK`n~xIZl509`8;A;4)^rrA~H-245F5Br<++sHj(ZSMjihoA_4I9AX4?cDh}Ka z?NT4SfJz`Bix0)(xW+cIWef=;cw2r>Cr7OAaNKTp>gzRX9eP!OQPv)sbS}{{&D@D@^POhSYvmnB`0{%QT4)} zRK?*uZ`li2E96N&Q#Qu4>XcjH(ZC&i?$jWwa-MN9Cz`9$VxECPGkUVMKAi7-7GL86 z#oAwRV@1tLYsT_#DEENnK--jDJG68CW(2E%chDy zb`*qpDXjn7PbfRx#ISux`eF3R8FK0#tH@lz`5TkM;V1uw%$1Fe?ZiK@4{51`2y z=p$<$HF2KoF2eS!@+i-H^+hjLd?hFd!~(^rGQTp8wQX$}%gsjjbWT5BgfJx5NiG`c zW0WH>NR^Vcmb(8B|Naq-$DcRFc6KZOD{A>S@O<74-f5DfNgv06`Y=|C22Egf$CFA+&aA$tY-- zhVQ$<6IhMWuO9c+^)k2H*tQffI^@=x3UKne1G_lxJn%uZRBKx< zIu7cj@>{=a7&$n6k3)Sy??(gnw-0w1MpE6|H(VE)yq|lU>3BE=8y`zb^(fbK0+AE0 z$J{t&Z@0{a{Ka7~R{dZXjt-wvPFZ>y6}zQpYtHSdf={g1RPce2z;$8;QRcwhgr^`8 zCI?j-#b}oof6I$+F(%Zpq{*tBq~*+NPkr8@k3q;bhosrSEO8vY1}|*L!_}85ZvFYs zmyPFSocG=uV9<$kT}8Y$1GSn@H+mT_4A}^&2dPF{2CU@5>w(b^JhSXF+m271rYGJ1 z44@G+Nz(P6d&f!lkr4}6tesYC#eUwr<~7XK$KjK%U3h)a3(m5todds6v7LI!`V<;< z4J3HDUXlk$YOAL=>cRfNVyTtp90u`e!z>d*orP1Pyw}cHy*Wy7oS$S626B|e_4R?5q5ewwmkb}0g3Qth^bmN4VbWuGl( z#~%mkr4c*0QyFXbHCxzu1a3P{8u!%ScDeSwn(!VMC&)CAzg%PD`G?o8;qLmAF9G@% zV_B3r6<({Xuwy7H&Hi(qKjGZ(&~1K~R14-SRga6_8jhqN?h1P- zsq+whQ3JPyQ3JTXSmEfzzneb(^Owj<{hcgiI$?F_BpkyT3gG~LsbCj*&yKjAmD!r8 zMf0+lQa#Jj~WcKT<;fb!x>M^PcV1MX5*q+8ko8 zJP{a2pLN%ZiW@krR<>(#`Z#5%K;q0627&nGJOhs4XLHzG9JL?oidSR0qxKs%6M(KL zfANV7H8IDi<4a77eqS=g3E*CFpa~fOBcti?fCeYw5tsVcsac56sTr)!)wIAy{8TWr z9!n{{XLt}Ne?yeCt$F<7!zDcPq=g zlEK!thClPd(a2jAg67>*S4(l{ni;v0Vmtr4d&1R ze`x^#gAcT^g|$QK2c@-!6&E;MP~2|=T_Wmt z>b7hKS_7x)&)RyiSm;4nFRZNxbhIWYho2F43m=B&upOE-tO-M zhat=MlbMm4c=<`U1(wFW@)jG5Q^qT9;Bi>JnG(W%SKRg=O#QDN`K&b(*Ef1o#-2ynb{k996WL+j6nOc>I-U%uV7q~*$z+#Ta^wI1=#t4oJJ-bNWa#uY^KmF`L|4ma2JpFb|})7b&hLVHX!qx^r>u zKLfXew3&_!J(hfd)c*42p}0!U(7+(9(}dLGpP_S>_U~jzS+R)qgq}&n`B#r}(gWhX zPNjdNUXK{BRbN+*?=n|R6>n6`*4989kefPoP%%${R$Re#RrT)n$~#Kc`a`+QhMt3x zGCXvt)VlE{*^wl}GlXHOW=w&gmOIL}2yR$MwPe(yf0_+U={_q-pK^Ap9nO}kq8b|;k8dqNda)M|uFM2VFTH-8Np(!?sJFF4psNuw+=93|3P~i2=zgPMZ@m zDixQ!%$-%7yr?Awm`}Q#qL3wH&)Ns%fF_w2$2^jcfMobA=XcX_PfrgiI_#{#@KSWT z($UuBkxB&PEPu!RIn!N&yn-2O2GNW0)OPzauM<`}&xxCRlBF26Zwzp3v;)L>D$S{s zl}DO_czL?br5)|PZs15`#hjb!?21AAtk5$mtGTSq8Nwj-G6cw~dX{PTGu9P7n{%eH z1e3LRHejTw*i5fV2W+9%Q>t;seiz&_RmV{H-h(08^^>;S0N zni`Pl`ax*8Nciqk-uXlnqorLf9_*G-4RSY2NlEc<9|AWwOSL2EC4oHk+FIzjpRa8h zVtS6eR}NeD$59}$h1Bq|%Vm9^fobqXVMNlr-34)oSiSd;TcQH66@M3o)$L))sZKdU zV=)EO}p`0E8e%VwCUMNBlAPt zwQ@eciYa|;VTJni!hGI&&z zTZT1oFI3+&5EM}0;0!OYWzuVTo;LjXBzLTge2&NXGa4MfCsI*w&`Jp~c+R(43QB16c)c2PXHs7mK2zml)ksEcj%q(pI%k z;j+@f)g570Qs{$vDkYaf8en{8oe;1w=feu^9w0m?&VXkmEV{4?xGb~~#P#C5Yq^r4 zwU>hPFF+@5uANLiQ&>1dR1oWEfZ@*ZRi29HkupRMrj^64>~?zF^?qFC+Q}d{*%BBp zCyXx$w$`_dCNJ$Q&$o~vP}x_GNhm~oorxTwr6;6MIgl!lSaM9K%AQBw#Sv zGFYJbATnB)%ka6NtYdfFvFp(W<=Ad)^-;evA}4v+aK^d>MuNLI#t(s_QLJV1+OM)O z?QOq!`qv@Ib^|pCUZ<`O86vRM)ba z>@6Jz@uR)gN{Hjo2Rae!Pi#XwwN*ylJ@ysIIrAtolsM|^Tmlz&Lvn8$fRR0j5%ICd zxcNA~ptB@xv2Q6=AUVF`1Nv9cH2>$16EF$aIaWZA-ATC^gq~N^EGq zHFe2+4VrGg#?0-sfYgz{)+OnTZHMZ)w;X3!fiyJ4BJn&-%UY!9z7*5dM{l>cu-4F! zVAIP2F0-TuZ#z>F(m{uhI5UNMj=;?Bj>n|N@}!ZBL#m84m3R7&g;TKnQ-&eueejDw z(t?A_ewH9Z@%{R3@qsb_4vvu&gHF3UmwfxNj#KKC9Ku-d!Hf2#R>KIAu-0KD)#8IlCb5rHZF*Lgvt}OW1A=F>lxitLH*nREjo%n;86!Kj)Lc z?C4k10lZPbx$ZE?qm$!32zCa1!fy^{1YUql#rkJ--Lj80n;c^?V@y7rX-H%ycU2$} zw5G!6043Fk&q#vslN(vgy#pC9RwYU5^tEPZR~GAVnQeRDqKYq?+1pThPknE{L$Esk z^e;SRE4=HM!|eL$m*1V|exD`CDW#>G{p8g{PnlVrZ{Bz(FOLw1x7D-`+f}#h)Z!Md zm%ttqa1U0*P@7Y49D%JCN$R)hPW^{d>n8bYP;Pyy6`N*M{f{D7XoI=oWKJ1#B;l+2 zVNbb|eko1}Ys@^o)K;mKp0dE^De1VvjGKX7`S>3_J;j^Q-N6DHkmq2<5nzGfX?o@Z zx`On@dM%T=HUrw?Pg2;Y6JoxZ1G*>w?{-MNGXrmoW#o456@B4=XckCsCgZJ zaS32Qeyd;c4j%RRUJYL8BfVhdK+jVBxpqREQrnYRZ;4O3&)5ZSsO#h2`3yOE3X@sK z&PiX6Z%2{0n9+H|JsI02m8sZfZE|dF%4kc!P)!oQ;tEi5H6bn|^0@&9NU(p_)PAwf zX!qVL@*cb1p_FI4j|Ct)u2$}BBvczBHRONpA;s_Bp0W%hQeWgO@6c7%oCB7qfOLP@ z=Y2iUH@nNdoT#4?w4I52Q%Ba@iSw`G%M^J(6!q9_3&jJ|+R9b+|L_4L$)49Z4Q8H^ zDPi__Lh<9S{OxBqAKVD}CjEx_3HhrNB$5-Cn0bSI&UqqfsNCbSD5qny>2WRFzJF5h z{@R*9n;)HLA57$TlbWAw;GUL^==F*uImJc(j@qJ3A?~qSmNMFAOzP&@E_6Ow|YtnJ(jWp&998EQFLE9pRT`RWzI{osWJ9@*MMRntS z%Go>QHXFWp?E6`)OZp9KXA;Yp+?>_^s@Q>$)bbNmimii#4=D~cJnHH12hYDwi$G27 zu0dIUK+o*>_f=q-ZA_L zWc}NM{BDcn!-SGgQ(F{|$V%Y6TPiGAkp1rj%2q(=Xjk^i;ISB~uWOY5oKETSnx zaL6zpRk9eE*!VT;5@vNKk#(hvq+nc5NTkEs*Z6Ze9+CRE8dV8|s%fMMSl90Qc`gz> ze!&sxv1nPs(HM%I(VZ;$WYS<2JhRs{n@Bs@Afx}1IhV;uel;VHgs zvGH9$vfR^R%0aetCHk8)5a=zBXasME9=p=@*U*1YIn56dGxolELxHtH=+aZzv|Qx10z7>sQ$p@v|=7BJOw+3d3PD~6#T=9?2?N28IIwr>z%;(o0?9eQbi@9eh z66fs<=+6s5v##XBo3jPcxSNNbD%cy(T93}jf9CyMgFTltxQ_d@&D+XJdLZI^S{qI{ zN;miO_ksPdTONqS-!O7Y(cIyD{pp1&8Ha4j3wh(j*mM~j@fT^p=a>hE)>}_HnuO<; zBLbCb7sR}8M?a=GDfmqBN|3*o-ABSNI6e56$Ww6&BsL?;+F?dUbcs}tdd7|zuS?K< z3~Ep|K7_+VDw9KK*jzKd8K9on3q=noYoT7A5S_?);(DN2Uh-7-bYuOmFFdzwy7GHM zYTwNK(K_MsKp^AlmWrU~{l{#Zvk&S?R;p%Ri8*Nz&8`dmIe?(tU7!E|oTSmHOT50OfSEMXcEzvIqaT z$os{Kh&Z){;?Vc$Dp0wRGSa5{=SBQoBg>y^ElU=tYtd`+<;ncaDzBB^U8Adc6Qzq$ zz;p?0DBr2Sf8=Z8-tp#4ApT$`tcl?B)S{^Ae~8A*$)yOgh43gzq_Rwy{P|SS^MOX+~W__RK$+ zEh%PSoyi!={6Z~dRjBz5S4D(}ANOQ(>;>NOPx+3<1L&D4)$*sYgw{ZxTu3ui><6#O zrj>@U5X%efOWhj!G+3SQtF-KqT+WZ<7zsm+<#pT9!+-+~&QJ(==ca(qIZyS(=hw@F zl&gAMUNH#s!_Izqq9Mx9vN`2*Z?ZOQH88KmaBu(RVdhYr^!4$M`@#j{586}OP4>Fn3JfP8|f+^uZot%ATHQ2TOz4q1@B zxshZ5-(;lT)Q`-}U~6Y?K-vkjr7SYgvJGtgg5%|i9Gn-hEe2~lB277CVi^fU^zZee zj5_1SawJ*X$lNRQ`888DM4Lx3v0=v6i3W&qZpCfWsb?+d#*X=(`_^iQ?>@h~zfW(2 z7cH%qBeKyi@YMgRG+3?nw`cmwcyl9VA}u3@`9rC*3(JXY;}h;P+J<)@(}V1hr*EY% zx{Wk-#anlb2OdQWmOoaErJ$7il_W9AvIxTafmre^4iK?g{?uu_w4>9NseZ|E7S?j%#nW?`ouiAKT zZt|0))O-kYA5)?=?+M9QzaOd?z#KmQ=V|VvELre!XL1$R=IVy)V|%8 z-OXtgj~h6wm>fNTd1B0kDZDHG7Hn$g`WWN@HOmfBllz=C_pQpWHDmUIW#mAIP>>l~ zZ>N`{NgO0DPgIZSrD;^kUslQh>iyuz`vw+BJqLZN^@x)Iav#p!-t+P&)EV&|s@-#8 z3i!T;K0=%bci3nbMZJfQ&RKT{qoy?hEH*QqeEZ@$;uR^9^BCINX&6&pMvhmHSis%> z<1@u$B&W}IuJ`K`1j%=Yr8T;Q`|o-VV3cD#c9&Mn0^X;7voVJY0aYRQVvOy=*sfIc zNuyHEjuF*!IVy+zd*i*hJ(<)Q5oHC1ckb>&YGqo1Po;vE@w174Ew}qu+05~HqYoOi zJX{#QT8}?Y?@@aDz>-Z+nNINlc!$#~{!To@x4U-dng*SUIo9wT z4;664ZV#dxLXnXz4l)K^TTnWECn?K?glzOL@^sWZOG{Lc;9ZD8}IQI8XY!4*|xmE$U}&^7Gl?fU(2wN8D? zVQNn%`=E2atj-zJ?f~{GBrK|Leyu$FFr*gjgEqoX@LY0nU>ByDkD8a zq7(X@FI8|ez{^qe^?eAVb{Tz+T)Lg#w*s>~zq%>BI~tpWydHYblCDJiS_UXtcMlXD zr1?c-Fg-2#Ceat7x*ly_UHwL2RQo&ay5qqahl6cNW>|L{DAcFyye&F8CEx|`5_xZ_ z!#U}gtJoQI6&jEq*)O+8Q4!*O{m|pQP;0__k8fBs&yz2(F8?L=ZOSKR`pJ6282Qu^ zi|^KTh%a1G+24;?S^_tNLpqZ-&?|7EM-=SmVa%Q|n+e_8CiDJNFQ4rqyBx-q&iTqZ zX>yHXK}$L5B);Bfr7p%dUBgD=>w=T}@AVB1XSWBvf5+12L3|cGq4;3>5P`GIR})~% zYOP8rDU{ZiV9BACamD?4jqYD@8wh3`;@ zi-8H3fmC{LY-~!OjA=L6QJ&FWSqeM*{Q}Hg=>05$P7G+Jpsk0!Is@6#&^uomlK+HC zU-lIF7f#ecMzKyBUuJioUiZP*pV-h3Th>p1@T1ejj}kY>OJv*iic|(Qa^)a^+(o3* zvI2XkoYBKfc-tonOGS=O=+LfPP)CWYv7yV44>Kd^r6=zR7OlOPN__?5EkCwQ&l0 z)kPv3HF3Z)S5ma)diqZb=3)+&&HSExzdtL4svoSuAOSl*lCrLL0&1BW!RbH6&E{O~ z2x%`Sds`Nl9^GOhtYGL_#VE?8uteLM!G=rtDSNLwO_J)S=V$_dB9H1=c27eIbgr%a zL~_`MQqDBzh3^l?KXBV(PNi-5NZOq-%%%Bx+drdDkdo4zr`z;F@m=d%!`o~1wGU|W zIRQH!Hy*Hq$r`Bw*JjX)Wta-X=rx@DN zaTB=(4UUa{h#&vxi`RZ7(rtEbM?xi|E-u{ZCA>A%U%Sv-cO=(Ru@0qW|3rtSH$Tr` zD|M)MLw%aj&iEtMgm)!`Sd=Z4_xZ-#$ZGkfLpOZMnLF0y@DU$Rbp;%nnWy)|7MLEJH}P6=tY|GdZmR4%2zUzbg3`DF(09zD*z zy&h9+V*R#xFZp@0DWBg3JeES-0!NXRCj}=+>9@~P@>#q^jq3BVt?Xu}?>bFT_6-cY zb9Q#Fdd4<<$#)>7PHR5GM#Lx;la|CBlbOnQmr6yyJHFDI@1R|N$U)$%a{k{*@Ur2% ztnv3XVI_nQ){LXQ5K@Sj$xPG5umCV5{PvBtSk!UF(~j4l8W`EQ7)!v;MT?r`n+}}% zl+Qd2VzCC$sR2i(Qca-o0qm&Kt(V9|PrUV#e7~ucO!(&uGutD%REeO7KoZZXgQo)+ zud_co`npp1hb~Vk&iWIeB&+f7E}a!`%sRJnF>h{zVf>lA!5Pte0M6o9I3Cm4qrIueeZ*#ifWiY7lq9TtnIw034imX~~ zGj0&1n65yt*;uRodboWyV&j~8apy7V>klb-d5`B}qkAIo*Ty{t)1{itF-I5_+!!w+A~GR# za!xYkG@MZbN;#t&w#@fR$v5d*h@8@vW&0biTQ-!+vf1CzLwGQz)|D}EalPp7{lQ_w z`Z+h(uA_VLixAEzqPhX_hx^>~;V1&oNu>{ZlmFNkU3ufA@ibVe^ipan;XB})qKh-J zU4@HPy0uF{TONZ^3}TgsV!PE(?2O-P7k(S8EFnLr(1}k1LmU&#iuVSQsmT@%Wfb=x z7Q?%HWXP?{21w|2!okEUgk**SdRiymc3TA5mWOXssIY|_9*z68AscMTsz`4sffX}z zLy4cJFU#sD-UyPg=?XVW2@$QHrAlS`NLjG(6D6gl5^<00_s@0bef*M_Tu^E#uwjY^ z9C9A+PF$O2xE0?8X&~hUpcqr0buTDi>O@n84fCSM_1D*^6y&j(m0m(I^Xg0%-oMa+LZ~hqcZ-f|C=)O6KRAFw^RjI3aHJ0opdE>^7 zRjk$QN5Qd~p%T1fc{MX=@*~l+n0V!KTef!SiMHa=E9@{KqWQo3R&cZ21~#G3I!Q z8v^UdkZ?-9>bWlha8CbTaaQcib&HazgK=f$KcX!uzc)L=-4Ve`ZOQha^L4(fyL=rjbZjgltg^g{RrU%t zk6d3QaisB!H@LURtF-#)4&|C6z*DphZ}AYmv>23i@%plax%AYRU*0tSa6KZWzla%N z8%5qO)2X$(oq~Jl!~lx|s+2Yj&i|>zuA$JPFqeb&^R`v08JCo@Wz=rCP2x%V76Ju6 zn5vqqW_Ig@vRn;lzE(-z;;3N<$OyFte&C&(bed$I;$J=gO+56!Gf{QJPsDM)oz`yx z-m{+AF$PA;n@+6TIxVSgc{|Z18!XlQY2Y{vJ8&)KhVR42 zX(dM1+EEo1?b33^W7qP?Kc80OrT<%TShU|OXsTdEz(SUACS<(zL|}b(|9O>zZw()L z`{xw2q6ya;G|3eGOL_0n3Ys&Ri*PfhCqX=w?8=k!3Ro+jG=|if79&Bd*l?66#-Wv) zGDA0oFY&ZIt--=T=dLg7NJ-qo%eni00_Poiw1&LPc-XVBf>kmxa%rk4wBv*wr=e`p z#vKIBYv=adobwWZ*Kl2p6E6~KVWuYfr652>)YevxerZ$YL`i<1>CoAqaxGDXB!&s2|N`yz8Z-=HilkBWCBSmY>6Lgd4nY{<}2~(!_oDzS|?l9 znSDHt0H|)E<-EO<(@5r??&M?ir;Yr5pw^9|zRD%v-kji{b@d0mipyU02}! zZy&xePGU^oqD}{`mG6=bG(d|8%K6Z$)35Y-MN`~{okkT|#2O9XL7QJ1Isw2q%`}YF z-3T_(`Qi^AsrMgZ?h8_ap(~MG`A-TgShHk?USDRjLFuq0@N?pI?~Ot#m?S z1+TKOY2CSAG5KA(<-o$!pO+*3W&69PK7T{$g`aO1hx4ghR&aj2^nX%GmNR_+&9ktC zZ-1nICrfT4__U=GGtH;LMRSTlGPG;e_sY1F=y&gY9CK>+bWhZHS!~fNR!)zH@DjFa zdVbRvv|CypBS-!}R*(X_>A0a=tVqJgBY7ZljMu1bHH(;O@%M7@(7u|IdTn^@U%4Qj z2cVGp7Z1?>%o0rY1m4T$K4Q>Y72ECj^%E8Lp>@&8iAYpzWk{I>39Yr^L~JOx5>hA~ z=PyPz&B#Nu=Fvfgvg-k*s>NB*d{o7lk`!)UdknkI?l%-;2rcSSuYxXh4f>3)Fty$u z)W^|2F@Ih7*P7qAXTzu!1%+{MrZ)*nfBW3OAXhwP=F9Qs4eDn%hlRY_^S>yC|IlcE zee{(h!E-Q7b)0Ny&cD998}JnJk>N@0L)V(|MnQx~S$U^VX$7M6#`La4((9%@^s zi&p|Bs236vM$tn>*4AyNwI>NHS?a08I(i_Y{QNt+pGok2Fnw?-`k51F**#jBtDF%J z(9Iimkrs68+NKeAJcLT14LK3B+L~ET^87-hOG91K0pk0%99joH&!LFL^IW*p@Ti^E zf$$|kPWSS$mII~mvfC2~5y^YY`O>jsqFood!kQHHS^3nQdvxD_I2Y!s4M}hG+%tMz zV9PaSU(A1=(~&&$$Cu+aR#rj3mYdykX5Lg;W(db_u$76dhAjDy7MX*7(*SCDI%w}q#QheQ;WPAFe?KTnp?e8<~OmUx`Sw31{CKh4_ ziOYsk5Q;V_w#k6(y zRMoG?>0{VHx4Kmh$6tsyKq}gp@!C0?U(~hJymvLH-O94TU-`=L%umjzOmj}Hi0126 z5?6b-fzBUxmb*eU|FKEmUZxupCuv6iO}-Q#5#4=YiFUJBIuG{@VNcPyblEJziHcLr(<{H3HTs8A@xgfE5Zx~s~F661o z`H2Ntn7QV^3tE+wBZrt)G+!eaPUsa2stSO&{B8Gt*y*oCYS+^8)h^>;j3ZJx@@kvuy?`g);#kF#6a7q0)4n=nU#$ z_!47yp2u5#$ArKP|IU)fqZlhc>;%XUl>3qTi_p^(TA+{9lW`aCq^c&JV{t$75ZI?CqJO|W7dT!WBPq+HIl9`9AneI}`;gQj#=AbwtU~N+J*W zcS9YOGCcHJIm_u6e3dfWxW@I-m^63p%?*>a7Hr@&S^nl@YJtRoalRiq3Ttj=uAUq~ zRzD5eqN#Ghd2yU_71kba<0)gf^K5vjdPA}jZq2EhcWRi%b4G2wt-$V$+^nC&SXf0q#>M4zsTAhi`KP2Ezg(O`~?7zbR0ObceS|vqLS>aQInz`gQqJNRD~`f zu18c&V1|kT0z;1`IQYt>J%nEyWlqxGxS)EXV|S2LOH{Z(2zhxY-vi`w-|*Pwa6hWZ%m-r%e3ym7zeu4XSzv$6$4N<20NX3n}J`!N;WG&6}F@96dYkpU4RzFz`7L zm`6tPjZ0R-hb!?Jeq`Z2e4~qojW%9f^}USY=IUKo&r|sY7xFPlfS)eXQ`!h4BmBin zTjRp>3H#G~(QxcLLlq$xq8C;FkFIY3uB*|)ZQ9tj+1R%2G`7vgY8oev?KHM+8;y;| z)`?E|PXGJg`!3#_*>h&j%)$0v>(f%82sAlVFL@g+KZ$~2MKd;$g%f)h3xpysLUfZ_ zZX#k(6#R*fb}VO`hMa+n0>B0)N}24OJ|DPP^|K8wyvY-DY^V zmhN(IGoMhWYU+gZxR{`c;icmktK0V3@kpZK(PE<GaV8*nn&7(naUidfKK5{-DZUlfEVPIt;cecZtkTgpR(b`XxzVO6u;R3&lXcMWi z&*X%hy9w9($)YXd>X71x1P=UY(TZ=&j7{B^u|l!CVoBU*Xs%_uX_UR&e7c*Nkj=~F z8Nt^pjytF=_Jru$SH~I5kk6XwVyv`yMgn#sMXz!mobvwS@Ep!s_XfYr_sE%*noTxM znSU)_))Or(+ls2!T!3}%Z2Ipe$yTdrQX+6jbUVFm@o(2@B;CE%(276 z5o~x&>aj~`c>V*eS*Kz|wi_LZd#Def_Qv%-6z;^P-s&AoJVVmz>TV}n-uQ#6)yS{* zCelm(LO;Fd_|5QPPCAT0r1h?>@?v!iov!pS1fovjrUG8h!4Fo>KdhN5-f#D>O00+*@Y`MghUK4)o5gJR;){Xw{b zm34OHN4wJVR}dBk9?XWmYC-l?78w;zV>Zd6qlLVU1VNei!kn3n9XY3?UKR=P(RqV5 zm8ayj9hFlM(zlO(yrK>NP&~8==1^fYlIzT=K%ll{^VzUuvi@P@vrS*dowxFjGgg=? zQF3Kf^@PPFm`V%BPv9FaN{f%97Qv9jffvlp-&Vsq{ZRZ0W_|<8B~9CMiHU50`H2Ny z32wL7GOO9BLE~4Ro=UtBw;H$jAdV<%VxC-hwv}d)%RAI^n%8dG8~bd%!>VRS^hGkA z0#g+Gb4z_wIuw`YV?S4konjCXPzRq@uwrHNeNJxw*Y;Z|k>9tN7{yw?(a}QI(5din z$+M@dypFZpiyAvMJe-uO@>3HQEujMXPEi-lrI|(eGS*}0y+Vsv_2yruEIDk8lw#>x z@%IE+L9DMk>;?GzpANIn8?3`#%yDkQk+mQndumPK;ttmHE0>Ey^+Lqi9ZZ$9GQ5P1 z$J+?Vpf4w@yCe_3JP8%KbqpaXFIf?wh6SBN#rkZMtGx5ZY*XZv%yGrK(&BQ972Sn1 zz@KT)b|$1Kis|01udr1>ct^~kyHm-;&?}@r=yv>^a=T5xe|8!_E>B&jLRVN!mXewl zi~PWN#zRDM(+qcy=DoOVez4cqHv!J*ZvBx_Qb&xU2$udfs>PJ<>9AmA7}s1}^Lc!r z!OB%JbuPPcGN52^F)>XiCd_Mvc^R#sl0jOt-pdKIYbRo8xm37H^s~pc0zT?& zo#B2-+r12K4xYos{h~*m5wcrmdo&t~yO?uM2z#usSpsYQ_O59pEi=qVs0&V;HGh6x zu8~|E^%I8aY^a0fc(dJ<>fO~De=!WKhd1i`s0)OP+B>-JoCzyjl~Cq{UQ^lwnbZ+9 zKhx%#+tv<$+9A3vr1znpVO+Ha$aaIqlotm{#+SHFDLvfpIj%ny<)NE5e6PKph(G8bS3%Lx#dJ#$+)g)ltTp|K61%EfDJ-abN zg$W1Wa(QYfCx2hi+r&wl7s~6niUm^sl4??q(+H9F2IvuzN3bw`0W&)nqHNxa z;$WqaF8%oOQw?{vf+3g6(CQF+QQ@;Qi0#rZ8_7VmIF!*Q;-R{Ksf;~9WzOwecNcfI z>RJlQ$M2wJlHP+eMn{G zQngsB*1x$_gGQ^{N+@)u7Ob)}e$^ZVPh;=rxcgEN*!W9^8zq-$k*VH|msQu59}r%$ zui9RS;ft-g1S4i|{JTdCvso_)ME+v;FDM5aSgrAjYpYT=Y@9%BcjGblWu~*&(%J7J zWR8+2C?`iMm-WiH&yKKKHBr#rCnRg7)JSvk?itHF(xE7t%m%R>s|TH`opg$%+fLbt zuUgnE=IqTiz8r6{_UZbJy6}hNTkK?HNjkl`WpKF=GgeMsVEnsmFfF&p@5s^tRze| zTLX{E&IT$@Gla+@=L7RmDY|s4H=a2zB4?6^@AJ)V=qLqJ15MvMYa$0irw#4RU*|64 zC9iU6Bouo#F`7MLF25{bNDCzU9ef_jgw#uVr{+#O)_#}dGH_&H;N8(}iQsd|{HJEr zqXVVn>+G4x02i}Baz}63Tj_dKv{+@EJb10qY-UOZEi`Mt*TWP&=?sFQA~==uu+Ohy zyr(ij-t55f|0HZys3o>qT^Ci3o}ksdl5;W;>RA{N%q0#YLBj(vzkuehK)YX*G>G!; zOy`wPf|_Zu^q)i=(+5$b^)`3fBCJ!Nz_^da)}>fGZi9DRnWXQ_Dp}U{+GMrO3q778 zZk_a{){|~+N$cGhf)D)uqS{!O*_arqB(zc1B|YC7 z$@rQ(UL!BdLD^*obMPR~${I8b%*+qe-?z74;R$#eG<|BrUcMVE7l!|GCrT`;_{k^d z7&U_5{4Fd$7-@2t>if7DPs+rAR+VA~slF3O?@7V@>~4K@mj zsr*-QirpPo7Q5>Fs-eE>D?epY#`vZ(3(F8^&68J@S%GG%N|p?hdUJM4ZoNdaNjfww zG|~)DHvi?VLHz+L6+XCx&Gfx5Z6&F(gGdS6QB;g@^z=4i{y1wIr<)c4+d_Q8nY!NBHGZ84;Q&zERDP)L0>;s9F7npf#t%7u zDhrQ@2MJogj6oN|keR65nL%mnrU&jzN74q*e4^_R(^Qhy;mJAN59zkMwMgHmYxzet9d7 z=Gz!#yOS>UCrgQ%R|*@efxnTTf{CjbfBE!-b97d!_Dy+~?IFs>EWAG5G&dv5<34gmG&m8s&I#}eRd;^{Gy@hoLZUP2SIR$&$yM;wMFC78Axl(rx z0&+}0n^oRcf@7rdxbegG5-RkeV56DnBc)4!6+rl5nvlGieLzzmeqUdCiXzM)pvl%? z7G{}p$ZOzV<4E}71J5UEF=4f~{8{62yd@M6(ul}z^Mf(YEuA3D>>yF0eltU^25(M# zf%T0?&NtsX>tsiu9=`<@hS6DW>1T5H5fv|vcbQXcG~rS731~!x#1Cw+%@p;i^L1jj zN&d}j3dnPhZ0EJZP1Hfg4c0mH&>+NYzrEpYioLmQCf?XGn&0j)-VcX&8J>C`O@1?p z@!5it9fb&h=h$!_B%`XX{D{PXN5hUP*fvxmuL1%L@e-hX*M_= z%j@>ccdQ#jR@pmqf?o|=mMH3eZH*IRK6SN=Do!%#>$3k4>G$l03z49G?gw&nwdQ(e zP|Z`Ux?PA4a1w@&fm->g(TEX{7VOGDvn-p)HfviJCo6>o)C>r@{*0fMJUppQx_z9HS zGhv^#On29&Qk*o&pulElnBV$x7vjh5Q!;PyvjnT`5DpPcqo&0M5n(st8O3ikD@|yV z25uK~!a<$_%pgmZxr;+H*apbXQ5j~Em)Rn_-^zlJvs-UXq=uTrC)~%GYEMgL{kr-7n{rylvoJ6`lFEiv9ND@McGV_NhYER$672_j=N9X&_f% zssB@S$A*AAJ}uQf+9%L#3!ae8Nze>Z_il)X>3~#M#WNbad8ig zD^N3(Im9|(HRj1#KG*n~dVR;X>K>dg-@UBJ^nf(1RULM9F04rDp_uYw>RKVv$5=M7 z9>65M9+EXRP1?W%P{3qNUN&(d`2NR&l)Eay(dE} z`1VLfH!5R&do;6c1#5<(FA%^7BcEZGxF?UiOCT@?Xor}gu z!*|pOA1)r0{#^ZoN(3x#=JX^S2Z-KOY*q2Xpwq!JN#vaDwucEnn#e4pI$osuBS6a&okf?NAJKM2gjk)k?Ft1ujiKPL%US|8A|I)$z6T$h5Y7q)Xi7jaD zLmq?u>+N4Y#eXMLx1Gd$8p;$VQWsy?e!Hey(GTX`5HDFVtYUfU;{TyRQP#Usr2eoH zx~mTw*K+8)M|uRV?e~)AGNjSf24$=<=B9!A{oDFvVq&k&R^qLToQoBTA+$tb&m2*4 z1UP{cry)x$oVAT1s?AC>0qc*Z)6mpdm(m%p;nCz|NqjxKTg73kLiFQrpz%A) z!gvy}O-_G_UVyKo>Uj-Tg&VTXXxo`o_X%?YumLb*wjj@@DAIMDX z?4k^WkaagYuP!eHuOAj&nOHd`4ITAnnxeA!%a-7Kk9Xg9a_sP8)T9Cw)F@4Eh@@x7 zdh(Vjml9i;!<^N6s0C^4OQ3;M+IJ&?5oCYSeT2*?i_r`t)8m7+jj}cF$0>Qn|1=Ms z?g(3Xa@C)6o~1quhex?#1Hd%LZAF%T>_>rS7qc_q%1SiR*H-+5KiBtwntX~02v6g_ zadB~aebl#2Wl4W@6&R}KiJ*mYFL^#Z)A&S&^%lG(NEm4&`Tj>4rBo~{I77AQQ|;%D zy+g_nc~Hh~6=@*=1f$}3`BwBQ5y@v^rDO=wg-e}od{Fdj`G82+MS6J-psH~;78Kp6 zaMYdz5|O;FV=Qx!&?3~0g{{?~Ir+cPpWOLALtK6P{twPh+5|Ub#FO^h%M0q-WWj<;QTc3My06$Cfu5DfE{;RL_1IpX?Gh5v>CcSc)2NcD= zpT4ZA%bMXXt-wK4Q{7gun8WN4(e@%x2yv@Blo3yMN@hs8L&=4E)m|y$@yG~^h==KK zZtE`$NdI$V;X(eqv}EE{7J+Wx6E^D_c-qm5YidC zJ1|7FV5(}s8I1<)aNBdZ1#a^ePz@SAyp)c#S$I`f%USJ!99d2t3qHchpVHk$K!x(p zg)0pw$3fizU4AO5T9FgcoGhtB!yl)^_j&ujU`n=3n5OIvK4trFSgoVq`nfYCcdyH` z*_c~>0@>Y!Q>V{YzwxFh^iO>$3_rU$u?Pxh_al>UO-o`g<2+X~);K>uuk`^sX8|$k zwSz$Z{u4P>(lq!r_Uag3*(I2rGma9d9gbJB5f+F^%!pQs;|l$CFLSY!gW0^E0!(0m>ee zc~z#`Ce(t(S`(L4@0OA^LlM2$3=I#tfz|mHecuYkp8ZtE;;hymkCX2`z|rU{PdhG# zs5dDCl*&h)dP}BaE4>yaZh51d%k_GKyp(rgSA*~JfrjYocuHZ5|A&q+^@sP2B#w(!!_Or3hni<5qiA+y?Y5~exV zo)0z4Wn`(|L|<_>-jPmdR}I>LbPcu3U;T_=rSN*mm-5Ta&D6ca0ZLNVmt-D!m=@msUAGGINmxQ> z#(#gh3y~`E{4r9nqBJ!%_5MMdQ*uI!DtYpR6O`~*+Yz+Va70e=g#hX0OBJA0`dOJd z{3AM2dKDV1fRBA|+Yz1x;l2N2}OtXC&1RDbuY9y~5S(C%lOK84~8KN7gOq$E}p=J1NC zf-df83pz<%t$OO<{FKBQkO~lDeey9BpfYD9Lf)blZ$KQzXj9&rNm^7uU@qNVTsb8> zEhJ!NsD(FyYfcG@aG~kT&dH{pO|0AB$Uyl`KQfMm>sp9Nh|WZMXoP`6IX4=r00T-^ z{M|aGQ=bBJ?qzinCCwC4Zu0Y>gU!s8@V_5zASB-$O9P2~Xt?a~^VYwh(LNN{)3#-7 zFGUA`467-+mywi|f*WM6D**nLI%@-OT>r`u$2FYudLx>&GN%|>JC+~a0r;s6vmj?i zSb1TWY1Xy5sh^;qztpaFKi&uv>htA^Mp67p<7zo%BYVE3@Oakg2)( z&Nw67;Ry{e|IMKhRqYbPcnbG${*gSr-WJki)fr8-2fU4#N`wD>Wn@CFfI4`kV`aJR z>M4TVjeQDm7zZ4 zpq#i;d%K%KXx3Z!1%Ep@+PMFb_X&}F918B84gE2CyR^C(eL$*%TFu;r)m4l^j%JJH zuK4?9JQF&l+2Na}TLcU%9HaQ){`na}a4AQ^(TfTf4~_4ose%M(0^8aTFV$2_$yjPu z-VtL~Q%twA!RTI|m>k&V9%LpfvqOEbN#2OB`-cZSgdA_Znedut(gBnkc?B1i*E$6($g)Uq5RFev4C)PpI$h^* z*t9LN44(V=@gsN@0v-pQtDY8T9(WhS^j8u|!~2ugDcgw2_5V@>KJc&54kO%BW5&jDYtRdtRp3{SzKcD954yzx~!f zdy_zD%knUIb9%Y2j?TgW4#g5O5ZDz+t{nqpN;BN*@Pn8tutPjAa2`mE&kO?;$CLp$ zY!}tf;ps>&7^ZT#s}vFf#zcD(A?ZK~HpKYWPLCGk%ydo4C_Y1Z}9wekPfB9gLdwQB$==R9E* znb@T)qT>t8d3mLpX;K#^fT}pIjJ3?fTFMttgJQE@QsPVTbPPW+k-y0rEdR}J&tyHN zfj~)q#+925@OHmbcg|)k9rTvz^0r!tKGVbu^ zF{XE5^{^2o*zAIm*u3yw2*R860hd=PleXFO*fVM26V|Hd0IWNj9FwcuM+tcs)g`h| z|I<~9LSqAgKnx%}M`Gu#6(uDl69+TKxT00^4zO6gKyQMw@-R@-d~A3a3WUL=>y_PH z1zIL3Ml?2_9YtpOP^`l)d1hx$jq}wuB^?S2O8xe$DSEVTk%7~0U}UToAW3f47SPur z8v`!artw-$qS+f@0dgXf7u+YtH2UEoGhP;EwbeLX+@WB)3bv1m3;Z2+ACPL!`UW4(H`m3kF(Fkzh1g4 z?nYkzYYX3bq1AA#{4_~?kiqYT8-O)C2HAesUUPgoDOj}9P9_Y-C zE(u#yg#Uhw{AtP^yHR4(51yl1Y!yM@`{#A^Z=WQ@0YMplrVL+Jc*X!Fh5w;*{@2Ha z20?=7R^!Lf+5g(?zu#VAlY!s@6CbDj(*F-!F!+}s+SnRaRoh)7CpAsy2qcJ(;rd7^ zLHSt!EK%F3{*a`S6$L7Pbe~;OGF&&(3uyo)7CU!wS>Ep7!Yl9c{b;Ffx8`D|-`1@&YNt~(pz((0`ioH-Oz~-{@6$(v4Cb} ztQUw!ra`EXb}$wtD0nY3_(vz24lE@_vpRo z_kD9WF%a38kDt!iJgTE9!LZH#*7r27vYcw@)fqa?YIH#m-rSM&JA5wedvbV|Ff`s-;N+YQf`-hEcGbw6fT1 z#4Q^9f9w4z@pvwjVOy3;L_d9D{C!moh)XMbC(uqmtUh;f_&wW2-{N{N#SEQ}fc?oE zX^TcrzISQ^D!Uiqw|+iD=4)UHqd&&!a6h_HB05a|qvKwq7Y^mS&(DR%Pcu;0^U+|B zou}HVI2l%jqo4l1jy~i4L6S%i(t8&+;6|lk937n(t!KuId2}~A9Jep&zhgJsE+K#k zleGA}v8CE_TCPI=PMO-dpR}>uEP^8l-TAd?r|mitIQ%JK{X701R$?@9K{glE=Ep8C zKvyOo(ov0=PsfKMh5IX?3oGy`5405hbjZDws_DFszc58}GY=(zBo; zUrAKz$wBv5+@lJw(+N%^YpX(5d4mGwJtu4mG2-e$S*Dfe#x7f)T^*H1}&z%ElpN7_#Ph3cAz@rnz07*Uc$Hs5hL)UHb<>b-qx}B=9%6793{LQ-MDRc77_T+e(cQ0cmGaE33ULrVs+vLtMqOCrB7JY z^*j9Ew-})`l+RdJT`~m^XQQGZWKAuT?O9#YcEYY8E;*FqV2U0j%tk9%nN;1M0j}=5 zFg4c%LHKsu3>B?%uV|TFbkJRO2YmbD==#{jj7oXmRDeAuxp*hFo>`0TS$Yhu%@r&} zJlT}RKCAFo2U>gh)x?EW?cIlTf?tn1GVT}YR(dCV+VAH19C9Ayp)5M85m^lxIUl6!=&u(SBx#XQB?t*Jr^`Sg_FcJjT1R1;8d=FyYVxDt{L zwvHH|Mg(nH*6en$@)+G2vO;9<3@bd{vF&qc$(e_i?ui9!OiHEeyX%Prj;SIeJ=P)^~ec}0lJp-`>P!g#dx$-t~)ue#tj^I z_qw4iXO!0++Pj{i58cZvaP;;3CFX~t5ORtm)_~`sT=jRHGjBHCIcT3IJh@jMtNcY# zh6Rgm{hgNAqa2aog&7J#rH6o6{`Z2%E>j+oz$=lt*X%Y*`A?mh(>(?GNtn|>cWw$1EZ=GW(GUrR_CHy?TvQY$21UK&QtDfzd?AY&{HMbz z?vykbW;7E_+9%2WJ+7EHqj{TkT=&b8Fi=?hE^o0al@v>z95mCl?91WjD#h%`HGp48 zYYr7i_lV+}y1YUhJkiSPAQo^avh`+9VT-%?JuMW@R}7nsb%(=V?p+YXxPJiRXkGdH zS4=f=6Zu!@fR{kJfi7=IyfQ|w6^V7AUslOlscgw-8?o2xoF|qzXVrE68d%Qm7yVdI zP%NJte~!OIs!n8#52>Rc5qi@xsD5|8poIQtc(Gr!fF)|@ETIa;d$sOxx(b$Oz}#)i zphau{ArO7`EMr--7X+WqAW@RtSK^VpRzx_ew+L3T2e0HVO3wH zvEEj|ZPL+3ZziZ@j{|40o>L}Ki89!=1p+Wd^)anL2=|T&3qLa=Z+N1k0LO!Ac&iWQ zsQ|Y0!2Sp4Iu^0inEJUd+9@RfDsX10=;#6@9Y$pd;BR2$ad_>rz=JQB2M`WT+r@^L zLhCGQrz*FHU6Xd|m$2F_fls3yWwgFv$d5C45OhFCTEOS~`@AY)1UCP}*11+#SPOIq zG7l^R$JgmiG(>l3YsnxUG`M>P%pmRBXr&vq`?rkl?ta0P`<}S4j-us+aBj4Rpj)vS zEJQ)&_5^dxKqJCm&b~!(v(tJINH+xrEu5NSfy(Y0cD;=c=41y;f`h(Ck7BgSDQS^p zws@h$PD{dK?;YWgktZ{^2>x+H544KAE$v>kJMXJdr?I$f9)*2V6O%X~L|alU9dglf zJG5Bu5u{8n_lLM>G(s=tHyp(IgTpKA77g$MjRtR60?4!(mz12lk>PU3bcrGpZi3?09oqxOYuqA66J zET-8$c-Ts=1dfe1?6WF)~hZOG{hilXivC zlsnu^L^9Oh9fXo6bfWB>Z3*yi<5|b+N6W99nyF5-z{%tO?=bbD&4RYJ=RsN=HaU}Q zZ_+fL`}(B0#IX|O3{e6tm7*V|ZO$ZC*Gx4ZJwA_0ar@vx6gaWe)}gz2byC3xrqzFP z&f6W}Gk`Nd`pG;+X;;_U9FbWHX9IpCy=EriQm*wF$bWi3@DpNpH#3!s%grw8vGREU ziVqE#VNgL#?+Q*8_&j2)xT=#p!aW|FlSa=rR&i(ieL|a=vXry?DZ8PcGni7BFLBuWu0u8CWLo9#%7PLNfW5fXTuCRGfaSJ+rX=H>%LX z-uF2Tk}Lw+{{xGs**VL2)H`B{OUsPp;iPJ)qvd`xDMKtMI4 zb9_%#W5I%QtfI7Pu2B$nq$9gOy>Ip4ip+nmW88_YxYaQ>ntI6~E82j?OOYNYQQTTVrECJy58)q@&;o#CGi2aY?@?SM~->-Bsm zahajX<$mLnB^rAOV(-d%y({G6!4+00FVn2rLt|gSpTV-2DNhTHluUlz_u8N3rTELs!1kLChj<2pu1K z$n%ymQ_dq_wBFL0Sl9shY)}SqEiRE4>Z|sE7S|SB%{HGvqWl@J0UH@M(rLmkHt7yo zMZkOUgO|FYU%>}cm`S-!{ITf1vXumt3tfscbdWK3cf*DxGJF!?jauDUz3Nzp?haxE zUiz}MPpk6}FHW8I1p)ku0+1e#-yW(fxJ&IB!YY)p1Nmi(B3nSewO~~lTFzuH$8WsTt#6T&CoMaQJCL^&Bsl8*p;>tO;nbl79N2&`fYbHszpQ zKc8zDCl6~@E%g_FZ;Ow5W@($(hX>{f6sExg$bFbF#L$Jrl8Hto-(eHoV8@oD2_vJu zw9TW{&pqaO?20uBQqruhnTBzwva(lk2#^S|rie>gas!odAZWxEgA7GP+(!Hya^AfG zn?4WD4#^oA3-2ydlg-_i&#zq+TL6J^Jn}0?kM7)hlb+U<1to^~%w63AgzzoHbSVpme=;j>-|=sP`^dfdD7fLJut(5=|x2Pm2o4);>vevGvK96s}^ zQ*Ir_`*|JBM6FkTHpmEBvvlnQcY+VyU9+&V&sN>-C0tff*{pTt^MoeEk^}VOa1mvi zr)Dql)SiK(@Ohbn%eOb-6@}oHDkRSH2r%8ra{UcEU8A@#gfZ4cJVg`LW_Vj)y~bhJ zE|m^xtTnyYIWAv8{MY`BH#6Q(v`RmC)lGPL1=CN*WaeoWtpew z_7z%$d0jex5*7V4#MfY1epaREGDsOKZJ!HaPzV~T54$2aaBSiF%p|K^g11d%J;kvy z*xRWQW~K7X zAsF^+8g!`nk9IAp2`XBBQA{@tDywE4QyfZO#xxJ@@`M}qTihmkqi`2*D$)RIM(@F; z3?^mO3DjEr=vRT;y6Go(UG5_*XyVZnTXZ~tU$rfdMe+xsE-zVLGs`Y@P#(iatXitJ zSlkq#vknY3j3uCicCvlOd_^XUD>RMK5Avdr(lqyeAS}tQrkjxn zevCRTrWsSRTs!f7Afq$hhhQ{NY5?{u`5N`w3lw%fBU+W0axPk>tW^?Sp2std9xAM; zm)CkI!K>}m&lVQm{@~suT-D2R6V$BsCfi`PeB33;53{gsQ9zwPf1O9`W*L;s=dX$& zy)IH{&ux7~4i*2bc{LMm@U(~XCLiWB6#CVF^tx|@Qr&CkEO>|~3&WJZ;m&IgAkvn; zcHWt`Xh~@|GLRUq;@YLX!+yanz_Dq%?3O(D>+H0^sS2+OaPGSLdmDMn(MC3^5!-$d zx-eC_hWlz|eySQ2udQy*=dF&6I+j_F;{E7&9yuc zG#Osr9Jh*p1mlCdt(5wk4{Ow7j-x9`)pOmZ#B`J=|rlVjlpW zVzw$}1`|89RH|`W(!eWHune&;5p9(7+T>Sr!uj1Yr(0hQ(o^+5XE5jc5YZ`d83lZo z{W`ld{=@l@i~NiZ)uY$zlrZ1bn$Gd&22L>XcOUmf83c-drnGL5dC5_C*-lWBNMY0FmRS2D+raCXbZYihYceO{5KS~U@((h{@Gin8UV72QoH zI>o^S?L6{7H(mmgCY=Y-7f-X9KvO=@`ajqKbgtTre#xs&n5R%q-zZ%B>he4YHolWd zkEfXr6^d6FbWuKt@KJ*kjiz>cEar@Kux~MAb{T5eSZMiXRr5J5YgSbQ-^|nM9Md#Y zrnjB>6H+AT;XI`8sI8W!0vb~)nXmgf=`=0Mp3mmrwyh=4SUQBa3BEsoW!u1aLW;Lh zsk~o<36pm1nMIQCr_-`McPfGn)V8E3PX%>r*=-U@4W%@>wGLPwg(Rh-U52~cCVNOi zX%iq>LfZT|?L2`JoYOo&-<5!)fKbU}=m{&$aXF7NhyQKc$O5JFWDA^0Cj>ecONTWd z4Ylt@T&~+?Ls+a_^K_W_`E>yc3thcH(HSu8TS{6C!12Y@0`IL3z;C5?_@lP6c(T=O z!4x%Ls*0B9txqQsQV=|dvEOW_#_%RwYk7UTUOfS?wI=~N>Ey;o&_b3FReEd%Vq2t#f5R2>s>Q=_)e@YqVz!K;#d63QN29>)qyYfcb=hNhojb{|D2Yq;G}hIqSX2= z16b~Xkb12?jq{6Yqxu9qG{WBI%oB-83BzA?^$GN0zt^V6xjxSVBh79ad|aKZlR7WE zjokNdZG?ITu&gC)Xe*NzoNu3FR>e#Gc0Zx{L))-4j=-wo-h}?>;-(>UPj=vT^e#FI{3}7n~3F zt%NALDdc3Uzo$XNidwcU{`#fZrP~pJbE;7$(ly)NdY>gi{eDMlgtJUguJ13YqOMEg z>V3cR#e8+?zD4l0v4OEuQSgXu_ewhFnHAR04zQI$*C)&QCUo^FU*rA!QZB1nuua@#J1^ILgZ^U?&}?alQ0`(3v5jTut=LsSB#&4c376w7tI)Q8}+W=G#3 zVzg=uhef{YqP+<3%jxqn0eO_mmvMdw=jV8xeX3L6>u?7`KVIUCRz(ak$E!A=KXv(` z?xkSe_JPIZOA(e2%=il)X%>v1fGhB1NjyU%N->vLix}sM^spsmTUx=cO+J^AWtrL>Gp}3Yb&YCRKKDGsF*@XW;$&LaIj8 zl783XOU-)#0b^=^Yypi%b7jY&Q~eL@rzH*t7QlU1P>Ps?*}mY*^5?*NuJx4IN=`k7 zV8*xg0l+R!06S)~qs7P1I@SD(3g=Ib=PM5~0X-_7hXC)_i0M;?Jajn7EY zFHlFBPBL7SQ@&?>H#Bnd6GvNqPZaqz0-ZF|18CKrj*&i>`+*@yL+cU>;KjtWJi4WD z>_#b|C)n28Bfs;8ui_ZJ&O*Higx$*Z-m4qd7TbVt4v$rvZhr0T($;&!3X~BwZ}$D` zRxMNpbGjfkG`9Vi{chqv?%zD%y4>8{0q=dvA=#ViX!F ze^I4i!Ox#{83BgMOtO1b>rDu^QqE2g}`rJ2(#fFhgLO#gtat4>*!{aVg zWJtbJy42DefZ8J1*B14&GP|v6uCn8tF`Ky`t<_xsb)l$*{es;^w+N>VhI!H#W#AVn z2V17lC|<;OrjVA~z1Rmm_wz;)WNf!8uS2M}0;2=2SE2+*r;rAlx^|99Rb;wN>Bz!6 z$jO0s(=KbqbFd62BDs-vwNN%%4pA9yPUFwSrs|>MPPyRc0ZAwktM68e`+Dy)?x#!_ zNwwA#jagzFS#Wgu_6TIq6pjxP&rhVE{X4W>$7Ar}kFBE~z<7S;`ar^Ud|?3`NaYVf zEoBOzbO67bwR{Uj+mk&CH;%fBk1z=pX~H1;Ks?8jc&7B-ov_NPWml9S@Du_x&wd3C zQRCYBrh4(Ri|qQnRSAth6M^dmPfKI;Qwk6rASnRETG3I)uE~NU7y1ZV_^Ey|%D7CI z6Fu^GGr*o1W*+t!$S*Eyd{j!7oX+}ba+o|a>hD)nJ%}*A>+o{zKw%^o?^Gye>vG&1 zv)}x%!BsH@kxc)C5nmx&GY-#wu0|&&HhMkul4Yd^#Vsxu=Q_s6!9{A{V>v4`zc&o9 zaWx;APHvMk^rCMZBJ(_0Iqp?}%@Ywmc4qK?wqJ8St+dlr5o2%rRSlOQ4;h}*eHf6X z57pcL@)qt5VgD2-wfMvHB>XBB$L(a!ZW>shEvbSTL-JR_itPnX>O4*8Rosgi_W?#s zB)9V|Q~Ss*{kK0Uw;wtFC-(*JCHr3!8}=*rBSQ>mckl~GzePn8(EC_}v_2>Bg z`y%@7L(&0>Bg!)PfB!adz)$IKGJszl47M0Bu^;|(=|Lv`vuu`%5@L6w!<;Xqef8V@uN-Jf)DI5GsQ)@Xh;P;$4hjX{EN zO^u@Yg~?`PWevy|-XP2Z*e*B*e zDnJX9vJ%k#D&%&#Ou1hi$+(@s?o_9i%}sxZds6VTOMNP_)Z5XmYNfidP-}DTF5;#3 zivPM{&g}7{$7{{bpk98Xi-;+(Z^-X+mU#Ss@8A(i?CfB|%%UR| zku_6E2xBC1r*q(3v6t@90zv@Hvqi6Ha5p+nrI47i|8Z@Aqd~e(L%| zJqD~pIpw8uY-ysLg@ud%vpZ%CEVE`MtC+9OWynerMNSqVw!tP2m)Oj<(`Aob7pEM1 z^YfTgY*>`f(aH($XC7hcpYucO)%2>MsvEuf8mZ@YFF27JaxQ_Z`1*bW24mSIV(JuC0RYI zEj-^{%l2D0 zgrKKpiJ7pr+>`TvT-L?xi^|GgmGXYKwThx*px*wh*81BgliQz`oGkqR>5r?os;1>t z9shZD!5_5mZ1!GtKJ3T$hdy68?tBgZy6O35mz9@3&eB8=F9)W*$`4MRJQ--6QgT#Q zI(y9t-&vZ{lAGHnrKG2OXJ)=^Fv~P}b%gWW+3bno^R=dH=bxT5=~KrqceW|Q+R8|| z8j^n97Q~yZ{I<>P;D(SX2US1(oc?jrZng9R>4Ue>11&*9c)|I12Ky>GXMas3rqlrgHHW2@67e=>b5XM@FxxJu=M2bv7r##jQwn6;C7T|7x!Hc`sZHg= bewKgonw=X2r~cX6$^ZnOu6{1-oD!M<%>FVx literal 0 HcmV?d00001 diff --git a/docs/_static/img/ide_code_settings.png b/docs/_static/img/ide_code_settings.png new file mode 100644 index 0000000000000000000000000000000000000000..8e46f0d3809d8e8f36ad32794da595d17edcd2f3 GIT binary patch literal 9460 zcmdUVWl&s8(CFfB0fK9AclX5!8iE9uMS=zh4vRx@2~LpU3A#uKEbat?B@o=*m*Dm` z_a^t=davsJ{HngKI(ufOr+a4FPM_}cMoU8p2a_BV007{qD9h^r0EqA=Kphhd>lAJ^5&&th1n?qocEvez>6WscH!7 zIHD7JuI6M1O+qQbA;CCPJ+gM_9fEp;q6jLEg6#Mrb<;w0q)(krA}X+Dn1 zdal!T(bZ<1p0~5=?`(9}+zDOGno9LVl~fv`MCCRGtKN+++D>rTQqrlbkzPDx=J=I)bZlLdl1+(A39zdza<*HV$#$!MPzbyRZ?uh z?)3F?Za*jVo}}CRG^7=Cu5a1U%`~m=wk|O|U@cSe9C7+M@2Z{*DNBYwSsQ;QM`wXa}0>3}Aasz>! z-R)dF9;JI;VeMiG;_`NK{e=RM@D_tNoj@KIwBAmR&hBE~ zlJtMg5QDdW4Rh1e{x!wJL6Y7;U5ggz;s&A>;^N`rp_js>rKOc{v$he_kyreU4)00Q z+j)4nig9y$d3kYp@pHMj*>dxWii&dc@Nx6;al&VCy8AeLSa@?fyFdAx$R9fLAa^S_ zdsh#87iZdEbS*4hUVBK=)BjrNpWom61bN&4yOOi}?_H8R zF)e#R-Y9M*c7GH=+dhuY>=G5r1>}uTeOgr7$J9{}GuKCeoMQ5CDLVLPcIy*BjyB9h!Hl z!c@Pe+#|#S6gdp_=rWd&XcPr}B4VtnIAZqbR1_?|wuaiOxCRT0>5wQ(7jy2i{HQ2Q zPZWLRa#=aeAff@`b_{o)qmhGf!S$RipVP)ne*T9{QK!ujb%)6k^>7we#ONUUd~_DX zXqh-U+E&E|J&KRNdVdc-Mi2ye(EigE#E17^nwL7H5F;qa5NIJ2{eolvf8fDvs4}DR zj9mpq{}L7FMtqJPWS2c#Z6>|*{iFWadMFxRn%_~YzL>{0{?7nq;F}DwRGT%|Fwsge28QxLicZ+_f2KQLPL$EU(zZlCY3_YU@Wt@@NbJsVq1>+O<;GkAfRn3%>> zO`{453KNw^!VY#lBl7%gfAm9~4pDpydhlpm}pCn zWvZ%KO+IYU>`~4J=2rHt{KH2_VE%`x+Z158t-n?RwWPnG!d1&CqU~tzfnOEL9{~i> z<5LEQ(F4LxBXKA;`>&Fes83oF&X|~)l@t@GIUS~}QYoC93RKTY%>1<#<0ud)T~ zlW%`cYEFD|F)P=v&K>Et>HhjoTz{|>@3+3dZ|35N7-Rr@dwa;mqMyr_P1U7T9=>!> zW?Wz`KE)*uht_sylzLccUSM-9jr1Zp!eQQZjqUQ~TEZto!W%6g85TBx;4?X+K!r)p_W{+Nmpf%}hJok@nk|KWtj9L6;xW7#G zn+&{P7W}SCUGI-wz~Q|Alkk^kT4v1%vCnhO4%yk+Q?R(2WL`A9WIUK{ zAUy3>IA;b6yDaYOm$b})(oL+F0xnX+u}GSg&W96F&vw771>T+M`?Zfs-JnOZr<+k1 zNge7V&NVtwyiK5)WO0Gi)q&s^?>59*ir#MBhqHk4lYwUQN369y4py`I!(yC(h#~ks|Ke&i5Ka* z_cjoKj*4zkik=Sh7u!#k!0s-Vlu{V&vUD5YAa)O8S+Az*o3bLIVkg3Z3b(-x`8epO zJegMLne$K|cf(psg^Wj^+Ew))9JempTf|@kPbzbPQ$9N#NO*5kG^Gs5g`FMGh}+X* zs9&`^W|viAR&TTlVl+veeS5Y6hbL|2fftW(qH#Ayk!{9`OZ8F7M4s8}zapXwVd3T; z-SpbYd31k!IvS4jBskOzb=@{ZQ~W@{=&){7YEPd_IFiaU13-$jpmuA0@`gE=<*5Ck z?Kun(tqsr&oXCLi+8nmsI6FxEJX|osmAYJJ^M^_lwo(h&C3> z;|*D1QXmz-Zl-hWsniSyiqVE@*g23|kaL;IjLWlzn1pQ0|;wI&R z?+tB6bE(wR678W*a*KC&N8ofcYVZ8$!bi!nVI(W-y$U2I*5DUZNRJ}fv@zf650d6p zX$*kYA3iMw?oo$JFxP?Q9z#M+Y4^sn@qN40wr5ar$OowX_ASnc+&>o;$={zOfCqNJ ze`M4)A^wE0Txnglx6&Osr|<2CeMw^K-KkmbD}KG`chrChyT3WQsGW2Aw6K5n&TG-0 z87n0cLA_A=dix103(IkF+*t?|eoV;B`K$vsnV3ge4*34cZ7oi?Z-~wD`Z4E=A7u(W zq2aLxo*p&H^oeZ9)Zp6_-CB(d?zJUR1Pp#PdL;UI5q3PO%Y1IW4D54xWVn-Ohev9I z<|IjI5EM~yAJ}@7eWA*+ZIXWZZk`>FvL`vEg_6j}MLWXr!#E1z7bf9^V}(k1qBr|> z8>=zgjY%}=+>7q2I6>a*#B4Mc*RR(4Fie3LhYP(vI0X|*H@E#X;CfWlP&VrK!1E#Z4eUYu?U*=c#bXH(fV zJp2^y8slQAC1=SWK_&Jo>ehphaBDUyOn+m#K4%IDn=A00{4 z<WSG{c)_ZZW9Jb=2E!{#p}#EAT^beP0Ay%RRb z*o)kEgq*dG3QPBUW)Bd<4$M*FoOBLD`hF^}oH*ko)P(hobhC3GZtt(&?%$H6a^X)o z|LdL7=h2FvZLF;g3L2LJZ!iknqE66p*CU1RcEc>2?lCxm2FTH@+!vZ(>*(sPr!(kE zW{yG&rCp8Tu?u5$`+c*q=L|^hO2Sq{M+sVE6)&n7FSg6l@@)&Vqrk$r^`wvDFlqnu zR<5TpEa(Iyp5v7dE9K!^1u)NIMoaz9^)ui&XhU}sfseQ} zriSlJk-UK+8)Lm@3-i;A&8gxTyVWqY&)ziB*H21We6UsG^}NcHA6xM`?UN` z<`73?1L!vgi2FWL>;><1EW^E2id4Y00U6n)RPIff5CXmKabGK1P*L*iCl{9)e%BNQ z?5ABG*k29a6o=Z*I5=x_2yy6s7(~B)mbl5Cl3upkM}{YoVomGyoUk ztdohl6%i+!e-RT6*yEAqp(W)t*rtHd@`P$PHR{BR8=Jbu^|Hqnk9^C1g^z_Y)&D)6 zpUyVzf~S!O{V}(x<13lVy7_QK#;)^$Cka3J_qJn0Wx4?7x4hgLruO6MegxBy+J4j%nctN{qPq zKMbVoS4gc_o7u2$&oq6^=Js5hLb*xvZ_WG;1V3VbE0qxYlk45lgjG#2!GYpYeolcagqPB6TIh((7UkU){IUH4Dbj5_)U^kK74eFX z#b1R$HGlG*oTrMm@*tBK$tYH6D4)H=lsSD5ZU4uTI_*UKT)&E*lKx_~^jQp`X@xKo zbu>Pw*f6q z1XNR_S~~Ig(az-uQQH}PkEc*DJCvb} z+B>mn{dUV+#0oJK1=vH;K>NNo!*BHsv!fd}Yad2m#_3}!8VSRRRrMB}W@zCNX&$~R z2~yn2W*DgHzBY51Mnd_Ttefb`>Rxqw-#C}`c13(*j09ehD-6>r;d3xThoU1T z>Cl_!rYLEmHp8S8Alb!9G?8pl=HNTRD#}mgV=;3+TklBRa+=FP1d@rhW9wJ2U_Z3hz#zo&<){Xn7XV*-ka=PtlV@l033$2tjZ-ZRQMk9CE< zi5v@_Co9-xB&sIlTHLBnLD)oJQT3yf;78@Zq}}hpr9&9!l6b6i>>8b5Yp_ijhP#kX zYp{pkmvGu<0&W@zzQDp|2Jp(vPXv9!(FiguA-Lz!&iu9vG~^J+hm~k>2;)b1A~VGs zQ$^W`AgxBTSGiVmGMV0PV=4Via>(=Bl#G1o0KpD2nBA|`e?KknTB4q96E41n@6R)f zpF^7QRm5W>h#*?x1x#_(_*=vYnfgN$AE_osf59ykuO2~(R>fJ*nD|;Oe+9m$d-&-? z2-J!v4Yq@HP5*<^{ibr64Xr((w<@yu+^t4>PkHe4ElTe|E+8$V10S!qa?9^Dtl7q25m@^Z z&z)gx+Z8bfVp+-TWPM1m0Bpop^}0=BkL8V3-pUZ^r21e|FiVMXp>_M-=j*LJZVl1u zW`+Dxg92|0tJ?#0%^SdehBQE|jD zGiBQ2$OVT=)X($6&H)T9d~aYt056##Gs8{W>bHaYfVDA|+|Mx0M-gRmr&K+n4? zEIZ>ETg0{?lw_#iN0^H{8%5>IweW#le%-6VaPWzxVQDX+bqc`)3F1n~GH4LhDp;&m zg_Yb2r8BatB;J2D^pnyYbbV?OIyx_S#hT4iJLk}eN`FXzODd8S1;V{-3c&VO`1Vdp zok>(Sws_Q(KrZknh@}|+!@hOBD3Qx<6n(#P5zP!>GpfNrC_~$Yo(1bMM@==4Bk4sf z^VGq+B$XFw=^+H;q)^&E1BU2ga3Tte^(V5FYjN%*H@FPkvrb0*Evg#%BNCq2_VAis z6~%C;ZhtYq8zhbF_^!}PN%M8Z^`azZE9?4+3#pf*b0N+?Q@`CSr>Gv`?=Uuc?z0Oa_w84~3TVG<6CENq_p zVwSBVdkf85hr`QCQMFTQWf|9NC%!(3X=2h(+yxmJ|o#h!Gjgh5^Pc$G#F%8 z9sMwp5oK}Tjf523ebspl66H>8Hn*;2jTL-7MNH=n!Nu;JJMp?$u|Jfe0qIy_ zH%XA9p?iK34uP7=$(oaXz93~rcMy_Q0DdF3!YtVK)PGXL>^O&UuuFhbt?X1h%0oUQ z-I+nRijM`>kCdO=zeLWR`JOBL$ana7&b_-Mccsm0a>@={nVz3GgJyY@QZE2<9 z*a(Y8A6H^w=wludqg2?XaWR!JzSqRKOVC#^-9+!hKz6SEmU3n#GzxLZyO$$qRu4Ru z3vBw_lVQ<@yr=p8nf!Pe;W(G@SL225mBV#q4mW`s+^{TGLNW}C0sf=+#WFs`FP%TT zS*7YH?XFy+mSE|}^VdV0!uP<_)r&tD-MR<+?E2e{&aKH7)V6tIWpX}f<;D9=`t5B5 zP#?U`nz>ymt@QJ(Fy_wXZFhSw|7CkL4c8$LpDHD8Fym#Hxy-olG`AHi;Sr6?fcSRT zMyZ5c7J9ZFOiMTx;pkfU){veSvLcbv!N-cmfkX^8Z|?9z<{S_4J~r=KO{J};8rKUW zq5kYe*43(co+%dO0GZkg{f-wLzm>puRrhEuq7D7AY$v;oW7r32RHShdA{&tqJf-Cy z*y+r(_peSFbg5XM%=&i4+*@)XXS)wu^hr9!%gZz4FBa5bMEWr-jzeu1-OVrif#qrX zC*R&kRJ9kT#;m_-d{EgdR)+AJwk`7M1hh_)1Ike~u8s_le}sRzZi~U)|1r@Z@KKUJ}+MXMH!w_qcw#mprS}bM@s)>vUII}rKI}al|I-1lag9h zfTR6yB(!=jy6o$Bp8hk|vE7Zm@$Ena&C-eArn4KKlqMg=5j_|xs0zbvmxCC2CNXN7 zEt^J4*nazxGPU(IQF^h}MCkt9aKJHA&`38&di^%ad?-i4c_4f4k4gOaz@B($Fzz-1>d8d%0 z%q4~9a+|Edy3jJyt-|7y0+ONk{TgSf1`YA|PREQ-Hxzt1RjP@-(u*NBLq@+>%A!k5 z6f)#r(>6Svxy;Bmz8nt-^x?~{PZ>tN3LKqR_t7rk@Gom`S)kjBMIB)%)OKixAnuX= zKyB|~ZDze)J|udj*Z5l(;Hl3r>>yY3??=8N0;2maIv+Kv-c}lN7)9$}C2yUzs6tLh z9(DAb7IhOpV;tdN6p&bcTE29~4LJ@DZ1Xtp@_r~XyoaIB7$2w9)7=^@i8A(SRteI4~0kN4%N%mV7h1?c@v7TJ5F76C0RW&`bNKSjM z-{`nV8-BDvr*JA&F_8;>Qk^zCa`CoU6OZ>yExm*Z^O*}?XnJRv|G^_x?=x~%lZFmt zeLlvAq85f0CodI#?a<`Zhtv@Lszmz1J#C+-GmM)XY<4kC#ZmOp*VWBvHgT*n!+W9RLZS7+-tcv- z+afTn)&Dcm9jrh=00a_pcG#PA8v_?L8s3{fAe}ikE!n3!@kMA>Hpdm)w>wz3vf7?p z%0hg`9qm2mS?2WG`!XnE%VTpC&tJ82^{d>q#@F7akptN{B~7x(LIFcGB8m{N88f{$ za48W~NQaMHAbn~}%4x|$q$i|cqzDZ6I=^aAo?^(4q5nAUiv~a(>cyHY+kzyV5*2RF1kv;IF4b!Lf+top#dh7neEY&yK>M35U#(2ec^TM7fKpc!pC)?v@BpH5R<^W%(=b=CJ!$D6=x;TURwBi z%W6aWoU;60C!y2Zg6*vf>kaWwWJi~PLLUdi;-gD^`vFHtOiSEeEG&|3xNm^7n1WG~ zUx_a+BtmtXqu$ofr0FiJvEu3{uln^{fnT zBv^dY1Xo>!|Aml)@_1Tu2u;9DRFBNem7mHKXLx?JG>t&XqruxUt$%Jh%o5U~T? z8b5pkn1Sm9K_8$zaTtor1mo+s8Q{z2C`gvRcNo3??TKXwlo+t#W1dc1cETK6#5&Ne+ndP>yw(zOC3*-OYkz74vhF3=yONcd`m z3ht%K%@_r1qJ(mzGnLi#?J%o4j1%PSh;0;8fzgErG3BCohj!nuFz44(-u6D%-J68n zF;1el<#Q|MMnZYiX;r3<7iz~j>M`{NXNGGGrB>Y3SK5mNPl^_bkH{6479AOG6^Euc zKWXE>$17GG-#VjqJ~#2GQ&fmaQ8dC<)xY2E2fF;Wk(XSE7z%)0FHutwe`|?_D=!+$ z>D%dHQs&#!JX90Sx0S^THUZ!&!=ta|Pj?N;mpf;do-AK0UXovq%;59m3XY1J!eu$=*+NE$vf>MngtC-7WU&@_1Al zhe}Y)vf{w*MOQ&a`?B@s-Pp%ZTt-7T9PTTwR>=lu8q(`Uj|^|3 z&bjjZD_c#^Z?)_LHmxRhW0ejh&CEG`b|w6r63Zo&f>#=`eSgP0KQIWC;pOi*rR?aR z`v#WTqE>+iJ=qS>?bsZq^V7Ph;o&Tayo!}8rtC|X5XZo7aa@MQR2O51xR4Gn?+2q1 zu<*|VP*h#KdIGgQcCvb#CO*YQ^8l&R*4%?j2dbo6gjd8BUx?<}?SMo-OVA6l;;$v0 zy_z{YYI1LB7yMD1K&QS-{9YFAxjcEukk48wf9ib4^od%?9bZ_s8%~lxvef5ei~tMQ z*&11!wNzFEsH<>S7Hf@{m0nHN{!F;)623jVsGKiPu;^@mdd>>QsCZfq3{ytvI4OBpGq}( pNDw(du&A&h^H10rL=yafRx8co3Z%_s{Pneqih_oGm7ICV{{eoBP`LmA literal 0 HcmV?d00001 diff --git a/docs/_static/img/ide_code_templates.png b/docs/_static/img/ide_code_templates.png new file mode 100644 index 0000000000000000000000000000000000000000..d65026bec091fe6f55c62c6d45b9d03650e1e7b0 GIT binary patch literal 26735 zcmb5Wby!qg{5>j&k`jW_-60*)-3UlYgTTNb-3`(p-5}i}9g;(Lmx8pwz|bQ@_dWQ2 zfAQS=*X4PJIp^%&vuB^R)@PlGR9BUMfkA@t^*Qo&*2!H?!-2 z3-7I^rPURsrK!}N9W1SFK~J6#rMSg+DoY^|go@9;ocY$BpWoHh-Ayx=U#X%L^n4oG z4lPG*wu>^Z6z>RcnxO&ZeefO9yB~6IpTCfmoL;49PKF%KD!plS6r>FAU3bpMk)3ts zr!ryMpZ*&poTKZpWKzo=8%)hW@T2G%ZNjU_zm@H?o?T{pD@6-Gn(;iA=V0c_0`jG( zRfB<@S2@Hg#me7z!Ye-fI_1_C=}E}0v#eudH;R3^Dj+z!ZJv~5C*}asDWqO+{TuOB zTchWqr^6yGcXz|r-r%ma8?u_gUeT|TLQbnhB#u;~c$;` z#O2aKR`5MxujffnJI46;kC1t+LcOGWtLE>%$oW_w%R4m)E`T}K0qH7QzI*qC1vo}~ zf)sB31O+%k0$wD*>&cU+iD6Grf%lidOC}Hbf2AV1<~{vi$4~wMX-_0Ir4#F-sMabO2j{T#BgBgh3)6Vg+$P*DyA>hys=3fpW;P~+}2MyJ~EUq@9G`jE9siYm8K~((goa~%5Vi;6ZR3gq6mO>gba{p5t zI1{A-ySh3Gad3Eec(8l$usb+gac~I=3UY9Ab8vIN0dl-?@v?XQ==sLph4#Nn{#QLR zAQy9IYe!dW2Yaf=dOw;uxVehb&^$KuzyJR0C&<(K|FmT9@;}1@2FUT4!okJP$??DI z28xP2UKLWe_5|7L$ynO~(*v|2#?8$u@-P4YC*}XN_g8!47EBRqX$KGXGQf|4#g$f+8G`Bmci<;=gJBcNLgtF$@uo|NWVXVPxnW(LQ-1 z@kCKZQp*$RFdH4BJxe(F;Uo_|C7&)aggb<8I!}s@xt@UdL9uxF=jV884T}-=_X8@5 zPno5->3Ba;QS*M9Dtfp-x|lnJEM>3y!8`Kju6>V0VLA1a6#hQBYsNl@eitl`Q0s1BUH zLKh}^EEqqD$0;=ldj(z)K*2(imH^H^&>eu$fLbNQ(?o%fIXW5 zzeB28soEO-_fd{El~42E>3eg0HB#A(T1W2=nrqhlZ-$>$8jKzfQgYN=O{UhG502EC z1dPXi^CO#gJzyfQ?$AG9wBcwuK#UeUtaUi694-%QG9%Ov2oM<>uaS)`aGHJ>EPtXq?Jy0eYVG4ZqW4YqW&gqiGN}K9nJAfS>S$rzVudxRqi_| z9;P`k&n=%NQqe^T5eG>X?mPL;@Xtv1zKadk+~<`|MKQFO*Q%SozC*ZWUO(U0`Q4ne zf#83FaZ-_=VGK=*Kd_|<9g?-Q*vuDg`rbw|hNu=OrLfFw9Mprx7k_YjpMojZewJw_ z#R~kwq44ZN@(*$TEd2HHb6=IE8boV13&n7@fv?v2UO6)xRAnWz8tB~LU6>FHeoAhz zTbg@VZM4gb_Bzwf;4%-BKale9>(32PqO047^2!YzM{#t-V^7cpK~9d^f|&E=Vo02K z>C=k(4;4D(FOF6rHK5VgS9|-_#eU(~h+I-_!+;gHnXT!^7!>WwG$A1{~fA-y;H3TtA1P-v@L@!2!Np{y2{t+DO#;($*7I(Wy;$fDam zw7j|_aaK+E-4W$mbz{$C*$0UK3vVc1fPS^HsM=0Il`b`o=xxnvY7|GSLTd|la5KOC zGLu95UE}iZjC6Xj=0s-@9zy_54RC|feP`ZIbICCcRCX}{ymC}!{ycO45fTvs*4q$ zYR<^3+S$+4hP;>49PQ9tW|&YajqDeNddqQ(Rft!edbXs^Vt=o^Dlrw7!@5r<25-T}Yq{Kr+Txw(nr@MrLmND%^!+or=6&)`@aVU> zsC^QnTIxs&%kYwYvJCmZb$Ub&wP=uo_8u=9~1EJP0djlqw^ZB-xHU{yKZYFfG2H272&J>{Db_uv^in$nSC+> zs;PFhpVs_dqB+`Wm(TX6!q;rC5-hJn&nv-mLf%8S9cG-gUaXg~lo_72`U>OmAv z>ECR-()ugzB~fHv29Jrxc z4PGJ+(IxVrRgV8+rZ;k$pB2$$~a!(2eqtREf zv-futg*i87eD{t6Q7N%PoA7}f*N`4covbtmRAJM*X+Vf+vK76#0&Tyembu!ERr2}_ z-l?pW{ysfVcrM`##cIAE>UirkIE*Jxw*KQ_KI?^FA?XWK9{9fPIA7Lt-^BV#9^^~w zX&Q&7HToG9L;91a z+1h)DY_sUI(Jc^tk^U&%&^O#icNZR!C8pPlo(!59=^#O{`s6D5xj%D@ZLVd^lz`Due`U{R7A$5%-cAY zKpz#MnfJ-geH4oKmzy}&hY_gw7$rCixYRDTqi@KLO><0tVFrAn3HbeuD-^ij{#;59 zcKWm$G|!!dig*Q?u-IK-f1{IjyT^NvlO1z)=h_pkH(^S*3+oHa&cqH+5r`LLNOJ#R zOpHMut@$3yNO{tHr^yEcA$wvONk_!F9@$QY9%hdH)P(@QRbbUIm@^=OVu;e%i8YSO zk-_r8V>9M{vE}FiwQrJ+wLBtyjQ&Ax+j?rDi8`JXjli)HN-j$l84`7W_;5l`-{Y|L z+ks($ieyMIHJL?QN%ieI%GLRjL%SR$Gn@V?yb9k`gSRc@l`!o}2vt%^x)Rl7tpjA8 zPn&@z=?=zt?x(YCI;$qP^LwCWB0AD2Rlh({y5rc9ZaLT9o`ys&J$Z*H=%TFf?qba( zl}#^EMdy(CamC1?)b&FXW3GmAUy(S{)UCcllglLf1*>-veiD;ZoSM0;e~oWsg~;Al z(duw&JhK?idMiazCzsxDPS@sD#IzY2^^;m%UT!;Ks=+2qVuCil6kQayKIHpld__at z_DV%a%APF{I;q;i+}YZ0OW_(opyZAYvfo~r;YYy=gyId&GNeJHKx+TMTMdjl{!LPW z4~4Cc**_eHpg@+}@v2s3lmCbN=t4i^d27@xNdJcdMITY1akg$OaLeh!RloosT#;#g z_`kCuA2gzTSi4Xy8Sy_@NB~r%*lmgP-`NE&01|!N%mlgrVKru;qNJd4ivP~w)apZa zUW?W0C`}(x53?I1e^uLjp-bK`%wpxeM>d(Y!|LCD7tm&etNlBF(t&G#8Il1dL-rK@ z^s^>^OvO}qsV$qapmTY^@Zfv;2b5}sl~&h;vI%rgM~@*bXq?}H6I;+Aqv7>dV^N)0 ztB}ftlRdGU0aPoEuXI+~%?sgttG7BDx^i*kWYIPEef1`OX&^;(Xdw%uv`>K^=8)*f zAQ&Y<-{e!!x)+tkJy8_Anh>;pY$W8eJr*xM{ij_kCy8-Uo5l|23QI3sAhz-xoEwHl zPB{EnT2E~Ebk?WY64mju;0tJovhYBtFdSy7a%6}cu}@lLOg+HPC%21xjmX3#65Kx+ zu6t1Rt@jB88lia*PMtS69~?qNQalWv!a$On7Ha1fD1&RRBo0E1v~Pb-SPv_$x(HRL z8n4PxjfNPnyoiF0Ws}mle2FSFKc5>YEq&>2-DGtI;8Ipkn#F_jw*mK|gwR1(ldRI8 z<*@dN8=ncQZL8b>c86H<{e<49>S+_%#dS-Qkk2ljfznZf*|eJg+Pu24YFWX~6lBB< z=?!3rr4h>(TUjgjKny$Qxb#Ut0JKdYf`#YNdzAVTWZRN82Z!g{BAJTxeE!^b1f5Ws3z2zt#3KQtXEgG>@z9K zJ)cp)slCn~oAAQ;D<_f!_h@#e0EDO3%}#{Pk$-yG&oMEO$Np?hNnfGBvnW*c#=o!l zUF!jZSM&+?uXd|Rb+MlutV?%ddm@A%KMB7o{i(VH;&2}cIl;5$ zN|QpGLTrx2P2}Fq&A0lBbJ?bQ9N~?j!rqYxoLh=!<#Li_3AufbyXk4P^K1Qc_RWR= zez-K(8%I4dxvA0R*uyOxTdb_Ss&%9(7^y1T457QVNK&Xgvwtu~@N?^I-iXoF@#mlh zjBTqnU)nF@o6UJBDl6L8#$yH(hHhDH3XBkcy$BV}h%v`%g| z5yJATNGpuuI^IQ5JS2OhhV#PJX4RedtuS=CYe#guV2^c4mSB`{4cif>HBM@Y7xvNa#ugO^Y z`knTXBZpV``QcsVr->z8`^mq&A^UtRxNRJi&{X>sh&0UOZ{wz~QL2-^u7i95+P!4$ zQlYjeltlPNR6Y;KIkk5A2z}M`=5M6gNy;MyAtfoXpi_=48OkHWpT}fDO+FC%+86c} zr(V;ZiWqHaN}&^l7OiRYkT=>)CabVORd(Ms$C*v9<#8uUvN#vjYU+Z7Lb+ji%@PHW zJX2|o_xCt49pf5-CfC&@CdRE$3u>yAu|N5fjvk}v1SAO}+VFNA3Zt?p+|&pN64Q>$Ez zHxy2p=egLOs;;O!8}=0^kj^jG0^pIasuewBl^P(fB76P2?M|CeSy&B9zUCnPTH3@< zl$@V#YOsv5bXw{?%OICS`K9*VfQOUyvi&cwSccRv2}fLFKQk@;stWE3rMe3J!-3^; zI2fFOIHGkm>3jG}+XkSyi)^WISGab=mA5MThF`(TR2a=iIS~ksD^~iLu_H*jwjfrk zXNk$*(>z6U|AJ^9PYSYWv2*y0=Mdrq!lCh9~8 z+_yk#8~W;_M(lGOW3VWM^SH)zHK+!x1ip%Xe(#k2&Eh$ZUl^?nc7C;!Tiwu=nXG3E zm5>~%Z-?6=-lB*4l%bVy;#R19j>nmt)eA{#ZRB;UX75m{GLyYrLIKDosL!c>HgVT< zihE7j<1E~u`xm=|5W2X-y>E3tgTC>e?r3PCy&?MdDWt;WJWjIpnZTc@IK&(KiP&g~ zBBxW*e9er93977vF0TvO_kz!rT8L8!Tn_mSX|{}dr?5Rg478mehtGC6UB_@1fQm+h zB`F5q5(b*eF}`Mx{_ZJm;y!|Fci>;T6{cJhqY-yJGJK*Y9KOM07>Da-f-nm#pHS#VHdn!{@$PGb`G>Jf_b#+ak6CR1Wk`KLA7m%!fPL3Hn zs9h_8`5|yDm?DR&kT;oE=9U&7h+!9k#~j3CPxrg=%IH^nc(iFy(DN$t*YmqwL(1ww zcDxDB2wTpH;t3Z*67Fmz$f+aY@Hg3KSwr|rn6z{%>gnCz1tG-koEYWsh}JODz_84y z)AoGh##V00d!JT|$f;V`G}x=9_XDbH=ho})aEnE>kI=W+SwOtTH%lcibaZov#224 ztcYUcZ{|Y!Dlg&}SMF})--(TGy@<-V+)?;ulievvoR6>iQ`<~8S@FT5At*8t&syZ8 zmhNiCoz|3pZJT_O%qTk2KIDC|E7{J}MllQ?{;r15gS;DK@Tn^=lm82{z*$9ZsgHEV zqZuaZed=>6X0*6^?U8pXVQs4u`x2dVf{K>Nk#egm8KOk3T(>cdxQvVQeD*Im`G;Jx zG|gkNim5aiz+d@n805*_k^}OEh1%$3FaphdKr^u*6RiFzg7`|@I-gzYdF%HTSh6j7%0Jg8Nd2ywis{qMc8)wh*)^2LmMY7vOYQY zrRa$S?G}P~1|=ZS;B_%~!-LG9+8hG!M?yU6X1d1lOp`)uYVkEYm3w*LMWC%ykl(w_ z8)s4rB2d=NVkI<3mArUlvhIV?-!AsOvrqmkxbs$^OkO`Sos_np5b`?b;!nE*OuQ9h zy5;~tO9uxPZfy4w{yiE)@F*n8H6pqyJsdqYwx^d-U-fTkM;FpbL~QG3vHYr6?Cn&J zkgy0bXj+kWx%z04B>OdV5;^?N=q&>4LwcVal`@IY`_4_sD9^|K`0~_DC%LIb6OSa& z3332k5E>0SzwUvf6|eP1xIR|#^Z>CxHIC5BqpgO6{aF*!2I_-DG@ip3fF8BjtSBHR za#L$?)ukSdCY8Z$i8MbTpxk(}XUHPMktT4r3 z6E;NWq<-qvUgFd@g6VEd^1xiRX(okkhe}G9O%g;dz~c#CBolYCsn=6y-oF<|WF^$?oI|;g$FD0=(LdoFOT!2(zQ&NU~>N|76R-c`bC_ z_4F$R?UoT-R)|~a<7p2@`W!45zH}FEe9?>45e<$O#zZwkqIMGZcwT1IXJqg z`&9~lRoap_Ztak*_m>lIr>^3!)$|52%Xd3|Ps+m>?N>qs*sGgaAGyz`JYQ@~EIC|A zw=W$v(~4k-e~CyHlj+m=0p5kb)O*d**Mn>q4sxqSBr;KS+Xivv=H0!@WR~6pyS?@m zn-S#q;Puf7or>h4e$B!DV^;nXYEhuODn&546!%t4vPpXwIZ>Mf&1CVB^$mDt{UwJeB9?}zx}$gV$677kP+b^B zw6Wc%Zkr6jIPDflkt8?zz-R>3>%jh2c@Zom2iHXgQmd77U8%a71#5;1$Zs1;ox5S3`E!<-fcFq29_8hTts8l zTr^3+X#ZsRr#8fS;Hz*%8ikiyQ3v0Hy9!pIE&XX<;+RkVT-!PszlcsViWU6CSkEFT6<22Cj004f8Z5=&8ymK!2;oY(~Z(aVDt zfI-4r+#KkENJ8KW7Qil7Dx@mk0*N&+8q<&t*j#(cG7w2{=K<)WM?J=m4WOE?oT$%u z*Et3kI?ezm{XGs4&9!lspI;uuZy$Dmg!%8Ujr1MCDtn+6hcD_i~A;%P7JN8|8PiP@rs|Sw{VsO!b+3db7s>33zo={donRvt`o_I=ba}SnwOb&>8tLyH|vkVJ|)4gBH6QUQ(&2f}YeRv2!&i4)V3qsk;b!fjNUz?qU?MRol7ewI`S@=yPTbAgd!4dUaH7Gb%&1x@ z#w!t3AK}s!mbY|bXFoNIfe3-~c7~NVtk0VR(sAYuF)!DtERv6-;0iYK+DbtHHH?DN zj9&9(t?;7MVFt3jF^Eo!be_1m?YKE2<@?=FXu<}RNCc$bke)2p8AZgBLuaVJ)x2V{ zQi{b&tD93H?PNS7-RI2>Pf!Hu$3&!{w?C^*OsG&%z=6uT$R`i~zwS*3SFu^6E8HKO4j2*ak+i=^~vnwFVu?Ux#G+00@agWXpSp1eH)v$2La*LTv~_6>bHfo}vfk)dz*|Y}4%}%z*w+iFL5O!1=2uL|ZIf)V09uO!u;&AyR z*=Farwb%5j--VHy1uhNA=uiT!pYcE}IqRg@O@+8PAdcQCqg;`@|NiKIi8U+=e0jKi z&TXxin`^Ut5ySpOeoJ;uUK_smS1{0Wv(qZY@(-I2lsF|>=5eBPI4zNxY3-iJ8~JsG zij-ZV*vV{`l8dQ!*X z9s=v@sCOpgplQvYA~?;&e*qR3`;I=MmpYtm13hzslL^;-GDq~TeeYNKP4h0>&=SC2 zIjpjr!XEx&S*UtHiUx4By}3%#=#eC^6c%lfV@ZEppNpdvo)AT-yy`kWGY3ceX=1?p zUikE!*zI3w!M$&V5Ib_@0jNwHCdT2HR?p2C0pCGX%i|FKP~tfL^Ldj1#o3Z~ex#+e z4}0a+zObv`=5bfoBxx!x$>I8;(RR@~A)KzlSp#<(W&!uS&@>9XfQyra|B{C_YnObpYmO7o4}$^pJV;yc>-XR%Va$F`j*`tFW#!$2&bV_J5+@jl(`daIj}h{nR;c~;?vMpj zGdD>fw3n;65-jq2$&Kc%j<@N=H(bf`&XWmYza<&kQa|e7Q|}PZk^fGtc&r`nqieut zOVt!KVMb7jqP1!h@dvT37(**|LS{zSFp?W1NN` zxMBA*#mdFT?cTa8*%4G0%$HXa+9XdEkfP(U9Q!LU z<;j_e@@yQ>?sa8O4Hy#eaMy%`>?B$6J3|17zov_5QJLCqv4GuhDctuR18r&1kMq$d zvjk^tPKye^x)=Set&;S`o&M7u+|Cf*BCjY+hC1}jY#Tod^Rw@y&_0`ZQ2wk%$V-9- zl^SGulo@#!`Jj+*f2?l}D7GD>IE0i=-%N&C1Dj~9qJBgix=cevLR{hVDkNLx}Mx7`_^5ChfHUQ z)vw|kFtKdLATK4x@64zn6!%uu0&BjSVwZoOQsz`=^)M!zCgZ`-FSAfeBQh(YfMgw;#M%W#VZNZ>0A@anyimq_qUz8K?e@245t-0^kI09= zhO<+_3O_6KW^DpLiQ}wpipO$v0>HZDBUmp}U=911q(^hwWAKnL4)QnxnK}L!yK19-0_*nkdSBOI zTieq!UDt5_rmj+97oy8Qn)I*GLhKU@?29eg$nW_l;RU(@ECrkoWS*@FUd9#L|FRTe zAra^{b8sl7@}8b;eHxN<$W9|11$q4!mI9%%*_<_6EMs$k6hSN--$=Ym%xb&-Q57I$ z9t%Xqsy+mQ9<5SXfd@c$MSUC_bpMNmxd8;lr5%Iq>A&0-H~^4e2rle@jIM@BRMDD- zb- zH5oDDe*fhlz*lNg22-o?Y+vJB&-AoC>}H>t`jkccJ6bN-`Iq(wueDSY-(%S!Qc;;8 zD;hEhO5eZTT)Be*#n&~_sd6k;n;#?83SP2lR#-Wf&X>H))RlBTFb0}6C$0~1XGOt) z`-Bh({^<)D(Yj49F9MHaZmm(8q}tg8DDGr{12HH)6A4%{wk0~izXmi)#(#q)Pqxn3 z?rND9UT$JwO$zNbC3>og9Lp>;I~D$s=E-qAS{5{p7jY7w=9VmY#k$?!$qYp&dV}xt z7CH31jD6|hQtm$cPO6dSeT{`Iv*B^NZ46Hntuk@`W-p$6+k{n_M@Y2v`g3xq)CQw=zoj%m{1@F~%u`+v0h*>5PavPO;!=j!BBH5r=ec^lESzI+} zFzgj^4?&gdmxEt zkDqE3eZo&b4z*ZvFd4bVv5$YaH+n-)dYXnTt27iA=Y4I3kt!rHVkD$D7i-V3^q@zf zRPR~o&TcEttX@-t3m@3MXvULRUCyOX-bGQ>Fmt9hI)K{DSEUS6_>B4JiNms7MBX2h zZa^uF;l~njS12>L0I`wQ8?dPNJT_Ci#GM6b<@Vt|N|&NbnMSd%ZiG$e-FXrot=y2y zq$uORs88If(`A&r&jj@Yu>C84`uUC8sLlOs62RyyIVjW{Ryke+Fv<3IY4xTR_e{p2 zLCOF_29ha{#Ipc`hb-!;R8|%hWKru5n#aYOw@28D5;%A30yMRG9u`6)q)- z08Nn6FJ--+@QW=rh)oVEW*rT)_OlqBq#SJ&?@apP(BN(%ZlW~T^8U-)+2K}?s^oxG z8;;PCdDrZIJFyNaA8oV@t=nhvm(RYmjz1-LLIsWXcz|Z-<3?j28m~MBI7lcV*$$ID zsobPdwSG75)rT*)062U!{N4R{JZVT4y=-pyzXQElTW!3{hUa&lh>i zDTg_a_c=8a^62#xGXd|jfG?q#)6VMtCSc_bv`2~j|`lU1T#zxkfVO-Aji zfbsH!Uac7{9!nFQVxRhFUS0+*peWevI8WKS^x?^S*8Mz15Ai~M@A5pPs0WPH{S!Gs zszGVn@L{(!S1Kg71SdmkgC$*lsv^NOgK54m62I7Kb3jZ4L@r>T1pz#$*WfVP?KHrS zsQ-nKQ+>D*p-=Gh9xC;Bh7Eh|9&BRsT4-fi z=RH!B%rvT`7u$yURH(s62}MJJZlLOtIHmlV6iC=RU0yqUU`*bYU7=(FSM-|)e3oDG z@>ZTJM){1UaS#sOH_WQy>x@wWtl&bv$ z$?#KfL4a5!#eeo^gj@)l*THr*fexYquJy$R%wi5-TmA%LLTA;)&k6~17~i~75x}M2 zrWw;rf7ehr#p!KvDlPb3yPUzq|EjoCrGshoV4+@9Uv4~@{6(NHnyo>*!j})+IW7hD%D!c{RQe%>mCm~~0k`a-L ziu8AXPAVEC+F&?Jz;KQmy>%k~V#y2_S;tX`LK>BC!H{57oO4CISjAz0V_vQOKBsmJ zaXHu!SoJK7OSsv<-!V|u!p%$L=+aFh zSoJA}Xc-dimSzq-r$<7zW45bCX@aet;uLkqDQO93LdOb=Nquu%Y>(J9qqSlW*v&l` z&zEo`qCRUbW4vB7xL;~GX~zmEF^%-I7~FO{w%cH@LhiqC!!Kdw`Mu58fh?;WdD~>< zZo5>JRzNX5;c?qUZzeKLF@k**_1Z(Hs#HZFZ!A|ZwKcigly^~vAi*&Fv_UMLSy@v{GX%fI!^+D~qAK!-kxFiCbZH}$z& zZ1lAux4?vr;fL4JKHT~8gQC;&6gzXP;V~b=LTFTI)al43Dgq!7&OesK-ZySD&Gv_M zW-C`i(xO#~9+aIfuuvkaY<6qv!JlvG%pw5U+stx?!`NR{x@BkVIdKq~cIB64zgZAB zv~eF=p}B9-OjrTO64QH_Nzn=X?_g&n&5+(bzB~E9&vbR{W%mGbut=umq zbW!0cH+!}A^FUW0(@XYs7tCBE?hW=lDivo8nZw@l5tH({5mw5gyPM0e{Nomma2t#L z#>2V=JMHD;pc2GHGO_df3sgVe?(0|rhX$Jl%|&fq;yVcs=1vb{FW=Kd zje`8aQRWcDjY)a&iJ#hN;+Azzq`sh($;sq%y?T3T$e_pmjU?GBg?Ag{$v)g$YUEd& z%t(%JZM1*D*M?8dbq&_3={%oTNGeWy-sJqPfEV2x{$P-Hg73LzjG43=o{Me0Ec^Wj z!|P&`V{18#QL(l_mOEthi1O9N8%DU=jnXyNyQLn!w#}wXtUxfD4E^c#FMiBii(>Ws zAnubC$-3k6d5c3rbLf55p|93jOFlTe5c_2DvX!{Hnead+pxyqE1hH~Uf1gEYzZi9a ze&OYR@6%hfD0vfJhp@;c-GZx`ATInfctxJJe5||jnyH^aI@lLeNoo})7si(ymu9`W zjPTexO$%v$nec(~X!iu*IVCYzFIuLN=Y7TWSG^D1b14zgWoCPr7uRbKzE@wc+D^jM zUm7iC!)ECza@wkuN!UrKBBTrPZAL?O*G8fo0*acpz6(<(EnudOlnu_G&VLcE5F>W# zoLmaas2;{~c##r}qq0Q$^~uUoE=Pbr_snLkMGsG%gCLAkfa=hIFJX&P6b?5>Zri)% zLH58GA7}QZ+1g24dr16bWyGqSX0@D+82&IYy4C)?q^K=J|pT<(L-Bfu>HX zs)`NAADKHknm4NAk~(eIg%n)fc|T|y+%`J%cA4)*5}MhfspE|v=emMHz10H!%4Lhh z+GDA(?s8rnx;<*(ZCAfi8RuC^X7^nihQ#f*jiXWINL{vLT;ZfzE=koZBu~NMM;V}J z-|;}a&}Pf*-~yvKpJJV~VKX04OH%d`Wn-Pcr=>kQr1`B>E8Z!5s>*!eZUkwa#Pu>8 zdu#1JQY*3@8oQB%_q5ovmS*NYwmXL@vfN{kpM-7$k7rP@G?ytZ#`{B+@XF>p5lTNngSP@)D9$eflYdV! zKXD&@EwH*ZTK5OG#l5)|4vkKVOL!0WW7}WG+-fXMi?bX%zu=%}imN!4a>t41>B z&(4Z8dfZqzuy<(E)`S!+CrwyJb#4zw{e$-NsFK=MZTClYI}G(&M)4-Xd*KNT_5~*` zwMl-&ufiD^OW*n5|J6Jdd^S_?tq%zl@ZO8(1v=^`6}@nVGZ_0mok4x=Cf`&!jWbhgZhz zyqXFxoOM1HXfZ0Y9%-tIA7@d~va zubIUf=00@m^$q8;&R=he@6U>q#ADM&psQVY6vln( zi2j7pB@&9_lzx5|cdx&qA0EySdl%Y8=z`aGvryrcUp`BgPxWGtjGhm>j4hzrn2=l0 z(XWhFB73UbV+^muC9qFBq8b?BnU&OFRu@d&%9(Tu*wC)@+THyOrAo~go9;oYSlqWq z6Fnifp7Yc;@-?1%i)7e%4j!tqI#qg%@;br_8^sq%v)kIKcgJa3N&+oAe83|oErAjw zNl>|wK;C_F18IK+la)HtU5A`xYSC9ZSiHhpZy#JzcxuTsyu{nL;)>5%ju<;F1xZCh zw0_K}l5$IbdUZsrooPf%uLj;1r{8NxK$IGnNYYT`9{3ar?qA1a8Zztvb^UoE;>C})+~ECgW_c$yjh3bcj|hGl(>BIBN&X%8(k95H_I zV=)Th?YB&fh{X^!>9`N8)sPzg+{V*u+V=M0*c2H+SMZmi0x=LO8tBF&;q!5MPM+-| zef=doVe_jy4K0&{-y;G|GRk&k%Q(QlccmBdd;&^?uN!YZDamm;ERSsHbyJ8?c~C|4 zkK%YcKOk;G@B6nYj~Qggd06C41@}6R zz(o1Sfxp&VmvD8x+0kK_I3eK+5G;#*vL#;}d6y<6CgkbpYR!U)@0_*2MrzI2JJ*1} zf(~&ofs<#sou_#a!Al+g>*Cy)>eiLq&bw(WukEwQoNs#X8jjb->%a??Ieq9|bf;bG z4wn<}(7XIWFr&pz%j1qBQt^&Z&SlITfxD!7k?3GE1lN` z!~>J>^(~_{)7Vr^wAeOdx;V~!vs3P-a*r;=JNqz&t(x;h>w;tsXS;8}4e{h5WQs)2 zE;w)a6HNys4JOB0JvN|mtl%y^yi{x{vGh1Cvvmm~73NGbYSdwc+^q<|j^@lu%z?6t z?!CM_){uh%Z&6Q+vQ5f>>T`07zuuM;nvrUj2t10{5al*;%1{m)tgyn{-#tHQu~+L| zq^TmF+~l)fJgbI++i7hAuDzDA^0QPKOqh zLPd+*(lJ7B?tYN=U${h~4j?9sQCXt@tV9wN(o_a2>|P6)?i9UCb&vk1FbTkX`$7~n zrQi1--(dr&i(ph!@uR)yU!s&6al`A{f7J;BJ}U2t4ZE`cSPnNpo;06V%G&W<$WLqY{QiOrxC2mO#Sy>C$IwG%bxb_N~;TwAH&CCSSWz4ZvZM^WWxm zfcFG2u$a)4Q0|-EoF8r-)K8|WiCxhVRG4)7OF;kwNn)mbt0G`zqRi!}N_YSQh*veW z2-(l_6kamlbt+j*02DAJVkk^;-&v;pf z3n>%;o~6{eCw-@lPO=kA11(Zw-BoBW(_ob;`0W4fl0_%rTVDc1(>WoLJi3PEi}H(c zGD`sGk-qsT;zO0yN7}DnPj_@V5E!Eb)D={o*Os=wDl( zf2aHJ`YLdDK`e85fm64er`*i|lk#~2OwYRip8ru1*q57_+(dvpzb%p57pd)f7T_<# z^VdE+nQ1V-@nB*;5zqv3#zO}eUIuK4l3djAz4`}ez~Tvc4 z*Gn^+w06cUY`cIrixKzp@#rVZ-3B{b#%(r$cUE0D7ufw#)tMo#uvTIyCRT zy|u?7YXRxv_blkbkou>lmiCzTo(x8ogKwRMx5`bLDXg1=8o$P7k=3<6@FdiJKTf9f z8Gh$#s8Uk)h*kXdydpVl3gv*+>>-8bdGfIDfgy~hhP>b*mPzYN3sV~@v-Y2-oxfUY*> z&t|f5MmmQn@6C(OCQwd~VTpCDsF(N=aJ*lx)s5Q(exqIaNd{{r$uc;5heD-oiO^PH zj*hz756CP_*>2TD0TnGqpn@YoCa(%nCV=R)-%e|3&$i^;-W+ek>d$^aPl??vvkNfA{gH{RUXHy+6!d8I0@)n)1m7XD{dC?mTM1 zQyQ)$EdheYuT>$wHzW?>{*TJ_t`~1N^Mn(%sI!bLM(LbQS|#QhJRFZE`V#433ZtnK zE>XR-Ab)@{JTxExYjgJ@MITGpQiJT1Eo=YLtaq=!SKs{{U`m-My&0NJHC&W<&p z{`i>}u)LtcQ%QbDZLAFIRFO=OJ6G2$qBxprR!1-Fl21gMRYYZG?xEO|rLrpJdb6*K zMJRETm!g6J9Vy7K*9Fswslyz{(6-0fNxqW7fzan>?3c>@pDUoOqNbTT6Jy-4{_M)5 zbi^P4up`O8wko6gg1NWEY5=^=W#Q=b}qeC@RbQ255wB7sS!@AV&5Ifm)rW`KG11PH)x zfbPM+NYnOd$HXuSYnY!5{^*G7YA~T6VH9BB`X%_lZ<6qlA3_71hk%myqPw_CCTj=k zMwI(;okpNU#klmvawXeEle*Q9zKHt1wOHqC=Ohgdo%Z+s-%vHS%E)k}yW z#jOUERvbHChu5?1216}h(&c4!UF)yGUq}q(@oRXk9h#nXsRjshDU|BUwOdgY`+kao zKkp)!(RGRJLiOHWas@=4W&&nL8JMpAVF7{SGC#19OK`z` zdc^5(Kqtv8Oeu`F2|jYJxn3A{C9JrQ^7>?Mw{269W)GDBzp}U~){eazB~jVd>XO*; zES<*R(ZykwY$xhF>p^3z zVRZ^4!gD8vgDTauVN&g7qR;3M_Z`LOYDuBRO5ZgkgfI6Rh!aboxb|?nnpfwzzhbJ- zD0oQF0Uq-)a}gtu7P(X86Y;Q8o{h(47HjZqyVV=4v}a^1DT0_7xr9<^!niFxbs_XX zw9$OdSRQU=v8e$4L~Hb7m{j;P7UE=QQt zUnsti!q`0xAdkGPpRebU_5408_=b|owl&WvAikT;Z3Izs1b#}9PqurrP~5h=IpZW& zL~PpW77Jd7Lo1T_R^SRShGP7oKBpQ!f7bPjqO(gvl-ooxK%sNeH}}XBjeVuq<=hLt z8uHwRl~zRT{bg!>CQYMWP3>}3`8k(|L$j_%g_f9YgL~oc(D2mP3N$xQir*RA<5jOg z2acOyf4D)k4fulu?Qd1KSNkZHpBmR{e>{6n*beI3pgS)wiI78l2i*80F~KT&)kB+lvUuK~};|5^xYKvRLR`2NR0oofNSRLJ6VB-_#=`-Xpv3I#l zuP$-bC{gneCZ#1)l3Z_mfBwOk9=LSlX#63*&xJ5zQ8ELYC7|=2HbsQd*Kd@ZmdN1dr2% z3k!bPF~|Hvfb&Nc5S;9R+8cvQo^R~mK52O5J63q!00Zy0qM)$?>RfySe0~YThaZ3~Cx3ff6IQtg1}Tb6d#BaamY}PkeKJ z;zi0LRQpjDz)XO6g?e3_G#p54^r|=&bVZJ zpat6*1j+MzPd7|DvKhH+N*CqL&qv%9j!=KP*1Som>bD_MFROOS_`8;w9@G^IB$!(u zarwLTtp$T&?_yvHs*rYo;6Lfya9Iig_y=?Yv+TxBkY(oxw{=Ph&<4V_3CxC*p6>j zMtt_ZhjeUB=WjD^EUS$Q!B)Kmsi$ga0|S03Px?Wgk9uGxZo!6TU-LHYSqC32=L>S& zEBx}!t1|Zu^4s%_D&I-g2Cvor<|_Qk)#)yslGSxPhkLl>{H`sxrAyab)1p7ZgYwn8 zKk3)K`hLCBS}AnhC&BxZ2;V2AYc?%|^Tq~6G2M@t8$ihQ1MP3 z7{xp@!^2yVQ6S6WOXYM@ZNFy6Fs(^cSalS7gn^bJ;p}rVhG#*VINutjWX*mfg-4P; z&}D>rXss;|0UC++?dO>}%qy9`f~m1)R7fbWhEQfR+$(M!EFU`L57_d8ZNl)Sn!4t5 z`Iy^>wb;>CrPF0cbPqrP_*XHIk>+7{SiI0&=i;AZHGF3JzX%@bA0H$cH}oK%m~BKm3DC?ux{q>aCTZgeU*JCNYf z-@H4I)ew94;d-u>)cpd9d%L4nm3h3A?FmXhjU6Bx37g>&!i?3X;I{>-W&Oo8MR>fq_}_9SRkFGczjxvAq8Gpl2?+!z{mdrCiD~6<$G#= z)SOhr9lWz<+M%O=G~8e~%zB#`&G7MyfGoGIs?kyoGR52l^nC;(Llm$hJDP95EK-6x z8Xcl~CAc~2$29RSH{;pr(AK&Wbq!dKo%vT2gqk$IIA^JY$^908)U2S|=v< zlNobMKa>tw5JGV|VJ{@BwY3^k45-#YEE|%j>g_=(&LRm=_h{bXaPf5lKTocF=m0R& z*m{qSB3>C4>+hJ%0XpNu#l)x9^kYm@G1rt&xyC4e|*rAH@-}A%FNw+Hjs=_-s}8KGOF1ASxN-m{8Zz=*|#<09+%9Tgm={zNb;ul+P1w0(Ri15wFC4wTOiJ@2>Z3&E`2KOdFC zqq)oM)JMUICk@$gV_Dem_Fo+Zk~+d`ae_V>`I&4?tvL*mF55^(1@!{L|v$q*bLkjP$&rTrx#<&|`J*cLRFgVOc3RJ~pTh`=F za&{rhk9cWTMjLW3-p9*NmW0)YHWvp6d>p zCUv|Y&8mV#b;U`wfY}16Gx;E*oXV0<8*VT%MO>FuxpM>bN{}-DevS|ibtG6Z-N_rxd=(IdA{X&=@KZ71Ju6s6USD%-6d@UD3%^<`l5(hTR z9E{SSes&@^#mXt?L+%fi-PiwQ*M87n55HRc!7@HEnQefHDqG>?WF*+(h2{G@+Gz@! zm|xx-m>P)d;MXhNiOpQtNWFDr3qwYjT2+1{$3|L^g%5Nr8I)*RkDBFQu)`J$w_xs# zG>#BcsspTOS42$|@Z*@ovxQpt0~Fi=7JMspIfECa)AsDtuq*YH2+j_fSX)YV6WFHf=yi81sS{2n#3kU z=4*|O>U3wmN}duYoWwW9i}Jhv5yG01xWcb&VlBtmaay1*knltFUZyrWs$I^xqa$&K z>7o2#+fH5KRVs@fRk2*wr_xqBE+ga(?9Sy<|W(vN9wd-=H&x?ra zK+u~J&f`{_ePncYboh}ly~iMj8mC&yMK>Fp1*-4cY+2jwhGihNR0}}R6)QxsV+5vRqWOI@qS_C=LWWU^Q+Ux{?glgd z9yq)89cVCa>>cWqMTYjiEyNouiO_5Edm(6rpf(!9VeFqSUQeI*!r+@uC#r|845kg% zGmYIIyt1y*cD>OKnLAQGNXD^HR8_>FrmTjSy<@wLIu8%ugg#`vnwLOV1g3p;X~FoU zan{k#aMs_0&kr>GorUa?g7x@zKzmcftLGf#3^01X7>I%p!ol6yfIqk3Iyk`_)YLHH zwyq&;hB>J8ji=$f zq-yvlYq7-}#a#}W;=D=mP*Y!^1O+of+og9RZ&qG|=RDrXyMYxp;KJ#l70uZaFjJYD zdlBjpp=|ru!AE}R*SZ_A@|P5x?p|G{&^%HXpkiC`_1+S`IXLL7XWb+ly`XdMw`$fJ z0E|XZ_)S4K>JmK^BX38lKG0%SYYND8X0mZyO<=Q$#H&bN=R^yt6_*aF)2cbx^#xq! zS%~m=bYS!fmmn^bYDJ7laqGlVarUfl7c1$2U5rxLh$NH8y)p%S?q+cdF)l>*&qV3ep ze_{Kd)7&yw!?5qPy|sS0QmL(maxH?0ujjrk6w}hX;*e2h%4btu@q=cTf(o4WXy((= z$x=Kx`IB1s@&Q{>kPv6WaiOWb6I;E$%wJES{|`X&54l@?60TTpzw*xakTcy?p7H$0 zVK|!b%R72r8+N&#O3RAchhxymbY7#tKhto_n|eLtcK1H30Lwk^6HQMm|p#mhT!*o=MRpH9E^*5zysqG zn}djIpx1o8?(~O&5Jx`;-G?OllU+p|nmD=U^C5Pi^-x#WQ@qCzP1ps=ug!aS`?e4n zlL9DBwpkm_?grkuY$TpVa^2w_18*S)xfBzaoG(Q09~xS8c0p^gtDC9sQXT4Z{1Ok- zW^TXi9Uk0Bg2l@AcFd`;fmR%5Ean4VV?p1HUCVLQ%4m&6I%lI(#!pfu*Olr|X;CgP zmV@?*r`jku>Fs4zM$z4uj%mLb^X+n-#zHh}R0dQga;md9vZw!7QgSwmIbJropPnJk z5Wzo#aq5>LZU`kKHFOmLB(%UWiZ3y_GgmIWxJY9m&SFJlftnGO>pxT*;HSL*L@vuq zNlvF4PssEk=~)~LD`g#7KAkchp?A~mF$nT5_Nm}f<_o<+5C*p|`t)I^(LhgmoNwTS zkra1DbUg)ySHyX$$fi@E_;#t6j<00&x=!Z*^To74qOm|9D0qvI{a?=s0Xn~++xOC) z|8;(s`R%YgG?zaAr(`K>0PQ5NdA9T9zYbCx(8bW|->UypvKUE(-m@1jmOua=prSDY zy=Jmet=oT!7z*SWIXIb#{`1)-WCxgEF3WkByqw8j76n;l)Zwg=tu&}Xw;)KW-+;@X zrw6w7=x^JJTUs3X;{=1b1(|=&A3lq7QQEM#$6x3Vtn%TpW=O!zJsnsfbBec>>3^PxvU+`JiONXcSH z{BBA6`52^x_Ky@?^_gC&5aF| zREQ~%<$AE5>8pHx==it>uDnlG?jPeVSmP}&$Jdce2Bz8Gk9Hy*`GQZ>)3q@Jo?a@R zXHr~G0XueU&n!{eCdqwX?e5<1IThT~6MJBK(HN_G6R~%jXMlJMs=wf0g66{BW@ zs->Y3Ai#qG9rDBr@jw|Wg5mL2gRS?Ql z%&mtybYfsTptC3p=IAOYU!?3K9g8P#Se|%tyuZ$m#YXVH;*%}HneallIL8ymjUd*L z%tL0FjEmlDZ-{_K#esJxwF8*eg+0hl!lr2=wY<8MpB_fWsbte5x(WydYym9f9p=^H zljI9HP$};Y)a|d=u;#n(meNbzUR?pY-QxLi$QEI_?!h|uTLas<7Czv-^#|saw_6SW z1+`L;CPk_He$||!Q`bsZdjy_}6W*E$VCs^*1`W8k(m3JovA<_Ne|*vsD`58k=d&?N z6lUq?IJ^SjZqfx-{68b*!D*CQy3FX*Wq_LLO+le9YXo-^$tfB4-8 zf08`wD)b%jG4fQB#JC%`wtqe6Y5?{+8>RGUZutoKqp3Ufm}JqHB^;OQcuD*)0H->5 z#w-(&CZeOJe&UIH*dh0JCPebKq8~2<^h8SRX2bdJfJ$I|iu=SYe}J%BGLTYj;2!b4 zL?AT@`<=D=QD-KLvolTp@5X12v%zP_N3ON27S-01l1Gk{yyu9aS2WK=z!ugXksrrW zT?y-Pz)q~1G57r22GdSE62z+~(<&DO1Oz4DCao`Ul))M>q&-Oelv^G)w8Plsy&I3?_<)=iUKDu?(P=FO6*9CPQ)k8&=hZ8 z%jC!dKU>K5y3m02i}EL=M+n{=;r8!7PfAR{}7MhRwa?Mzi>py3$z!whaK< zSOPMmGm-_+$D3A@U%6i9Z{`)dfL59IH^!hJ^1fB#^4H|#JCZ4sRVE#CCjc#l)E62!n4 z1W0B-kLiG4RkM?zejXr1Wf`Wg+5}{@62aosu{k&NkU)gOAA{l;URmFU*oc-%pyKys z11?%4!yS@XiI={BcQ1P5RR;L;OpC7l_fCPi%Sp`Y-zA+Z&u~yRIyf!Bkq|j$t!CAy z0J7o>+!J{)FRl(lpbq4oc8DqU#|Q3jEPpvx-MR9(i5La<(uQRr{p)^{xzjuwq1B%G z-*`Jt^(Hm!r@WYT-Q|YII7RTIe!iJ}rNDrPiYB=LWbpz>p1JN#7v8~y1KdE|pI-Hg zxgV-NYoS(;oKv$c2$=cbD;eM>_`r;w*KNP;XI`HzhsnfN0>XEd7Yc9Q)s_lC(AX&0azO|2d*)I6@lhdLfN%AHo zy(}IA0m^5dBvaj`}{v!YpOwG$w;X4H)HM+7~^F`q=gRZrb(mD7aqo8nuG?Et!$EZSh_$$+O zaIp5>yIEK&DYGT$v{X0aU;z5qXyFVT_1^2>BHy7Pzc_s}DNZjV#u~nG#QBhGgw7>X zJ0Yz$_1&mDVD(}$)e3YaXMyS)4Yp*-nN{TlV*XBS;Vi~MF98^gbf45cSG_!)F@{L- zR~F(7kP^WB1YQ{dw&`~7Uvd^V+|hfh#aw^1?X#HFoa`@H|8K8KprI~lk7|aKxjG~M zE!`UaBxE6*gmCGP-WPwLr+2wI*PQ|<1bFyg0|CU)l*b%0x$T>2x>g+%+;HvZ4PD}U zU8je6Io#!C`bgm%n!x^lH|%LBnf1pH2CkwXe3LsrgzQ?Uze9-Lymy1=N!0_1U*{XP zThT}hxt^r|9yT##+5W`$v8C-F4V_)`__). You can always refer to :ref:`convert-weights-via-MLC` for in-depth details on how to convert your model weights. If you are using your own model weights, i.e., you finetuned the model on your personal codebase, it is important to follow these steps to convert the respective weights properly. However, it is also possible to download precompiled weights from the original models, available in the MLC format. See the full list of all precompiled weights `here `__. + +**Example:** + +.. code:: bash + + # convert model weights + mlc_llm convert_weight ./dist/models/CodeLlama-7b-hf \ + --quantization q4f16_1 \ + -o ./dist/CodeLlama-7b-hf-q4f16_1-MLC + +Compile Your Model +------------------ + +Compiling the model architecture is the crucial step to optimize inference for a given platform. However, compilation relies on multiple settings that will impact the runtime. This configuration is specified inside the ``mlc-chat-config.json`` file, which can be generated by the ``gen_config`` command. You can learn more about the ``gen_config`` command `here `__. + +**Example:** + +.. code:: bash + + # generate mlc-chat-config.json + mlc_llm gen_config ./dist/models/CodeLlama-7b-hf \ + --quantization q4f16_1 --conv-template LM \ + -o ./dist/CodeLlama-7b-hf-q4f16_1-MLC + +.. note:: + Make sure to set the ``--conv-template`` flag to ``LM``. This template is specifically tailored to perform vanilla LLM completion, generally adopted by code completion models. + +After generating the MLC model configuration file, we are all set to compile and create the model library. You can learn more about the ``compile`` command `here `__ + +**Example:** + +.. tabs:: + + .. group-tab:: Linux - CUDA + + .. code:: bash + + # compile model library with specification in mlc-chat-config.json + mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device cuda -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-cuda.so + + .. group-tab:: Metal + + For M-chip Mac: + + .. code:: bash + + # compile model library with specification in mlc-chat-config.json + mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device metal -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-metal.so + + Cross-Compiling for Intel Mac on M-chip Mac: + + .. code:: bash + + # compile model library with specification in mlc-chat-config.json + mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device metal:x86-64 -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-metal_x86_64.dylib + + For Intel Mac: + + .. code:: bash + + # compile model library with specification in mlc-chat-config.json + mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device metal -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-metal_x86_64.dylib + + .. group-tab:: Vulkan + + For Linux: + + .. code:: bash + + # compile model library with specification in mlc-chat-config.json + mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device vulkan -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-vulkan.so + + For Windows: + + .. code:: bash + + # compile model library with specification in mlc-chat-config.json + mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \ + --device vulkan -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-vulkan.dll + +.. note:: + The generated model library can be shared across multiple model variants, as long as the architecture and number of parameters does not change, e.g., same architecture, but different weights (your finetuned model). + +Setting up the Inference Entrypoint +----------------------------------- + +You can now locally deploy your compiled model with the MLC serve module. To find more details about the MLC LLM API visit our :ref:`deploy-rest-api` page. + +**Example:** + +.. code:: bash + + python -m mlc_llm.serve.server \ + --model dist/CodeLlama-7b-hf-q4f16_1-MLC \ + --model-lib-path ./dist/libs/CodeLlama-7b-hf-q4f16_1-cuda.so + +Configure the IDE Extension +--------------------------- + +After deploying the LLM we can easily connect the IDE with the MLC Rest API. In this guide, we will be using the Hugging Face Code Completion extension `llm-ls `__ which has support across multiple IDEs (e.g., `vscode `__, `intellij `__ and `nvim `__) to connect to an external OpenAI compatible API (i.e., our MLC LLM :ref:`deploy-rest-api`). + +After installing the extension on your IDE, open the ``settings.json`` extension configuration file: + +.. figure:: /_static/img/ide_code_settings.png + :width: 450 + :align: center + :alt: settings.json + +| + +Then, make sure to replace the following settings with the respective values: + +.. code:: javascript + + "llm.modelId": "dist/CodeLlama-7b-hf-q4f16_1-MLC" + "llm.url": "http://127.0.0.1:8000/v1/completions" + "llm.backend": "openai" + +This will enable the extension to send OpenAI compatible requests to the MLC Serve API. Also, feel free to tune the API parameters. Please refer to our :ref:`deploy-rest-api` documentation for more details about these API parameters. + +.. code:: javascript + + "llm.requestBody": { + "best_of": 1, + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "logprobs": false, + "top_logprobs": 0, + "logit_bias": null, + "max_tokens": 128, + "seed": null, + "stop": null, + "suffix": null, + "temperature": 1.0, + "top_p": 1.0 + } + +The llm-ls extension supports a variety of different model code completion templates. Choose the one that best matches your model, i.e., the template with the correct tokenizer and Fill in the Middle tokens. + +.. figure:: /_static/img/ide_code_templates.png + :width: 375 + :align: center + :alt: llm-ls templates + +| + +After everything is all set, the extension will be ready to use the responses from the MLC Serve API to provide off-the-shelf code completion on your IDE. + +.. figure:: /_static/img/code_completion.png + :width: 700 + :align: center + :alt: IDE Code Completion + +| + +Conclusion +---------- + +Please, let us know if you have any questions. Feel free to open an issue on the `MLC LLM repo `__! diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index e24d65afb5..621a22fb71 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -1,3 +1,5 @@ +.. _deploy-rest-api: + Rest API ======== diff --git a/docs/index.rst b/docs/index.rst index 504b667285..485567b37e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -158,6 +158,7 @@ It is recommended to have at least 6GB free VRAM to run it. deploy/python.rst deploy/ios.rst deploy/android.rst + deploy/ide_integration.rst .. toctree:: :maxdepth: 1 diff --git a/scripts/local_deploy_site.sh b/scripts/local_deploy_site.sh index 9e75aaecde..52ba40b6fe 100755 --- a/scripts/local_deploy_site.sh +++ b/scripts/local_deploy_site.sh @@ -5,4 +5,4 @@ set -euxo pipefail scripts/build_site.sh -cd site && jekyll serve --skip-initial-build --host localhost --baseurl /mlc-llm --port 8888 +cd site && jekyll serve --skip-initial-build --host localhost --baseurl / --port 8888 From 12ca8fdbe2a24f43bbc72241a76735dbad8c2026 Mon Sep 17 00:00:00 2001 From: Yu Xuanchi Date: Tue, 2 Apr 2024 23:37:09 +0800 Subject: [PATCH 146/531] =?UTF-8?q?Allow=20"mlc=5Fllm=20--host"=20option?= =?UTF-8?q?=20to=20override=20host=20triple=20the=20model=20compi=E2=80=A6?= =?UTF-8?q?=20(#2074)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow "mlc_llm --host" option to override host triple the model compile to --- python/mlc_llm/support/auto_target.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index e09f661ff7..56f0940165 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -41,7 +41,7 @@ def detect_target_and_host(target_hint: str, host_hint: str = "auto") -> Tuple[T The hint for the host CPU, default is "auto". """ target, build_func = _detect_target_gpu(target_hint) - if target.host is None: + if target.host is None or host_hint != "auto": target = Target(target, host=_detect_target_host(host_hint)) if target.kind.name == "cuda": # Enable thrust for CUDA From 63fc9723c17261ad839f3f8267fdc7c89707ebb6 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Tue, 2 Apr 2024 15:15:42 -0400 Subject: [PATCH 147/531] [Web] Move prep emcc deps script to web folder (#2077) --- docs/install/emcc.rst | 2 +- python/mlc_llm/support/auto_target.py | 2 +- web/README.md | 2 +- {scripts => web}/prep_emcc_deps.sh | 0 4 files changed, 3 insertions(+), 3 deletions(-) rename {scripts => web}/prep_emcc_deps.sh (100%) diff --git a/docs/install/emcc.rst b/docs/install/emcc.rst index 389d3cc4f8..f82292e00c 100644 --- a/docs/install/emcc.rst +++ b/docs/install/emcc.rst @@ -51,7 +51,7 @@ Now we can prepare wasm runtime using the script in mlc-llm repo .. code:: bash - ./scripts/prep_emcc_deps.sh + ./web/prep_emcc_deps.sh We can then validate the outcome diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 56f0940165..6e64247ea8 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -209,7 +209,7 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None): ) error_info = ( "Cannot find library: mlc_wasm_runtime.bc\n" - + "Make sure you have run `scripts/prep_emcc_deps.sh` and " + + "Make sure you have run `./web/prep_emcc_deps.sh` and " + "`export MLC_LLM_HOME=/path/to/mlc-llm` so that we can locate the file. " + "We tried to look at candidate paths:\n" ) diff --git a/web/README.md b/web/README.md index e6e34918db..f4fc808b1f 100644 --- a/web/README.md +++ b/web/README.md @@ -21,7 +21,7 @@ This folder contains MLC-LLM WebAssembly Runtime. Please refer to https://llm.mlc.ai/docs/install/emcc.html. -The main step is running `make` under this folder, a step included in `scripts/prep_emcc_deps.sh`. +The main step is running `make` under this folder, a step included in `web/prep_emcc_deps.sh`. `make` creates `web/dist/wasm/mlc_wasm_runtime.bc`, which will be included in the model library wasm when we compile the model. Thus during runtime, runtimes like WebLLM can directly reuse source diff --git a/scripts/prep_emcc_deps.sh b/web/prep_emcc_deps.sh similarity index 100% rename from scripts/prep_emcc_deps.sh rename to web/prep_emcc_deps.sh From 5bc3ffa6f682a4cf42fdeba3a4c505d0e7c08c3c Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Wed, 3 Apr 2024 03:54:26 +0800 Subject: [PATCH 148/531] [SLM] Qwen Multi-GPU support (#2075) --- python/mlc_llm/model/qwen/qwen_model.py | 62 +++++++++++++++++++------ 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/python/mlc_llm/model/qwen/qwen_model.py b/python/mlc_llm/model/qwen/qwen_model.py index 5cd979e589..09bb8e854f 100644 --- a/python/mlc_llm/model/qwen/qwen_model.py +++ b/python/mlc_llm/model/qwen/qwen_model.py @@ -13,6 +13,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase from mlc_llm.support.style import bold @@ -36,6 +37,7 @@ class QWenConfig(ConfigBase): # pylint: disable=too-many-instance-attributes prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 max_batch_size: int = 1 + head_dim: int = 0 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -56,6 +58,9 @@ def __post_init__(self): "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( "%s defaults to %s (%d)", @@ -73,7 +78,6 @@ def __post_init__(self): bold("context_window_size"), ) self.prefill_chunk_size = self.context_window_size - assert self.tensor_parallel_shards == 1, "QWEN currently does not support sharding." # pylint: disable=invalid-name,missing-docstring @@ -82,16 +86,12 @@ def __post_init__(self): class QWenAttention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: QWenConfig): self.hidden_size = config.hidden_size - self.rope_theta = config.rotary_emb_base self.num_heads = config.num_attention_heads // config.tensor_parallel_shards - self.head_dim = self.hidden_size // self.num_heads - self.projection_size = config.kv_channels * config.num_attention_heads - self.c_attn = nn.Linear( - in_features=config.hidden_size, - out_features=3 * self.projection_size, - bias=True, - ) - self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=False) + self.head_dim = config.head_dim + + self.c_attn = nn.Linear(config.hidden_size, 3 * self.num_heads * self.head_dim, bias=True) + + self.c_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) def forward( # pylint: disable=too-many-locals self, @@ -134,13 +134,45 @@ def __init__(self, config: QWenConfig): self.ln_1 = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) self.ln_2 = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + def _set_tp(): + def _set(layer, hint): + layer.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.attn.num_heads * hd + k = self.attn.num_heads * hd + v = self.attn.num_heads * hd + i = self.mlp.intermediate_size // 2 + _set( + self.attn.c_attn.weight, + tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]), + ) + _set( + self.attn.c_attn.bias, + tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]), + ) + _set(self.attn.c_proj.weight, tp.ShardSingleDim("_shard_attn_c_proj", dim=1)) + _set( + self.mlp.gate_up_proj.weight, + tp.ShardSingleDim("_shard_mlp_gate_up_proj", segs=[i, i], dim=0), + ) + _set(self.mlp.c_proj.weight, tp.ShardSingleDim("_shard_mlp_c_proj", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): out = self.attn(self.ln_1(hidden_states), paged_kv_cache, layer_id) - hidden_states = out + hidden_states + hidden_states = self._apply_residual(out, residual=hidden_states) out = self.mlp(self.ln_2(hidden_states)) - hidden_states = out + hidden_states + hidden_states = self._apply_residual(out, residual=hidden_states) return hidden_states + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + class QWenModel(nn.Module): def __init__(self, config: QWenConfig): @@ -165,7 +197,7 @@ def __init__(self, config: QWenConfig): self.vocab_size = config.vocab_size self.num_hidden_layers = config.num_hidden_layers self.num_attention_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_attention_heads + self.head_dim = config.head_dim self.tensor_parallel_shards = config.tensor_parallel_shards self.rotary_emb_base = config.rotary_emb_base self.dtype = "float32" @@ -191,6 +223,8 @@ def batch_forward( return logits def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.transformer.wte(input_ids) def prefill(self, inputs: Tensor, paged_kv_cache: PagedKVCache): @@ -221,6 +255,8 @@ def decode(self, inputs: Tensor, paged_kv_cache: PagedKVCache): return logits, paged_kv_cache def batch_prefill(self, inputs: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(inputs, paged_kv_cache, logit_positions) return logits, paged_kv_cache From 96b8c33e13fc902fd9cde7fee42215c641b48e02 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 3 Apr 2024 08:22:37 -0700 Subject: [PATCH 149/531] Fix mismatch of metadata func and global symbol (#2078) * Fix mismatch of metadata func and global symbol * Update estimate_memory_usage.py --- python/mlc_llm/compiler_pass/estimate_memory_usage.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/mlc_llm/compiler_pass/estimate_memory_usage.py b/python/mlc_llm/compiler_pass/estimate_memory_usage.py index 9b4de3a5cc..cdd7e7105a 100644 --- a/python/mlc_llm/compiler_pass/estimate_memory_usage.py +++ b/python/mlc_llm/compiler_pass/estimate_memory_usage.py @@ -22,14 +22,16 @@ def __init__(self, metadata: Dict[str, Any]): def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: """Entrypoint""" + func_name = "_metadata" + def _emit_metadata(metadata): bb = relax.BlockBuilder() # pylint: disable=invalid-name - with bb.function("main", params=[]): + with bb.function(func_name, params=[]): bb.emit_func_output(relax.StringImm(json.dumps(metadata))) - return bb.finalize()["main"] + return bb.finalize()[func_name] self.metadata["memory_usage"] = _MemoryEstimator().run(mod) - mod["_metadata"] = _emit_metadata(self.metadata) + mod[func_name] = _emit_metadata(self.metadata) return mod From 1d345273f086abf5ef1c9dcc592148841bd5a6a9 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 3 Apr 2024 19:10:22 -0400 Subject: [PATCH 150/531] [Disco] Set worker CPU affinity with env variable (#2042) This PR enables setting the CPU affinity of disco workers in MLC, following the support in apache/tvm#16807. The purpose is to try reduce the CPU core switch overhead brought to disco workers which may cause extra bubble times in disco workers before/during tasks. We use a macro `MLC_DISCO_WORKER_CPU_BINDING` to specify the CPU affinities of workers. This is by default not used. To enable it, you can run the command like ```shell MLC_DISCO_WORKER_CPU_BINDING=64,65,66,67 python some_mlc_app.py ``` to specify the four CPU core ids for the four workers. --- cpp/serve/function_table.cc | 31 +++++++++++++++++++++++++++++++ cpp/support/utils.h | 24 ++++++++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 cpp/support/utils.h diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index f4466c875b..a7f878c1ba 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -13,17 +13,44 @@ #include #include +#include #include #include #include #include "../support/load_bytes_from_file.h" +#include "../support/utils.h" #include "sampler/sampler.h" namespace mlc { namespace llm { namespace serve { +Optional GetDiscoWorkerCPUBinding(int num_workers) { + const char* raw_cpu_binding = std::getenv("MLC_DISCO_WORKER_CPU_BINDING"); + if (raw_cpu_binding == nullptr) { + return NullOpt; + } + + std::string cpu_binding_str(raw_cpu_binding); + std::vector cpu_ids_str = Split(cpu_binding_str, ','); + std::vector cpu_ids; + for (const std::string& cpu_id_str : cpu_ids_str) { + try { + cpu_ids.push_back(std::stol(cpu_id_str)); + } catch (std::invalid_argument const& ex) { + LOG(FATAL) << "Invalid MLC_DISCO_WORKER_CPU_BINDING \"" << cpu_binding_str << "\""; + } + } + if (static_cast(cpu_ids.size()) < num_workers) { + LOG(FATAL) << "Insufficient number of specified CPU workers in MLC_DISCO_WORKER_CPU_BINDING, " + "expecting at least " + << num_workers << "CPU ids but only " << cpu_ids.size() << " are given."; + } + + return IntTuple{cpu_ids}; +} + PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name) { return PackedFunc([sess, func = std::move(sess_func), name = std::move(name)]( TVMArgs args, TVMRetValue* rv) -> void { @@ -100,6 +127,10 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object } return SessionFuncAsPackedFunc(sess, func, name); }; + if (Optional cpu_ids = GetDiscoWorkerCPUBinding(/*num_workers=*/num_shards)) { + IntTuple cpu_ids_value = cpu_ids.value(); + sess->CallPacked(sess->GetGlobalFunc("runtime.disco.bind_worker_to_cpu_core"), cpu_ids_value); + } this->get_global_func = [this](const std::string& name) -> PackedFunc { return SessionFuncAsPackedFunc(sess, sess->GetGlobalFunc(name), name); }; diff --git a/cpp/support/utils.h b/cpp/support/utils.h new file mode 100644 index 0000000000..5360f0496c --- /dev/null +++ b/cpp/support/utils.h @@ -0,0 +1,24 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file utils.h + * \brief Utility functions. + */ +#include +#include +#include + +namespace mlc { +namespace llm { + +inline std::vector Split(const std::string& str, char delim) { + std::string item; + std::istringstream is(str); + std::vector ret; + while (std::getline(is, item, delim)) { + ret.push_back(item); + } + return ret; +} + +} // namespace llm +} // namespace mlc From 7f1aacc01d75b7f1c44980d5a9e91364dff44154 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 3 Apr 2024 18:04:07 -0700 Subject: [PATCH 151/531] [Quantization] Introduce PerTensor and F8 quantization (#2079) * [Quantization] Introduce PerTensor and F8 quantization * address comments --- .../fuse_dequantize_matmul_ewise.py | 2 +- python/mlc_llm/interface/compiler_flags.py | 8 +- .../mlc_llm/model/llama/llama_quantization.py | 24 +- python/mlc_llm/model/mixtral/mixtral_model.py | 2 + .../model/mixtral/mixtral_quantization.py | 24 +- python/mlc_llm/model/model.py | 2 + python/mlc_llm/nn/expert.py | 3 +- python/mlc_llm/op/cutlass.py | 11 +- python/mlc_llm/op/moe_matmul.py | 120 ++++ python/mlc_llm/quantization/__init__.py | 2 + .../mlc_llm/quantization/fp8_quantization.py | 97 +++ .../quantization/per_tensor_quantization.py | 555 ++++++++++++++++++ python/mlc_llm/quantization/quantization.py | 12 + python/mlc_llm/quantization/utils.py | 42 +- 14 files changed, 892 insertions(+), 12 deletions(-) create mode 100644 python/mlc_llm/quantization/fp8_quantization.py create mode 100644 python/mlc_llm/quantization/per_tensor_quantization.py diff --git a/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py b/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py index f8a64c8cda..0943828933 100644 --- a/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py +++ b/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py @@ -15,7 +15,7 @@ def transform_module( ) -> IRModule: """IRModule-level transformation""" seq = [] - for n_aux_tensor in [1, 2, 3, 4]: + for n_aux_tensor in [0, 1, 2, 3, 4]: for match_ewise in [0, 1, 2, 6]: if match_ewise == 6 and n_aux_tensor != 4: continue diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index f3a6092f6d..2d0d668672 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -103,7 +103,13 @@ def _flashinfer(target) -> bool: def _cublas_gemm(target, quantization) -> bool: """correct cublas_gemm flag""" - if not (target.kind.name == "cuda" and quantization.name in ["q0f16", "q0f32"]): + if not target.kind.name == "cuda": + return False + if not ( + quantization.name in ["q0f16", "q0f32"] + or "e4m3" in quantization.name + or "e5m2" in quantization.name + ): return False return self.cublas_gemm diff --git a/python/mlc_llm/model/llama/llama_quantization.py b/python/mlc_llm/model/llama/llama_quantization.py index cf67288585..e3878eed74 100644 --- a/python/mlc_llm/model/llama/llama_quantization.py +++ b/python/mlc_llm/model/llama/llama_quantization.py @@ -5,7 +5,13 @@ from tvm.relax.frontend import nn from mlc_llm.loader import QuantizeMapping -from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.quantization import ( + AWQQuantize, + FTQuantize, + GroupQuantize, + NoQuantize, + PerTensorQuantize, +) from .llama_model import LlamaConfig, LlamaForCasualLM @@ -67,3 +73,19 @@ def no_quant( model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) return model, quant_map + + +def per_tensor_quant( + model_config: LlamaConfig, + quantization: PerTensorQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama-architecture model using per-tensor quantization.""" + model: nn.Module = LlamaForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map diff --git a/python/mlc_llm/model/mixtral/mixtral_model.py b/python/mlc_llm/model/mixtral/mixtral_model.py index ec8025f3dc..db41dc31ce 100644 --- a/python/mlc_llm/model/mixtral/mixtral_model.py +++ b/python/mlc_llm/model/mixtral/mixtral_model.py @@ -49,11 +49,13 @@ def __init__(self, config: MixtralConfig): self.num_local_experts, in_features=config.hidden_size, out_features=2 * self.intermediate_size, + tensor_parallel_shards=config.tensor_parallel_shards, ) self.e2 = MixtralExperts( self.num_local_experts, in_features=self.intermediate_size, out_features=config.hidden_size, + tensor_parallel_shards=config.tensor_parallel_shards, ) self.dtype = "float32" diff --git a/python/mlc_llm/model/mixtral/mixtral_quantization.py b/python/mlc_llm/model/mixtral/mixtral_quantization.py index 0e8130e051..e405cae140 100644 --- a/python/mlc_llm/model/mixtral/mixtral_quantization.py +++ b/python/mlc_llm/model/mixtral/mixtral_quantization.py @@ -5,7 +5,13 @@ from tvm.relax.frontend import nn from mlc_llm.loader import QuantizeMapping -from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize +from mlc_llm.quantization import ( + AWQQuantize, + FTQuantize, + GroupQuantize, + NoQuantize, + PerTensorQuantize, +) from .mixtral_model import MixtralConfig, MixtralForCasualLM @@ -59,3 +65,19 @@ def no_quant( model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) return model, quant_map + + +def per_tensor_quant( + model_config: MixtralConfig, + quantization: PerTensorQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Mixtral model using per-tensor quantization.""" + model: nn.Module = MixtralForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 946d8af787..119cfded4c 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -83,6 +83,7 @@ class Model: "group-quant": llama_quantization.group_quant, "ft-quant": llama_quantization.ft_quant, "awq": llama_quantization.awq_quant, + "per-tensor-quant": llama_quantization.per_tensor_quant, }, ), "mistral": Model( @@ -139,6 +140,7 @@ class Model: "no-quant": mixtral_quantization.no_quant, "group-quant": mixtral_quantization.group_quant, "ft-quant": mixtral_quantization.ft_quant, + "per-tensor-quant": mixtral_quantization.per_tensor_quant, }, ), "gpt_neox": Model( diff --git a/python/mlc_llm/nn/expert.py b/python/mlc_llm/nn/expert.py index 481b430baf..d6c38db248 100644 --- a/python/mlc_llm/nn/expert.py +++ b/python/mlc_llm/nn/expert.py @@ -8,12 +8,13 @@ class MixtralExperts(nn.Module): """Mixtral experts""" - def __init__(self, num_local_experts, in_features, out_features): + def __init__(self, num_local_experts, in_features, out_features, tensor_parallel_shards=1): self.num_local_experts = num_local_experts self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter((num_local_experts, out_features, in_features)) self.dtype = "float32" + self.tensor_parallel_shards = tensor_parallel_shards def forward(self, x: Tensor, indptr: Tensor): # pylint: disable=invalid-name,missing-docstring assert x.ndim == 2 diff --git a/python/mlc_llm/op/cutlass.py b/python/mlc_llm/op/cutlass.py index 275d61f20a..6b0e21578e 100644 --- a/python/mlc_llm/op/cutlass.py +++ b/python/mlc_llm/op/cutlass.py @@ -45,23 +45,22 @@ def group_gemm( assert x.ndim == 2 assert weight.ndim == 3 assert indptr.ndim == 1 - assert weight.shape[2] == x.shape[1] assert weight.shape[0] == indptr.shape[0] assert indptr.dtype == "int64" out_dtype = out_dtype if out_dtype else x.dtype weight_dtype = weight_dtype if weight_dtype else weight.dtype - if x.dtype == "e5m2_float8" and weight.dtype == "e5m2_float8" and out_dtype == "float16": + if x.dtype == "e5m2_float8" and weight_dtype == "e5m2_float8" and out_dtype == "float16": func_name = "cutlass.group_gemm_e5m2_e5m2_fp16" - elif x.dtype == "e4m3_float8" and weight.dtype == "e5m2_float8" and out_dtype == "float16": + elif x.dtype == "e4m3_float8" and weight_dtype == "e5m2_float8" and out_dtype == "float16": func_name = "cutlass.group_gemm_e4m3_e5m2_fp16" - elif x.dtype == "e4m3_float8" and weight.dtype == "e4m3_float8" and out_dtype == "float16": + elif x.dtype == "e4m3_float8" and weight_dtype == "e4m3_float8" and out_dtype == "float16": func_name = "cutlass.group_gemm_e4m3_e4m3_fp16" - elif x.dtype == "float16" and weight.dtype == "float16" and out_dtype == "float16": + elif x.dtype == "float16" and weight_dtype == "float16" and out_dtype == "float16": func_name = "cutlass.group_gemm_fp16_sm90" else: raise NotImplementedError( - f"Unsupported data type: x={x.dtype}, weight={weight.dtype}, out={out_dtype}" + f"Unsupported data type: x={x.dtype}, weight={weight_dtype}, out={out_dtype}" ) if "float8" in x.dtype: diff --git a/python/mlc_llm/op/moe_matmul.py b/python/mlc_llm/op/moe_matmul.py index 169140a597..95d7fed941 100644 --- a/python/mlc_llm/op/moe_matmul.py +++ b/python/mlc_llm/op/moe_matmul.py @@ -1,5 +1,7 @@ """Mixture of Experts operators""" +from typing import Literal, Optional + from tvm import DataType, tir from tvm.relax.frontend.nn import Tensor, op from tvm.script import tir as T @@ -175,6 +177,124 @@ def _func( ) +def dequantize_float8_gemv( + x: Tensor, + w: Tensor, + scale: Optional[Tensor], + indptr: Tensor, + quantize_dtype: Literal["e5m2_float8", "e4m3_float8"], +) -> Tensor: + """GEMV for project-in (e1-e3) or project-out (e2) in MLP but the weight is quantized in + fp8 e5m2 or e4m3. It needs to be dequantized before the GEMV computation. + + Parameters + ---------- + x : Tensor + For project-in, the input tensor of shape (1, in_features); and for project-out, the input + shape is (experts_per_tok, in_features), where `experts_per_tok` is the number of activated + experts per token. + + w : Tensor + The quantized weight tensor of shape (local_experts, out_features, in_features) + + scale : Optional[Tensor] + The optional scale tensor of shape (1,) + + indptr : Tensor + The index pointer tensor of shape (1, experts_per_tok), where `experts_per_tok` is the + number of activated experts per token. + + quantize_dtype : Literal["e5m2_float8", "e4m3_float8"] + The quantize dtype of the weight tensor, which is either e5m2_float8 or e4m3_float8. + """ + (x_leading_dim, in_features), model_dtype = x.shape, x.dtype + (local_experts, out_features, _), storage_dtype = w.shape, w.dtype + _, experts_per_tok = indptr.shape + quantize_dtype_bits = DataType(quantize_dtype).bits + num_elem_per_storage = DataType(storage_dtype).bits // quantize_dtype_bits + num_storage = tir.ceildiv(in_features, num_elem_per_storage) + + def _dequantize(w, s, e, i, j): + if num_elem_per_storage == 1: + w = tir.reinterpret(quantize_dtype, w[e, i, j]) + else: + tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) + w = w[e, i, j // num_elem_per_storage] + shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) + w = tir.reinterpret( + quantize_dtype, + tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask).astype("uint8"), + ) + w = w.astype(model_dtype) + if s is not None: + w = w * s[0] + return w + + def access_x(x, e, j): + return x[0, j] if x_leading_dim == 1 else x[e, j] + + @T.prim_func(private=True) + def _func_with_scale( + x: T.Buffer((x_leading_dim, in_features), model_dtype), + w: T.Buffer((local_experts, out_features, num_storage), storage_dtype), + scale: T.Buffer((1,), model_dtype), + indptr: T.Buffer((1, experts_per_tok), "int32"), + o: T.Buffer((experts_per_tok, out_features), model_dtype), + ): + T.func_attr({"op_pattern": 4, "tir.noalias": True}) # kOutEWiseFusable + for expert_id in T.thread_binding(experts_per_tok, thread="blockIdx.y"): + with T.block("gemv_o"): + e = T.axis.spatial(experts_per_tok, expert_id) + y = T.alloc_buffer((out_features, in_features), model_dtype) + for i1, i2 in T.grid(out_features, in_features): + with T.block("dequantize"): + i, j = T.axis.remap("SS", [i1, i2]) + y[i, j] = _dequantize(w, scale, indptr[0, e], i, j) + for i1, i2 in T.grid(out_features, in_features): + with T.block("gemv"): + i, j = T.axis.remap("SR", [i1, i2]) + with T.init(): + o[e, i] = T.cast(T.float16(0), model_dtype) + o[e, i] += access_x(x, e, j) * y[i, j] + + @T.prim_func(private=True) + def _func_without_scale( + x: T.Buffer((x_leading_dim, in_features), model_dtype), + w: T.Buffer((local_experts, out_features, num_storage), storage_dtype), + indptr: T.Buffer((1, experts_per_tok), "int32"), + o: T.Buffer((experts_per_tok, out_features), model_dtype), + ): + T.func_attr({"op_pattern": 4, "tir.noalias": True}) # kOutEWiseFusable + for expert_id in T.thread_binding(experts_per_tok, thread="blockIdx.y"): + with T.block("gemv_o"): + e = T.axis.spatial(experts_per_tok, expert_id) + y = T.alloc_buffer((out_features, in_features), model_dtype) + for i1, i2 in T.grid(out_features, in_features): + with T.block("dequantize"): + i, j = T.axis.remap("SS", [i1, i2]) + y[i, j] = _dequantize(w, None, indptr[0, e], i, j) + for i1, i2 in T.grid(out_features, in_features): + with T.block("gemv"): + i, j = T.axis.remap("SR", [i1, i2]) + with T.init(): + o[e, i] = T.cast(T.float16(0), model_dtype) + o[e, i] += access_x(x, e, j) * y[i, j] + + if scale is not None: + return op.tensor_ir_op( + _func_with_scale, + "moe_dequantize_gemv", + args=[x, w, scale, indptr], + out=Tensor.placeholder([experts_per_tok, out_features], model_dtype), + ) + return op.tensor_ir_op( + _func_without_scale, + "moe_dequantize_gemv", + args=[x, w, indptr], + out=Tensor.placeholder([experts_per_tok, out_features], model_dtype), + ) + + def group_gemm(x: Tensor, w: Tensor, indptr: Tensor): # pylint: disable=too-many-statements """Group GEMM in MoE models. diff --git a/python/mlc_llm/quantization/__init__.py b/python/mlc_llm/quantization/__init__.py index 31016a9952..a076958650 100644 --- a/python/mlc_llm/quantization/__init__.py +++ b/python/mlc_llm/quantization/__init__.py @@ -1,6 +1,8 @@ """A subpackage for quantization and dequantization algorithms""" from .awq_quantization import AWQQuantize +from .fp8_quantization import FP8PerTensorQuantizeMixtralExperts from .ft_quantization import FTQuantize from .group_quantization import GroupQuantize from .no_quantization import NoQuantize +from .per_tensor_quantization import PerTensorQuantize from .quantization import QUANTIZATION, Quantization diff --git a/python/mlc_llm/quantization/fp8_quantization.py b/python/mlc_llm/quantization/fp8_quantization.py new file mode 100644 index 0000000000..573dfdef28 --- /dev/null +++ b/python/mlc_llm/quantization/fp8_quantization.py @@ -0,0 +1,97 @@ +""" Quantization techniques for FP8 """ + +import numpy as np +from tvm import nd, relax +from tvm.relax.frontend import nn + +from mlc_llm.nn import MixtralExperts + +from ..op import cutlass, extern, moe_matmul +from . import per_tensor_quantization as ptq +from .utils import apply_sharding + + +class FP8PerTensorQuantizeMixtralExperts( + ptq.PerTensorQuantizeMixtralExperts +): # pylint: disable=too-many-instance-attributes + """MixtralExperts with per-tensor quantization in FP8.""" + + def __init__( + self, + num_local_experts, + in_features, + out_features, + config: ptq.PerTensorQuantize, + tensor_parallel_shards=1, + ): # pylint: disable=too-many-arguments + super().__init__(num_local_experts, in_features, out_features, config) + self.tensor_parallel_shards = tensor_parallel_shards + + @staticmethod + def from_mixtral_experts( + src: "MixtralExperts", + config: ptq.PerTensorQuantize, + ) -> "FP8PerTensorQuantizeMixtralExperts": + """ + Converts a non-quantized MixtralExperts to a per-tensor quantized MixtralExperts. + + Parameters + ---------- + src : MixtralExperts + The non-quantized MixtralExperts + + weight_config : GroupQuantize + The group quantization weight_config. + + Returns + ------- + ret : MixtralExpertsFP8 + The per-tensor quantized MixtralExperts. + """ + quantized_mistral_experts = FP8PerTensorQuantizeMixtralExperts( + num_local_experts=src.num_local_experts, + in_features=src.in_features, + out_features=src.out_features, + config=config, + tensor_parallel_shards=src.tensor_parallel_shards, + ) + + if "shard_strategy" in src.weight.attrs: + shard = src.weight.attrs["shard_strategy"] + apply_sharding(shard, f"{shard.name}_q_weight", quantized_mistral_experts.q_weight) + # scale doesn't need to be sharded since it's the same for all shards + + return quantized_mistral_experts + + def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name + w = self.q_weight + if indptr.ndim == 2: + assert indptr.shape[0] == 1 + return moe_matmul.dequantize_float8_gemv( + x, w, self.q_scale, indptr, self.config.weight_dtype + ) + + if extern.get_store().cutlass_group_gemm: + # NOTE: calibration scale should be used to convert x to fp8 when calibration is enabled + x = nn.op.astype(x, dtype=self.config.activation_dtype) + scale = ( + self.q_scale.astype("float32") + if self.q_scale is not None + else nn.wrap_nested( + relax.Constant(nd.array(np.array([1.0]).astype("float32"))), "scale" + ) + ) + return cutlass.group_gemm( + x, w, indptr, scale, self.config.weight_dtype, self.config.model_dtype + ) + # Note: convert_weight is target agnostic, so a fallback must be provided + w = nn.tensor_expr_op( + self.config.dequantize_float8, + "dequantize", + args=[w, self.q_scale, self.config.weight_dtype], + ) + return moe_matmul.group_gemm(x, w, indptr) + + +# pylint: disable=protected-access +ptq.PerTensorQuantizeMixtralExperts._IMPL["fp8"] = FP8PerTensorQuantizeMixtralExperts diff --git a/python/mlc_llm/quantization/per_tensor_quantization.py b/python/mlc_llm/quantization/per_tensor_quantization.py new file mode 100644 index 0000000000..c2776b2a86 --- /dev/null +++ b/python/mlc_llm/quantization/per_tensor_quantization.py @@ -0,0 +1,555 @@ +"""The per-tensor quantization config""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union + +from tvm import DataType, DataTypeCode, IRModule, te, tir, topi +from tvm.relax.frontend import nn +from tvm.runtime import NDArray + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.nn import MixtralExperts +from mlc_llm.support import logging + +from .utils import ( + apply_sharding, + compile_quantize_func, + convert_uint_packed_fp8_to_float, + is_final_fc, + pack_weight, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class PerTensorQuantize: # pylint: disable=too-many-instance-attributes + """Configuration for per-tensor quantization""" + + name: str + kind: str + activation_dtype: Literal["e4m3_float8", "e5m2_float8"] + weight_dtype: Literal["e4m3_float8", "e5m2_float8"] + storage_dtype: Literal["uint32"] + model_dtype: Literal["float16"] + quantize_embedding: bool = True + quantize_final_fc: bool = True + + num_elem_per_storage: int = 0 + max_int_value: int = 0 + use_scale: bool = True + + def __post_init__(self): + assert self.kind == "per-tensor-quant" + self.num_elem_per_storage = ( + DataType(self.storage_dtype).bits // DataType(self.weight_dtype).bits + ) + self.max_int_value = int(tir.max_value(self.weight_dtype).value) + self._quantize_func_cache = {} + + def quantize_model( + self, model: nn.Module, quant_map: QuantizeMapping, name_prefix: str + ) -> nn.Module: + """ + Quantize model with per-tensor quantization + + Parameters + ---------- + model : nn.Module + The non-quantized nn.Module. + + quant_map : QuantizeMapping + The quantize mapping with name mapping and func mapping. + + name_prefix : str + The name prefix for visited weight. + + Returns + ------- + ret : nn.Module + The quantized nn.Module. + """ + + class _Mutator(nn.Mutator): + def __init__(self, config: PerTensorQuantize, quant_map: QuantizeMapping) -> None: + super().__init__() + self.config = config + self.quant_map = quant_map + + def visit_module(self, name: str, node: nn.Module) -> Any: + """ + The visiting method for per-tensor quantization of nn.Module nodes. + + Parameters + ---------- + name : str + The name of the current node. + + node : nn.Module + The current node of nn.Module to mutate. + + Returns + ------ + ret_node: Any + The new node to replace current node. + """ + weight_name = f"{name}.weight" + param_names = ( + [f"{name}.q_weight", f"{name}.q_scale"] + if self.config.use_scale + else [ + f"{name}.q_weight", + ] + ) + if isinstance(node, nn.Linear) and ( + not is_final_fc(name) or self.config.quantize_final_fc + ): + self.quant_map.param_map[weight_name] = param_names + self.quant_map.map_func[weight_name] = self.config.quantize_weight + return PerTensorQuantizeLinear.from_linear(node, self.config) + if isinstance(node, nn.Embedding) and self.config.quantize_embedding: + self.quant_map.param_map[weight_name] = param_names + self.quant_map.map_func[weight_name] = self.config.quantize_weight + return PerTensorQuantizeEmbedding.from_embedding(node, self.config) + if isinstance(node, MixtralExperts): + self.quant_map.param_map[weight_name] = param_names + self.quant_map.map_func[weight_name] = self.config.quantize_weight + return PerTensorQuantizeMixtralExperts.from_mixtral_experts(node, self.config) + return self.visit(name, node) + + model.to(dtype=self.model_dtype) + mutator = _Mutator(self, quant_map) + model = mutator.visit(name_prefix, model) + return model + + def quantize_weight(self, weight) -> List[NDArray]: + """ + Quantize weight with per-tensor quantization. + + Parameters + ---------- + weight : NDArray + The weight to quantize. + + Returns + ------- + ret : List[NDArray] + The quantized weight and the scale if use_scale is True. + """ + device = weight.device + device_type = device.MASK2STR[device.device_type] + + def _create_quantize_func() -> IRModule: + if DataType(self.weight_dtype).type_code in [ + DataTypeCode.E4M3Float, + DataTypeCode.E5M2Float, + ]: + quantize_func = self._quantize_float8 + else: + assert NotImplementedError() + + class Quantizer(nn.Module): + """Quantizer module for per-tensor quantization.""" + + def main(self, weight: nn.Tensor): # pylint: disable=missing-function-docstring + return quantize_func(weight) + + mod = Quantizer() + mod, _ = mod.export_tvm( # pylint: disable=unbalanced-tuple-unpacking + spec={"main": {"weight": nn.spec.Tensor(weight.shape, weight.dtype)}} + ) + return mod + + key = f"({weight.shape}, {weight.dtype}, {device_type}" + quantize_func = self._quantize_func_cache.get(key, None) + if quantize_func is None: + logger.info("Compiling quantize function for key: %s", key) + quantize_func = compile_quantize_func(_create_quantize_func(), device) + self._quantize_func_cache[key] = quantize_func + return quantize_func(weight) + + def _quantize_float8( # pylint: disable=too-many-locals + self, + weight: nn.Tensor, + ) -> Union[Tuple[nn.Tensor], Tuple[nn.Tensor, nn.Tensor]]: + """Per-tensor quantization for weight tensor, defined in tensor expression.""" + + quantize_dtype = DataType(self.weight_dtype) + + if self.use_scale: + # min_scaling_factor taken from TRT-LLM + def _compute_scale(x: te.Tensor) -> te.Tensor: + max_abs = topi.max(topi.abs(x)) + min_scaling_factor = tir.const(1.0 / (self.max_int_value * 512.0), self.model_dtype) + scale = topi.maximum( + max_abs.astype(self.model_dtype) / self.max_int_value, min_scaling_factor + ) + scale = topi.expand_dims(scale, axis=0) + return scale + + scale = nn.tensor_expr_op(_compute_scale, "compute_scale", args=[weight]) + else: + scale = None + + def _compute_quantized_weight(weight: te.Tensor, scale: Optional[te.Tensor]) -> te.Tensor: + elem_storage_dtype = f"uint{quantize_dtype.bits}" + scaled_weight = te.compute( + shape=weight.shape, + fcompute=lambda *idx: tir.Cast( + self.storage_dtype, + tir.reinterpret( + elem_storage_dtype, + tir.Cast( + quantize_dtype, + weight(*idx) / scale(0) if scale is not None else weight(*idx), + ), + ), + ), + ) + + packed_weight = pack_weight( + scaled_weight, + axis=-1, + num_elem_per_storage=self.num_elem_per_storage, + weight_dtype=self.weight_dtype, + storage_dtype=self.storage_dtype, + ) + + return packed_weight + + quantized_weight = nn.tensor_expr_op( + _compute_quantized_weight, "compute_quantized_weight", args=[weight, scale] + ) + + if self.use_scale: + return quantized_weight, scale + return (quantized_weight,) + + def _dequantize( + self, + q_weight: te.Tensor, + scale: Optional[te.Tensor] = None, + out_shape: Optional[Sequence[tir.PrimExpr]] = None, + ) -> te.Tensor: + if self.use_scale: + assert scale is not None + if DataType(self.weight_dtype).type_code in [ + DataTypeCode.E4M3Float, + DataTypeCode.E5M2Float, + ]: + return self.dequantize_float8(q_weight, scale, self.weight_dtype, out_shape) + raise NotImplementedError() + + def dequantize_float8( + self, + q_weight: te.Tensor, + scale: Optional[te.Tensor], + quantize_dtype: str, + out_shape: Optional[Sequence[tir.PrimExpr]] = None, + ) -> te.Tensor: + """Dequantize a fp8 tensor to higher-precision float.""" + weight = convert_uint_packed_fp8_to_float( + q_weight, + self.num_elem_per_storage, + self.storage_dtype, + self.model_dtype, + quantize_dtype, + axis=-1, + out_shape=out_shape, + ) + if scale is not None: + weight = weight * scale + return weight + + +class PerTensorQuantizeLinear(nn.Module): # pylint: disable=too-many-instance-attributes + """An nn.Linear module with per-tensor quantization.""" + + def __init__( # pylint: disable=too-many-arguments + self, + in_features: int, + out_features: Union[int, tir.Var], + config: PerTensorQuantize, + bias: bool = True, + out_dtype: Optional[str] = None, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + self.config = config + self.q_weight = nn.Parameter( + (out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), + config.storage_dtype, + ) + if config.use_scale: + self.q_scale = nn.Parameter((1,), config.model_dtype) + else: + self.q_scale = None + if bias: + self.bias = nn.Parameter( + (out_features,), config.model_dtype if out_dtype is None else out_dtype + ) + else: + self.bias = None + + @classmethod + def from_linear(cls, src: nn.Linear, config: PerTensorQuantize) -> "PerTensorQuantizeLinear": + """ + Converts a non-quantized nn.Linear to a per-tensor quantized PerTensorQuantizeLinear + + Parameters + ---------- + src : nn.Linear + The non-quantized nn.Linear. + + config : PerTensorQuantize + The per-tensor quantization config. + + Returns + ------- + ret : PerTensorQuantizeLinear + The per-tensor quantized PerTensorQuantizeLinear layer. + """ + out_features, in_features = src.weight.shape + quantized_linear = cls( + in_features=in_features, + out_features=out_features, + config=config, + bias=getattr(src, "bias", None) is not None, + out_dtype=src.out_dtype, + ) + if quantized_linear.bias is not None: + quantized_linear.bias.attrs = src.bias.attrs + if "shard_strategy" in src.weight.attrs: + shard = src.weight.attrs["shard_strategy"] + apply_sharding(shard, f"{shard.name}_q_weight", quantized_linear.q_weight) + # scale doesn't need to be sharded since it's the same for all shards + return quantized_linear + + def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name + """ + Forward method for per-tensor quantized linear layer. + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + Returns + ------- + ret : nn.Tensor + The output tensor for the per-tensor quantized linear layer. + """ + w = nn.op.tensor_expr_op( + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + out_shape=[ + ( + tir.IntImm("int64", self.out_features) + if isinstance(self.out_features, int) + else weight.shape[0] + ), + tir.IntImm("int64", self.in_features), + ], + ), + "dequantize", + args=[self.q_weight, self.q_scale], + ) + w = nn.op.permute_dims(w) + x = nn.op.matmul(x, w, out_dtype=self.out_dtype) + if self.bias is not None: + x = x + self.bias + return x + + def to(self, dtype: Optional[str] = None) -> None: + """ + Override to() such that we do not convert bias if there is an out_dtype. + Otherwise, we might run into dtype mismatch when computing x + self.bias. + """ + self.q_weight.to(dtype=dtype) + if self.q_scale: + self.q_scale.to(dtype=dtype) + if self.bias is not None and self.out_dtype is None: + self.bias.to(dtype=dtype) + if dtype is not None and isinstance(getattr(self, "dtype", None), str): + self.dtype = dtype # pylint: disable=attribute-defined-outside-init + + +class PerTensorQuantizeEmbedding(nn.Module): + """An nn.Embedding module with group quantization""" + + def __init__(self, num: Union[int, tir.Var], dim: int, config: PerTensorQuantize): + self.num = num + self.dim = dim + self.config = config + self.q_weight = nn.Parameter( + (num, tir.ceildiv(dim, config.num_elem_per_storage)), config.storage_dtype + ) + if self.config.use_scale: + self.q_scale = nn.Parameter((1,), config.model_dtype) + else: + self.q_scale = None + + @staticmethod + def from_embedding( + embedding: nn.Embedding, config: PerTensorQuantize + ) -> "PerTensorQuantizeEmbedding": + """ + Converts a non-quantized nn.Embedding to a per-tensor quantized PerTensorQuantizeEmbedding + + Parameters + ---------- + linear : nn.Embedding + The non-quantized nn.Embedding. + + config : PerTensorQuantize + The per-tensor quantization config. + + Returns + ------- + ret : PerTensorQuantizeEmbedding + The per-tensor quantized embedding layer. + """ + num, dim = embedding.weight.shape + return PerTensorQuantizeEmbedding(num, dim, config) + + def forward(self, x: nn.Tensor): # pylint: disable=invalid-name + """ + Forward method for per-tensor quantized embedding layer. + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + Returns + ------- + ret : nn.Tensor + The output tensor for the embedding layer. + """ + w = nn.op.tensor_expr_op( + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + out_shape=[ + tir.IntImm("int64", self.num) if isinstance(self.num, int) else weight.shape[0], + tir.IntImm("int64", self.dim), + ], + ), + "dequantize", + args=[self.q_weight, self.q_scale], + ) + if x.ndim == 1: + return nn.op.take(w, x, axis=0) + return nn.op.reshape( + nn.op.take(w, nn.op.reshape(x, shape=[-1]), axis=0), + shape=[*x.shape, self.dim], + ) + + def lm_head_forward(self, x: nn.Tensor): + """The lm_head forwarding, which dequantizes the weight + and multiplies it with the input tensor. + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + Returns + ------- + ret : nn.Tensor + The output tensor for the lm_head layer. + """ + w = nn.op.tensor_expr_op( + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + out_shape=[ + tir.IntImm("int64", self.num) if isinstance(self.num, int) else weight.shape[0], + tir.IntImm("int64", self.dim), + ], + ), + "dequantize", + args=[self.q_weight, self.q_scale], + ) + w = nn.op.permute_dims(w) + return nn.op.matmul(x, w, out_dtype="float32") + + +class PerTensorQuantizeMixtralExperts(nn.Module): # pylint: disable=too-many-instance-attributes + """An MixtralExperts module with group quantization""" + + _IMPL: Dict[str, Type["PerTensorQuantizeMixtralExperts"]] = {} + + def __init__( + self, + num_local_experts, + in_features, + out_features, + config: PerTensorQuantize, + ): # pylint: disable=too-many-arguments + self.num_local_experts = num_local_experts + self.in_features = in_features + self.out_features = out_features + self.config = config + self.q_weight = nn.Parameter( + ( + num_local_experts, + out_features, + tir.ceildiv(in_features, config.num_elem_per_storage), + ), + config.storage_dtype, + ) + if config.use_scale: + self.q_scale = nn.Parameter((1,), config.model_dtype) + else: + self.q_scale = None + + @staticmethod + def from_mixtral_experts( + src: "MixtralExperts", + config: PerTensorQuantize, + ) -> "PerTensorQuantizeMixtralExperts": + """ + Converts a non-quantized MixtralExperts to a per-tensor quantized + PerTensorQuantizeMixtralExperts + + Parameters + ---------- + src : MixtralExperts + The non-quantized MixtralExperts + + config : PerTensorQuantize + The per-tensor quantization config + + Returns + ------- + ret : PerTensorQuantizeMixtralExperts + The per-tensor quantized MixtralExperts layer + """ + if DataType(config.weight_dtype).type_code in [ + DataTypeCode.E4M3Float, + DataTypeCode.E5M2Float, + ]: + return PerTensorQuantizeMixtralExperts._IMPL["fp8"].from_mixtral_experts(src, config) + raise NotImplementedError() + + def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name + """Forward method for per-tensor quantized mistral experts. + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + indptr: nn.Tensor + The indptr tensor + + Returns + ------- + ret : nn.Tensor + The output tensor for the per-tensor quantized mistral experts layer. + """ + raise NotImplementedError() diff --git a/python/mlc_llm/quantization/quantization.py b/python/mlc_llm/quantization/quantization.py index 3fab898fb2..1b2d8695cf 100644 --- a/python/mlc_llm/quantization/quantization.py +++ b/python/mlc_llm/quantization/quantization.py @@ -5,6 +5,7 @@ from .ft_quantization import FTQuantize from .group_quantization import GroupQuantize from .no_quantization import NoQuantize +from .per_tensor_quantization import PerTensorQuantize Quantization = Any """Quantization is an object that represents an quantization algorithm. It is required to @@ -117,4 +118,15 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr storage_dtype="int8", model_dtype="float16", ), + "e5m2_e5m2_f16": PerTensorQuantize( + name="e5m2_e5m2_f16", + kind="per-tensor-quant", + activation_dtype="e5m2_float8", + weight_dtype="e5m2_float8", + storage_dtype="uint32", + model_dtype="float16", + quantize_final_fc=True, + quantize_embedding=False, + use_scale=False, + ), } diff --git a/python/mlc_llm/quantization/utils.py b/python/mlc_llm/quantization/utils.py index 260c9a6b45..3edd53959c 100644 --- a/python/mlc_llm/quantization/utils.py +++ b/python/mlc_llm/quantization/utils.py @@ -94,6 +94,44 @@ def apply_sharding(shard_strategy, name: str, weight: nn.Parameter): raise NotImplementedError(f"Unknowing sharding strategy: {shard_strategy}") +def convert_uint_packed_fp8_to_float( # pylint: disable=too-many-arguments + weight: te.Tensor, + num_elem_per_storage: int, + storage_dtype: str, + model_dtype: str, + quant_dtype: str, + axis: int = -1, + out_shape: Optional[Sequence[tir.PrimExpr]] = None, +) -> te.Tensor: + """Unpack a fp8 value from the storage dtype and convert to float.""" + assert quant_dtype in ["e4m3_float8", "e5m2_float8"] + bits = DataType(quant_dtype).bits + elem_storage_dtype = DataType(f"uint{bits}") + tir_bin_mask = tir.const((1 << bits) - 1, "uint8") + if axis < 0: + axis += len(weight.shape) + if out_shape is None: + out_shape = ( + *weight.shape[:axis], + weight.shape[axis] * num_elem_per_storage, + *weight.shape[axis + 1 :], + ) + axis = axis if axis >= 0 else len(out_shape) + axis + return te.compute( + shape=out_shape, + fcompute=lambda *idx: tir.reinterpret( + quant_dtype, + tir.bitwise_and( + tir.shift_right( + weight(*idx[:axis], idx[axis] // num_elem_per_storage, *idx[axis + 1 :]), + ((idx[axis] % num_elem_per_storage) * bits).astype(storage_dtype), + ).astype(elem_storage_dtype), + tir_bin_mask, + ), + ).astype(model_dtype), + ) + + def pack_weight( weight: te.Tensor, axis: int, @@ -122,10 +160,12 @@ def pack_weight( """ assert weight.dtype == storage_dtype shape = weight.shape + if axis < 0: + axis += len(shape) k = shape[axis] axis = axis if axis >= 0 else len(shape) + axis if out_shape is None: - out_shape = (*shape[axis], tir.ceildiv(k, num_elem_per_storage), *shape[axis + 1 :]) + out_shape = (*shape[:axis], tir.ceildiv(k, num_elem_per_storage), *shape[axis + 1 :]) r = te.reduce_axis((0, num_elem_per_storage), name="r") # pylint: disable=invalid-name packed_weight = te.compute( shape=out_shape, From 700206b20dd63dfd8674e7615d302b0baee7904c Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 4 Apr 2024 18:24:59 -0400 Subject: [PATCH 152/531] [Serving][Refactor] Rename AsyncThreadedEngine to ThreadedEngine (#2081) This PR renames the AsyncThreadedEngine to ThreadedEngine to prepare for follow up refactors of Python interface. Meanwhile, this PR exposes a creation function for AsyncThreadedEngine so that it can be further used by others, such as JSONFFIEngine. --- ..._threaded_engine.cc => threaded_engine.cc} | 58 +++++++++++-------- ...nc_threaded_engine.h => threaded_engine.h} | 25 +++++--- 2 files changed, 51 insertions(+), 32 deletions(-) rename cpp/serve/{async_threaded_engine.cc => threaded_engine.cc} (85%) rename cpp/serve/{async_threaded_engine.h => threaded_engine.h} (65%) diff --git a/cpp/serve/async_threaded_engine.cc b/cpp/serve/threaded_engine.cc similarity index 85% rename from cpp/serve/async_threaded_engine.cc rename to cpp/serve/threaded_engine.cc index 49313e4ca1..61ce2e51d6 100644 --- a/cpp/serve/async_threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -1,9 +1,9 @@ /*! * Copyright (c) 2023 by Contributors - * \file serve/async_threaded_engine.cc - * \brief The implementation for asynchronous threaded serving engine in MLC LLM. + * \file serve/threaded_engine.cc + * \brief The implementation for threaded serving engine in MLC LLM. */ -#include "async_threaded_engine.h" +#include "threaded_engine.h" #include #include @@ -23,24 +23,9 @@ namespace serve { using tvm::Device; using namespace tvm::runtime; -/*! \brief The implementation of AsyncThreadedEngine. */ -class AsyncThreadedEngineImpl : public AsyncThreadedEngine, public ModuleNode { +/*! \brief The implementation of ThreadedEngine. */ +class ThreadedEngineImpl : public ThreadedEngine { public: - TVM_MODULE_VTABLE_BEGIN("mlc.serve.async_threaded_engine"); - TVM_MODULE_VTABLE_ENTRY("add_request", &AsyncThreadedEngineImpl::AddRequest); - TVM_MODULE_VTABLE_ENTRY("abort_request", &AsyncThreadedEngineImpl::AbortRequest); - TVM_MODULE_VTABLE_ENTRY("run_background_loop", &AsyncThreadedEngineImpl::RunBackgroundLoop); - TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", - &AsyncThreadedEngineImpl::RunBackgroundStreamBackLoop); - TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &AsyncThreadedEngineImpl::ExitBackgroundLoop); - if (_name == "init_background_engine") { - return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { - SelfPtr self = static_cast(_self.get()); - self->InitBackgroundEngine(args); - }); - } - TVM_MODULE_VTABLE_END(); - void InitBackgroundEngine(TVMArgs args) { Optional request_stream_callback; try { @@ -50,7 +35,7 @@ class AsyncThreadedEngineImpl : public AsyncThreadedEngine, public ModuleNode { } CHECK(request_stream_callback.defined()) - << "AsyncThreadedEngine requires request stream callback function, but it is not given."; + << "ThreadedEngine requires request stream callback function, but it is not given."; request_stream_callback_ = request_stream_callback.value(); auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { @@ -158,7 +143,9 @@ class AsyncThreadedEngineImpl : public AsyncThreadedEngine, public ModuleNode { flattened_callback_inputs.push_back(callback_input); } } - request_stream_callback_(Array(flattened_callback_inputs)); + if (!flattened_callback_inputs.empty()) { + request_stream_callback_(Array(flattened_callback_inputs)); + } flattened_callback_inputs.clear(); } } @@ -222,10 +209,35 @@ class AsyncThreadedEngineImpl : public AsyncThreadedEngine, public ModuleNode { bool stream_callback_waiting_ = false; }; +/*! \brief The implementation of ThreadedEngine. */ +class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { + public: + TVM_MODULE_VTABLE_BEGIN("mlc.serve.async_threaded_engine"); + TVM_MODULE_VTABLE_ENTRY("add_request", &ThreadedEngineImpl::AddRequest); + TVM_MODULE_VTABLE_ENTRY("abort_request", &ThreadedEngineImpl::AbortRequest); + TVM_MODULE_VTABLE_ENTRY("run_background_loop", &ThreadedEngineImpl::RunBackgroundLoop); + TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", + &ThreadedEngineImpl::RunBackgroundStreamBackLoop); + TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &ThreadedEngineImpl::ExitBackgroundLoop); + if (_name == "init_background_engine") { + return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { + SelfPtr self = static_cast(_self.get()); + self->InitBackgroundEngine(args); + }); + } + TVM_MODULE_VTABLE_END(); +}; + TVM_REGISTER_GLOBAL("mlc.serve.create_threaded_engine").set_body_typed([]() { - return Module(make_object()); + return Module(make_object()); }); +std::unique_ptr CreateThreadedEnginePacked(TVMArgs args) { + std::unique_ptr threaded_engine = std::make_unique(); + threaded_engine->InitBackgroundEngine(args); + return std::move(threaded_engine); +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/async_threaded_engine.h b/cpp/serve/threaded_engine.h similarity index 65% rename from cpp/serve/async_threaded_engine.h rename to cpp/serve/threaded_engine.h index 550bd81623..90447e28d8 100644 --- a/cpp/serve/async_threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -1,10 +1,10 @@ /*! * Copyright (c) 2023 by Contributors - * \file serve/async_threaded_engine.h - * \brief The header of threaded asynchronous serving engine in MLC LLM. + * \file serve/threaded_engine.h + * \brief The header of threaded serving engine in MLC LLM. */ -#ifndef MLC_LLM_SERVE_ASYNC_THREADED_ENGINE_H_ -#define MLC_LLM_SERVE_ASYNC_THREADED_ENGINE_H_ +#ifndef MLC_LLM_SERVE_THREADED_ENGINE_H_ +#define MLC_LLM_SERVE_THREADED_ENGINE_H_ #include @@ -19,16 +19,16 @@ namespace serve { using namespace tvm::runtime; /*! - * \brief The interface asynchronous threaded engine in MLC LLM. + * \brief The interface threaded engine in MLC LLM. * The threaded engine keeps running a background request processing * loop on a standalone thread. Ensuring thread safety, it exposes * `AddRequest` and `AbortRequest` to receive new requests or * abortions from other threads, and the internal request processing * is backed by a normal engine wrapped inside. */ -class AsyncThreadedEngine { +class ThreadedEngine { public: - virtual ~AsyncThreadedEngine() = default; + virtual ~ThreadedEngine() = default; /*! \brief Starts the background request processing loop. */ virtual void RunBackgroundLoop() = 0; @@ -37,7 +37,7 @@ class AsyncThreadedEngine { virtual void RunBackgroundStreamBackLoop() = 0; /*! - * \brief Notify the AsyncThreadedEngine to exit the background + * \brief Notify the ThreadedEngine to exit the background * request processing loop. This method is invoked by threads * other than the engine-driving thread. */ @@ -50,8 +50,15 @@ class AsyncThreadedEngine { virtual void AbortRequest(const String& request_id) = 0; }; +/*! + * \brief Create a ThreadedEngine from packed arguments in TVMArgs. + * \param args The arguments of engine construction. + * \return The constructed threaded engine in unique pointer. + */ +std::unique_ptr CreateThreadedEnginePacked(TVMArgs args); + } // namespace serve } // namespace llm } // namespace mlc -#endif // MLC_LLM_SERVE_ASYNC_THREADED_ENGINE_H_ +#endif // MLC_LLM_SERVE_THREADED_ENGINE_H_ From 2e9cc1ccba974336fb4605ef24fe4c55a3909ebc Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Thu, 4 Apr 2024 17:48:21 -0700 Subject: [PATCH 153/531] [Serving] Add cuda profiling in benchmark test (#2084) * [Serving] Add cuda profiling in benchmark test --- tests/python/serve/benchmark.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/python/serve/benchmark.py b/tests/python/serve/benchmark.py index fe914d1073..d544f4b371 100644 --- a/tests/python/serve/benchmark.py +++ b/tests/python/serve/benchmark.py @@ -23,15 +23,16 @@ def _parse_args(): args.add_argument("--dataset", type=str, required=True) args.add_argument("--device", type=str, default="auto") args.add_argument("--num-prompts", type=int, default=500) - args.add_argument("--batch-size", type=int, default=80) + args.add_argument("--max-num-sequence", type=int, default=80) args.add_argument("--page-size", type=int, default=16) args.add_argument("--max-total-seq-length", type=int) args.add_argument("--seed", type=int, default=0) args.add_argument("--json-output", type=bool, default=False) + args.add_argument("--cuda-profile", type=bool, default=False) parsed = args.parse_args() parsed.model = os.path.dirname(parsed.model_lib_path) - assert parsed.batch_size % 16 == 0 + assert parsed.max_num_sequence % 16 == 0 assert parsed.page_size == 16 return parsed @@ -108,7 +109,7 @@ def benchmark(args: argparse.Namespace): model = ModelInfo(args.model, args.model_lib_path, args.device) kv_cache_config = KVCacheConfig( page_size=args.page_size, - max_num_sequence=args.batch_size, + max_num_sequence=args.max_num_sequence, max_total_sequence_length=args.max_total_seq_length, ) @@ -138,6 +139,15 @@ def engine_generate(): total_prefill_tokens.append(engine_stats["total_prefill_tokens"]) total_decode_tokens.append(engine_stats["total_decode_tokens"]) + if args.cuda_profile: + import cuda + import cuda.cudart + + cuda.cudart.cudaProfilerStart() + engine_generate() + cuda.cudart.cudaProfilerStop() + return + e2e_latency = time_evaluator(engine_generate, args=[], num_runs=num_runs) single_token_prefill_latency = np.array(single_token_prefill_latency) single_token_decode_latency = np.array(single_token_decode_latency) From 41da87a8a5c7ca33c0a1b9b4d63bc5a6ab2c9cad Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 4 Apr 2024 20:48:28 -0400 Subject: [PATCH 154/531] [Grammar] Fix broken grammar tests (#2083) This PR fixes some grammar parser tests that were broken. --- tests/python/serve/test_grammar_parser.py | 27 ++++++++++--------- .../test_grammar_state_matcher_custom.py | 2 +- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/python/serve/test_grammar_parser.py b/tests/python/serve/test_grammar_parser.py index 325b0a5117..10eacdf9b9 100644 --- a/tests/python/serve/test_grammar_parser.py +++ b/tests/python/serve/test_grammar_parser.py @@ -17,7 +17,7 @@ def test_bnf_simple(): b ::= (([b])) c ::= (([c])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -36,7 +36,7 @@ def test_ebnf(): c_1 ::= (([acep-z] c_1) | ([acep-z])) d_1 ::= ("" | ([d])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -60,7 +60,7 @@ def test_star_quantifier(): e_star_2 ::= [g]* d_1_choice ::= (([b] [c] [d]) | ([p] [q])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -75,7 +75,7 @@ def test_char(): rest1 ::= ((([\?] [\"] [\'] [\u6d4b] [\u8bd5] [\u3042] [c]) ([\U0001f440]) "")) """ # Disable unwrap_nesting_rules to expose the result before unwrapping. - bnf_grammar = BNFGrammar.from_ebnf_string(before, False, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", False, False) after = bnf_grammar.to_string() assert after == expected @@ -90,7 +90,7 @@ def test_space(): """ expected = """main ::= (([a] [b] [c] [d] [e]) | ([f]) | ([g])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -101,7 +101,7 @@ def test_nest(): expected = """main ::= (([a] main_choice) | ([e] [f])) main_choice ::= (([b]) | ([c] [d])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -122,7 +122,7 @@ def test_flatten(): empty_test ::= ("" | ([d]) | ([a])) sequence_test_choice ::= (([c]) | ([d])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -159,7 +159,7 @@ def test_json(): exponent_choice_1 ::= ("" | ([+]) | ([\-])) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_string() assert after == expected @@ -176,9 +176,9 @@ def test_to_string_roundtrip(): c_2 ::= [acep-z] d_1 ::= [d] | "" """ - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main", True, False) output_string_1 = bnf_grammar_1.to_string() - bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, True, False) + bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, "main", True, False) output_string_2 = bnf_grammar_2.to_string() assert output_string_1 == output_string_2 @@ -240,7 +240,8 @@ def test_error(): with pytest.raises( TVMError, - match='TVMError: EBNF parse error at line 1, column 10: There must be a rule named "main"', + match="TVMError: EBNF parse error at line 1, column 10: " + 'The main rule with name "main" is not found.', ): BNFGrammar.from_ebnf_string('a ::= "a"') @@ -256,7 +257,7 @@ def test_to_json(): '4,3,7,8,9,5,1,10,0,2,97,122,4,1,12,5,1,13],"rules":[{"body_expr_id":6,"name":"main"},' '{"body_expr_id":11,"name":"b"},{"body_expr_id":14,"name":"c"}]}' ) - bnf_grammar = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) after = bnf_grammar.to_json(False) assert after == expected @@ -271,7 +272,7 @@ def test_to_json_roundtrip(): c_2 ::= (([acep-z])) d_1 ::= ("" | ([d])) """ - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, True, False) + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main", True, False) output_json_1 = bnf_grammar_1.to_json(False) bnf_grammar_2 = BNFGrammar.from_json(output_json_1) output_json_2 = bnf_grammar_2.to_json(False) diff --git a/tests/python/serve/test_grammar_state_matcher_custom.py b/tests/python/serve/test_grammar_state_matcher_custom.py index 5bdc8ecc4b..6fc48705d1 100644 --- a/tests/python/serve/test_grammar_state_matcher_custom.py +++ b/tests/python/serve/test_grammar_state_matcher_custom.py @@ -12,7 +12,7 @@ import tvm.testing from pydantic import BaseModel -from mlc_llm.serve import BNFGrammar, GrammarStateMatcher, json_schema_to_ebnf +from mlc_llm.serve import BNFGrammar, GrammarStateMatcher from mlc_llm.tokenizer import Tokenizer From 791623ae669a590dd2141657a9b202f3c1b02ae7 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 4 Apr 2024 20:48:36 -0400 Subject: [PATCH 155/531] [Serving][Fix] Fix chunked prefill condition (#2082) This PR fixes a bug when trying to chunk an input and do prefill. The stats prior ot this PR was wrong. --- cpp/serve/engine_actions/new_request_prefill.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index f93fbc2ded..5ff8ee923e 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -330,6 +330,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { std::min(input_length, kv_cache_config_->prefill_chunk_size - total_input_length); num_require_pages = (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; if (input_length > 0 && CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), total_input_length, total_required_pages, num_available_pages, From 7e0f102936999d812380c736a9e0efe077748caa Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 5 Apr 2024 07:14:19 -0400 Subject: [PATCH 156/531] [Conversation] Fix RedPajama conversation template (#2087) As reported and discussed in #2086, this PR fixes the RedPajama template. --- python/mlc_llm/conversation_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index 5976517c53..e71e6734f7 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -335,7 +335,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: roles={"user": "", "assistant": ""}, seps=["\n"], role_content_sep=": ", - role_empty_sep=": ", + role_empty_sep=":", stop_str=[""], stop_token_ids=[0], ) From c2f2e595919ff7f97f22e22cad8d28c6b9447ef9 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 5 Apr 2024 07:14:46 -0400 Subject: [PATCH 157/531] [Serving][Refactor] Python interface refactor (#2085) This PR is an initial major Python interface refactor of MLC Serve. With this PR, `mlc_llm.serve` in Python now exposes two engine classes: `AsyncEngine` and `Engine`. Both classes have two entrypoints, `chat_completion` and `completion` which conform to OpenAI Python API (reference: https://github.com/openai/openai-python). As the name suggested, `AsyncEngine` works asynchronously, and `Engine` works synchronously. It worths noting that the `Engine` since this PR is different from the `Engine` so far. The new `Engine` does not provide interfaces for batch generation. For robustness and correctness, the old `Engine` in Python is moved to `mlc_llm.serve.sync_engine.SyncEngine`. We do not directly expose this SyncEngine, and it now mainly serves testing and debug purposes. It is useful to check the correctness of new features, because of its simplicity. It keeps the low-level interface to directly invoke `step()` function of the engine, and also keeps the low-level batch generation interface. Our REST API entry points defined under `mlc_llm/serve/entrypoints/` are also refactored accordingly to adapt to the latest Python API in MLC Serve. In short, most of the logic in OpenAI API entry points are moved to Python API, which simplifies the implementation of entry points. Please note that this is the first (also the largest) planned refactor. We will follow up with some other refactors, which have smaller scopes compared with this PR. The planned refactors include: * provide submodule interface to align OpenAI Python package in https://github.com/openai/openai-python * refactor the constructor interface of `Engine`/`AsyncEngine` to align the MLC serve CLI interface. --- python/mlc_llm/interface/gen_config.py | 1 + python/mlc_llm/interface/serve.py | 14 +- .../mlc_llm/protocol/conversation_protocol.py | 129 +- python/mlc_llm/protocol/error_protocol.py | 34 + .../mlc_llm/protocol/openai_api_protocol.py | 89 +- python/mlc_llm/protocol/protocol_utils.py | 10 - python/mlc_llm/serve/__init__.py | 3 +- python/mlc_llm/serve/async_engine.py | 432 ------ python/mlc_llm/serve/data.py | 43 +- python/mlc_llm/serve/engine.py | 1251 ++++++++++------- python/mlc_llm/serve/engine_base.py | 1066 ++++++++++++++ python/mlc_llm/serve/engine_utils.py | 97 ++ .../serve/entrypoints/debug_entrypoints.py | 12 +- .../serve/entrypoints/entrypoint_utils.py | 150 -- .../serve/entrypoints/openai_entrypoints.py | 470 ++----- python/mlc_llm/serve/server/__main__.py | 73 - python/mlc_llm/serve/server/popen_server.py | 5 +- python/mlc_llm/serve/server/server_context.py | 38 +- python/mlc_llm/serve/sync_engine.py | 332 +++++ python/mlc_llm/testing/debug_chat.py | 4 +- tests/python/serve/benchmark.py | 7 +- tests/python/serve/evaluate_engine.py | 8 +- tests/python/serve/server/test_server.py | 13 +- tests/python/serve/test_serve_async_engine.py | 114 +- .../serve/test_serve_async_engine_spec.py | 15 +- tests/python/serve/test_serve_engine.py | 431 ++---- .../python/serve/test_serve_engine_grammar.py | 16 +- tests/python/serve/test_serve_engine_image.py | 23 +- tests/python/serve/test_serve_engine_spec.py | 14 +- tests/python/serve/test_serve_sync_engine.py | 402 ++++++ 30 files changed, 3251 insertions(+), 2045 deletions(-) create mode 100644 python/mlc_llm/protocol/error_protocol.py delete mode 100644 python/mlc_llm/serve/async_engine.py create mode 100644 python/mlc_llm/serve/engine_base.py create mode 100644 python/mlc_llm/serve/engine_utils.py delete mode 100644 python/mlc_llm/serve/entrypoints/entrypoint_utils.py delete mode 100644 python/mlc_llm/serve/server/__main__.py create mode 100644 python/mlc_llm/serve/sync_engine.py create mode 100644 tests/python/serve/test_serve_sync_engine.py diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index e0d401920a..d22aa7d231 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -289,6 +289,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b "rwkv_world", "rwkv", "gorilla", + "gorilla-openfunctions-v2", "guanaco", "dolly", "oasst", diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index c9b9b161b5..df64488a72 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -6,7 +6,8 @@ import uvicorn from fastapi.middleware.cors import CORSMiddleware -from mlc_llm.serve import async_engine, config +from mlc_llm.protocol import error_protocol +from mlc_llm.serve import config, engine, engine_base from mlc_llm.serve.entrypoints import debug_entrypoints, openai_entrypoints from mlc_llm.serve.server import ServerContext @@ -28,7 +29,7 @@ def serve( ): # pylint: disable=too-many-arguments, too-many-locals """Serve the model with the specified configuration.""" # Initialize model loading info and KV cache config - model_info = async_engine.ModelInfo( + model_info = engine_base.ModelInfo( model=model, model_lib_path=model_lib_path, device=device, @@ -39,12 +40,10 @@ def serve( prefill_chunk_size=prefill_chunk_size, ) # Create engine and start the background loop - engine = async_engine.AsyncThreadedEngine( - model_info, kv_cache_config, enable_tracing=enable_tracing - ) + async_engine = engine.AsyncEngine(model_info, kv_cache_config, enable_tracing=enable_tracing) with ServerContext() as server_context: - server_context.add_model(model, engine) + server_context.add_model(model, async_engine) app = fastapi.FastAPI() app.add_middleware( @@ -57,4 +56,7 @@ def serve( app.include_router(openai_entrypoints.app) app.include_router(debug_entrypoints.app) + app.exception_handler(error_protocol.BadRequestError)( + error_protocol.bad_request_error_handler + ) uvicorn.run(app, host=host, port=port, log_level="info") diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index 1c2a3cb2e4..482cce54c8 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -5,8 +5,6 @@ from pydantic import BaseModel, Field, field_validator -from ..serve import data - # The message placeholders in the message prompts according to roles. class MessagePlaceholders(Enum): @@ -113,17 +111,25 @@ def from_json_dict(cls: Type[T], json_dict: Dict[str, Any]) -> T: return Conversation.model_validate(json_dict) # pylint: disable=too-many-branches - def as_prompt(self, config=None) -> List[Union[str, data.ImageData]]: + def as_prompt(self, config=None) -> List[Any]: """Convert the conversation template and history messages to a single prompt. + + Returns + ------- + prompts : List[Union[str, "mlc_llm.serve.data.Data"]] + The prompts converted from the conversation messages. + We use Any in the signature to avoid cyclic import. """ + from ..serve import data # pylint: disable=import-outside-toplevel + # - Get the system message. system_msg = self.system_template.replace( MessagePlaceholders.SYSTEM.value, self.system_message ) # - Get the message strings. - message_list: List[Union[str, data.ImageData]] = [] + message_list: List[Union[str, data.Data]] = [] separators = list(self.seps) if len(separators) == 1: separators.append(separators[0]) @@ -136,55 +142,48 @@ def as_prompt(self, config=None) -> List[Union[str, data.ImageData]]: if role not in self.roles.keys(): raise ValueError(f'Role "{role}" is not a supported role in {self.roles.keys()}') separator = separators[role == "assistant"] # check assistant role - if content is not None: - role_prefix = ( - "" - # Do not append role prefix if this is the first message and there - # is already a system message - if (not self.add_role_after_system_message and system_msg != "" and i == 0) - else self.roles[role] + self.role_content_sep + + if content is None: + message_list.append(self.roles[role] + self.role_empty_sep) + continue + + role_prefix = ( + "" + # Do not append role prefix if this is the first message and there + # is already a system message + if (not self.add_role_after_system_message and system_msg != "" and i == 0) + else self.roles[role] + self.role_content_sep + ) + if isinstance(content, str): + message_list.append( + role_prefix + + self.role_templates[role].replace( + MessagePlaceholders[role.upper()].value, content + ) + + separator ) - if isinstance(content, str): - message_string = ( - role_prefix - + self.role_templates[role].replace( - MessagePlaceholders[role.upper()].value, content - ) - + separator + continue + + message_list.append(role_prefix) + + for item in content: + assert isinstance(item, dict), "Content should be a string or a list of dicts" + assert "type" in item, "Content item should have a type field" + if item["type"] == "text": + message = self.role_templates[role].replace( + MessagePlaceholders[role.upper()].value, item["text"] ) - message_list.append(message_string) + message_list.append(message) + elif item["type"] == "image_url": + assert config is not None, "Model config is required" + image_url = _get_url_from_item(item) + message_list.append(data.ImageData.from_url(image_url, config)) else: - message_list.append(role_prefix) - for item in content: - assert isinstance( - item, dict - ), "Content should be a string or a list of dicts" - assert "type" in item, "Content item should have a type field" - if item["type"] == "text": - message_list.append( - self.role_templates[role].replace( - MessagePlaceholders[role.upper()].value, item["text"] - ) - ) - elif item["type"] == "image_url": - assert config is not None, "Model config is required" - - # pylint: disable=import-outside-toplevel - from ..serve.entrypoints.entrypoint_utils import ( - get_image_from_url, - ) - - image_url = _get_url_from_item(item) - message_list.append(get_image_from_url(image_url, config)) - else: - raise ValueError(f"Unsupported content type: {item['type']}") - - message_list.append(separator) - else: - message_string = self.roles[role] + self.role_empty_sep - message_list.append(message_string) - - prompt = _combine_consecutive_strings(message_list) + raise ValueError(f"Unsupported content type: {item['type']}") + + message_list.append(separator) + + prompt = _combine_consecutive_messages(message_list) if not any(isinstance(item, data.ImageData) for item in message_list): # Replace the last function string placeholder with actual function string @@ -215,11 +214,27 @@ def _get_url_from_item(item: Dict) -> str: return image_url -def _combine_consecutive_strings(lst): - result = [] - for item in lst: - if isinstance(item, str) and result and isinstance(result[-1], str): - result[-1] += item +def _combine_consecutive_messages(messages: List[Any]) -> List[Any]: + """Combining consecutive strings into one. + + Parameters + ---------- + messages : List[Union[str, "mlc_llm.serve.data.Data"]] + The input messages to be combined. + We use Any in the signature to avoid cyclic import. + + Returns + ------- + updated_messages : List[Union[str, "mlc_llm.serve.data.Data"]] + The combined messages + """ + if len(messages) == 0: + return [] + + combined_messages = [messages[0]] + for message in messages[1:]: + if isinstance(message, str) and isinstance(combined_messages[-1], str): + combined_messages[-1] += message else: - result.append(item) - return result + combined_messages.append(message) + return combined_messages diff --git a/python/mlc_llm/protocol/error_protocol.py b/python/mlc_llm/protocol/error_protocol.py new file mode 100644 index 0000000000..83a201f578 --- /dev/null +++ b/python/mlc_llm/protocol/error_protocol.py @@ -0,0 +1,34 @@ +"""Error protocols in MLC LLM""" + +from http import HTTPStatus + +import fastapi +from pydantic import BaseModel + + +class BadRequestError(ValueError): + """The exception for bad requests in engines.""" + + def __init__(self, *args: object) -> None: + super().__init__(*args) + + +class ErrorResponse(BaseModel): + """The class of error response.""" + + object: str = "error" + message: str + code: int = None + + +def create_error_response(status_code: HTTPStatus, message: str) -> fastapi.responses.JSONResponse: + """Create a JSON response that reports error with regarding the input message.""" + return fastapi.responses.JSONResponse( + ErrorResponse(message=message, code=status_code.value).model_dump_json(), + status_code=status_code.value, + ) + + +async def bad_request_error_handler(_request: fastapi.Request, e: BadRequestError): + """The handler of BadRequestError that converts an exception into error response.""" + return create_error_response(status_code=HTTPStatus.BAD_REQUEST, message=e.args[0]) diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index fa4893447f..6f5754dee1 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -4,13 +4,16 @@ """ # pylint: disable=missing-class-docstring + +import json import time from typing import Any, Dict, List, Literal, Optional, Tuple, Union import shortuuid from pydantic import BaseModel, Field, field_validator, model_validator -from mlc_llm.serve.config import ResponseFormat +from .conversation_protocol import Conversation +from .error_protocol import BadRequestError ################ Commons ################ @@ -82,7 +85,7 @@ class CompletionRequest(BaseModel): """ model: str - prompt: Union[str, List[int], List[Union[str, List[int]]]] + prompt: Union[str, List[int]] best_of: int = 1 echo: bool = False frequency_penalty: float = 0.0 @@ -100,7 +103,7 @@ class CompletionRequest(BaseModel): top_p: float = 1.0 user: Optional[str] = None ignore_eos: bool = False - response_format: RequestResponseFormat = Field(default_factory=RequestResponseFormat) + response_format: Optional[RequestResponseFormat] = None @field_validator("frequency_penalty", "presence_penalty") @classmethod @@ -214,7 +217,7 @@ class ChatCompletionRequest(BaseModel): tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None user: Optional[str] = None ignore_eos: bool = False - response_format: RequestResponseFormat = Field(default_factory=RequestResponseFormat) + response_format: Optional[RequestResponseFormat] = None @field_validator("frequency_penalty", "presence_penalty") @classmethod @@ -249,6 +252,74 @@ def check_logprobs(self) -> "ChatCompletionRequest": raise ValueError('"logprobs" must be True to support "top_logprobs"') return self + def check_message_validity(self) -> None: + """Check if the given chat messages are valid. Return error message if invalid.""" + for i, message in enumerate(self.messages): + if message.role == "system" and i != 0: + raise BadRequestError( + f"System prompt at position {i} in the message list is invalid." + ) + if message.role == "tool": + raise BadRequestError("Tool as the message author is not supported yet.") + if message.tool_call_id is not None: + if message.role != "tool": + raise BadRequestError("Non-tool message having `tool_call_id` is invalid.") + if isinstance(message.content, list): + if message.role != "user": + raise BadRequestError("Non-user message having a list of content is invalid.") + if message.tool_calls is not None: + if message.role != "assistant": + raise BadRequestError("Non-assistant message having `tool_calls` is invalid.") + raise BadRequestError("Assistant message having `tool_calls` is not supported yet.") + + def check_function_call_usage(self, conv_template: Conversation) -> None: + """Check if function calling is used and update the conversation template. + Return error message if invalid request format for function calling. + """ + + # return if no tools are provided or tool_choice is set to none + if self.tools is None or (isinstance(self.tool_choice, str) and self.tool_choice == "none"): + conv_template.use_function_calling = False + return + + # select the tool based on the tool_choice if specified + if isinstance(self.tool_choice, dict): + if self.tool_choice["type"] != "function": # pylint: disable=unsubscriptable-object + raise BadRequestError("Only 'function' tool choice is supported") + + if len(self.tool_choice["function"]) > 1: # pylint: disable=unsubscriptable-object + raise BadRequestError("Only one tool is supported when tool_choice is specified") + + for tool in self.tools: # pylint: disable=not-an-iterable + if ( + tool.function.name + == self.tool_choice["function"][ # pylint: disable=unsubscriptable-object + "name" + ] + ): + conv_template.use_function_calling = True + conv_template.function_string = tool.function.model_dump_json() + return + + # pylint: disable=unsubscriptable-object + raise BadRequestError( + f"The tool_choice function {self.tool_choice['function']['name']}" + " is not found in the tools list" + ) + # pylint: enable=unsubscriptable-object + + if isinstance(self.tool_choice, str) and self.tool_choice != "auto": + raise BadRequestError(f"Invalid tool_choice value: {self.tool_choice}") + + function_list = [] + for tool in self.tools: # pylint: disable=not-an-iterable + if tool.type != "function": + raise BadRequestError("Only 'function' tool type is supported") + function_list.append(tool.function.model_dump()) + + conv_template.use_function_calling = True + conv_template.function_string = json.dumps(function_list) + class ChatCompletionResponseChoice(BaseModel): finish_reason: Optional[Literal["stop", "length", "tool_calls", "error"]] = None @@ -291,6 +362,9 @@ class ChatCompletionStreamResponse(BaseModel): model: str system_fingerprint: str object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + usage: UsageInfo = Field( + default_factory=lambda: UsageInfo() # pylint: disable=unnecessary-lambda + ) ################################################ @@ -315,6 +389,8 @@ def openai_api_get_generation_config( request: Union[CompletionRequest, ChatCompletionRequest] ) -> Dict[str, Any]: """Create the generation config from the given request.""" + from ..serve.config import ResponseFormat # pylint: disable=import-outside-toplevel + kwargs: Dict[str, Any] = {} arg_names = [ "n", @@ -337,5 +413,8 @@ def openai_api_get_generation_config( kwargs["max_tokens"] = -1 if request.stop is not None: kwargs["stop_strs"] = [request.stop] if isinstance(request.stop, str) else request.stop - kwargs["response_format"] = ResponseFormat(**request.response_format.model_dump(by_alias=True)) + if request.response_format is not None: + kwargs["response_format"] = ResponseFormat( + **request.response_format.model_dump(by_alias=True) + ) return kwargs diff --git a/python/mlc_llm/protocol/protocol_utils.py b/python/mlc_llm/protocol/protocol_utils.py index a9a68a1f82..f4273d0302 100644 --- a/python/mlc_llm/protocol/protocol_utils.py +++ b/python/mlc_llm/protocol/protocol_utils.py @@ -2,8 +2,6 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel - from ..serve.config import GenerationConfig from . import RequestProtocol from .openai_api_protocol import ChatCompletionRequest as OpenAIChatCompletionRequest @@ -14,14 +12,6 @@ ) -class ErrorResponse(BaseModel): - """The class of error response.""" - - object: str = "error" - message: str - code: int = None - - def get_unsupported_fields(request: RequestProtocol) -> List[str]: """Get the unsupported fields of the request. Return the list of unsupported field names. diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 8e06de7b54..e165128ea3 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -2,10 +2,9 @@ # Load MLC LLM library by importing base from .. import base -from .async_engine import AsyncThreadedEngine from .config import EngineMode, GenerationConfig, KVCacheConfig from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData -from .engine import Engine +from .engine import AsyncEngine, Engine from .grammar import BNFGrammar, GrammarStateMatcher from .json_schema_converter import json_schema_to_ebnf from .request import Request diff --git a/python/mlc_llm/serve/async_engine.py b/python/mlc_llm/serve/async_engine.py deleted file mode 100644 index 341a3880f3..0000000000 --- a/python/mlc_llm/serve/async_engine.py +++ /dev/null @@ -1,432 +0,0 @@ -"""The MLC LLM Asynchronous Serving Engine. -Acknowledgment: Part of the code was adapted from the vLLM project. -""" - -import asyncio -import sys -import threading -from dataclasses import dataclass -from typing import ( - Any, - AsyncGenerator, - Callable, - Dict, - List, - Optional, - Sequence, - Tuple, - Union, -) - -import tvm - -from ..streamer import TextStreamer -from ..tokenizer import Tokenizer -from . import data -from .config import EngineMode, GenerationConfig, KVCacheConfig -from .engine import ModelInfo, _estimate_max_total_sequence_length, _process_model_args -from .event_trace_recorder import EventTraceRecorder -from .request import Request - - -@dataclass -class AsyncStreamOutput: - """The output of AsyncThreadedEngine.generate - - Attributes - ---------- - delta_text : str - The delta text generated since the last output. - - num_delta_tokens : int - The number of delta tokens generated since the last output. - - delta_logprob_json_strs : Optional[List[str]] - The list of logprob JSON strings since the last output, - or None if the request does not require logprobs. - - finish_reason : Optional[str] - The finish reason of the request, or None if unfinished. - """ - - delta_text: str - num_delta_tokens: int - delta_logprob_json_strs: Optional[List[str]] - finish_reason: Optional[str] - - -class AsyncRequestStream: - """The asynchronous stream for requests. - - Each request has its own unique stream. - The stream exposes the method `push` for engine to push new generated - delta text to the stream, and the method `finish` for engine to mark - the finish of generation. - - The stream implements `__aiter__` and `__anext__`, which the engine - can use to iterates all the generated tokens in order asynchronously. - """ - - # The asynchronous queue to hold elements of either a list of - # AsyncStreamOutput or an exception. - if sys.version_info >= (3, 9): - _queue: asyncio.Queue[ # pylint: disable=unsubscriptable-object - Union[List[AsyncStreamOutput], Exception] - ] - else: - _queue: asyncio.Queue - # The finish flag. - _finished: bool - - def __init__(self) -> None: - self._queue = asyncio.Queue() - self._finished = False - - def push(self, item_or_exception: Union[List[AsyncStreamOutput], Exception]) -> None: - """Push a new token to the stream.""" - if self._finished: - # No new item is expected after finish. - self._queue.put_nowait( - RuntimeError( - "The request has already finished. " - "The stream is not supposed to accept new items." - ) - ) - return - self._queue.put_nowait(item_or_exception) - - def finish(self) -> None: - """Mark the finish of the generation in the stream.""" - self._queue.put_nowait(StopIteration()) - self._finished = True - - def __aiter__(self): - return self - - async def __anext__(self) -> List[AsyncStreamOutput]: - result = await self._queue.get() - if isinstance(result, StopIteration): - raise StopAsyncIteration - if isinstance(result, Exception): - raise result - return result - - -class _AsyncThreadedEngineState: - """The engine states that the request stream callback function may use. - We use this state class to avoid the callback function from capturing - the AsyncThreadedEngine. - """ - - trace_recorder = None - # The mapping from request ids to request asynchronous stream. - request_tools: Dict[str, Tuple[AsyncRequestStream, List[TextStreamer]]] = {} - num_unfinished_generations: Dict[str, int] = {} - _async_event_loop: Optional[asyncio.AbstractEventLoop] = None - - def __init__(self, enable_tracing: bool) -> None: - if enable_tracing: - self.trace_recorder = EventTraceRecorder() - - def lazy_init_event_loop(self) -> None: - """Lazily set the asyncio event loop so that the event - loop is the main driving event loop of the process. - """ - if self._async_event_loop is None: - self._async_event_loop = asyncio.get_event_loop() - - def get_request_stream_callback(self) -> Callable[[List[data.RequestStreamOutput]], None]: - """Construct a callback function and return.""" - - def _callback(delta_outputs: List[data.RequestStreamOutput]) -> None: - self._request_stream_callback(delta_outputs) - - return _callback - - def _request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: - """The request stream callback function for engine to stream back - the request generation results. - - Parameters - ---------- - delta_outputs : List[data.RequestStreamOutput] - The delta output of each requests. - Check out data.RequestStreamOutput for the fields of the outputs. - - Note - ---- - This callback function uses `call_soon_threadsafe` in asyncio to - schedule the invocation in the event loop, so that the underlying - callback logic will be executed asynchronously in the future rather - than right now. - """ - - # Schedule a callback run in the event loop without executing right now. - # NOTE: This function causes GIL during execution. - self._async_event_loop.call_soon_threadsafe( - self._request_stream_callback_impl, delta_outputs - ) - - def _request_stream_callback_impl(self, delta_outputs: List[data.RequestStreamOutput]) -> None: - """The underlying implementation of request stream callback.""" - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - tools = self.request_tools.get(request_id, None) - if tools is None: - continue - - self.record_event(request_id, event="start callback") - stream, text_streamers = tools - outputs = [] - for stream_output, text_streamer in zip(stream_outputs, text_streamers): - self.record_event(request_id, event="start detokenization") - delta_text = ( - text_streamer.put(stream_output.delta_token_ids) - if len(stream_output.delta_token_ids) > 0 - else "" - ) - if stream_output.finish_reason is not None: - delta_text += text_streamer.finish() - self.record_event(request_id, event="finish detokenization") - - outputs.append( - AsyncStreamOutput( - delta_text=delta_text, - num_delta_tokens=len(stream_output.delta_token_ids), - delta_logprob_json_strs=stream_output.delta_logprob_json_strs, - finish_reason=stream_output.finish_reason, - ) - ) - if stream_output.finish_reason is not None: - self.num_unfinished_generations[request_id] -= 1 - - # Push new delta text to the stream. - stream.push(outputs) - if self.num_unfinished_generations[request_id] == 0: - stream.finish() - self.request_tools.pop(request_id, None) - self.record_event(request_id, event="finish callback") - - def record_event(self, request_id: str, event: str) -> None: - """Record a event for the the input request in the trace - recorder when the recorder exists. - - Parameters - ---------- - request_id : str - The subject request of the event. - - event : str - The event in a string name. - It can have one of the following patterns: - - "start xxx", which marks the start of event "xxx", - - "finish xxx", which marks the finish of event "xxx", - - "yyy", which marks the instant event "yyy". - The "starts" and "finishes" will be automatically paired in the trace recorder. - """ - if self.trace_recorder is None: - return - self.trace_recorder.add_event(request_id, event) - - -class AsyncThreadedEngine: # pylint: disable=too-many-instance-attributes - """The asynchronous engine for generate text asynchronously, - backed by ThreadedEngine. - - This class wraps a synchronous threaded engine that runs on - a standalone thread inside, and exports the asynchronous `generate` - method as the main text generation interface, which yields the - generated tokens. The internal threaded engine keeps running an - event loop that drives the engine. - - Parameters - ---------- - models : Union[ModelInfo, List[ModelInfo]] - One or a list of model info (specifying which models to load and - which device to load to) to launch the engine. - - kv_cache_config : KVCacheConfig - The configuration of the paged KV cache. - - engine_mode : Optional[EngineMode] - The Engine execution mode. - - enable_tracing : bool - A boolean indicating if to enable event logging for requests. - """ - - def __init__( - self, - models: Union[ModelInfo, List[ModelInfo]], - kv_cache_config: KVCacheConfig, - engine_mode: Optional[EngineMode] = None, - enable_tracing: bool = False, - ) -> None: - if isinstance(models, ModelInfo): - models = [models] - ( - model_args, - config_file_paths, - tokenizer_path, - max_single_sequence_length, - prefill_chunk_size, - self.conv_template_name, - ) = _process_model_args(models) - - for i, model in enumerate(models): - # model_args: - # [model_lib_path, model_path, device.device_type, device.device_id] * N - model.model_lib_path = model_args[i * (len(model_args) // len(models))] - - self.max_input_sequence_length = max_single_sequence_length - self.state = _AsyncThreadedEngineState(enable_tracing) - - if kv_cache_config.max_total_sequence_length is None: - kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( - models, config_file_paths, kv_cache_config.max_num_sequence - ) - if kv_cache_config.prefill_chunk_size is None: - kv_cache_config.prefill_chunk_size = prefill_chunk_size - elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: - raise ValueError( - f"The specified prefill chunk size {kv_cache_config.prefill_chunk_size} is " - f"larger than the maximum prefill chunk size {prefill_chunk_size} supported by " - "models. Please specify a smaller prefill chunk size." - ) - - module = tvm.get_global_func("mlc.serve.create_threaded_engine", allow_missing=False)() - self._ffi = { - key: module[key] - for key in [ - "add_request", - "abort_request", - "run_background_loop", - "run_background_stream_back_loop", - "init_background_engine", - "exit_background_loop", - ] - } - self.tokenizer = Tokenizer(tokenizer_path) - if engine_mode is None: - # The default engine mode: non-speculative - engine_mode = EngineMode() - - def _background_loop(): - self._ffi["init_background_engine"]( - max_single_sequence_length, - tokenizer_path, - kv_cache_config.asjson(), - engine_mode.asjson(), - self.state.get_request_stream_callback(), - self.state.trace_recorder, - *model_args, - ) - self._ffi["run_background_loop"]() - - def _background_stream_back_loop(): - self._ffi["run_background_stream_back_loop"]() - - # Create the background engine-driving thread and start the loop. - self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) - self._background_stream_back_loop_thread: threading.Thread = threading.Thread( - target=_background_stream_back_loop - ) - self._background_loop_thread.start() - self._background_stream_back_loop_thread.start() - # The main thread request handling asyncio event loop, which will - # be lazily initialized. - self._terminated = False - - def terminate(self): - """Terminate the engine.""" - self._terminated = True - self._ffi["exit_background_loop"]() - self._background_loop_thread.join() - self._background_stream_back_loop_thread.join() - - async def generate( - self, - prompt: Union[str, List[int], Sequence[Union[str, List[int], data.Data]]], - generation_config: GenerationConfig, - request_id: str, - ) -> AsyncGenerator[List[AsyncStreamOutput], Any]: - """Asynchronous text generation interface. - The method is a coroutine that streams a list of AsyncStreamOutput - at a time via yield. The returned list length is the number of - parallel generations specified by `generation_config.n`. - - Parameters - ---------- - prompt : Union[str, List[int]] - The input prompt in forms of text string or a list of token ids. - - generation_config : GenerationConfig - The generation config of the request. - - request_id : str - The unique identifier (in string) or this generation request. - """ - if self._terminated: - raise ValueError("The AsyncThreadedEngine has terminated.") - self.state.lazy_init_event_loop() - - def convert_to_data( - prompt: Union[str, List[int], Sequence[Union[str, List[int], data.Data]]] - ) -> List[data.Data]: - if isinstance(prompt, data.Data): - return [prompt] - if isinstance(prompt, str): - return [data.TextData(prompt)] - if isinstance(prompt[0], int): - return [data.TokenData(prompt)] # type: ignore - return [convert_to_data(x)[0] for x in prompt] # type: ignore - - # Create the request with the given id, input data, generation - # config and the created callback. - input_data = convert_to_data(prompt) - request = Request(request_id, input_data, generation_config) - - # Create the unique stream of the request. - stream = AsyncRequestStream() - if request_id in self.state.request_tools: - # Report error in the stream if the request id already exists. - stream.push( - RuntimeError( - f'The request id "{request_id} already exists. ' - 'Please make sure the request id is unique."' - ) - ) - else: - # Record the stream in the tracker - self.state.request_tools[request_id] = ( - stream, - [TextStreamer(self.tokenizer) for _ in range(generation_config.n)], - ) - self.state.num_unfinished_generations[request_id] = generation_config.n - self._ffi["add_request"](request) - - # Iterate the stream asynchronously and yield the token. - try: - async for request_output in stream: - yield request_output - except ( - Exception, - asyncio.CancelledError, - ) as exception: # pylint: disable=broad-exception-caught - await self.abort(request_id) - raise exception - - async def abort(self, request_id: str) -> None: - """Generation abortion interface. - - Parameter - --------- - request_id : str - The id of the request to abort. - """ - self._abort(request_id) - - def _abort(self, request_id: str): - """Internal implementation of request abortion.""" - self.state.request_tools.pop(request_id, None) - self._ffi["abort_request"](request_id) diff --git a/python/mlc_llm/serve/data.py b/python/mlc_llm/serve/data.py index 8444e3f363..b8ffc8da8f 100644 --- a/python/mlc_llm/serve/data.py +++ b/python/mlc_llm/serve/data.py @@ -1,8 +1,9 @@ """Classes denoting multi-modality data used in MLC LLM serving""" from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple +import tvm import tvm._ffi from tvm.runtime import Object from tvm.runtime.ndarray import NDArray @@ -81,6 +82,46 @@ def image(self) -> NDArray: def __len__(self): return self.embed_size + @staticmethod + def from_url(url: str, config: Dict) -> "ImageData": + """Get the image from the given URL, process and return the image tensor as TVM NDArray.""" + + # pylint: disable=import-outside-toplevel, import-error + import base64 + from io import BytesIO + + import requests + from PIL import Image + from transformers import CLIPImageProcessor + + if url.startswith("data:image"): + # The image is encoded in base64 format + base64_image = url.split(",")[1] + image_data = base64.b64decode(base64_image) + image_tensor = Image.open(BytesIO(image_data)).convert("RGB") + elif url.startswith("http"): + response = requests.get(url, timeout=5) + image_tensor = Image.open(BytesIO(response.content)).convert("RGB") + else: + raise ValueError(f"Unsupported image URL format: {url}") + + image_input_size = config["model_config"]["vision_config"]["image_size"] + image_embed_size = ( + image_input_size // config["model_config"]["vision_config"]["patch_size"] + ) ** 2 + + image_processor = CLIPImageProcessor( + size={"shortest_edge": image_input_size}, + crop_size={"height": image_input_size, "width": image_input_size}, + ) + image_features = tvm.nd.array( + image_processor.preprocess(image_tensor, return_tensors="np")["pixel_values"].astype( + "float16" + ) + ) + image_data = ImageData(image_features, image_embed_size) + return image_data + @dataclass class SingleRequestStreamOutput: diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 607f970a1e..1f856c907c 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -1,279 +1,423 @@ """The MLC LLM Serving Engine.""" -import json -import os -import subprocess -import sys -from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union - -import tvm -from tvm.runtime import Device - -from mlc_llm.protocol.conversation_protocol import Conversation -from mlc_llm.serve import data +import asyncio +import queue +from typing import Any, AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union + +from mlc_llm.protocol import openai_api_protocol +from mlc_llm.serve import data, engine_utils +from mlc_llm.serve.config import EngineMode, GenerationConfig, KVCacheConfig +from mlc_llm.serve.request import Request +from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging -from mlc_llm.support.auto_device import detect_device -from mlc_llm.support.style import green -from ..chat_module import _get_chat_config, _get_lib_module_path, _get_model_path -from ..streamer import TextStreamer -from ..tokenizer import Tokenizer -from . import data -from .config import EngineMode, GenerationConfig, KVCacheConfig -from .event_trace_recorder import EventTraceRecorder -from .request import Request +from . import engine_base logging.enable_logging() logger = logging.getLogger(__name__) -@dataclass -class ModelInfo: - """The model info dataclass. +class AsyncEngine(engine_base.EngineBase): + """The AsyncEngine in MLC LLM that provides the asynchronous + interfaces with regard to OpenAI API. Parameters ---------- - model : str - The identifier of the input model. - It may be a compiled model's id (e.g., "Llama-2-7b-chat-hf-q4f16_1"), - or a full path to a model directory - (e.g., "dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1") - - device : str - The device where to run the model. - It can be "auto", "device_name" (e.g., "cuda") or - "device_name:device_id" (e.g., "cuda:1"). - - model_lib_path : str - The path to the compiled library of the model. - E.g., "dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + models : Union[ModelInfo, List[ModelInfo]] + One or a list of model info (specifying which models to load and + which device to load to) to launch the engine. + + kv_cache_config : KVCacheConfig + The configuration of the paged KV cache. + + engine_mode : Optional[EngineMode] + The Engine execution mode. + + enable_tracing : bool + A boolean indicating if to enable event logging for requests. """ - model: str - model_lib_path: str - device: Device = "auto" # type: ignore - - def __post_init__(self): - if isinstance(self.device, str): - self.device = detect_device(self.device) - assert isinstance(self.device, Device) - - -def _create_tvm_module( - creator: str, ffi_funcs: Sequence[str], creator_args: Optional[List[Any]] = None -) -> Dict[str, Callable]: - """Internal method to create a module.""" - if creator_args is None: - creator_args = [] - module = tvm.get_global_func(creator, allow_missing=False)(*creator_args) - return {key: module[key] for key in ffi_funcs} - - -def _process_model_args( - models: List[ModelInfo], -) -> Tuple[List[Any], List[str], str, int, int, Optional[str]]: - """Process the input ModelInfo to get the engine initialization arguments.""" - max_single_sequence_length = int(1e9) - prefill_chunk_size = int(1e9) - tokenizer_path: Optional[str] = None - conv_template_name: Optional[str] = None - config_file_paths: List[str] = [] - - def _convert_model_info(model: ModelInfo) -> List[Any]: - nonlocal max_single_sequence_length, prefill_chunk_size, tokenizer_path, conv_template_name - - device = model.device - model_path, config_file_path = _get_model_path(model.model) - config_file_paths.append(config_file_path) - chat_config = _get_chat_config(config_file_path, user_chat_config=None) - if chat_config.context_window_size and chat_config.context_window_size != -1: - max_single_sequence_length = min( - max_single_sequence_length, - chat_config.context_window_size, - ) - if chat_config.prefill_chunk_size: - prefill_chunk_size = min(prefill_chunk_size, chat_config.prefill_chunk_size) - if tokenizer_path is None: - tokenizer_path = model_path - if conv_template_name is None: - assert isinstance(chat_config.conv_template, Conversation) - conv_template_name = chat_config.conv_template.name - # Try look up model library, and do JIT compile if model library not found. - try: - model_lib_path = _get_lib_module_path( - model=model.model, - model_path=model_path, - chat_config=chat_config, - model_lib_path=model.model_lib_path, - device_name=device.MASK2STR[device.device_type], - config_file_path=config_file_path, - ) - except FileNotFoundError: - from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel - - model_lib_path = str( - jit.jit( - model_path=Path(model_path), - chat_config=asdict(chat_config), - device=device, - ) - ) - return [model_lib_path, model_path, device.device_type, device.device_id] - - model_args: List[Any] = sum( - (_convert_model_info(model) for model in models), - start=[], - ) - - assert prefill_chunk_size != int(1e9) - return ( - model_args, - config_file_paths, - tokenizer_path, - max_single_sequence_length, - prefill_chunk_size, - conv_template_name, - ) - - -def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals - models: List[ModelInfo], config_file_paths: List[str], max_num_sequence: int -) -> int: - """Estimate the max total sequence length (capacity) of the KV cache.""" - assert len(models) != 0 - - kv_bytes_per_token = 0 - kv_aux_workspace_bytes = 0 - model_workspace_bytes = 0 - logit_processor_workspace_bytes = 0 - params_bytes = 0 - temp_func_bytes = 0 - - for model, config_file_path in zip(models, config_file_paths): - # Read metadata for the parameter size and the temporary memory size. - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-memory-usage-in-json", - "--mlc-chat-config", - config_file_path, - ] - usage_str = subprocess.check_output(cmd, universal_newlines=True) - usage_json = json.loads(usage_str) - params_bytes += usage_json["params_bytes"] - temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) - - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-kv-cache-metadata-in-json", - ] - kv_cache_metadata_str = subprocess.check_output(cmd, universal_newlines=True) - kv_cache_metadata = json.loads(kv_cache_metadata_str) - - # Read model config and compute the kv size per token. - with open(config_file_path, mode="rt", encoding="utf-8") as file: - json_object = json.load(file) - model_config = json_object["model_config"] - vocab_size = model_config["vocab_size"] - prefill_chunk_size = model_config["prefill_chunk_size"] - num_layers = kv_cache_metadata["num_hidden_layers"] - head_dim = kv_cache_metadata["head_dim"] - num_qo_heads = kv_cache_metadata["num_attention_heads"] - num_kv_heads = kv_cache_metadata["num_key_value_heads"] - hidden_size = head_dim * num_qo_heads - kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25 - kv_aux_workspace_bytes += ( - (max_num_sequence + 1) * 88 - + prefill_chunk_size * (num_qo_heads + 1) * 8 - + prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 - + 48 * 1024 * 1024 + def __init__( + self, + models: Union[engine_base.ModelInfo, List[engine_base.ModelInfo]], + kv_cache_config: KVCacheConfig, + engine_mode: Optional[EngineMode] = None, + enable_tracing: bool = False, + ) -> None: + super().__init__("async", models, kv_cache_config, engine_mode, enable_tracing) + + async def abort(self, request_id: str) -> None: + """Generation abortion interface. + + Parameter + --------- + request_id : str + The id of the request to abort. + """ + self._abort(request_id) + + async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any]: + """Asynchronous chat completion interface with OpenAI API compatibility. + The method is a coroutine that streams ChatCompletionStreamResponse + that conforms to OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : ChatCompletionStreamResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id=request_id, ) - model_workspace_bytes += ( - prefill_chunk_size * 4 - + max_num_sequence * 4 - + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 + async for response in chatcmpl_generator: + yield response + + async def completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + model: str, + prompt: Union[str, List[int]], + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: + """Asynchronous completion interface with OpenAI API compatibility. + The method is a coroutine that streams CompletionResponse + that conforms to OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + if request_id is None: + request_id = f"cmpl-{engine_utils.random_uuid()}" + cmpl_generator = self._handle_completion( + openai_api_protocol.CompletionRequest( + model=model, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + suffix=suffix, + temperature=temperature, + top_p=top_p, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id, ) - logit_processor_workspace_bytes += ( - max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 + async for response in cmpl_generator: + yield response + + async def _handle_chat_completion( + self, request: openai_api_protocol.ChatCompletionRequest, request_id: str + ) -> AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any]: + """The implementation fo asynchronous ChatCompletionRequest handling. + + Yields + ------ + stream_response : ChatCompletionStreamResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + ( + prompts, + generation_cfg, + use_function_calling, + prompt_length, + ) = engine_base.process_chat_completion_request( + request, + request_id, + self.state, + self.model_config_dicts[0], + self.tokenizer.encode, + self.max_input_sequence_length, + self.conv_template.model_copy(deep=True), ) - # Get single-card GPU size. - gpu_size_bytes = os.environ.get("MLC_GPU_SIZE_BYTES", default=None) - if gpu_size_bytes is None: - gpu_size_bytes = models[0].device.total_global_memory - if gpu_size_bytes is None: - raise ValueError( - "Cannot read total GPU global memory from device. " - 'Please the GPU memory size in bytes through "MLC_GPU_SIZE_BYTES" env variable.' + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + num_completion_tokens = 0 + self.state.record_event(request_id, event="invoke generate") + async for delta_outputs in self._generate( + prompts, generation_cfg, request_id # type: ignore + ): + response, num_completion_tokens = engine_base.process_chat_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + use_function_calling, + prompt_length, + finish_reasons, + num_completion_tokens, ) - - max_total_sequence_length = int( + if response is not None: + yield response + self.state.record_event(request_id, event="finish") + + async def _handle_completion( + self, request: openai_api_protocol.CompletionRequest, request_id: str + ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: + """The implementation fo asynchronous CompletionRequest handling. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ ( - int(gpu_size_bytes) * 0.90 - - params_bytes - - temp_func_bytes - - kv_aux_workspace_bytes - - model_workspace_bytes - - logit_processor_workspace_bytes + prompt, + generation_cfg, + prompt_length, + echo_response, + ) = engine_base.process_completion_request( + request, + request_id, + self.state, + self.tokenizer, + self.max_input_sequence_length, ) - / kv_bytes_per_token - ) - assert max_total_sequence_length > 0, ( - "Cannot estimate KV cache capacity. " - f"The model weight size {params_bytes} may be larger than GPU memory size {gpu_size_bytes}" - ) - - if models[0].device.device_type == Device.kDLMetal: - # NOTE: Metal runtime has severe performance issues with large buffers. - # To work around the issue, we limit the KV cache capacity to 32768. - max_total_sequence_length = min(max_total_sequence_length, 32768) - - total_size = ( - params_bytes - + temp_func_bytes - + kv_aux_workspace_bytes - + model_workspace_bytes - + logit_processor_workspace_bytes - + kv_bytes_per_token * max_total_sequence_length - ) - logger.info( - "%s: %d.", - green('Estimated KVCacheConfig "max_total_sequence_length"'), - max_total_sequence_length, - ) - logger.info( - "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB)", - green("Estimated total single GPU memory usage"), - total_size / 1024 / 1024, - params_bytes / 1024 / 1024, - (kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes) / 1024 / 1024, - (model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes) / 1024 / 1024, - ) - return int(max_total_sequence_length) - - -class Engine: - """The Python interface of request serving engine for MLC LLM. - - The engine can run one or multiple LLM models internally for - text generation. Usually, when there are multiple models, - speculative inference will be activated, where the first model - (index 0) is the main "large model" that has better generation - quality, and all other models are "small" models that used for - speculation. - - The engine receives requests from the "add_request" method. For - an given request, the engine will keep generating new tokens for - the request until finish (under certain criterion). After finish, - the engine will return the generation result through the callback - function provided by the request. + if echo_response is not None: + yield echo_response + + num_completion_tokens = 0 + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + self.state.record_event(request_id, event="invoke generate") + async for delta_outputs in self._generate( + prompt, generation_cfg, request_id # type: ignore + ): + response, num_completion_tokens = engine_base.process_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + prompt_length, + finish_reasons, + num_completion_tokens, + ) + if response is not None: + yield response + + suffix_response = engine_base.create_completion_suffix_response( + request, request_id, prompt_length, finish_reasons, num_completion_tokens + ) + if suffix_response is not None: + yield suffix_response + self.state.record_event(request_id, event="finish") + + async def _generate( + self, + prompt: Union[str, List[int], List[Union[str, List[int], data.Data]]], + generation_config: GenerationConfig, + request_id: str, + ) -> AsyncGenerator[List[engine_base.CallbackStreamOutput], Any]: + """Internal asynchronous text generation interface of AsyncEngine. + The method is a coroutine that streams a list of CallbackStreamOutput + at a time via yield. The returned list length is the number of + parallel generations specified by `generation_config.n`. + + Parameters + ---------- + prompt : Union[str, List[int], List[Union[str, List[int], data.Data]]] + The input prompt in forms of text strings, lists of token ids or data. + + generation_config : GenerationConfig + The generation config of the request. + + request_id : str + The unique identifier (in string) or this generation request. + + Yields + ------ + request_output : List[engine_base.CallbackStreamOutput] + The delta generated outputs in a list. + The number of list elements equals to `generation_config.n`, + and each element corresponds to the delta output of a parallel + generation. + """ + if self._terminated: + raise ValueError("The AsyncThreadedEngine has terminated.") + self.state.async_lazy_init_event_loop() + + # Create the request with the given id, input data, generation + # config and the created callback. + input_data = engine_utils.convert_prompts_to_data(prompt) + request = Request(request_id, input_data, generation_config) + + # Create the unique async request stream of the request. + stream = engine_base.AsyncRequestStream() + if request_id in self.state.async_streamers: + # Report error in the stream if the request id already exists. + stream.push( + RuntimeError( + f'The request id "{request_id} already exists. ' + 'Please make sure the request id is unique."' + ) + ) + else: + # Record the stream in the tracker + self.state.async_streamers[request_id] = ( + stream, + [TextStreamer(self.tokenizer) for _ in range(generation_config.n)], + ) + self.state.async_num_unfinished_generations[request_id] = generation_config.n + self._ffi["add_request"](request) + + # Iterate the stream asynchronously and yield the output. + try: + async for request_output in stream: + yield request_output + except ( + Exception, + asyncio.CancelledError, + ) as exception: # pylint: disable=broad-exception-caught + await self.abort(request_id) + raise exception + + def _abort(self, request_id: str): + """Internal implementation of request abortion.""" + self.state.async_streamers.pop(request_id, None) + self.state.async_num_unfinished_generations.pop(request_id, None) + self._ffi["abort_request"](request_id) + + +class Engine(engine_base.EngineBase): + """The Engine in MLC LLM that provides the synchronous + interfaces with regard to OpenAI API. Parameters ---------- @@ -284,21 +428,6 @@ class Engine: kv_cache_config : KVCacheConfig The configuration of the paged KV cache. - request_stream_callback : Optional[Callable[[str, data.TokenData, Optional[str]], None]] - The provided callback function to handle the generation - output. It has the signature of `(str, data.TokenData, bool) -> None`, - where - - the first string is the request id, - - the TokenData contains the generated **delta** token ids since - the last invocation of the callback on the specific request, - - the optional string value denotes the finish reason if the - generation of the request is finished, or None if it has not finished. - - The callback function is optional at construction, but it needs to - be set before the engine executing requests. This can be done via - the `set_request_stream_callback` method. Otherwise, the engine will raise - exception. - engine_mode : Optional[EngineMode] The Engine execution mode. @@ -306,247 +435,391 @@ class Engine: A boolean indicating if to enable event logging for requests. """ - def __init__( # pylint: disable=too-many-arguments + def __init__( self, - models: Union[ModelInfo, List[ModelInfo]], + models: Union[engine_base.ModelInfo, List[engine_base.ModelInfo]], kv_cache_config: KVCacheConfig, engine_mode: Optional[EngineMode] = None, - request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, enable_tracing: bool = False, - ): - if isinstance(models, ModelInfo): - models = [models] - ( - model_args, - config_file_paths, - tokenizer_path, - max_single_sequence_length, - prefill_chunk_size, - self.conv_template_name, - ) = _process_model_args(models) - self._ffi = _create_tvm_module( - "mlc.serve.create_engine", - ffi_funcs=[ - "init", - "add_request", - "abort_request", - "step", - "stats", - "reset", - "get_request_stream_callback", - "set_request_stream_callback", - ], - ) - self.trace_recorder = EventTraceRecorder() if enable_tracing else None - self.max_input_sequence_length = max_single_sequence_length + ) -> None: + super().__init__("sync", models, kv_cache_config, engine_mode, enable_tracing) - if kv_cache_config.max_total_sequence_length is None: - kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( - models, config_file_paths, kv_cache_config.max_num_sequence - ) - if kv_cache_config.prefill_chunk_size is None: - kv_cache_config.prefill_chunk_size = prefill_chunk_size - elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: - raise ValueError( - f"The specified prefill chunk size {kv_cache_config.prefill_chunk_size} is " - f"larger than the maximum prefill chunk size {prefill_chunk_size} supported by " - "models. Please specify a smaller prefill chunk size." - ) + def abort(self, request_id: str) -> None: + """Generation abortion interface. + + Parameter + --------- + request_id : str + The id of the request to abort. + """ + self._ffi["abort_request"](request_id) + + def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + """Synchronous chat completion interface with OpenAI API compatibility. + The method streams back ChatCompletionStreamResponse that conforms to + OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. - if engine_mode is None: - # The default engine mode: non-speculative - engine_mode = EngineMode() - - self._ffi["init"]( - max_single_sequence_length, - tokenizer_path, - kv_cache_config.asjson(), - engine_mode.asjson(), - request_stream_callback, - self.trace_recorder, - *model_args, + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : ChatCompletionStreamResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id=request_id, ) - self.tokenizer = Tokenizer(tokenizer_path) + for response in chatcmpl_generator: + yield response - def generate( # pylint: disable=too-many-locals + def completion( # pylint: disable=too-many-arguments,too-many-locals self, - prompts: Union[str, List[str], List[int], List[List[int]], List[List[data.Data]]], - generation_config: Union[GenerationConfig, List[GenerationConfig]], - ) -> Tuple[List[List[str]], List[Optional[List[List[str]]]]]: - """Generate texts for a list of input prompts. - Each prompt can be a string or a list of token ids. - The generation for each prompt is independent. - Return the generation results, one for each prompt. + *, + model: str, + prompt: Union[str, List[int]], + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """Synchronous completion interface with OpenAI API compatibility. + The method streams back CompletionResponse that conforms to + OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. Parameters ---------- - prompts : Union[str, List[str], List[int], List[List[int]]] - One or a list of input prompts for text generation. - Each prompt can be a string or a list of token ids. - - generation_config : Union[GenerationConfig, List[GenerationConfig]] - The generation config for each requests. - If the it is a single GenerationConfig instance, - this config will be shared by all the prompts. - Otherwise, one generation config is required for every - prompt. - - Returns - ------- - output_text : List[List[str]] - The text generation results, one list of strings for each input prompt. - The length of each list is the parallel generation `n` in - generation config. - - output_logprobs_str : List[Optional[List[List[str]]]] - The logprob strings of each token for each input prompt, or None - if an input prompt does not require logprobs. + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. """ - if isinstance(prompts, str): - # `prompts` is a single string. - prompts = [prompts] - else: - assert isinstance(prompts, list), ( - "Input `prompts` is expected to be a string, a list of " - "str, a list of token ids or multiple lists of token ids. " + if request_id is None: + request_id = f"cmpl-{engine_utils.random_uuid()}" + cmpl_generator = self._handle_completion( + openai_api_protocol.CompletionRequest( + model=model, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + suffix=suffix, + temperature=temperature, + top_p=top_p, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id, + ) + for response in cmpl_generator: + yield response + + def _handle_chat_completion( + self, request: openai_api_protocol.ChatCompletionRequest, request_id: str + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + """The implementation fo synchronous ChatCompletionRequest handling. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + ( + prompts, + generation_cfg, + use_function_calling, + prompt_length, + ) = engine_base.process_chat_completion_request( + request, + request_id, + self.state, + self.model_config_dicts[0], + self.tokenizer.encode, + self.max_input_sequence_length, + self.conv_template.model_copy(deep=True), + ) + + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + num_completion_tokens = 0 + self.state.record_event(request_id, event="invoke generate") + for delta_outputs in self._generate(prompts, generation_cfg, request_id): # type: ignore + response, num_completion_tokens = engine_base.process_chat_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + use_function_calling, + prompt_length, + finish_reasons, + num_completion_tokens, ) - if len(prompts) == 0: - return [], [] - if isinstance(prompts[0], int): - # `prompts` is a list of token ids - prompts = [prompts] # type: ignore - - num_requests = len(prompts) - if not isinstance(generation_config, list): - generation_config = [generation_config] * num_requests - - assert ( - len(generation_config) == num_requests - ), "Number of generation config and number of prompts mismatch" - - num_finished_generations = 0 - output_texts: List[List[str]] = [] - output_logprobs_str: List[Optional[List[List[str]]]] = [] - text_streamers: List[List[TextStreamer]] = [] - for i in range(num_requests): - output_texts.append([]) - output_logprobs_str.append([] if generation_config[i].logprobs else None) - text_streamers.append([]) - for _ in range(generation_config[i].n): - output_texts[i].append("") - text_streamers[i].append(TextStreamer(self.tokenizer)) - if output_logprobs_str[i] is not None: - output_logprobs_str[i].append([]) - - num_total_generations = sum(cfg.n for cfg in generation_config) - - # Save a copy of the original function callback since `generate` - # overrides the callback function. - # The original callback will be set back later on. - original_callback = self._ffi["get_request_stream_callback"]() - - # Define the callback function for request generation results - def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]): - nonlocal num_finished_generations - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - rid = int(request_id) - - assert len(stream_outputs) == generation_config[rid].n - for i, (stream_output, text_streamer) in enumerate( - zip(stream_outputs, text_streamers[rid]) - ): - if output_logprobs_str[rid] is not None: - assert stream_output.delta_logprob_json_strs is not None - output_logprobs_str[rid][i] += stream_output.delta_logprob_json_strs - - delta_text = ( - text_streamer.put(stream_output.delta_token_ids) - if len(stream_output.delta_token_ids) > 0 - else "" - ) - if stream_output.finish_reason is not None: - delta_text += text_streamer.finish() - - output_texts[rid][i] += delta_text - if stream_output.finish_reason is not None: - num_finished_generations += 1 - - # Override the callback function in engine. - self._ffi["set_request_stream_callback"](request_stream_callback) - - def convert_to_data(prompt: Union[str, List[int], List[data.Data]]) -> List[data.Data]: - if isinstance(prompt, str): - return [data.TextData(prompt)] - if isinstance(prompt[0], int): - return [data.TokenData(prompt)] # type: ignore - return prompt # type: ignore - - # Add requests to engine. - for req_id, (prompt, generation_cfg) in enumerate(zip(prompts, generation_config)): - input_data = convert_to_data(prompt) # type: ignore - self.add_request( - Request( - request_id=str(req_id), - inputs=input_data, - generation_config=generation_cfg, - ) + if response is not None: + yield response + self.state.record_event(request_id, event="finish") + + def _handle_completion( + self, request: openai_api_protocol.CompletionRequest, request_id: str + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """The implementation fo synchronous CompletionRequest handling. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + ( + prompt, + generation_cfg, + prompt_length, + echo_response, + ) = engine_base.process_completion_request( + request, + request_id, + self.state, + self.tokenizer, + self.max_input_sequence_length, + ) + if echo_response is not None: + yield echo_response + + num_completion_tokens = 0 + finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] + self.state.record_event(request_id, event="invoke generate") + for delta_outputs in self._generate(prompt, generation_cfg, request_id): # type: ignore + response, num_completion_tokens = engine_base.process_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + prompt_length, + finish_reasons, + num_completion_tokens, ) + if response is not None: + yield response - while num_finished_generations != num_total_generations: - self.step() - - # Restore the callback function in engine. - self._ffi["set_request_stream_callback"](original_callback) - return output_texts, output_logprobs_str + suffix_response = engine_base.create_completion_suffix_response( + request, request_id, prompt_length, finish_reasons, num_completion_tokens + ) + if suffix_response is not None: + yield suffix_response + self.state.record_event(request_id, event="finish") - def add_request(self, request: Request) -> None: - """Add a new request to the engine. + def _generate( # pylint: disable=too-many-locals + self, + prompt: Union[str, List[int], List[Union[str, List[int], data.Data]]], + generation_config: GenerationConfig, + request_id: str, + ) -> Iterator[List[engine_base.CallbackStreamOutput]]: + """Internal synchronous text generation interface of AsyncEngine. + The method is a coroutine that streams a list of CallbackStreamOutput + at a time via yield. The returned list length is the number of + parallel generations specified by `generation_config.n`. Parameters ---------- - request : Request - The request to add. - """ - self._ffi["add_request"](request) + prompt : Union[str, List[int], List[Union[str, List[int], data.Data]]] + The input prompt in forms of text strings, lists of token ids or data. - def abort_request(self, request_id: str) -> None: - """Abort the generation of the request corresponding to the input request id. + generation_config : GenerationConfig + The generation config of the request. - Parameters - ---------- request_id : str - The unique id of the request to abort. + The unique identifier (in string) or this generation request. + + Yields + ------ + request_output : List[engine_base.CallbackStreamOutput] + The delta generated outputs in a list. + The number of list elements equals to `generation_config.n`, + and each element corresponds to the delta output of a parallel + generation. """ - self._ffi["abort_request"](request_id) - - def step(self) -> None: - """The main function that the engine takes a step of action. - - At each step, the engine may decide to - - run prefill for one (or more) requests, - - run one-step decode for the all existing requests - ... + if self._terminated: + raise ValueError("The engine has terminated.") + + # Create the request with the given id, input data, generation + # config and the created callback. + input_data = engine_utils.convert_prompts_to_data(prompt) + request = Request(request_id, input_data, generation_config) + + # Record the stream in the tracker + self.state.sync_output_queue = queue.Queue() + self.state.sync_text_streamers = [ + TextStreamer(self.tokenizer) for _ in range(generation_config.n) + ] + self.state.sync_num_unfinished_generations = generation_config.n + self._ffi["add_request"](request) - In the end of certain actions (e.g., decode), the engine will - check if any request has finished, and will return the - generation results for those finished requests. - """ - self._ffi["step"]() - - def reset(self) -> None: - """Reset the engine, clean up all running data and statistics.""" - self._ffi["reset"]() - - def stats(self) -> Dict[str, float]: - """The engine runtime statistics. - We collect the following entries: - - single token prefill latency (s/tok): avg latency of processing one token in prefill - - single token decode latency (s/tok): avg latency of processing one token in decode - - engine time for prefill (sec) - - engine time for decode (sec) - - total number of processed tokens in prefill. - - total number of processed tokens in decode. - """ - stats_json_str = self._ffi["stats"]() - return json.loads(stats_json_str) + # Iterate the stream asynchronously and yield the token. + try: + while self.state.sync_num_unfinished_generations > 0: + delta_outputs = self.state.sync_output_queue.get() + request_outputs = self._request_stream_callback_impl(delta_outputs) + for request_output in request_outputs: + yield request_output + except Exception as exception: # pylint: disable=broad-exception-caught + self.abort(request_id) + raise exception + + def _request_stream_callback_impl( + self, delta_outputs: List[data.RequestStreamOutput] + ) -> List[List[engine_base.CallbackStreamOutput]]: + """The underlying implementation of request stream callback of Engine.""" + batch_outputs: List[List[engine_base.CallbackStreamOutput]] = [] + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + self.state.record_event(request_id, event="start callback") + outputs: List[engine_base.CallbackStreamOutput] = [] + for stream_output, text_streamer in zip(stream_outputs, self.state.sync_text_streamers): + self.state.record_event(request_id, event="start detokenization") + delta_text = ( + text_streamer.put(stream_output.delta_token_ids) + if len(stream_output.delta_token_ids) > 0 + else "" + ) + if stream_output.finish_reason is not None: + delta_text += text_streamer.finish() + self.state.record_event(request_id, event="finish detokenization") + + outputs.append( + engine_base.CallbackStreamOutput( + delta_text=delta_text, + num_delta_tokens=len(stream_output.delta_token_ids), + delta_logprob_json_strs=stream_output.delta_logprob_json_strs, + finish_reason=stream_output.finish_reason, + ) + ) + if stream_output.finish_reason is not None: + self.state.sync_num_unfinished_generations -= 1 + batch_outputs.append(outputs) + self.state.record_event(request_id, event="finish callback") + return batch_outputs diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py new file mode 100644 index 0000000000..21bb928df3 --- /dev/null +++ b/python/mlc_llm/serve/engine_base.py @@ -0,0 +1,1066 @@ +"""The MLC LLM Serving engine base class.""" + +# pylint: disable=too-many-lines + +import asyncio +import json +import os +import queue +import subprocess +import sys +import threading +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import tvm +from tvm.runtime import Device + +from mlc_llm.chat_module import _get_chat_config, _get_lib_module_path, _get_model_path +from mlc_llm.protocol import openai_api_protocol, protocol_utils +from mlc_llm.protocol.conversation_protocol import Conversation +from mlc_llm.serve import data, engine_utils +from mlc_llm.serve.config import EngineMode, GenerationConfig, KVCacheConfig +from mlc_llm.serve.event_trace_recorder import EventTraceRecorder +from mlc_llm.streamer import TextStreamer +from mlc_llm.support import logging +from mlc_llm.support.auto_device import detect_device +from mlc_llm.support.style import green +from mlc_llm.tokenizer import Tokenizer + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +@dataclass +class ModelInfo: + """The model info dataclass. + + Parameters + ---------- + model : str + The identifier of the input model. + It may be a compiled model's id (e.g., "Llama-2-7b-chat-hf-q4f16_1"), + or a full path to a model directory + (e.g., "dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1") + + device : str + The device where to run the model. + It can be "auto", "device_name" (e.g., "cuda") or + "device_name:device_id" (e.g., "cuda:1"). + + model_lib_path : str + The path to the compiled library of the model. + E.g., "dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + """ + + model: str + model_lib_path: str + device: Device = "auto" # type: ignore + + def __post_init__(self): + if isinstance(self.device, str): + self.device = detect_device(self.device) + assert isinstance(self.device, Device) + + +def _process_model_args( + models: List[ModelInfo], +) -> Tuple[List[Any], List[str], str, int, int, Conversation]: + """Process the input ModelInfo to get the engine initialization arguments.""" + max_single_sequence_length = int(1e9) + prefill_chunk_size = int(1e9) + tokenizer_path: Optional[str] = None + conversation: Optional[Conversation] = None + config_file_paths: List[str] = [] + + def _convert_model_info(model: ModelInfo) -> List[Any]: + nonlocal max_single_sequence_length, prefill_chunk_size, tokenizer_path, conversation + + device = model.device + model_path, config_file_path = _get_model_path(model.model) + config_file_paths.append(config_file_path) + chat_config = _get_chat_config(config_file_path, user_chat_config=None) + if chat_config.context_window_size and chat_config.context_window_size != -1: + max_single_sequence_length = min( + max_single_sequence_length, + chat_config.context_window_size, + ) + if chat_config.prefill_chunk_size: + prefill_chunk_size = min(prefill_chunk_size, chat_config.prefill_chunk_size) + if tokenizer_path is None: + tokenizer_path = model_path + if conversation is None: + assert isinstance(chat_config.conv_template, Conversation) + conversation = chat_config.conv_template + # Try look up model library, and do JIT compile if model library not found. + try: + model_lib_path = _get_lib_module_path( + model=model.model, + model_path=model_path, + chat_config=chat_config, + model_lib_path=model.model_lib_path, + device_name=device.MASK2STR[device.device_type], + config_file_path=config_file_path, + ) + except FileNotFoundError: + from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel + + model_lib_path = str( + jit.jit( + model_path=Path(model_path), + chat_config=asdict(chat_config), + device=device, + ) + ) + return [model_lib_path, model_path, device.device_type, device.device_id] + + model_args: List[Any] = sum( + (_convert_model_info(model) for model in models), + start=[], + ) + + assert prefill_chunk_size != int(1e9) + assert conversation is not None + return ( + model_args, + config_file_paths, + tokenizer_path, + max_single_sequence_length, + prefill_chunk_size, + conversation, + ) + + +def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals + models: List[ModelInfo], config_file_paths: List[str], max_num_sequence: int +) -> int: + """Estimate the max total sequence length (capacity) of the KV cache.""" + assert len(models) != 0 + + kv_bytes_per_token = 0 + kv_aux_workspace_bytes = 0 + model_workspace_bytes = 0 + logit_processor_workspace_bytes = 0 + params_bytes = 0 + temp_func_bytes = 0 + + for model, config_file_path in zip(models, config_file_paths): + # Read metadata for the parameter size and the temporary memory size. + cmd = [ + sys.executable, + "-m", + "mlc_llm.cli.model_metadata", + model.model_lib_path, + "--print-memory-usage-in-json", + "--mlc-chat-config", + config_file_path, + ] + usage_str = subprocess.check_output(cmd, universal_newlines=True) + usage_json = json.loads(usage_str) + params_bytes += usage_json["params_bytes"] + temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) + + cmd = [ + sys.executable, + "-m", + "mlc_llm.cli.model_metadata", + model.model_lib_path, + "--print-kv-cache-metadata-in-json", + ] + kv_cache_metadata_str = subprocess.check_output(cmd, universal_newlines=True) + kv_cache_metadata = json.loads(kv_cache_metadata_str) + + # Read model config and compute the kv size per token. + with open(config_file_path, mode="rt", encoding="utf-8") as file: + json_object = json.load(file) + model_config = json_object["model_config"] + vocab_size = model_config["vocab_size"] + prefill_chunk_size = model_config["prefill_chunk_size"] + num_layers = kv_cache_metadata["num_hidden_layers"] + head_dim = kv_cache_metadata["head_dim"] + num_qo_heads = kv_cache_metadata["num_attention_heads"] + num_kv_heads = kv_cache_metadata["num_key_value_heads"] + hidden_size = head_dim * num_qo_heads + kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25 + kv_aux_workspace_bytes += ( + (max_num_sequence + 1) * 88 + + prefill_chunk_size * (num_qo_heads + 1) * 8 + + prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 + + 48 * 1024 * 1024 + ) + model_workspace_bytes += ( + prefill_chunk_size * 4 + + max_num_sequence * 4 + + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 + ) + logit_processor_workspace_bytes += ( + max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 + ) + + # Get single-card GPU size. + gpu_size_bytes = os.environ.get("MLC_GPU_SIZE_BYTES", default=None) + if gpu_size_bytes is None: + gpu_size_bytes = models[0].device.total_global_memory + if gpu_size_bytes is None: + raise ValueError( + "Cannot read total GPU global memory from device. " + 'Please the GPU memory size in bytes through "MLC_GPU_SIZE_BYTES" env variable.' + ) + + max_total_sequence_length = int( + ( + int(gpu_size_bytes) * 0.90 + - params_bytes + - temp_func_bytes + - kv_aux_workspace_bytes + - model_workspace_bytes + - logit_processor_workspace_bytes + ) + / kv_bytes_per_token + ) + assert max_total_sequence_length > 0, ( + "Cannot estimate KV cache capacity. " + f"The model weight size {params_bytes} may be larger than GPU memory size {gpu_size_bytes}" + ) + + if models[0].device.device_type == Device.kDLMetal: + # NOTE: Metal runtime has severe performance issues with large buffers. + # To work around the issue, we limit the KV cache capacity to 32768. + max_total_sequence_length = min(max_total_sequence_length, 32768) + + total_size = ( + params_bytes + + temp_func_bytes + + kv_aux_workspace_bytes + + model_workspace_bytes + + logit_processor_workspace_bytes + + kv_bytes_per_token * max_total_sequence_length + ) + logger.info( + "%s: %d.", + green('Estimated KVCacheConfig "max_total_sequence_length"'), + max_total_sequence_length, + ) + logger.info( + "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB)", + green("Estimated total single GPU memory usage"), + total_size / 1024 / 1024, + params_bytes / 1024 / 1024, + (kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes) / 1024 / 1024, + (model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes) / 1024 / 1024, + ) + return int(max_total_sequence_length) + + +@dataclass +class CallbackStreamOutput: + """The output of Engine._generate and AsyncEngine._generate + + Attributes + ---------- + delta_text : str + The delta text generated since the last output. + + num_delta_tokens : int + The number of delta tokens generated since the last output. + + delta_logprob_json_strs : Optional[List[str]] + The list of logprob JSON strings since the last output, + or None if the request does not require logprobs. + + finish_reason : Optional[str] + The finish reason of the request, or None if unfinished. + """ + + delta_text: str + num_delta_tokens: int + delta_logprob_json_strs: Optional[List[str]] + finish_reason: Optional[str] + + +class AsyncRequestStream: + """The asynchronous stream for requests in AsyncEngine. + + Each request has its own unique stream. + The stream exposes the method `push` for engine to push new generated + delta text to the stream, and the method `finish` for engine to mark + the finish of generation. + + The stream implements `__aiter__` and `__anext__`, which the engine + can use to iterates all the generated tokens in order asynchronously. + """ + + # The asynchronous queue to hold elements of either a list of + # CallbackStreamOutput or an exception. + if sys.version_info >= (3, 9): + _queue: asyncio.Queue[ # pylint: disable=unsubscriptable-object + Union[List[CallbackStreamOutput], Exception] + ] + else: + _queue: asyncio.Queue + # The finish flag. + _finished: bool + + def __init__(self) -> None: + self._queue = asyncio.Queue() + self._finished = False + + def push(self, item_or_exception: Union[List[CallbackStreamOutput], Exception]) -> None: + """Push a new token to the stream.""" + if self._finished: + # No new item is expected after finish. + self._queue.put_nowait( + RuntimeError( + "The request has already finished. " + "The stream is not supposed to accept new items." + ) + ) + return + self._queue.put_nowait(item_or_exception) + + def finish(self) -> None: + """Mark the finish of the generation in the stream.""" + self._queue.put_nowait(StopIteration()) + self._finished = True + + def __aiter__(self): + return self + + async def __anext__(self) -> List[CallbackStreamOutput]: + result = await self._queue.get() + if isinstance(result, StopIteration): + raise StopAsyncIteration + if isinstance(result, Exception): + raise result + return result + + +class EngineState: + """The engine states that the request stream callback function may use. + + This class is used for both AsyncEngine and Engine. + AsyncEngine uses the fields and methods starting with "async", + and Engine uses the ones starting with "sync". + + - For AsyncEngine, the state contains an asynchronous event loop, + the streamers and the number of unfinished generations for each request + being processed. + - For Engine, the state contains a callback output blocking queue, + the text streamers and the number of unfinished requests. + + We use this state class to avoid the callback function from capturing + the AsyncEngine. + + The state also optionally maintains an event trace recorder, which can + provide Chrome tracing when enabled. + """ + + trace_recorder = None + # States used for AsyncEngine + async_event_loop: Optional[asyncio.AbstractEventLoop] = None + async_streamers: Dict[str, Tuple[AsyncRequestStream, List[TextStreamer]]] = {} + async_num_unfinished_generations: Dict[str, int] = {} + # States used for Engine + sync_output_queue: queue.Queue = queue.Queue() + sync_text_streamers: List[TextStreamer] = [] + sync_num_unfinished_generations: int = 0 + + def __init__(self, enable_tracing: bool) -> None: + """Constructor.""" + if enable_tracing: + self.trace_recorder = EventTraceRecorder() + + def record_event(self, request_id: str, event: str) -> None: + """Record a event for the the input request in the trace + recorder when the recorder exists. + + Parameters + ---------- + request_id : str + The subject request of the event. + + event : str + The event in a string name. + It can have one of the following patterns: + - "start xxx", which marks the start of event "xxx", + - "finish xxx", which marks the finish of event "xxx", + - "yyy", which marks the instant event "yyy". + The "starts" and "finishes" will be automatically paired in the trace recorder. + """ + if self.trace_recorder is None: + return + self.trace_recorder.add_event(request_id, event) + + def get_request_stream_callback( + self, kind: Literal["async", "sync"] + ) -> Callable[[List[data.RequestStreamOutput]], None]: + """Construct a callback function and return. + + The callback function has signature + "Callable[[List[data.RequestStreamOutput]], None]", + whose input is a list of "data.RequestStreamOutput". + Each "data.RequestStreamOutput" is the delta output of a request, + generated from the engine. + """ + + f_callback = ( + self._async_request_stream_callback + if kind == "async" + else self._sync_request_stream_callback + ) + + def _callback(delta_outputs: List[data.RequestStreamOutput]) -> None: + f_callback(delta_outputs) + + return _callback + + def async_lazy_init_event_loop(self) -> None: + """Lazily set the asyncio event loop so that the event + loop is the main driving event loop of the process. + """ + if self.async_event_loop is None: + self.async_event_loop = asyncio.get_event_loop() + + def _async_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: + """The request stream callback function for AsyncEngine to stream back + the request generation results. + + Note + ---- + This callback function uses `call_soon_threadsafe` in asyncio to + schedule the invocation in the event loop, so that the underlying + callback logic will be executed asynchronously in the future rather + than right now. + """ + + # Schedule a callback run in the event loop without executing right now. + # NOTE: This function causes GIL during execution. + self.async_event_loop.call_soon_threadsafe( + self._async_request_stream_callback_impl, delta_outputs + ) + + def _async_request_stream_callback_impl( + self, delta_outputs: List[data.RequestStreamOutput] + ) -> None: + """The underlying implementation of request stream callback for AsyncEngine.""" + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + streamers = self.async_streamers.get(request_id, None) + if streamers is None: + continue + + self.record_event(request_id, event="start callback") + stream, text_streamers = streamers + outputs = [] + for stream_output, text_streamer in zip(stream_outputs, text_streamers): + self.record_event(request_id, event="start detokenization") + delta_text = ( + text_streamer.put(stream_output.delta_token_ids) + if len(stream_output.delta_token_ids) > 0 + else "" + ) + if stream_output.finish_reason is not None: + delta_text += text_streamer.finish() + self.record_event(request_id, event="finish detokenization") + + outputs.append( + CallbackStreamOutput( + delta_text=delta_text, + num_delta_tokens=len(stream_output.delta_token_ids), + delta_logprob_json_strs=stream_output.delta_logprob_json_strs, + finish_reason=stream_output.finish_reason, + ) + ) + if stream_output.finish_reason is not None: + self.async_num_unfinished_generations[request_id] -= 1 + + # Push new delta text to the stream. + stream.push(outputs) + if self.async_num_unfinished_generations[request_id] == 0: + stream.finish() + self.async_streamers.pop(request_id, None) + self.async_num_unfinished_generations.pop(request_id, None) + self.record_event(request_id, event="finish callback") + + def _sync_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: + """The request stream callback function for Engine to stream back + the request generation results. + """ + # Put the delta outputs to the queue in the unblocking way. + self.sync_output_queue.put_nowait(delta_outputs) + + +class EngineBase: # pylint: disable=too-many-instance-attributes,too-few-public-methods + """The base engine class, which implements common functions that + are shared by Engine and AsyncEngine. + + This class wraps a threaded engine that runs on a standalone + thread inside and streams back the delta generated results via + callback functions. The internal threaded engine keeps running an + loop that drives the engine. + + Engine and AsyncEngine inherits this EngineBase class, and implements + their own methods to process the delta generated results received + from callback functions and yield the processed delta results in + the forms of standard API protocols. + + Parameters + ---------- + kind : Literal["async", "sync"] + The kind of the engine. "async" for AsyncEngine and "sync" for Engine. + + models : Union[ModelInfo, List[ModelInfo]] + One or a list of model info (specifying which models to load and + which device to load to) to launch the engine. + + kv_cache_config : KVCacheConfig + The configuration of the paged KV cache. + + engine_mode : Optional[EngineMode] + The Engine execution mode. + + enable_tracing : bool + A boolean indicating if to enable event logging for requests. + """ + + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + kind: Literal["async", "sync"], + models: Union[ModelInfo, List[ModelInfo]], + kv_cache_config: KVCacheConfig, + engine_mode: Optional[EngineMode] = None, + enable_tracing: bool = False, + ) -> None: + if isinstance(models, ModelInfo): + models = [models] + ( + model_args, + config_file_paths, + tokenizer_path, + max_single_sequence_length, + prefill_chunk_size, + self.conv_template, + ) = _process_model_args(models) + + self.model_config_dicts = [] + for i, model in enumerate(models): + # model_args: + # [model_lib_path, model_path, device.device_type, device.device_id] * N + model.model_lib_path = model_args[i * (len(model_args) // len(models))] + with open(config_file_paths[i], "r", encoding="utf-8") as file: + self.model_config_dicts.append(json.load(file)) + + self.state = EngineState(enable_tracing) + self.max_input_sequence_length = max_single_sequence_length + + if kv_cache_config.max_total_sequence_length is None: + kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( + models, config_file_paths, kv_cache_config.max_num_sequence + ) + if kv_cache_config.prefill_chunk_size is None: + kv_cache_config.prefill_chunk_size = prefill_chunk_size + elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: + raise ValueError( + f"The specified prefill chunk size {kv_cache_config.prefill_chunk_size} is " + f"larger than the maximum prefill chunk size {prefill_chunk_size} supported by " + "models. Please specify a smaller prefill chunk size." + ) + + module = tvm.get_global_func("mlc.serve.create_threaded_engine", allow_missing=False)() + self._ffi = { + key: module[key] + for key in [ + "add_request", + "abort_request", + "run_background_loop", + "run_background_stream_back_loop", + "init_background_engine", + "exit_background_loop", + ] + } + self.tokenizer = Tokenizer(tokenizer_path) + if engine_mode is None: + # The default engine mode: non-speculative + engine_mode = EngineMode() + + def _background_loop(): + self._ffi["init_background_engine"]( + max_single_sequence_length, + tokenizer_path, + kv_cache_config.asjson(), + engine_mode.asjson(), + self.state.get_request_stream_callback(kind), + self.state.trace_recorder, + *model_args, + ) + self._ffi["run_background_loop"]() + + def _background_stream_back_loop(): + self._ffi["run_background_stream_back_loop"]() + + # Create the background engine-driving thread and start the loop. + self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_stream_back_loop_thread: threading.Thread = threading.Thread( + target=_background_stream_back_loop + ) + self._background_loop_thread.start() + self._background_stream_back_loop_thread.start() + self._terminated = False + + def terminate(self): + """Terminate the engine.""" + self._terminated = True + self._ffi["exit_background_loop"]() + self._background_loop_thread.join() + self._background_stream_back_loop_thread.join() + + +def process_chat_completion_request( # pylint: disable=too-many-arguments + request: openai_api_protocol.ChatCompletionRequest, + request_id: str, + engine_state: EngineState, + model_config: Dict[str, Any], + f_tokenize: Callable[[str], List[int]], + max_input_sequence_length: int, + conv_template: Conversation, +) -> Tuple[List[Union[List[int], data.Data]], GenerationConfig, bool, int]: + """Process the given ChatCompletionRequest, apply request validity + checks, and return the processed prompts, and other info. + + Parameters + ---------- + request : openai_api_protocol.ChatCompletionRequest + The request to be processed and checked. + + request_id : str + The id of the request. + + engine_state : EngineState + The state of the engine. + + model_config : Dict[str, Any] + The model configuration dictionary. + + f_tokenize : Callable[[str], List[int]] + The tokenizer encode function. + + max_input_sequence_length : int + The maximum allowed total prompt length. + + conv_template : Conversation + The conversation template of the model. + + Returns + ------- + prompts : List[Union[List[int], data.Data]] + The prompts, in a list. + Each element is a list of token ids or a "data.Data" instance. + + generation_cfg : GenerationConfig + The generation config of the request got from the input request. + + use_function_calling : bool + A boolean flag indicating if the request uses function call. + + prompt_length : int + The total prompt length. + """ + engine_state.record_event(request_id, event="receive request") + # - Check if unsupported arguments are specified. + engine_utils.check_unsupported_fields(request) + + # - Process messages and update the conversation template in three steps: + # i. Check the message validity. + # ii. Add the input messages to the conversation template. + # iii. Add the additional message for the assistant. + request.check_message_validity() + # - Check for function calling usage and update the conversation template + request.check_function_call_usage(conv_template) + + for message in request.messages: + role = message.role + content = message.content + if role == "system": + assert isinstance(content, str) + conv_template.system_message = content if content is not None else "" + continue + assert role != "tool", "Internal error: tool role." + conv_template.messages.append((role, content)) + conv_template.messages.append(("assistant", None)) + + # - Get the prompt from template, and encode to token ids. + # - Check prompt length + engine_state.record_event(request_id, event="start tokenization") + prompts = engine_utils.process_prompts( # type: ignore + conv_template.as_prompt(model_config), f_tokenize + ) + engine_state.record_event(request_id, event="finish tokenization") + + if conv_template.system_prefix_token_ids is not None: + if isinstance(prompts[0], list): + prompts[0] = conv_template.system_prefix_token_ids + prompts[0] + else: + prompts.insert(0, conv_template.system_prefix_token_ids) + prompt_length = engine_utils.check_and_get_prompts_length(prompts, max_input_sequence_length) + + # Process generation config. Create request id. + generation_cfg = protocol_utils.get_generation_config( + request, + extra_stop_token_ids=conv_template.stop_token_ids, + extra_stop_str=conv_template.stop_str, + ) + return prompts, generation_cfg, conv_template.use_function_calling, prompt_length + + +def process_chat_completion_stream_output( # pylint: disable=too-many-arguments + delta_outputs: List[CallbackStreamOutput], + request_id: str, + engine_state: EngineState, + model: str, + generation_cfg: GenerationConfig, + use_function_calling: bool, + prompt_length: int, + finish_reasons: List[Optional[str]], + num_completion_tokens: int, +) -> Tuple[Optional[openai_api_protocol.ChatCompletionStreamResponse], int]: + """Process the delta outputs of a single request of ChatCompletion, + convert the delta output to ChatCompletionStreamResponse and return. + + Parameters + ---------- + delta_outputs : List[CallbackStreamOutput] + The delta outputs of a request. + The list length is the number of parallel generation specified by "n". + Each element corresponds to a generation. + + request_id : str + The id of the request. + + engine_state : EngineState + The state of the engine. + + model : str + The requested model. + + generation_cfg : GenerationConfig + The generation config of the request. + + use_function_calling : bool + A boolean flag indicating if the request uses function call. + + prompt_length : int + The total prompt length. + + finish_reasons : List[Optional[str]] + The list of finish reasons of each generation. + The list length is the number of parallel generation specified by "n". + This list is updated in place. + + num_completion_tokens : int + The number of total completion tokens so far. + + Returns + ------- + response : Optional[openai_api_protocol.ChatCompletionStreamResponse] + The converted OpenAI API ChatCompletionStreamResponse instance. + It can be none when there is no content. + + num_completion_tokens : int + The updated number of total completion tokens. + It is sum of the input number and the number of new completion tokens + from the given delta outputs. + """ + assert len(delta_outputs) == generation_cfg.n + choices = [] + num_new_completion_tokens = 0 + for i, delta_output in enumerate(delta_outputs): + finish_reason_updated = False + num_new_completion_tokens += delta_output.num_delta_tokens + if delta_output.finish_reason is not None and finish_reasons[i] is None: + finish_reasons[i] = ( + delta_output.finish_reason if not use_function_calling else "tool_calls" + ) + finish_reason_updated = True + if not finish_reason_updated and delta_output.delta_text == "": + # Ignore empty delta text when finish reason is not updated. + engine_state.record_event(request_id, event="skip empty delta text") + continue + + choices.append( + openai_api_protocol.ChatCompletionStreamResponseChoice( + index=i, + finish_reason=finish_reasons[i], + delta=openai_api_protocol.ChatCompletionMessage( + content=delta_output.delta_text, role="assistant" + ), + logprobs=( + openai_api_protocol.LogProbs( + content=[ + openai_api_protocol.LogProbsContent.model_validate_json( + logprob_json_str + ) + for logprob_json_str in delta_output.delta_logprob_json_strs + ] + ) + if delta_output.delta_logprob_json_strs is not None + else None + ), + ) + ) + + if len(choices) == 0 and num_new_completion_tokens == 0: + # Skip return when there is no delta output and no number of completion tokens. + return None, num_completion_tokens + num_completion_tokens += num_new_completion_tokens + response = openai_api_protocol.ChatCompletionStreamResponse( + id=request_id, + choices=choices, + model=model, + system_fingerprint="", + usage=openai_api_protocol.UsageInfo( + prompt_tokens=prompt_length, + completion_tokens=num_completion_tokens, + ), + ) + engine_state.record_event(request_id, event="yield delta output") + return response, num_completion_tokens + + +def process_completion_request( + request: openai_api_protocol.CompletionRequest, + request_id: str, + engine_state: EngineState, + tokenizer: Tokenizer, + max_input_sequence_length: int, +) -> Tuple[List[int], GenerationConfig, int, Optional[openai_api_protocol.CompletionResponse]]: + """Process the given CompletionRequest, apply request validity + checks, and return the processed prompts, and other info. + + Parameters + ---------- + request : openai_api_protocol.CompletionRequest + The request to be processed and checked. + + request_id : str + The id of the request. + + engine_state : EngineState + The state of the engine. + + tokenizer : Tokenizer + The tokenizer instance of the model. + + max_input_sequence_length : int + The maximum allowed total prompt length. + + Returns + ------- + prompt : List[int] + The prompt in a list of token ids. + + generation_cfg : GenerationConfig + The generation config of the request got from the input request. + + prompt_length : int + The total prompt length. + + echo_response : Optional[openai_api_protocol.CompletionResponse] + The CompletionResponse of the echoing part, when argument "echo" + of the input request is specified. + """ + engine_state.record_event(request_id, event="receive request") + # - Check if unsupported arguments are specified. + engine_utils.check_unsupported_fields(request) + + # - Process prompt and check validity. + engine_state.record_event(request_id, event="start tokenization") + prompts = engine_utils.process_prompts(request.prompt, tokenizer.encode) + engine_state.record_event(request_id, event="finish tokenization") + prompt_length = engine_utils.check_and_get_prompts_length(prompts, max_input_sequence_length) + prompt = prompts[0] + assert isinstance(prompt, list) + + # Process generation config. Create request id. + generation_cfg = protocol_utils.get_generation_config(request) + + # - Echo back the prompt. + echo_response = None + if request.echo: + text = tokenizer.decode(prompt) + response = openai_api_protocol.CompletionResponse( + id=request_id, + choices=[ + openai_api_protocol.CompletionResponseChoice(index=i, text=text) + for i in range(generation_cfg.n) + ], + model=request.model, + usage=openai_api_protocol.UsageInfo( + prompt_tokens=prompt_length, + completion_tokens=0, + ), + ) + echo_response = response + return prompt, generation_cfg, prompt_length, echo_response + + +def process_completion_stream_output( # pylint: disable=too-many-arguments + delta_outputs: List[CallbackStreamOutput], + request_id: str, + engine_state: EngineState, + model: str, + generation_cfg: GenerationConfig, + prompt_length: int, + finish_reasons: List[Optional[str]], + num_completion_tokens: int, +) -> Tuple[Optional[openai_api_protocol.CompletionResponse], int]: + """Process the delta outputs of a single request of Completion, + convert the delta output to CompletionResponse and return. + + Parameters + ---------- + delta_outputs : List[CallbackStreamOutput] + The delta outputs of a request. + The list length is the number of parallel generation specified by "n". + Each element corresponds to a generation. + + request_id : str + The id of the request. + + engine_state : EngineState + The state of the engine. + + model : str + The requested model. + + generation_cfg : GenerationConfig + The generation config of the request. + + prompt_length : int + The total prompt length. + + finish_reasons : List[Optional[str]] + The list of finish reasons of each generation. + The list length is the number of parallel generation specified by "n". + This list is updated in place. + + num_completion_tokens : int + The number of total completion tokens so far. + + Returns + ------- + response : Optional[openai_api_protocol.CompletionResponse] + The converted OpenAI API CompletionResponse instance. + It can be none when there is no content. + + num_completion_tokens : int + The updated number of total completion tokens. + It is sum of the input number and the number of new completion tokens + from the given delta outputs. + """ + assert len(delta_outputs) == generation_cfg.n + choices = [] + num_new_completion_tokens = 0 + for i, delta_output in enumerate(delta_outputs): + finish_reason_updated = False + if delta_output.finish_reason is not None and finish_reasons[i] is None: + finish_reasons[i] = delta_output.finish_reason + finish_reason_updated = True + num_new_completion_tokens += delta_output.num_delta_tokens + if not finish_reason_updated and delta_output.delta_text == "": + # Ignore empty delta text when finish reason is not updated. + continue + + choices.append( + openai_api_protocol.CompletionResponseChoice( + index=i, + finish_reason=finish_reasons[i], + text=delta_output.delta_text, + logprobs=( + openai_api_protocol.LogProbs( + content=[ + openai_api_protocol.LogProbsContent.model_validate_json( + logprob_json_str + ) + for logprob_json_str in delta_output.delta_logprob_json_strs + ] + ) + if delta_output.delta_logprob_json_strs is not None + else None + ), + ) + ) + + if len(choices) == 0 and num_new_completion_tokens == 0: + # Skip return when there is no delta output and no number of completion tokens. + return None, num_completion_tokens + num_completion_tokens += num_new_completion_tokens + response = openai_api_protocol.CompletionResponse( + id=request_id, + choices=choices, + model=model, + usage=openai_api_protocol.UsageInfo( + prompt_tokens=prompt_length, + completion_tokens=num_completion_tokens, + ), + ) + engine_state.record_event(request_id, event="yield delta output") + return response, num_completion_tokens + + +def create_completion_suffix_response( + request: openai_api_protocol.CompletionRequest, + request_id: str, + prompt_length: int, + finish_reasons: List[Optional[str]], + num_completion_tokens: int, +) -> Optional[openai_api_protocol.CompletionResponse]: + """Create the suffix response of Completion request + when the request requires suffix. + + Parameters + ---------- + request : openai_api_protocol.CompletionRequest + The request whose suffix response if to be created. + + request_id : str + The id of the request. + + prompt_length : int + The total prompt length. + + finish_reasons : List[Optional[str]] + The list of finish reasons of each generation. + The list length is the number of parallel generation specified by "n". + This list is updated in place. + + num_completion_tokens : int + The number of total completion tokens so far. + + Returns + ------- + suffix_response : Optional[openai_api_protocol.CompletionResponse] + The created OpenAI API CompletionResponse instance for the suffix. + Or None if the request does not require suffix. + """ + # - Echo the suffix. + if request.suffix is None: + return None + assert all(finish_reason is not None for finish_reason in finish_reasons) + response = openai_api_protocol.CompletionResponse( + id=request_id, + choices=[ + openai_api_protocol.CompletionResponseChoice( + index=i, + finish_reason=finish_reason, + text=request.suffix, + ) + for i, finish_reason in enumerate(finish_reasons) + ], + model=request.model, + usage=openai_api_protocol.UsageInfo( + prompt_tokens=prompt_length, + completion_tokens=num_completion_tokens, + ), + ) + return response diff --git a/python/mlc_llm/serve/engine_utils.py b/python/mlc_llm/serve/engine_utils.py new file mode 100644 index 0000000000..d1c96e37d4 --- /dev/null +++ b/python/mlc_llm/serve/engine_utils.py @@ -0,0 +1,97 @@ +"""Utility functions for MLC Serve engine""" + +import uuid +from typing import Callable, List, Union + +from mlc_llm.serve import data + +from ..protocol import RequestProtocol, error_protocol, protocol_utils + + +def random_uuid() -> str: + """Generate a random id in hexadecimal string.""" + return uuid.uuid4().hex + + +def check_unsupported_fields(request: RequestProtocol) -> None: + """Check if the request has unsupported fields. Raise BadRequestError if so.""" + unsupported_fields = protocol_utils.get_unsupported_fields(request) + if len(unsupported_fields) != 0: + unsupported_fields = [f'"{field}"' for field in unsupported_fields] + raise error_protocol.BadRequestError( + f'Request fields {", ".join(unsupported_fields)} are not supported right now.', + ) + + +def check_and_get_prompts_length( + prompts: List[Union[List[int], data.ImageData]], max_input_sequence_length: int +) -> int: + """Check if the total prompt length exceeds the max single sequence + sequence length allowed by the served model. Raise BadRequestError if so. + Return the total prompt length. + """ + total_length: int = 0 + for prompt in prompts: + total_length += len(prompt) + if total_length > max_input_sequence_length: + raise error_protocol.BadRequestError( + f"Request prompt has {total_length} tokens in total," + f" larger than the model input length limit {max_input_sequence_length}.", + ) + return total_length + + +def process_prompts( + input_prompts: Union[str, List[int], List[Union[str, List[int], data.ImageData]]], + ftokenize: Callable[[str], List[int]], +) -> List[Union[List[int], data.ImageData]]: + """Convert all input tokens to list of token ids with regard to the + given tokenization function. + For each input prompt, return the list of token ids after tokenization. + """ + error_msg = f"Invalid request prompt {input_prompts}" + + # Case 1. The prompt is a single string. + if isinstance(input_prompts, str): + return [ftokenize(input_prompts)] + + assert isinstance(input_prompts, list) + if len(input_prompts) == 0: + raise error_protocol.BadRequestError(error_msg) + + # Case 2. The prompt is a list of token ids. + if isinstance(input_prompts[0], int): + assert isinstance(input_prompts, list) + if not all(isinstance(token_id, int) for token_id in input_prompts): + raise error_protocol.BadRequestError(error_msg) + return [input_prompts] # type: ignore + + # Case 3. A list of prompts. + output_prompts: List[Union[List[int], data.ImageData]] = [] + for input_prompt in input_prompts: + if isinstance(input_prompt, str): + output_prompts.append(ftokenize(input_prompt)) + elif isinstance(input_prompt, list) and all( + isinstance(token_id, int) for token_id in input_prompt + ): + output_prompts.append(input_prompt) + elif isinstance(input_prompt, data.ImageData): + output_prompts.append(input_prompt) + else: + raise error_protocol.BadRequestError(error_msg) + return output_prompts + + +def convert_prompts_to_data( + prompts: Union[str, List[int], List[Union[str, List[int], data.Data]]] +) -> List[data.Data]: + """Convert the given prompts in the combination of token id lists + and/or data to all data.""" + if isinstance(prompts, data.Data): + return [prompts] + if isinstance(prompts, str): + return [data.TextData(prompts)] + if isinstance(prompts[0], int): + assert isinstance(prompts, list) and all(isinstance(token_id, int) for token_id in prompts) + return [data.TokenData(prompts)] # type: ignore + return [convert_prompts_to_data(x)[0] for x in prompts] # type: ignore diff --git a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py index b95fd4faae..fe76696163 100644 --- a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py @@ -5,8 +5,8 @@ import fastapi -from ..server import ServerContext -from . import entrypoint_utils +from mlc_llm.protocol import error_protocol +from mlc_llm.serve.server import ServerContext app = fastapi.APIRouter() @@ -26,11 +26,11 @@ async def debug_dump_event_trace(request: fastapi.Request): # Parse the JSON string request_dict = json.loads(request_json_str) except json.JSONDecodeError: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" ) if "model" not in request_dict: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" ) @@ -41,11 +41,11 @@ async def debug_dump_event_trace(request: fastapi.Request): async_engine = server_context.get_engine(model) if async_engine is None: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{model}" is not served.' ) if async_engine.state.trace_recorder is None: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{model}" does not enable tracing' ) diff --git a/python/mlc_llm/serve/entrypoints/entrypoint_utils.py b/python/mlc_llm/serve/entrypoints/entrypoint_utils.py deleted file mode 100644 index b0895f2fe7..0000000000 --- a/python/mlc_llm/serve/entrypoints/entrypoint_utils.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Utility functions for server entrypoints""" - -import uuid -from http import HTTPStatus -from io import BytesIO -from typing import Callable, Dict, List, Optional, Union - -import fastapi - -from mlc_llm.serve import data - -from ...protocol import RequestProtocol -from ...protocol.protocol_utils import ErrorResponse, get_unsupported_fields - - -def random_uuid() -> str: - """Generate a random id in hexadecimal string.""" - return uuid.uuid4().hex - - -def create_error_response(status_code: HTTPStatus, message: str) -> fastapi.responses.JSONResponse: - """Create a JSON response that reports error with regarding the input message.""" - return fastapi.responses.JSONResponse( - ErrorResponse(message=message, code=status_code.value).model_dump_json(), - status_code=status_code.value, - ) - - -def check_unsupported_fields( - request: RequestProtocol, -) -> Optional[fastapi.responses.JSONResponse]: - """Check if the request has unsupported fields. Return an error if so.""" - unsupported_fields = get_unsupported_fields(request) - if len(unsupported_fields) != 0: - unsupported_fields = [f'"{field}"' for field in unsupported_fields] - return create_error_response( - HTTPStatus.BAD_REQUEST, - message=f'Request fields {", ".join(unsupported_fields)} are not supported right now.', - ) - return None - - -def check_prompts_length( - prompts: List[List[int]], max_input_sequence_length: int -) -> Optional[fastapi.responses.JSONResponse]: - """Check if the total prompt length exceeds the max single sequence - sequence length allowed by the served model. Return an error if so. - """ - total_length = 0 - for prompt in prompts: - total_length += len(prompt) - if total_length > max_input_sequence_length: - return create_error_response( - HTTPStatus.BAD_REQUEST, - message=f"Request prompt has {total_length} tokens in total," - f" larger than the model input length limit {max_input_sequence_length}.", - ) - return None - - -def process_prompts( - input_prompts: Union[ - str, List[int], List[Union[str, List[int]]], List[Union[str, data.ImageData]] - ], - ftokenize: Callable[[str], List[int]], -) -> Union[List[Union[List[int], data.ImageData]], fastapi.responses.JSONResponse]: - """Convert all input tokens to list of token ids with regard to the - given tokenization function. - For each input prompt, return the list of token ids after tokenization. - """ - error_msg = f"Invalid request prompt {input_prompts}" - - # Case 1. The prompt is a single string. - if isinstance(input_prompts, str): - return [ftokenize(input_prompts)] - - assert isinstance(input_prompts, list) - if len(input_prompts) == 0: - return create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) - - # Case 2. The prompt is a list of token ids. - if isinstance(input_prompts[0], int): - if not all(isinstance(token_id, int) for token_id in input_prompts): - return create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) - return [input_prompts] - - # Case 3. A list of prompts. - output_prompts: List[List[int]] = [] - for input_prompt in input_prompts: - is_str = isinstance(input_prompt, str) - is_token_ids = isinstance(input_prompt, list) and all( - isinstance(token_id, int) for token_id in input_prompt - ) - is_image = isinstance(input_prompt, data.ImageData) - if not (is_str or is_token_ids or is_image): - return create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) - output_prompts.append(ftokenize(input_prompt) if is_str else input_prompt) # type: ignore - return output_prompts - - -def get_image_from_url(url: str, config: Dict) -> data.ImageData: - """Get the image from the given URL, process and return the image tensor as TVM NDArray.""" - - # pylint: disable=import-outside-toplevel, import-error - import base64 - - import requests - import tvm - from PIL import Image - from transformers import CLIPImageProcessor - - if url.startswith("data:image"): - # The image is encoded in base64 format - base64_image = url.split(",")[1] - image_data = base64.b64decode(base64_image) - image_tensor = Image.open(BytesIO(image_data)).convert("RGB") - elif url.startswith("http"): - response = requests.get(url, timeout=5) - image_tensor = Image.open(BytesIO(response.content)).convert("RGB") - else: - raise ValueError(f"Unsupported image URL format: {url}") - - image_input_size = get_image_input_size(config) - image_embed_size = get_image_embed_size(config) - - image_processor = CLIPImageProcessor( - size={"shortest_edge": image_input_size}, - crop_size={"height": image_input_size, "width": image_input_size}, - ) - image_features = tvm.nd.array( - image_processor.preprocess(image_tensor, return_tensors="np")["pixel_values"].astype( - "float16" - ) - ) - image_data = data.ImageData(image_features, image_embed_size) - return image_data - - -def get_image_embed_size(config: Dict) -> int: - """Get the image embedding size from the model config file.""" - image_size = config["model_config"]["vision_config"]["image_size"] - patch_size = config["model_config"]["vision_config"]["patch_size"] - embed_size = (image_size // patch_size) ** 2 - return embed_size - - -def get_image_input_size(config: Dict) -> int: - """Get the image input size from the model config file.""" - image_size = config["model_config"]["vision_config"]["image_size"] - return image_size diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index ac8503d5df..0625ea6aae 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -2,23 +2,17 @@ # pylint: disable=too-many-locals,too-many-return-statements,too-many-statements import ast -import json from http import HTTPStatus -from typing import AsyncGenerator, Dict, List, Optional, Sequence, Union +from typing import AsyncGenerator, Dict, List, Optional, Union import fastapi -from mlc_llm.serve import data - -from ...protocol import protocol_utils -from ...protocol.conversation_protocol import Conversation -from ...protocol.openai_api_protocol import ( +from mlc_llm.protocol import error_protocol +from mlc_llm.protocol.openai_api_protocol import ( ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, - ChatCompletionStreamResponse, - ChatCompletionStreamResponseChoice, ChatFunctionCall, ChatToolCall, CompletionRequest, @@ -30,8 +24,8 @@ ModelResponse, UsageInfo, ) -from ..server import ServerContext -from . import entrypoint_utils +from mlc_llm.serve import engine_utils +from mlc_llm.serve.server import ServerContext app = fastapi.APIRouter() @@ -59,130 +53,30 @@ async def request_completion(request: CompletionRequest, raw_request: fastapi.Re server_context: ServerContext = ServerContext.current() async_engine = server_context.get_engine(request.model) if async_engine is None: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{request.model}" is not served.' ) - request_id = f"cmpl-{entrypoint_utils.random_uuid()}" - async_engine.state.record_event(request_id, event="receive request") - - # - Check if unsupported arguments are specified. - error = entrypoint_utils.check_unsupported_fields(request) - if error is not None: - return error - - # - Process prompt and check validity. - async_engine.state.record_event(request_id, event="start tokenization") - prompts = entrypoint_utils.process_prompts(request.prompt, async_engine.tokenizer.encode) - async_engine.state.record_event(request_id, event="finish tokenization") - if isinstance(prompts, fastapi.responses.JSONResponse): - # Errored when processing the prompts - return prompts - if len(prompts) > 1: - return entrypoint_utils.create_error_response( - HTTPStatus.BAD_REQUEST, - message="Entrypoint /v1/completions only accept single prompt. " - f"However, {len(prompts)} prompts {prompts} are received.", - ) - error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_input_sequence_length) - if error is not None: - return error - prompt = prompts[0] - - # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config(request) + request_id = f"cmpl-{engine_utils.random_uuid()}" # Streaming response. if request.stream: + # We manually get the first response from generator to + # capture potential exceptions in this scope, rather then + # the StreamingResponse scope. + stream_generator = async_engine._handle_completion( # pylint: disable=protected-access + request, request_id + ) + first_response = await anext( # type: ignore # pylint: disable=undefined-variable + stream_generator + ) async def completion_stream_generator() -> AsyncGenerator[str, None]: - # - Echo back the prompt. - if request.echo: - text = async_engine.tokenizer.decode(prompt) - response = CompletionResponse( - id=request_id, - choices=[ - CompletionResponseChoice(index=i, text=text) - for i in range(generation_cfg.n) - ], - model=request.model, - usage=UsageInfo( - prompt_tokens=len(prompt), - completion_tokens=0, - ), - ) + if isinstance(first_response, StopAsyncIteration): + yield "data: [DONE]\n\n" + return + yield f"data: {first_response.model_dump_json()}\n\n" + async for response in stream_generator: yield f"data: {response.model_dump_json()}\n\n" - - # - Generate new tokens. - num_completion_tokens = 0 - finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - async_engine.state.record_event(request_id, event="invoke generate") - async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): - assert len(delta_outputs) == generation_cfg.n - choices = [] - for i, delta_output in enumerate(delta_outputs): - finish_reason_updated = False - if delta_output.finish_reason is not None and finish_reasons[i] is None: - finish_reasons[i] = delta_output.finish_reason - finish_reason_updated = True - num_completion_tokens += delta_output.num_delta_tokens - if not finish_reason_updated and delta_output.delta_text == "": - # Ignore empty delta text when finish reason is not updated. - continue - - choices.append( - CompletionResponseChoice( - index=i, - finish_reason=finish_reasons[i], - text=delta_output.delta_text, - logprobs=( - LogProbs( - content=[ - LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in delta_output.delta_logprob_json_strs - ] - ) - if delta_output.delta_logprob_json_strs is not None - else None - ), - ) - ) - - if len(choices) == 0: - # Skip yield when there is no delta output. - continue - response = CompletionResponse( - id=request_id, - choices=choices, - model=request.model, - usage=UsageInfo( - prompt_tokens=len(prompt), - completion_tokens=num_completion_tokens, - ), - ) - yield f"data: {response.model_dump_json()}\n\n" - async_engine.state.record_event(request_id, event="finish") - - # - Echo the suffix. - if request.suffix is not None: - assert all(finish_reason is not None for finish_reason in finish_reasons) - response = CompletionResponse( - id=request_id, - choices=[ - CompletionResponseChoice( - index=i, - finish_reason=finish_reason, - text=request.suffix, - ) - for i, finish_reason in enumerate(finish_reasons) - ], - model=request.model, - usage=UsageInfo( - prompt_tokens=len(prompt), - completion_tokens=num_completion_tokens, - ), - ) - yield f"data: {response.model_dump_json()}\n\n" - yield "data: [DONE]\n\n" return fastapi.responses.StreamingResponse( @@ -190,140 +84,58 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) # Normal response. - init_output_text = "" if not request.echo else async_engine.tokenizer.decode(prompt) - output_texts = [init_output_text for _ in range(generation_cfg.n)] + num_prompt_tokens = 0 num_completion_tokens = 0 - finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - logprob_json_strs_list: Optional[List[List[str]]] = ( - [[] for _ in range(generation_cfg.n)] if generation_cfg.logprobs else None + output_texts = ["" for _ in range(request.n)] + finish_reasons: List[Optional[str]] = [None for _ in range(request.n)] + logprob_results: Optional[List[List[LogProbsContent]]] = ( + [[] for _ in range(request.n)] if request.logprobs else None ) - async_engine.state.record_event(request_id, event="invoke generate") - async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): + + async for response in async_engine._handle_completion( # pylint: disable=protected-access + request, request_id + ): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified # when the request is disconnected. # Therefore, we check if it is disconnected each time, # and abort the request from engine if so. await async_engine.abort(request_id) - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" ) + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + output_texts[choice.index] += choice.text + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[choice.index] += choice.logprobs.content - assert len(delta_outputs) == generation_cfg.n - for i, delta_output in enumerate(delta_outputs): - if delta_output.finish_reason is not None and finish_reasons[i] is None: - finish_reasons[i] = delta_output.finish_reason - output_texts[i] += delta_output.delta_text - num_completion_tokens += delta_output.num_delta_tokens - if logprob_json_strs_list is not None: - assert delta_output.delta_logprob_json_strs is not None - logprob_json_strs_list[i] += delta_output.delta_logprob_json_strs assert all(finish_reason is not None for finish_reason in finish_reasons) - suffix = request.suffix if request.suffix is not None else "" - async_engine.state.record_event(request_id, event="finish") - response = CompletionResponse( + return CompletionResponse( id=request_id, choices=[ CompletionResponseChoice( index=i, finish_reason=finish_reason, - text=output_text + suffix, + text=output_text, logprobs=( - LogProbs( - content=[ - LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in logprob_json_strs_list[ # pylint: disable=unsubscriptable-object - i - ] - ] - ) - if logprob_json_strs_list is not None - else None + LogProbs(content=logprob_results[i]) if logprob_results is not None else None ), ) for i, (output_text, finish_reason) in enumerate(zip(output_texts, finish_reasons)) ], model=request.model, - usage=UsageInfo( - prompt_tokens=len(prompt), - completion_tokens=num_completion_tokens, - ), + usage=UsageInfo(prompt_tokens=num_prompt_tokens, completion_tokens=num_completion_tokens), ) - return response ################ v1/chat/completions ################ -def chat_completion_check_message_validity( - messages: List[ChatCompletionMessage], -) -> Optional[str]: - """Check if the given chat messages are valid. Return error message if invalid.""" - for i, message in enumerate(messages): - if message.role == "system" and i != 0: - return f"System prompt at position {i} in the message list is invalid." - if message.role == "tool": - return "Tool as the message author is not supported yet." - if message.tool_call_id is not None: - if message.role != "tool": - return "Non-tool message having `tool_call_id` is invalid." - if isinstance(message.content, list): - if message.role != "user": - return "Non-user message having a list of content is invalid." - if message.tool_calls is not None: - if message.role != "assistant": - return "Non-assistant message having `tool_calls` is invalid." - return "Assistant message having `tool_calls` is not supported yet." - return None - - -def check_function_call_usage( - request: ChatCompletionRequest, conv_template: Conversation -) -> Optional[str]: - """Check if function calling is used and update the conversation template. - Return error message if invalid request format for function calling. - """ - - # return if no tools are provided or tool_choice is set to none - if request.tools is None or ( - isinstance(request.tool_choice, str) and request.tool_choice == "none" - ): - conv_template.use_function_calling = False - return None - - # select the tool based on the tool_choice if specified - if isinstance(request.tool_choice, dict): - if request.tool_choice["type"] != "function": - return "Only 'function' tool choice is supported" - - if len(request.tool_choice["function"]) > 1: - return "Only one tool is supported when tool_choice is specified" - - for tool in request.tools: - if tool.function.name == request.tool_choice["function"]["name"]: - conv_template.use_function_calling = True - conv_template.function_string = tool.function.model_dump_json() - return None - - return ( - f"The tool_choice function {request.tool_choice['function']['name']}" - " is not found in the tools list" - ) - - if isinstance(request.tool_choice, str) and request.tool_choice != "auto": - return f"Invalid tool_choice value: {request.tool_choice}" - - function_list = [] - for tool in request.tools: - if tool.type != "function": - return "Only 'function' tool type is supported" - function_list.append(tool.function.model_dump()) - - conv_template.use_function_calling = True - conv_template.function_string = json.dumps(function_list) - return None - - def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: """Convert a (possibly list) of function call string to a list of json objects. Return None for invalid function call string.""" @@ -360,132 +172,30 @@ async def request_chat_completion( server_context: ServerContext = ServerContext.current() async_engine = server_context.get_engine(request.model) if async_engine is None: - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message=f'The requested model "{request.model}" is not served.' ) - request_id = f"chatcmpl-{entrypoint_utils.random_uuid()}" - async_engine.state.record_event(request_id, event="receive request") - - # - Check if the model supports chat conversation. - conv_template = server_context.get_conv_template(request.model) - if conv_template is None: - return entrypoint_utils.create_error_response( - HTTPStatus.BAD_REQUEST, - message=f'The requested model "{request.model}" does not support chat.', - ) - - # - Check if unsupported arguments are specified. - error = entrypoint_utils.check_unsupported_fields(request) - if error is not None: - return error - - # - Process messages and update the conversation template in three steps: - # i. Check the message validity. - # ii. Add the input messages to the conversation template. - # iii. Add the additional message for the assistant. - error_msg = chat_completion_check_message_validity(request.messages) - if error_msg is not None: - return entrypoint_utils.create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) - - # Check for function calling usage and update the conversation template - error_msg = check_function_call_usage(request, conv_template) - if error_msg is not None: - return entrypoint_utils.create_error_response(HTTPStatus.BAD_REQUEST, message=error_msg) - - for message in request.messages: - role = message.role - content = message.content - if role == "system": - assert isinstance(content, str) - conv_template.system_message = content if content is not None else "" - continue - - assert role != "tool", "Internal error: tool role." - conv_template.messages.append((role, content)) - conv_template.messages.append(("assistant", None)) - - # - Get the prompt from template, and encode to token ids. - # - Check prompt length - async_engine.state.record_event(request_id, event="start tokenization") - - model_config = server_context.get_model_config(request.model) - prompts = entrypoint_utils.process_prompts( - conv_template.as_prompt(model_config), - async_engine.tokenizer.encode, - ) - - async_engine.state.record_event(request_id, event="finish tokenization") - - if conv_template.system_prefix_token_ids is not None: - prompts[0] = conv_template.system_prefix_token_ids + prompts[0] - error = entrypoint_utils.check_prompts_length(prompts, async_engine.max_input_sequence_length) - if error is not None: - return error - - prompt: Sequence[Union[List[int], data.ImageData]] = prompts - - # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config( - request, - extra_stop_token_ids=conv_template.stop_token_ids, - extra_stop_str=conv_template.stop_str, - ) + request_id = f"chatcmpl-{engine_utils.random_uuid()}" # Streaming response. if request.stream: + # We manually get the first response from generator to + # capture potential exceptions in this scope, rather then + # the StreamingResponse scope. + stream_generator = async_engine._handle_chat_completion( # pylint: disable=protected-access + request, request_id + ) + first_response = await anext( # type: ignore # pylint: disable=undefined-variable + stream_generator + ) async def completion_stream_generator() -> AsyncGenerator[str, None]: - async_engine.state.record_event(request_id, event="invoke generate") - finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): - assert len(delta_outputs) == generation_cfg.n - choices = [] - for i, delta_output in enumerate(delta_outputs): - finish_reason_updated = False - if delta_output.finish_reason is not None and finish_reasons[i] is None: - finish_reasons[i] = ( - delta_output.finish_reason - if not conv_template.use_function_calling - else "tool_calls" - ) - finish_reason_updated = True - if not finish_reason_updated and delta_output.delta_text == "": - # Ignore empty delta text when finish reason is not updated. - async_engine.state.record_event(request_id, event="skip empty delta text") - continue - - choices.append( - ChatCompletionStreamResponseChoice( - index=i, - finish_reason=finish_reasons[i], - delta=ChatCompletionMessage( - content=delta_output.delta_text, role="assistant" - ), - logprobs=( - LogProbs( - content=[ - LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in delta_output.delta_logprob_json_strs - ] - ) - if delta_output.delta_logprob_json_strs is not None - else None - ), - ) - ) - - if len(choices) == 0: - # Skip yield when there is no delta output. - continue - response = ChatCompletionStreamResponse( - id=request_id, - choices=choices, - model=request.model, - system_fingerprint="", - ) - async_engine.state.record_event(request_id, event="yield delta output") + if isinstance(first_response, StopAsyncIteration): + yield "data: [DONE]\n\n" + return + yield f"data: {first_response.model_dump_json()}\n\n" + async for response in stream_generator: yield f"data: {response.model_dump_json()}\n\n" - async_engine.state.record_event(request_id, event="finish") yield "data: [DONE]\n\n" return fastapi.responses.StreamingResponse( @@ -493,39 +203,42 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) # Normal response. - output_texts = ["" for _ in range(generation_cfg.n)] + num_prompt_tokens = 0 num_completion_tokens = 0 - finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - logprob_json_strs_list: Optional[List[List[str]]] = ( - [[] for _ in range(generation_cfg.n)] if generation_cfg.logprobs else None + output_texts = ["" for _ in range(request.n)] + finish_reasons: List[Optional[str]] = [None for _ in range(request.n)] + logprob_results: Optional[List[List[LogProbsContent]]] = ( + [[] for _ in range(request.n)] if request.logprobs else None ) - async_engine.state.record_event(request_id, event="invoke generate") - async for delta_outputs in async_engine.generate(prompt, generation_cfg, request_id): + + async for response in async_engine._handle_chat_completion( # pylint: disable=protected-access + request, request_id + ): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified # when the request is disconnected. # Therefore, we check if it is disconnected each time, # and abort the request from engine if so. await async_engine.abort(request_id) - return entrypoint_utils.create_error_response( + return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" ) + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + assert isinstance(choice.delta.content, str) + output_texts[choice.index] += choice.delta.content + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[choice.index] += choice.logprobs.content - assert len(delta_outputs) == generation_cfg.n - for i, delta_output in enumerate(delta_outputs): - if delta_output.finish_reason is not None and finish_reasons[i] is None: - finish_reasons[i] = delta_output.finish_reason - output_texts[i] += delta_output.delta_text - num_completion_tokens += delta_output.num_delta_tokens - if logprob_json_strs_list is not None: - assert delta_output.delta_logprob_json_strs is not None - logprob_json_strs_list[i] += delta_output.delta_logprob_json_strs assert all(finish_reason is not None for finish_reason in finish_reasons) - async_engine.state.record_event(request_id, event="finish") - - tool_calls_list: List[List[ChatToolCall]] = [[] for _ in range(generation_cfg.n)] - if conv_template.use_function_calling: + tool_calls_list: List[List[ChatToolCall]] = [[] for _ in range(request.n)] + use_function_calling = any(finish_reason == "tool_calls" for finish_reason in finish_reasons) + if use_function_calling: for i, output_text in enumerate(output_texts): try: fn_json_list = convert_function_str_to_json(output_text) @@ -557,20 +270,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: finish_reason=finish_reasons[i], message=( ChatCompletionMessage(role="assistant", content=output_text) - if (not conv_template.use_function_calling or finish_reason == "error") + if not use_function_calling or finish_reason == "error" else ChatCompletionMessage(role="assistant", tool_calls=tool_calls) ), logprobs=( - LogProbs( - content=[ - LogProbsContent.model_validate_json(logprob_json_str) - for logprob_json_str in logprob_json_strs_list[ # pylint: disable=unsubscriptable-object - i - ] - ] - ) - if logprob_json_strs_list is not None - else None + LogProbs(content=logprob_results[i]) if logprob_results is not None else None ), ) for i, (output_text, finish_reason, tool_calls) in enumerate( @@ -579,7 +283,5 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ], model=request.model, system_fingerprint="", - usage=UsageInfo( - prompt_tokens=sum(len(item) for item in prompt), completion_tokens=num_completion_tokens - ), + usage=UsageInfo(prompt_tokens=num_prompt_tokens, completion_tokens=num_completion_tokens), ) diff --git a/python/mlc_llm/serve/server/__main__.py b/python/mlc_llm/serve/server/__main__.py deleted file mode 100644 index ed900edd03..0000000000 --- a/python/mlc_llm/serve/server/__main__.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Entrypoint of RESTful HTTP request server in MLC LLM""" - -import argparse -import json - -import fastapi -import uvicorn -from fastapi.middleware.cors import CORSMiddleware - -from mlc_llm.serve.entrypoints import debug_entrypoints, openai_entrypoints - -from .. import async_engine, config -from .server_context import ServerContext - - -def parse_args_and_initialize() -> argparse.Namespace: - """Parse the server arguments and initialize the engine.""" - - args = argparse.ArgumentParser() # pylint: disable=redefined-outer-name - args.add_argument("--model", type=str, required=True) - args.add_argument("--model-lib-path", type=str, required=True) - args.add_argument("--device", type=str, default="auto") - args.add_argument("--max-batch-size", type=int, default=80) - args.add_argument("--max-total-seq-length", type=int) - args.add_argument("--prefill-chunk-size", type=int) - args.add_argument("--enable-tracing", action="store_true") - - args.add_argument("--host", type=str, default="127.0.0.1", help="host name") - args.add_argument("--port", type=int, default=8000, help="port") - args.add_argument("--allow-credentials", action="store_true", help="allow credentials") - args.add_argument("--allowed-origins", type=json.loads, default=["*"], help="allowed origins") - args.add_argument("--allowed-methods", type=json.loads, default=["*"], help="allowed methods") - args.add_argument("--allowed-headers", type=json.loads, default=["*"], help="allowed headers") - - parsed = args.parse_args() - - return parsed - - -if __name__ == "__main__": - # Parse the arguments and initialize the asynchronous engine. - args: argparse.Namespace = parse_args_and_initialize() - app = fastapi.FastAPI() - - # Initialize model loading info and KV cache config - model_info = async_engine.ModelInfo( - model=args.model, - model_lib_path=args.model_lib_path, - device=args.device, - ) - kv_cache_config = config.KVCacheConfig( - max_num_sequence=args.max_batch_size, - max_total_sequence_length=args.max_total_seq_length, - prefill_chunk_size=args.prefill_chunk_size, - ) - # Create engine and start the background loop - engine = async_engine.AsyncThreadedEngine( - model_info, kv_cache_config, enable_tracing=args.enable_tracing - ) - - with ServerContext() as server_context: - server_context.add_model(args.model, engine) - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - app.include_router(openai_entrypoints.app) - app.include_router(debug_entrypoints.app) - uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index ed63f6ac51..9529316010 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -17,9 +17,9 @@ class PopenServer: # pylint: disable=too-many-instance-attributes def __init__( # pylint: disable=too-many-arguments self, model: str, - model_lib_path: str, device: str = "auto", *, + model_lib_path: Optional[str] = None, max_batch_size: int = 80, max_total_sequence_length: Optional[int] = None, enable_tracing: bool = False, @@ -43,7 +43,8 @@ def start(self) -> None: """ cmd = [sys.executable] cmd += ["-m", "mlc_llm", "serve", self.model] - cmd += ["--model-lib-path", self.model_lib_path] + if self.model_lib_path is not None: + cmd += ["--model-lib-path", self.model_lib_path] cmd += ["--device", self.device] cmd += ["--max-batch-size", str(self.max_batch_size)] if self.max_total_sequence_length is not None: diff --git a/python/mlc_llm/serve/server/server_context.py b/python/mlc_llm/serve/server/server_context.py index baad7b5e7d..ab103c05f8 100644 --- a/python/mlc_llm/serve/server/server_context.py +++ b/python/mlc_llm/serve/server/server_context.py @@ -1,12 +1,8 @@ """Server context that shared by multiple entrypoint files.""" -import json from typing import Dict, List, Optional -from ...chat_module import _get_model_path -from ...conversation_template import ConvTemplateRegistry -from ...protocol.conversation_protocol import Conversation -from .. import async_engine +from ..engine import AsyncEngine class ServerContext: @@ -17,9 +13,7 @@ class ServerContext: server_context: Optional["ServerContext"] = None def __init__(self): - self._models: Dict[str, async_engine.AsyncThreadedEngine] = {} - self._conv_templates: Dict[str, Conversation] = {} - self._model_configs: Dict[str, Dict] = {} + self._models: Dict[str, AsyncEngine] = {} def __enter__(self): if ServerContext.server_context is not None: @@ -31,46 +25,22 @@ def __exit__(self, exc_type, exc_value, traceback): for model_engine in self._models.values(): model_engine.terminate() self._models.clear() - self._conv_templates.clear() - self._model_configs.clear() @staticmethod def current(): """Returns the current ServerContext.""" return ServerContext.server_context - def add_model(self, hosted_model: str, engine: async_engine.AsyncThreadedEngine) -> None: + def add_model(self, hosted_model: str, engine: AsyncEngine) -> None: """Add a new model to the server context together with the engine.""" if hosted_model in self._models: raise RuntimeError(f"Model {hosted_model} already running.") self._models[hosted_model] = engine - # Get the conversation template. - if engine.conv_template_name is not None: - conv_template = ConvTemplateRegistry.get_conv_template(engine.conv_template_name) - if conv_template is not None: - self._conv_templates[hosted_model] = conv_template - - _, config_file_path = _get_model_path(hosted_model) - with open(config_file_path, "r", encoding="utf-8") as file: - config = json.load(file) - self._model_configs[hosted_model] = config - - def get_engine(self, model: str) -> Optional[async_engine.AsyncThreadedEngine]: + def get_engine(self, model: str) -> Optional[AsyncEngine]: """Get the async engine of the requested model.""" return self._models.get(model, None) - def get_conv_template(self, model: str) -> Optional[Conversation]: - """Get the conversation template of the requested model.""" - conv_template = self._conv_templates.get(model, None) - if conv_template is not None: - return conv_template.model_copy(deep=True) - return None - def get_model_list(self) -> List[str]: """Get the list of models on serve.""" return list(self._models.keys()) - - def get_model_config(self, model: str) -> Optional[Dict]: - """Get the model config path of the requested model.""" - return self._model_configs.get(model, None) diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py new file mode 100644 index 0000000000..e8bc0288cf --- /dev/null +++ b/python/mlc_llm/serve/sync_engine.py @@ -0,0 +1,332 @@ +"""The MLC LLM synchronized engine. + +NOTE: This engine defined in this file directly wraps the underlying +Engine implementation in C++, is not optimized by multi-threading and +does not offer standard OpenAI API interface. + +We do not expose it and use it by default. As of now it mainly serves +the test and debug purpose because of its simplicity. +""" + +import json +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import tvm + +from mlc_llm.serve import data +from mlc_llm.serve.config import EngineMode, GenerationConfig, KVCacheConfig +from mlc_llm.serve.engine_base import ( + ModelInfo, + _estimate_max_total_sequence_length, + _process_model_args, +) +from mlc_llm.serve.event_trace_recorder import EventTraceRecorder +from mlc_llm.serve.request import Request +from mlc_llm.streamer import TextStreamer +from mlc_llm.support import logging +from mlc_llm.tokenizer import Tokenizer + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +def _create_tvm_module( + creator: str, ffi_funcs: Sequence[str], creator_args: Optional[List[Any]] = None +) -> Dict[str, Callable]: + """Internal method to create a module.""" + if creator_args is None: + creator_args = [] + module = tvm.get_global_func(creator, allow_missing=False)(*creator_args) + return {key: module[key] for key in ffi_funcs} + + +class SyncEngine: + """The Python interface of synchronize request serving engine for MLC LLM. + + The engine receives requests from the "add_request" method. For + an given request, the engine will keep generating new tokens for + the request until finish (under certain criterion). After finish, + the engine will return the generation result through the callback + function provided by the request. + + NOTE: This engine directly wraps the underlying Engine implementation + in C++, is not optimized by multi-threading and does not offer standard + OpenAI API interface. We do not expose it and use it by default. + As of now it mainly serves the test and debug purpose because of its + simplicity. + + Parameters + ---------- + models : Union[ModelInfo, List[ModelInfo]] + One or a list of model info (specifying which models to load and + which device to load to) to launch the engine. + + kv_cache_config : KVCacheConfig + The configuration of the paged KV cache. + + request_stream_callback : Optional[Callable[[str, data.TokenData, Optional[str]], None]] + The provided callback function to handle the generation + output. It has the signature of `(str, data.TokenData, bool) -> None`, + where + - the first string is the request id, + - the TokenData contains the generated **delta** token ids since + the last invocation of the callback on the specific request, + - the optional string value denotes the finish reason if the + generation of the request is finished, or None if it has not finished. + + The callback function is optional at construction, but it needs to + be set before the engine executing requests. This can be done via + the `set_request_stream_callback` method. Otherwise, the engine will raise + exception. + + engine_mode : Optional[EngineMode] + The Engine execution mode. + + enable_tracing : bool + A boolean indicating if to enable event logging for requests. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + models: Union[ModelInfo, List[ModelInfo]], + kv_cache_config: KVCacheConfig, + engine_mode: Optional[EngineMode] = None, + request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, + enable_tracing: bool = False, + ): + if isinstance(models, ModelInfo): + models = [models] + ( + model_args, + config_file_paths, + tokenizer_path, + max_single_sequence_length, + prefill_chunk_size, + self.conv_template_name, + ) = _process_model_args(models) + self._ffi = _create_tvm_module( + "mlc.serve.create_engine", + ffi_funcs=[ + "init", + "add_request", + "abort_request", + "step", + "stats", + "reset", + "get_request_stream_callback", + "set_request_stream_callback", + ], + ) + self.trace_recorder = EventTraceRecorder() if enable_tracing else None + self.max_input_sequence_length = max_single_sequence_length + + if kv_cache_config.max_total_sequence_length is None: + kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( + models, config_file_paths, kv_cache_config.max_num_sequence + ) + if kv_cache_config.prefill_chunk_size is None: + kv_cache_config.prefill_chunk_size = prefill_chunk_size + elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: + raise ValueError( + f"The specified prefill chunk size {kv_cache_config.prefill_chunk_size} is " + f"larger than the maximum prefill chunk size {prefill_chunk_size} supported by " + "models. Please specify a smaller prefill chunk size." + ) + + if engine_mode is None: + # The default engine mode: non-speculative + engine_mode = EngineMode() + + self._ffi["init"]( + max_single_sequence_length, + tokenizer_path, + kv_cache_config.asjson(), + engine_mode.asjson(), + request_stream_callback, + self.trace_recorder, + *model_args, + ) + self.tokenizer = Tokenizer(tokenizer_path) + + def generate( # pylint: disable=too-many-locals + self, + prompts: Union[str, List[str], List[int], List[List[int]], List[List[data.Data]]], + generation_config: Union[GenerationConfig, List[GenerationConfig]], + ) -> Tuple[List[List[str]], List[Optional[List[List[str]]]]]: + """Generate texts for a list of input prompts. + Each prompt can be a string or a list of token ids. + The generation for each prompt is independent. + Return the generation results, one for each prompt. + + Parameters + ---------- + prompts : Union[str, List[str], List[int], List[List[int]]] + One or a list of input prompts for text generation. + Each prompt can be a string or a list of token ids. + + generation_config : Union[GenerationConfig, List[GenerationConfig]] + The generation config for each requests. + If the it is a single GenerationConfig instance, + this config will be shared by all the prompts. + Otherwise, one generation config is required for every + prompt. + + Returns + ------- + output_text : List[List[str]] + The text generation results, one list of strings for each input prompt. + The length of each list is the parallel generation `n` in + generation config. + + output_logprobs_str : List[Optional[List[List[str]]]] + The logprob strings of each token for each input prompt, or None + if an input prompt does not require logprobs. + """ + if isinstance(prompts, str): + # `prompts` is a single string. + prompts = [prompts] + else: + assert isinstance(prompts, list), ( + "Input `prompts` is expected to be a string, a list of " + "str, a list of token ids or multiple lists of token ids. " + ) + if len(prompts) == 0: + return [], [] + if isinstance(prompts[0], int): + # `prompts` is a list of token ids + prompts = [prompts] # type: ignore + + num_requests = len(prompts) + if not isinstance(generation_config, list): + generation_config = [generation_config] * num_requests + + assert ( + len(generation_config) == num_requests + ), "Number of generation config and number of prompts mismatch" + + num_finished_generations = 0 + output_texts: List[List[str]] = [] + output_logprobs_str: List[Optional[List[List[str]]]] = [] + text_streamers: List[List[TextStreamer]] = [] + for i in range(num_requests): + output_texts.append([]) + output_logprobs_str.append([] if generation_config[i].logprobs else None) + text_streamers.append([]) + for _ in range(generation_config[i].n): + output_texts[i].append("") + text_streamers[i].append(TextStreamer(self.tokenizer)) + if output_logprobs_str[i] is not None: + output_logprobs_str[i].append([]) + + num_total_generations = sum(cfg.n for cfg in generation_config) + + # Save a copy of the original function callback since `generate` + # overrides the callback function. + # The original callback will be set back later on. + original_callback = self._ffi["get_request_stream_callback"]() + + # Define the callback function for request generation results + def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]): + nonlocal num_finished_generations + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + rid = int(request_id) + + assert len(stream_outputs) == generation_config[rid].n + for i, (stream_output, text_streamer) in enumerate( + zip(stream_outputs, text_streamers[rid]) + ): + if output_logprobs_str[rid] is not None: + assert stream_output.delta_logprob_json_strs is not None + output_logprobs_str[rid][i] += stream_output.delta_logprob_json_strs + + delta_text = ( + text_streamer.put(stream_output.delta_token_ids) + if len(stream_output.delta_token_ids) > 0 + else "" + ) + if stream_output.finish_reason is not None: + delta_text += text_streamer.finish() + + output_texts[rid][i] += delta_text + if stream_output.finish_reason is not None: + num_finished_generations += 1 + + # Override the callback function in engine. + self._ffi["set_request_stream_callback"](request_stream_callback) + + def convert_to_data(prompt: Union[str, List[int], List[data.Data]]) -> List[data.Data]: + if isinstance(prompt, str): + return [data.TextData(prompt)] + if isinstance(prompt[0], int): + return [data.TokenData(prompt)] # type: ignore + return prompt # type: ignore + + # Add requests to engine. + for req_id, (prompt, generation_cfg) in enumerate(zip(prompts, generation_config)): + input_data = convert_to_data(prompt) # type: ignore + self.add_request( + Request( + request_id=str(req_id), + inputs=input_data, + generation_config=generation_cfg, + ) + ) + + while num_finished_generations != num_total_generations: + self.step() + + # Restore the callback function in engine. + self._ffi["set_request_stream_callback"](original_callback) + return output_texts, output_logprobs_str + + def add_request(self, request: Request) -> None: + """Add a new request to the engine. + + Parameters + ---------- + request : Request + The request to add. + """ + self._ffi["add_request"](request) + + def abort_request(self, request_id: str) -> None: + """Abort the generation of the request corresponding to the input request id. + + Parameters + ---------- + request_id : str + The unique id of the request to abort. + """ + self._ffi["abort_request"](request_id) + + def step(self) -> None: + """The main function that the engine takes a step of action. + + At each step, the engine may decide to + - run prefill for one (or more) requests, + - run one-step decode for the all existing requests + ... + + In the end of certain actions (e.g., decode), the engine will + check if any request has finished, and will return the + generation results for those finished requests. + """ + self._ffi["step"]() + + def reset(self) -> None: + """Reset the engine, clean up all running data and statistics.""" + self._ffi["reset"]() + + def stats(self) -> Dict[str, float]: + """The engine runtime statistics. + We collect the following entries: + - single token prefill latency (s/tok): avg latency of processing one token in prefill + - single token decode latency (s/tok): avg latency of processing one token in decode + - engine time for prefill (sec) + - engine time for decode (sec) + - total number of processed tokens in prefill. + - total number of processed tokens in decode. + """ + stats_json_str = self._ffi["stats"]() + return json.loads(stats_json_str) diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index 51e7bae586..a88f3d68b8 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -21,7 +21,7 @@ ) from mlc_llm.conversation_template import ConvTemplateRegistry from mlc_llm.help import HELP -from mlc_llm.serve.entrypoints import entrypoint_utils +from mlc_llm.serve import engine_utils from mlc_llm.support.argparse import ArgumentParser from mlc_llm.support.auto_device import detect_device from mlc_llm.support.style import green, red @@ -261,7 +261,7 @@ def _tokenize(self, prompt: str) -> tvm.nd.array: "Parsed prompt using conversation template " f"{green(self.conversation.name)}: {parsed_prompt}" ) - tokens = entrypoint_utils.process_prompts(parsed_prompt, self.tokenizer.encode) + tokens = engine_utils.process_prompts(parsed_prompt, self.tokenizer.encode) # type: ignore # TODO: Handle ImageData in DebugChat # pylint: disable=fixme assert len(tokens) == 1, "DebugChat will only handle TextData for now" diff --git a/tests/python/serve/benchmark.py b/tests/python/serve/benchmark.py index d544f4b371..dd6d59c72f 100644 --- a/tests/python/serve/benchmark.py +++ b/tests/python/serve/benchmark.py @@ -10,9 +10,10 @@ import numpy as np from transformers import AutoTokenizer -from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig +from mlc_llm.serve import GenerationConfig, KVCacheConfig from mlc_llm.serve.config import ResponseFormat -from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve.engine_base import ModelInfo +from mlc_llm.serve.sync_engine import SyncEngine def _parse_args(): @@ -114,7 +115,7 @@ def benchmark(args: argparse.Namespace): ) # Create engine - engine = Engine(model, kv_cache_config) + engine = SyncEngine(model, kv_cache_config) # Sample prompts from dataset prompts, generation_config = sample_requests( args.dataset, args.num_prompts, args.model, args.json_output diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index bbd2089f4c..82c9dfa534 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -4,8 +4,9 @@ import random from typing import List, Tuple -from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig -from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve import GenerationConfig, KVCacheConfig +from mlc_llm.serve.engine_base import ModelInfo +from mlc_llm.serve.sync_engine import SyncEngine def _parse_args(): @@ -21,7 +22,6 @@ def _parse_args(): parsed.model = os.path.dirname(parsed.model_lib_path) assert parsed.batch_size % 16 == 0 assert parsed.page_size == 16 - assert parsed.max_total_seq_length >= 2048 return parsed @@ -52,7 +52,7 @@ def benchmark(args: argparse.Namespace): ) # Create engine - engine = Engine(model, kv_cache_config) + engine = SyncEngine(model, kv_cache_config) print(args) for num_requests in [1, 2, 4, 8, 16, 32, 64]: diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index 286d64a874..cca9a4265e 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -181,7 +181,7 @@ def check_openai_stream_response( usage = response["usage"] assert isinstance(usage, dict) assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] - assert usage["prompt_tokens"] > 0 + assert usage["prompt_tokens"] >= 0 if completion_tokens is not None: assert usage["completion_tokens"] <= completion_tokens @@ -255,6 +255,7 @@ def test_openai_v1_completions( "prompt": prompt, "max_tokens": max_tokens, "stream": stream, + "ignore_eos": True, } response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) @@ -310,7 +311,7 @@ def test_openai_v1_completions_openai_package( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reasons=["length"], + finish_reasons=["length", "stop"], completion_tokens=max_tokens, ) else: @@ -323,7 +324,7 @@ def test_openai_v1_completions_openai_package( model=served_model[0], object_str="text_completion", num_choices=1, - finish_reasons=["length"], + finish_reasons=["length", "stop"], completion_tokens=max_tokens, ) @@ -362,6 +363,7 @@ def test_openai_v1_completions_echo( "max_tokens": max_tokens, "echo": True, "stream": stream, + "ignore_eos": True, } response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) @@ -412,6 +414,7 @@ def test_openai_v1_completions_suffix( "max_tokens": max_tokens, "suffix": suffix, "stream": stream, + "ignore_eos": True, } response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) @@ -511,6 +514,7 @@ def test_openai_v1_completions_temperature( "max_tokens": max_tokens, "stream": stream, "temperature": 0.0, + "ignore_eos": True, } response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) @@ -664,6 +668,7 @@ def test_openai_v1_completions_logit_bias( "max_tokens": max_tokens, "stream": stream, "logit_bias": {338: -100}, # 338 is " is" in Llama tokenizer. + "ignore_eos": True, } response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) @@ -710,6 +715,7 @@ def test_openai_v1_completions_presence_frequency_penalty( "stream": stream, "frequency_penalty": 2.0, "presence_penalty": 2.0, + "ignore_eos": True, } response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) @@ -753,6 +759,7 @@ def test_openai_v1_completions_seed( "max_tokens": max_tokens, "stream": False, "seed": 233, + "ignore_eos": True, } response1 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index a1a2791bf7..f87c11547a 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -3,8 +3,8 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncThreadedEngine, GenerationConfig, KVCacheConfig -from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve import AsyncEngine, GenerationConfig, KVCacheConfig +from mlc_llm.serve.engine_base import ModelInfo prompts = [ "What is the meaning of life?", @@ -28,25 +28,25 @@ async def test_engine_generate(): ) kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - async_engine = AsyncThreadedEngine(model, kv_cache_config) + async_engine = AsyncEngine(model, kv_cache_config) num_requests = 10 max_tokens = 256 - generation_cfg = GenerationConfig(max_tokens=max_tokens, n=3) + generation_cfg = GenerationConfig(max_tokens=max_tokens, n=7) output_texts: List[List[str]] = [ ["" for _ in range(generation_cfg.n)] for _ in range(num_requests) ] async def generate_task( - async_engine: AsyncThreadedEngine, + async_engine: AsyncEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, ): print(f"generate task for request {request_id}") rid = int(request_id) - async for delta_outputs in async_engine.generate( + async for delta_outputs in async_engine._generate( prompt, generation_cfg, request_id=request_id ): assert len(delta_outputs) == generation_cfg.n @@ -76,5 +76,107 @@ async def generate_task( del async_engine +async def test_chat_completion(): + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + # Create engine + async_engine = AsyncEngine(model, kv_cache_config) + + num_requests = 2 + max_tokens = 32 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + async def generate_task(prompt: str, request_id: str): + print(f"generate chat completion task for request {request_id}") + rid = int(request_id) + async for response in async_engine.chat_completion( + messages=[{"role": "user", "content": prompt}], + model=model.model, + max_tokens=max_tokens, + n=n, + request_id=request_id, + ): + for choice in response.choices: + assert choice.delta.role == "assistant" + output_texts[rid][choice.index] += choice.delta.content + + tasks = [ + asyncio.create_task(generate_task(prompts[i], request_id=str(i))) + for i in range(num_requests) + ] + + await asyncio.gather(*tasks) + + # Print output. + print("Chat completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + async_engine.terminate() + del async_engine + + +async def test_completion(): + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + # Create engine + async_engine = AsyncEngine(model, kv_cache_config) + + num_requests = 2 + max_tokens = 128 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + async def generate_task(prompt: str, request_id: str): + print(f"generate completion task for request {request_id}") + rid = int(request_id) + async for response in async_engine.completion( + prompt=prompt, + model=model.model, + max_tokens=max_tokens, + n=n, + ignore_eos=True, + request_id=request_id, + ): + for choice in response.choices: + output_texts[rid][choice.index] += choice.text + + tasks = [ + asyncio.create_task(generate_task(prompts[i], request_id=str(i))) + for i in range(num_requests) + ] + + await asyncio.gather(*tasks) + + # Print output. + print("Chat completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + async_engine.terminate() + del async_engine + + if __name__ == "__main__": asyncio.run(test_engine_generate()) + asyncio.run(test_chat_completion()) + asyncio.run(test_completion()) diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index 10ed7a4729..b142bce7ae 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -3,13 +3,8 @@ import asyncio from typing import List -from mlc_llm.serve import ( - AsyncThreadedEngine, - EngineMode, - GenerationConfig, - KVCacheConfig, -) -from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve import AsyncEngine, EngineMode, GenerationConfig, KVCacheConfig +from mlc_llm.serve.engine_base import ModelInfo prompts = [ "What is the meaning of life?", @@ -38,7 +33,7 @@ async def test_engine_generate(): kv_cache_config = KVCacheConfig(page_size=16) engine_mode = EngineMode(enable_speculative=True) # Create engine - async_engine = AsyncThreadedEngine([llm, ssm], kv_cache_config, engine_mode) + async_engine = AsyncEngine([llm, ssm], kv_cache_config, engine_mode) num_requests = 10 max_tokens = 256 @@ -49,14 +44,14 @@ async def test_engine_generate(): ] async def generate_task( - async_engine: AsyncThreadedEngine, + async_engine: AsyncEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, ): print(f"generate task for request {request_id}") rid = int(request_id) - async for delta_outputs in async_engine.generate( + async for delta_outputs in async_engine._generate( prompt, generation_cfg, request_id=request_id ): assert len(delta_outputs) == generation_cfg.n diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 9f56f507ca..cece8a1e27 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -1,18 +1,9 @@ # pylint: disable=chained-comparison,line-too-long,missing-docstring, # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -from typing import Callable, List, Optional +from typing import List -import numpy as np - -from mlc_llm.serve import ( - Engine, - GenerationConfig, - KVCacheConfig, - Request, - RequestStreamOutput, - data, -) -from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig +from mlc_llm.serve.engine_base import ModelInfo prompts = [ "What is the meaning of life?", @@ -28,345 +19,87 @@ ] -def create_requests( - num_requests: int, - stop_token_id: Optional[int] = None, - temperature: float = 0.8, - repetition_penalty: float = 1.0, - max_tokens_low: int = 256, - max_tokens_high: int = 257, -) -> List[Request]: - assert num_requests >= 0 and num_requests <= len(prompts) - - stop_token_ids = [stop_token_id] if stop_token_id is not None else [] - requests = [] - for req_id, prompt in zip(range(num_requests), prompts): - max_tokens = np.random.randint(max_tokens_low, max_tokens_high) - requests.append( - Request( - request_id=str(req_id), - inputs=data.TextData(prompt), - generation_config=GenerationConfig( - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - ), - ) - ) - return requests - - -def test_engine_basic(): - """Test engine **without continuous batching**. - - - Add all requests to the engine altogether in the beginning. - - All requests have the same max_tokens. This means all requests - will end together. - - Engine keeps running `step` for estimated number of steps (number of - requests + max_tokens - 1). Then check the output of each request. - """ - - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16) - - # Hyperparameters for tests (you can try different combinations). - num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.0 # [1.0, 1.01] - max_tokens: int = 256 # [32, 128, 256] - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - - # Define the callback function for request generation results - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - - # Create engine - engine = Engine(model, kv_cache_config, request_stream_callback=fcallback) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens, - max_tokens_high=max_tokens + 1, - ) - - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max_tokens - 1 - # Run steps - for step in range(num_steps): - engine.step() - - for req_id, output in enumerate(outputs): - print(f"Prompt {req_id}: {requests[req_id].inputs[0]}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - - -def test_engine_continuous_batching_1(): - """Test engine **with continuous batching**. - - - Add all requests to the engine altogether in the beginning. - - All requests have a random maximum generation length. So each - request keeps generating until reaching the maximum length. - - Engine keeps running `step` for estimated number of steps (number of - requests + the maximum max_tokens - 1). Then check the output - of each request. - """ - - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16) - - # Hyperparameters for tests (you can try different combinations) - num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.00 # [1.0, 1.01] - max_tokens_low = 128 - max_tokens_high = 384 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - finish_time = [None] * num_requests - - # Define the callback class for request generation results - class CallbackTimer: - timer: int = -1 - - def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - if stream_outputs[0].finish_reason is not None: - print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - finish_time[int(request_id)] = self.timer - - return fcallback - - def step(self) -> None: - self.timer += 1 - - # Create engine - timer = CallbackTimer() - engine = Engine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens_low, - max_tokens_high=max_tokens_high, - ) - - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1 - # Run steps - for step in range(num_steps): - timer.step() - assert timer.timer == step - engine.step() - - for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): - print(f"Prompt {req_id}: {request.inputs[0]}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - assert fin_time == request.generation_config.max_tokens - 1 - - -def test_engine_continuous_batching_2(): - """Test engine **with continuous batching**. - - - Add all requests to the engine altogether in the beginning. - - All requests have the stop token. So each request keeps generating - until having the stop token or reaching the maximum length. - - Engine keeps running `step` for estimated number of steps (number of - requests + the maximum max_tokens - 1). Then check the output - of each request. - """ - +def test_engine_generate(): # Initialize model loading info and KV cache config model = ModelInfo( "dist/Llama-2-7b-chat-hf-q0f16-MLC", model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", ) - kv_cache_config = KVCacheConfig(page_size=16) - - # Hyperparameters for tests (you can try different combinations) - num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.00 # [1.0, 1.01] - stop_token_id = 2 - max_tokens = 512 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - finish_time = [None] * num_requests - - # Define the callback class for request generation results - class CallbackTimer: - timer: int = -1 - - def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - if stream_outputs[0].finish_reason is not None: - print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - finish_time[int(request_id)] = self.timer - - return fcallback - - def step(self) -> None: - self.timer += 1 - + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - timer = CallbackTimer() - engine = Engine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) - - # Create requests - requests = create_requests( - num_requests, - stop_token_id=stop_token_id, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens, - max_tokens_high=max_tokens + 1, - ) - - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max_tokens - 1 - # Run steps - for step in range(num_steps): - timer.step() - assert timer.timer == step - engine.step() - - for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): - print(f"Prompt {req_id}: {request.inputs[0]}") - if fin_time < num_requests + max_tokens - 2: - print(f"Request {req_id} ends early on the stop token") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + engine = Engine(model, kv_cache_config) + num_requests = 10 + max_tokens = 256 + generation_cfg = GenerationConfig(max_tokens=max_tokens, n=7) + + output_texts: List[List[str]] = [ + ["" for _ in range(generation_cfg.n)] for _ in range(num_requests) + ] + for rid in range(num_requests): + print(f"generating for request {rid}") + for delta_outputs in engine._generate(prompts[rid], generation_cfg, request_id=str(rid)): + assert len(delta_outputs) == generation_cfg.n + for i, delta_output in enumerate(delta_outputs): + output_texts[rid][i] += delta_output.delta_text + + # Print output. + print("All finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") -def test_engine_continuous_batching_3(): - """Test engine **with continuous batching**. + engine.terminate() + del engine - - Add requests randomly between time [0, 200). - - All requests have a random maximum generation length. So each - request keeps generating until reaching the maximum length. - - Engine keeps running `step` until all requests finish. - Then check the output of each request. - """ +def test_chat_completion(): # Initialize model loading info and KV cache config model = ModelInfo( "dist/Llama-2-7b-chat-hf-q0f16-MLC", model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", ) - kv_cache_config = KVCacheConfig(page_size=16) - - # Hyperparameters for tests (you can try different combinations) - num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.00 # [1.0, 1.01] - stop_token_id = 2 - max_tokens_low = 64 - max_tokens_high = 192 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - finish_time = [None] * num_requests - - # Define the callback class for request generation results - class CallbackTimer: - timer: int = -1 - finished_requests: int = 0 - - def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - if stream_outputs[0].finish_reason is not None: - print(f"Request {request_id} finished at step {self.timer}.") - self.finished_requests += 1 - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - finish_time[int(request_id)] = self.timer - - return fcallback - - def step(self) -> None: - self.timer += 1 - - def all_finished(self) -> bool: - return self.finished_requests == num_requests - + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - timer = CallbackTimer() - engine = Engine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) - - # Create requests - requests = create_requests( - num_requests, - stop_token_id=stop_token_id, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens_low, - max_tokens_high=max_tokens_high, - ) - - # Assign the time to add requests to engine - request_add_time = [np.random.randint(0, 200) for _ in range(num_requests)] - - # Run steps - while not timer.all_finished(): - timer.step() - - # Add requests to engine - for req_id, add_time in enumerate(request_add_time): - if add_time == timer.timer: - print(f"add request {req_id} at step {timer.timer}") - engine.add_request(requests[req_id]) + engine = Engine(model, kv_cache_config) - engine.step() + num_requests = 2 + max_tokens = 64 + n = 2 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + for rid in range(num_requests): + print(f"chat completion for request {rid}") + for response in engine.chat_completion( + messages=[{"role": "user", "content": prompts[rid]}], + model=model.model, + max_tokens=max_tokens, + n=n, + request_id=str(rid), + ): + for choice in response.choices: + assert choice.delta.role == "assistant" + output_texts[rid][choice.index] += choice.delta.content + + # Print output. + print("Chat completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") - for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): - print(f"Prompt {req_id}: {request.inputs[0]}") - print(f"Finish time: {fin_time}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + engine.terminate() + del engine -def test_engine_generate(): +def test_completion(): # Initialize model loading info and KV cache config model = ModelInfo( "dist/Llama-2-7b-chat-hf-q0f16-MLC", @@ -376,13 +109,26 @@ def test_engine_generate(): # Create engine engine = Engine(model, kv_cache_config) - num_requests = 10 - max_tokens = 256 - - # Generate output. - output_texts, _ = engine.generate( - prompts[:num_requests], GenerationConfig(max_tokens=max_tokens) - ) + num_requests = 2 + max_tokens = 128 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + for rid in range(num_requests): + print(f"completion for request {rid}") + for response in engine.completion( + prompt=prompts[rid], + model=model.model, + max_tokens=max_tokens, + n=n, + ignore_eos=True, + request_id=str(rid), + ): + for choice in response.choices: + output_texts[rid][choice.index] += choice.text + + # Print output. + print("Chat completion all finished") for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") if len(outputs) == 1: @@ -391,10 +137,11 @@ def test_engine_generate(): for i, output in enumerate(outputs): print(f"Output {req_id}({i}):{output}\n") + engine.terminate() + del engine + if __name__ == "__main__": - test_engine_basic() - test_engine_continuous_batching_1() - test_engine_continuous_batching_2() - test_engine_continuous_batching_3() test_engine_generate() + test_chat_completion() + test_completion() diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index 45926002ae..e40f477061 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -7,10 +7,10 @@ import pytest from pydantic import BaseModel -from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig -from mlc_llm.serve.async_engine import AsyncThreadedEngine +from mlc_llm.serve import AsyncEngine, GenerationConfig, KVCacheConfig from mlc_llm.serve.config import ResponseFormat -from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve.engine_base import ModelInfo +from mlc_llm.serve.sync_engine import SyncEngine prompts_list = [ "Generate a JSON string containing 20 objects:", @@ -26,7 +26,7 @@ def test_batch_generation_with_grammar(): model = ModelInfo(model_path, model_lib_path=model_lib_path) kv_cache_config = KVCacheConfig(page_size=16) # Create engine - engine = Engine(model, kv_cache_config) + engine = SyncEngine(model, kv_cache_config) prompt_len = len(prompts_list) prompts = prompts_list * 3 @@ -76,7 +76,7 @@ def test_batch_generation_with_schema(): model = ModelInfo(model_path, model_lib_path=model_lib_path) kv_cache_config = KVCacheConfig(page_size=16) # Create engine - engine = Engine(model, kv_cache_config) + engine = SyncEngine(model, kv_cache_config) prompt = ( "Generate a json containing three fields: an integer field named size, a " @@ -131,7 +131,7 @@ async def run_async_engine(): model = ModelInfo(model_path, model_lib_path=model_lib_path) kv_cache_config = KVCacheConfig(page_size=16) # Create engine - async_engine = AsyncThreadedEngine(model, kv_cache_config, enable_tracing=True) + async_engine = AsyncEngine(model, kv_cache_config, enable_tracing=True) prompts = prompts_list * 20 @@ -152,14 +152,14 @@ async def run_async_engine(): ] async def generate_task( - async_engine: AsyncThreadedEngine, + async_engine: AsyncEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, ): print(f"Start generation task for request {request_id}") rid = int(request_id) - async for delta_outputs in async_engine.generate( + async for delta_outputs in async_engine._generate( prompt, generation_cfg, request_id=request_id ): assert len(delta_outputs) == generation_cfg.n diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py index 5b23a245f9..e8bcb13ae4 100644 --- a/tests/python/serve/test_serve_engine_image.py +++ b/tests/python/serve/test_serve_engine_image.py @@ -1,10 +1,13 @@ -from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig, data -from mlc_llm.serve.engine import ModelInfo -from mlc_llm.serve.entrypoints.entrypoint_utils import get_image_from_url +import json +from pathlib import Path +from mlc_llm.serve import GenerationConfig, KVCacheConfig, data +from mlc_llm.serve.engine_base import ModelInfo +from mlc_llm.serve.sync_engine import SyncEngine -def get_test_image(): - return get_image_from_url("https://llava-vl.github.io/static/images/view.jpg") + +def get_test_image(config) -> data.ImageData: + return data.ImageData.from_url("https://llava-vl.github.io/static/images/view.jpg", config) def test_engine_generate(): @@ -15,19 +18,21 @@ def test_engine_generate(): ) kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - engine = Engine(model, kv_cache_config) - + engine = SyncEngine(model, kv_cache_config) max_tokens = 256 + with open(Path(model.model) / "mlc-chat-config.json", "r", encoding="utf-8") as file: + model_config = json.load(file) + prompts = [ [ data.TextData("USER: "), - data.ImageData(get_test_image(), 576), + get_test_image(model_config), data.TextData("\nWhat does this image represent? ASSISTANT:"), ], [ data.TextData("USER: "), - data.ImageData(get_test_image(), 576), + get_test_image(model_config), data.TextData("\nIs there a dog in this image? ASSISTANT:"), ], [data.TextData("USER: What is the meaning of life? ASSISTANT:")], diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 828146afc9..403f75d325 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -5,7 +5,6 @@ import numpy as np from mlc_llm.serve import ( - Engine, EngineMode, GenerationConfig, KVCacheConfig, @@ -13,7 +12,8 @@ RequestStreamOutput, data, ) -from mlc_llm.serve.engine import ModelInfo +from mlc_llm.serve.engine_base import ModelInfo +from mlc_llm.serve.sync_engine import SyncEngine prompts = [ "What is the meaning of life?", @@ -98,7 +98,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine - engine = Engine([model, ssm], kv_cache_config, engine_mode, fcallback) + engine = SyncEngine([model, ssm], kv_cache_config, engine_mode, fcallback) # Create requests requests = create_requests( @@ -179,7 +179,7 @@ def step(self) -> None: # Create engine timer = CallbackTimer() - engine = Engine([model, ssm], kv_cache_config, engine_mode, timer.callback_getter()) + engine = SyncEngine([model, ssm], kv_cache_config, engine_mode, timer.callback_getter()) # Create requests requests = create_requests( @@ -220,7 +220,7 @@ def test_engine_generate(): kv_cache_config = KVCacheConfig(page_size=16) engine_mode = EngineMode(enable_speculative=True) # Create engine - engine = Engine([model, ssm], kv_cache_config, engine_mode) + engine = SyncEngine([model, ssm], kv_cache_config, engine_mode) num_requests = 10 max_tokens = 256 @@ -266,7 +266,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine - engine = Engine(model, kv_cache_config, request_stream_callback=fcallback) + engine = SyncEngine(model, kv_cache_config, request_stream_callback=fcallback) # Create requests requests = create_requests( @@ -338,7 +338,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine - spec_engine = Engine([model, ssm], kv_cache_config, engine_mode, fcallback) + spec_engine = SyncEngine([model, ssm], kv_cache_config, engine_mode, fcallback) # Create requests requests = create_requests( diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py new file mode 100644 index 0000000000..3c8ec011ae --- /dev/null +++ b/tests/python/serve/test_serve_sync_engine.py @@ -0,0 +1,402 @@ +# pylint: disable=chained-comparison,line-too-long,missing-docstring, +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +from typing import Callable, List, Optional + +import numpy as np + +from mlc_llm.serve import ( + GenerationConfig, + KVCacheConfig, + Request, + RequestStreamOutput, + data, +) +from mlc_llm.serve.engine_base import ModelInfo +from mlc_llm.serve.sync_engine import SyncEngine + +prompts = [ + "What is the meaning of life?", + "Introduce the history of Pittsburgh to me. Please elaborate in detail.", + "Write a three-day Seattle travel plan. Please elaborate in detail.", + "What is Alaska famous of? Please elaborate in detail.", + "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", + "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", + "Why is Vitamin D important to human beings? Please elaborate in detail.", + "Where is milk tea originated from? Please elaborate in detail.", + "Where is the southernmost place in United States? Please elaborate in detail.", + "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", +] + + +def create_requests( + num_requests: int, + stop_token_id: Optional[int] = None, + temperature: float = 0.8, + repetition_penalty: float = 1.0, + max_tokens_low: int = 256, + max_tokens_high: int = 257, +) -> List[Request]: + assert num_requests >= 0 and num_requests <= len(prompts) + + stop_token_ids = [stop_token_id] if stop_token_id is not None else [] + requests = [] + for req_id, prompt in zip(range(num_requests), prompts): + max_tokens = np.random.randint(max_tokens_low, max_tokens_high) + requests.append( + Request( + request_id=str(req_id), + inputs=data.TextData(prompt), + generation_config=GenerationConfig( + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + ), + ) + ) + return requests + + +def test_engine_basic(): + """Test engine **without continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have the same max_tokens. This means all requests + will end together. + - Engine keeps running `step` for estimated number of steps (number of + requests + max_tokens - 1). Then check the output of each request. + """ + + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + + # Hyperparameters for tests (you can try different combinations). + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.0 # [1.0, 1.01] + max_tokens: int = 256 # [32, 128, 256] + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + + # Define the callback function for request generation results + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + + # Create engine + engine = SyncEngine(model, kv_cache_config, request_stream_callback=fcallback) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + engine.step() + + for req_id, output in enumerate(outputs): + print(f"Prompt {req_id}: {requests[req_id].inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + + +def test_engine_continuous_batching_1(): + """Test engine **with continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have a random maximum generation length. So each + request keeps generating until reaching the maximum length. + - Engine keeps running `step` for estimated number of steps (number of + requests + the maximum max_tokens - 1). Then check the output + of each request. + """ + + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + + # Hyperparameters for tests (you can try different combinations) + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + max_tokens_low = 128 + max_tokens_high = 384 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + finish_time = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + # Create engine + timer = CallbackTimer() + engine = SyncEngine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens_low, + max_tokens_high=max_tokens_high, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1 + # Run steps + for step in range(num_steps): + timer.step() + assert timer.timer == step + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + assert ( + fin_time == request.generation_config.max_tokens - 1 + ), f"finish time = {fin_time}, max tokens = {request.generation_config.max_tokens - 1}" + + +def test_engine_continuous_batching_2(): + """Test engine **with continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have the stop token. So each request keeps generating + until having the stop token or reaching the maximum length. + - Engine keeps running `step` for estimated number of steps (number of + requests + the maximum max_tokens - 1). Then check the output + of each request. + """ + + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + + # Hyperparameters for tests (you can try different combinations) + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + stop_token_id = 2 + max_tokens = 512 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + finish_time = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + # Create engine + timer = CallbackTimer() + engine = SyncEngine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) + + # Create requests + requests = create_requests( + num_requests, + stop_token_id=stop_token_id, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + timer.step() + assert timer.timer == step + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + if fin_time < num_requests + max_tokens - 2: + print(f"Request {req_id} ends early on the stop token") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + + +def test_engine_continuous_batching_3(): + """Test engine **with continuous batching**. + + - Add requests randomly between time [0, 200). + - All requests have a random maximum generation length. So each + request keeps generating until reaching the maximum length. + - Engine keeps running `step` until all requests finish. + Then check the output of each request. + """ + + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16) + + # Hyperparameters for tests (you can try different combinations) + num_requests = 10 # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + stop_token_id = 2 + max_tokens_low = 64 + max_tokens_high = 192 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + finish_time = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + finished_requests: int = 0 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + self.finished_requests += 1 + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + def all_finished(self) -> bool: + return self.finished_requests == num_requests + + # Create engine + timer = CallbackTimer() + engine = SyncEngine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) + + # Create requests + requests = create_requests( + num_requests, + stop_token_id=stop_token_id, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens_low, + max_tokens_high=max_tokens_high, + ) + + # Assign the time to add requests to engine + request_add_time = [np.random.randint(0, 200) for _ in range(num_requests)] + + # Run steps + while not timer.all_finished(): + timer.step() + + # Add requests to engine + for req_id, add_time in enumerate(request_add_time): + if add_time == timer.timer: + print(f"add request {req_id} at step {timer.timer}") + engine.add_request(requests[req_id]) + + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + print(f"Finish time: {fin_time}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + + +def test_engine_generate(): + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + # Create engine + engine = SyncEngine(model, kv_cache_config) + + num_requests = 10 + max_tokens = 256 + + # Generate output. + output_texts, _ = engine.generate( + prompts[:num_requests], GenerationConfig(max_tokens=max_tokens, n=7) + ) + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + +if __name__ == "__main__": + test_engine_basic() + test_engine_continuous_batching_1() + test_engine_continuous_batching_2() + test_engine_continuous_batching_3() + test_engine_generate() From 5cf700ba9b3eadc85787b48a91af6b037bac4d85 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 5 Apr 2024 11:17:02 -0400 Subject: [PATCH 158/531] [Serving] Separating ThreadedEngine creation and initialization (#2090) This PR separates the creation and initialization of ThreadedEngine for multi-threading use cases. So we can make sure that the ThreadedEngine instance is created before any other operations (such as initialization, running background loop, etc.). --- cpp/serve/threaded_engine.cc | 5 ++--- cpp/serve/threaded_engine.h | 16 +++++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index 61ce2e51d6..f74517d7bf 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -26,7 +26,7 @@ using namespace tvm::runtime; /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineImpl : public ThreadedEngine { public: - void InitBackgroundEngine(TVMArgs args) { + void InitBackgroundEngine(TVMArgs args) final { Optional request_stream_callback; try { request_stream_callback = args.At>(4); @@ -232,9 +232,8 @@ TVM_REGISTER_GLOBAL("mlc.serve.create_threaded_engine").set_body_typed([]() { return Module(make_object()); }); -std::unique_ptr CreateThreadedEnginePacked(TVMArgs args) { +std::unique_ptr ThreadedEngine::Create() { std::unique_ptr threaded_engine = std::make_unique(); - threaded_engine->InitBackgroundEngine(args); return std::move(threaded_engine); } diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index 90447e28d8..1440a88056 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -28,8 +28,17 @@ using namespace tvm::runtime; */ class ThreadedEngine { public: + /*! \brief Create a ThreadedEngine. */ + static std::unique_ptr Create(); + virtual ~ThreadedEngine() = default; + /*! + * \brief Initialize the threaded engine from packed arguments in TVMArgs. + * \param args The arguments of engine construction. + */ + virtual void InitBackgroundEngine(TVMArgs args) = 0; + /*! \brief Starts the background request processing loop. */ virtual void RunBackgroundLoop() = 0; @@ -50,13 +59,6 @@ class ThreadedEngine { virtual void AbortRequest(const String& request_id) = 0; }; -/*! - * \brief Create a ThreadedEngine from packed arguments in TVMArgs. - * \param args The arguments of engine construction. - * \return The constructed threaded engine in unique pointer. - */ -std::unique_ptr CreateThreadedEnginePacked(TVMArgs args); - } // namespace serve } // namespace llm } // namespace mlc From d6d3d7e6aa798f804aba0cad3eb61ba16a373a8f Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 5 Apr 2024 14:46:20 -0400 Subject: [PATCH 159/531] [Serving] Enhance robustness with small KV capacity (#2091) This PR enhances the robustness, which had issue when the KV capacity is small. --- cpp/serve/engine_actions/batch_decode.cc | 2 ++ python/mlc_llm/serve/engine_base.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index fc830a21ee..94e441279a 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -59,6 +59,8 @@ class BatchDecodeActionObj : public EngineActionObj { // NOTE: Right now we only support decode all the running request states at a time. int num_rsentries = running_rsentries.size(); + ICHECK_GT(num_rsentries, 0) + << "There should be at least one request state entry that can run decode"; // Collect // - the last committed token, // - the request id, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 21bb928df3..248bd1acf2 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -552,12 +552,16 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self.model_config_dicts.append(json.load(file)) self.state = EngineState(enable_tracing) - self.max_input_sequence_length = max_single_sequence_length if kv_cache_config.max_total_sequence_length is None: kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( models, config_file_paths, kv_cache_config.max_num_sequence ) + self.max_input_sequence_length = min( + max_single_sequence_length, kv_cache_config.max_total_sequence_length + ) + prefill_chunk_size = min(prefill_chunk_size, kv_cache_config.max_total_sequence_length) + if kv_cache_config.prefill_chunk_size is None: kv_cache_config.prefill_chunk_size = prefill_chunk_size elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: From a73eae2af34dd245ed51d740591d77dd27398236 Mon Sep 17 00:00:00 2001 From: Kartik Khandelwal Date: Fri, 5 Apr 2024 17:25:00 -0400 Subject: [PATCH 160/531] [REST] Update REST API docs (#2092) This updates the rest docs to use `mlc_llm serve` and also adds a quick start section. --- docs/deploy/rest.rst | 89 +++++++++++++++++++++++--------------------- 1 file changed, 46 insertions(+), 43 deletions(-) diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index 621a22fb71..07d39dbfad 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -1,6 +1,6 @@ .. _deploy-rest-api: -Rest API +REST API ======== .. contents:: Table of Contents @@ -8,33 +8,65 @@ Rest API :depth: 2 We provide `REST API `_ -for a user to interact with MLC-Chat in their own programs. +for a user to interact with MLC-LLM in their own programs. -Install MLC-Chat Package +Install MLC-LLM Package ------------------------ -SERVE is a part of the MLC-Chat package, installation instruction for which we be found here :doc:`<../install/mlc_llm>`. +SERVE is a part of the MLC-LLM package, installation instruction for which can be found :ref:`here `. Once you have install the MLC-LLM package, you can run the following command to check if the installation was successful: -Verify Installation -^^^^^^^^^^^^^^^^^^^ +.. code:: bash + + mlc_llm serve --help + +You should see serve help message if the installation was successful. + +Quick start +------------ + +This section provides a quick start guide to work with MLC-LLM REST API. To launch a server, run the following command: .. code:: bash - python -m mlc_llm.serve.server --help + mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] + +where ``MODEL`` is the model folder after compiling with :ref:`MLC-LLM build process `. Information about other arguments can be found under :ref:`Launch the server ` section. + +Once you have launched the Server, you can use the API in your own program to send requests. Below is an example of using the API to interact with MLC-LLM in Python without Streaming (suppose the server is running on ``http://127.0.0.1:8080/``): + +.. code:: bash + + import requests + + # Get a response using a prompt without streaming + payload = { + "model": "./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/", + "messages": [ + {"role": "user", "content": "Write a haiku about apples."}, + ], + "stream": False, + # "n": 1, + "max_tokens": 300, + } + r = requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload) + choices = r.json()["choices"] + for choice in choices: + print(f"{choice['message']['content']}\n") -You are expected to see the help information of the MLC SERVE. +------------------------------------------------ -.. _mlcchat_package_build_from_source: + +.. _rest_launch_server: Launch the Server ----------------- -To launch the MLC Server for MLC-Chat, run the following command in your terminal. +To launch the MLC Server for MLC-LLM, run the following command in your terminal. .. code:: bash - python -m mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] + mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] MODEL The model folder after compiling with MLC-LLM build process. The parameter can either be the model name with its quantization scheme @@ -71,7 +103,7 @@ The REST API provides the following endpoints: ------------------------------------------------ - Get a list of models available for MLC-Chat. + Get a list of models available for MLC-LLM. **Example** @@ -95,7 +127,7 @@ The REST API provides the following endpoints: ------------------------------------------------ - Get a response from MLC-Chat using a prompt, either with or without streaming. + Get a response from MLC-LLM using a prompt, either with or without streaming. **Chat Completion Request Object** @@ -203,35 +235,7 @@ The REST API provides the following endpoints: **Example** -Once you have launched the Server, you can use the API in your own program. Below is an example of using the API to interact with MLC-Chat in Python without Streaming (suppose the server is running on ``http://127.0.0.1:8080/``): - -.. code:: bash - - import requests - - # Get a response using a prompt without streaming - payload = { - "model": "./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/", - "messages": [ - {"role": "user", "content": "Hello! Our project is MLC LLM."}, - { - "role": "assistant", - "content": "Hello! It's great to hear about your project, MLC LLM.", - }, - {"role": "user", "content": "What is the name of our project?"}, - ], - "stream": False, - # "n": 1, - "max_tokens": 300, - } - r = requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload) - choices = r.json()["choices"] - for choice in choices: - print(f"{choice['message']['content']}\n") - ------------------------------------------------- - -Below is an example of using the API to interact with MLC-Chat in Python with Streaming. +Below is an example of using the API to interact with MLC-LLM in Python with Streaming. .. code:: bash @@ -256,7 +260,6 @@ Below is an example of using the API to interact with MLC-Chat in Python with St ------------------------------------------------ - There is also support for function calling similar to OpenAI (https://platform.openai.com/docs/guides/function-calling). Below is an example on how to use function calling in Python. .. code:: bash From 466fa8a80303ae7b7015045cbc1fd8fe15ce2f1a Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 5 Apr 2024 19:50:59 -0400 Subject: [PATCH 161/531] [DOCS] Clarify vulkan loader dependency (#2095) This PR clarifies the vulkan loader dependecy. Some system may not have the right vulkan loader and we need to install them via conda. --- docs/index.rst | 10 ++++++++++ docs/install/mlc_llm.rst | 9 ++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 485567b37e..2aabd613bf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -74,6 +74,16 @@ It is recommended to have at least 6GB free VRAM to run it. mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + + If you are using windows/linux/steamdeck and would like to use vulkan, + we recommend installing necessary vulkan loader dependency via conda + to avoid vulkan not found issues. + + .. code:: bash + + conda install -c conda-forge gcc libvulkan-loader + + .. tab:: Web Browser `WebLLM `__. MLC LLM generates performant code for WebGPU and WebAssembly, diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index 3003abdc72..c6602559ae 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -61,10 +61,17 @@ Select your operating system/compute platform and run the command in your termin .. tab:: Vulkan - Supported in all Linux packages. + Supported in all Linux packages. Checkout the following instructions + to install the latest vulkan loader to avoid vulkan not found issue. .. note:: + + .. code-block:: bash + + conda install -c conda-forge gcc libvulkan-loader + + If encountering issues with GLIBC not found, please install the latest glibc in conda: .. code-block:: bash From a75eb0b2a5c1e2f93eaa4d3a4a9e221bf971be5b Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 7 Apr 2024 02:05:12 +0800 Subject: [PATCH 162/531] [SLM] Add support for Chatglm3 architecture (#2096) This pr enable Chatglm3 model. --- python/mlc_llm/conversation_template.py | 20 + python/mlc_llm/model/chatglm3/__init__.py | 0 .../mlc_llm/model/chatglm3/chatglm3_loader.py | 63 +++ .../mlc_llm/model/chatglm3/chatglm3_model.py | 384 ++++++++++++++++++ .../model/chatglm3/chatglm3_quantization.py | 53 +++ python/mlc_llm/model/model.py | 14 + python/mlc_llm/model/model_preset.py | 37 ++ 7 files changed, 571 insertions(+) create mode 100644 python/mlc_llm/model/chatglm3/__init__.py create mode 100644 python/mlc_llm/model/chatglm3/chatglm3_loader.py create mode 100644 python/mlc_llm/model/chatglm3/chatglm3_model.py create mode 100644 python/mlc_llm/model/chatglm3/chatglm3_quantization.py diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index e71e6734f7..1b2a06feab 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -475,3 +475,23 @@ def get_conv_template(name: str) -> Optional[Conversation]: system_prefix_token_ids=[1], ) ) + +# GLM +ConvTemplateRegistry.register_conv_template( + Conversation( + name="glm", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={ + "user": "问", + "assistant": "答", + "tool": "问", + }, + seps=["\n\n"], + role_content_sep=": ", + role_empty_sep=":", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[64790, 64792], + ) +) diff --git a/python/mlc_llm/model/chatglm3/__init__.py b/python/mlc_llm/model/chatglm3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/chatglm3/chatglm3_loader.py b/python/mlc_llm/model/chatglm3/chatglm3_loader.py new file mode 100644 index 0000000000..677514f491 --- /dev/null +++ b/python/mlc_llm/model/chatglm3/chatglm3_loader.py @@ -0,0 +1,63 @@ +""" +This file specifies how MLC's ChatGLM3 parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .chatglm3_model import ChatGLMForCausalLM, GLMConfig + + +def huggingface(model_config: GLMConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : GLMConfig + The configuration of the Baichuan model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = ChatGLMForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + mlc_name = "transformer.embedding.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + ["transformer.embedding.word_embeddings.weight"], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_llm/model/chatglm3/chatglm3_model.py b/python/mlc_llm/model/chatglm3/chatglm3_model.py new file mode 100644 index 0000000000..e4a9f53b15 --- /dev/null +++ b/python/mlc_llm/model/chatglm3/chatglm3_model.py @@ -0,0 +1,384 @@ +""" +Implementation for CHATGLM3 architecture. +TODO: add docstring +""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class GLMConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the ChatGLM model.""" + + hidden_size: int + num_layers: int + kv_channels: int + num_attention_heads: int + ffn_hidden_size: int + layernorm_epsilon: float + post_layer_norm: bool + rmsnorm: bool + add_bias_linear: bool + add_qkv_bias: bool + apply_query_key_layer_scaling: bool + multi_query_attention: bool + multi_query_group_num: int + vocab_size: int = 0 + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + max_batch_size: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.vocab_size == 0: + for name in ["padded_vocab_size"]: + if name in self.kwargs: + self.vocab_size = self.kwargs.pop(name) + if self.context_window_size == 0: + for name in ["max_position_embeddings", "seq_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + assert self.tensor_parallel_shards == 1, "ChatGLM currently does not support sharding." + + +# pylint: disable=invalid-name,missing-docstring + + +class GLMAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: GLMConfig): + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.multi_query_attention = config.multi_query_attention + self.num_key_value_heads = ( + config.multi_query_group_num + if config.multi_query_attention + else config.num_attention_heads + ) + self.head_dim = self.hidden_size // self.num_heads + self.query_key_value = nn.Linear( + config.hidden_size, + (2 * self.num_key_value_heads + self.num_heads) * self.head_dim, + bias=config.add_bias_linear or config.add_qkv_bias, + ) + self.dense = nn.Linear( + self.num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear + ) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads + b, s, _ = hidden_states.shape + qkv = self.query_key_value(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, h_q), + (b, s, h_q * d), + ) + attn_output = self.dense(output) + return attn_output + + +class GLMMLP(nn.Module): + def __init__(self, config: GLMConfig): + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=config.add_bias_linear, + ) + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=config.add_bias_linear, + ) + + def swiglu(x): + x = nn.chunk(x, 2, dim=-1) + return nn.silu(x[0]) * x[1] + + self.activation_func = swiglu + + def forward(self, x): + intermediate_parallel = self.dense_h_to_4h(x) + intermediate_parallel = self.activation_func(intermediate_parallel) + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(nn.Module): + def __init__(self, config: GLMConfig): + self.self_attention = GLMAttention(config=config) + self.mlp = GLMMLP(config) + self.input_layernorm = nn.RMSNorm( + config.hidden_size, -1, config.layernorm_epsilon, bias=False + ) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, -1, config.layernorm_epsilon, bias=False + ) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attention(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) + hidden_states = out + hidden_states + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = out + hidden_states + return hidden_states + + +class GLMTransformer(nn.Module): + """Transformer class.""" + + def __init__(self, config: GLMConfig): + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + self.layers = nn.ModuleList([GLMBlock(config) for _ in range(config.num_layers)]) + + if self.post_layer_norm: + if config.rmsnorm: + self.final_layernorm = nn.RMSNorm( + config.hidden_size, -1, config.layernorm_epsilon, bias=False + ) + else: + self.final_layernorm = nn.LayerNorm(config.hidden_size, config.layernorm_epsilon) + + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class ChatGLMModel(nn.Module): + def __init__(self, config: GLMConfig): + self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.encoder = GLMTransformer(config) + self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + hidden_states = self.encoder(hidden_states, paged_kv_cache) + return hidden_states + + +class ChatGLMForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: GLMConfig): + self.transformer = ChatGLMModel(config) + self.num_hidden_layers = config.num_layers + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = ( + config.multi_query_group_num + if config.multi_query_attention + else config.num_attention_heads + ) + self.head_dim = self.hidden_size // self.num_attention_heads + self.vocab_size = config.vocab_size + self.rope_theta = 10000 + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.transformer(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.transformer.output_layer(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + return self.transformer.embedding(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.transformer(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.transformer.output_layer(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.transformer(input_embed, paged_kv_cache) + logits = self.transformer.output_layer(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( # pylint: disable=too-many-arguments + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "softmax_with_temperature": { + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/chatglm3/chatglm3_quantization.py b/python/mlc_llm/model/chatglm3/chatglm3_quantization.py new file mode 100644 index 0000000000..26b404daa8 --- /dev/null +++ b/python/mlc_llm/model/chatglm3/chatglm3_quantization.py @@ -0,0 +1,53 @@ +"""This file specifies how MLC's ChatGLM parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .chatglm3_model import ChatGLMForCausalLM, GLMConfig + + +def group_quant( + model_config: GLMConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a ChatGLM-architecture model using group quantization.""" + model: nn.Module = ChatGLMForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: GLMConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a ChatGLM-architecture model using FasterTransformer quantization.""" + model: nn.Module = ChatGLMForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: GLMConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a ChatGLM model without quantization.""" + model: nn.Module = ChatGLMForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 119cfded4c..fe9775109a 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -9,6 +9,7 @@ from mlc_llm.quantization.quantization import Quantization from .baichuan import baichuan_loader, baichuan_model, baichuan_quantization +from .chatglm3 import chatglm3_loader, chatglm3_model, chatglm3_quantization from .gemma import gemma_loader, gemma_model, gemma_quantization from .gpt2 import gpt2_loader, gpt2_model, gpt2_quantization from .gpt_bigcode import gpt_bigcode_loader, gpt_bigcode_model, gpt_bigcode_quantization @@ -324,4 +325,17 @@ class Model: "group-quant": rwkv6_quantization.group_quant, }, ), + "chatglm": Model( + name="chatglm", + model=chatglm3_model.ChatGLMForCausalLM, + config=chatglm3_model.GLMConfig, + source={ + "huggingface-torch": chatglm3_loader.huggingface, + "huggingface-safetensor": chatglm3_loader.huggingface, + }, + quantize={ + "no-quant": chatglm3_quantization.no_quant, + "group-quant": chatglm3_quantization.group_quant, + }, + ), } diff --git a/python/mlc_llm/model/model_preset.py b/python/mlc_llm/model/model_preset.py index 8e87217d35..3bfe1cb891 100644 --- a/python/mlc_llm/model/model_preset.py +++ b/python/mlc_llm/model/model_preset.py @@ -623,4 +623,41 @@ "vision_feature_select_strategy": "default", "vocab_size": 32064, }, + "chatglm": { + "architectures": ["ChatGLMModel"], + "model_type": "chatglm", + "auto_map": { + "AutoConfig": "configuration_chatglm.ChatGLMConfig", + "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration", + "AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration", + }, + "add_bias_linear": False, + "add_qkv_bias": True, + "apply_query_key_layer_scaling": True, + "apply_residual_connection_post_layernorm": False, + "attention_dropout": 0.0, + "attention_softmax_in_fp32": True, + "bias_dropout_fusion": True, + "ffn_hidden_size": 13696, + "fp32_residual_connection": False, + "hidden_dropout": 0.0, + "hidden_size": 4096, + "kv_channels": 128, + "layernorm_epsilon": 1e-05, + "multi_query_attention": True, + "multi_query_group_num": 2, + "num_attention_heads": 32, + "num_layers": 28, + "original_rope": True, + "padded_vocab_size": 65024, + "post_layer_norm": True, + "rmsnorm": True, + "seq_length": 8192, + "use_cache": True, + "torch_dtype": "float16", + "transformers_version": "4.30.2", + "tie_word_embeddings": False, + "eos_token_id": 2, + "pad_token_id": 0, + }, } From 3d564f3ebf3b36e99834832ae5d3e6c0c807bf3e Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Sun, 7 Apr 2024 03:57:24 +0800 Subject: [PATCH 163/531] [Quantization] Add OpenCL device (#2097) This PR adds OpenCL device for weight conversion. --- python/mlc_llm/quantization/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/quantization/utils.py b/python/mlc_llm/quantization/utils.py index 3edd53959c..c24c9b4271 100644 --- a/python/mlc_llm/quantization/utils.py +++ b/python/mlc_llm/quantization/utils.py @@ -62,7 +62,7 @@ def is_moe_gate(name: str) -> bool: def compile_quantize_func(mod: IRModule, device) -> Callable: """Compile a quantization function for a given device.""" device_type = device.MASK2STR[device.device_type] - if device_type in ["cuda", "rocm", "metal", "vulkan"]: + if device_type in ["cuda", "rocm", "metal", "vulkan", "opencl"]: target = Target.current() if target is None: target = Target.from_device(device) From 61f76c7b4c4e1895bdfcf752222944bdcf74bafb Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 6 Apr 2024 19:15:30 -0400 Subject: [PATCH 164/531] [Serving] Support stream=True for Python API (#2098) The previous refactoring PR formalizes the MLC serve Python API but does not respect the `stream` flag properly: no matter if `stream` is True or False, the functions always work in a streaming style. This PR supports the non-stream case. --- python/mlc_llm/serve/engine.py | 553 +++++++++++++++++- python/mlc_llm/serve/engine_base.py | 137 +++++ .../serve/entrypoints/openai_entrypoints.py | 120 +--- tests/python/serve/test_serve_async_engine.py | 110 +++- tests/python/serve/test_serve_engine.py | 90 ++- 5 files changed, 893 insertions(+), 117 deletions(-) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 1f856c907c..2846d0ffc3 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -1,8 +1,20 @@ """The MLC LLM Serving Engine.""" +# pylint: disable=too-many-lines + import asyncio import queue -from typing import Any, AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union +from typing import ( + Any, + AsyncGenerator, + Dict, + Iterator, + List, + Literal, + Optional, + Union, + overload, +) from mlc_llm.protocol import openai_api_protocol from mlc_llm.serve import data, engine_utils @@ -56,11 +68,13 @@ async def abort(self, request_id: str) -> None: """ self._abort(request_id) + @overload async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], model: str, + stream: Literal[True], frequency_penalty: float = 0.0, presence_penalty: float = 0.0, logprobs: bool = False, @@ -70,7 +84,6 @@ async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, - stream: bool = False, temperature: float = 1.0, top_p: float = 1.0, tools: Optional[List[Dict[str, Any]]] = None, @@ -80,7 +93,7 @@ async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, ) -> AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any]: - """Asynchronous chat completion interface with OpenAI API compatibility. + """Asynchronous streaming chat completion interface with OpenAI API compatibility. The method is a coroutine that streams ChatCompletionStreamResponse that conforms to OpenAI API one at a time via yield. @@ -104,6 +117,99 @@ async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals e : BadRequestError BadRequestError is raised when the request is invalid. """ + + @overload + async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.ChatCompletionResponse: + """Asynchronous non-streaming chat completion interface with OpenAI API compatibility. + The method is a coroutine that streams ChatCompletionStreamResponse + that conforms to OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : ChatCompletionResponse + The chat completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any], + openai_api_protocol.ChatCompletionResponse, + ]: + """Asynchronous chat completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ if request_id is None: request_id = f"chatcmpl-{engine_utils.random_uuid()}" @@ -142,14 +248,54 @@ async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals ), request_id=request_id, ) + if stream: + # Stream response. + return chatcmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) async for response in chatcmpl_generator: - yield response + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + assert isinstance(choice.delta.content, str) + output_texts[choice.index] += choice.delta.content + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + use_function_calling, tool_calls_list = engine_base.process_function_call_output( + output_texts, finish_reasons + ) + return engine_base.wrap_chat_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + tool_calls_list=tool_calls_list, + logprob_results=logprob_results, + use_function_calling=use_function_calling, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) + @overload async def completion( # pylint: disable=too-many-arguments,too-many-locals self, *, model: str, prompt: Union[str, List[int]], + stream: Literal[True], best_of: int = 1, echo: bool = False, frequency_penalty: float = 0.0, @@ -161,7 +307,6 @@ async def completion( # pylint: disable=too-many-arguments,too-many-locals n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, - stream: bool = False, suffix: Optional[str] = None, temperature: float = 1.0, top_p: float = 1.0, @@ -170,7 +315,7 @@ async def completion( # pylint: disable=too-many-arguments,too-many-locals response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: - """Asynchronous completion interface with OpenAI API compatibility. + """Asynchronous streaming completion interface with OpenAI API compatibility. The method is a coroutine that streams CompletionResponse that conforms to OpenAI API one at a time via yield. @@ -194,6 +339,99 @@ async def completion( # pylint: disable=too-many-arguments,too-many-locals e : BadRequestError BadRequestError is raised when the request is invalid. """ + + @overload + async def completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + model: str, + prompt: Union[str, List[int]], + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.CompletionResponse: + """Asynchronous non-streaming completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : CompletionResponse + The completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + async def completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + model: str, + prompt: Union[str, List[int]], + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + AsyncGenerator[openai_api_protocol.CompletionResponse, Any], + openai_api_protocol.CompletionResponse, + ]: + """Asynchronous completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ if request_id is None: request_id = f"cmpl-{engine_utils.random_uuid()}" cmpl_generator = self._handle_completion( @@ -225,8 +463,41 @@ async def completion( # pylint: disable=too-many-arguments,too-many-locals ), request_id, ) + if stream: + # Stream response. + return cmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) + async for response in cmpl_generator: - yield response + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + output_texts[choice.index] += choice.text + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + return engine_base.wrap_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + logprob_results=logprob_results, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) async def _handle_chat_completion( self, request: openai_api_protocol.ChatCompletionRequest, request_id: str @@ -454,11 +725,13 @@ def abort(self, request_id: str) -> None: """ self._ffi["abort_request"](request_id) + @overload def chat_completion( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], model: str, + stream: Literal[True], frequency_penalty: float = 0.0, presence_penalty: float = 0.0, logprobs: bool = False, @@ -468,7 +741,6 @@ def chat_completion( # pylint: disable=too-many-arguments,too-many-locals n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, - stream: bool = False, temperature: float = 1.0, top_p: float = 1.0, tools: Optional[List[Dict[str, Any]]] = None, @@ -478,7 +750,7 @@ def chat_completion( # pylint: disable=too-many-arguments,too-many-locals response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - """Synchronous chat completion interface with OpenAI API compatibility. + """Synchronous streaming chat completion interface with OpenAI API compatibility. The method streams back ChatCompletionStreamResponse that conforms to OpenAI API one at a time via yield. @@ -502,6 +774,97 @@ def chat_completion( # pylint: disable=too-many-arguments,too-many-locals e : BadRequestError BadRequestError is raised when the request is invalid. """ + + @overload + def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.ChatCompletionResponse: + """Synchronous non-streaming chat completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : ChatCompletionResponse + The chat completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + Iterator[openai_api_protocol.ChatCompletionStreamResponse], + openai_api_protocol.ChatCompletionResponse, + ]: + """Synchronous chat completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ if request_id is None: request_id = f"chatcmpl-{engine_utils.random_uuid()}" @@ -540,14 +903,54 @@ def chat_completion( # pylint: disable=too-many-arguments,too-many-locals ), request_id=request_id, ) + if stream: + # Stream response. + return chatcmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) for response in chatcmpl_generator: - yield response + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + assert isinstance(choice.delta.content, str) + output_texts[choice.index] += choice.delta.content + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + use_function_calling, tool_calls_list = engine_base.process_function_call_output( + output_texts, finish_reasons + ) + return engine_base.wrap_chat_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + tool_calls_list=tool_calls_list, + logprob_results=logprob_results, + use_function_calling=use_function_calling, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) + @overload def completion( # pylint: disable=too-many-arguments,too-many-locals self, *, model: str, prompt: Union[str, List[int]], + stream: Literal[True], best_of: int = 1, echo: bool = False, frequency_penalty: float = 0.0, @@ -559,7 +962,6 @@ def completion( # pylint: disable=too-many-arguments,too-many-locals n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, - stream: bool = False, suffix: Optional[str] = None, temperature: float = 1.0, top_p: float = 1.0, @@ -567,8 +969,8 @@ def completion( # pylint: disable=too-many-arguments,too-many-locals ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, - ) -> Iterator[openai_api_protocol.CompletionResponse]: - """Synchronous completion interface with OpenAI API compatibility. + ) -> openai_api_protocol.CompletionResponse: + """Synchronous streaming completion interface with OpenAI API compatibility. The method streams back CompletionResponse that conforms to OpenAI API one at a time via yield. @@ -592,6 +994,96 @@ def completion( # pylint: disable=too-many-arguments,too-many-locals e : BadRequestError BadRequestError is raised when the request is invalid. """ + + @overload + def completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + model: str, + prompt: Union[str, List[int]], + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """Synchronous non-streaming completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : CompletionResponse + The completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + def completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + model: str, + prompt: Union[str, List[int]], + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """Synchronous completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ if request_id is None: request_id = f"cmpl-{engine_utils.random_uuid()}" cmpl_generator = self._handle_completion( @@ -623,8 +1115,41 @@ def completion( # pylint: disable=too-many-arguments,too-many-locals ), request_id, ) + if stream: + # Stream response. + return cmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) + for response in cmpl_generator: - yield response + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + output_texts[choice.index] += choice.text + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + return engine_base.wrap_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + logprob_results=logprob_results, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) def _handle_chat_completion( self, request: openai_api_protocol.ChatCompletionRequest, request_id: str diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 248bd1acf2..fadd38978d 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -2,6 +2,7 @@ # pylint: disable=too-many-lines +import ast import asyncio import json import os @@ -1068,3 +1069,139 @@ def create_completion_suffix_response( ), ) return response + + +def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: + """Convert a (possibly list) of function call string to a list of json objects. + Return None for invalid function call string.""" + + def parse_function_call(call_str: str): + node = ast.parse(call_str, mode="eval") + call_node = node.body + if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): + name = call_node.func.id + arguments = {} + for keyword in call_node.keywords: + arguments[keyword.arg] = ast.literal_eval(keyword.value) + return {"name": name, "arguments": arguments} + return None + + if ( + stringified_calls[0] == "[" and stringified_calls[-1] == "]" + ): # hacky way to check if string list + calls = ast.literal_eval(stringified_calls) + else: + calls = [stringified_calls] + function_calls_json = [parse_function_call(call_str) for call_str in calls] + return function_calls_json + + +def process_function_call_output( + output_texts: List[str], finish_reasons: List[str] +) -> Tuple[bool, List[List[openai_api_protocol.ChatToolCall]]]: + """Process the potential function call results outputted by model, + according to the finish reasons. + Return whether the output has function call, and the list of tool calls. + """ + n = len(output_texts) + tool_calls_list: List[List[openai_api_protocol.ChatToolCall]] = [[] for _ in range(n)] + use_function_calling = any(finish_reason == "tool_calls" for finish_reason in finish_reasons) + if use_function_calling: + for i, output_text in enumerate(output_texts): + try: + fn_json_list = convert_function_str_to_json(output_text) + except (SyntaxError, ValueError): + output_text = "Got an invalid function call output from model" + finish_reasons[i] = "error" + else: + tool_calls_list[i] = [ + openai_api_protocol.ChatToolCall( + type="function", + function=openai_api_protocol.ChatFunctionCall( + name=fn_json_obj["name"], arguments=fn_json_obj["arguments"] + ), + ) + for fn_json_obj in fn_json_list + if fn_json_obj is not None + ] + if len(tool_calls_list[i]) == 0: + output_texts[i] = "Got an invalid function call output from model" + finish_reasons[i] = "error" + else: + finish_reasons[i] = "tool_calls" + return use_function_calling, tool_calls_list + + +def wrap_chat_completion_response( # pylint: disable=too-many-arguments + request_id: str, + model: str, + output_texts: List[str], + finish_reasons: List[str], + tool_calls_list: List[List[openai_api_protocol.ChatToolCall]], + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]], + use_function_calling: bool, + num_prompt_tokens: int, + num_completion_tokens: int, +) -> openai_api_protocol.ChatCompletionResponse: + """Wrap the non-streaming chat completion results to ChatCompletionResponse instance.""" + return openai_api_protocol.ChatCompletionResponse( + id=request_id, + choices=[ + openai_api_protocol.ChatCompletionResponseChoice( + index=i, + finish_reason=finish_reasons[i], + message=( + openai_api_protocol.ChatCompletionMessage(role="assistant", content=output_text) + if not use_function_calling or finish_reason == "error" + else openai_api_protocol.ChatCompletionMessage( + role="assistant", tool_calls=tool_calls + ) + ), + logprobs=( + openai_api_protocol.LogProbs(content=logprob_results[i]) + if logprob_results is not None + else None + ), + ) + for i, (output_text, finish_reason, tool_calls) in enumerate( + zip(output_texts, finish_reasons, tool_calls_list) + ) + ], + model=model, + system_fingerprint="", + usage=openai_api_protocol.UsageInfo( + prompt_tokens=num_prompt_tokens, completion_tokens=num_completion_tokens + ), + ) + + +def wrap_completion_response( # pylint: disable=too-many-arguments + request_id: str, + model: str, + output_texts: List[str], + finish_reasons: List[str], + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]], + num_prompt_tokens: int, + num_completion_tokens: int, +) -> openai_api_protocol.CompletionResponse: + """Wrap the non-streaming completion results to CompletionResponse instance.""" + return openai_api_protocol.CompletionResponse( + id=request_id, + choices=[ + openai_api_protocol.CompletionResponseChoice( + index=i, + finish_reason=finish_reason, + text=output_text, + logprobs=( + openai_api_protocol.LogProbs(content=logprob_results[i]) + if logprob_results is not None + else None + ), + ) + for i, (output_text, finish_reason) in enumerate(zip(output_texts, finish_reasons)) + ], + model=model, + usage=openai_api_protocol.UsageInfo( + prompt_tokens=num_prompt_tokens, completion_tokens=num_completion_tokens + ), + ) diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index 0625ea6aae..23a279021f 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -1,30 +1,20 @@ """OpenAI API-compatible server entrypoints in MLC LLM""" # pylint: disable=too-many-locals,too-many-return-statements,too-many-statements -import ast from http import HTTPStatus -from typing import AsyncGenerator, Dict, List, Optional, Union +from typing import AsyncGenerator, List, Optional import fastapi from mlc_llm.protocol import error_protocol from mlc_llm.protocol.openai_api_protocol import ( - ChatCompletionMessage, ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatFunctionCall, - ChatToolCall, CompletionRequest, - CompletionResponse, - CompletionResponseChoice, ListResponse, - LogProbs, LogProbsContent, ModelResponse, - UsageInfo, ) -from mlc_llm.serve import engine_utils +from mlc_llm.serve import engine_base, engine_utils from mlc_llm.serve.server import ServerContext app = fastapi.APIRouter() @@ -115,52 +105,20 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: logprob_results[choice.index] += choice.logprobs.content assert all(finish_reason is not None for finish_reason in finish_reasons) - return CompletionResponse( - id=request_id, - choices=[ - CompletionResponseChoice( - index=i, - finish_reason=finish_reason, - text=output_text, - logprobs=( - LogProbs(content=logprob_results[i]) if logprob_results is not None else None - ), - ) - for i, (output_text, finish_reason) in enumerate(zip(output_texts, finish_reasons)) - ], + return engine_base.wrap_completion_response( + request_id=request_id, model=request.model, - usage=UsageInfo(prompt_tokens=num_prompt_tokens, completion_tokens=num_completion_tokens), + output_texts=output_texts, + finish_reasons=finish_reasons, + logprob_results=logprob_results, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, ) ################ v1/chat/completions ################ -def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: - """Convert a (possibly list) of function call string to a list of json objects. - Return None for invalid function call string.""" - - def parse_function_call(call_str: str): - node = ast.parse(call_str, mode="eval") - call_node = node.body - if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): - name = call_node.func.id - arguments = {} - for keyword in call_node.keywords: - arguments[keyword.arg] = ast.literal_eval(keyword.value) - return {"name": name, "arguments": arguments} - return None - - if ( - stringified_calls[0] == "[" and stringified_calls[-1] == "]" - ): # hacky way to check if string list - calls = ast.literal_eval(stringified_calls) - else: - calls = [stringified_calls] - function_calls_json = [parse_function_call(call_str) for call_str in calls] - return function_calls_json - - @app.post("/v1/chat/completions") async def request_chat_completion( request: ChatCompletionRequest, raw_request: fastapi.Request @@ -235,53 +193,17 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: logprob_results[choice.index] += choice.logprobs.content assert all(finish_reason is not None for finish_reason in finish_reasons) - - tool_calls_list: List[List[ChatToolCall]] = [[] for _ in range(request.n)] - use_function_calling = any(finish_reason == "tool_calls" for finish_reason in finish_reasons) - if use_function_calling: - for i, output_text in enumerate(output_texts): - try: - fn_json_list = convert_function_str_to_json(output_text) - except (SyntaxError, ValueError): - output_text = "Got an invalid function call output from model" - finish_reasons[i] = "error" - else: - tool_calls_list[i] = [ - ChatToolCall( - type="function", - function=ChatFunctionCall( - name=fn_json_obj["name"], arguments=fn_json_obj["arguments"] - ), - ) - for fn_json_obj in fn_json_list - if fn_json_obj is not None - ] - if len(tool_calls_list[i]) == 0: - output_texts[i] = "Got an invalid function call output from model" - finish_reasons[i] = "error" - else: - finish_reasons[i] = "tool_calls" - - return ChatCompletionResponse( - id=request_id, - choices=[ - ChatCompletionResponseChoice( - index=i, - finish_reason=finish_reasons[i], - message=( - ChatCompletionMessage(role="assistant", content=output_text) - if not use_function_calling or finish_reason == "error" - else ChatCompletionMessage(role="assistant", tool_calls=tool_calls) - ), - logprobs=( - LogProbs(content=logprob_results[i]) if logprob_results is not None else None - ), - ) - for i, (output_text, finish_reason, tool_calls) in enumerate( - zip(output_texts, finish_reasons, tool_calls_list) - ) - ], + use_function_calling, tool_calls_list = engine_base.process_function_call_output( + output_texts, finish_reasons + ) + return engine_base.wrap_chat_completion_response( + request_id=request_id, model=request.model, - system_fingerprint="", - usage=UsageInfo(prompt_tokens=num_prompt_tokens, completion_tokens=num_completion_tokens), + output_texts=output_texts, + finish_reasons=finish_reasons, + tool_calls_list=tool_calls_list, + logprob_results=logprob_results, + use_function_calling=use_function_calling, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, ) diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index f87c11547a..cb6a065b41 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -94,12 +94,13 @@ async def test_chat_completion(): async def generate_task(prompt: str, request_id: str): print(f"generate chat completion task for request {request_id}") rid = int(request_id) - async for response in async_engine.chat_completion( + async for response in await async_engine.chat_completion( messages=[{"role": "user", "content": prompt}], model=model.model, max_tokens=max_tokens, n=n, request_id=request_id, + stream=True, ): for choice in response.choices: assert choice.delta.role == "assistant" @@ -126,6 +127,56 @@ async def generate_task(prompt: str, request_id: str): del async_engine +async def test_chat_completion_non_stream(): + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + # Create engine + async_engine = AsyncEngine(model, kv_cache_config) + + num_requests = 2 + max_tokens = 32 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + async def generate_task(prompt: str, request_id: str): + print(f"generate chat completion task for request {request_id}") + rid = int(request_id) + response = await async_engine.chat_completion( + messages=[{"role": "user", "content": prompt}], + model=model.model, + max_tokens=max_tokens, + n=n, + request_id=request_id, + ) + for choice in response.choices: + assert choice.message.role == "assistant" + output_texts[rid][choice.index] += choice.message.content + + tasks = [ + asyncio.create_task(generate_task(prompts[i], request_id=str(i))) + for i in range(num_requests) + ] + + await asyncio.gather(*tasks) + + # Print output. + print("Chat completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + async_engine.terminate() + del async_engine + + async def test_completion(): # Initialize model loading info and KV cache config model = ModelInfo( @@ -144,13 +195,14 @@ async def test_completion(): async def generate_task(prompt: str, request_id: str): print(f"generate completion task for request {request_id}") rid = int(request_id) - async for response in async_engine.completion( + async for response in await async_engine.completion( prompt=prompt, model=model.model, max_tokens=max_tokens, n=n, ignore_eos=True, request_id=request_id, + stream=True, ): for choice in response.choices: output_texts[rid][choice.index] += choice.text @@ -163,7 +215,57 @@ async def generate_task(prompt: str, request_id: str): await asyncio.gather(*tasks) # Print output. - print("Chat completion all finished") + print("Completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + async_engine.terminate() + del async_engine + + +async def test_completion_non_stream(): + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + # Create engine + async_engine = AsyncEngine(model, kv_cache_config) + + num_requests = 2 + max_tokens = 128 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + async def generate_task(prompt: str, request_id: str): + print(f"generate completion task for request {request_id}") + rid = int(request_id) + response = await async_engine.completion( + prompt=prompt, + model=model.model, + max_tokens=max_tokens, + n=n, + ignore_eos=True, + request_id=request_id, + ) + for choice in response.choices: + output_texts[rid][choice.index] += choice.text + + tasks = [ + asyncio.create_task(generate_task(prompts[i], request_id=str(i))) + for i in range(num_requests) + ] + + await asyncio.gather(*tasks) + + # Print output. + print("Completion all finished") for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") if len(outputs) == 1: @@ -179,4 +281,6 @@ async def generate_task(prompt: str, request_id: str): if __name__ == "__main__": asyncio.run(test_engine_generate()) asyncio.run(test_chat_completion()) + asyncio.run(test_chat_completion_non_stream()) asyncio.run(test_completion()) + asyncio.run(test_completion_non_stream()) diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index cece8a1e27..aa54f4cd97 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -80,6 +80,7 @@ def test_chat_completion(): max_tokens=max_tokens, n=n, request_id=str(rid), + stream=True, ): for choice in response.choices: assert choice.delta.role == "assistant" @@ -99,6 +100,48 @@ def test_chat_completion(): del engine +def test_chat_completion_non_stream(): + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + # Create engine + engine = Engine(model, kv_cache_config) + + num_requests = 2 + max_tokens = 64 + n = 2 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + for rid in range(num_requests): + print(f"chat completion for request {rid}") + response = engine.chat_completion( + messages=[{"role": "user", "content": prompts[rid]}], + model=model.model, + max_tokens=max_tokens, + n=n, + request_id=str(rid), + ) + for choice in response.choices: + assert choice.message.role == "assistant" + output_texts[rid][choice.index] += choice.message.content + + # Print output. + print("Chat completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + engine.terminate() + del engine + + def test_completion(): # Initialize model loading info and KV cache config model = ModelInfo( @@ -123,12 +166,55 @@ def test_completion(): n=n, ignore_eos=True, request_id=str(rid), + stream=True, ): for choice in response.choices: output_texts[rid][choice.index] += choice.text # Print output. - print("Chat completion all finished") + print("Completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + engine.terminate() + del engine + + +def test_completion_non_stream(): + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + # Create engine + engine = Engine(model, kv_cache_config) + + num_requests = 2 + max_tokens = 128 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + for rid in range(num_requests): + print(f"completion for request {rid}") + response = engine.completion( + prompt=prompts[rid], + model=model.model, + max_tokens=max_tokens, + n=n, + ignore_eos=True, + request_id=str(rid), + ) + for choice in response.choices: + output_texts[rid][choice.index] += choice.text + + # Print output. + print("Completion all finished") for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") if len(outputs) == 1: @@ -144,4 +230,6 @@ def test_completion(): if __name__ == "__main__": test_engine_generate() test_chat_completion() + test_chat_completion_non_stream() test_completion() + test_completion_non_stream() From 50766fd09b7f589ec9c5806ea87f80e285312100 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 7 Apr 2024 08:31:24 -0400 Subject: [PATCH 165/531] [Serving][Refactor] OpenAI API Python interface alignment (#2099) This PR aligns the Python API of chat completions and completions MLC serve with the OpenAI Python package https://github.com/openai/openai-python. Specifically, say we first create an engine or async engine, then we can use entrance `engine.chat.completions.create(...)` for chat completions. We will add more use examples in the codebase after another few refactors. --- python/mlc_llm/serve/engine.py | 967 ++++++++++++------ tests/python/serve/test_serve_async_engine.py | 8 +- tests/python/serve/test_serve_engine.py | 8 +- 3 files changed, 655 insertions(+), 328 deletions(-) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 2846d0ffc3..b822285d44 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -4,6 +4,8 @@ import asyncio import queue +import sys +import weakref from typing import ( Any, AsyncGenerator, @@ -29,47 +31,31 @@ logger = logging.getLogger(__name__) -class AsyncEngine(engine_base.EngineBase): - """The AsyncEngine in MLC LLM that provides the asynchronous - interfaces with regard to OpenAI API. - - Parameters - ---------- - models : Union[ModelInfo, List[ModelInfo]] - One or a list of model info (specifying which models to load and - which device to load to) to launch the engine. - - kv_cache_config : KVCacheConfig - The configuration of the paged KV cache. +class Chat: # pylint: disable=too-few-public-methods + """The proxy class to direct to chat completions.""" - engine_mode : Optional[EngineMode] - The Engine execution mode. + def __init__(self, engine: weakref.ReferenceType) -> None: + assert isinstance(engine(), (AsyncEngine, Engine)) + self.completions = ( + AsyncChatCompletion(engine) # type: ignore + if isinstance(engine(), AsyncEngine) + else ChatCompletion(engine) # type: ignore + ) - enable_tracing : bool - A boolean indicating if to enable event logging for requests. - """ - def __init__( - self, - models: Union[engine_base.ModelInfo, List[engine_base.ModelInfo]], - kv_cache_config: KVCacheConfig, - engine_mode: Optional[EngineMode] = None, - enable_tracing: bool = False, - ) -> None: - super().__init__("async", models, kv_cache_config, engine_mode, enable_tracing) +class AsyncChatCompletion: # pylint: disable=too-few-public-methods + """The proxy class to direct to async chat completions.""" - async def abort(self, request_id: str) -> None: - """Generation abortion interface. + if sys.version_info >= (3, 9): + engine: weakref.ReferenceType["AsyncEngine"] + else: + engine: weakref.ReferenceType - Parameter - --------- - request_id : str - The id of the request to abort. - """ - self._abort(request_id) + def __init__(self, engine: weakref.ReferenceType) -> None: + self.engine = engine @overload - async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + async def create( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], @@ -119,7 +105,7 @@ async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals """ @overload - async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + async def create( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], @@ -168,7 +154,7 @@ async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals BadRequestError is raised when the request is invalid. """ - async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + async def create( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], @@ -210,87 +196,218 @@ async def chat_completion( # pylint: disable=too-many-arguments,too-many-locals e : BadRequestError BadRequestError is raised when the request is invalid. """ - if request_id is None: - request_id = f"chatcmpl-{engine_utils.random_uuid()}" - - chatcmpl_generator = self._handle_chat_completion( - openai_api_protocol.ChatCompletionRequest( - messages=[ - openai_api_protocol.ChatCompletionMessage.model_validate(message) - for message in messages - ], - model=model, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - logprobs=logprobs, - top_logprobs=top_logprobs, - logit_bias=logit_bias, - max_tokens=max_tokens, - n=n, - seed=seed, - stop=stop, - stream=stream, - temperature=temperature, - top_p=top_p, - tools=( - [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] - if tools is not None - else None - ), - tool_choice=tool_choice, - user=user, - ignore_eos=ignore_eos, - response_format=( - openai_api_protocol.RequestResponseFormat.model_validate(response_format) - if response_format is not None - else None - ), - ), + return await self.engine()._chat_completion( # pylint: disable=protected-access + messages=messages, + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=tools, + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=response_format, request_id=request_id, ) - if stream: - # Stream response. - return chatcmpl_generator - # Normal response. - num_prompt_tokens = 0 - num_completion_tokens = 0 - output_texts = ["" for _ in range(n)] - finish_reasons: List[Optional[str]] = [None for _ in range(n)] - logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( - [[] for _ in range(n)] if logprobs else None - ) - async for response in chatcmpl_generator: - num_prompt_tokens = response.usage.prompt_tokens - num_completion_tokens = response.usage.completion_tokens - for choice in response.choices: - assert isinstance(choice.delta.content, str) - output_texts[choice.index] += choice.delta.content - if choice.finish_reason is not None and finish_reasons[choice.index] is None: - finish_reasons[choice.index] = choice.finish_reason - if choice.logprobs is not None: - assert logprob_results is not None - logprob_results[ # pylint: disable=unsupported-assignment-operation - choice.index - ] += choice.logprobs.content - assert all(finish_reason is not None for finish_reason in finish_reasons) - use_function_calling, tool_calls_list = engine_base.process_function_call_output( - output_texts, finish_reasons - ) - return engine_base.wrap_chat_completion_response( - request_id=request_id, + +class ChatCompletion: # pylint: disable=too-few-public-methods + """The proxy class to direct to chat completions.""" + + if sys.version_info >= (3, 9): + engine: weakref.ReferenceType["Engine"] + else: + engine: weakref.ReferenceType + + def __init__(self, engine: weakref.ReferenceType) -> None: + self.engine = engine + + @overload + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + stream: Literal[True], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + """Synchronous streaming chat completion interface with OpenAI API compatibility. + The method streams back ChatCompletionStreamResponse that conforms to + OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : ChatCompletionStreamResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/streaming for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + @overload + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.ChatCompletionResponse: + """Synchronous non-streaming chat completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : ChatCompletionResponse + The chat completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/chat/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + Iterator[openai_api_protocol.ChatCompletionStreamResponse], + openai_api_protocol.ChatCompletionResponse, + ]: + """Synchronous chat completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + return self.engine()._chat_completion( # pylint: disable=protected-access + messages=messages, model=model, - output_texts=output_texts, - finish_reasons=finish_reasons, - tool_calls_list=tool_calls_list, - logprob_results=logprob_results, - use_function_calling=use_function_calling, - num_prompt_tokens=num_prompt_tokens, - num_completion_tokens=num_completion_tokens, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=tools, + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=response_format, + request_id=request_id, ) + +class AsyncCompletion: # pylint: disable=too-few-public-methods + """The proxy class to direct to async completions.""" + + if sys.version_info >= (3, 9): + engine: weakref.ReferenceType["AsyncEngine"] + else: + engine: weakref.ReferenceType + + def __init__(self, engine: weakref.ReferenceType) -> None: + self.engine = engine + @overload - async def completion( # pylint: disable=too-many-arguments,too-many-locals + async def create( # pylint: disable=too-many-arguments,too-many-locals self, *, model: str, @@ -341,7 +458,7 @@ async def completion( # pylint: disable=too-many-arguments,too-many-locals """ @overload - async def completion( # pylint: disable=too-many-arguments,too-many-locals + async def create( # pylint: disable=too-many-arguments,too-many-locals self, *, model: str, @@ -383,13 +500,419 @@ async def completion( # pylint: disable=too-many-arguments,too-many-locals See mlc_llm/protocol/openai_api_protocol.py or https://platform.openai.com/docs/api-reference/completions/object for specification. - Raises - ------ - e : BadRequestError - BadRequestError is raised when the request is invalid. - """ + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + async def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + model: str, + prompt: Union[str, List[int]], + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + AsyncGenerator[openai_api_protocol.CompletionResponse, Any], + openai_api_protocol.CompletionResponse, + ]: + """Asynchronous completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + return await self.engine()._completion( # pylint: disable=protected-access + model=model, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + suffix=suffix, + temperature=temperature, + top_p=top_p, + user=user, + ignore_eos=ignore_eos, + response_format=response_format, + request_id=request_id, + ) + + +class Completion: # pylint: disable=too-few-public-methods + """The proxy class to direct to completions.""" + + if sys.version_info >= (3, 9): + engine: weakref.ReferenceType["Engine"] + else: + engine: weakref.ReferenceType + + def __init__(self, engine: weakref.ReferenceType) -> None: + self.engine = engine + + @overload + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + model: str, + prompt: Union[str, List[int]], + stream: Literal[True], + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> openai_api_protocol.CompletionResponse: + """Synchronous streaming completion interface with OpenAI API compatibility. + The method streams back CompletionResponse that conforms to + OpenAI API one at a time via yield. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Yields + ------ + stream_response : CompletionResponse + The stream response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + @overload + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + model: str, + prompt: Union[str, List[int]], + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Literal[False] = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """Synchronous non-streaming completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Returns + ------ + response : CompletionResponse + The completion response conforming to OpenAI API. + See mlc_llm/protocol/openai_api_protocol.py or + https://platform.openai.com/docs/api-reference/completions/object for specification. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + + def create( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + model: str, + prompt: Union[str, List[int]], + best_of: int = 1, + echo: bool = False, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: int = 16, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + suffix: Optional[str] = None, + temperature: float = 1.0, + top_p: float = 1.0, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.CompletionResponse]: + """Synchronous completion interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/completions/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + return self.engine()._completion( # pylint: disable=protected-access + model=model, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + suffix=suffix, + temperature=temperature, + top_p=top_p, + user=user, + ignore_eos=ignore_eos, + response_format=response_format, + request_id=request_id, + ) + + +class AsyncEngine(engine_base.EngineBase): + """The AsyncEngine in MLC LLM that provides the asynchronous + interfaces with regard to OpenAI API. + + Parameters + ---------- + models : Union[ModelInfo, List[ModelInfo]] + One or a list of model info (specifying which models to load and + which device to load to) to launch the engine. + + kv_cache_config : KVCacheConfig + The configuration of the paged KV cache. + + engine_mode : Optional[EngineMode] + The Engine execution mode. + + enable_tracing : bool + A boolean indicating if to enable event logging for requests. + """ + + def __init__( + self, + models: Union[engine_base.ModelInfo, List[engine_base.ModelInfo]], + kv_cache_config: KVCacheConfig, + engine_mode: Optional[EngineMode] = None, + enable_tracing: bool = False, + ) -> None: + super().__init__("async", models, kv_cache_config, engine_mode, enable_tracing) + self.chat = Chat(weakref.ref(self)) + self.completions = AsyncCompletion(weakref.ref(self)) + + async def abort(self, request_id: str) -> None: + """Generation abortion interface. + + Parameter + --------- + request_id : str + The id of the request to abort. + """ + self._abort(request_id) + + async def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Union[ + AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any], + openai_api_protocol.ChatCompletionResponse, + ]: + """Asynchronous chat completion internal interface with OpenAI API compatibility. + + See https://platform.openai.com/docs/api-reference/chat/create for specification. + + Parameters + ---------- + request_id : Optional[str] + The optional request id. + A random one will be generated if it is not given. + + Raises + ------ + e : BadRequestError + BadRequestError is raised when the request is invalid. + """ + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ), + request_id=request_id, + ) + if stream: + # Stream response. + return chatcmpl_generator + # Normal response. + num_prompt_tokens = 0 + num_completion_tokens = 0 + output_texts = ["" for _ in range(n)] + finish_reasons: List[Optional[str]] = [None for _ in range(n)] + logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( + [[] for _ in range(n)] if logprobs else None + ) + async for response in chatcmpl_generator: + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + assert isinstance(choice.delta.content, str) + output_texts[choice.index] += choice.delta.content + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + + assert all(finish_reason is not None for finish_reason in finish_reasons) + use_function_calling, tool_calls_list = engine_base.process_function_call_output( + output_texts, finish_reasons + ) + return engine_base.wrap_chat_completion_response( + request_id=request_id, + model=model, + output_texts=output_texts, + finish_reasons=finish_reasons, + tool_calls_list=tool_calls_list, + logprob_results=logprob_results, + use_function_calling=use_function_calling, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) - async def completion( # pylint: disable=too-many-arguments,too-many-locals + async def _completion( # pylint: disable=too-many-arguments,too-many-locals self, *, model: str, @@ -417,7 +940,7 @@ async def completion( # pylint: disable=too-many-arguments,too-many-locals AsyncGenerator[openai_api_protocol.CompletionResponse, Any], openai_api_protocol.CompletionResponse, ]: - """Asynchronous completion interface with OpenAI API compatibility. + """Asynchronous completion internal interface with OpenAI API compatibility. See https://platform.openai.com/docs/api-reference/completions/create for specification. @@ -714,6 +1237,8 @@ def __init__( enable_tracing: bool = False, ) -> None: super().__init__("sync", models, kv_cache_config, engine_mode, enable_tracing) + self.chat = Chat(weakref.ref(self)) + self.completions = Completion(weakref.ref(self)) def abort(self, request_id: str) -> None: """Generation abortion interface. @@ -725,105 +1250,7 @@ def abort(self, request_id: str) -> None: """ self._ffi["abort_request"](request_id) - @overload - def chat_completion( # pylint: disable=too-many-arguments,too-many-locals - self, - *, - messages: List[Dict[str, Any]], - model: str, - stream: Literal[True], - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - logprobs: bool = False, - top_logprobs: int = 0, - logit_bias: Optional[Dict[int, float]] = None, - max_tokens: Optional[int] = None, - n: int = 1, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - temperature: float = 1.0, - top_p: float = 1.0, - tools: Optional[List[Dict[str, Any]]] = None, - tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, - user: Optional[str] = None, - ignore_eos: bool = False, - response_format: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None, - ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - """Synchronous streaming chat completion interface with OpenAI API compatibility. - The method streams back ChatCompletionStreamResponse that conforms to - OpenAI API one at a time via yield. - - See https://platform.openai.com/docs/api-reference/chat/create for specification. - - Parameters - ---------- - request_id : Optional[str] - The optional request id. - A random one will be generated if it is not given. - - Yields - ------ - stream_response : ChatCompletionStreamResponse - The stream response conforming to OpenAI API. - See mlc_llm/protocol/openai_api_protocol.py or - https://platform.openai.com/docs/api-reference/chat/streaming for specification. - - Raises - ------ - e : BadRequestError - BadRequestError is raised when the request is invalid. - """ - - @overload - def chat_completion( # pylint: disable=too-many-arguments,too-many-locals - self, - *, - messages: List[Dict[str, Any]], - model: str, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - logprobs: bool = False, - top_logprobs: int = 0, - logit_bias: Optional[Dict[int, float]] = None, - max_tokens: Optional[int] = None, - n: int = 1, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Literal[False] = False, - temperature: float = 1.0, - top_p: float = 1.0, - tools: Optional[List[Dict[str, Any]]] = None, - tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, - user: Optional[str] = None, - ignore_eos: bool = False, - response_format: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None, - ) -> openai_api_protocol.ChatCompletionResponse: - """Synchronous non-streaming chat completion interface with OpenAI API compatibility. - - See https://platform.openai.com/docs/api-reference/chat/create for specification. - - Parameters - ---------- - request_id : Optional[str] - The optional request id. - A random one will be generated if it is not given. - - Returns - ------ - response : ChatCompletionResponse - The chat completion response conforming to OpenAI API. - See mlc_llm/protocol/openai_api_protocol.py or - https://platform.openai.com/docs/api-reference/chat/object for specification. - - Raises - ------ - e : BadRequestError - BadRequestError is raised when the request is invalid. - """ - - def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], @@ -850,7 +1277,7 @@ def chat_completion( # pylint: disable=too-many-arguments,too-many-locals Iterator[openai_api_protocol.ChatCompletionStreamResponse], openai_api_protocol.ChatCompletionResponse, ]: - """Synchronous chat completion interface with OpenAI API compatibility. + """Synchronous chat completion internal interface with OpenAI API compatibility. See https://platform.openai.com/docs/api-reference/chat/create for specification. @@ -944,107 +1371,7 @@ def chat_completion( # pylint: disable=too-many-arguments,too-many-locals num_completion_tokens=num_completion_tokens, ) - @overload - def completion( # pylint: disable=too-many-arguments,too-many-locals - self, - *, - model: str, - prompt: Union[str, List[int]], - stream: Literal[True], - best_of: int = 1, - echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - logprobs: bool = False, - top_logprobs: int = 0, - logit_bias: Optional[Dict[int, float]] = None, - max_tokens: int = 16, - n: int = 1, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, - user: Optional[str] = None, - ignore_eos: bool = False, - response_format: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None, - ) -> openai_api_protocol.CompletionResponse: - """Synchronous streaming completion interface with OpenAI API compatibility. - The method streams back CompletionResponse that conforms to - OpenAI API one at a time via yield. - - See https://platform.openai.com/docs/api-reference/completions/create for specification. - - Parameters - ---------- - request_id : Optional[str] - The optional request id. - A random one will be generated if it is not given. - - Yields - ------ - stream_response : CompletionResponse - The stream response conforming to OpenAI API. - See mlc_llm/protocol/openai_api_protocol.py or - https://platform.openai.com/docs/api-reference/completions/object for specification. - - Raises - ------ - e : BadRequestError - BadRequestError is raised when the request is invalid. - """ - - @overload - def completion( # pylint: disable=too-many-arguments,too-many-locals - self, - *, - model: str, - prompt: Union[str, List[int]], - best_of: int = 1, - echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - logprobs: bool = False, - top_logprobs: int = 0, - logit_bias: Optional[Dict[int, float]] = None, - max_tokens: int = 16, - n: int = 1, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Literal[False] = False, - suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, - user: Optional[str] = None, - ignore_eos: bool = False, - response_format: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None, - ) -> Iterator[openai_api_protocol.CompletionResponse]: - """Synchronous non-streaming completion interface with OpenAI API compatibility. - - See https://platform.openai.com/docs/api-reference/completions/create for specification. - - Parameters - ---------- - request_id : Optional[str] - The optional request id. - A random one will be generated if it is not given. - - Returns - ------ - response : CompletionResponse - The completion response conforming to OpenAI API. - See mlc_llm/protocol/openai_api_protocol.py or - https://platform.openai.com/docs/api-reference/completions/object for specification. - - Raises - ------ - e : BadRequestError - BadRequestError is raised when the request is invalid. - """ - - def completion( # pylint: disable=too-many-arguments,too-many-locals + def _completion( # pylint: disable=too-many-arguments,too-many-locals self, *, model: str, @@ -1069,7 +1396,7 @@ def completion( # pylint: disable=too-many-arguments,too-many-locals response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, ) -> Iterator[openai_api_protocol.CompletionResponse]: - """Synchronous completion interface with OpenAI API compatibility. + """Synchronous completion internal interface with OpenAI API compatibility. See https://platform.openai.com/docs/api-reference/completions/create for specification. diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index cb6a065b41..4da72c5deb 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -94,7 +94,7 @@ async def test_chat_completion(): async def generate_task(prompt: str, request_id: str): print(f"generate chat completion task for request {request_id}") rid = int(request_id) - async for response in await async_engine.chat_completion( + async for response in await async_engine.chat.completions.create( messages=[{"role": "user", "content": prompt}], model=model.model, max_tokens=max_tokens, @@ -145,7 +145,7 @@ async def test_chat_completion_non_stream(): async def generate_task(prompt: str, request_id: str): print(f"generate chat completion task for request {request_id}") rid = int(request_id) - response = await async_engine.chat_completion( + response = await async_engine.chat.completions.create( messages=[{"role": "user", "content": prompt}], model=model.model, max_tokens=max_tokens, @@ -195,7 +195,7 @@ async def test_completion(): async def generate_task(prompt: str, request_id: str): print(f"generate completion task for request {request_id}") rid = int(request_id) - async for response in await async_engine.completion( + async for response in await async_engine.completions.create( prompt=prompt, model=model.model, max_tokens=max_tokens, @@ -246,7 +246,7 @@ async def test_completion_non_stream(): async def generate_task(prompt: str, request_id: str): print(f"generate completion task for request {request_id}") rid = int(request_id) - response = await async_engine.completion( + response = await async_engine.completions.create( prompt=prompt, model=model.model, max_tokens=max_tokens, diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index aa54f4cd97..eccf1facda 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -74,7 +74,7 @@ def test_chat_completion(): for rid in range(num_requests): print(f"chat completion for request {rid}") - for response in engine.chat_completion( + for response in engine.chat.completions.create( messages=[{"role": "user", "content": prompts[rid]}], model=model.model, max_tokens=max_tokens, @@ -117,7 +117,7 @@ def test_chat_completion_non_stream(): for rid in range(num_requests): print(f"chat completion for request {rid}") - response = engine.chat_completion( + response = engine.chat.completions.create( messages=[{"role": "user", "content": prompts[rid]}], model=model.model, max_tokens=max_tokens, @@ -159,7 +159,7 @@ def test_completion(): for rid in range(num_requests): print(f"completion for request {rid}") - for response in engine.completion( + for response in engine.completions.create( prompt=prompts[rid], model=model.model, max_tokens=max_tokens, @@ -202,7 +202,7 @@ def test_completion_non_stream(): for rid in range(num_requests): print(f"completion for request {rid}") - response = engine.completion( + response = engine.completions.create( prompt=prompts[rid], model=model.model, max_tokens=max_tokens, From fb24fcfc1bc18c5fd79d977e147deec4c48bac2a Mon Sep 17 00:00:00 2001 From: Hangrui Cao <50705298+DiegoCao@users.noreply.github.com> Date: Sun, 7 Apr 2024 15:35:20 -0400 Subject: [PATCH 166/531] [DOC] fix small python env install error (#2102) Fixed one slight issue of tvm install: would require specify python=3.11 on the platform otherwise might encounter python not found error. --- docs/install/tvm.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst index 7fbd3d08ad..849152cce6 100644 --- a/docs/install/tvm.rst +++ b/docs/install/tvm.rst @@ -160,7 +160,8 @@ While it is generally recommended to always use the prebuilt TVM Unity, if you r conda create -n tvm-build-venv -c conda-forge \ "llvmdev>=15" \ "cmake>=3.24" \ - git + git \ + python=3.11 # enter the build environment conda activate tvm-build-venv From cc8b7476cc0aeabb3311715295303f8d09546b11 Mon Sep 17 00:00:00 2001 From: Animesh Bohara Date: Sun, 7 Apr 2024 23:52:34 -0400 Subject: [PATCH 167/531] [JSONFFIEngine] Initial implementation of JSONFFIEngine (#2101) This PR introduces initial support for the JSONFFIEngine. The request is supposed to be a JSON string in the [Chat completion request body format](https://platform.openai.com/docs/api-reference/chat/create). The output (input to the callback function provided) is a list of JSON strings in the [Chat completion chunk object format](https://platform.openai.com/docs/api-reference/chat/streaming). There is still functionality to be added, which will be added in follow-up PRs. 1. Support for other input datatypes (image, etc.) 2. Applying conversation template to input 3. Function calling and tools support 4. Generation config parameters support 5. Independent text streamers for each request 6. logprobs support --- Co-authored-by: Ruihang Lai --- cpp/json_ffi/json_ffi_engine.cc | 216 +++++++++++++ cpp/json_ffi/json_ffi_engine.h | 56 ++++ cpp/json_ffi/openai_api_protocol.cc | 224 ++++++++++++++ cpp/json_ffi/openai_api_protocol.h | 168 ++++++++++ cpp/metadata/json_parser.h | 49 +++ cpp/serve/config.cc | 20 ++ cpp/serve/config.h | 8 + cpp/serve/engine.h | 1 + .../mlc_llm/protocol/openai_api_protocol.py | 2 +- tests/python/json_ffi/test_json_ffi_engine.py | 289 ++++++++++++++++++ 10 files changed, 1032 insertions(+), 1 deletion(-) create mode 100644 cpp/json_ffi/json_ffi_engine.cc create mode 100644 cpp/json_ffi/json_ffi_engine.h create mode 100644 cpp/json_ffi/openai_api_protocol.cc create mode 100644 cpp/json_ffi/openai_api_protocol.h create mode 100644 tests/python/json_ffi/test_json_ffi_engine.py diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc new file mode 100644 index 0000000000..489e2e5339 --- /dev/null +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -0,0 +1,216 @@ +#include "json_ffi_engine.h" + +#include +#include +#include + +namespace mlc { +namespace llm { +namespace json_ffi { + +using namespace tvm::runtime; + +JSONFFIEngine::JSONFFIEngine() { engine_ = serve::ThreadedEngine::Create(); } + +bool JSONFFIEngine::ChatCompletion(std::string request_json_str, std::string request_id) { + bool success = this->AddRequest(request_json_str, request_id); + if (!success) { + this->StreamBackError(request_id); + } + return success; +} + +void JSONFFIEngine::StreamBackError(std::string request_id) { + ChatCompletionMessage delta; + delta.content = std::vector>{ + {{"type", "text"}, {"text", this->err_}}}; + delta.role = Role::assistant; + + ChatCompletionStreamResponseChoice choice; + choice.finish_reason = FinishReason::error; + choice.index = 0; + choice.delta = delta; + + ChatCompletionStreamResponse response; + response.id = request_id; + response.choices = std::vector{choice}; + response.model = "json_ffi"; // TODO: Return model name from engine (or from args) + response.system_fingerprint = ""; + + this->request_stream_callback_(Array{picojson::value(response.ToJSON()).serialize()}); +} + +bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request_id) { + std::optional optional_request = + ChatCompletionRequest::FromJSON(request_json_str, &err_); + if (!optional_request.has_value()) { + return false; + } + ChatCompletionRequest request = optional_request.value(); + // Create Request + // TODO: Check if request_id is present already + + // inputs + // TODO: Apply conv template + Array inputs; + for (const auto& message : request.messages) { + if (message.content.has_value()) { + for (const auto& content : message.content.value()) { + if (content.find("type") == content.end()) { + err_ += "Content should have a type field"; + return false; + } + std::string type = content.at("type"); + if (type == "text") { + if (content.find("text") == content.end()) { + err_ += "Content should have a text field"; + return false; + } + std::string text = content.at("text"); + inputs.push_back(TextData(text)); + } else { + err_ += "Content type not supported"; + return false; + } + } + } + } + + // generation_cfg + Optional generation_cfg = GenerationConfig::FromJSON(request_json_str, &err_); + if (!generation_cfg.defined()) { + return false; + } + + Request engine_request(request_id, inputs, generation_cfg.value()); + this->engine_->AddRequest(engine_request); + + return true; +} + +bool JSONFFIEngine::Abort(std::string request_id) { + this->engine_->AbortRequest(request_id); + return true; +} + +std::string JSONFFIEngine::GetLastError() { return err_; } + +void JSONFFIEngine::ExitBackgroundLoop() { this->engine_->ExitBackgroundLoop(); } + +JSONFFIEngine::~JSONFFIEngine() { this->ExitBackgroundLoop(); } + +class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { + public: + TVM_MODULE_VTABLE_BEGIN("mlc.json_ffi"); + TVM_MODULE_VTABLE_ENTRY("chat_completion", &JSONFFIEngineImpl::ChatCompletion); + TVM_MODULE_VTABLE_ENTRY("abort", &JSONFFIEngineImpl::Abort); + TVM_MODULE_VTABLE_ENTRY("get_last_error", &JSONFFIEngineImpl::GetLastError); + TVM_MODULE_VTABLE_ENTRY("run_background_loop", &JSONFFIEngineImpl::RunBackgroundLoop); + TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", + &JSONFFIEngineImpl::RunBackgroundStreamBackLoop); + TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); + if (_name == "init_background_engine") { + return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { + SelfPtr self = static_cast(_self.get()); + + std::string tokenizer_path = args.At(1); + self->streamer_ = TextStreamer(Tokenizer::FromPath(tokenizer_path)); + + // Callback wrapper + Optional request_stream_callback; + try { + request_stream_callback = args.At>(4); + } catch (const dmlc::Error& e) { + LOG(FATAL) << "ValueError: " << e.what() << kEngineCreationErrorMessage; + } + + CHECK(request_stream_callback.defined()) + << "JSONFFIEngine requires request stream callback function, but it is not given."; + self->request_stream_callback_ = request_stream_callback.value(); + + auto frequest_stream_callback_wrapper = [self](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.size(), 1); + Array delta_outputs = args[0]; + Array responses = self->GetResponseFromStreamOutput(delta_outputs); + self->request_stream_callback_(responses); + }; + + std::vector values{args.values, args.values + args.size()}; + std::vector type_codes{args.type_codes, args.type_codes + args.size()}; + TVMArgsSetter setter(values.data(), type_codes.data()); + request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); + setter(4, request_stream_callback); + self->engine_->InitBackgroundEngine(TVMArgs(values.data(), type_codes.data(), args.size())); + }); + } + TVM_MODULE_VTABLE_END(); + + void RunBackgroundLoop() { this->engine_->RunBackgroundLoop(); } + + void RunBackgroundStreamBackLoop() { this->engine_->RunBackgroundStreamBackLoop(); } + + Array GetResponseFromStreamOutput(Array delta_outputs) { + std::unordered_map> response_map; + for (const auto& delta_output : delta_outputs) { + std::string request_id = delta_output->request_id; + if (response_map.find(request_id) == response_map.end()) { + response_map[request_id] = std::vector(); + } + ChatCompletionStreamResponseChoice choice; + + if (delta_output->group_finish_reason.size() != 1) { + // Only support n = 1 in ChatCompletionStreamResponse for now + this->err_ += "Group finish reason should have exactly one element"; + } + Optional finish_reason = delta_output->group_finish_reason[0]; + if (finish_reason.defined()) { + if (finish_reason.value() == "stop") { + choice.finish_reason = FinishReason::stop; + } else if (finish_reason.value() == "length") { + choice.finish_reason = FinishReason::length; + } else if (finish_reason.value() == "tool_calls") { + choice.finish_reason = FinishReason::tool_calls; + } else if (finish_reason.value() == "error") { + choice.finish_reason = FinishReason::error; + } + } else { + choice.finish_reason = std::nullopt; + } + + choice.index = response_map[request_id].size(); + + ChatCompletionMessage delta; + // Size of delta_output->group_delta_token_ids Array should be 1 + IntTuple delta_token_ids = delta_output->group_delta_token_ids[0]; + std::vector delta_token_ids_vec(delta_token_ids.begin(), delta_token_ids.end()); + delta.content = std::vector>(); + delta.content.value().push_back(std::unordered_map{ + {"type", "text"}, {"text", this->streamer_->Put(delta_token_ids_vec)}}); + + delta.role = Role::assistant; + + choice.delta = delta; + + response_map[request_id].push_back(choice); + } + + Array response_arr; + for (const auto& [request_id, choices] : response_map) { + ChatCompletionStreamResponse response; + response.id = request_id; + response.choices = choices; + response.model = "json_ffi"; // TODO: Return model name from engine (or from args) + response.system_fingerprint = ""; + response_arr.push_back(picojson::value(response.ToJSON()).serialize()); + } + return response_arr; + } +}; + +TVM_REGISTER_GLOBAL("mlc.json_ffi.CreateJSONFFIEngine").set_body_typed([]() { + return Module(make_object()); +}); + +} // namespace json_ffi +} // namespace llm +} // namespace mlc diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h new file mode 100644 index 0000000000..83013b5876 --- /dev/null +++ b/cpp/json_ffi/json_ffi_engine.h @@ -0,0 +1,56 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file json_ffi/json_ffi_engine.h + * \brief The header of JSON FFI engine in MLC LLM. + */ +#ifndef MLC_LLM_JSON_FFI_JSON_FFI_ENGINE_H_ +#define MLC_LLM_JSON_FFI_JSON_FFI_ENGINE_H_ + +#include + +#include + +#include "../serve/threaded_engine.h" +#include "../streamer.h" +#include "openai_api_protocol.h" + +namespace mlc { +namespace llm { +namespace json_ffi { + +using namespace tvm::runtime; +using namespace mlc::llm::serve; + +/*! + * \brief // Todo: document this class, fields and member functions + */ +class JSONFFIEngine { + public: + JSONFFIEngine(); + + ~JSONFFIEngine(); + + bool ChatCompletion(std::string request_json_str, std::string request_id); + + bool AddRequest(std::string request_json_str, std::string request_id); + + void StreamBackError(std::string request_id); + + bool Abort(std::string request_id); + + std::string GetLastError(); + + void ExitBackgroundLoop(); + + protected: + std::unique_ptr engine_; + std::string err_; + PackedFunc request_stream_callback_; + TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request +}; + +} // namespace json_ffi +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_JSON_FFI_JSON_FFI_ENGINE_H_ diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc new file mode 100644 index 0000000000..41378fc3e0 --- /dev/null +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -0,0 +1,224 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file json_ffi/openai_api_protocol.cc + * \brief The implementation of OpenAI API Protocol in MLC LLM. + */ +#include "openai_api_protocol.h" + +#include "../metadata/json_parser.h" + +namespace mlc { +namespace llm { +namespace json_ffi { + +std::optional ChatCompletionMessage::FromJSON(const picojson::value& json, + std::string* err) { + if (!json.is()) { + *err += "Input is not a valid JSON object"; + return std::nullopt; + } + picojson::object json_obj = json.get(); + + ChatCompletionMessage message; + + // content + picojson::array content_arr; + if (!json::ParseJSONField(json_obj, "content", content_arr, err, true)) { + return std::nullopt; + } + std::vector > content; + for (const auto& item : content_arr) { + if (!item.is()) { + *err += "Content item is not an object"; + return std::nullopt; + } + std::unordered_map item_map; + picojson::object item_obj = item.get(); + for (picojson::value::object::const_iterator i = item_obj.begin(); i != item_obj.end(); ++i) { + item_map[i->first] = i->second.to_str(); + } + content.push_back(item_map); + } + message.content = content; + + // role + std::string role_str; + if (!json::ParseJSONField(json_obj, "role", role_str, err, true)) { + return std::nullopt; + } + if (role_str == "system") { + message.role = Role::system; + } else if (role_str == "user") { + message.role = Role::user; + } else if (role_str == "assistant") { + message.role = Role::assistant; + } else if (role_str == "tool") { + message.role = Role::tool; + } else { + *err += "Invalid role"; + return std::nullopt; + } + + // name + std::string name; + if (json::ParseJSONField(json_obj, "name", name, err, false)) { + message.name = name; + } + + // TODO: tool_calls and tool_call_id + + return message; +} + +std::optional ChatCompletionRequest::FromJSON( + const picojson::object& json_obj, std::string* err) { + ChatCompletionRequest request; + + // messages + picojson::array messages_arr; + if (!json::ParseJSONField(json_obj, "messages", messages_arr, err, true)) { + return std::nullopt; + } + std::vector messages; + for (const auto& item : messages_arr) { + std::optional message = ChatCompletionMessage::FromJSON(item, err); + if (!message.has_value()) { + return std::nullopt; + } + messages.push_back(message.value()); + } + request.messages = messages; + + // model + std::string model; + if (!json::ParseJSONField(json_obj, "model", model, err, true)) { + return std::nullopt; + } + request.model = model; + + // frequency_penalty + double frequency_penalty; + if (json::ParseJSONField(json_obj, "frequency_penalty", frequency_penalty, err, false)) { + request.frequency_penalty = frequency_penalty; + } + + // presence_penalty + double presence_penalty; + if (json::ParseJSONField(json_obj, "presence_penalty", presence_penalty, err, false)) { + request.presence_penalty = presence_penalty; + } + + // TODO: Other parameters + + return request; +} + +std::optional ChatCompletionRequest::FromJSON(const std::string& json_str, + std::string* err) { + std::optional json_obj = json::LoadJSONFromString(json_str, err); + if (!json_obj.has_value()) { + return std::nullopt; + } + return ChatCompletionRequest::FromJSON(json_obj.value(), err); +} + +picojson::object ChatCompletionMessage::ToJSON() { + picojson::object obj; + picojson::array content_arr; + for (const auto& item : this->content.value()) { + picojson::object item_obj; + for (const auto& pair : item) { + item_obj[pair.first] = picojson::value(pair.second); + } + content_arr.push_back(picojson::value(item_obj)); + } + obj["content"] = picojson::value(content_arr); + if (this->role == Role::system) { + obj["role"] = picojson::value("system"); + } else if (this->role == Role::user) { + obj["role"] = picojson::value("user"); + } else if (this->role == Role::assistant) { + obj["role"] = picojson::value("assistant"); + } else if (this->role == Role::tool) { + obj["role"] = picojson::value("tool"); + } + if (name.has_value()) { + obj["name"] = picojson::value(name.value()); + } + return obj; +} + +picojson::object ChatCompletionResponseChoice::ToJSON() { + picojson::object obj; + if (!this->finish_reason.has_value()) { + obj["finish_reason"] = picojson::value(); + } else { + if (this->finish_reason == FinishReason::stop) { + obj["finish_reason"] = picojson::value("stop"); + } else if (this->finish_reason == FinishReason::length) { + obj["finish_reason"] = picojson::value("length"); + } else if (this->finish_reason == FinishReason::tool_calls) { + obj["finish_reason"] = picojson::value("tool_calls"); + } else if (this->finish_reason == FinishReason::error) { + obj["finish_reason"] = picojson::value("error"); + } + } + obj["index"] = picojson::value((int64_t)this->index); + obj["message"] = picojson::value(this->message.ToJSON()); + return obj; +} + +picojson::object ChatCompletionStreamResponseChoice::ToJSON() { + picojson::object obj; + if (!this->finish_reason.has_value()) { + obj["finish_reason"] = picojson::value(); + } else { + if (this->finish_reason.value() == FinishReason::stop) { + obj["finish_reason"] = picojson::value("stop"); + } else if (this->finish_reason.value() == FinishReason::length) { + obj["finish_reason"] = picojson::value("length"); + } else if (this->finish_reason.value() == FinishReason::tool_calls) { + obj["finish_reason"] = picojson::value("tool_calls"); + } else if (this->finish_reason.value() == FinishReason::error) { + obj["finish_reason"] = picojson::value("error"); + } + } + + obj["index"] = picojson::value((int64_t)this->index); + obj["delta"] = picojson::value(this->delta.ToJSON()); + return obj; +} + +picojson::object ChatCompletionResponse::ToJSON() { + picojson::object obj; + obj["id"] = picojson::value(this->id); + picojson::array choices_arr; + for (auto& choice : this->choices) { + choices_arr.push_back(picojson::value(choice.ToJSON())); + } + obj["choices"] = picojson::value(choices_arr); + obj["created"] = picojson::value((int64_t)this->created); + obj["model"] = picojson::value(this->model); + obj["system_fingerprint"] = picojson::value(this->system_fingerprint); + obj["object"] = picojson::value(this->object); + return obj; +} + +picojson::object ChatCompletionStreamResponse::ToJSON() { + picojson::object obj; + obj["id"] = picojson::value(this->id); + picojson::array choices_arr; + for (auto& choice : this->choices) { + choices_arr.push_back(picojson::value(choice.ToJSON())); + } + obj["choices"] = picojson::value(choices_arr); + obj["created"] = picojson::value((int64_t)this->created); + obj["model"] = picojson::value(this->model); + obj["system_fingerprint"] = picojson::value(this->system_fingerprint); + obj["object"] = picojson::value(this->object); + return obj; +} + +} // namespace json_ffi +} // namespace llm +} // namespace mlc diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h new file mode 100644 index 0000000000..1579b5f337 --- /dev/null +++ b/cpp/json_ffi/openai_api_protocol.h @@ -0,0 +1,168 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file json_ffi/openai_api_protocol.h + * \brief The header of OpenAI API Protocol in MLC LLM. + */ +#ifndef MLC_LLM_JSON_FFI_OPENAI_API_PROTOCOL_H +#define MLC_LLM_JSON_FFI_OPENAI_API_PROTOCOL_H + +#include +#include +#include +#include +#include + +#include "picojson.h" + +namespace mlc { +namespace llm { +namespace json_ffi { + +enum class Role { system, user, assistant, tool }; +enum class Type { text, json_object, function }; +enum class FinishReason { stop, length, tool_calls, error }; + +// TODO: Implement the following class +class ChatFunction { + public: + std::optional description = std::nullopt; + std::string name; + std::unordered_map + parameters; // Assuming parameters are string key-value pairs + + static std::optional FromJSON(const picojson::value& json, std::string* err); +}; + +// TODO: Implement the following class +class ChatTool { + public: + Type type = Type::function; + ChatFunction function; + + static std::optional FromJSON(const picojson::value& json, std::string* err); +}; + +// TODO: Implement the following class +class ChatFunctionCall { + public: + std::string name; + std::optional> arguments = + std::nullopt; // Assuming arguments are string key-value pairs +}; + +// TODO: Implement the following class +class ChatToolCall { + public: + std::string id; // TODO: python code initializes this to an random string + Type type = Type::function; + ChatFunctionCall function; +}; + +class ChatCompletionMessage { + public: + std::optional>> content = + std::nullopt; // Assuming content is a list of string key-value pairs + Role role; + std::optional name = std::nullopt; + std::optional> tool_calls = std::nullopt; // TODO: Implement this + std::optional tool_call_id = std::nullopt; // TODO: Implement this + + static std::optional FromJSON(const picojson::value& json, + std::string* err); + picojson::object ToJSON(); +}; + +class RequestResponseFormat { + public: + Type type = Type::text; + std::optional json_schema = std::nullopt; +}; + +class ChatCompletionRequest { + public: + std::vector messages; + std::string model; + double frequency_penalty = 0.0; + double presence_penalty = 0.0; + bool logprobs = false; + int top_logprobs = 0; + std::optional> logit_bias = std::nullopt; + std::optional max_tokens = std::nullopt; + int n = 1; + std::optional seed = std::nullopt; + std::optional> stop = std::nullopt; + bool stream = false; + double temperature = 1.0; + double top_p = 1.0; + std::optional> tools = std::nullopt; + std::optional tool_choice = std::nullopt; + std::optional user = std::nullopt; + bool ignore_eos = false; + // RequestResponseFormat response_format; //TODO: implement this + + /*! + * \brief Create a ChatCompletionRequest instance from the given JSON object. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + static std::optional FromJSON(const picojson::object& json_obj, + std::string* err); + /*! + * \brief Parse and create a ChatCompletionRequest instance from the given JSON string. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + static std::optional FromJSON(const std::string& json_str, + std::string* err); + + // TODO: check_penalty_range, check_logit_bias, check_logprobs +}; + +class ChatCompletionResponseChoice { + public: + std::optional finish_reason; + int index = 0; + ChatCompletionMessage message; + // TODO: logprobs + + picojson::object ToJSON(); +}; + +class ChatCompletionStreamResponseChoice { + public: + std::optional finish_reason; + int index = 0; + ChatCompletionMessage delta; + // TODO: logprobs + + picojson::object ToJSON(); +}; + +class ChatCompletionResponse { + public: + std::string id; + std::vector choices; + int created = static_cast(std::time(nullptr)); + std::string model; + std::string system_fingerprint; + std::string object = "chat.completion"; + // TODO: usage_info + + picojson::object ToJSON(); +}; + +class ChatCompletionStreamResponse { + public: + std::string id; + std::vector choices; + int created = static_cast(std::time(nullptr)); + std::string model; + std::string system_fingerprint; + std::string object = "chat.completion.chunk"; + + picojson::object ToJSON(); +}; + +} // namespace json_ffi +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_JSON_FFI_OPENAI_API_PROTOCOL_H diff --git a/cpp/metadata/json_parser.h b/cpp/metadata/json_parser.h index 14f622f2c8..f6ff10e1ac 100644 --- a/cpp/metadata/json_parser.h +++ b/cpp/metadata/json_parser.h @@ -10,6 +10,8 @@ #include #include +#include + namespace mlc { namespace llm { namespace json { @@ -20,6 +22,53 @@ namespace json { * \return The parsed JSON object. */ picojson::object ParseToJsonObject(const std::string& json_str); + +// Todo(mlc-team): implement "Result" class for JSON parsing with error collection. +/*! + * \brief Parse input JSON string into JSON dict. + * Any error will be dumped to the input error string. + */ +inline std::optional LoadJSONFromString(const std::string& json_str, + std::string* err) { + ICHECK_NOTNULL(err); + picojson::value json; + *err = picojson::parse(json, json_str); + if (!json.is()) { + *err += "The input JSON string does not correspond to a JSON dict."; + return std::nullopt; + } + return json.get(); +} + +/*! + * \brief // Todo(mlc-team): document this function. + * \tparam T + * \param json_obj + * \param field + * \param value + * \param err + * \param required + * \return + */ +template +inline bool ParseJSONField(const picojson::object& json_obj, const std::string& field, T& value, + std::string* err, bool required) { + // T can be int, double, bool, string, picojson::array + if (json_obj.count(field)) { + if (!json_obj.at(field).is()) { + *err += "Field " + field + " is not of type " + typeid(T).name() + "\n"; + return false; + } + value = json_obj.at(field).get(); + } else { + if (required) { + *err += "Field " + field + " is required\n"; + return false; + } + } + return true; +} + /*! * \brief Lookup a JSON object by a key, and convert it to a given type. * \param json The JSON object to look up. diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 3465de402e..0c69296326 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -8,6 +8,8 @@ #include +#include "../json_ffi/openai_api_protocol.h" +#include "../metadata/json_parser.h" #include "data.h" namespace mlc { @@ -158,6 +160,24 @@ GenerationConfig::GenerationConfig(String config_json_str) { data_ = std::move(n); } +Optional GenerationConfig::FromJSON(const std::string& json_str, + std::string* err) { + std::optional json_obj = json::LoadJSONFromString(json_str, err); + if (!err->empty() || !json_obj.has_value()) { + return NullOpt; + } + ObjectPtr n = make_object(); + + // TODO(mlc-team): Pass the parameters from `json_obj` to `n`. + + if (!err->empty()) { + return NullOpt; + } + GenerationConfig gen_config; + gen_config.data_ = std::move(n); + return gen_config; +} + String GenerationConfigNode::AsJSONString() const { picojson::object config; config["n"] = picojson::value(static_cast(this->n)); diff --git a/cpp/serve/config.h b/cpp/serve/config.h index c406e55125..0c3402b2ca 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -9,6 +9,8 @@ #include #include +#include + namespace mlc { namespace llm { namespace serve { @@ -57,6 +59,12 @@ class GenerationConfig : public ObjectRef { public: explicit GenerationConfig(String config_json_str); + /*! + * \brief Parse the generation config from the given JSON string. + * When parsing fails, errors are dumped to the input error string, and NullOpt is returned. + */ + static Optional FromJSON(const std::string& json_str, std::string* err); + TVM_DEFINE_OBJECT_REF_METHODS(GenerationConfig, ObjectRef, GenerationConfigNode); }; diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index 9ff38bdc42..973be50093 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -11,6 +11,7 @@ #include "data.h" #include "event_trace_recorder.h" #include "request.h" +#include "request_state.h" namespace mlc { namespace llm { diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 6f5754dee1..1cbf0bd228 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -329,7 +329,7 @@ class ChatCompletionResponseChoice(BaseModel): class ChatCompletionStreamResponseChoice(BaseModel): - finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "error"]] = None index: int = 0 delta: ChatCompletionMessage logprobs: Optional[LogProbs] = None diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py new file mode 100644 index 0000000000..0d8448c9c5 --- /dev/null +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -0,0 +1,289 @@ +# pylint: disable=chained-comparison,line-too-long,missing-docstring, +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +import json +import queue +import threading +from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union + +import tvm + +from mlc_llm.protocol import error_protocol, openai_api_protocol +from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig, engine_utils +from mlc_llm.serve.engine_base import ( + EngineMode, + ModelInfo, + _estimate_max_total_sequence_length, + _process_model_args, +) +from mlc_llm.tokenizer import Tokenizer + +prompts = [ + "What is the meaning of life?", + "Introduce the history of Pittsburgh to me. Please elaborate in detail.", + "Write a three-day Seattle travel plan. Please elaborate in detail.", + "What is Alaska famous of? Please elaborate in detail.", + "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", + "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", + "Why is Vitamin D important to human beings? Please elaborate in detail.", + "Where is milk tea originated from? Please elaborate in detail.", + "Where is the southernmost place in United States? Please elaborate in detail.", + "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", +] + + +class EngineState: + sync_queue: queue.Queue + + def get_request_stream_callback(self) -> Callable[[List[str]], None]: + # ChatCompletionStreamResponse + + def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: + self._sync_request_stream_callback(chat_completion_stream_responses_json_str) + + return _callback + + def _sync_request_stream_callback( + self, chat_completion_stream_responses_json_str: List[str] + ) -> None: + # Put the delta outputs to the queue in the unblocking way. + self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) + + +class JSONFFIEngine: + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + models: Union[ModelInfo, List[ModelInfo]], + kv_cache_config: KVCacheConfig, + engine_mode: Optional[EngineMode] = None, + ) -> None: + if isinstance(models, ModelInfo): + models = [models] + ( + model_args, + config_file_paths, + tokenizer_path, + max_single_sequence_length, + prefill_chunk_size, + self.conv_template, + ) = _process_model_args(models) + + self.model_config_dicts = [] + for i, model in enumerate(models): + # model_args: + # [model_lib_path, model_path, device.device_type, device.device_id] * N + model.model_lib_path = model_args[i * (len(model_args) // len(models))] + with open(config_file_paths[i], "r", encoding="utf-8") as file: + self.model_config_dicts.append(json.load(file)) + + self.state = EngineState() + + if kv_cache_config.max_total_sequence_length is None: + kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( + models, config_file_paths, kv_cache_config.max_num_sequence + ) + self.max_input_sequence_length = min( + max_single_sequence_length, kv_cache_config.max_total_sequence_length + ) + prefill_chunk_size = min(prefill_chunk_size, kv_cache_config.max_total_sequence_length) + + if kv_cache_config.prefill_chunk_size is None: + kv_cache_config.prefill_chunk_size = prefill_chunk_size + elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: + raise ValueError( + f"The specified prefill chunk size {kv_cache_config.prefill_chunk_size} is " + f"larger than the maximum prefill chunk size {prefill_chunk_size} supported by " + "models. Please specify a smaller prefill chunk size." + ) + + module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() + self._ffi = { + key: module[key] + for key in [ + "init_background_engine", + "chat_completion", + "abort", + "get_last_error", + "run_background_loop", + "run_background_stream_back_loop", + "exit_background_loop", + ] + } + self.tokenizer = Tokenizer(tokenizer_path) + if engine_mode is None: + # The default engine mode: non-speculative + engine_mode = EngineMode() + + def _background_loop(): + self._ffi["init_background_engine"]( + max_single_sequence_length, + tokenizer_path, + kv_cache_config.asjson(), + engine_mode.asjson(), + self.state.get_request_stream_callback(), + None, + *model_args, + ) + self._ffi["run_background_loop"]() + + def _background_stream_back_loop(): + self._ffi["run_background_stream_back_loop"]() + + # Create the background engine-driving thread and start the loop. + self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_stream_back_loop_thread: threading.Thread = threading.Thread( + target=_background_stream_back_loop + ) + self._background_loop_thread.start() + self._background_stream_back_loop_thread.start() + self._terminated = False + + def terminate(self): + self._terminated = True + self._ffi["exit_background_loop"]() + self._background_loop_thread.join() + self._background_stream_back_loop_thread.join() + + def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ).model_dump_json(), + n=n, + request_id=request_id, + ) + for response in chatcmpl_generator: + yield response + + def _handle_chat_completion( + self, request_json_str: str, n: int, request_id: str + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + self.state.sync_queue = queue.Queue() + num_unfinished_requests = n + + success = bool(self._ffi["chat_completion"](request_json_str, request_id)) + + try: + while num_unfinished_requests > 0: + chat_completion_stream_responses_json_str = self.state.sync_queue.get() + for chat_completion_response_json_str in chat_completion_stream_responses_json_str: + chat_completion_response = ( + openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( + chat_completion_response_json_str + ) + ) + for choice in chat_completion_response.choices: + if choice.finish_reason is not None: + num_unfinished_requests -= 1 + yield chat_completion_response + except Exception as exception: # pylint: disable=broad-exception-caught + self._ffi["abort"](request_id) + raise exception + + +def test_chat_completion(engine: JSONFFIEngine): + num_requests = 2 + max_tokens = 64 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + for rid in range(num_requests): + print(f"chat completion for request {rid}") + for response in engine.chat_completion( + messages=[{"role": "user", "content": [{"type": "text", "text": prompts[rid]}]}], + model=model.model, + max_tokens=max_tokens, + n=n, + request_id=str(rid), + ): + for choice in response.choices: + assert choice.delta.role == "assistant" + assert isinstance(choice.delta.content[0], Dict) + assert choice.delta.content[0]["type"] == "text" + output_texts[rid][choice.index] += choice.delta.content[0]["text"] + + # Print output. + print("Chat completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + +def test_malformed_request(engine: JSONFFIEngine): + for response in engine._handle_chat_completion("malformed_string", n=1, request_id="123"): + assert len(response.choices) == 1 + assert response.choices[0].finish_reason == "error" + + +if __name__ == "__main__": + # Initialize model loading info and KV cache config + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=1024) + engine = JSONFFIEngine(model, kv_cache_config) + + test_chat_completion(engine) + test_malformed_request(engine) + + engine.terminate() + del engine From 95d268bf1c072206a6ae4e51143fbfc263c0d7b6 Mon Sep 17 00:00:00 2001 From: Jeethu Rao Date: Mon, 8 Apr 2024 20:36:59 +0100 Subject: [PATCH 168/531] [Model] Use tanh approximation of GeLU in Gemma MLP (#2106) This is in line with the implementation in the [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L183) library. Also, the [gemma-1.1](https://huggingface.co/google/gemma-1.1-2b-it/blob/main/config.json#L10) model config. --- python/mlc_llm/model/gemma/gemma_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mlc_llm/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py index 5950ab2972..118f3ce856 100644 --- a/python/mlc_llm/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -39,7 +39,7 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): - if self.hidden_act != "gelu": + if self.hidden_act not in ("gelu", "gelu_pytorch_tanh"): raise ValueError("Only GeLU is supported as the activation for gemma.") if self.attention_bias: raise ValueError('Only "False" attention_bias is supported for gemma') @@ -115,7 +115,7 @@ def __init__(self, config: GemmaConfig): def forward(self, x: Tensor): concat_x1_x2 = self.gate_up_proj(x) x1, x2 = op.split(concat_x1_x2, 2, axis=-1) - return self.down_proj(op.gelu(x1) * x2) + return self.down_proj(op.gelu(x1, approximate="tanh") * x2) class GemmaAttention(nn.Module): # pylint: disable=too-many-instance-attributes From 36d0e6aca1288123791c8650133582d768d356a6 Mon Sep 17 00:00:00 2001 From: Git bot Date: Mon, 8 Apr 2024 20:08:21 +0000 Subject: [PATCH 169/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 5400532c4b..6ce8430b7f 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 5400532c4ba37e8a30fcaac488c2ecb05a307e4f +Subproject commit 6ce8430b7f8b894789e9d6a12e5fe3231290cd9c From 3e71b70ac98b985404bca39b03c77daa0f7b5017 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 9 Apr 2024 15:22:03 -0400 Subject: [PATCH 170/531] [Quantization] Stricter checks for MoE gate (#2109) This PR strenthens the MoE gate checks to include checking number of experts, given the real MoE gate router layer's output feature number is the number of experts and is usually very small. This PR comes from a regression that there is a layer in RWKV6 that ends with name "gate" is not for MoE at all. --- python/mlc_llm/quantization/awq_quantization.py | 6 +++++- python/mlc_llm/quantization/ft_quantization.py | 2 +- python/mlc_llm/quantization/group_quantization.py | 2 +- python/mlc_llm/quantization/utils.py | 4 ++-- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/mlc_llm/quantization/awq_quantization.py b/python/mlc_llm/quantization/awq_quantization.py index 1d7cddbfa6..d51f0a6020 100644 --- a/python/mlc_llm/quantization/awq_quantization.py +++ b/python/mlc_llm/quantization/awq_quantization.py @@ -117,7 +117,11 @@ def visit_module(self, name: str, node: nn.Module) -> Any: The new node to replace current node. """ - if isinstance(node, nn.Linear) and not is_final_fc(name) and not is_moe_gate(name): + if ( + isinstance(node, nn.Linear) + and not is_final_fc(name) + and not is_moe_gate(name, node) + ): return AWQQuantizeLinear.from_linear(node, self.config) return self.visit(name, node) diff --git a/python/mlc_llm/quantization/ft_quantization.py b/python/mlc_llm/quantization/ft_quantization.py index b6b1da100f..4a15846096 100644 --- a/python/mlc_llm/quantization/ft_quantization.py +++ b/python/mlc_llm/quantization/ft_quantization.py @@ -147,7 +147,7 @@ def visit_module(self, name: str, node: nn.Module) -> Any: group_quantize = self.config.fallback_group_quantize() self.quant_map.map_func[weight_name] = group_quantize.quantize_weight return GroupQuantizeLinear.from_linear(node, group_quantize) - if not is_moe_gate(name): + if not is_moe_gate(name, node): self.quant_map.map_func[weight_name] = self.config.quantize_weight return FTQuantizeLinear.from_linear(node, self.config) if isinstance(node, nn.Embedding): diff --git a/python/mlc_llm/quantization/group_quantization.py b/python/mlc_llm/quantization/group_quantization.py index 1da5174721..1a9dd82519 100644 --- a/python/mlc_llm/quantization/group_quantization.py +++ b/python/mlc_llm/quantization/group_quantization.py @@ -113,7 +113,7 @@ def visit_module(self, name: str, node: nn.Module) -> Any: if ( isinstance(node, nn.Linear) and (not is_final_fc(name) or self.config.quantize_final_fc) - and not is_moe_gate(name) + and not is_moe_gate(name, node) ): weight_name = f"{name}.weight" self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] diff --git a/python/mlc_llm/quantization/utils.py b/python/mlc_llm/quantization/utils.py index c24c9b4271..fdc50ff74d 100644 --- a/python/mlc_llm/quantization/utils.py +++ b/python/mlc_llm/quantization/utils.py @@ -54,9 +54,9 @@ def is_final_fc(name: str) -> bool: return name in ["head", "lm_head", "lm_head.linear", "embed_out"] -def is_moe_gate(name: str) -> bool: +def is_moe_gate(name: str, node: nn.Linear) -> bool: """Check whether the parameter is the MoE gate layer.""" - return name.endswith("gate") + return name.endswith("gate") and isinstance(node.out_features, int) and node.out_features < 16 def compile_quantize_func(mod: IRModule, device) -> Callable: From 623ed624f5f0c213d9235c873eb68eb8ad3e1cac Mon Sep 17 00:00:00 2001 From: Git bot Date: Wed, 10 Apr 2024 00:10:06 +0000 Subject: [PATCH 171/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 6ce8430b7f..c7bdcabd60 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 6ce8430b7f8b894789e9d6a12e5fe3231290cd9c +Subproject commit c7bdcabd602f3d882e764232692d1d1eb449d07b From 021c29c8821b435c4159fa71c654f5757c010eec Mon Sep 17 00:00:00 2001 From: Animesh Bohara Date: Wed, 10 Apr 2024 08:25:54 -0400 Subject: [PATCH 172/531] [LLaVa] Fix allowed text model value in config (#2062) * Llava support vicuna and mistral text models * Support f32 quantization * Lint fix * Use preset if transformers not installed * Rebase on main --------- Co-authored-by: Animesh Bohara --- python/mlc_llm/model/llava/llava_model.py | 118 +++++++++++++--------- python/mlc_llm/serve/data.py | 26 +++-- 2 files changed, 88 insertions(+), 56 deletions(-) diff --git a/python/mlc_llm/model/llava/llava_model.py b/python/mlc_llm/model/llava/llava_model.py index 30963f990c..1498c13fdb 100644 --- a/python/mlc_llm/model/llava/llava_model.py +++ b/python/mlc_llm/model/llava/llava_model.py @@ -23,10 +23,12 @@ from tvm.relax.op import arange, strided_slice from mlc_llm import op as op_ext +from mlc_llm.model.model_preset import MODEL_PRESETS from mlc_llm.nn import PagedKVCache, RopeMode from ...support.config import ConfigBase from ..llama.llama_model import LlamaConfig, LlamaForCasualLM +from ..mistral.mistral_model import MistralConfig, MistralForCasualLM logger = logging.getLogger(__name__) @@ -45,12 +47,15 @@ class LlavaVisionConfig(ConfigBase): # pylint: disable=too-many-instance-attrib patch_size: int projection_dim: int vocab_size: int - dtype: str = "float16" num_channels: int = 3 layer_norm_eps: float = 1e-06 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) +CONFIG_MAP = {"LlamaForCausalLM": LlamaConfig, "MistralForCausalLM": MistralConfig} +ARCHITECTURE_MAP = {"LlamaForCausalLM": LlamaForCasualLM, "MistralForCausalLM": MistralForCasualLM} + + @dataclasses.dataclass class LlavaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes """ @@ -61,11 +66,12 @@ class LlavaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes text_config: LlamaConfig vision_config: LlavaVisionConfig vocab_size: int - context_window_size: int = 0 - prefill_chunk_size: int = 0 + context_window_size: int = -1 + sliding_window_size: int = -1 + prefill_chunk_size: int = -1 tensor_parallel_shards: int = 1 - dtype: str = "float16" max_batch_size: int = 1 + text_architecture: str = "LlamaForCausalLM" kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -81,41 +87,54 @@ def __post_init__(self): self.vision_config = LlavaVisionConfig.from_dict(vision_config_dict) text_config_dict: Dict[str, Any] - if isinstance(self.text_config, LlamaConfig): + if isinstance(self.text_config, ConfigBase): text_config_dict = dataclasses.asdict(self.text_config) else: text_config_dict = dict(self.text_config) if "_name_or_path" in text_config_dict: - if text_config_dict["_name_or_path"] == "meta-llama/Llama-2-7b-hf": - text_config_dict["hidden_size"] = text_config_dict.pop("hidden_size", 4096) - text_config_dict["intermediate_size"] = text_config_dict.pop( - "intermediate_size", 11008 - ) - text_config_dict["num_attention_heads"] = text_config_dict.pop( - "num_attention_heads", 32 - ) - text_config_dict["num_hidden_layers"] = text_config_dict.pop( - "num_hidden_layers", 32 - ) - text_config_dict["rms_norm_eps"] = text_config_dict.pop("rms_norm_eps", 1e-06) - text_config_dict["vocab_size"] = text_config_dict.pop("vocab_size", 32064) - text_config_dict["context_window_size"] = text_config_dict.pop( - "context_window_size", 4096 - ) - else: - raise ValueError("Unsupported text model") + hf_config = self.get_hf_config(text_config_dict) + text_config_dict.update(hf_config) + architectures = text_config_dict["architectures"] + assert len(architectures) == 1 + self.text_architecture = architectures[0] else: for k, v in text_config_dict.pop("kwargs", {}).items(): text_config_dict[k] = v - self.text_config = LlamaConfig.from_dict(text_config_dict) - - if self.context_window_size <= 0: - self.context_window_size = self.text_config.context_window_size + self.text_config = CONFIG_MAP[self.text_architecture].from_dict(text_config_dict) + + for k in ["context_window_size", "sliding_window_size", "prefill_chunk_size"]: + if getattr(self, k) <= 0: + if hasattr(self.text_config, k): + setattr(self, k, getattr(self.text_config, k)) + + def get_hf_config(self, text_config_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Get the Hugging Face config of the text model + """ + + hf_config: Dict[str, Any] + try: + # pylint: disable=import-outside-toplevel, import-error + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(text_config_dict["_name_or_path"]).to_dict() + except (ImportError, OSError) as e: + # If transformers is not installed, get the config from preset + # Llama2 is gated so it throws an OSError. Get the config from preset instead + preset_mapping = { + "meta-llama/Llama-2-7b-hf": "llama2_7b", + "meta-llama/Llama-2-13b-hf": "llama2_13b", + "lmsys/vicuna-7b-v1.5": "llama2_7b", + "mistralai/Mistral-7B-v0.1": "mistral_7b", + } + if text_config_dict["_name_or_path"] in preset_mapping: + hf_config = MODEL_PRESETS[preset_mapping[text_config_dict["_name_or_path"]]] + else: + raise ValueError("Unsupported text model") from e - if self.prefill_chunk_size <= 0: - self.prefill_chunk_size = self.text_config.prefill_chunk_size + return hf_config # pylint: disable=missing-docstring @@ -128,21 +147,18 @@ def __init__(self, config: LlavaVisionConfig): self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size - self.class_embedding = nn.Parameter((self.embed_dim,), dtype=config.dtype) + self.class_embedding = nn.Parameter((self.embed_dim,)) self.patch_embedding = Conv2D( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False, - dtype=config.dtype, ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 - self.position_embedding = nn.Embedding( - num=self.num_positions, dim=self.embed_dim, dtype=config.dtype - ) + self.position_embedding = nn.Embedding(num=self.num_positions, dim=self.embed_dim) def forward(self, pixel_values: Tensor) -> Tensor: batch_size = pixel_values.shape[0] @@ -194,8 +210,8 @@ class CLIPMLP(Module): def __init__(self, config: LlavaVisionConfig): super().__init__() self.activation_fn = LlavaQuickGELU() - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, dtype=config.dtype) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, dtype=config.dtype) + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: Tensor) -> Tensor: hidden_states = self.fc1(hidden_states) @@ -216,10 +232,10 @@ def __init__(self, config: LlavaVisionConfig): f" and `num_heads`: {self.num_heads})." ) self.scale = self.head_dim**-0.5 - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def _shape(self, tensor: Tensor, seq_len: int, bsz: int): reshape_tensor = reshape(tensor, shape=(bsz, seq_len, self.num_heads, self.head_dim)) @@ -263,13 +279,9 @@ def __init__(self, config: LlavaVisionConfig): super().__init__() self.embed_dim = config.hidden_size self.self_attn = CLIPAttention(config) - self.layer_norm1 = nn.LayerNorm( - normalized_shape=self.embed_dim, eps=config.layer_norm_eps, dtype=config.dtype - ) + self.layer_norm1 = nn.LayerNorm(normalized_shape=self.embed_dim, eps=config.layer_norm_eps) self.mlp = CLIPMLP(config) - self.layer_norm2 = nn.LayerNorm( - normalized_shape=self.embed_dim, eps=config.layer_norm_eps, dtype=config.dtype - ) + self.layer_norm2 = nn.LayerNorm(normalized_shape=self.embed_dim, eps=config.layer_norm_eps) def forward(self, hidden_states: Tensor) -> Tensor: residual = hidden_states @@ -308,9 +320,9 @@ def __init__(self, config: LlavaVisionConfig): super().__init__() embed_dim = config.hidden_size self.embeddings = CLIPVisionEmbeddings(config) - self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, dtype=config.dtype) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = CLIPEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps, dtype=config.dtype) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward(self, pixel_values: Tensor) -> Tensor: hidden_states = self.embeddings(pixel_values) @@ -353,9 +365,15 @@ def __init__(self, config: LlavaConfig): self.config = config self.vision_tower = CLIPVisionModel(config.vision_config) self.multi_modal_projector = LlavaMultiModalProjector(config) - self.language_model = LlamaForCasualLM(config.text_config) + self.language_model = ARCHITECTURE_MAP[config.text_architecture](config.text_config) self.vocab_size = config.vocab_size - self.dtype = config.dtype + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + self.language_model.to(dtype=dtype) + if dtype is not None: + self.dtype = dtype def _embed_input_ids(self, input_ids: Tensor) -> Tensor: return self.language_model.embed(input_ids) diff --git a/python/mlc_llm/serve/data.py b/python/mlc_llm/serve/data.py index b8ffc8da8f..1c56178ad1 100644 --- a/python/mlc_llm/serve/data.py +++ b/python/mlc_llm/serve/data.py @@ -83,7 +83,7 @@ def __len__(self): return self.embed_size @staticmethod - def from_url(url: str, config: Dict) -> "ImageData": + def from_url(url: str, config: Dict) -> "ImageData": # pylint: disable=too-many-locals """Get the image from the given URL, process and return the image tensor as TVM NDArray.""" # pylint: disable=import-outside-toplevel, import-error @@ -105,23 +105,37 @@ def from_url(url: str, config: Dict) -> "ImageData": else: raise ValueError(f"Unsupported image URL format: {url}") - image_input_size = config["model_config"]["vision_config"]["image_size"] - image_embed_size = ( - image_input_size // config["model_config"]["vision_config"]["patch_size"] - ) ** 2 + image_input_size = ImageData.get_input_size(config) + image_embed_size = ImageData.get_embed_size(config) image_processor = CLIPImageProcessor( size={"shortest_edge": image_input_size}, crop_size={"height": image_input_size, "width": image_input_size}, ) + quantization = config["quantization"] + out_dtype = "float16" if "f16" in quantization else "float32" image_features = tvm.nd.array( image_processor.preprocess(image_tensor, return_tensors="np")["pixel_values"].astype( - "float16" + out_dtype ) ) image_data = ImageData(image_features, image_embed_size) return image_data + @staticmethod + def get_embed_size(config: Dict) -> int: + """Get the image embedding size from the model config file.""" + image_size = config["model_config"]["vision_config"]["image_size"] + patch_size = config["model_config"]["vision_config"]["patch_size"] + embed_size = (image_size // patch_size) ** 2 + return embed_size + + @staticmethod + def get_input_size(config: Dict) -> int: + """Get the image input size from the model config file.""" + image_size = config["model_config"]["vision_config"]["image_size"] + return image_size + @dataclass class SingleRequestStreamOutput: From c4169d8c8a4afedd06bc9d9b99c3aa65eee4a89e Mon Sep 17 00:00:00 2001 From: Git bot Date: Wed, 10 Apr 2024 14:10:35 +0000 Subject: [PATCH 173/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index c7bdcabd60..d2b00d25cb 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c7bdcabd602f3d882e764232692d1d1eb449d07b +Subproject commit d2b00d25cbaee2df7cf515117bb05220cc872a73 From f832bde67f5149d9a2a0332d72368d22cf64b7b7 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 10 Apr 2024 10:43:18 -0400 Subject: [PATCH 174/531] =?UTF-8?q?Revert=20"Allow=20"mlc=5Fllm=20--host"?= =?UTF-8?q?=20option=20to=20override=20host=20triple=20the=20model=20compi?= =?UTF-8?q?=E2=80=A6"=20(#2115)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 12ca8fdbe2a24f43bbc72241a76735dbad8c2026. Co-authored-by: Mengshiun Yu --- python/mlc_llm/support/auto_target.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 6e64247ea8..f000cc85b2 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -41,7 +41,7 @@ def detect_target_and_host(target_hint: str, host_hint: str = "auto") -> Tuple[T The hint for the host CPU, default is "auto". """ target, build_func = _detect_target_gpu(target_hint) - if target.host is None or host_hint != "auto": + if target.host is None: target = Target(target, host=_detect_target_host(host_hint)) if target.kind.name == "cuda": # Enable thrust for CUDA From 716a5ed56b653d283edf77da724d768eded7303c Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 10 Apr 2024 11:09:12 -0400 Subject: [PATCH 175/531] Revert "Auto updated submodule references" (#2117) This reverts commit c4169d8c8a4afedd06bc9d9b99c3aa65eee4a89e which causes CI broken. --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index d2b00d25cb..c7bdcabd60 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit d2b00d25cbaee2df7cf515117bb05220cc872a73 +Subproject commit c7bdcabd602f3d882e764232692d1d1eb449d07b From 6c48755b205a983283034b5e4ef1fb24cfa0b9cd Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 10 Apr 2024 11:31:55 -0400 Subject: [PATCH 176/531] [Metadata] Include picojson rather than forward declaring (#2118) This PR fixes the picojson uses in MLC that conflicts with the latest changes on the picojson side. --- cpp/metadata/model.h | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/cpp/metadata/model.h b/cpp/metadata/model.h index 7a3224d28e..2472cb7d36 100644 --- a/cpp/metadata/model.h +++ b/cpp/metadata/model.h @@ -5,6 +5,7 @@ #ifndef MLC_LLM_CPP_MODEL_METADATA_H_ #define MLC_LLM_CPP_MODEL_METADATA_H_ +#include #include #include #include @@ -12,13 +13,6 @@ #include -// Forward declare picojson's value, object and array -namespace picojson { -class value; -using object = std::unordered_map; -using array = std::vector; -} // namespace picojson - namespace mlc { namespace llm { From 39dfa3e1eafd409756f8f1e8f2a9087e9ad46178 Mon Sep 17 00:00:00 2001 From: Git bot Date: Wed, 10 Apr 2024 15:33:50 +0000 Subject: [PATCH 177/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index c7bdcabd60..d2b00d25cb 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c7bdcabd602f3d882e764232692d1d1eb449d07b +Subproject commit d2b00d25cbaee2df7cf515117bb05220cc872a73 From 7f7c01f6e3f397027919889670bb492ac65b6198 Mon Sep 17 00:00:00 2001 From: Git bot Date: Thu, 11 Apr 2024 01:34:27 +0000 Subject: [PATCH 178/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index d2b00d25cb..0f67508236 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit d2b00d25cbaee2df7cf515117bb05220cc872a73 +Subproject commit 0f67508236158e5c7eb7c906df068e4ed95190f9 From a81514875b35f15ea02ed5437b49e7167b251c2f Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Thu, 11 Apr 2024 19:13:15 +0800 Subject: [PATCH 179/531] [Serving][Grammar] Porting the json schema converter from python to C++ (#2112) [Serve][Grammar] Porting the json schema converter from python to C++ This PR ports the json schema converter from python to C++. It defines the interface: ``` std::string JSONSchemaToEBNF( std::string schema, std::optional indent = std::nullopt, std::optional> separators = std::nullopt, bool strict_mode = true); ``` And uses it in BNFGrammar::FromSchema. This helps cases where python cannot be deployed. --- cpp/serve/grammar/grammar.cc | 51 +- cpp/serve/grammar/grammar.h | 28 +- cpp/serve/grammar/grammar_parser.cc | 10 +- cpp/serve/grammar/grammar_parser.h | 4 +- cpp/serve/grammar/grammar_serializer.cc | 4 +- cpp/serve/grammar/grammar_serializer.h | 6 +- cpp/serve/grammar/json_schema_converter.cc | 987 ++++++++++++++++++ cpp/serve/grammar/json_schema_converter.h | 44 + python/mlc_llm/serve/__init__.py | 1 - python/mlc_llm/serve/grammar.py | 50 +- python/mlc_llm/serve/json_schema_converter.py | 742 ------------- .../serve/test_json_schema_converter.py | 125 ++- 12 files changed, 1205 insertions(+), 847 deletions(-) create mode 100644 cpp/serve/grammar/json_schema_converter.cc create mode 100644 cpp/serve/grammar/json_schema_converter.h delete mode 100644 python/mlc_llm/serve/json_schema_converter.py diff --git a/cpp/serve/grammar/grammar.cc b/cpp/serve/grammar/grammar.cc index c4d6445c7e..c8d760538c 100644 --- a/cpp/serve/grammar/grammar.cc +++ b/cpp/serve/grammar/grammar.cc @@ -8,6 +8,7 @@ #include "grammar_parser.h" #include "grammar_serializer.h" #include "grammar_simplifier.h" +#include "json_schema_converter.h" namespace mlc { namespace llm { @@ -20,7 +21,7 @@ std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar) { return os; } -BNFGrammar BNFGrammar::FromEBNFString(const String& ebnf_string, const String& main_rule, +BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string, const std::string& main_rule, bool normalize, bool simplify) { auto grammar = EBNFParser::Parse(ebnf_string, main_rule); if (normalize) { @@ -34,7 +35,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromEBNFString") return BNFGrammar::FromEBNFString(ebnf_string, main_rule, normalize, simplify); }); -BNFGrammar BNFGrammar::FromJSON(const String& json_string) { +BNFGrammar BNFGrammar::FromJSON(const std::string& json_string) { return BNFJSONParser::Parse(json_string); } @@ -42,33 +43,31 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromJSON").set_body_typed([](String jso return BNFGrammar::FromJSON(json_string); }); -BNFGrammar BNFGrammar::FromSchema(const String& schema, int indent, - Optional> separators, bool strict_mode) { - static const PackedFunc* json_schema_to_ebnf = Registry::Get("mlc.serve.json_schema_to_ebnf"); - CHECK(json_schema_to_ebnf != nullptr) << "mlc.serve.json_schema_to_ebnf is not registered."; - - String ebnf_string; - - // Convert the indent parameter to NullOpt for sending it to the PackedFunc. - if (indent == -1) { - // The conversion from TVMRetValue to String is ambiguous, so we call the conversion function - // explicitly - ebnf_string = - ((*json_schema_to_ebnf)(schema, Optional(NullOpt), separators, strict_mode) - . - operator String()); +BNFGrammar BNFGrammar::FromSchema(const std::string& schema, std::optional indent, + std::optional> separators, + bool strict_mode) { + return FromEBNFString(JSONSchemaToEBNF(schema, indent, separators, strict_mode)); +} + +TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromSchema").set_body([](TVMArgs args, TVMRetValue* rv) { + std::optional indent; + if (args[1].type_code() != kTVMNullptr) { + indent = args[1]; } else { - ebnf_string = (*json_schema_to_ebnf)(schema, indent, separators, strict_mode).operator String(); - ; + indent = std::nullopt; } - return FromEBNFString(ebnf_string); -} -TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromSchema") - .set_body_typed([](const String& schema, int indent, Optional> separators, - bool strict_mode) { - return BNFGrammar::FromSchema(schema, indent, separators, strict_mode); - }); + std::optional> separators; + if (args[2].type_code() != kTVMNullptr) { + Array separators_arr = args[2]; + CHECK(separators_arr.size() == 2); + separators = std::make_pair(separators_arr[0], separators_arr[1]); + } else { + separators = std::nullopt; + } + + *rv = BNFGrammar::FromSchema(args[0], indent, separators, args[3]); +}); const std::string kJSONGrammarString = R"( main ::= ( diff --git a/cpp/serve/grammar/grammar.h b/cpp/serve/grammar/grammar.h index 545a4e08a0..ba15e58af3 100644 --- a/cpp/serve/grammar/grammar.h +++ b/cpp/serve/grammar/grammar.h @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -183,33 +184,38 @@ class BNFGrammar : public ObjectRef { * \param simplify Whether to simplify the grammar to make matching more efficient. Default: true. * Not implemented yet. */ - static BNFGrammar FromEBNFString(const String& ebnf_string, const String& main_rule = "main", - bool normalize = true, bool simplify = true); + static BNFGrammar FromEBNFString(const std::string& ebnf_string, + const std::string& main_rule = "main", bool normalize = true, + bool simplify = true); /*! * \brief Construct a BNF grammar from the dumped JSON string. * \param json_string The JSON-formatted string. This string should have the same format as * the result of BNFGrammarJSONSerializer::ToString. */ - static BNFGrammar FromJSON(const String& json_string); + static BNFGrammar FromJSON(const std::string& json_string); /*! * \brief Construct a BNF grammar from the json schema string. The schema string should be in the * format of the schema of a JSON file. We will parse the schema and generate a BNF grammar. * \param schema The schema string. - * \param indent The number of spaces for indentation. If -1, the output will be in one line. - * Default: -1. + * \param indent The number of spaces for indentation. If set to std::nullopt, the output will be + * in one line. Default: std::nullopt. * \param separators Two separators used in the schema: comma and colon. Examples: {",", ":"}, - * {", ", ": "}. If NullOpt, the default separators will be used: {",", ": "} when the indent - * is not -1, and {", ", ": "} otherwise. Default: NullOpt. + * {", ", ": "}. If std::nullopt, the default separators will be used: {",", ": "} when the + * indent is not -1, and {", ", ": "} otherwise. This follows the convention in python + * json.dumps(). Default: std::nullopt. * \param strict_mode Whether to use strict mode. In strict mode, the generated grammar will not - * allow unevaluatedProperties and unevaluatedItems, i.e. these will be set to false by default. + * allow properties and items that is not specified in the schema. This is equivalent to + * setting unevaluatedProperties and unevaluatedItems to false. + * * This helps LLM to generate accurate output in the grammar-guided generation with JSON * schema. Default: true. */ - static BNFGrammar FromSchema(const String& schema, int indent = -1, - Optional> separators = NullOpt, - bool strict_mode = true); + static BNFGrammar FromSchema( + const std::string& schema, std::optional indent = std::nullopt, + std::optional> separators = std::nullopt, + bool strict_mode = true); /*! * \brief Get the grammar of standard JSON format. We have built-in support for JSON. diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index ba9ac80135..1ece99099e 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -16,7 +16,7 @@ namespace serve { class EBNFParserImpl { public: /*! \brief The logic of parsing the grammar string. */ - BNFGrammar DoParse(String ebnf_string, String main_rule); + BNFGrammar DoParse(std::string ebnf_string, std::string main_rule); private: using Rule = BNFGrammarNode::Rule; @@ -192,7 +192,7 @@ int32_t EBNFParserImpl::ParseString() { std::vector character_classes; while (Peek() && Peek() != '\"') { if (Peek() == '\r' || Peek() == '\n') { - ThrowParseError("String should not contain newline"); + ThrowParseError("There should be no newline character in a string literal"); } auto [codepoint, len] = Utf8OrEscapeToCodepoint(cur_); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { @@ -391,7 +391,7 @@ void EBNFParserImpl::ResetStringIterator(const char* cur) { in_parentheses_ = false; } -BNFGrammar EBNFParserImpl::DoParse(String ebnf_string, String main_rule) { +BNFGrammar EBNFParserImpl::DoParse(std::string ebnf_string, std::string main_rule) { ResetStringIterator(ebnf_string.c_str()); BuildRuleNameToId(); @@ -412,12 +412,12 @@ BNFGrammar EBNFParserImpl::DoParse(String ebnf_string, String main_rule) { return builder_.Get(main_rule); } -BNFGrammar EBNFParser::Parse(String ebnf_string, String main_rule) { +BNFGrammar EBNFParser::Parse(std::string ebnf_string, std::string main_rule) { EBNFParserImpl parser; return parser.DoParse(ebnf_string, main_rule); } -BNFGrammar BNFJSONParser::Parse(String json_string) { +BNFGrammar BNFJSONParser::Parse(std::string json_string) { auto node = make_object(); auto grammar_json = json::ParseToJsonObject(json_string); auto rules_json = json::Lookup(grammar_json, "rules"); diff --git a/cpp/serve/grammar/grammar_parser.h b/cpp/serve/grammar/grammar_parser.h index be36f40459..4d10e8eb0d 100644 --- a/cpp/serve/grammar/grammar_parser.h +++ b/cpp/serve/grammar/grammar_parser.h @@ -37,7 +37,7 @@ class EBNFParser { * \param main_rule The name of the main rule. Default is "main". * \return The parsed grammar. */ - static BNFGrammar Parse(String ebnf_string, String main_rule = "main"); + static BNFGrammar Parse(std::string ebnf_string, std::string main_rule = "main"); /*! * \brief The exception thrown when parsing fails. @@ -58,7 +58,7 @@ class BNFJSONParser { * \param json_string The JSON string. * \return The parsed BNF grammar. */ - static BNFGrammar Parse(String json_string); + static BNFGrammar Parse(std::string json_string); }; } // namespace serve diff --git a/cpp/serve/grammar/grammar_serializer.cc b/cpp/serve/grammar/grammar_serializer.cc index a057921f61..fd41517863 100644 --- a/cpp/serve/grammar/grammar_serializer.cc +++ b/cpp/serve/grammar/grammar_serializer.cc @@ -107,7 +107,7 @@ std::string BNFGrammarPrinter::PrintCharacterClassStar(const RuleExpr& rule_expr return PrintRuleExpr(rule_expr[0]) + "*"; } -String BNFGrammarPrinter::ToString() { +std::string BNFGrammarPrinter::ToString() { std::string result; auto num_rules = grammar_->NumRules(); for (auto i = 0; i < num_rules; ++i) { @@ -120,7 +120,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarToString").set_body_typed([](const BNFG return BNFGrammarPrinter(grammar).ToString(); }); -String BNFGrammarJSONSerializer::ToString() { +std::string BNFGrammarJSONSerializer::ToString() { picojson::object grammar_json; picojson::array rules_json; diff --git a/cpp/serve/grammar/grammar_serializer.h b/cpp/serve/grammar/grammar_serializer.h index 5837ce2bf6..8746b1f6ae 100644 --- a/cpp/serve/grammar/grammar_serializer.h +++ b/cpp/serve/grammar/grammar_serializer.h @@ -27,7 +27,7 @@ class BNFGrammarSerializer { explicit BNFGrammarSerializer(const BNFGrammar& grammar) : grammar_(grammar) {} /*! \brief Serialize the grammar to string. */ - virtual String ToString() = 0; + virtual std::string ToString() = 0; protected: const BNFGrammar& grammar_; @@ -50,7 +50,7 @@ class BNFGrammarPrinter : public BNFGrammarSerializer { explicit BNFGrammarPrinter(const BNFGrammar& grammar) : BNFGrammarSerializer(grammar) {} /*! \brief Print the complete grammar. */ - String ToString() final; + std::string ToString() final; /*! \brief Print a rule. */ std::string PrintRule(const Rule& rule); @@ -102,7 +102,7 @@ class BNFGrammarJSONSerializer : public BNFGrammarSerializer { * \brief Dump the raw representation of the AST to a JSON file. * \param prettify Whether to format the JSON string. If false, all whitespaces will be removed. */ - String ToString() final; + std::string ToString() final; private: bool prettify_; diff --git a/cpp/serve/grammar/json_schema_converter.cc b/cpp/serve/grammar/json_schema_converter.cc new file mode 100644 index 0000000000..93d693f3c6 --- /dev/null +++ b/cpp/serve/grammar/json_schema_converter.cc @@ -0,0 +1,987 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/grammar/json_schema_converter.cc + */ +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! + * \brief Manage the indent and separator for the generation of EBNF grammar. + * \param indent The number of spaces for each indent. If it is std::nullopt, there will be no + * indent or newline. + * \param separator The separator between different elements in json. Examples include "," and ", ". + */ +class IndentManager { + public: + IndentManager(std::optional indent, const std::string& separator) + : enable_newline_(indent.has_value()), + indent_(indent.value_or(0)), + separator_(separator), + total_indent_(0), + is_first_({true}) {} + + /*! \brief Enter a new indent level. */ + void StartIndent() { + total_indent_ += indent_; + is_first_.push_back(true); + } + + /*! \brief Exit the current indent level. */ + void EndIndent() { + total_indent_ -= indent_; + is_first_.pop_back(); + } + + /*! + * \brief Get the next separator in the current level. When first called in the current + * level, the starting separator will be returned. When called again, the middle separator will be + * returned. When called with `is_end=True`, the ending separator will be returned. + * \param is_end Get the separator for the end of the current level. + * \example + * \code + * IndentManager indent_manager(2, ", "); + * indent_manager.StartIndent(); + * indent_manager.GetSep(); // get the start separator: "\"\n \"" + * indent_manager.GetSep(); // get the middle separator: "\",\n \"" + * indent_manager.GetSep(true); // get the end separator: "\"\n\"" + * \endcode + */ + std::string NextSeparator(bool is_end = false); + + /*! \brief Get the separator itself. */ + std::string GetBareSeparator() { return separator_; } + + private: + bool enable_newline_; + int indent_; + std::string separator_; + int total_indent_; + std::vector is_first_; + friend class JSONSchemaToEBNFConverter; +}; + +std::string IndentManager::NextSeparator(bool is_end) { + std::string res = ""; + if (!is_first_.back() && !is_end) { + res += separator_; + } + is_first_.back() = false; + + if (enable_newline_) { + res += "\\n"; + } + + if (!is_end) { + res += std::string(total_indent_, ' '); + } else { + res += std::string(total_indent_ - indent_, ' '); + } + + return "\"" + res + "\""; +} + +/*! + * \brief Convert JSON schema string to EBNF grammar string. The parameters follow + * JSONSchemaToEBNF(). + * + * \note About the representation of json schema in this converter. JSON schema could be two types: + * bool (true or false) or dict (a json dict) containing attributes. We use picojson::value to + * represent the json schema. + */ +class JSONSchemaToEBNFConverter { + public: + JSONSchemaToEBNFConverter( + const picojson::value& json_schema, std::optional indent = std::nullopt, + std::optional> separators = std::nullopt, + bool strict_mode = false); + + /*! \brief The main method. Convert the JSON schema to EBNF grammar string. */ + std::string Convert(); + + private: + // The name of the basic rules + inline static const std::string kBasicAny = "basic_any"; + inline static const std::string kBasicInteger = "basic_integer"; + inline static const std::string kBasicNumber = "basic_number"; + inline static const std::string kBasicString = "basic_string"; + inline static const std::string kBasicBoolean = "basic_boolean"; + inline static const std::string kBasicNull = "basic_null"; + inline static const std::string kBasicArray = "basic_array"; + inline static const std::string kBasicObject = "basic_object"; + + // The name of the helper rules to construct basic rules + inline static const std::string kBasicEscape = "basic_escape"; + inline static const std::string kBasicStringSub = "basic_string_sub"; + + /*! \brief Add the basic rules to the rules list and the basic_rules_cache. */ + void AddBasicRules(); + + /*! \brief Add helper rules for the basic rules. */ + void AddHelperRules(); + + /*! \brief Create a rule for the given schema and name, and add it to the basic_rules_cache. */ + void CreateBasicRule(const picojson::value& schema, const std::string& name); + + /*! \brief Get the index for the schema in the cache. Keys that do not effect the validation + * will be ignored when finding the corresponding cache rule. */ + std::string GetSchemaCacheIndex(const picojson::value& schema); + + /*! + * \brief Create a rule with the given schema and rule name hint. + * \returns The name of the rule will be returned. That is not necessarily the same as the + * rule_name_hint due to the caching mechanism. + */ + std::string CreateRuleFromSchema(const picojson::value& schema, + const std::string& rule_name_hint); + + /*! \brief Get the next separator in the current level from the indent manager. */ + std::string NextSeparator(bool is_end = false); + + /*! \brief Warn if any keyword is existing in the schema but not supported. */ + static void WarnUnsupportedKeywords(const picojson::value& schema, + const std::vector& keywords); + + /*! \brief Warn if any keyword is existing in the object but not supported. */ + static void WarnUnsupportedKeywords(const picojson::object& schema, + const std::vector& keywords); + + /*! \brief Visit the schema and return the rule body for later constructing the rule. */ + std::string VisitSchema(const picojson::value& schema, const std::string& rule_name); + + /*! \brief Visit a reference schema. */ + std::string VisitRef(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Get the schema from the URI. */ + picojson::value URIToSchema(const picojson::value& uri); + + /*! \brief Visit a const schema. */ + std::string VisitConst(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Visit an enum schema. */ + std::string VisitEnum(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Convert the JSON string to a printable string that can be shown in BNF. */ + std::string JSONStrToPrintableStr(const std::string& json_str); + + /*! \brief Visit an anyOf schema. */ + std::string VisitAnyOf(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Visit a true schema that can match anything. */ + std::string VisitAny(const picojson::value& schema, const std::string& rule_name); + + /*! \brief Visit an integer schema. */ + std::string VisitInteger(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Visit a number schema. */ + std::string VisitNumber(const picojson::object& schema, const std::string& rule_name); + /*! \brief Visit a string schema. */ + std::string VisitString(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Visit a boolean schema. */ + std::string VisitBoolean(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Visit a null schema. */ + std::string VisitNull(const picojson::object& schema, const std::string& rule_name); + + /*! + * \brief Visit an array schema. + * \example + * Schema: + * \code + * { + * "type": "array", + * "prefixItems": [ + * {"type": "boolean"}, + * {"type": "integer"} + * ], + * "items": { + * "type": "string" + * } + * } + * \endcode + * Rule (not considering the indent): + * \code + * main ::= "[" basic_boolean ", " basic_integer (", " basic_string)* "]" + * \endcode + */ + std::string VisitArray(const picojson::object& schema, const std::string& rule_name); + + /*! + * \brief Visit an object schema. + * \example + * Schema: + * \code + * { + * "type": "object", + * "properties": { + * "a": {"type": "string"}, + * "b": {"type": "integer"} + * }, + * "required": ["a"], + * "additionalProperties": true + * } + * \endcode + * + * Rule (not considering the indent): + * \code + * main ::= "{" "a" ":" basic_string (", " "b" ":" basic_integer)* + * (", " basic_string ": " basic_any)* "}" + * \endcode + + * We need special handling when all properties are optional, since the handling of separators + * is tricky in this case. E.g. + + * Schema: + * \code + * { + * "type": "object", + * "properties": { + * "a": {"type": "string"}, + * "b": {"type": "integer"}, + * "c": {"type": "boolean"} + * }, + * "additionalProperties": true + * } + * \endcode + * + * Rule (indent=2): + * \code + * main ::= "{" ("\n " (a main_sub_1 | b main_sub_2 | c main_sub_3 | d main_sub_3) + * "\n" | "") "}" + * main_sub_1 ::= ",\n " b r2 | r2 + * main_sub_2 ::= ",\n " c r3 | r3 + * main_sub_3 ::= (",\n " d)* + * \endcode + */ + std::string VisitObject(const picojson::object& schema, const std::string& rule_name); + + /*! \brief Get the pattern for a property in the object schema. */ + std::string GetPropertyPattern(const std::string& prop_name, const picojson::value& prop_schema, + const std::string& rule_name, int idx); + + /*! \brief Get the pattern for the additional/unevaluated properties in the object schema. */ + std::string GetOtherPropertyPattern(const std::string& key_pattern, + const picojson::value& prop_schema, + const std::string& rule_name, + const std::string& rule_name_suffix); + + /*! \brief Get the partial rule for the properties when all properties are optional. See the + * example in VisitObject(). */ + std::string GetPartialRuleForPropertiesAllOptional( + const std::vector>& properties, + const picojson::value& additional, const std::string& rule_name, + const std::string& additional_suffix = ""); + + /*! + * \brief Get the partial rule for the properties when some properties are required. See the + * example in VisitObject(). + * + * The constructed rule should be: + * \code + * start_separator (optional_property separator)? (optional_property separator)? ... + * first_required_property (separator optional_property)? separator required_property ... + * end_separator + * \endcode + * + * i.e. Before the first required property, all properties are in the form + * (property separator) ; and after the first required property, all properties are in the form + * (separator property) . */ + std::string GetPartialRuleForPropertiesContainRequired( + const std::vector>& properties, + const std::unordered_set& required, const std::string& rule_name); + + // The indent manager to get separators + std::unique_ptr indentManager_; + // The root JSON schema + picojson::value json_schema_; + // Whether to use strict mode in conversion. See JSONSchemaToEBNF(). + bool strict_mode_; + // The colon separator + std::string colon_; + // The rules constructed + std::vector> rules_; + // The cache for basic rules. Mapping from the key of schema returned by GetSchemaCacheIndex() + // to the basic rule name. + std::map basic_rules_cache_; +}; + +JSONSchemaToEBNFConverter::JSONSchemaToEBNFConverter( + const picojson::value& json_schema, std::optional indent, + std::optional> separators, bool strict_mode) + : json_schema_(json_schema), strict_mode_(strict_mode) { + if (!separators.has_value()) { + separators = (indent == std::nullopt) ? std::make_pair(", ", ": ") : std::make_pair(",", ": "); + } + indentManager_ = std::make_unique(indent, separators->first); + colon_ = separators->second; + + AddBasicRules(); +} + +std::string JSONSchemaToEBNFConverter::Convert() { + CreateRuleFromSchema(json_schema_, "main"); + std::string res; + for (auto& rule : rules_) { + res += rule.first + " ::= " + rule.second + "\n"; + } + return res; +} + +void JSONSchemaToEBNFConverter::AddBasicRules() { + bool past_strict_mode = strict_mode_; + strict_mode_ = false; + + auto past_indent_manager = std::move(indentManager_); + indentManager_ = + std::make_unique(std::nullopt, past_indent_manager->GetBareSeparator()); + + AddHelperRules(); + CreateBasicRule(picojson::value(true), kBasicAny); + basic_rules_cache_[GetSchemaCacheIndex(picojson::value(picojson::object()))] = kBasicAny; + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("integer")}}), + kBasicInteger); + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("number")}}), + kBasicNumber); + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("string")}}), + kBasicString); + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("boolean")}}), + kBasicBoolean); + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("null")}}), kBasicNull); + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("array")}}), + kBasicArray); + CreateBasicRule(picojson::value(picojson::object{{"type", picojson::value("object")}}), + kBasicObject); + + strict_mode_ = past_strict_mode; + indentManager_ = std::move(past_indent_manager); +} + +void JSONSchemaToEBNFConverter::AddHelperRules() { + rules_.push_back(std::make_pair( + kBasicEscape, "[\"\\\\/bfnrt] | \"u\" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]")); + rules_.push_back(std::make_pair(kBasicStringSub, "\"\" | [^\"\\\\\\r\\n] " + kBasicStringSub + + " | \"\\\\\" " + kBasicEscape + " " + + kBasicStringSub)); +} + +void JSONSchemaToEBNFConverter::CreateBasicRule(const picojson::value& schema, + const std::string& name) { + std::string rule_name = CreateRuleFromSchema(schema, name); + basic_rules_cache_[GetSchemaCacheIndex(schema)] = rule_name; +} + +std::string JSONSchemaToEBNFConverter::NextSeparator(bool is_end) { + return indentManager_->NextSeparator(is_end); +} + +void JSONSchemaToEBNFConverter::WarnUnsupportedKeywords(const picojson::value& schema, + const std::vector& keywords) { + if (schema.is()) { + return; + } + + ICHECK(schema.is()); + WarnUnsupportedKeywords(schema.get(), keywords); +} + +void JSONSchemaToEBNFConverter::WarnUnsupportedKeywords(const picojson::object& schema, + const std::vector& keywords) { + for (const auto& keyword : keywords) { + if (schema.find(keyword) != schema.end()) { + LOG(WARNING) << "Keyword " << keyword << " is not supported in schema " + << picojson::value(schema); + } + } +} + +std::string JSONSchemaToEBNFConverter::CreateRuleFromSchema(const picojson::value& schema, + const std::string& rule_name_hint) { + std::string idx = GetSchemaCacheIndex(schema); + if (basic_rules_cache_.count(idx)) { + return basic_rules_cache_[idx]; + } + + rules_.push_back(std::make_pair(rule_name_hint, VisitSchema(schema, rule_name_hint))); + return rule_name_hint; +} + +std::string JSONSchemaToEBNFConverter::GetSchemaCacheIndex(const picojson::value& schema) { + // Keys that do not effect the validation + static const std::unordered_set kSkippedKeys = { + "title", "default", "description", "examples", "deprecated", + "readOnly", "writeOnly", "$comment", "$schema", + }; + if (schema.is()) { + // remove skipped keys and sort key by lexicographical order + std::string result = "{"; + std::vector> sorted_kv; + for (const auto& kv : schema.get()) { + if (kSkippedKeys.count(kv.first) == 0) { + sorted_kv.push_back(kv); + } + } + std::sort(sorted_kv.begin(), sorted_kv.end(), + [](const auto& lhs, const auto& rhs) { return lhs.first < rhs.first; }); + int idx = 0; + for (const auto& [key, value] : sorted_kv) { + if (idx != 0) { + result += ","; + } + ++idx; + result += "\"" + key + "\":" + GetSchemaCacheIndex(value); + } + return result + "}"; + } else if (schema.is()) { + std::string result = "["; + int idx = 0; + for (const auto& item : schema.get()) { + if (idx != 0) { + result += ","; + } + ++idx; + result += GetSchemaCacheIndex(item); + } + return result + "]"; + } + // If the object is neither an array nor an object, return it directly + return schema.serialize(false); +} + +std::string JSONSchemaToEBNFConverter::VisitSchema(const picojson::value& schema, + const std::string& rule_name) { + if (schema.is()) { + ICHECK(schema.get()); + return VisitAny(schema, rule_name); + } + + WarnUnsupportedKeywords(schema, { + "allof", + "oneof", + "not", + "if", + "then", + "else", + "dependentRequired", + "dependentSchemas", + }); + + ICHECK(schema.is()); + + const auto& schema_obj = schema.get(); + + if (schema_obj.count("$ref")) { + return VisitRef(schema_obj, rule_name); + } else if (schema_obj.count("const")) { + return VisitConst(schema_obj, rule_name); + } else if (schema_obj.count("enum")) { + return VisitEnum(schema_obj, rule_name); + } else if (schema_obj.count("anyOf")) { + return VisitAnyOf(schema_obj, rule_name); + } else if (schema_obj.count("type")) { + const std::string& type = schema_obj.at("type").get(); + if (type == "integer") { + return VisitInteger(schema_obj, rule_name); + } else if (type == "number") { + return VisitNumber(schema_obj, rule_name); + } else if (type == "string") { + return VisitString(schema_obj, rule_name); + } else if (type == "boolean") { + return VisitBoolean(schema_obj, rule_name); + } else if (type == "null") { + return VisitNull(schema_obj, rule_name); + } else if (type == "array") { + return VisitArray(schema_obj, rule_name); + } else if (type == "object") { + return VisitObject(schema_obj, rule_name); + } else { + LOG(FATAL) << "Unsupported type " << type << " in schema " << schema; + } + } + + // If no above keyword is detected, we treat it as any + return VisitAny(schema, rule_name); +} + +std::string JSONSchemaToEBNFConverter::VisitRef(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("$ref")); + picojson::value new_schema = URIToSchema(schema.at("$ref")); + if (!new_schema.is()) { + picojson::object new_schema_obj = new_schema.get(); + for (const auto& [k, v] : schema) { + if (k != "$ref") { + new_schema_obj[k] = v; + } + } + new_schema = picojson::value(new_schema_obj); + } + return VisitSchema(new_schema, rule_name); +} + +picojson::value JSONSchemaToEBNFConverter::URIToSchema(const picojson::value& uri) { + if (uri.get().substr(0, 8) == "#/$defs/") { + return json_schema_.get("$defs").get(uri.get().substr(8)); + } + LOG(WARNING) << "Now only support URI starting with '#/$defs/' but got " << uri; + return picojson::value(true); +} + +std::string JSONSchemaToEBNFConverter::VisitConst(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("const")); + // TODO(yixin): Customize serialize to support indent logics + return "\"" + JSONStrToPrintableStr(schema.at("const").serialize()) + "\""; +} + +std::string JSONSchemaToEBNFConverter::VisitEnum(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("enum")); + std::string result = ""; + int idx = 0; + for (auto value : schema.at("enum").get()) { + if (idx != 0) { + result += " | "; + } + ++idx; + result += "(\"" + JSONStrToPrintableStr(value.serialize()) + "\")"; + } + return result; +} + +std::string JSONSchemaToEBNFConverter::JSONStrToPrintableStr(const std::string& json_str) { + static const std::vector> kReplaceMapping = {{"\\", "\\\\"}, + {"\"", "\\\""}}; + std::string result = json_str; + for (const auto& [k, v] : kReplaceMapping) { + size_t pos = 0; + while ((pos = result.find(k, pos)) != std::string::npos) { + result.replace(pos, k.length(), v); + pos += v.length(); + } + } + return result; +} + +std::string JSONSchemaToEBNFConverter::VisitAnyOf(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("anyOf")); + std::string result = ""; + int idx = 0; + for (auto anyof_schema : schema.at("anyOf").get()) { + if (idx != 0) { + result += " | "; + } + result += CreateRuleFromSchema(anyof_schema, rule_name + "_case_" + std::to_string(idx)); + ++idx; + } + return result; +} + +std::string JSONSchemaToEBNFConverter::VisitAny(const picojson::value& schema, + const std::string& rule_name) { + // Note integer is a subset of number, so we don't need to add integer here + return kBasicNumber + " | " + kBasicString + " | " + kBasicBoolean + " | " + kBasicNull + " | " + + kBasicArray + " | " + kBasicObject; +} + +std::string JSONSchemaToEBNFConverter::VisitInteger(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "integer"); + WarnUnsupportedKeywords(schema, { + "multipleOf", + "minimum", + "maximum", + "exclusiveMinimum", + "exclusiveMaximum", + }); + return "(\"0\" | \"-\"? [1-9] [0-9]*) \".0\"?"; +} + +std::string JSONSchemaToEBNFConverter::VisitNumber(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "number"); + WarnUnsupportedKeywords(schema, { + "multipleOf", + "minimum", + "maximum", + "exclusiveMinimum", + "exclusiveMaximum", + }); + return "(\"0\" | \"-\"? [1-9] [0-9]*) (\".\" [0-9]+)? ([eE] [+-]? [0-9]+)?"; +} + +std::string JSONSchemaToEBNFConverter::VisitString(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "string"); + WarnUnsupportedKeywords(schema, { + "minLength", + "maxLength", + "pattern", + "format", + }); + return "[\"] " + kBasicStringSub + " [\"]"; +} + +std::string JSONSchemaToEBNFConverter::VisitBoolean(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "boolean"); + return "\"true\" | \"false\""; +} + +std::string JSONSchemaToEBNFConverter::VisitNull(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "null"); + return "\"null\""; +} + +std::string JSONSchemaToEBNFConverter::VisitArray(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "array"); + WarnUnsupportedKeywords(schema, { + "uniqueItems", + "contains", + "minContains", + "maxContains", + "minItems", + "maxItems", + }); + + std::string result = "\"[\""; + + indentManager_->StartIndent(); + + // 1. Handle prefix items + if (schema.count("prefixItems")) { + const auto& prefix_items = schema.at("prefixItems").get(); + for (int i = 0; i < prefix_items.size(); ++i) { + ICHECK(prefix_items[i].is()); + result += " " + NextSeparator() + " "; + result += CreateRuleFromSchema(prefix_items[i], rule_name + "_item_" + std::to_string(i)); + } + } + + // 2. Find additional items + picojson::value additional_item = picojson::value(false); + std::string additional_suffix = ""; + + if (schema.count("items") && (!schema.at("items").is() || schema.at("items").get())) { + additional_item = schema.at("items"); + additional_suffix = "items"; + } + + // If items is specified in the schema, we don't need to consider unevaluatedItems + if (schema.count("items") == 0) { + picojson::value unevaluated = schema.count("unevaluatedItems") ? schema.at("unevaluatedItems") + : picojson::value(!strict_mode_); + if (!unevaluated.is() || unevaluated.get()) { + additional_item = unevaluated; + additional_suffix = "uneval"; + } + } + + // 3. Handle additional items and the end separator + bool could_be_empty = false; + if (additional_item.is() && !additional_item.get()) { + result += " " + NextSeparator(true); + } else { + std::string additional_pattern = + CreateRuleFromSchema(additional_item, rule_name + "_" + additional_suffix); + if (schema.count("prefixItems")) { + result += " (" + NextSeparator() + " " + additional_pattern + ")* "; + result += NextSeparator(true); + } else { + result += " " + NextSeparator() + " " + additional_pattern + " ("; + result += NextSeparator() + " " + additional_pattern + ")* "; + result += NextSeparator(true); + could_be_empty = true; + } + } + + indentManager_->EndIndent(); + + result += " \"]\""; + + if (could_be_empty) { + result = "(" + result + ") | \"[]\""; + } + + return result; +} + +std::string JSONSchemaToEBNFConverter::GetPropertyPattern(const std::string& prop_name, + const picojson::value& prop_schema, + const std::string& rule_name, int idx) { + // the outer quote is for the string in EBNF grammar, and the inner quote is for + // the string in JSON + std::string key = "\"\\\"" + prop_name + "\\\"\""; + std::string colon = "\"" + colon_ + "\""; + std::string value = CreateRuleFromSchema(prop_schema, rule_name + "_prop_" + std::to_string(idx)); + return key + " " + colon + " " + value; +} + +std::string JSONSchemaToEBNFConverter::GetOtherPropertyPattern( + const std::string& key_pattern, const picojson::value& prop_schema, + const std::string& rule_name, const std::string& rule_name_suffix) { + std::string colon = "\"" + colon_ + "\""; + std::string value = CreateRuleFromSchema(prop_schema, rule_name + "_" + rule_name_suffix); + return key_pattern + " " + colon + " " + value; +} + +std::string JSONSchemaToEBNFConverter::GetPartialRuleForPropertiesAllOptional( + const std::vector>& properties, + const picojson::value& additional, const std::string& rule_name, + const std::string& additional_suffix) { + ICHECK(properties.size() >= 1); + + std::string first_sep = NextSeparator(); + std::string mid_sep = NextSeparator(); + std::string last_sep = NextSeparator(true); + + std::string res = ""; + + std::vector prop_patterns; + int idx = 0; + for (const auto& [prop_name, prop_schema] : properties) { + prop_patterns.push_back(GetPropertyPattern(prop_name, prop_schema, rule_name, idx)); + ++idx; + } + + std::vector rule_names(properties.size(), ""); + + // construct the last rule + std::string additional_prop_pattern; + if (!additional.is() || additional.get()) { + additional_prop_pattern = + GetOtherPropertyPattern(kBasicString, additional, rule_name, additional_suffix); + std::string last_rule_body = "(" + mid_sep + " " + additional_prop_pattern + ")*"; + std::string last_rule_name = rule_name + "_part_" + std::to_string(properties.size() - 1); + rules_.push_back(std::make_pair(last_rule_name, last_rule_body)); + rule_names.back() = last_rule_name; + } else { + rule_names.back() = "\"\""; + } + + // construct 0~(len(properties) - 2) rules + for (int i = properties.size() - 2; i >= 0; --i) { + const std::string& prop_pattern = prop_patterns[i + 1]; + const std::string& last_rule_name = rule_names[i + 1]; + std::string cur_rule_body = + last_rule_name + " | " + mid_sep + " " + prop_pattern + " " + last_rule_name; + std::string cur_rule_name = rule_name + "_part_" + std::to_string(i); + rules_.push_back(std::make_pair(cur_rule_name, cur_rule_body)); + rule_names[i] = cur_rule_name; + } + + // construct the main rule + for (int i = 0; i < properties.size(); ++i) { + if (i != 0) { + res += " | "; + } + res += "(" + prop_patterns[i] + " " + rule_names[i] + ")"; + } + + if (!additional.is() || additional.get()) { + res += " | " + additional_prop_pattern + " " + rule_names.back(); + } + + // add separators and the empty string option + res = first_sep + " (" + res + ") " + last_sep; + return res; +} + +std::string JSONSchemaToEBNFConverter::GetPartialRuleForPropertiesContainRequired( + const std::vector>& properties, + const std::unordered_set& required, const std::string& rule_name) { + // Find the index of the first required property + int first_required_idx = properties.size(); + for (int i = 0; i < properties.size(); ++i) { + if (required.count(properties[i].first)) { + first_required_idx = i; + break; + } + } + ICHECK(first_required_idx < properties.size()); + + std::string res = NextSeparator(); + + // Handle the properties before the first required property + for (int i = 0; i < first_required_idx; ++i) { + const auto& [prop_name, prop_schema] = properties[i]; + ICHECK(!prop_schema.is() || prop_schema.get()); + std::string property_pattern = GetPropertyPattern(prop_name, prop_schema, rule_name, i); + res += " (" + property_pattern + " " + NextSeparator() + ")?"; + } + + // Handle the first required property + const auto& [prop_name, prop_schema] = properties[first_required_idx]; + std::string property_pattern = + GetPropertyPattern(prop_name, prop_schema, rule_name, first_required_idx); + res += " " + property_pattern; + + // Handle the properties after the first required property + for (int i = first_required_idx + 1; i < properties.size(); ++i) { + const auto& [prop_name, prop_schema] = properties[i]; + ICHECK(!prop_schema.is() || prop_schema.get()); + std::string property_pattern = GetPropertyPattern(prop_name, prop_schema, rule_name, i); + if (required.count(prop_name)) { + res += " " + NextSeparator() + " " + property_pattern; + } else { + res += " (" + NextSeparator() + " " + property_pattern + ")?"; + } + } + + return res; +} + +std::string JSONSchemaToEBNFConverter::VisitObject(const picojson::object& schema, + const std::string& rule_name) { + ICHECK(schema.count("type")); + ICHECK(schema.at("type").get() == "object"); + WarnUnsupportedKeywords(schema, { + "patternProperties", + "minProperties", + "maxProperties", + "propertyNames", + }); + + std::string result = "\"{\""; + + // could_be_empty will be set to True when the rule could be "{}". We will handle this case at + // last, and handle non-empty cases before that. + bool could_be_empty = false; + + indentManager_->StartIndent(); + + // 1. Handle properties + std::vector> properties; + if (schema.count("properties")) { + auto properties_obj = schema.at("properties").get(); + for (const auto& key : properties_obj.ordered_keys()) { + properties.push_back({key, properties_obj.at(key)}); + } + } + + std::unordered_set required; + if (schema.count("required")) { + for (const auto& required_prop : schema.at("required").get()) { + required.insert(required_prop.get()); + } + } + + // 2. Find additional properties + picojson::value additional_property = picojson::value(false); + std::string additional_suffix = ""; + + if (schema.count("additionalProperties") && (!schema.at("additionalProperties").is() || + schema.at("additionalProperties").get())) { + additional_property = schema.at("additionalProperties"); + additional_suffix = "addl"; + } + + if (schema.count("additionalProperties") == 0) { + picojson::value unevaluated = schema.count("unevaluatedProperties") + ? schema.at("unevaluatedProperties") + : picojson::value(!strict_mode_); + if (!unevaluated.is() || unevaluated.get()) { + additional_property = unevaluated; + additional_suffix = "uneval"; + } + } + + bool is_all_properties_optional = + std::all_of(properties.begin(), properties.end(), + [&](const auto& prop) { return required.count(prop.first) == 0; }); + + if (is_all_properties_optional && properties.size() > 0) { + // 3.1 Case 1: properties are defined and all properties are optional + result += " " + GetPartialRuleForPropertiesAllOptional(properties, additional_property, + rule_name, additional_suffix); + could_be_empty = true; + } else if (properties.size() > 0) { + // 3.2 Case 2: properties are defined and some properties are required + result += " " + GetPartialRuleForPropertiesContainRequired(properties, required, rule_name); + if (!additional_property.is() || additional_property.get()) { + std::string other_property_pattern = + GetOtherPropertyPattern(kBasicString, additional_property, rule_name, additional_suffix); + result += " (" + NextSeparator() + " " + other_property_pattern + ")*"; + } + result += " " + NextSeparator(true); + } else if (!additional_property.is() || additional_property.get()) { + // 3.3 Case 3: no properties are defined and additional properties are allowed + std::string other_property_pattern = + GetOtherPropertyPattern(kBasicString, additional_property, rule_name, additional_suffix); + result += " " + NextSeparator() + " " + other_property_pattern + " ("; + result += NextSeparator() + " " + other_property_pattern + ")* "; + result += NextSeparator(true); + could_be_empty = true; + } + + indentManager_->EndIndent(); + + result += " \"}\""; + if (could_be_empty) { + result = "(" + result + ") | \"{}\""; + } + + return result; +}; + +std::string JSONSchemaToEBNF(std::string schema, std::optional indent, + std::optional> separators, + bool strict_mode) { + picojson::value schema_value; + std::string err = picojson::parse(schema_value, schema); + if (!err.empty()) { + LOG(FATAL) << "Failed to parse JSON: err. The JSON string is:" << schema; + } + JSONSchemaToEBNFConverter converter(schema_value, indent, separators, strict_mode); + return converter.Convert(); +} + +TVM_REGISTER_GLOBAL("mlc.serve.DebugJSONSchemaToEBNF").set_body([](TVMArgs args, TVMRetValue* rv) { + std::optional indent; + if (args[1].type_code() != kTVMNullptr) { + indent = args[1]; + } else { + indent = std::nullopt; + } + + std::optional> separators; + if (args[2].type_code() != kTVMNullptr) { + Array separators_arr = args[2]; + CHECK(separators_arr.size() == 2); + separators = std::make_pair(separators_arr[0], separators_arr[1]); + } else { + separators = std::nullopt; + } + + *rv = JSONSchemaToEBNF(args[0], indent, separators, args[3]); +}); + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/grammar/json_schema_converter.h b/cpp/serve/grammar/json_schema_converter.h new file mode 100644 index 0000000000..22c730aa41 --- /dev/null +++ b/cpp/serve/grammar/json_schema_converter.h @@ -0,0 +1,44 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/grammar/json_grammar_converter.h + * \brief The header for translating JSON schema to EBNF grammar. + */ + +#ifndef MLC_LLM_SERVE_GRAMMAR_JSON_SCHEMA_CONVERTER_H_ +#define MLC_LLM_SERVE_GRAMMAR_JSON_SCHEMA_CONVERTER_H_ + +#include +#include +#include + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief Convert JSON schema string to EBNF grammar string. + * \param json_schema The JSON schema string. + * \param indent The number of spaces for indentation. If set to std::nullopt, the output will be + * in one line. Default: std::nullopt. + * \param separators Two separators used in the schema: comma and colon. Examples: {",", ":"}, + * {", ", ": "}. If std::nullopt, the default separators will be used: {",", ": "} when the + * indent is not -1, and {", ", ": "} otherwise. This follows the convention in python json.dumps(). + * Default: std::nullopt. + * \param strict_mode Whether to use strict mode. In strict mode, the generated grammar will not + * allow properties and items that is not specified in the schema. This is equivalent to + * setting unevaluatedProperties and unevaluatedItems to false. + * + * This helps LLM to generate accurate output in the grammar-guided generation with JSON + * schema. Default: true. + * \returns The EBNF grammar string. + */ +std::string JSONSchemaToEBNF( + std::string schema, std::optional indent = std::nullopt, + std::optional> separators = std::nullopt, + bool strict_mode = true); + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_GRAMMAR_JSON_SCHEMA_CONVERTER_H_ diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index e165128ea3..7043cb75c7 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -6,6 +6,5 @@ from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData from .engine import AsyncEngine, Engine from .grammar import BNFGrammar, GrammarStateMatcher -from .json_schema_converter import json_schema_to_ebnf from .request import Request from .server import PopenServer diff --git a/python/mlc_llm/serve/grammar.py b/python/mlc_llm/serve/grammar.py index d640c62da2..d5ad862a42 100644 --- a/python/mlc_llm/serve/grammar.py +++ b/python/mlc_llm/serve/grammar.py @@ -137,11 +137,13 @@ def from_schema( separators : Optional[Tuple[str, str]] Two separators used in the schema: comma and colon. Examples: (",", ":"), (", ", ": "). If None, the default separators will be used: (",", ": ") when the indent is not None, - and (", ", ": ") otherwise. Default: None. + and (", ", ": ") otherwise. This follows the convention in json.dumps(). Default: None. strict_mode : bool Whether to use strict mode. In strict mode, the generated grammar will not allow - unevaluatedProperties and unevaluatedItems, i.e. these will be set to false by default. + properties and items that is not specified in the schema. This is equivalent to + setting unevaluatedProperties and unevaluatedItems to false. + This helps LLM to generate accurate output in the grammar-guided generation with JSON schema. Default: True. @@ -150,9 +152,8 @@ def from_schema( grammar : BNFGrammar The generated BNF grammar. """ - indent_converted = -1 if indent is None else indent return _ffi_api.BNFGrammarFromSchema( # type: ignore # pylint: disable=no-member - schema, indent_converted, separators, strict_mode + schema, indent, separators, strict_mode ) @staticmethod @@ -166,6 +167,47 @@ def get_grammar_of_json() -> "BNFGrammar": """ return _ffi_api.BNFGrammarGetGrammarOfJSON() # type: ignore # pylint: disable=no-member + @staticmethod + def debug_json_schema_to_ebnf( + schema: str, + *, + indent: Optional[int] = None, + separators: Optional[Tuple[str, str]] = None, + strict_mode: bool = True + ) -> str: + """Convert JSON schema string to EBNF grammar string. For test purposes. + + Parameters + ---------- + json_schema : str + The JSON schema string. + + indent : Optional[int] + The number of spaces for indentation. If None, the output will be in one line. + Default: None. + + separators : Optional[Tuple[str, str]] + Two separators used in the schema: comma and colon. Examples: (",", ":"), (", ", ": "). + If None, the default separators will be used: (",", ": ") when the indent is not None, + and (", ", ": ") otherwise. This follows the convention in json.dumps(). Default: None. + + strict_mode : bool + Whether to use strict mode. In strict mode, the generated grammar will not allow + properties and items that is not specified in the schema. This is equivalent to + setting unevaluatedProperties and unevaluatedItems to false. + + This helps LLM to generate accurate output in the grammar-guided generation with JSON + schema. Default: True. + + Returns + ------- + ebnf_string : str + The EBNF grammar string. + """ + return _ffi_api.DebugJSONSchemaToEBNF( # type: ignore # pylint: disable=no-member + schema, indent, separators, strict_mode + ) + @tvm._ffi.register_object("mlc.serve.GrammarStateMatcher") # pylint: disable=protected-access class GrammarStateMatcher(Object): diff --git a/python/mlc_llm/serve/json_schema_converter.py b/python/mlc_llm/serve/json_schema_converter.py deleted file mode 100644 index 9a4af6176e..0000000000 --- a/python/mlc_llm/serve/json_schema_converter.py +++ /dev/null @@ -1,742 +0,0 @@ -# mypy: disable-error-code="operator,union-attr,index" -"""Utility to convert JSON schema to EBNF grammar. Helpful for the grammar-guided generation.""" -import json -import logging -from typing import Any, Dict, List, Optional, Tuple, Union - -from tvm._ffi import register_func - -SchemaType = Union[Dict[str, Any], bool] -""" -JSON schema specification defines the schema type could be a dictionary or a boolean value. -""" - - -class _IndentManager: - """Manage the indent and separator for the generation of EBNF grammar. - - Parameters - ---------- - indent : Optional[int] - The number of spaces for each indent. If it is None, there will be no indent or newline. - - separator : str - The separator between different elements in json. Examples include "," and ", ". - """ - - def __init__(self, indent: Optional[int], separator: str): - self.enable_newline = indent is not None - self.indent = indent or 0 - self.separator = separator - self.total_indent = 0 - self.is_first = [True] - - def __enter__(self): - """Enter a new indent level.""" - self.total_indent += self.indent - self.is_first.append(True) - return self - - def __exit__(self, exc_type, exc_value, traceback): - """Exit the current indent level.""" - self.total_indent -= self.indent - self.is_first.pop() - - def get_sep(self, is_end: bool = False) -> str: - """Get the separator according to the current state. When first called in the current level, - the starting separator will be returned. When called again, the middle separator will be - returned. When called with `is_end=True`, the ending separator will be returned. - - Parameters - ---------- - is_end : bool - Get the separator for the end of the current level. - - Examples - -------- - >>> indent_manager = IndentManager(2, ", ") - >>> with indent_manager: - ... print(indent_manager.get_sep()) # get the start separator - ... print(indent_manager.get_sep()) # get the middle separator - ... print(indent_manager.get_sep(is_end=True)) # get the end separator - - Output: (double quotes are included in the string for EBNF construction) - '"\n "' - '",\n "' - '"\n"' - """ - res = "" - - if not self.is_first[-1] and not is_end: - res += self.separator - self.is_first[-1] = False - - if self.enable_newline: - res += "\\n" - - if not is_end: - res += self.total_indent * " " - else: - res += (self.total_indent - self.indent) * " " - - return f'"{res}"' - - -# pylint: disable=unused-argument,too-few-public-methods -class _JSONSchemaToEBNFConverter: - """Convert JSON schema string to EBNF grammar string. The parameters follow - `json_schema_to_ebnf()`. - """ - - def __init__( - self, - json_schema: SchemaType, - indent: Optional[int] = None, - separators: Optional[Tuple[str, str]] = None, - strict_mode: bool = False, - ): - self.json_schema = json_schema - self.strict_mode = strict_mode - - if separators is None: - separators = (", ", ": ") if indent is None else (",", ": ") - assert len(separators) == 2 - self.indent_manager = _IndentManager(indent, separators[0]) - self.colon = separators[1] - - self.rules: List[Tuple[str, str]] = [] - self.basic_rules_cache: Dict[str, str] = {} - self._add_basic_rules() - - def convert(self) -> str: - """Main method. Convert the JSON schema to EBNF grammar string.""" - self._create_rule_with_schema(self.json_schema, "main") - res = "" - for rule_name, rule in self.rules: - res += f"{rule_name} ::= {rule}\n" - return res - - # The name of the basic rules - BASIC_ANY = "basic_any" - BASIC_INTEGER = "basic_integer" - BASIC_NUMBER = "basic_number" - BASIC_STRING = "basic_string" - BASIC_BOOLEAN = "basic_boolean" - BASIC_NULL = "basic_null" - BASIC_ARRAY = "basic_array" - BASIC_OBJECT = "basic_object" - - # The name of the helper rules to construct basic rules - BASIC_ESCAPE = "basic_escape" - BASIC_STRING_SUB = "basic_string_sub" - - def _add_basic_rules(self): - """Add the basic rules to the rules list and the basic_rules_cache.""" - past_strict_mode = self.strict_mode - self.strict_mode = False - past_indent_manager = self.indent_manager - self.indent_manager = _IndentManager(None, past_indent_manager.separator) - - self._add_helper_rules() - self._create_basic_rule(True, self.BASIC_ANY) - self.basic_rules_cache[self._get_schema_cache_index({})] = self.BASIC_ANY - self._create_basic_rule({"type": "integer"}, self.BASIC_INTEGER) - self._create_basic_rule({"type": "number"}, self.BASIC_NUMBER) - self._create_basic_rule({"type": "string"}, self.BASIC_STRING) - self._create_basic_rule({"type": "boolean"}, self.BASIC_BOOLEAN) - self._create_basic_rule({"type": "null"}, self.BASIC_NULL) - self._create_basic_rule({"type": "array"}, self.BASIC_ARRAY) - self._create_basic_rule({"type": "object"}, self.BASIC_OBJECT) - - self.strict_mode = past_strict_mode - self.indent_manager = past_indent_manager - - def _add_helper_rules(self): - """Add helper rules for the basic rules.""" - self.rules.append( - ( - self.BASIC_ESCAPE, - '["\\\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]', - ) - ) - self.rules.append( - ( - self.BASIC_STRING_SUB, - f'"" | [^"\\\\\\r\\n] {self.BASIC_STRING_SUB} | ' - f'"\\\\" {self.BASIC_ESCAPE} {self.BASIC_STRING_SUB}', - ) - ) - - def _create_basic_rule(self, schema: SchemaType, name: str): - """Create a rule for the given schema and name, and add it to the basic_rules_cache.""" - rule_name = self._create_rule_with_schema(schema, name) - self.basic_rules_cache[self._get_schema_cache_index(schema)] = rule_name - - def _get_sep(self, is_end: bool = False): - """Get the separator from the indent manager.""" - return self.indent_manager.get_sep(is_end) - - @staticmethod - def _warn_unsupported_keywords(schema: SchemaType, keywords: Union[str, List[str]]): - """Warn if any keyword is existing in the schema but not supported.""" - if isinstance(schema, bool): - return - if isinstance(keywords, str): - keywords = [keywords] - for keyword in keywords: - if keyword in schema: - logging.warning("Keyword %s is not supported in schema %s", keyword, schema) - - def _create_rule_with_schema(self, schema: SchemaType, rule_name_hint: str) -> str: - """Create a rule with the given schema and rule name hint. - - Returns - ------- - The name of the rule will be returned. That is not necessarily the same as the - rule_name_hint due to the caching mechanism. - """ - idx = self._get_schema_cache_index(schema) - if idx in self.basic_rules_cache: - return self.basic_rules_cache[idx] - - assert isinstance(rule_name_hint, str) - - self.rules.append((rule_name_hint, self._visit_schema(schema, rule_name_hint))) - return rule_name_hint - - # The keywords that will be ignored when finding the cached rule for a schema - SKIPPED_KEYS = [ - "title", - "default", - "description", - "examples", - "deprecated", - "readOnly", - "writeOnly", - "$comment", - "$schema", - ] - - @staticmethod - def _remove_skipped_keys_recursive(obj: Any) -> Any: - """Remove the skipped keys from the schema recursively.""" - if isinstance(obj, dict): - return { - k: _JSONSchemaToEBNFConverter._remove_skipped_keys_recursive(v) - for k, v in obj.items() - if k not in _JSONSchemaToEBNFConverter.SKIPPED_KEYS - } - if isinstance(obj, list): - return [_JSONSchemaToEBNFConverter._remove_skipped_keys_recursive(v) for v in obj] - return obj - - def _get_schema_cache_index(self, schema: SchemaType) -> str: - """Get the index for the schema in the cache.""" - return json.dumps( - _JSONSchemaToEBNFConverter._remove_skipped_keys_recursive(schema), - sort_keys=True, - indent=None, - ) - - # pylint: disable=too-many-return-statements,too-many-branches - def _visit_schema(self, schema: SchemaType, rule_name: str) -> str: - """Visit the schema and return the rule body for later constructing the rule.""" - assert schema is not False - if schema is True: - return self._visit_any(schema, rule_name) - - _JSONSchemaToEBNFConverter._warn_unsupported_keywords( - schema, - [ - "allof", - "oneof", - "not", - "if", - "then", - "else", - "dependentRequired", - "dependentSchemas", - ], - ) - - if "$ref" in schema: - return self._visit_ref(schema, rule_name) - if "const" in schema: - return self._visit_const(schema, rule_name) - if "enum" in schema: - return self._visit_enum(schema, rule_name) - if "anyOf" in schema: - return self._visit_anyof(schema, rule_name) - if "type" in schema: - type_obj = schema["type"] - if type_obj == "integer": - return self._visit_integer(schema, rule_name) - if type_obj == "number": - return self._visit_number(schema, rule_name) - if type_obj == "string": - return self._visit_string(schema, rule_name) - if type_obj == "boolean": - return self._visit_boolean(schema, rule_name) - if type_obj == "null": - return self._visit_null(schema, rule_name) - if type_obj == "array": - return self._visit_array(schema, rule_name) - if type_obj == "object": - return self._visit_object(schema, rule_name) - raise ValueError(f"Unsupported type {schema['type']}") - # no keyword is detected, we treat it as any - return self._visit_any(schema, rule_name) - - def _visit_ref(self, schema: SchemaType, rule_name: str) -> str: - """Visit a reference schema.""" - assert "$ref" in schema - new_schema = self._uri_to_schema(schema["$ref"]).copy() - if not isinstance(new_schema, bool): - new_schema.update({k: v for k, v in schema.items() if k != "$ref"}) - return self._visit_schema(new_schema, rule_name) - - def _uri_to_schema(self, uri: str) -> SchemaType: - """Get the schema from the URI.""" - if uri.startswith("#/$defs/"): - return self.json_schema["$defs"][uri[len("#/$defs/") :]] - logging.warning("Now only support URI starting with '#/$defs/' but got %s", uri) - return True - - def _visit_const(self, schema: SchemaType, rule_name: str) -> str: - """Visit a const schema.""" - assert "const" in schema - return '"' + self._json_str_to_printable_str(json.dumps(schema["const"])) + '"' - - def _visit_enum(self, schema: SchemaType, rule_name: str) -> str: - """Visit an enum schema.""" - assert "enum" in schema - res = "" - for i, enum_value in enumerate(schema["enum"]): - if i != 0: - res += " | " - res += '("' + self._json_str_to_printable_str(json.dumps(enum_value)) + '")' - return res - - REPLACE_MAPPING = { - "\\": "\\\\", - '"': '\\"', - } - - def _json_str_to_printable_str(self, json_str: str) -> str: - """Convert the JSON string to a printable string in BNF.""" - for k, v in self.REPLACE_MAPPING.items(): - json_str = json_str.replace(k, v) - return json_str - - def _visit_anyof(self, schema: SchemaType, rule_name: str) -> str: - """Visit an anyOf schema.""" - assert "anyOf" in schema - res = "" - for i, anyof_schema in enumerate(schema["anyOf"]): - if i != 0: - res += " | " - res += self._create_rule_with_schema(anyof_schema, f"{rule_name}_{i}") - return res - - def _visit_any(self, schema: SchemaType, rule_name: str) -> str: - """Visit a true schema that can match anything.""" - # note integer is a subset of number, so we don't need to add integer here - return ( - f"{self.BASIC_NUMBER} | {self.BASIC_STRING} | {self.BASIC_BOOLEAN} | " - f"{self.BASIC_NULL} | {self.BASIC_ARRAY} | {self.BASIC_OBJECT}" - ) - - def _visit_integer(self, schema: SchemaType, rule_name: str) -> str: - """Visit an integer schema.""" - assert schema["type"] == "integer" - _JSONSchemaToEBNFConverter._warn_unsupported_keywords( - schema, ["multipleOf", "minimum", "maximum", "exclusiveMinimum", "exclusiveMaximum"] - ) - return '("0" | "-"? [1-9] [0-9]*) ".0"?' - - def _visit_number(self, schema: SchemaType, rule_name: str) -> str: - """Visit a number schema.""" - assert schema["type"] == "number" - _JSONSchemaToEBNFConverter._warn_unsupported_keywords( - schema, ["multipleOf", "minimum", "maximum", "exclusiveMinimum", "exclusiveMaximum"] - ) - return '("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?' - - def _visit_string(self, schema: SchemaType, rule_name: str) -> str: - """Visit a string schema.""" - assert schema["type"] == "string" - _JSONSchemaToEBNFConverter._warn_unsupported_keywords( - schema, ["minLength", "maxLength", "pattern", "format"] - ) - return f'["] {self.BASIC_STRING_SUB} ["]' - - def _visit_boolean(self, schema: SchemaType, rule_name: str) -> str: - """Visit a boolean schema.""" - assert schema["type"] == "boolean" - - return '"true" | "false"' - - def _visit_null(self, schema: SchemaType, rule_name: str) -> str: - """Visit a null schema.""" - assert schema["type"] == "null" - - return '"null"' - - def _visit_array(self, schema: SchemaType, rule_name: str) -> str: - """Visit an array schema. - - Examples - -------- - Schema: - { - "type": "array", - "prefixItems": [ - {"type": "boolean"}, - {"type": "integer"} - ], - "items": { - "type": "string" - } - } - - Rule (not considering the indent): - main ::= "[" basic_boolean ", " basic_integer (", " basic_string)* "]" - """ - assert schema["type"] == "array" - _JSONSchemaToEBNFConverter._warn_unsupported_keywords( - schema, - ["uniqueItems", "contains", "minContains", "maxContains", "minItems", "maxItems"], - ) - - res = '"["' - could_be_empty = False - - with self.indent_manager: - # 1. Handle prefix items - prefix_items = schema.get("prefixItems", []) - if len(prefix_items) > 0: - for i, prefix_item in enumerate(prefix_items): - assert prefix_item is not False - item = self._create_rule_with_schema(prefix_item, f"{rule_name}_{i}") - res += f" {self._get_sep()} {item}" - - # 2. Find additional items - additional_item = None - additional_suffix = "" - - items = schema.get("items", False) - if items is not False: - additional_item = items - additional_suffix = "item" - - # if items is in the schema, we don't need to consider unevaluatedItems - unevaluated = schema.get("unevaluatedItems", not self.strict_mode) - if "items" not in schema and unevaluated is not False: - additional_item = unevaluated - additional_suffix = "uneval" - - # 3. Handle additional items and the end separator - if additional_item is None: - res += f" {self._get_sep(is_end=True)}" - else: - additional_pattern = self._create_rule_with_schema( - additional_item, f"{rule_name}_{additional_suffix}" - ) - if len(prefix_items) > 0: - res += ( - f" ({self._get_sep()} {additional_pattern})* {self._get_sep(is_end=True)}" - ) - else: - res += ( - f" {self._get_sep()} {additional_pattern} ({self._get_sep()} " - f"{additional_pattern})* {self._get_sep(is_end=True)}" - ) - could_be_empty = True - - res += ' "]"' - - if could_be_empty: - res = f'({res}) | "[]"' - - return res - - def _visit_object(self, schema: SchemaType, rule_name: str) -> str: - """Visit an object schema. - - Examples - -------- - Schema: - { - "type": "object", - "properties": { - "a": {"type": "string"}, - "b": {"type": "integer"} - }, - "required": ["a"], - "additionalProperties": true - } - - Rule (not considering the indent): - main ::= "{" "a" ":" basic_string (", " "b" ":" basic_integer)* - (", " basic_string ": " basic_any)* "}" - - We need special handling when all properties are optional, since the handling of separators - is tricky in this case. E.g. - - Schema: - { - "type": "object", - "properties": { - "a": {"type": "string"}, - "b": {"type": "integer"}, - "c": {"type": "boolean"} - }, - "additionalProperties": true - } - - Rule (indent=2): - main ::= "{" ("\n " (a main_sub_1 | b main_sub_2 | c main_sub_3 | d main_sub_3) - "\n" | "") "}" - main_sub_1 ::= ",\n " b r2 | r2 - main_sub_2 ::= ",\n " c r3 | r3 - main_sub_3 ::= (",\n " d)* - """ - assert schema["type"] == "object" - _JSONSchemaToEBNFConverter._warn_unsupported_keywords( - schema, ["patternProperties", "minProperties", "maxProperties", "propertyNames"] - ) - - res = '"{"' - # Set could_be_empty to True when the rule could be "{}". We will handle this case at last, - # and handle non-empty cases before that. - could_be_empty = False - # Now we only consider the required list for the properties field - required = schema.get("required", []) - - with self.indent_manager: - # 1. Find additional properties - additional_property = None - additional_suffix = "" - - additional = schema.get("additionalProperties", False) - if additional is not False: - additional_property = additional - additional_suffix = "add" - - unevaluated = schema.get("unevaluatedProperties", not self.strict_mode) - if "additionalProperties" not in schema and unevaluated is not False: - additional_property = unevaluated - additional_suffix = "uneval" - - # 2. Handle properties - properties_obj = schema.get("properties", {}) - properties = list(properties_obj.items()) - - properties_all_optional = all(prop_name not in required for prop_name, _ in properties) - if properties_all_optional and len(properties) > 0: - # 3.1 Case 1: properties are defined and all properties are optional - res += " " + self._get_partial_rule_for_properties_all_optional( - properties, additional_property, rule_name, additional_suffix - ) - could_be_empty = True - elif len(properties) > 0: - # 3.2 Case 2: properties are defined and some properties are required - res += " " + self._get_partial_rule_for_properties_contain_required( - properties, required, rule_name - ) - if additional_property is not None: - other_property_pattern = self._get_other_property_pattern( - self.BASIC_STRING, additional_property, rule_name, additional_suffix - ) - res += f" ({self._get_sep()} {other_property_pattern})*" - res += " " + self._get_sep(is_end=True) - elif additional_property is not None: - # 3.3 Case 3: no properties are defined and additional properties are allowed - other_property_pattern = self._get_other_property_pattern( - self.BASIC_STRING, additional_property, rule_name, additional_suffix - ) - res += ( - f" {self._get_sep()} {other_property_pattern} ({self._get_sep()} " - f"{other_property_pattern})* {self._get_sep(is_end=True)}" - ) - could_be_empty = True - - res += ' "}"' - - if could_be_empty: - res = f'({res}) | "{{}}"' - return res - - def _get_property_pattern(self, prop_name: str, prop_schema: SchemaType, rule_name: str) -> str: - """Get the pattern for a property in the object schema.""" - # the outer quote is for the string in EBNF grammar, and the inner quote is for - # the string in JSON - key = f'"\\"{prop_name}\\""' - colon = f'"{self.colon}"' - value = self._create_rule_with_schema(prop_schema, rule_name + "_" + prop_name) - return f"{key} {colon} {value}" - - def _get_other_property_pattern( - self, key_pattern: str, prop_schema: SchemaType, rule_name: str, rule_name_suffix: str - ) -> str: - """Get the pattern for the additional/unevaluated properties in the object schema.""" - colon = f'"{self.colon}"' - value = self._create_rule_with_schema(prop_schema, rule_name + "_" + rule_name_suffix) - return f"{key_pattern} {colon} {value}" - - # pylint: disable=too-many-locals - def _get_partial_rule_for_properties_all_optional( - self, - properties: List[Tuple[str, SchemaType]], - additional: Optional[SchemaType], - rule_name: str, - additional_suffix: str = "", - ) -> str: - """Get the partial rule for the properties when all properties are optional. See the - above example.""" - assert len(properties) >= 1 - - first_sep = self._get_sep() - mid_sep = self._get_sep() - last_sep = self._get_sep(is_end=True) - - res = "" - - prop_patterns = [ - self._get_property_pattern(prop_name, prop_schema, rule_name) - for prop_name, prop_schema in properties - ] - - rule_names = [None] * len(properties) - - # construct the last rule - if additional is not None: - additional_prop_pattern = self._get_other_property_pattern( - self.BASIC_STRING, additional, rule_name, additional_suffix - ) - last_rule_body = f"({mid_sep} {additional_prop_pattern})*" - last_rule_name = f"{rule_name}_sub_{len(properties)-1}" - self.rules.append((last_rule_name, last_rule_body)) - rule_names[-1] = last_rule_name # type: ignore - else: - rule_names[-1] = '""' # type: ignore - - # construct 0~(len(properties) - 2) rules - for i in reversed(range(0, len(properties) - 1)): - prop_pattern = prop_patterns[i + 1] - last_rule_name = rule_names[i + 1] - cur_rule_body = f"{last_rule_name} | {mid_sep} {prop_pattern} {last_rule_name}" - cur_rule_name = f"{rule_name}_sub_{i}" - self.rules.append((cur_rule_name, cur_rule_body)) - rule_names[i] = cur_rule_name # type: ignore - - # construct the main rule - for i, prop_pattern in enumerate(prop_patterns): - if i != 0: - res += " | " - res += f"({prop_pattern} {rule_names[i]})" - - if additional is not None: - res += f" | {additional_prop_pattern} {rule_names[-1]}" - - # add separators and the empty string option - res = f"{first_sep} ({res}) {last_sep}" - return res - - def _get_partial_rule_for_properties_contain_required( - self, - properties: List[Tuple[str, SchemaType]], - required: List[str], - rule_name: str, - ) -> str: - """Get the partial rule for the properties when some properties are required. See the - above example. - - The constructed rule should be: - - start_separator (optional_property separator)? (optional_property separator)? ... - first_required_property (separator optional_property)? separator required_property ... - end_separator - - i.e. Before the first required property, all properties are in the form - (property separator); and after the first required property, all properties are in the form - (separator property). - """ - - # Find the index of the first required property - first_required_idx = next( - (i for i, (prop_name, _) in enumerate(properties) if prop_name in required), - len(properties), - ) - assert first_required_idx < len(properties) - - res = self._get_sep() - - # Handle the properties before the first required property - for prop_name, prop_schema in properties[:first_required_idx]: - assert prop_schema is not False - property_pattern = self._get_property_pattern(prop_name, prop_schema, rule_name) - res += f" ({property_pattern} {self._get_sep()})?" - - # Handle the first required property - property_pattern = self._get_property_pattern( - properties[first_required_idx][0], properties[first_required_idx][1], rule_name - ) - res += f" {property_pattern}" - - # Handle the properties after the first required property - for prop_name, prop_schema in properties[first_required_idx + 1 :]: - assert prop_schema is not False - property_pattern = self._get_property_pattern(prop_name, prop_schema, rule_name) - if prop_name in required: - res += f" {self._get_sep()} {property_pattern}" - else: - res += f" ({self._get_sep()} {property_pattern})?" - - return res - - -def json_schema_to_ebnf( - json_schema: str, - *, - indent: Optional[int] = None, - separators: Optional[Tuple[str, str]] = None, - strict_mode: bool = True, -) -> str: - """Convert JSON schema string to EBNF grammar string. - - Parameters - ---------- - json_schema : str - The JSON schema string. - - indent : Optional[int] - The number of spaces for each indent. If it is None, there will be no indent or newline. - The indent and separators parameters follow the same convention as - `json.dumps()`. - - separators : Optional[Tuple[str, str]] - The separator between different elements in json. Examples include "," and ", ". - - strict_mode : bool - Whether to use strict mode. In strict mode, the generated grammar will not allow - unevaluatedProperties and unevaluatedItems, i.e. these will be set to false by default. - This helps LLM to generate accurate output in the grammar-guided generation with JSON - schema. - """ - json_schema_schema = json.loads(json_schema) - return _JSONSchemaToEBNFConverter(json_schema_schema, indent, separators, strict_mode).convert() - - -@register_func("mlc.serve.json_schema_to_ebnf") -def json_schema_to_ebnf_register( - json_schema: str, - indent: Optional[int] = None, - separators: Optional[Tuple[str, str]] = None, - strict_mode: bool = True, -) -> str: - """To register json_schema_to_ebnf in ffi, we need to create an equivalent function without - keyword-only arguments.""" - return json_schema_to_ebnf( - json_schema, indent=indent, separators=separators, strict_mode=strict_mode - ) diff --git a/tests/python/serve/test_json_schema_converter.py b/tests/python/serve/test_json_schema_converter.py index 822199977c..84dbd2cb7b 100644 --- a/tests/python/serve/test_json_schema_converter.py +++ b/tests/python/serve/test_json_schema_converter.py @@ -5,7 +5,7 @@ import tvm.testing from pydantic import BaseModel, Field, TypeAdapter -from mlc_llm.serve import BNFGrammar, GrammarStateMatcher, json_schema_to_ebnf +from mlc_llm.serve import BNFGrammar, GrammarStateMatcher def check_schema_with_grammar( @@ -16,7 +16,7 @@ def check_schema_with_grammar( strict_mode: bool = True, ): schema_str = json.dumps(schema, indent=2) - grammar = json_schema_to_ebnf( + grammar = BNFGrammar.debug_json_schema_to_ebnf( schema_str, indent=indent, separators=separators, strict_mode=strict_mode ) assert grammar == expected_grammar @@ -25,17 +25,14 @@ def check_schema_with_grammar( def check_schema_with_json( schema: Dict[str, Any], json_str: str, - check_accepted=True, + check_accepted: bool = True, indent: Optional[int] = None, separators: Optional[Tuple[str, str]] = None, strict_mode: bool = True, ): - schema_str = json.dumps(schema, indent=2) - - ebnf_grammar_str = json_schema_to_ebnf( - schema_str, indent=indent, separators=separators, strict_mode=strict_mode + ebnf_grammar = BNFGrammar.from_schema( + json.dumps(schema, indent=2), indent=indent, separators=separators, strict_mode=strict_mode ) - ebnf_grammar = BNFGrammar.from_ebnf_string(ebnf_grammar_str) matcher = GrammarStateMatcher(ebnf_grammar) if check_accepted: @@ -47,7 +44,7 @@ def check_schema_with_json( def check_schema_with_instance( schema: Dict[str, Any], instance: BaseModel, - check_accepted=True, + check_accepted: bool = True, indent: Optional[int] = None, separators: Optional[Tuple[str, str]] = None, strict_mode: bool = True, @@ -78,14 +75,14 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_any_array_field ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -main_array_field ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" -main_tuple_field_2 ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" -main_tuple_field ::= "[" "" basic_string ", " basic_integer ", " main_tuple_field_2 "" "]" -main_object_field ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" -main_nested_object_field_add ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" -main_nested_object_field ::= ("{" "" basic_string ": " main_nested_object_field_add (", " basic_string ": " main_nested_object_field_add)* "" "}") | "{}" -main ::= "{" "" "\"integer_field\"" ": " basic_integer ", " "\"number_field\"" ": " basic_number ", " "\"boolean_field\"" ": " basic_boolean ", " "\"any_array_field\"" ": " main_any_array_field ", " "\"array_field\"" ": " main_array_field ", " "\"tuple_field\"" ": " main_tuple_field ", " "\"object_field\"" ": " main_object_field ", " "\"nested_object_field\"" ": " main_nested_object_field "" "}" +main_prop_3 ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +main_prop_4 ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" +main_prop_5_item_2 ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" +main_prop_5 ::= "[" "" basic_string ", " basic_integer ", " main_prop_5_item_2 "" "]" +main_prop_6 ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" +main_prop_7_addl ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" +main_prop_7 ::= ("{" "" basic_string ": " main_prop_7_addl (", " basic_string ": " main_prop_7_addl)* "" "}") | "{}" +main ::= "{" "" "\"integer_field\"" ": " basic_integer ", " "\"number_field\"" ": " basic_number ", " "\"boolean_field\"" ": " basic_boolean ", " "\"any_array_field\"" ": " main_prop_3 ", " "\"array_field\"" ": " main_prop_4 ", " "\"tuple_field\"" ": " main_prop_5 ", " "\"object_field\"" ": " main_prop_6 ", " "\"nested_object_field\"" ": " main_prop_7 "" "}" """ schema = MainModel.model_json_schema() @@ -134,11 +131,11 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any ("," basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" "}") | "{}" -main_array_field ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" -main_tuple_field_2 ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" -main_tuple_field ::= "[" "\n " basic_string ",\n " basic_integer ",\n " main_tuple_field_2 "\n " "]" -main_object_field ::= ("{" "\n " basic_string ": " basic_integer (",\n " basic_string ": " basic_integer)* "\n " "}") | "{}" -main ::= "{" "\n " "\"array_field\"" ": " main_array_field ",\n " "\"tuple_field\"" ": " main_tuple_field ",\n " "\"object_field\"" ": " main_object_field "\n" "}" +main_prop_0 ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" +main_prop_1_item_2 ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" +main_prop_1 ::= "[" "\n " basic_string ",\n " basic_integer ",\n " main_prop_1_item_2 "\n " "]" +main_prop_2 ::= ("{" "\n " basic_string ": " basic_integer (",\n " basic_string ": " basic_integer)* "\n " "}") | "{}" +main ::= "{" "\n " "\"array_field\"" ": " main_prop_0 ",\n " "\"tuple_field\"" ": " main_prop_1 ",\n " "\"object_field\"" ": " main_prop_2 "\n" "}" """ instance = MainModel( @@ -171,10 +168,10 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any ("," basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" "}") | "{}" -main_tuple_field_1 ::= "[" "\n " basic_integer ",\n " basic_integer (",\n " basic_any)* "\n " "]" -main_tuple_field ::= "[" "\n " basic_string ",\n " main_tuple_field_1 (",\n " basic_any)* "\n " "]" -main_foo_field ::= ("{" "\n " basic_string ": " basic_any (",\n " basic_string ": " basic_any)* "\n " "}") | "{}" -main ::= "{" "\n " "\"tuple_field\"" ": " main_tuple_field ",\n " "\"foo_field\"" ": " main_foo_field (",\n " basic_string ": " basic_any)* "\n" "}" +main_prop_0_item_1 ::= "[" "\n " basic_integer ",\n " basic_integer (",\n " basic_any)* "\n " "]" +main_prop_0 ::= "[" "\n " basic_string ",\n " main_prop_0_item_1 (",\n " basic_any)* "\n " "]" +main_prop_1 ::= ("{" "\n " basic_string ": " basic_any (",\n " basic_string ": " basic_any)* "\n " "}") | "{}" +main ::= "{" "\n " "\"tuple_field\"" ": " main_prop_0 ",\n " "\"foo_field\"" ": " main_prop_1 (",\n " basic_string ": " basic_any)* "\n" "}" """ instance_json = """{ @@ -220,12 +217,12 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_bars ::= "\"a\"" -main_str_values ::= "\"a\\n\\r\\\"\"" -main_foo ::= ("\"a\"") | ("\"b\"") | ("\"c\"") -main_values ::= ("1") | ("\"a\"") | ("true") -main_field ::= ("\"foo\"") | ("\"bar\"") -main ::= "{" "" "\"bars\"" ": " main_bars ", " "\"str_values\"" ": " main_str_values ", " "\"foo\"" ": " main_foo ", " "\"values\"" ": " main_values ", " "\"field\"" ": " main_field "" "}" +main_prop_0 ::= "\"a\"" +main_prop_1 ::= "\"a\\n\\r\\\"\"" +main_prop_2 ::= ("\"a\"") | ("\"b\"") | ("\"c\"") +main_prop_3 ::= ("1") | ("\"a\"") | ("true") +main_prop_4 ::= ("\"foo\"") | ("\"bar\"") +main ::= "{" "" "\"bars\"" ": " main_prop_0 ", " "\"str_values\"" ": " main_prop_1 ", " "\"foo\"" ": " main_prop_2 ", " "\"values\"" ": " main_prop_3 ", " "\"field\"" ": " main_prop_4 "" "}" """ schema = MainModel.model_json_schema() @@ -251,9 +248,9 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_opt_bool ::= basic_boolean | basic_null -main_size ::= basic_number | basic_null -main ::= "{" "" ("\"num\"" ": " basic_integer ", ")? ("\"opt_bool\"" ": " main_opt_bool ", ")? "\"size\"" ": " main_size (", " "\"name\"" ": " basic_string)? "" "}" +main_prop_1 ::= basic_boolean | basic_null +main_prop_2 ::= basic_number | basic_null +main ::= "{" "" ("\"num\"" ": " basic_integer ", ")? ("\"opt_bool\"" ": " main_prop_1 ", ")? "\"size\"" ": " main_prop_2 (", " "\"name\"" ": " basic_string)? "" "}" """ schema = MainModel.model_json_schema() @@ -286,9 +283,9 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_sub_1 ::= "" | ", " "\"num\"" ": " basic_number "" -main_sub_0 ::= main_sub_1 | ", " "\"state\"" ": " basic_boolean main_sub_1 -main ::= ("{" "" (("\"size\"" ": " basic_integer main_sub_0) | ("\"state\"" ": " basic_boolean main_sub_1) | ("\"num\"" ": " basic_number "")) "" "}") | "{}" +main_part_1 ::= "" | ", " "\"num\"" ": " basic_number "" +main_part_0 ::= main_part_1 | ", " "\"state\"" ": " basic_boolean main_part_1 +main ::= ("{" "" (("\"size\"" ": " basic_integer main_part_0) | ("\"state\"" ": " basic_boolean main_part_1) | ("\"num\"" ": " basic_number "")) "" "}") | "{}" """ schema = MainModel.model_json_schema() @@ -310,10 +307,10 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_sub_2 ::= (", " basic_string ": " basic_any)* -main_sub_1 ::= main_sub_2 | ", " "\"num\"" ": " basic_number main_sub_2 -main_sub_0 ::= main_sub_1 | ", " "\"state\"" ": " basic_boolean main_sub_1 -main ::= ("{" "" (("\"size\"" ": " basic_integer main_sub_0) | ("\"state\"" ": " basic_boolean main_sub_1) | ("\"num\"" ": " basic_number main_sub_2) | basic_string ": " basic_any main_sub_2) "" "}") | "{}" +main_part_2 ::= (", " basic_string ": " basic_any)* +main_part_1 ::= main_part_2 | ", " "\"num\"" ": " basic_number main_part_2 +main_part_0 ::= main_part_1 | ", " "\"state\"" ": " basic_boolean main_part_1 +main ::= ("{" "" (("\"size\"" ": " basic_integer main_part_0) | ("\"state\"" ": " basic_boolean main_part_1) | ("\"num\"" ": " basic_number main_part_2) | basic_string ": " basic_any main_part_2) "" "}") | "{}" """ check_schema_with_grammar(schema, ebnf_grammar_non_strict, strict_mode=False) @@ -376,12 +373,12 @@ class MainModel(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_foo_size ::= basic_number | basic_null -main_foo ::= "{" "" "\"count\"" ": " basic_integer (", " "\"size\"" ": " main_foo_size)? "" "}" -main_bars_item_sub_0 ::= "" | ", " "\"banana\"" ": " basic_string "" -main_bars_item ::= ("{" "" (("\"apple\"" ": " basic_string main_bars_item_sub_0) | ("\"banana\"" ": " basic_string "")) "" "}") | "{}" -main_bars ::= ("[" "" main_bars_item (", " main_bars_item)* "" "]") | "[]" -main ::= "{" "" "\"foo\"" ": " main_foo ", " "\"bars\"" ": " main_bars "" "}" +main_prop_0_prop_1 ::= basic_number | basic_null +main_prop_0 ::= "{" "" "\"count\"" ": " basic_integer (", " "\"size\"" ": " main_prop_0_prop_1)? "" "}" +main_prop_1_items_part_0 ::= "" | ", " "\"banana\"" ": " basic_string "" +main_prop_1_items ::= ("{" "" (("\"apple\"" ": " basic_string main_prop_1_items_part_0) | ("\"banana\"" ": " basic_string "")) "" "}") | "{}" +main_prop_1 ::= ("[" "" main_prop_1_items (", " main_prop_1_items)* "" "]") | "[]" +main ::= "{" "" "\"foo\"" ": " main_prop_0 ", " "\"bars\"" ": " main_prop_1 "" "}" """ schema = MainModel.model_json_schema() @@ -412,9 +409,9 @@ class Dog(BaseModel): basic_null ::= "null" basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_0 ::= "{" "" "\"name\"" ": " basic_string ", " "\"color\"" ": " basic_string "" "}" -main_1 ::= "{" "" "\"name\"" ": " basic_string ", " "\"breed\"" ": " basic_string "" "}" -main ::= main_0 | main_1 +main_case_0 ::= "{" "" "\"name\"" ": " basic_string ", " "\"color\"" ": " basic_string "" "}" +main_case_1 ::= "{" "" "\"name\"" ": " basic_string ", " "\"breed\"" ": " basic_string "" "}" +main ::= main_case_0 | main_case_1 """ check_schema_with_grammar(model_schema, ebnf_grammar) @@ -450,6 +447,32 @@ class MainModel(BaseModel): instance_str = json.dumps(instance.model_dump(mode="json", round_trip=True, by_alias=True)) check_schema_with_json(MainModel.model_json_schema(by_alias=True), instance_str) + # property name contains space + class MainModelSpace(BaseModel): + test: Literal["abc"] = Field(..., alias="name 1") + + ebnf_grammar_space = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] +basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub +basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object +basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? +basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? +basic_string ::= ["] basic_string_sub ["] +basic_boolean ::= "true" | "false" +basic_null ::= "null" +basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" +basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" +main_prop_0 ::= "\"abc\"" +main ::= "{" "" "\"name 1\"" ": " main_prop_0 "" "}" +""" + + check_schema_with_grammar(MainModelSpace.model_json_schema(), ebnf_grammar_space) + + instance_space = MainModelSpace(**{"name 1": "abc"}) + instance_space_str = json.dumps( + instance_space.model_dump(mode="json", round_trip=True, by_alias=True) + ) + check_schema_with_json(MainModelSpace.model_json_schema(by_alias=True), instance_space_str) + if __name__ == "__main__": tvm.testing.main() From 9b71443b490f4bfbe78878f5230ab50fa238e566 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 11 Apr 2024 09:39:23 -0700 Subject: [PATCH 180/531] [Model] Use R.topk/cumsum for mixtral (#2107) --- python/mlc_llm/op/moe_misc.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/python/mlc_llm/op/moe_misc.py b/python/mlc_llm/op/moe_misc.py index 6dc7f33265..ff5e50c60c 100644 --- a/python/mlc_llm/op/moe_misc.py +++ b/python/mlc_llm/op/moe_misc.py @@ -5,9 +5,6 @@ from tvm import te, tir from tvm.relax.frontend.nn import Tensor, op from tvm.script import tir as T -from tvm.target import Target -from tvm.topi.cuda.scan import inclusive_scan -from tvm.topi.cuda.sort import topk as topi_topk # mypy: disable-error-code="attr-defined,name-defined" # pylint: disable=line-too-long,too-many-locals,invalid-name @@ -120,7 +117,9 @@ def topk_softmax_func( Tensor.placeholder([batch_size, 2], index_dtype), ), ) - expert_score, expert_indices = op.tensor_expr_op(topi_topk, "topk", args=[x, k, -1, "both", False, index_dtype]) # type: ignore[list-item] + expert_score, expert_indices = op.topk( + x, k, axis=-1, ret_type="both", largest=True, dtype=index_dtype + ) expert_score = op.softmax(expert_score.astype("float32"), axis=-1).astype(dtype) return expert_score, expert_indices @@ -203,14 +202,8 @@ def moe_cumsum(expert_indices: Tensor, num_local_experts: int) -> Tensor: .permute_dims(1, 0) .reshape(batch_size * num_local_experts) ) - with Target.current(allow_none=True) or Target( - { - "kind": "cuda", - "max_num_threads": 1024, - "arch": "sm_50", - } - ): - return op.tensor_expr_op(inclusive_scan, "cumsum", args=[expert_mask, 0, "int32"]) # type: ignore[list-item] + + return op.cumsum(expert_mask, axis=0, exclusive=False, dtype="int32") def get_indices(cumsum: Tensor, expert_indices: Tensor) -> Tuple[Tensor, Tensor]: From 880c68a00d6138590b206e9d8703d4bee9047c82 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 11 Apr 2024 22:26:48 -0700 Subject: [PATCH 181/531] Enable flashinfer when group_size == 6 (#2124) --- python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py | 2 +- python/mlc_llm/op/attention.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py index 20e4c7bdd9..d9d478cd1f 100644 --- a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py +++ b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py @@ -155,7 +155,7 @@ def create_flashinfer_paged_kv_cache( in self.metadata["model_type"] ) # filter by attention group size - or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 8] + or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 6, 8] ): return diff --git a/python/mlc_llm/op/attention.py b/python/mlc_llm/op/attention.py index 801dbd66ba..dc41a5f5ef 100644 --- a/python/mlc_llm/op/attention.py +++ b/python/mlc_llm/op/attention.py @@ -103,12 +103,12 @@ def _fallback(): and k.dtype == "float16" and v.dtype == "float16" ): - if group_size not in [1, 4, 8]: + if group_size not in [1, 4, 6, 8]: global WARN_FLASHINFER_GROUP_SIZE # pylint: disable=global-statement if not WARN_FLASHINFER_GROUP_SIZE: WARN_FLASHINFER_GROUP_SIZE = True logger.warning( - "FlashInfer only supports group size in [1, 4, 8], but got %d. Skip and " + "FlashInfer only supports group size in [1, 4, 6, 8], but got %d. Skip and " "fallback to default implementation.", group_size, ) From 4dfb9f070fe865ea8299a85e4e6557e6a2042785 Mon Sep 17 00:00:00 2001 From: ZCHNO Date: Fri, 12 Apr 2024 22:56:44 +0800 Subject: [PATCH 182/531] [SpecDecode] Support Eagle in speculative decoding (#2080) 1. Add Eagle-Llama-7b-chat model support. 2. Add speculative decoding support with Eagle. --- cpp/serve/config.cc | 18 +- cpp/serve/config.h | 13 +- cpp/serve/engine.cc | 50 +- cpp/serve/engine_actions/action.h | 52 ++ cpp/serve/engine_actions/eagle_batch_draft.cc | 230 +++++++ .../engine_actions/eagle_batch_verify.cc | 364 +++++++++++ .../eagle_new_request_prefill.cc | 568 ++++++++++++++++++ .../engine_actions/new_request_prefill.cc | 4 +- cpp/serve/function_table.cc | 10 +- cpp/serve/function_table.h | 9 + cpp/serve/logit_processor.cc | 4 +- cpp/serve/model.cc | 425 +++++++++++++ cpp/serve/model.h | 100 +++ cpp/serve/request_state.cc | 12 +- cpp/serve/request_state.h | 9 +- cpp/serve/sampler/cpu_sampler.cc | 31 +- .../mlc_llm/compiler_pass/attach_sampler.py | 10 +- python/mlc_llm/model/eagle/__init__.py | 0 python/mlc_llm/model/eagle/eagle_loader.py | 172 ++++++ python/mlc_llm/model/eagle/eagle_model.py | 242 ++++++++ .../mlc_llm/model/eagle/eagle_quantization.py | 70 +++ python/mlc_llm/model/llama/llama_model.py | 135 ++++- python/mlc_llm/model/model.py | 17 + python/mlc_llm/serve/__init__.py | 2 +- python/mlc_llm/serve/config.py | 37 +- .../serve/test_serve_async_engine_spec.py | 10 +- tests/python/serve/test_serve_engine_spec.py | 276 ++++++++- 27 files changed, 2797 insertions(+), 73 deletions(-) create mode 100644 cpp/serve/engine_actions/eagle_batch_draft.cc create mode 100644 cpp/serve/engine_actions/eagle_batch_verify.cc create mode 100644 cpp/serve/engine_actions/eagle_new_request_prefill.cc create mode 100644 python/mlc_llm/model/eagle/__init__.py create mode 100644 python/mlc_llm/model/eagle/eagle_loader.py create mode 100644 python/mlc_llm/model/eagle/eagle_model.py create mode 100644 python/mlc_llm/model/eagle/eagle_quantization.py diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 0c69296326..62394c4b21 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -299,16 +299,16 @@ String KVCacheConfigNode::AsJSONString() const { TVM_REGISTER_OBJECT_TYPE(EngineModeNode); -EngineMode::EngineMode(bool enable_speculative, int spec_draft_length) { +EngineMode::EngineMode(int spec_draft_length, int speculative_mode) { ObjectPtr n = make_object(); - n->enable_speculative = enable_speculative; n->spec_draft_length = spec_draft_length; + n->speculative_mode = SpeculativeMode(speculative_mode); data_ = std::move(n); } EngineMode::EngineMode(const std::string& config_str) { - bool enable_speculative = false; int spec_draft_length = 4; + int speculative_mode = 0; picojson::value config_json; std::string err = picojson::parse(config_json, config_str); @@ -318,25 +318,25 @@ EngineMode::EngineMode(const std::string& config_str) { // Get json fields. picojson::object config = config_json.get(); - if (config.count("enable_speculative")) { - CHECK(config["enable_speculative"].is()); - enable_speculative = config["enable_speculative"].get(); - } if (config.count("spec_draft_length")) { CHECK(config["spec_draft_length"].is()); spec_draft_length = config["spec_draft_length"].get(); } + if (config.count("speculative_mode")) { + CHECK(config["speculative_mode"].is()); + speculative_mode = config["speculative_mode"].get(); + } ObjectPtr n = make_object(); - n->enable_speculative = enable_speculative; n->spec_draft_length = spec_draft_length; + n->speculative_mode = SpeculativeMode(speculative_mode); data_ = std::move(n); } String EngineModeNode::AsJSONString() const { picojson::object config; - config["enable_speculative"] = picojson::value(static_cast(this->enable_speculative)); config["spec_draft_length"] = picojson::value(static_cast(this->spec_draft_length)); + config["speculative_mode"] = picojson::value(static_cast(this->speculative_mode)); return picojson::value(config).serialize(true); } diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 0c3402b2ca..bee0af5561 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -98,13 +98,20 @@ class KVCacheConfig : public ObjectRef { /****************** Engine Mode ******************/ +/*! \brief The speculative mode. */ +enum class SpeculativeMode : int { + kDisable = 0, + kSmallDraft = 1, + kEagle = 2, +}; + /*! \brief The configuration of engine execution mode. */ class EngineModeNode : public Object { public: - /* Whether the speculative decoding mode is enabled */ - bool enable_speculative; /* The number of tokens to generate in speculative proposal (draft) */ int spec_draft_length; + /* The speculative mode. */ + SpeculativeMode speculative_mode; String AsJSONString() const; @@ -116,7 +123,7 @@ class EngineModeNode : public Object { class EngineMode : public ObjectRef { public: - explicit EngineMode(bool enable_speculative, int spec_draft_length); + explicit EngineMode(int spec_draft_length, int speculative_mode); explicit EngineMode(const std::string& config_str); diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index abb5c7b6c7..d9530c22fe 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -80,10 +80,11 @@ class EngineImpl : public Engine { << ", is smaller than the pre-defined max single sequence length, " << this->max_single_sequence_length_; this->models_.push_back(model); - this->model_workspaces_.push_back(ModelWorkspace{model->AllocEmbeddingTensor()}); + this->model_workspaces_.push_back( + ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); } int max_num_tokens = kv_cache_config_->max_num_sequence; - if (engine_mode_->enable_speculative) { + if (engine_mode_->speculative_mode != SpeculativeMode::kDisable) { max_num_tokens *= engine_mode_->spec_draft_length; } LogitProcessor logit_processor = @@ -91,21 +92,40 @@ class EngineImpl : public Engine { Sampler sampler = this->models_[0]->CreateSampler( max_num_tokens, static_cast(this->models_.size()), trace_recorder); // Step 3. Initialize engine actions that represent state transitions. - if (this->engine_mode_->enable_speculative) { + if (this->engine_mode_->speculative_mode != SpeculativeMode::kDisable) { // Speculative decoding is only possible for more than one model. ICHECK_GT(this->models_.size(), 1U); - this->actions_ = { - EngineAction::NewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - this->kv_cache_config_, // - this->engine_mode_, // - this->trace_recorder_), - EngineAction::BatchDraft(this->models_, logit_processor, sampler, this->trace_recorder_, - this->engine_mode_->spec_draft_length), - EngineAction::BatchVerify(this->models_, logit_processor, sampler, this->kv_cache_config_, - this->trace_recorder_)}; + switch (this->engine_mode_->speculative_mode) { + case SpeculativeMode::kEagle: + this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + this->kv_cache_config_, // + this->engine_mode_, // + this->trace_recorder_), + EngineAction::EagleBatchDraft( + this->models_, logit_processor, sampler, this->model_workspaces_, + this->trace_recorder_, this->engine_mode_->spec_draft_length), + EngineAction::EagleBatchVerify( + this->models_, logit_processor, sampler, this->model_workspaces_, + this->kv_cache_config_, this->trace_recorder_)}; + break; + default: + this->actions_ = { + EngineAction::NewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + this->kv_cache_config_, // + this->engine_mode_, // + this->trace_recorder_), + EngineAction::BatchDraft(this->models_, logit_processor, sampler, + this->trace_recorder_, + this->engine_mode_->spec_draft_length), + EngineAction::BatchVerify(this->models_, logit_processor, sampler, + this->kv_cache_config_, this->trace_recorder_)}; + } } else { this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // logit_processor, // diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index e355168365..1385befddf 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -66,6 +66,23 @@ class EngineAction : public ObjectRef { std::vector model_workspaces, KVCacheConfig kv_cache_config, EngineMode engine_mode, Optional trace_recorder); + /*! + * \brief Create the action that prefills requests in the `waiting_queue` + * of the engine state. + * \param models The models to run prefill in. + * \param logit_processor The logit processor. + * \param sampler The sampler to sample new tokens. + * \param model_workspaces The workspace of each model. + * \param kv_cache_config The KV cache config to help decide prefill is doable. + * \param engine_mode The engine operation mode. + * \param trace_recorder The event trace recorder for requests. + * \return The created action object. + */ + static EngineAction EagleNewRequestPrefill(Array models, LogitProcessor logit_processor, + Sampler sampler, + std::vector model_workspaces, + KVCacheConfig kv_cache_config, EngineMode engine_mode, + Optional trace_recorder); /*! * \brief Create the action that runs one-step decode for requests in the * `running_queue` of engine state. Preempt low-priority requests @@ -97,6 +114,23 @@ class EngineAction : public ObjectRef { Sampler sampler, Optional trace_recorder, int draft_length = 4); + /*! + * \brief Create the action that runs one-step speculative draft proposal for + * requests in the `running_queue` of engine state. Preempt low-priority requests + * accordingly when it is impossible to decode all the running requests. + * \param models The model to run decode in. When there are multiple + * models, the `Step` function of the created action will not take effect. + * \param sampler The sampler to sample new tokens. + * \param model_workspaces The workspace of each model. + * \param trace_recorder The event trace recorder for requests. + * \param draft_length The number of draft proposal rounds. + * \return The created action object. + */ + static EngineAction EagleBatchDraft(Array models, LogitProcessor logit_processor, + Sampler sampler, std::vector model_workspaces, + Optional trace_recorder, + int draft_length = 4); + /*! * \brief Create the action that runs one-step speculative verification for requests in the * `running_queue` of engine state. Preempt low-priority requests @@ -112,6 +146,24 @@ class EngineAction : public ObjectRef { Sampler sampler, KVCacheConfig kv_cache_config, Optional trace_recorder); + /*! + * \brief Create the action that runs one-step speculative verification for requests in the + * `running_queue` of engine state. Preempt low-priority requests + * accordingly when it is impossible to decode all the running requests. + * \param models The model to run decode in. When there are multiple + * models, the `Step` function of the created action will not take effect. + * \param sampler The sampler to sample new tokens. + * \param model_workspaces The workspace of each model. + * \param kv_cache_config The KV cache config to help decide verify is doable. + * \param trace_recorder The event trace recorder for requests. + * \return The created action object. + */ + static EngineAction EagleBatchVerify(Array models, LogitProcessor logit_processor, + Sampler sampler, + std::vector model_workspaces, + KVCacheConfig kv_cache_config, + Optional trace_recorder); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineAction, ObjectRef, EngineActionObj); }; diff --git a/cpp/serve/engine_actions/eagle_batch_draft.cc b/cpp/serve/engine_actions/eagle_batch_draft.cc new file mode 100644 index 0000000000..50393c38a2 --- /dev/null +++ b/cpp/serve/engine_actions/eagle_batch_draft.cc @@ -0,0 +1,230 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/eagle_batch_draft.cc + */ + +#include + +#include "../config.h" +#include "../model.h" +#include "../sampler/sampler.h" +#include "action.h" +#include "action_commons.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief The action that runs draft proposal for requests in the + * `running_queue` of engine state. Preempt low-priority requests + * accordingly when it is impossible to decode all the running requests. + */ +class EagleBatchDraftActionObj : public EngineActionObj { + public: + explicit EagleBatchDraftActionObj(Array models, LogitProcessor logit_processor, + Sampler sampler, std::vector model_workspaces, + Optional trace_recorder, int draft_length) + : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), + sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), + trace_recorder_(std::move(trace_recorder)), + draft_length_(draft_length) { + ICHECK_GT(draft_length_, 0); + } + + Array Step(EngineState estate) final { + // - Only run spec decode when there are two models (llm+ssm) and >=1 running requests. + if (models_.size() != 2 || estate->running_queue.empty()) { + return {}; + } + + // Preempt request state entries when decode cannot apply. + std::vector running_rsentries = GetRunningRequestStateEntries(estate); + while (!CanDecode(running_rsentries.size())) { + RequestStateEntry preempted = + PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + if (preempted.same_as(running_rsentries.back())) { + running_rsentries.pop_back(); + } + } + + auto tstart = std::chrono::high_resolution_clock::now(); + + int num_rsentries = running_rsentries.size(); + Array request_ids; + std::vector request_internal_ids; + Array generation_cfg; + std::vector rngs; + request_ids.reserve(num_rsentries); + request_internal_ids.reserve(num_rsentries); + generation_cfg.reserve(num_rsentries); + for (const RequestStateEntry& rsentry : running_rsentries) { + request_ids.push_back(rsentry->request->id); + request_internal_ids.push_back(rsentry->mstates[0]->internal_id); + generation_cfg.push_back(rsentry->request->generation_cfg); + rngs.push_back(&rsentry->rng); + } + + // The first model doesn't get involved in draft proposal. + for (int model_id = 1; model_id < static_cast(models_.size()); ++model_id) { + // Collect + // - the last committed token, + // - the request model state + // of each request. + std::vector input_tokens; + Array mstates; + input_tokens.reserve(num_rsentries); + mstates.reserve(num_rsentries); + for (const RequestStateEntry& rsentry : running_rsentries) { + mstates.push_back(rsentry->mstates[model_id]); + } + // draft_length_ rounds of draft proposal. + NDArray hidden_states_nd{nullptr}; + ObjectRef last_hidden_states{nullptr}; + ObjectRef hidden_states = model_workspaces_[model_id].hidden_states; + // Concat last hidden_states + std::vector previous_hidden_on_device; + for (int i = 0; i < num_rsentries; ++i) { + previous_hidden_on_device.push_back(mstates[i]->draft_last_hidden_on_device.back()); + } + hidden_states_nd = + models_[model_id]->ConcatLastHidden(previous_hidden_on_device, &hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 2); + ICHECK_EQ(hidden_states_nd->shape[0], num_rsentries); + hidden_states_nd = hidden_states_nd.CreateView( + {hidden_states_nd->shape[0], 1, hidden_states_nd->shape[1]}, hidden_states_nd->dtype); + last_hidden_states = hidden_states_nd; + // The first draft token has been generated in prefill/verify stage + for (int draft_id = 1; draft_id < draft_length_; ++draft_id) { + // prepare new input tokens + input_tokens.clear(); + for (int i = 0; i < num_rsentries; ++i) { + ICHECK(!mstates[i]->draft_output_tokens.empty()); + input_tokens.push_back(mstates[i]->draft_output_tokens.back().sampled_token_id.first); + } + + // - Compute embeddings. + RECORD_EVENT(trace_recorder_, request_ids, "start proposal embedding"); + ObjectRef embeddings = + models_[model_id]->TokenEmbed({IntTuple{input_tokens.begin(), input_tokens.end()}}); + RECORD_EVENT(trace_recorder_, request_ids, "finish proposal embedding"); + + // - Invoke model decode. + RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); + ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden( + embeddings, last_hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states_nd = + models_[model_id]->BatchDecodeToLastHidden(fused_hidden_states, request_internal_ids); + last_hidden_states = hidden_states_nd; + NDArray logits; + if (models_[model_id]->CanGetLogits()) { + logits = models_[model_id]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, + /*seq_len*/ 1); + } else { + // - Use base model's head. + logits = + models_[0]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + } + RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); + ICHECK_EQ(logits->ndim, 3); + ICHECK_EQ(logits->shape[0], num_rsentries); + ICHECK_EQ(logits->shape[1], 1); + + // - Update logits. + logits = logits.CreateView({num_rsentries, logits->shape[2]}, logits->dtype); + logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); + + // - Compute probability distributions. + NDArray probs_on_device = + logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); + + // - Sample tokens. + // Fill range [0, num_rsentries) into `sample_indices`. + std::vector sample_indices(num_rsentries); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + std::vector prob_dist; + std::vector sample_results = sampler_->BatchSampleTokens( + probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + ICHECK_EQ(sample_results.size(), num_rsentries); + + // - Add draft token to the state. + for (int i = 0; i < num_rsentries; ++i) { + // - Slice hidden_states_for_sample + NDArray last_hidden_on_device = GetTokenHidden(hidden_states_nd, i); + CHECK(i < static_cast(prob_dist.size())); + CHECK(prob_dist[i].defined()); + mstates[i]->AddDraftToken(sample_results[i], prob_dist[i], last_hidden_on_device); + estate->stats.total_draft_length += 1; + } + } + } + + auto tend = std::chrono::high_resolution_clock::now(); + estate->stats.engine_total_decode_time += static_cast((tend - tstart).count()) / 1e9; + + return {}; + } + + private: + /*! \brief Check if the input requests can be decoded under conditions. */ + bool CanDecode(int num_rsentries) { + // The first model is not involved in draft proposal. + for (int model_id = 1; model_id < static_cast(models_.size()); ++model_id) { + // Check if the model has enough available pages. + int num_available_pages = models_[model_id]->GetNumAvailablePages(); + if (num_rsentries > num_available_pages) { + return false; + } + } + return true; + } + + /*! + * \brief Get one item from a hidden_states array, which corresponds to the last token. + * \param hidden_states The hidden_states of all the tokens. + * \param token_pos The desired token position in the sequence. + * \return The desired token's hidden_states + */ + NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { + ICHECK_EQ(hidden_states->ndim, 3); + NDArray last_hidden_on_device = + NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); + + int64_t ndata = hidden_states->shape[2]; + const int16_t* __restrict p_hidden = + static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + + (token_pos * ndata); + + last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); + return last_hidden_on_device; + } + + /*! \brief The model to run draft generation in speculative decoding. */ + Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; + /*! \brief The sampler to sample new tokens. */ + Sampler sampler_; + /*! \brief Workspace of each model. */ + std::vector model_workspaces_; + /*! \brief Event trace recorder. */ + Optional trace_recorder_; + /*! \brief Draft proposal length */ + int draft_length_; +}; + +EngineAction EngineAction::EagleBatchDraft(Array models, LogitProcessor logit_processor, + Sampler sampler, + std::vector model_workspaces, + Optional trace_recorder, + int draft_length) { + return EngineAction(make_object( + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(trace_recorder), draft_length)); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc new file mode 100644 index 0000000000..0c2040db9d --- /dev/null +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -0,0 +1,364 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/eagle_batch_verify.cc + */ + +#include + +#include +#include +#include + +#include "../../random.h" +#include "../config.h" +#include "../model.h" +#include "../sampler/sampler.h" +#include "action.h" +#include "action_commons.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief The action that runs verification for requests in the + * `running_queue` of engine state. Preempt low-priority requests + * accordingly when it is impossible to decode all the running requests. + */ +class EagleBatchVerifyActionObj : public EngineActionObj { + public: + explicit EagleBatchVerifyActionObj(Array models, LogitProcessor logit_processor, + Sampler sampler, std::vector model_workspaces, + KVCacheConfig kv_cache_config, + Optional trace_recorder) + : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), + sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), + kv_cache_config_(std::move(kv_cache_config)), + trace_recorder_(std::move(trace_recorder)), + rng_(RandomGenerator::GetInstance()) {} + + Array Step(EngineState estate) final { + // - Only run spec decode when there are two models (llm+ssm) and >=1 running requests. + if (models_.size() != 2 || estate->running_queue.empty()) { + return {}; + } + + const auto& [rsentries, draft_lengths, total_draft_length] = GetDraftsToVerify(estate); + ICHECK_EQ(rsentries.size(), draft_lengths.size()); + if (rsentries.empty()) { + return {}; + } + + int num_rsentries = rsentries.size(); + Array request_ids = + rsentries.Map([](const RequestStateEntry& rstate) { return rstate->request->id; }); + auto tstart = std::chrono::high_resolution_clock::now(); + + // - Get embedding and run verify. + std::vector request_internal_ids; + std::vector all_tokens_to_verify; + Array verify_request_mstates; + Array generation_cfg; + std::vector rngs; + std::vector> draft_output_tokens; + std::vector> draft_output_prob_dist; + request_internal_ids.reserve(num_rsentries); + all_tokens_to_verify.reserve(total_draft_length); + verify_request_mstates.reserve(num_rsentries); + rngs.reserve(num_rsentries); + generation_cfg.reserve(num_rsentries); + draft_output_tokens.reserve(num_rsentries); + draft_output_prob_dist.reserve(num_rsentries); + + for (int i = 0; i < num_rsentries; ++i) { + RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_]; + RequestModelState draft_mstate = rsentries[i]->mstates[draft_model_id_]; + request_internal_ids.push_back(verify_mstate->internal_id); + ICHECK(!draft_lengths.empty()); + ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_tokens.size()); + ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_prob_dist.size()); + // the last committed token + all the draft tokens but the last one. + all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().sampled_token_id.first); + for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()); ++j) { + all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); + } + verify_request_mstates.push_back(verify_mstate); + generation_cfg.push_back(rsentries[i]->request->generation_cfg); + rngs.push_back(&rsentries[i]->rng); + draft_output_tokens.push_back(draft_mstate->draft_output_tokens); + CHECK(draft_mstate->draft_output_prob_dist[0]->device.device_type == kDLCPU); + draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); + } + + std::vector cum_verify_lengths = {0}; + cum_verify_lengths.reserve(num_rsentries + 1); + std::vector verify_lengths; + for (int i = 0; i < num_rsentries; ++i) { + // Add one committed token. + verify_lengths.push_back(draft_lengths[i] + 1); + cum_verify_lengths.push_back(cum_verify_lengths.back() + verify_lengths.back()); + } + + RECORD_EVENT(trace_recorder_, request_ids, "start verify embedding"); + ObjectRef embeddings = models_[verify_model_id_]->TokenEmbed( + {IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}}); + RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding"); + + RECORD_EVENT(trace_recorder_, request_ids, "start verify"); + ObjectRef fused_hidden_states = models_[verify_model_id_]->FuseEmbedHidden( + embeddings, NDArray(), 1, cum_verify_lengths[num_rsentries]); + NDArray hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden( + fused_hidden_states, request_internal_ids, verify_lengths); + ICHECK_EQ(hidden_states->ndim, 3); + ICHECK_EQ(hidden_states->shape[0], 1); + NDArray logits = + models_[verify_model_id_]->GetLogits(hidden_states, 1, cum_verify_lengths[num_rsentries]); + RECORD_EVENT(trace_recorder_, request_ids, "finish verify"); + ICHECK_EQ(logits->ndim, 3); + ICHECK_EQ(logits->shape[0], 1); + ICHECK_EQ(logits->shape[1], cum_verify_lengths[num_rsentries]); + + // - Update logits. + logits = + logits.CreateView({cum_verify_lengths[num_rsentries], logits->shape[2]}, logits->dtype); + logit_processor_->InplaceUpdateLogits(logits, generation_cfg, verify_request_mstates, + request_ids, &cum_verify_lengths, &draft_output_tokens); + + // - Compute probability distributions. + NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits( + logits, generation_cfg, request_ids, &cum_verify_lengths); + + std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokens( + probs_on_device, request_ids, cum_verify_lengths, generation_cfg, rngs, draft_output_tokens, + draft_output_prob_dist); + ICHECK_EQ(sample_results_arr.size(), num_rsentries); + + std::vector last_hidden_states; + for (int i = 0; i < num_rsentries; ++i) { + const std::vector& sample_results = sample_results_arr[i]; + int accept_length = sample_results.size(); + ICHECK_GE(accept_length, 1); + for (SampleResult sample_result : sample_results) { + rsentries[i]->mstates[verify_model_id_]->CommitToken(sample_result); + rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result); + } + estate->stats.total_accepted_length += accept_length - 1; + // - Minus one because the last draft token has no kv cache entry + // - Take max with 0 in case of all accepted. + int rollback_length = + std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0); + // rollback kv cache + // NOTE: when number of small models is more than 1 (in the future), + // it is possible to re-compute prefill for the small models. + if (rollback_length > 0) { + models_[verify_model_id_]->PopNFromKVCache( + rsentries[i]->mstates[verify_model_id_]->internal_id, rollback_length); + // Draft model rollback minus one because verify uses one more token. + models_[draft_model_id_]->PopNFromKVCache( + rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1); + } + // clear the draft model state entries + rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(); + // - Slice hidden_states_for_sample + NDArray last_hidden_on_device = + GetTokenHidden(hidden_states, (cum_verify_lengths[i] + accept_length - 1)); + last_hidden_states.push_back(last_hidden_on_device); + } + + { + // One step draft for the following steps + NDArray hidden_states_nd{nullptr}; + ObjectRef next_hidden_states = model_workspaces_[draft_model_id_].hidden_states; + // Concat last hidden_states + hidden_states_nd = + models_[draft_model_id_]->ConcatLastHidden(last_hidden_states, &next_hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 2); + ICHECK_EQ(hidden_states_nd->shape[0], num_rsentries); + hidden_states_nd = hidden_states_nd.CreateView( + {hidden_states_nd->shape[0], 1, hidden_states_nd->shape[1]}, hidden_states_nd->dtype); + + std::vector input_tokens; + Array mstates; + input_tokens.reserve(num_rsentries); + mstates.reserve(num_rsentries); + for (const RequestStateEntry& rsentry : rsentries) { + mstates.push_back(rsentry->mstates[draft_model_id_]); + } + for (int i = 0; i < num_rsentries; ++i) { + ICHECK(!mstates[i]->committed_tokens.empty()); + input_tokens.push_back(mstates[i]->committed_tokens.back().sampled_token_id.first); + } + + // - Compute embeddings. + RECORD_EVENT(trace_recorder_, request_ids, "start proposal embedding"); + embeddings = models_[draft_model_id_]->TokenEmbed( + {IntTuple{input_tokens.begin(), input_tokens.end()}}); + RECORD_EVENT(trace_recorder_, request_ids, "finish proposal embedding"); + + // - Invoke model decode. + RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); + ObjectRef fused_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( + embeddings, hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states_nd = models_[draft_model_id_]->BatchDecodeToLastHidden(fused_hidden_states, + request_internal_ids); + + if (models_[draft_model_id_]->CanGetLogits()) { + logits = models_[draft_model_id_]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, + /*seq_len*/ 1); + } else { + // - Use base model's head. + logits = + models_[0]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + } + RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); + ICHECK_EQ(logits->ndim, 3); + ICHECK_EQ(logits->shape[0], num_rsentries); + ICHECK_EQ(logits->shape[1], 1); + + // - Update logits. + logits = logits.CreateView({num_rsentries, logits->shape[2]}, logits->dtype); + logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); + + // - Compute probability distributions. + probs_on_device = + logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); + + // - Sample tokens. + // Fill range [0, num_rsentries) into `sample_indices`. + std::vector sample_indices(num_rsentries); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + std::vector prob_dist; + std::vector sample_results = sampler_->BatchSampleTokens( + probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + ICHECK_EQ(sample_results.size(), num_rsentries); + + // - Add draft token to the state. + for (int i = 0; i < num_rsentries; ++i) { + // - Slice hidden_states_for_sample + NDArray last_hidden_on_device = GetTokenHidden(hidden_states_nd, i); + CHECK(i < static_cast(prob_dist.size())); + CHECK(prob_dist[i].defined()); + mstates[i]->AddDraftToken(sample_results[i], prob_dist[i], last_hidden_on_device); + estate->stats.total_draft_length += 1; + } + } + + auto tend = std::chrono::high_resolution_clock::now(); + estate->stats.engine_total_decode_time += static_cast((tend - tstart).count()) / 1e9; + + return estate->running_queue; + } + + private: + struct DraftRequestStateEntries { + /*! \brief The request state entries to verify. */ + Array draft_rsentries; + /*! \brief The draft length of each request state. */ + std::vector draft_lengths; + /*! \brief The total draft length. */ + int total_draft_length; + }; + + /*! + * \brief Decide whether to run verify for the draft of each request. + * \param estate The engine state. + * \return The drafts to verify, together with their respective + * state and input length. + */ + DraftRequestStateEntries GetDraftsToVerify(EngineState estate) { + std::vector draft_lengths; + int total_draft_length = 0; + int total_required_pages = 0; + int num_available_pages = models_[verify_model_id_]->GetNumAvailablePages(); + + // Preempt the request state entries that cannot fit the large model for verification. + std::vector running_rsentries = GetRunningRequestStateEntries(estate); + std::vector num_page_requirement; + num_page_requirement.reserve(running_rsentries.size()); + for (const RequestStateEntry& rsentry : running_rsentries) { + int draft_length = rsentry->mstates[draft_model_id_]->draft_output_tokens.size(); + int num_require_pages = + (draft_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + draft_lengths.push_back(draft_length); + num_page_requirement.push_back(num_require_pages); + total_draft_length += draft_length; + total_required_pages += num_require_pages; + } + while (!CanVerify(total_required_pages)) { + RequestStateEntry preempted = + PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + if (preempted.same_as(running_rsentries.back())) { + total_draft_length -= draft_lengths.back(); + total_required_pages -= num_page_requirement.back(); + draft_lengths.pop_back(); + num_page_requirement.pop_back(); + running_rsentries.pop_back(); + } + } + + return {running_rsentries, draft_lengths, total_draft_length}; + } + + bool CanVerify(int num_required_pages) { + int num_available_pages = models_[0]->GetNumAvailablePages(); + return num_required_pages <= num_available_pages; + } + + /*! + * \brief Get one item from a hidden_states array, which corresponds to the last token. + * \param hidden_states The hidden_states of all the tokens. + * \param token_pos The desired token position in the sequence. + * \return The desired token's hidden_states + */ + NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { + ICHECK_EQ(hidden_states->ndim, 3); + NDArray last_hidden_on_device = + NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); + + int64_t ndata = hidden_states->shape[2]; + const int16_t* __restrict p_hidden = + static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + + (token_pos * ndata); + + last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); + return last_hidden_on_device; + } + + /*! + * \brief The model to run decode in. When there are multiple + * models, the `Step` function of the created action will not take effect. + */ + Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; + /*! \brief The sampler to sample new tokens. */ + Sampler sampler_; + /*! \brief Workspace of each model. */ + std::vector model_workspaces_; + /*! \brief The kv cache config. */ + KVCacheConfig kv_cache_config_; + /*! \brief Event trace recorder. */ + Optional trace_recorder_; + /*! \brief Random number generator. */ + RandomGenerator& rng_; + /*! \brief The ids of verify/draft models. */ + const int verify_model_id_ = 0; + const int draft_model_id_ = 1; + const float eps_ = 1e-5; +}; + +EngineAction EngineAction::EagleBatchVerify(Array models, LogitProcessor logit_processor, + Sampler sampler, + std::vector model_workspaces, + KVCacheConfig kv_cache_config, + Optional trace_recorder) { + return EngineAction(make_object( + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(kv_cache_config), std::move(trace_recorder))); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc new file mode 100644 index 0000000000..90c8ac3be8 --- /dev/null +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -0,0 +1,568 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/engine_actions/eagle_new_request_prefill.cc + */ + +#include + +#include "../config.h" +#include "../model.h" +#include "../sampler/sampler.h" +#include "action.h" +#include "action_commons.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief The action that prefills requests in the `waiting_queue` of + * the engine state. + */ +class EagleNewRequestPrefillActionObj : public EngineActionObj { + public: + explicit EagleNewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, + Sampler sampler, + std::vector model_workspaces, + KVCacheConfig kv_cache_config, EngineMode engine_mode, + Optional trace_recorder) + : models_(std::move(models)), + logit_processor_(std::move(logit_processor)), + sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), + kv_cache_config_(std::move(kv_cache_config)), + engine_mode_(std::move(engine_mode)), + trace_recorder_(std::move(trace_recorder)) {} + + Array Step(EngineState estate) final { + // - Find the requests in `waiting_queue` that can prefill in this step. + std::vector prefill_inputs; + { + NVTXScopedRange nvtx_scope("NewRequestPrefill getting requests"); + prefill_inputs = GetRequestStateEntriesToPrefill(estate); + if (prefill_inputs.empty()) { + return {}; + } + } + + int num_rsentries = prefill_inputs.size(); + auto tstart = std::chrono::high_resolution_clock::now(); + + // - Update status of request states from pending to alive. + Array request_ids; + std::vector rstates_of_entries; + std::vector status_before_prefill; + request_ids.reserve(num_rsentries); + rstates_of_entries.reserve(num_rsentries); + status_before_prefill.reserve(num_rsentries); + for (const PrefillInput& prefill_input : prefill_inputs) { + const RequestStateEntry& rsentry = prefill_input.rsentry; + const Request& request = rsentry->request; + RequestState request_rstate = estate->GetRequestState(request); + request_ids.push_back(request->id); + status_before_prefill.push_back(rsentry->status); + rsentry->status = RequestStateStatus::kAlive; + + if (status_before_prefill.back() == RequestStateStatus::kPending) { + // - Add the request to running queue if the request state + // status was pending and all its request states were pending. + bool alive_state_existed = false; + for (const RequestStateEntry& rsentry_ : request_rstate->entries) { + if (rsentry_->status == RequestStateStatus::kAlive && !rsentry_.same_as(rsentry)) { + alive_state_existed = true; + } + } + if (!alive_state_existed) { + estate->running_queue.push_back(request); + } + } + rstates_of_entries.push_back(std::move(request_rstate)); + } + + // - Get embedding and run prefill for each model. + std::vector prefill_lengths; + prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/-1); + NDArray hidden_states_for_input{nullptr}; + NDArray hidden_states_for_sample{nullptr}; + NDArray logits_for_sample{nullptr}; + // A map used to record the entry and child_idx pair needed to fork sequence. + // The base model (id 0) should record all the pairs and all the small models + // fork sequences according to this map. + std::unordered_map> fork_rsentry_child_map; + for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { + std::vector request_internal_ids; + request_internal_ids.reserve(num_rsentries); + ObjectRef embeddings = model_workspaces_[model_id].embeddings; + int cum_prefill_length = 0; + bool single_input = + num_rsentries == 1 && prefill_inputs[0].rsentry->mstates[model_id]->inputs.size() == 1; + for (int i = 0; i < num_rsentries; ++i) { + const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; + RequestModelState mstate = rsentry->mstates[model_id]; + auto [input_data, input_length] = + ChunkPrefillInputData(mstate, prefill_inputs[i].max_prefill_length); + if (prefill_lengths[i] == -1) { + prefill_lengths[i] = input_length; + } else { + ICHECK_EQ(prefill_lengths[i], input_length); + } + + ICHECK(mstate->draft_output_tokens.empty()); + ICHECK(mstate->draft_output_prob_dist.empty()); + if (status_before_prefill[i] == RequestStateStatus::kPending) { + // Add the sequence to the model, or fork the sequence from its parent. + if (rsentry->parent_idx == -1) { + models_[model_id]->AddNewSequence(mstate->internal_id); + } else { + models_[model_id]->ForkSequence( + rstates_of_entries[i]->entries[rsentry->parent_idx]->mstates[model_id]->internal_id, + mstate->internal_id); + } + // Enable sliding window for the sequence if it is not a parent. + if (rsentry->child_indices.empty()) { + models_[model_id]->EnableSlidingWindowForSeq(mstate->internal_id); + } + } + request_internal_ids.push_back(mstate->internal_id); + RECORD_EVENT(trace_recorder_, prefill_inputs[i].rsentry->request->id, "start embedding"); + // Speculative models shift left the input tokens by 1 when base model has committed tokens. + // Note: for n > 1 cases Eagle doesn't work because parent entry doesn't shift input tokens. + int embed_offset = + prefill_inputs[i].rsentry->mstates[model_id]->committed_tokens.empty() ? 0 : 1; + for (int j = 0; j < static_cast(input_data.size()); ++j) { + if (j == static_cast(input_data.size()) - 1) { + std::vector tail_tokens; + TokenData tk_data = Downcast(input_data[j]); + CHECK(tk_data.defined()); + for (int k = embed_offset; k < static_cast(tk_data->token_ids.size()); ++k) { + tail_tokens.push_back(tk_data->token_ids[k]); + } + embeddings = models_[model_id]->TokenEmbed( + {tail_tokens.begin(), tail_tokens.end()}, + /*dst=*/!single_input ? &model_workspaces_[model_id].embeddings : nullptr, + /*offset=*/cum_prefill_length); + cum_prefill_length += input_data[j]->GetLength(); + cum_prefill_length -= embed_offset; + } else { + embeddings = input_data[i]->GetEmbedding( + models_[model_id], + /*dst=*/!single_input ? &model_workspaces_[model_id].embeddings : nullptr, + /*offset=*/cum_prefill_length); + cum_prefill_length += input_data[j]->GetLength(); + } + } + if (embed_offset > 0) { + std::vector new_tokens = {prefill_inputs[i] + .rsentry->mstates[model_id] + ->committed_tokens.back() + .sampled_token_id.first}; + embeddings = + models_[model_id]->TokenEmbed({new_tokens.begin(), new_tokens.end()}, + /*dst=*/&model_workspaces_[model_id].embeddings, + /*offset=*/cum_prefill_length); + cum_prefill_length += new_tokens.size(); + } + RECORD_EVENT(trace_recorder_, rsentry->request->id, "finish embedding"); + } + + RECORD_EVENT(trace_recorder_, request_ids, "start prefill"); + ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden( + embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length); + NDArray hidden_states = models_[model_id]->BatchPrefillToLastHidden( + fused_hidden_states, request_internal_ids, prefill_lengths); + RECORD_EVENT(trace_recorder_, request_ids, "finish prefill"); + ICHECK_EQ(hidden_states->ndim, 3); + ICHECK_EQ(hidden_states->shape[0], 1); + ICHECK_EQ(hidden_states->shape[1], cum_prefill_length); + + if (model_id == 0) { + // We only need to sample for model 0 in prefill. + hidden_states_for_input = hidden_states; + } + + // Whether to use base model to get logits. + int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id; + hidden_states_for_sample = models_[sample_model_id]->BatchSelectLastHidden( + hidden_states, request_internal_ids, prefill_lengths); + logits_for_sample = + models_[sample_model_id]->GetLogits(hidden_states_for_sample, 1, num_rsentries); + ICHECK_EQ(hidden_states_for_sample->ndim, 3); + ICHECK_EQ(hidden_states_for_sample->shape[0], 1); + ICHECK_EQ(hidden_states_for_sample->shape[1], num_rsentries); + + // - Update logits. + ICHECK(logits_for_sample.defined()); + Array generation_cfg; + Array mstates_for_logitproc; + generation_cfg.reserve(num_rsentries); + mstates_for_logitproc.reserve(num_rsentries); + for (int i = 0; i < num_rsentries; ++i) { + generation_cfg.push_back(prefill_inputs[i].rsentry->request->generation_cfg); + mstates_for_logitproc.push_back(prefill_inputs[i].rsentry->mstates[sample_model_id]); + } + logits_for_sample = logits_for_sample.CreateView({num_rsentries, logits_for_sample->shape[2]}, + logits_for_sample->dtype); + logit_processor_->InplaceUpdateLogits(logits_for_sample, generation_cfg, + mstates_for_logitproc, request_ids); + + // - Compute probability distributions. + NDArray probs_on_device = + logit_processor_->ComputeProbsFromLogits(logits_for_sample, generation_cfg, request_ids); + + // - Sample tokens. + // For prefill_inputs which have children, sample + // one token for each rstate that is depending. + // Otherwise, sample a token for the current rstate. + std::vector sample_indices; + std::vector rsentries_for_sample; + std::vector rngs; + sample_indices.reserve(num_rsentries); + rsentries_for_sample.reserve(num_rsentries); + rngs.reserve(num_rsentries); + request_ids.clear(); + generation_cfg.clear(); + for (int i = 0; i < num_rsentries; ++i) { + const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; + for (int child_idx : rsentry->child_indices) { + // Only use base model to judge if we need to add child entries. + if (rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty() || + fork_rsentry_child_map[i].count(child_idx)) { + // If rstates_of_entries[i]->entries[child_idx] has no committed token, + // the prefill of the current rsentry will unblock + // rstates_of_entries[i]->entries[child_idx], + // and thus we want to sample a token for rstates_of_entries[i]->entries[child_idx]. + fork_rsentry_child_map[i].insert(child_idx); + sample_indices.push_back(i); + rsentries_for_sample.push_back(rstates_of_entries[i]->entries[child_idx]); + request_ids.push_back(rsentry->request->id); + generation_cfg.push_back(rsentry->request->generation_cfg); + rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng); + + if (model_id == 0) { + ICHECK(rstates_of_entries[i]->entries[child_idx]->status == + RequestStateStatus::kPending); + rstates_of_entries[i]->entries[child_idx]->status = RequestStateStatus::kAlive; + } + int64_t child_internal_id = + rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id; + models_[model_id]->ForkSequence(rsentry->mstates[model_id]->internal_id, + child_internal_id); + // Enable sliding window for the child sequence if the child is not a parent. + if (rstates_of_entries[i]->entries[child_idx]->child_indices.empty()) { + models_[model_id]->EnableSlidingWindowForSeq(child_internal_id); + } + } + } + if (rsentry->child_indices.empty()) { + // If rsentry has no child, we sample a token for itself. + sample_indices.push_back(i); + rsentries_for_sample.push_back(rsentry); + request_ids.push_back(rsentry->request->id); + generation_cfg.push_back(rsentry->request->generation_cfg); + rngs.push_back(&rsentry->rng); + } + } + std::vector prob_dist; + std::vector sample_results = sampler_->BatchSampleTokens( + probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); + + // - Update the committed tokens of states. + // - If a request is first-time prefilled, set the prefill finish time. + auto tnow = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + if (model_id == 0) { + for (int mid = 0; mid < static_cast(models_.size()); ++mid) { + rsentries_for_sample[i]->mstates[mid]->CommitToken(sample_results[i]); + } + // Only base model trigger timing records. + if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { + rsentries_for_sample[i]->tprefill_finish = tnow; + } + } else { + // - Slice hidden_states_for_sample + NDArray last_hidden_on_device = GetTokenHidden(hidden_states_for_sample, i); + CHECK(i < static_cast(prob_dist.size())); + CHECK(prob_dist[i].defined()); + rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i], prob_dist[i], + last_hidden_on_device); + estate->stats.total_draft_length += 1; + } + } + } + + auto tend = std::chrono::high_resolution_clock::now(); + estate->stats.engine_total_prefill_time += static_cast((tend - tstart).count()) / 1e9; + + // - Remove the request from waiting queue if all its request states + // are now alive and have no remaining chunked inputs. + std::vector processed_requests; + { + processed_requests.reserve(num_rsentries); + std::unordered_set dedup_map; + for (int i = 0; i < num_rsentries; ++i) { + const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; + if (dedup_map.find(rsentry->request.get()) != dedup_map.end()) { + continue; + } + dedup_map.insert(rsentry->request.get()); + processed_requests.push_back(rsentry->request); + + bool pending_state_exists = false; + for (const RequestStateEntry& rsentry_ : rstates_of_entries[i]->entries) { + if (rsentry_->status == RequestStateStatus::kPending || + !rsentry_->mstates[0]->inputs.empty()) { + pending_state_exists = true; + break; + } + } + if (!pending_state_exists) { + auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), + rsentry->request); + ICHECK(it != estate->waiting_queue.end()); + estate->waiting_queue.erase(it); + } + } + } + return processed_requests; + } + + private: + /*! \brief The class of request state entry and its maximum allowed length for prefill. */ + struct PrefillInput { + RequestStateEntry rsentry; + int max_prefill_length; + }; + + /*! + * \brief Find one or multiple request state entries to run prefill. + * \param estate The engine state. + * \return The request entries to prefill, together with their input lengths. + */ + std::vector GetRequestStateEntriesToPrefill(EngineState estate) { + if (estate->waiting_queue.empty()) { + // No request to prefill. + return {}; + } + + std::vector prefill_inputs; + + // - Try to prefill pending requests. + int total_input_length = 0; + int total_required_pages = 0; + int num_available_pages = models_[0]->GetNumAvailablePages(); + int num_running_rsentries = GetRunningRequestStateEntries(estate).size(); + int current_total_seq_len = models_[0]->GetCurrentTotalSequenceLength(); + + int num_prefill_rsentries = 0; + for (const Request& request : estate->waiting_queue) { + RequestState rstate = estate->GetRequestState(request); + bool prefill_stops = false; + for (const RequestStateEntry& rsentry : rstate->entries) { + // A request state entry can be prefilled only when: + // - it has inputs, and + // - it has no parent or its parent is alive and has no remaining input. + if (rsentry->mstates[0]->inputs.empty() || + (rsentry->parent_idx != -1 && + (rstate->entries[rsentry->parent_idx]->status == RequestStateStatus::kPending || + !rstate->entries[rsentry->parent_idx]->mstates[0]->inputs.empty()))) { + continue; + } + + int input_length = rsentry->mstates[0]->GetInputLength(); + int num_require_pages = + (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + // - Attempt 1. Check if the entire request state entry can fit for prefill. + if (CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), + total_input_length, total_required_pages, num_available_pages, + current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length}); + num_prefill_rsentries += 1 + rsentry->child_indices.size(); + continue; + } + total_input_length -= input_length; + total_required_pages -= num_require_pages; + + // - Attempt 2. Check if the request state entry can partially fit by input chunking. + ICHECK_LE(total_input_length, kv_cache_config_->prefill_chunk_size); + input_length = + std::min(input_length, kv_cache_config_->prefill_chunk_size - total_input_length); + num_require_pages = + (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + if (input_length > 0 && + CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), + total_input_length, total_required_pages, num_available_pages, + current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length}); + num_prefill_rsentries += 1 + rsentry->child_indices.size(); + } + + // - Prefill stops here. + prefill_stops = true; + break; + } + if (prefill_stops) { + break; + } + } + + return prefill_inputs; + } + + /*! \brief Check if the input requests can be prefilled under conditions. */ + bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, + int num_required_pages, int num_available_pages, int current_total_seq_len, + int num_running_rsentries) { + ICHECK_LE(num_running_rsentries, kv_cache_config_->max_num_sequence); + + // No exceeding of the maximum allowed requests that can + // run simultaneously. + int spec_factor = engine_mode_->speculative_mode != SpeculativeMode::kDisable + ? engine_mode_->spec_draft_length + : 1; + if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > + std::min(kv_cache_config_->max_num_sequence, kv_cache_config_->prefill_chunk_size)) { + return false; + } + + // NOTE: The conditions are heuristic and can be revised. + // Cond 1: total input length <= prefill chunk size. + // Cond 2: at least one decode can be performed after prefill. + // Cond 3: number of total tokens after 8 times of decode does not + // exceed the limit, where 8 is a watermark number can + // be configured and adjusted in the future. + int new_batch_size = num_running_rsentries + num_prefill_rsentries; + return total_input_length <= kv_cache_config_->prefill_chunk_size && + num_required_pages + new_batch_size <= num_available_pages && + current_total_seq_len + total_input_length + 8 * new_batch_size <= + kv_cache_config_->max_total_sequence_length; + } + + /*! + * \brief Chunk the input of the given RequestModelState for prefill + * with regard to the provided maximum allowed prefill length. + * Return the list of input for prefill and the total prefill length. + * The `inputs` field of the given `mstate` will be mutated to exclude + * the returned input. + * \param mstate The RequestModelState whose input data is to be chunked. + * \param max_prefill_length The maximum allowed prefill length for the mstate. + * \return The list of input for prefill and the total prefill length. + */ + std::pair, int> ChunkPrefillInputData(const RequestModelState& mstate, + int max_prefill_length) { + if (mstate->inputs.empty()) { + } + ICHECK(!mstate->inputs.empty()); + std::vector inputs; + int cum_input_length = 0; + inputs.reserve(mstate->inputs.size()); + for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { + inputs.push_back(mstate->inputs[i]); + int input_length = mstate->inputs[i]->GetLength(); + cum_input_length += input_length; + // Case 0. the cumulative input length does not reach the maximum prefill length. + if (cum_input_length < max_prefill_length) { + continue; + } + + // Case 1. the cumulative input length equals the maximum prefill length. + if (cum_input_length == max_prefill_length) { + if (i == static_cast(mstate->inputs.size()) - 1) { + // - If `i` is the last input, we just copy and reset `mstate->inputs`. + mstate->inputs.clear(); + } else { + // - Otherwise, set the new input array. + mstate->inputs = Array{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; + } + return {inputs, cum_input_length}; + } + + // Case 2. cum_input_length > max_prefill_length + // The input `i` itself needs chunking if it is TokenData, + // or otherwise it cannot be chunked. + Data input = mstate->inputs[i]; + inputs.pop_back(); + cum_input_length -= input_length; + const auto* token_input = input.as(); + if (token_input == nullptr) { + // Cannot chunk the input. + if (i != 0) { + mstate->inputs = Array{mstate->inputs.begin() + i, mstate->inputs.end()}; + } + return {inputs, cum_input_length}; + } + + // Split the token data into two parts. + // Return the first part for prefill, and keep the second part. + int chunked_input_length = max_prefill_length - cum_input_length; + ICHECK_GT(input_length, chunked_input_length); + TokenData chunked_input(IntTuple{token_input->token_ids.begin(), + token_input->token_ids.begin() + chunked_input_length}); + TokenData remaining_input(IntTuple{token_input->token_ids.begin() + chunked_input_length, + token_input->token_ids.end()}); + inputs.push_back(chunked_input); + cum_input_length += chunked_input_length; + std::vector remaining_inputs{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; + remaining_inputs.insert(remaining_inputs.begin(), remaining_input); + mstate->inputs = remaining_inputs; + return {inputs, cum_input_length}; + } + + ICHECK(false) << "Cannot reach here"; + } + + /*! + * \brief Get one item from a hidden_states array, which corresponds to the last token. + * \param hidden_states The hidden_states of all the tokens. + * \param token_pos The desired token position in the sequence. + * \return The desired token's hidden_states + */ + NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { + ICHECK_EQ(hidden_states->ndim, 3); + NDArray last_hidden_on_device = + NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); + + int64_t ndata = hidden_states->shape[2]; + const int16_t* __restrict p_hidden = + static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + + (token_pos * ndata); + + last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); + return last_hidden_on_device; + } + + /*! \brief The models to run prefill in. */ + Array models_; + /*! \brief The logit processor. */ + LogitProcessor logit_processor_; + /*! \brief The sampler to sample new tokens. */ + Sampler sampler_; + /*! \brief Workspace of each model. */ + std::vector model_workspaces_; + /*! \brief The KV cache config to help decide prefill is doable. */ + KVCacheConfig kv_cache_config_; + /*! \brief The engine operation mode. */ + EngineMode engine_mode_; + /*! \brief Event trace recorder. */ + Optional trace_recorder_; +}; + +EngineAction EngineAction::EagleNewRequestPrefill(Array models, + LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + KVCacheConfig kv_cache_config, + EngineMode engine_mode, + Optional trace_recorder) { + return EngineAction(make_object( + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(kv_cache_config), std::move(engine_mode), + std::move(trace_recorder))); +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index 5ff8ee923e..288bb9ad83 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -360,7 +360,9 @@ class NewRequestPrefillActionObj : public EngineActionObj { // No exceeding of the maximum allowed requests that can // run simultaneously. - int spec_factor = engine_mode_->enable_speculative ? engine_mode_->spec_draft_length : 1; + int spec_factor = engine_mode_->speculative_mode != SpeculativeMode::kDisable + ? engine_mode_->spec_draft_length + : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > std::min(kv_cache_config_->max_num_sequence, kv_cache_config_->prefill_chunk_size)) { return false; diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index a7f878c1ba..21835566b3 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -228,7 +228,16 @@ void FunctionTable::_InitFunctions() { this->prefill_func_ = mod_get_func("batch_prefill"); this->decode_func_ = mod_get_func("batch_decode"); this->verify_func_ = mod_get_func("batch_verify"); + this->single_batch_prefill_to_last_hidden_func_ = mod_get_func("prefill_to_last_hidden_states"); + this->single_batch_decode_to_last_hidden_func_ = mod_get_func("decode_to_last_hidden_states"); + this->prefill_to_last_hidden_func_ = mod_get_func("batch_prefill_to_last_hidden_states"); + this->decode_to_last_hidden_func_ = mod_get_func("batch_decode_to_last_hidden_states"); + this->verify_to_last_hidden_func_ = mod_get_func("batch_verify_to_last_hidden_states"); + this->fuse_embed_hidden_func_ = mod_get_func("fuse_embed_hidden_states"); Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm; + this->get_logits_func_ = mod->GetFunction("get_logits", true); + this->batch_get_logits_func_ = mod->GetFunction("batch_get_logits", true); + this->batch_select_last_hidden_func_ = mod->GetFunction("batch_select_last_hidden_states", true); this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true); this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true); this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true); @@ -276,7 +285,6 @@ ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String buffer_cache_key, ShapeTuple max_reserved_shape) { - ICHECK(host_array->device.device_type == DLDeviceType::kDLCPU); if (this->use_disco) { Device null_device{DLDeviceType(0), 0}; DRef buffer(nullptr); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index 29d9d82fbc..195f79264e 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -72,6 +72,15 @@ struct FunctionTable { PackedFunc prefill_func_; PackedFunc decode_func_; PackedFunc verify_func_; + PackedFunc single_batch_prefill_to_last_hidden_func_; + PackedFunc single_batch_decode_to_last_hidden_func_; + PackedFunc prefill_to_last_hidden_func_; + PackedFunc decode_to_last_hidden_func_; + PackedFunc verify_to_last_hidden_func_; + PackedFunc fuse_embed_hidden_func_; + PackedFunc get_logits_func_; + PackedFunc batch_get_logits_func_; + PackedFunc batch_select_last_hidden_func_; PackedFunc softmax_func_; PackedFunc apply_logit_bias_func_; PackedFunc apply_penalty_func_; diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index 9dc4b1b9c5..f7190d50ac 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -289,7 +289,7 @@ class LogitProcessorImpl : public LogitProcessorObj { p_penalties[num_token_for_penalty * 3 + 2] = generation_cfg[i]->repetition_penalty; ++num_token_for_penalty; if (j > 0) { - mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray()); + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray(), NDArray()); } } if (num_token_to_process != 1) { @@ -368,7 +368,7 @@ class LogitProcessorImpl : public LogitProcessorObj { p_seq_ids[token_start_offset + j] = 1; } if (j > 0) { - mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray()); + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray(), NDArray()); } } if (token_number != 1) { diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 5ebf26a061..fa4a4bf09a 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -116,6 +116,223 @@ class ModelImpl : public ModelObj { } } + bool CanGetLogits() final { + return ft_.get_logits_func_.defined() && ft_.batch_get_logits_func_.defined(); + } + + NDArray GetLogits(const ObjectRef& last_hidden_states, int batch_size, int seq_len) final { + NVTXScopedRange nvtx_scope("GetLogits"); + CHECK(ft_.get_logits_func_.defined()) << "`get_logits` function is not found in the model."; + + ObjectRef hidden_states_dref_or_nd; + CHECK(!last_hidden_states->IsInstance()); + // hidden_states: (b, s, h) + NDArray hidden_states = Downcast(last_hidden_states); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(hidden_states->ndim, 3); + ICHECK_EQ(hidden_states->shape[0], batch_size); + ICHECK_EQ(hidden_states->shape[1], seq_len); + ICHECK_EQ(hidden_states->shape[2], hidden_size_); + ICHECK_EQ(hidden_states->device.device_type, device_.device_type); + ICHECK_EQ(hidden_states->device.device_id, device_.device_id); + + hidden_states_dref_or_nd = + hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); + + ObjectRef ret = ft_.get_logits_func_(hidden_states_dref_or_nd, params_); + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + + NDArray logits; + logits = Downcast(ret); + CHECK(logits.defined()); + // logits: (b * s, v) + ICHECK_EQ(logits->ndim, 2); + ICHECK_EQ(logits->shape[0], batch_size * seq_len); + return logits.CreateView({batch_size, seq_len, logits->shape[1]}, logits->dtype); + } + + NDArray BatchGetLogits(const ObjectRef& last_hidden_states, const std::vector& seq_ids, + const std::vector& lengths) { + NVTXScopedRange nvtx_scope("BatchGetLogits"); + CHECK(!seq_ids.empty()); + CHECK_EQ(seq_ids.size(), lengths.size()); + int num_sequences = seq_ids.size(); + int total_length = 0; + + int* p_logit_pos = static_cast(logit_pos_arr_->data); + for (int i = 0; i < num_sequences; ++i) { + total_length += lengths[i]; + p_logit_pos[i] = total_length - 1; + } + NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); + ObjectRef logit_pos_dref_or_nd = + ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); + + CHECK(ft_.batch_get_logits_func_.defined()) + << "`batch_get_logits` function is not found in the model."; + + ObjectRef hidden_states_dref_or_nd; + CHECK(!last_hidden_states->IsInstance()); + // hidden_states: (b, s, h) + NDArray hidden_states = Downcast(last_hidden_states); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(hidden_states->ndim, 3); + ICHECK_EQ(hidden_states->shape[0], 1); + ICHECK_EQ(hidden_states->shape[1], total_length); + ICHECK_EQ(hidden_states->shape[2], hidden_size_); + ICHECK_EQ(hidden_states->device.device_type, device_.device_type); + ICHECK_EQ(hidden_states->device.device_id, device_.device_id); + + hidden_states_dref_or_nd = + hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); + + ObjectRef ret = + ft_.batch_get_logits_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + + NDArray logits; + logits = Downcast(ret); + CHECK(logits.defined()); + // logits: (b * s, v) + ICHECK_EQ(logits->ndim, 2); + ICHECK_EQ(logits->shape[0], num_sequences); + return logits.CreateView({1, num_sequences, logits->shape[1]}, logits->dtype); + } + + NDArray BatchSelectLastHidden(const ObjectRef& last_hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) { + NVTXScopedRange nvtx_scope("BatchSelectLastHidden"); + CHECK(!seq_ids.empty()); + CHECK_EQ(seq_ids.size(), lengths.size()); + int num_sequences = seq_ids.size(); + int total_length = 0; + + int* p_logit_pos = static_cast(logit_pos_arr_->data); + for (int i = 0; i < num_sequences; ++i) { + total_length += lengths[i]; + p_logit_pos[i] = total_length - 1; + } + NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); + ObjectRef logit_pos_dref_or_nd = + ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); + + CHECK(ft_.batch_select_last_hidden_func_.defined()) + << "`batch_select_last_hidden_states` function is not found in the model."; + + ObjectRef hidden_states_dref_or_nd; + CHECK(!last_hidden_states->IsInstance()); + // hidden_states: (b, s, h) + NDArray hidden_states = Downcast(last_hidden_states); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(hidden_states->ndim, 3); + ICHECK_EQ(hidden_states->shape[0], 1); + ICHECK_EQ(hidden_states->shape[1], total_length); + ICHECK_EQ(hidden_states->shape[2], hidden_size_); + ICHECK_EQ(hidden_states->device.device_type, device_.device_type); + ICHECK_EQ(hidden_states->device.device_id, device_.device_id); + + hidden_states_dref_or_nd = + hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); + + ObjectRef ret = + ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + + NDArray hidden; + hidden = Downcast(ret); + // hidden: (b * s, v) + ICHECK_EQ(hidden->ndim, 2); + ICHECK_EQ(hidden->shape[0], num_sequences); + return hidden.CreateView({1, num_sequences, hidden->shape[1]}, hidden->dtype); + } + + NDArray ConcatLastHidden(std::vector& hidden_states, ObjectRef* dst) final { + NVTXScopedRange nvtx_scope("ConcatLastHidden"); + + CHECK(dst->defined()); + + int cum_length = 0; + ICHECK_GE(hidden_states.size(), 1); + for (auto hidden : hidden_states) { + ICHECK_EQ(hidden->ndim, 1); + // No ICHECK_EQ(hidden->shape[0], hidden_size_) here to allow different hidden_sizes. + hidden = hidden.CreateView({1, hidden_size_}, hidden->dtype); + // Reuse the copy embedding function + ft_.nd_copy_embedding_to_offset_func_(hidden, *dst, cum_length); + cum_length += 1; + } + NDArray ret = Downcast(*dst); + ret = ret.CreateView({cum_length, hidden_size_}, hidden_states[0]->dtype); + return ret; + } + + ObjectRef FuseEmbedHidden(const ObjectRef& embeddings, const ObjectRef& previous_hidden_states, + int batch_size, int seq_len) final { + NVTXScopedRange nvtx_scope("FuseEmbedHidden"); + + ObjectRef embeddings_dref_or_nd; + if (!embeddings->IsInstance()) { + // embeddings: (n, h) + NDArray embeddings_nd = Downcast(embeddings); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(embeddings_nd->ndim, 2); + ICHECK_GE(embeddings_nd->shape[0], batch_size * seq_len); + ICHECK_EQ(embeddings_nd->shape[1], hidden_size_); + ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); + ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); + embeddings_dref_or_nd = + embeddings_nd.CreateView({batch_size * seq_len, hidden_size_}, embeddings_nd->dtype); + + if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { + // Model has no support for fuse_embed_hidden_states or this is the first model (base model) + return embeddings_nd.CreateView({batch_size, seq_len, hidden_size_}, embeddings_nd->dtype); + } + } else { + ShapeTuple embedding_shape{batch_size, seq_len, hidden_size_}; + embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); + + if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { + // Model has no support for fuse_embed_hidden_states or this is the first model (base model) + ShapeTuple embedding_shape{batch_size, seq_len, hidden_size_}; + return ft_.nd_view_func_(embeddings, embedding_shape); + } + } + + NDArray hidden_states = Downcast(previous_hidden_states); + CHECK(hidden_states.defined()); + ICHECK_EQ(hidden_states->ndim, 3); + ICHECK_EQ(hidden_states->shape[0], batch_size); + ICHECK_EQ(hidden_states->shape[1], seq_len); + ICHECK_EQ(hidden_states->shape[2], hidden_size_); + ICHECK_EQ(hidden_states->device.device_type, device_.device_type); + ICHECK_EQ(hidden_states->device.device_id, device_.device_id); + NDArray hidden_states_2d = + hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); + auto hidden_states_dref_or_nd = + ft_.CopyToWorker0(hidden_states_2d, "hidden_states_2d", + {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); + + ObjectRef ret = + ft_.fuse_embed_hidden_func_(embeddings_dref_or_nd, hidden_states_dref_or_nd, params_); + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + if (!ret->IsInstance()) { + NDArray fused = Downcast(ret); + return fused.CreateView({batch_size, seq_len, hidden_size_}, fused->dtype); + } else { + ShapeTuple fused_shape{batch_size, seq_len, hidden_size_}; + return ft_.nd_view_func_(ret, fused_shape); + } + } + NDArray BatchPrefill(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) final { NVTXScopedRange nvtx_scope("BatchPrefill"); @@ -187,6 +404,74 @@ class ModelImpl : public ModelObj { return logits; } + NDArray BatchPrefillToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) final { + NVTXScopedRange nvtx_scope("BatchPrefillToLastHidden"); + CHECK(!seq_ids.empty()); + CHECK_EQ(seq_ids.size(), lengths.size()); + int num_sequences = seq_ids.size(); + int total_length = 0; + + for (int i = 0; i < num_sequences; ++i) { + total_length += lengths[i]; + } + + ObjectRef hidden_states_dref_or_nd; + if (!hidden_states->IsInstance()) { + // hidden_states: (1, n, h) + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], 1); + ICHECK_EQ(hidden_states_nd->shape[1], total_length); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + hidden_states_dref_or_nd = + hidden_states_nd.CreateView({1, total_length, hidden_size_}, hidden_states_nd->dtype); + } else { + ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; + hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); + } + + CHECK(ft_.prefill_to_last_hidden_func_.defined()) + << "`prefill_to_last_hidden_states` function is not found in the model."; + ICHECK(ft_.kv_cache_begin_forward_func_.defined()); + ICHECK(ft_.kv_cache_end_forward_func_.defined()); + ICHECK(kv_cache_.defined()) << "KV cache has not been initialized."; + + // Begin forward with the sequence ids and new lengths. + IntTuple seq_ids_tuple(seq_ids); + IntTuple lengths_tuple(lengths.begin(), lengths.end()); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + + // args: embeddings, logit_pos, kv_cache, params + ObjectRef ret; + if (seq_ids.size() == 1) { + CHECK(ft_.single_batch_prefill_to_last_hidden_func_.defined()) + << "`single_batch_prefill_to_last_hidden_states` function is not found in the model."; + ret = ft_.single_batch_prefill_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, + params_); + } else { + ret = ft_.prefill_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); + } + NDArray last_hidden_states; + if (ft_.use_disco) { + Array result = Downcast(ret)->DebugGetFromRemote(0); + last_hidden_states = Downcast(result[0]); + } else { + last_hidden_states = Downcast>(ret)[0]; + } + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + ft_.kv_cache_end_forward_func_(kv_cache_); + + // hidden_states: (1, total_length, v) + ICHECK_EQ(last_hidden_states->ndim, 3); + ICHECK_EQ(last_hidden_states->shape[0], 1); + ICHECK_EQ(last_hidden_states->shape[1], total_length); + return last_hidden_states; + } + NDArray BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) final { NVTXScopedRange nvtx_scope("BatchDecode"); int num_sequence = seq_ids.size(); @@ -247,6 +532,67 @@ class ModelImpl : public ModelObj { return logits; } + NDArray BatchDecodeToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids) final { + NVTXScopedRange nvtx_scope("BatchDecodeToLastHidden"); + int num_sequence = seq_ids.size(); + + CHECK(ft_.decode_to_last_hidden_func_.defined()) + << "`batch_decode_to_last_hidden_states` function is not found in the model."; + ICHECK(ft_.kv_cache_begin_forward_func_.defined()); + ICHECK(ft_.kv_cache_end_forward_func_.defined()); + ICHECK(kv_cache_.defined()) << "KV cache has not been initialized."; + + ObjectRef hidden_states_dref_or_nd; + if (!hidden_states->IsInstance()) { + // hidden_states: (1, n, h) + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], num_sequence); + ICHECK_EQ(hidden_states_nd->shape[1], 1); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + hidden_states_dref_or_nd = + hidden_states_nd.CreateView({num_sequence, 1, hidden_size_}, hidden_states_nd->dtype); + } else { + ShapeTuple hidden_states_shape{num_sequence, 1, hidden_size_}; + hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); + } + + // Reserve in KV cache for the lengths of the input. + // Begin forward with the sequence ids and new lengths. + IntTuple seq_ids_tuple(seq_ids); + IntTuple lengths_tuple(std::vector(/*n=*/seq_ids.size(), /*v=*/1)); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + + // args: embeddings, kv_cache, params + ObjectRef ret; + if (seq_ids.size() == 1) { + CHECK(ft_.single_batch_decode_to_last_hidden_func_.defined()) + << "`decode_to_last_hidden_states` function is not found in the model."; + ret = ft_.single_batch_decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, + params_); + } else { + ret = ft_.decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); + } + NDArray last_hidden_states; + if (ft_.use_disco) { + Array result = Downcast(ret)->DebugGetFromRemote(0); + last_hidden_states = Downcast(result[0]); + } else { + last_hidden_states = Downcast>(ret)[0]; + } + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + ft_.kv_cache_end_forward_func_(kv_cache_); + + // hidden_states: (b, 1, v) + ICHECK_EQ(last_hidden_states->ndim, 3); + ICHECK_EQ(last_hidden_states->shape[0], num_sequence); + ICHECK_EQ(last_hidden_states->shape[1], 1); + return last_hidden_states; + } + NDArray BatchVerify(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) final { NVTXScopedRange nvtx_scope("BatchVerify"); @@ -307,6 +653,65 @@ class ModelImpl : public ModelObj { return logits; } + NDArray BatchVerifyToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) final { + NVTXScopedRange nvtx_scope("BatchVerifyToLastHidden"); + CHECK(!seq_ids.empty()); + CHECK_EQ(seq_ids.size(), lengths.size()); + int num_sequences = seq_ids.size(); + int total_length = 0; + for (int i = 0; i < num_sequences; ++i) { + total_length += lengths[i]; + } + + CHECK(ft_.verify_to_last_hidden_func_.defined()) + << "`batch_verify_to_last_hidden_states` function is not found in the model."; + ICHECK(ft_.kv_cache_begin_forward_func_.defined()); + ICHECK(ft_.kv_cache_end_forward_func_.defined()); + ICHECK(kv_cache_.defined()) << "KV cache has not been initialized."; + + ObjectRef hidden_states_dref_or_nd; + if (!hidden_states->IsInstance()) { + // hidden_states: (1, n, h) + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], 1); + ICHECK_EQ(hidden_states_nd->shape[1], total_length); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + hidden_states_dref_or_nd = + hidden_states_nd.CreateView({1, total_length, hidden_size_}, hidden_states_nd->dtype); + } else { + ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; + hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); + } + + // Begin forward with the sequence ids and new lengths. + IntTuple seq_ids_tuple(seq_ids); + IntTuple lengths_tuple(lengths.begin(), lengths.end()); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + + // args: embeddings, logit_pos, kv_cache, params + ObjectRef ret = ft_.verify_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); + NDArray last_hidden_states; + if (ft_.use_disco) { + Array result = Downcast(ret)->DebugGetFromRemote(0); + last_hidden_states = Downcast(result[0]); + } else { + last_hidden_states = Downcast>(ret)[0]; + } + if (trace_enabled_) { + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + ft_.kv_cache_end_forward_func_(kv_cache_); + + // hidden_states: (1, total_length, v) + ICHECK_EQ(last_hidden_states->ndim, 3); + ICHECK_EQ(last_hidden_states->shape[0], 1); + ICHECK_EQ(last_hidden_states->shape[1], total_length); + return last_hidden_states; + } + /*********************** KV Cache Management ***********************/ LogitProcessor CreateLogitProcessor(int max_num_token, @@ -400,6 +805,26 @@ class ModelImpl : public ModelObj { return embedding; } + ObjectRef AllocHiddenStatesTensor() final { + // Allocate the hidden_states tensor. + // Use the same function as embeddings. + ObjectRef hidden_states = ft_.alloc_embedding_tensor_func_(); + // Get the shape of the hidden_states tensor for hidden size. + ShapeTuple hidden_states_shape; + if (ft_.use_disco) { + ICHECK(hidden_states->IsInstance()); + ObjectRef shape_ref = ft_.nd_get_shape_func_(hidden_states); + hidden_states_shape = Downcast(shape_ref)->DebugGetFromRemote(0); + } else { + NDArray hidden_states_nd = Downcast(hidden_states); + hidden_states_shape = hidden_states_nd.Shape(); + } + ICHECK_EQ(hidden_states_shape.size(), 2); + ICHECK_EQ(hidden_states_shape[0], prefill_chunk_size_); + this->hidden_size_ = hidden_states_shape[1]; + return hidden_states; + } + void Reset() final { // Reset the KV cache. if (kv_cache_.defined()) { diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 4e57d499ef..79619acbe6 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -39,6 +39,11 @@ struct ModelWorkspace { * model parallelism is not enabled, or a DRef when using tensor model parallelism. */ ObjectRef embeddings{nullptr}; + /*! + * \brief The hidden_states tensor. It can be either an NDArray when tensor + * model parallelism is not enabled, or a DRef when using tensor model parallelism. + */ + ObjectRef hidden_states{nullptr}; }; /*! @@ -91,6 +96,61 @@ class ModelObj : public Object { */ virtual ObjectRef ImageEmbed(const NDArray& image, ObjectRef* dst = nullptr, int offset = 0) = 0; + /*! + * \brief Fuse the embeddings and hidden_states. + * \param embeddings The embedding of the input to be prefilled. + * \param previous_hidden_states The hidden_states from previous base model. + * \param batch_size Batch size. + * \param seq_len Sequence length. + * \return The fused hidden_states. + */ + virtual ObjectRef FuseEmbedHidden(const ObjectRef& embeddings, + const ObjectRef& previous_hidden_states, int batch_size, + int seq_len) = 0; + + /*! + * \brief Return if the model has lm_head so that we can get logits. + */ + virtual bool CanGetLogits() = 0; + + /*! + * \brief Compute logits for last hidden_states. + * \param last_hidden_states The last hidden_states to compute logits for. + * \param batch_size The batch size of last_hidden_states + * \param seq_len The length of tokens in last_hidden_states + * \return The computed logits. + */ + virtual NDArray GetLogits(const ObjectRef& last_hidden_states, int batch_size, int seq_len) = 0; + + /*! + * \brief Compute logits for last hidden_states in a batch. + * \param last_hidden_states The last hidden_states to compute logits for. + * \param seq_ids The id of the sequence in the KV cache. + * \param lengths The length of each sequence to prefill. + * \return The computed logits. + */ + virtual NDArray BatchGetLogits(const ObjectRef& last_hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; + + /*! + * \brief Select desired hidden_states for last hidden_states in a batch. + * \param last_hidden_states The last hidden_states to select from. + * \param seq_ids The id of the sequence in the KV cache. + * \param lengths The length of each sequence to prefill. + * \return The last hidden_states for the batch. + */ + virtual NDArray BatchSelectLastHidden(const ObjectRef& last_hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; + + /*! + * \brief Concat a list of 1D hidden_states to 2D tensor. + * \param hidden_states The hidden_states to concat. + * \param dst The copy destination. + */ + virtual NDArray ConcatLastHidden(std::vector& hidden_states, ObjectRef* dst) = 0; + /*! * \brief Batch prefill function. Embedding in, logits out. * The embedding order of sequences in `embedding_arr` follows @@ -103,6 +163,18 @@ class ModelObj : public Object { virtual NDArray BatchPrefill(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) = 0; + /*! + * \brief Batch prefill function. Input hidden_states are computed from + * input embeddings and previous hidden_states, output last hidden_states. + * \param hidden_states The hidden_states of the input to be prefilled. + * \param seq_id The id of the sequence in the KV cache. + * \param lengths The length of each sequence to prefill. + * \return The hidden_states for the next token. + */ + virtual NDArray BatchPrefillToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; + /*! * \brief Batch decode function. Embedding in, logits out. * The embedding order of sequences in `embeddings` follows @@ -113,6 +185,16 @@ class ModelObj : public Object { */ virtual NDArray BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) = 0; + /*! + * \brief Batch decode function. Input hidden_states are computed from + * input embeddings and previous hidden_states, output last hidden_states. + * \param hidden_states The hidden_states of last generated token in the entire batch. + * \param seq_id The id of the sequence in the KV cache. + * \return The hidden_states for the next token for each sequence in the batch. + */ + virtual NDArray BatchDecodeToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids) = 0; + /*! * \brief Batch verify function. Embedding in, logits out. * \param embeddings The embedding of the input to be verified. @@ -126,6 +208,21 @@ class ModelObj : public Object { virtual NDArray BatchVerify(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) = 0; + /*! + * \brief Batch verify function. Input hidden_states are computed from + * input embeddings and previous hidden_states, output last hidden_states. + * \param hidden_states The hidden_states of the input to be verified. + * \param seq_id The id of the sequence in the KV cache. + * \param lengths The length of each sequence to verify. + * \return The hidden_states for the draft token for each sequence in the batch. + * \note The function runs for **every** sequence in the batch. + * That is to say, it does not accept "running a verify step for a subset + * of the full batch". + */ + virtual NDArray BatchVerifyToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; + /*********************** KV Cache Management ***********************/ /*! @@ -188,6 +285,9 @@ class ModelObj : public Object { /*! \brief Allocate an embedding tensor with the prefill chunk size. */ virtual ObjectRef AllocEmbeddingTensor() = 0; + /*! \brief Allocate an hidden_states tensor with the prefill chunk size. */ + virtual ObjectRef AllocHiddenStatesTensor() = 0; + /*! \brief Reset the model KV cache and other statistics. */ virtual void Reset() = 0; diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 2a035ad387..b1f5ae27a2 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -59,9 +59,11 @@ void RequestModelStateNode::CommitToken(SampleResult sampled_token) { } } -void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, NDArray prob_dist) { +void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, NDArray prob_dist, + NDArray last_hidden_on_device) { draft_output_tokens.push_back(std::move(sampled_token)); draft_output_prob_dist.push_back(std::move(prob_dist)); + draft_last_hidden_on_device.push_back(std::move(last_hidden_on_device)); appeared_token_ids[sampled_token.sampled_token_id.first] += 1; } @@ -116,14 +118,6 @@ RequestStateEntry::RequestStateEntry( DeltaRequestReturn RequestStateEntryNode::GetReturnTokenIds(const Tokenizer& tokenizer, int max_single_sequence_length) { - // - Case 0. There is remaining draft output ==> Unfinished - // All draft outputs are supposed to be processed before finish. - for (RequestModelState mstate : this->mstates) { - if (!mstate->draft_output_tokens.empty()) { - return {{}, {}, Optional()}; - } - } - std::vector return_token_ids; std::vector logprob_json_strs; Optional finish_reason; diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 7764a38c3e..950bb6e290 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -70,6 +70,12 @@ class RequestModelStateNode : public Object { * and draft outputs in speculative inference settings. */ std::vector draft_output_prob_dist; + /*! + * \brief The last hidden_states used to get probs in drafting. + * \note We only need this value when we have multiple parallel small models + * and draft outputs in speculative inference settings. + */ + std::vector draft_last_hidden_on_device; /*! \brief The appeared committed and draft tokens and their occurrence times. */ std::unordered_map appeared_token_ids; @@ -95,7 +101,8 @@ class RequestModelStateNode : public Object { /*! \brief Commit a new token into committed_tokens. Update appeared_token_ids. */ void CommitToken(SampleResult sampled_token); /*! \brief Add a draft token into draft_output_tokens. Update appeared_token_ids. */ - void AddDraftToken(SampleResult sampled_token, NDArray prob_dist); + void AddDraftToken(SampleResult sampled_token, NDArray prob_dist, + NDArray draft_last_hidden_on_device = NDArray()); /*! \brief Remove the last token from draft_output_tokens. Update appeared_token_ids. */ void RemoveLastDraftToken(); /*! \brief Remove all draft tokens from draft_output_tokens. Update appeared_token_ids. */ diff --git a/cpp/serve/sampler/cpu_sampler.cc b/cpp/serve/sampler/cpu_sampler.cc index e1316e57f0..02b7e2a81d 100644 --- a/cpp/serve/sampler/cpu_sampler.cc +++ b/cpp/serve/sampler/cpu_sampler.cc @@ -22,7 +22,8 @@ namespace serve { * The input is a batch of distributions, and we use `unit_offset` to specify * which distribution to sample from. * \param prob The input batch of probability distributions. - * \param unit_offset The offset specifying which distribution to sample from. + * \param unit_offset The offset specifying which distribution to output + * \param input_prob_offset The offset specifying which distribution to sample from. * \param top_p The top-p value of sampling. * \param uniform_sample The random number in [0, 1] for sampling. * \param output_prob_dist Optional pointer to store the corresponding probability distribution of @@ -31,7 +32,8 @@ namespace serve { * \note This function is an enhancement of SampleTopPFromProb in TVM Unity. * We will upstream the enhancement after it gets stable. */ -TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, double uniform_sample, +TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_offset, double top_p, + double uniform_sample, std::vector* output_prob_dist = nullptr) { // prob: (*, v) // The prob array may have arbitrary ndim and shape. @@ -50,10 +52,11 @@ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, double top_p, do int64_t ndata = prob->shape[prob->ndim - 1]; const float* __restrict p_prob = - static_cast(__builtin_assume_aligned(prob->data, 4)) + (unit_offset * ndata); + static_cast(__builtin_assume_aligned(prob->data, 4)) + (input_prob_offset * ndata); constexpr double one = 1.0f - 1e-5f; if (output_prob_dist) { + ICHECK_LT(unit_offset, static_cast(output_prob_dist->size())); if (!(*output_prob_dist)[unit_offset].defined()) { (*output_prob_dist)[unit_offset] = NDArray::Empty({ndata}, prob->dtype, DLDevice{kDLCPU, 0}); } @@ -294,7 +297,7 @@ class CPUSampler : public SamplerObj { RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); // Sample top p from probability. sample_results[i].sampled_token_id = SampleTopPFromProb( - probs_host, sample_indices[i], + probs_host, i, sample_indices[i], generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, rngs[i]->GetRandomNumber(), output_prob_dist); if (output_prob_dist == nullptr) { @@ -341,7 +344,9 @@ class CPUSampler : public SamplerObj { [&](int i) { int verify_start = cum_verify_lengths[i]; int verify_end = cum_verify_lengths[i + 1]; - for (int cur_token_idx = 0; cur_token_idx < verify_end - verify_start; ++cur_token_idx) { + int cur_token_idx = 0; + // Sub 1 to ignore the last prediction. + for (; cur_token_idx < verify_end - verify_start - 1; ++cur_token_idx) { float* p_probs = global_p_probs + (verify_start + cur_token_idx) * vocab_size; int cur_token = draft_output_tokens[i][cur_token_idx].sampled_token_id.first; float q_value = draft_output_tokens[i][cur_token_idx].sampled_token_id.second; @@ -383,7 +388,7 @@ class CPUSampler : public SamplerObj { // sample a new token from the new distribution SampleResult sample_result; sample_result.sampled_token_id = SampleTopPFromProb( - probs_host, verify_start + cur_token_idx, + probs_host, verify_start + cur_token_idx, verify_start + cur_token_idx, generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, rngs[i]->GetRandomNumber()); sample_result.top_prob_tokens = ComputeTopProbs( @@ -391,6 +396,20 @@ class CPUSampler : public SamplerObj { sample_results[i].push_back(sample_result); break; } + // if cur_token_idx == verify_end - verify_start - 1 + // all draft tokens are accepted + // we sample a new token + if (cur_token_idx == verify_end - verify_start - 1) { + SampleResult sample_result; + // sample a new token from the original distribution + sample_result.sampled_token_id = SampleTopPFromProb( + probs_host, verify_start + cur_token_idx, verify_start + cur_token_idx, + generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, + rngs[i]->GetRandomNumber()); + sample_result.top_prob_tokens = ComputeTopProbs( + probs_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); + sample_results[i].push_back(sample_result); + } }, 0, num_sequence); RECORD_EVENT(trace_recorder_, request_ids, "finish draft verification"); diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index 2d28730a9b..78d44b0086 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -29,7 +29,15 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR return mod bb = relax.BlockBuilder(mod) - vocab_size = mod["prefill"].ret_struct_info.fields[0].shape[-1] + # Prefill method exists in base models. + # Prefill_to_last_hidden method exists in base model and speculative small models + if "prefill" in mod: + vocab_size = mod["prefill"].ret_struct_info.fields[0].shape[-1] + else: + assert ( + "prefill_to_last_hidden_states" in mod + ), "Everay model should either has 'prefill' or 'prefill_to_last_hidden_states' method" + vocab_size = mod["prefill_to_last_hidden_states"].ret_struct_info.fields[0].shape[-1] gv_names = [ gv.name_hint for gv in [ diff --git a/python/mlc_llm/model/eagle/__init__.py b/python/mlc_llm/model/eagle/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/eagle/eagle_loader.py b/python/mlc_llm/model/eagle/eagle_loader.py new file mode 100644 index 0000000000..36ffee8a6c --- /dev/null +++ b/python/mlc_llm/model/eagle/eagle_loader.py @@ -0,0 +1,172 @@ +""" +This file specifies how MLC's EAGLE parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .eagle_model import EagleConfig, EagleForCasualLM +from .eagle_quantization import awq_quant + + +def huggingface(model_config: EagleConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : EagleConfig + The configuration of the Eagle model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = EagleForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping + + +def awq(model_config: EagleConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of AWQ parameters. + Parameters + ---------- + model_config : EagleConfig + The configuration of the Eagle model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to AWQ. + """ + model, _ = awq_quant(model_config, quantization) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), # type: ignore[attr-defined] + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"layers.{i}.self_attn" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{attn}.qkv_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{quantize_suffix}", + f"{attn}.k_proj.{quantize_suffix}", + f"{attn}.v_proj.{quantize_suffix}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate( + [q, k, v], + axis=1, # AWQ GEMM would transpose the weight + ).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Concat gate and up in MLP + mlp = f"layers.{i}.mlp" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.{quantize_suffix}", + f"{mlp}.up_proj.{quantize_suffix}", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate( + [gate, up], + axis=1, # AWQ GEMM would transpose the weight + ).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype), + ) + return mapping diff --git a/python/mlc_llm/model/eagle/eagle_model.py b/python/mlc_llm/model/eagle/eagle_model.py new file mode 100644 index 0000000000..ba647604de --- /dev/null +++ b/python/mlc_llm/model/eagle/eagle_model.py @@ -0,0 +1,242 @@ +""" +Implementation for EAGLE architecture. +""" + +import dataclasses +from typing import Optional + +from tvm import tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.model.llama.llama_model import LlamaAttention, LlamaConfig, LlamaFFN +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class EagleConfig(LlamaConfig): + """Configuration of the Eagle model.""" + + +# pylint: disable=invalid-name,missing-docstring + + +class EagleDecoderLayer(nn.Module): + def __init__(self, config: EagleConfig, index: int): + rms_norm_eps = config.rms_norm_eps + self.self_attn = LlamaAttention(config) + self.mlp = LlamaFFN(config) + self.index = index + if self.index != 0: + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_q_heads * hd + k = self.self_attn.num_kv_heads * hd + v = self.self_attn.num_kv_heads * hd + i = self.mlp.intermediate_size + _set(self.self_attn.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0)) + _set(self.self_attn.o_proj, tp.ShardSingleDim("_shard_o", dim=1)) + _set(self.mlp.gate_up_proj, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0)) + _set(self.mlp.down_proj, tp.ShardSingleDim("_shard_mlp_down", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + if self.index != 0: + hidden_states = self.input_layernorm(hidden_states) + out = self.self_attn(hidden_states, paged_kv_cache, layer_id) + hidden_states = self._apply_residual(out, residual=hidden_states) + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = self._apply_residual(out, residual=hidden_states) + return hidden_states + + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + + +class EagleForCasualLM(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: EagleConfig): + # Put the model definition here to align with EAGLE's original structure + assert config.hidden_size % config.num_attention_heads == 0 + self.embed_tokens = nn.Embedding("vocab_size", config.hidden_size) + self.layers = nn.ModuleList( + [EagleDecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.fc = nn.Linear( + in_features=2 * config.hidden_size, out_features=config.hidden_size, bias=True + ) + + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.vocab_size = config.vocab_size + self.rope_theta = config.position_embedding_base + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def fuse_embed_hidden_states(self, input_embed: Tensor, hidden_states: Tensor): + hidden_states = op.concat([input_embed, hidden_states], dim=-1) + hidden_states = self.fc(hidden_states) + return hidden_states + + def forward_to_last_hidden_states(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache): + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + return hidden_states + + def forward(self, input_embed: Tensor, hidden_states: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = self.fuse_embed_hidden_states(input_embed, hidden_states) + hidden_states = self.forward_to_last_hidden_states(hidden_states, paged_kv_cache) + return hidden_states + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + hidden_states: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.forward_to_last_hidden_states(hidden_states, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + return hidden_states + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + return self.embed_tokens(input_ids) + + def prefill_to_last_hidden_states(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.forward_to_last_hidden_states(hidden_states, paged_kv_cache) + return hidden_states, paged_kv_cache + + def decode_to_last_hidden_states(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.forward_to_last_hidden_states(hidden_states, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_prefill_to_last_hidden_states( + self, + hidden_states: Tensor, + paged_kv_cache: PagedKVCache, + ): + hidden_states = self.batch_forward(hidden_states, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_decode_to_last_hidden_states( + self, hidden_states: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward(hidden_states, paged_kv_cache) + return hidden_states, paged_kv_cache + + def create_paged_kv_cache( # pylint: disable=too-many-arguments + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "fuse_embed_hidden_states": { + "input_embed": nn.spec.Tensor(["length", self.hidden_size], self.dtype), + "hidden_states": nn.spec.Tensor(["length", self.hidden_size], self.dtype), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill_to_last_hidden_states": { + "hidden_states": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode_to_last_hidden_states": { + "hidden_states": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill_to_last_hidden_states": { + "hidden_states": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode_to_last_hidden_states": { + "hidden_states": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/eagle/eagle_quantization.py b/python/mlc_llm/model/eagle/eagle_quantization.py new file mode 100644 index 0000000000..a926f7d9dd --- /dev/null +++ b/python/mlc_llm/model/eagle/eagle_quantization.py @@ -0,0 +1,70 @@ +"""This file specifies how MLC's Eagle parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize + +from .eagle_model import EagleConfig, EagleForCasualLM + + +def group_quant( + model_config: EagleConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Eagle-architecture model using group quantization.""" + model: nn.Module = EagleForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: EagleConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Eagle-architecture model using FasterTransformer quantization.""" + model: nn.Module = EagleForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: EagleConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Eagle-architecture model using Activation-aware Weight Quantization(AWQ).""" + model: nn.Module = EagleForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: EagleConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Eagle model without quantization.""" + model: nn.Module = EagleForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 2ae5500c6d..7a01cc20de 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -224,15 +224,43 @@ def batch_forward( hidden_states = self.model(input_embeds, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) + return self.get_logits(hidden_states) + + def batch_forward_to_last_hidden_states( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + ): + op_ext.configure() + + hidden_states = self.model(input_embeds, paged_kv_cache) + return hidden_states + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + return self.model.embed_tokens(input_ids) + + def get_logits(self, hidden_states: Tensor): + op_ext.configure() logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") return logits - def embed(self, input_ids: Tensor): + def batch_get_logits(self, hidden_states: Tensor, logit_positions: Tensor): + op_ext.configure() if self.tensor_parallel_shards > 1: - input_ids = op.ccl_broadcast_from_worker0(input_ids) - return self.model.embed_tokens(input_ids) + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) + hidden_states = op.take(hidden_states, logit_positions, axis=0) + return self.get_logits(hidden_states) + + def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): + op_ext.configure() + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) + hidden_states = op.take(hidden_states, logit_positions, axis=0) + return hidden_states def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): op_ext.configure() @@ -243,20 +271,28 @@ def _index(x: te.Tensor): # x[:-1,:] hidden_states = self.model(input_embed, paged_kv_cache) hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) - logits = self.lm_head(hidden_states) - if logits.dtype != "float32": - logits = logits.astype("float32") + logits = self.get_logits(hidden_states) return logits, paged_kv_cache def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) - logits = self.lm_head(hidden_states) - if logits.dtype != "float32": - logits = logits.astype("float32") + logits = self.get_logits(hidden_states) return logits, paged_kv_cache + def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + return hidden_states, paged_kv_cache + + def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + return hidden_states, paged_kv_cache + def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): @@ -273,6 +309,24 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache + def batch_prefill_to_last_hidden_states( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_decode_to_last_hidden_states( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_verify_to_last_hidden_states( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) + return hidden_states, paged_kv_cache + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) @@ -309,6 +363,29 @@ def get_default_spec(self): "effect_mode": "none", }, }, + "get_logits": { + "hidden_states": nn.spec.Tensor(["batch_size", self.hidden_size], self.dtype), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_get_logits": { + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_select_last_hidden_states": { + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "prefill": { "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), @@ -325,6 +402,22 @@ def get_default_spec(self): "effect_mode": "none", }, }, + "prefill_to_last_hidden_states": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode_to_last_hidden_states": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "batch_prefill": { "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), @@ -350,6 +443,30 @@ def get_default_spec(self): "effect_mode": "none", }, }, + "batch_prefill_to_last_hidden_states": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode_to_last_hidden_states": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify_to_last_hidden_states": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "softmax_with_temperature": { "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), "temperature": nn.spec.Tensor(["batch_size"], "float32"), diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index fe9775109a..595d7ba9a3 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -10,6 +10,7 @@ from .baichuan import baichuan_loader, baichuan_model, baichuan_quantization from .chatglm3 import chatglm3_loader, chatglm3_model, chatglm3_quantization +from .eagle import eagle_loader, eagle_model, eagle_quantization from .gemma import gemma_loader, gemma_model, gemma_quantization from .gpt2 import gpt2_loader, gpt2_model, gpt2_quantization from .gpt_bigcode import gpt_bigcode_loader, gpt_bigcode_model, gpt_bigcode_quantization @@ -338,4 +339,20 @@ class Model: "group-quant": chatglm3_quantization.group_quant, }, ), + "eagle": Model( + name="eagle", + model=eagle_model.EagleForCasualLM, + config=eagle_model.EagleConfig, + source={ + "huggingface-torch": eagle_loader.huggingface, + "huggingface-safetensor": eagle_loader.huggingface, + "awq": eagle_loader.awq, + }, + quantize={ + "no-quant": eagle_quantization.no_quant, + "group-quant": eagle_quantization.group_quant, + "ft-quant": eagle_quantization.ft_quant, + "awq": eagle_quantization.awq_quant, + }, + ), } diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 7043cb75c7..764ec44198 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -2,7 +2,7 @@ # Load MLC LLM library by importing base from .. import base -from .config import EngineMode, GenerationConfig, KVCacheConfig +from .config import EngineMode, GenerationConfig, KVCacheConfig, SpeculativeMode from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData from .engine import AsyncEngine, Engine from .grammar import BNFGrammar, GrammarStateMatcher diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index e539ec7e56..32460d2dde 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -1,5 +1,6 @@ """Configuration dataclasses used in MLC LLM serving""" +import enum import json from dataclasses import asdict, dataclass, field from typing import Dict, List, Literal, Optional @@ -162,25 +163,53 @@ def from_json(json_str: str) -> "KVCacheConfig": return KVCacheConfig(**json.loads(json_str)) +class SpeculativeMode(enum.Enum): + """The speculative mode.""" + + DISABLE = 0 + SMALL_DRAFT = 1 + EAGLE = 2 + + +def speculative_mode_to_int(speculative_mode: SpeculativeMode): + """Convert speculative mode to int value + + Parameters + ---------- + speculative_mode (SpeculativeMode): + the speculative mode + """ + if speculative_mode == SpeculativeMode.DISABLE: + return 0 + if speculative_mode == SpeculativeMode.SMALL_DRAFT: + return 1 + if speculative_mode == SpeculativeMode.EAGLE: + return 2 + raise RuntimeError("Unknown speculative mode.") + + @dataclass class EngineMode: """The Engine execution mode. Parameters ---------- - enable_speculative : bool - Whether the speculative decoding mode is enabled, default False. spec_draft_length : int The number of tokens to generate in speculative proposal (draft), default 4. + + speculative_mode: SpeculativeMode + The speculative mode. """ - enable_speculative: bool = False spec_draft_length: int = 4 + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE def asjson(self) -> str: """Return the config in string of JSON format.""" - return json.dumps(asdict(self)) + dt = asdict(self) + dt["speculative_mode"] = speculative_mode_to_int(self.speculative_mode) + return json.dumps(dt) @staticmethod def from_json(json_str: str) -> "EngineMode": diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index b142bce7ae..dc0d0c1c7f 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -3,7 +3,13 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncEngine, EngineMode, GenerationConfig, KVCacheConfig +from mlc_llm.serve import ( + AsyncEngine, + EngineMode, + GenerationConfig, + KVCacheConfig, + SpeculativeMode, +) from mlc_llm.serve.engine_base import ModelInfo prompts = [ @@ -31,7 +37,7 @@ async def test_engine_generate(): model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", ) kv_cache_config = KVCacheConfig(page_size=16) - engine_mode = EngineMode(enable_speculative=True) + engine_mode = EngineMode(speculative_mode=SpeculativeMode.SMALL_DRAFT) # Create engine async_engine = AsyncEngine([llm, ssm], kv_cache_config, engine_mode) diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 403f75d325..49a55e3ed0 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -10,6 +10,7 @@ KVCacheConfig, Request, RequestStreamOutput, + SpeculativeMode, data, ) from mlc_llm.serve.engine_base import ModelInfo @@ -77,8 +78,74 @@ def test_engine_basic(): "dist/Llama-2-7b-chat-hf-q0f16-MLC", model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", ) - kv_cache_config = KVCacheConfig(page_size=16) - engine_mode = EngineMode(enable_speculative=True) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + engine_mode = EngineMode(speculative_mode=SpeculativeMode.SMALL_DRAFT) + + # Hyperparameters for tests (you can try different combinations). + num_requests = len(prompts) # [4, 8, 10] + temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.0 # [1.0, 1.01] + max_tokens: int = 256 # [32, 128, 256] + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + + # Define the callback function for request generation results + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + + # Create engine + engine = SyncEngine([model, ssm], kv_cache_config, engine_mode, fcallback) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + engine.step() + + for req_id, output in enumerate(outputs): + print(f"Prompt {req_id}: {requests[req_id].inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + + +def test_engine_eagle_basic(): + """Test engine **without continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have the same max_tokens. This means all requests + will end together. + - Engine keeps running `step` for estimated number of steps (number of + requests + max_tokens - 1). Then check the output of each request. + - Use Eagle model as speculative model + """ + + # Initialize model loading info and KV cache config + ssm = ModelInfo( + "dist/Eagle-llama2-7b-chat-q0f16-MLC", + model_lib_path="dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so", + ) + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + engine_mode = EngineMode(spec_draft_length=2, speculative_mode=SpeculativeMode.EAGLE) # Hyperparameters for tests (you can try different combinations). num_requests = len(prompts) # [4, 8, 10] @@ -143,8 +210,92 @@ def test_engine_continuous_batching_1(): "dist/Llama-2-7b-chat-hf-q0f16-MLC", model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", ) - kv_cache_config = KVCacheConfig(page_size=16) - engine_mode = EngineMode(enable_speculative=True) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + engine_mode = EngineMode(speculative_mode=SpeculativeMode.SMALL_DRAFT) + + # Hyperparameters for tests (you can try different combinations) + num_requests = len(prompts) # [4, 8, 10] + temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.00 # [1.0, 1.01] + max_tokens_low = 128 + max_tokens_high = 384 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + finish_time = [None] * num_requests + + # Define the callback class for request generation results + class CallbackTimer: + timer: int = -1 + + def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + if stream_outputs[0].finish_reason is not None: + print(f"Request {request_id} finished at step {self.timer}.") + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + finish_time[int(request_id)] = self.timer + + return fcallback + + def step(self) -> None: + self.timer += 1 + + # Create engine + timer = CallbackTimer() + engine = SyncEngine([model, ssm], kv_cache_config, engine_mode, timer.callback_getter()) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens_low, + max_tokens_high=max_tokens_high, + ) + + # Add all requests to engine + for request in requests: + engine.add_request(request) + + num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1 + # Run steps + for step in range(num_steps): + timer.step() + assert timer.timer == step + engine.step() + + for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): + print(f"Prompt {req_id}: {request.inputs[0]}") + print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") + # assert fin_time == request.generation_config.max_tokens - 1 + + +def test_engine_eagle_continuous_batching_1(): + """Test engine **with continuous batching**. + + - Add all requests to the engine altogether in the beginning. + - All requests have a random maximum generation length. So each + request keeps generating until reaching the maximum length. + - Engine keeps running `step` for estimated number of steps (number of + requests + the maximum max_tokens - 1). Then check the output + of each request. + """ + + # Initialize model loading info and KV cache config + ssm = ModelInfo( + "dist/Eagle-llama2-7b-chat-q4f16_1-MLC", + model_lib_path="dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so", + ) + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + engine_mode = EngineMode(speculative_mode=SpeculativeMode.EAGLE) # Hyperparameters for tests (you can try different combinations) num_requests = len(prompts) # [4, 8, 10] @@ -217,8 +368,39 @@ def test_engine_generate(): "dist/Llama-2-7b-chat-hf-q0f16-MLC", model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", ) - kv_cache_config = KVCacheConfig(page_size=16) - engine_mode = EngineMode(enable_speculative=True) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + engine_mode = EngineMode(speculative_mode=SpeculativeMode.SMALL_DRAFT) + # Create engine + engine = SyncEngine([model, ssm], kv_cache_config, engine_mode) + + num_requests = 10 + max_tokens = 256 + + # Generate output. + output_texts, _ = engine.generate( + prompts[:num_requests], GenerationConfig(max_tokens=max_tokens, n=3) + ) + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + +def test_engine_eagle_generate(): + # Initialize model loading info and KV cache config + ssm = ModelInfo( + "dist/Eagle-llama2-7b-chat-q4f16_1-MLC", + model_lib_path="dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so", + ) + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + engine_mode = EngineMode(speculative_mode=SpeculativeMode.EAGLE) # Create engine engine = SyncEngine([model, ssm], kv_cache_config, engine_mode) @@ -246,7 +428,7 @@ def test_engine_efficiency(): "dist/Llama-2-13b-chat-hf-q4f16_1-MLC", model_lib_path="dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so", ) - kv_cache_config = KVCacheConfig(page_size=16) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Hyperparameters for tests (you can try different combinations). num_requests = 1 # [4, 8, 10] @@ -317,8 +499,80 @@ def test_engine_spec_efficiency(): "dist/Llama-2-13b-chat-hf-q4f16_1-MLC", model_lib_path="dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so", ) - kv_cache_config = KVCacheConfig(page_size=16) - engine_mode = EngineMode(enable_speculative=True, spec_draft_length=6) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + engine_mode = EngineMode(spec_draft_length=6, speculative_mode=SpeculativeMode.SMALL_DRAFT) + + # Hyperparameters for tests (you can try different combinations). + num_requests = 1 # [4, 8, 10] + temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] + repetition_penalty = 1.0 # [1.0, 1.01] + max_tokens: int = 512 + np.random.seed(0) + + # Output list + outputs = [[] for _ in range(num_requests)] + + # Define the callback function for request generation results + def fcallback(delta_outputs: List[RequestStreamOutput]): + for delta_output in delta_outputs: + request_id, stream_outputs = delta_output.unpack() + assert len(stream_outputs) == 1 + outputs[int(request_id)] += stream_outputs[0].delta_token_ids + + # Create engine + spec_engine = SyncEngine([model, ssm], kv_cache_config, engine_mode, fcallback) + + # Create requests + requests = create_requests( + num_requests, + temperature=temperature, + repetition_penalty=repetition_penalty, + max_tokens_low=max_tokens, + max_tokens_high=max_tokens + 1, + ) + + # Add all requests to engine + for request in requests: + spec_engine.add_request(request) + + num_steps = num_requests + max_tokens - 1 + # Run steps + for step in range(num_steps): + spec_engine.step() + + for eg, name in zip([spec_engine], ["Speculative Decoding"]): + stats = eg.stats() + print("engine name:", name) + if name == "Speculative Decoding": + print("total draft tokens:", stats["total_draft_tokens"]) + print("total accepted tokens:", stats["total_accepted_tokens"]) + print( + "Accept rate:", + stats["total_accepted_tokens"] / (1e-10 + stats["total_draft_tokens"]), + ) + print("engine total decode time:", stats["engine_total_decode_time"]) + print() + + +def test_engine_eagle_spec_efficiency(): + """Test engine speculative decoding efficiency.""" + + # Initialize model loading info and KV cache config + ssm = ModelInfo( + "dist/Eagle-llama2-7b-chat-q0f16-MLC", + model_lib_path="dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so", + ) + # If Flashinfer allows head_dim < 128, we can test this model + # ssm = ModelInfo( + # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC", + # model_lib_path="dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so", + # ) + model = ModelInfo( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + ) + kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) + engine_mode = EngineMode(spec_draft_length=6, speculative_mode=SpeculativeMode.EAGLE) # Hyperparameters for tests (you can try different combinations). num_requests = 1 # [4, 8, 10] @@ -374,7 +628,11 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): if __name__ == "__main__": test_engine_basic() + test_engine_eagle_basic() test_engine_continuous_batching_1() + test_engine_eagle_continuous_batching_1() test_engine_generate() + test_engine_eagle_generate() test_engine_efficiency() test_engine_spec_efficiency() + test_engine_eagle_spec_efficiency() From 65e4a56ddb3939bf3746a132b087d4a905bd4cf4 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 12 Apr 2024 17:09:12 -0400 Subject: [PATCH 183/531] [Pass] Attach non-negative TIR var attributes (#2125) This PR attaches the attributes of `tir.non_negative_var` for memory planning. --- python/mlc_llm/compiler_pass/attach_sampler.py | 7 ++++++- python/mlc_llm/compiler_pass/attach_support_info.py | 5 ++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index 78d44b0086..1b7b0328a9 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -20,6 +20,7 @@ def __init__(self, target: tvm.target.Target, variable_bounds: Dict[str, int]): "num_samples": max_batch_size, "num_positions": 6 * max_batch_size, } + self.non_negative_var = ["vocab_size"] self.target = target def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: @@ -50,7 +51,11 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR mod = bb.finalize() for gv_name in gv_names: - mod[gv_name] = mod[gv_name].with_attr("tir_var_upper_bound", self.variable_bounds) + mod[gv_name] = ( + mod[gv_name] + .with_attr("tir_var_upper_bound", self.variable_bounds) + .with_attr("tir_non_negative_var", self.non_negative_var) + ) return mod diff --git a/python/mlc_llm/compiler_pass/attach_support_info.py b/python/mlc_llm/compiler_pass/attach_support_info.py index dbeb621fdc..f4a332f115 100644 --- a/python/mlc_llm/compiler_pass/attach_support_info.py +++ b/python/mlc_llm/compiler_pass/attach_support_info.py @@ -13,12 +13,15 @@ class AttachVariableBounds: # pylint: disable=too-few-public-methods def __init__(self, variable_bounds: Dict[str, int]): # Specifically for RWKV workloads, which contains -1 max_seq_len self.variable_bounds = {k: v for k, v in variable_bounds.items() if v > 0} + self.non_negative_var = ["vocab_size"] def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: """Entrypoint""" for g_var, func in mod.functions_items(): if isinstance(func, relax.Function): - mod[g_var] = func.with_attr("tir_var_upper_bound", self.variable_bounds) + mod[g_var] = func.with_attr("tir_var_upper_bound", self.variable_bounds).with_attr( + "tir_non_negative_var", self.non_negative_var + ) return mod From 8e8a92170d7dda76c4fee146cc8bae86f1326387 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 12 Apr 2024 19:28:09 -0400 Subject: [PATCH 184/531] [Serving][Refactor] Engine constructor interface refactor (#2126) This PR is a refactor of the engine's contructor interface and the serve CLI interface. This PR introduces the "mode" argument for engine, which has options "local", "interactive" and "server". The choice of mode will affect the automatically inferred value of `max_batch_size`, `max_total_sequence_length` and `prefill_chunk_size` (only effective when arguements are not specified. Once an argument is specified, we will not override it). For detailed specification of the mode, please check out the CLI help messages in `mlc_llm/help.py` or the engine constructor in `mlc_llm/serve/engine.py`. No matter which mode is chosen, we will print out the current mode and the values of these arguments, for peopple to understand the settings of the engine. We also provide hints on how to adjust the mode. For example, ``` [2024-04-12 16:12:26] INFO chat_module.py:379: Using model folder: /home/ruihang/Workspace/mlc-llm/dist/Llama-2-7b-chat-hf-q0f16-MLC [2024-04-12 16:12:26] INFO chat_module.py:380: Using mlc chat config: /home/ruihang/Workspace/mlc-llm/dist/Llama-2-7b-chat-hf-q0f16-MLC/mlc-chat-config.json [2024-04-12 16:12:26] INFO chat_module.py:529: Using library model: dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so [2024-04-12 16:12:26] INFO chat_module.py:379: Using model folder: /home/ruihang/Workspace/mlc-llm/dist/Llama-2-7b-chat-hf-q4f16_1-MLC [2024-04-12 16:12:26] INFO chat_module.py:380: Using mlc chat config: /home/ruihang/Workspace/mlc-llm/dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json [2024-04-12 16:12:26] INFO chat_module.py:529: Using library model: dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so [2024-04-12 16:12:29] INFO engine_base.py:382: Engine mode is "local". Max batch size is set to 4. Max KV cache token capacity is set to 4096. Prefill chunk size is set to 4096. [2024-04-12 16:12:29] INFO engine_base.py:387: Estimated total single GPU memory usage: 21543.74 MB (Parameters: 16467.64 MB. KVCache: 4450.07 MB. Temporary buffer: 626.03 MB). The actual usage might be slightly larger than the estimated number. [2024-04-12 16:12:29] INFO engine_base.py:398: Please switch to mode "server" if you want to use more GPU memory and support more concurrent requests. ``` After the refactor, we bring the speculative decoding to the serve CLI so that people can use multiple models and run speculative decoding with the server launched in CLI (which was not doable before). --- cpp/serve/config.cc | 14 +- cpp/serve/config.h | 16 +- cpp/serve/engine.cc | 34 +- cpp/serve/engine.h | 4 +- cpp/serve/engine_actions/action.h | 9 +- .../eagle_new_request_prefill.cc | 15 +- .../engine_actions/new_request_prefill.cc | 15 +- python/mlc_llm/cli/serve.py | 26 +- python/mlc_llm/help.py | 40 ++ python/mlc_llm/interface/serve.py | 28 +- python/mlc_llm/serve/__init__.py | 2 +- python/mlc_llm/serve/config.py | 56 +-- python/mlc_llm/serve/engine.py | 228 ++++++++-- python/mlc_llm/serve/engine_base.py | 408 ++++++++++++------ python/mlc_llm/serve/server/popen_server.py | 34 +- python/mlc_llm/serve/sync_engine.py | 90 ++-- tests/python/json_ffi/test_json_ffi_engine.py | 96 +++-- tests/python/serve/benchmark.py | 19 +- tests/python/serve/evaluate_engine.py | 20 +- tests/python/serve/test_serve_async_engine.py | 86 ++-- .../serve/test_serve_async_engine_spec.py | 32 +- tests/python/serve/test_serve_engine.py | 86 ++-- .../python/serve/test_serve_engine_grammar.py | 20 +- tests/python/serve/test_serve_engine_image.py | 20 +- tests/python/serve/test_serve_engine_spec.py | 252 ++++++----- tests/python/serve/test_serve_sync_engine.py | 88 ++-- 26 files changed, 1091 insertions(+), 647 deletions(-) diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 62394c4b21..ec9694ca1e 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -295,18 +295,18 @@ String KVCacheConfigNode::AsJSONString() const { return picojson::value(config).serialize(true); } -/****************** EngineMode ******************/ +/****************** EngineConfig ******************/ -TVM_REGISTER_OBJECT_TYPE(EngineModeNode); +TVM_REGISTER_OBJECT_TYPE(EngineConfigNode); -EngineMode::EngineMode(int spec_draft_length, int speculative_mode) { - ObjectPtr n = make_object(); +EngineConfig::EngineConfig(int spec_draft_length, int speculative_mode) { + ObjectPtr n = make_object(); n->spec_draft_length = spec_draft_length; n->speculative_mode = SpeculativeMode(speculative_mode); data_ = std::move(n); } -EngineMode::EngineMode(const std::string& config_str) { +EngineConfig::EngineConfig(const std::string& config_str) { int spec_draft_length = 4; int speculative_mode = 0; @@ -327,13 +327,13 @@ EngineMode::EngineMode(const std::string& config_str) { speculative_mode = config["speculative_mode"].get(); } - ObjectPtr n = make_object(); + ObjectPtr n = make_object(); n->spec_draft_length = spec_draft_length; n->speculative_mode = SpeculativeMode(speculative_mode); data_ = std::move(n); } -String EngineModeNode::AsJSONString() const { +String EngineConfigNode::AsJSONString() const { picojson::object config; config["spec_draft_length"] = picojson::value(static_cast(this->spec_draft_length)); config["speculative_mode"] = picojson::value(static_cast(this->speculative_mode)); diff --git a/cpp/serve/config.h b/cpp/serve/config.h index bee0af5561..214e9ccdd9 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -105,8 +105,8 @@ enum class SpeculativeMode : int { kEagle = 2, }; -/*! \brief The configuration of engine execution mode. */ -class EngineModeNode : public Object { +/*! \brief The configuration of engine execution config. */ +class EngineConfigNode : public Object { public: /* The number of tokens to generate in speculative proposal (draft) */ int spec_draft_length; @@ -115,19 +115,19 @@ class EngineModeNode : public Object { String AsJSONString() const; - static constexpr const char* _type_key = "mlc.serve.EngineMode"; + static constexpr const char* _type_key = "mlc.serve.EngineConfig"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(EngineModeNode, Object); + TVM_DECLARE_BASE_OBJECT_INFO(EngineConfigNode, Object); }; -class EngineMode : public ObjectRef { +class EngineConfig : public ObjectRef { public: - explicit EngineMode(int spec_draft_length, int speculative_mode); + explicit EngineConfig(int spec_draft_length, int speculative_mode); - explicit EngineMode(const std::string& config_str); + explicit EngineConfig(const std::string& config_str); - TVM_DEFINE_OBJECT_REF_METHODS(EngineMode, ObjectRef, EngineModeNode); + TVM_DEFINE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); }; } // namespace serve diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index d9530c22fe..7f764d3fb6 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -45,7 +45,7 @@ class EngineImpl : public Engine { /********************** Engine Management **********************/ explicit EngineImpl(int max_single_sequence_length, const String& tokenizer_path, - const String& kv_cache_config_json_str, const String& engine_mode_json_str, + const String& kv_cache_config_json_str, const String& engine_config_json_str, Optional request_stream_callback, Optional trace_recorder, const std::vector>& model_infos) { @@ -57,7 +57,7 @@ class EngineImpl : public Engine { ? max_single_sequence_length : std::numeric_limits::max(); this->kv_cache_config_ = KVCacheConfig(kv_cache_config_json_str, max_single_sequence_length); - this->engine_mode_ = EngineMode(engine_mode_json_str); + this->engine_config_ = EngineConfig(engine_config_json_str); this->request_stream_callback_ = std::move(request_stream_callback); this->trace_recorder_ = trace_recorder; this->tokenizer_ = Tokenizer::FromPath(tokenizer_path); @@ -84,29 +84,29 @@ class EngineImpl : public Engine { ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); } int max_num_tokens = kv_cache_config_->max_num_sequence; - if (engine_mode_->speculative_mode != SpeculativeMode::kDisable) { - max_num_tokens *= engine_mode_->spec_draft_length; + if (engine_config_->speculative_mode != SpeculativeMode::kDisable) { + max_num_tokens *= engine_config_->spec_draft_length; } LogitProcessor logit_processor = this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); Sampler sampler = this->models_[0]->CreateSampler( max_num_tokens, static_cast(this->models_.size()), trace_recorder); // Step 3. Initialize engine actions that represent state transitions. - if (this->engine_mode_->speculative_mode != SpeculativeMode::kDisable) { + if (this->engine_config_->speculative_mode != SpeculativeMode::kDisable) { // Speculative decoding is only possible for more than one model. ICHECK_GT(this->models_.size(), 1U); - switch (this->engine_mode_->speculative_mode) { + switch (this->engine_config_->speculative_mode) { case SpeculativeMode::kEagle: this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, // logit_processor, // sampler, // this->model_workspaces_, // this->kv_cache_config_, // - this->engine_mode_, // + this->engine_config_, // this->trace_recorder_), EngineAction::EagleBatchDraft( this->models_, logit_processor, sampler, this->model_workspaces_, - this->trace_recorder_, this->engine_mode_->spec_draft_length), + this->trace_recorder_, this->engine_config_->spec_draft_length), EngineAction::EagleBatchVerify( this->models_, logit_processor, sampler, this->model_workspaces_, this->kv_cache_config_, this->trace_recorder_)}; @@ -118,11 +118,11 @@ class EngineImpl : public Engine { sampler, // this->model_workspaces_, // this->kv_cache_config_, // - this->engine_mode_, // + this->engine_config_, // this->trace_recorder_), EngineAction::BatchDraft(this->models_, logit_processor, sampler, this->trace_recorder_, - this->engine_mode_->spec_draft_length), + this->engine_config_->spec_draft_length), EngineAction::BatchVerify(this->models_, logit_processor, sampler, this->kv_cache_config_, this->trace_recorder_)}; } @@ -132,7 +132,7 @@ class EngineImpl : public Engine { sampler, // this->model_workspaces_, // this->kv_cache_config_, // - this->engine_mode_, // + this->engine_config_, // this->trace_recorder_), EngineAction::BatchDecode(this->models_, logit_processor, sampler, this->trace_recorder_)}; @@ -289,7 +289,7 @@ class EngineImpl : public Engine { EngineState estate_; // Configurations and singletons KVCacheConfig kv_cache_config_; - EngineMode engine_mode_; + EngineConfig engine_config_; int max_single_sequence_length_; Tokenizer tokenizer_; std::vector token_table_; @@ -309,11 +309,11 @@ class EngineImpl : public Engine { std::unique_ptr Engine::Create( int max_single_sequence_length, const String& tokenizer_path, - const String& kv_cache_config_json_str, const String& engine_mode_json_str, + const String& kv_cache_config_json_str, const String& engine_config_json_str, Optional request_stream_callback, Optional trace_recorder, const std::vector>& model_infos) { return std::make_unique( - max_single_sequence_length, tokenizer_path, kv_cache_config_json_str, engine_mode_json_str, + max_single_sequence_length, tokenizer_path, kv_cache_config_json_str, engine_config_json_str, request_stream_callback, std::move(trace_recorder), model_infos); } @@ -333,7 +333,7 @@ std::unique_ptr CreateEnginePacked(TVMArgs args) { int max_single_sequence_length; std::string tokenizer_path; std::string kv_cache_config_json_str; - std::string engine_mode_json_str; + std::string engine_config_json_str; Optional request_stream_callback; Optional trace_recorder; std::vector> model_infos; @@ -344,7 +344,7 @@ std::unique_ptr CreateEnginePacked(TVMArgs args) { max_single_sequence_length = args.At(0); tokenizer_path = args.At(1); kv_cache_config_json_str = args.At(2); - engine_mode_json_str = args.At(3); + engine_config_json_str = args.At(3); request_stream_callback = args.At>(4); trace_recorder = args.At>(5); for (int i = 0; i < num_models; ++i) { @@ -359,7 +359,7 @@ std::unique_ptr CreateEnginePacked(TVMArgs args) { LOG(FATAL) << "ValueError: " << e.what() << kEngineCreationErrorMessage; } return Engine::Create(max_single_sequence_length, tokenizer_path, kv_cache_config_json_str, - engine_mode_json_str, request_stream_callback, std::move(trace_recorder), + engine_config_json_str, request_stream_callback, std::move(trace_recorder), model_infos); } diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index 973be50093..cb31304b5b 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -54,7 +54,7 @@ class Engine { * sequence length supported by the engine. * \param tokenizer_path The tokenizer path on disk. * \param kv_cache_config_json_str The KV cache config in JSON string. - * \param engine_mode_json_str The Engine execution mode in JSON string. + * \param engine_config_json_str The Engine execution configuration in JSON string. * \param request_stream_callback The request stream callback function to * stream back generated output for requests. * \param trace_recorder Event trace recorder for requests. @@ -67,7 +67,7 @@ class Engine { */ static std::unique_ptr Create( int max_single_sequence_length, const String& tokenizer_path, - const String& kv_cache_config_json_str, const String& engine_mode_json_str, + const String& kv_cache_config_json_str, const String& engine_config_json_str, Optional request_stream_callback, Optional trace_recorder, const std::vector>& model_infos); diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index 1385befddf..1c2387e834 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -57,14 +57,14 @@ class EngineAction : public ObjectRef { * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. * \param kv_cache_config The KV cache config to help decide prefill is doable. - * \param engine_mode The engine operation mode. + * \param engine_config The engine operation mode. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction NewRequestPrefill(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineMode engine_mode, + KVCacheConfig kv_cache_config, EngineConfig engine_config, Optional trace_recorder); /*! * \brief Create the action that prefills requests in the `waiting_queue` @@ -74,14 +74,15 @@ class EngineAction : public ObjectRef { * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. * \param kv_cache_config The KV cache config to help decide prefill is doable. - * \param engine_mode The engine operation mode. + * \param engine_config The engine operation mode. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction EagleNewRequestPrefill(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineMode engine_mode, + KVCacheConfig kv_cache_config, + EngineConfig engine_config, Optional trace_recorder); /*! * \brief Create the action that runs one-step decode for requests in the diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index 90c8ac3be8..7ed84feb86 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -24,14 +24,15 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { explicit EagleNewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineMode engine_mode, + KVCacheConfig kv_cache_config, + EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), kv_cache_config_(std::move(kv_cache_config)), - engine_mode_(std::move(engine_mode)), + engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)) {} Array Step(EngineState estate) final { @@ -421,8 +422,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // No exceeding of the maximum allowed requests that can // run simultaneously. - int spec_factor = engine_mode_->speculative_mode != SpeculativeMode::kDisable - ? engine_mode_->spec_draft_length + int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable + ? engine_config_->spec_draft_length : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > std::min(kv_cache_config_->max_num_sequence, kv_cache_config_->prefill_chunk_size)) { @@ -546,7 +547,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { /*! \brief The KV cache config to help decide prefill is doable. */ KVCacheConfig kv_cache_config_; /*! \brief The engine operation mode. */ - EngineMode engine_mode_; + EngineConfig engine_config_; /*! \brief Event trace recorder. */ Optional trace_recorder_; }; @@ -555,11 +556,11 @@ EngineAction EngineAction::EagleNewRequestPrefill(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, KVCacheConfig kv_cache_config, - EngineMode engine_mode, + EngineConfig engine_config, Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(kv_cache_config), std::move(engine_mode), + std::move(model_workspaces), std::move(kv_cache_config), std::move(engine_config), std::move(trace_recorder))); } diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index 288bb9ad83..1e7d798c26 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -23,14 +23,14 @@ class NewRequestPrefillActionObj : public EngineActionObj { public: explicit NewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineMode engine_mode, + KVCacheConfig kv_cache_config, EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), kv_cache_config_(std::move(kv_cache_config)), - engine_mode_(std::move(engine_mode)), + engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)) {} Array Step(EngineState estate) final { @@ -360,8 +360,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { // No exceeding of the maximum allowed requests that can // run simultaneously. - int spec_factor = engine_mode_->speculative_mode != SpeculativeMode::kDisable - ? engine_mode_->spec_draft_length + int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable + ? engine_config_->spec_draft_length : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > std::min(kv_cache_config_->max_num_sequence, kv_cache_config_->prefill_chunk_size)) { @@ -465,7 +465,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { /*! \brief The KV cache config to help decide prefill is doable. */ KVCacheConfig kv_cache_config_; /*! \brief The engine operation mode. */ - EngineMode engine_mode_; + EngineConfig engine_config_; /*! \brief Event trace recorder. */ Optional trace_recorder_; }; @@ -473,11 +473,12 @@ class NewRequestPrefillActionObj : public EngineActionObj { EngineAction EngineAction::NewRequestPrefill(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineMode engine_mode, + KVCacheConfig kv_cache_config, + EngineConfig engine_config, Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(kv_cache_config), std::move(engine_mode), + std::move(model_workspaces), std::move(kv_cache_config), std::move(engine_config), std::move(trace_recorder))); } diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index 4ad2319390..48a72327e2 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -4,6 +4,7 @@ from mlc_llm.help import HELP from mlc_llm.interface.serve import serve +from mlc_llm.serve.config import EngineConfig from mlc_llm.support.argparse import ArgumentParser @@ -29,15 +30,28 @@ def main(argv): help=HELP["model_lib_path"] + ' (default: "%(default)s")', ) parser.add_argument( - "--max-batch-size", - type=int, - default=80, - help=HELP["max_batch_size"] + ' (default: "%(default)s")', + "--mode", + type=str, + choices=["local", "interactive", "server"], + default="local", + help=HELP["mode_serve"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--additional-models", type=str, nargs="*", help=HELP["additional_models_serve"] ) + parser.add_argument("--max-batch-size", type=int, help=HELP["max_batch_size"]) parser.add_argument( "--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"] ) parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"]) + parser.add_argument( + "--gpu-memory-utilization", type=float, help=HELP["gpu_memory_utilization_serve"] + ) + parser.add_argument( + "--engine-config", + type=EngineConfig.from_str, + help=HELP["engine_config_serve"] + ' (default: "%(default)s")', + ) parser.add_argument("--enable-tracing", action="store_true", help=HELP["enable_tracing_serve"]) parser.add_argument( "--host", @@ -76,9 +90,13 @@ def main(argv): model=parsed.model, device=parsed.device, model_lib_path=parsed.model_lib_path, + mode=parsed.mode, + additional_models=parsed.additional_models, max_batch_size=parsed.max_batch_size, max_total_sequence_length=parsed.max_total_seq_length, prefill_chunk_size=parsed.prefill_chunk_size, + gpu_memory_utilization=parsed.gpu_memory_utilization, + engine_config=parsed.engine_config, enable_tracing=parsed.enable_tracing, host=parsed.host, port=parsed.port, diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index 13335c99c1..ffea30c303 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -159,4 +159,44 @@ to get the Chrome Trace. For example, "curl -X POST http://127.0.0.1:8000/debug/dump_event_trace -H "Content-Type: application/json" -d '{"model": "dist/llama"}'" """.strip(), + "mode_serve": """ +The engine mode in MLC LLM. We provide three preset modes: "local", "interactive" and "server". +The default mode is "local". +The choice of mode decides the values of "--max-batch-size", "--max-total-seq-length" and +"--prefill-chunk-size" when they are not explicitly specified. +1. Mode "local" refers to the local server deployment which has low request concurrency. + So the max batch size will be set to 4, and max total sequence length and prefill chunk size + are set to the context window size (or sliding window size) of the model. +2. Mode "interactive" refers to the interactive use of server, which has at most 1 concurrent + request. So the max batch size will be set to 1, and max total sequence length and prefill + chunk size are set to the context window size (or sliding window size) of the model. +3. Mode "server" refers to the large server use case which may handle many concurrent request + and want to use GPU memory as much as possible. In this mode, we will automatically infer + the largest possible max batch size and max total sequence length. +You can manually specify arguments "--max-batch-size", "--max-total-seq-length" and +"--prefill-chunk-size" to override the automatic inferred values. +""".strip(), + "additional_models_serve": """ +The model paths and (optional) model library paths of additional models (other than the main model). +When engine is enabled with speculative decoding, additional models are needed. +The way of specifying additional models is: +"--additional-models model_path_1 model_path_2 ..." or +"--additional-models model_path_1:model_lib_path_1 model_path_2 ...". +When the model lib path of a model is not given, JIT model compilation will be activated +to compile the model automatically. +""", + "gpu_memory_utilization_serve": """ +A number in (0, 1) denoting the fraction of GPU memory used by the server in total. +It is used to infer to maximum possible KV cache capacity. +When it is unspecified, it defaults to 0.90. +Under mode "local" or "interactive", the actual memory usage may be significantly smaller than +this number. Under mode "server", the actual memory usage may be slightly larger than this number. +""", + "engine_config_serve": """ +The Engine execution configuration. +Currently speculative decoding mode is specified via engine config. +For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" to +specify the eagle-style speculative decoding. +Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. +""", } diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index df64488a72..bdbb633414 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -1,13 +1,14 @@ """Python entrypoint of serve.""" -from typing import Any, Optional +from typing import Any, List, Literal, Optional import fastapi import uvicorn from fastapi.middleware.cors import CORSMiddleware from mlc_llm.protocol import error_protocol -from mlc_llm.serve import config, engine, engine_base +from mlc_llm.serve import engine +from mlc_llm.serve.config import EngineConfig from mlc_llm.serve.entrypoints import debug_entrypoints, openai_entrypoints from mlc_llm.serve.server import ServerContext @@ -16,9 +17,13 @@ def serve( model: str, device: str, model_lib_path: Optional[str], - max_batch_size: int, + mode: Literal["local", "interactive", "server"], + additional_models: List[str], + max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + gpu_memory_utilization: Optional[float], + engine_config: Optional[EngineConfig], enable_tracing: bool, host: str, port: int, @@ -28,19 +33,20 @@ def serve( allow_headers: Any, ): # pylint: disable=too-many-arguments, too-many-locals """Serve the model with the specified configuration.""" - # Initialize model loading info and KV cache config - model_info = engine_base.ModelInfo( + # Create engine and start the background loop + async_engine = engine.AsyncEngine( model=model, - model_lib_path=model_lib_path, device=device, - ) - kv_cache_config = config.KVCacheConfig( - max_num_sequence=max_batch_size, + model_lib_path=model_lib_path, + mode=mode, + additional_models=additional_models, + max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, + gpu_memory_utilization=gpu_memory_utilization, + engine_config=engine_config, + enable_tracing=enable_tracing, ) - # Create engine and start the background loop - async_engine = engine.AsyncEngine(model_info, kv_cache_config, enable_tracing=enable_tracing) with ServerContext() as server_context: server_context.add_model(model, async_engine) diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 764ec44198..abbedc911e 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -2,7 +2,7 @@ # Load MLC LLM library by importing base from .. import base -from .config import EngineMode, GenerationConfig, KVCacheConfig, SpeculativeMode +from .config import EngineConfig, GenerationConfig, KVCacheConfig, SpeculativeMode from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData from .engine import AsyncEngine, Engine from .grammar import BNFGrammar, GrammarStateMatcher diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 32460d2dde..77bca9b462 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -1,8 +1,10 @@ """Configuration dataclasses used in MLC LLM serving""" +import argparse import enum import json from dataclasses import asdict, dataclass, field +from io import StringIO from typing import Dict, List, Literal, Optional @@ -163,7 +165,7 @@ def from_json(json_str: str) -> "KVCacheConfig": return KVCacheConfig(**json.loads(json_str)) -class SpeculativeMode(enum.Enum): +class SpeculativeMode(enum.IntEnum): """The speculative mode.""" DISABLE = 0 @@ -171,30 +173,12 @@ class SpeculativeMode(enum.Enum): EAGLE = 2 -def speculative_mode_to_int(speculative_mode: SpeculativeMode): - """Convert speculative mode to int value - - Parameters - ---------- - speculative_mode (SpeculativeMode): - the speculative mode - """ - if speculative_mode == SpeculativeMode.DISABLE: - return 0 - if speculative_mode == SpeculativeMode.SMALL_DRAFT: - return 1 - if speculative_mode == SpeculativeMode.EAGLE: - return 2 - raise RuntimeError("Unknown speculative mode.") - - @dataclass -class EngineMode: - """The Engine execution mode. +class EngineConfig: + """The class of Engine execution configuration. Parameters ---------- - spec_draft_length : int The number of tokens to generate in speculative proposal (draft), default 4. @@ -205,13 +189,37 @@ class EngineMode: spec_draft_length: int = 4 speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE + def __repr__(self) -> str: + out = StringIO() + print(f"spec_draft_length={self.spec_draft_length}", file=out, end="") + print(f";speculative_mode={self.speculative_mode.name}", file=out, end="") + return out.getvalue().rstrip() + def asjson(self) -> str: """Return the config in string of JSON format.""" dt = asdict(self) - dt["speculative_mode"] = speculative_mode_to_int(self.speculative_mode) + dt["speculative_mode"] = int(self.speculative_mode) return json.dumps(dt) @staticmethod - def from_json(json_str: str) -> "EngineMode": + def from_json(json_str: str) -> "EngineConfig": """Construct a config from JSON string.""" - return EngineMode(**json.loads(json_str)) + return EngineConfig(**json.loads(json_str)) + + @staticmethod + def from_str(source: str) -> "EngineConfig": + """Parse engine config from a string.""" + + parser = argparse.ArgumentParser(description="optimization flags") + parser.add_argument("--spec_draft_length", type=int, default=4) + parser.add_argument( + "--speculative_mode", + type=str, + choices=["DISABLE", "SMALL_DRAFT", "EAGLE"], + default="DISABLE", + ) + results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) + return EngineConfig( + spec_draft_length=results.spec_draft_length, + speculative_mode=SpeculativeMode[results.speculative_mode], + ) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index b822285d44..99c455f3cd 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -18,9 +18,11 @@ overload, ) +from tvm.runtime import Device + from mlc_llm.protocol import openai_api_protocol from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineMode, GenerationConfig, KVCacheConfig +from mlc_llm.serve.config import EngineConfig, GenerationConfig from mlc_llm.serve.request import Request from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -756,28 +758,112 @@ class AsyncEngine(engine_base.EngineBase): Parameters ---------- - models : Union[ModelInfo, List[ModelInfo]] - One or a list of model info (specifying which models to load and - which device to load to) to launch the engine. - - kv_cache_config : KVCacheConfig - The configuration of the paged KV cache. - - engine_mode : Optional[EngineMode] - The Engine execution mode. + models : str + A path to ``mlc-chat-config.json``, or an MLC model directory that contains + `mlc-chat-config.json`. + It can also be a link to a HF repository pointing to an MLC compiled model. + + device: Union[str, Device] + The device used to deploy the model such as "cuda" or "cuda:0". + Will default to "auto" and detect from local available GPUs if not specified. + + model_lib_path : Optional[str] + The full path to the model library file to use (e.g. a ``.so`` file). + If unspecified, we will use the provided ``model`` to search over possible paths. + It the model lib path is not found, it will be compiled in a JIT manner. + + mode : Literal["local", "interactive", "server"] + The engine mode in MLC LLM. + We provide three preset modes: "local", "interactive" and "server". + The default mode is "local". + The choice of mode decides the values of "max_batch_size", "max_total_sequence_length" + and "prefill_chunk_size" when they are not explicitly specified. + 1. Mode "local" refers to the local server deployment which has low + request concurrency. So the max batch size will be set to 4, and max + total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 2. Mode "interactive" refers to the interactive use of server, which + has at most 1 concurrent request. So the max batch size will be set to 1, + and max total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 3. Mode "server" refers to the large server use case which may handle + many concurrent request and want to use GPU memory as much as possible. + In this mode, we will automatically infer the largest possible max batch + size and max total sequence length. + + You can manually specify arguments "max_batch_size", "max_total_sequence_length" and + "prefill_chunk_size" to override the automatic inferred values. + + additional_models : Optional[List[str]] + The model paths and (optional) model library paths of additional models + (other than the main model). + When engine is enabled with speculative decoding, additional models are needed. + Each string in the list is either in form "model_path" or "model_path:model_lib_path". + When the model lib path of a model is not given, JIT model compilation will + be activated to compile the model automatically. + + max_batch_size : Optional[int] + The maximum allowed batch size set for the KV cache to concurrently support. + + max_total_sequence_length : Optional[int] + The KV cache total token capacity, i.e., the maximum total number of tokens that + the KV cache support. This decides the GPU memory size that the KV cache consumes. + If not specified, system will automatically estimate the maximum capacity based + on the vRAM size on GPU. + + prefill_chunk_size : Optional[int] + The maximum number of tokens the model passes for prefill each time. + It should not exceed the prefill chunk size in model config. + If not specified, this defaults to the prefill chunk size in model config. + + gpu_memory_utilization : Optional[float] + A number in (0, 1) denoting the fraction of GPU memory used by the server in total. + It is used to infer to maximum possible KV cache capacity. + When it is unspecified, it defaults to 0.90. + Under mode "local" or "interactive", the actual memory usage may be + significantly smaller than this number. Under mode "server", the actual + memory usage may be slightly larger than this number. + + engine_config : Optional[EngineConfig] + The Engine execution configuration. + Currently speculative decoding mode is specified via engine config. + For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" + to specify the eagle-style speculative decoding. + Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. enable_tracing : bool A boolean indicating if to enable event logging for requests. """ - def __init__( + def __init__( # pylint: disable=too-many-arguments self, - models: Union[engine_base.ModelInfo, List[engine_base.ModelInfo]], - kv_cache_config: KVCacheConfig, - engine_mode: Optional[EngineMode] = None, + model: str, + device: Union[str, Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + gpu_memory_utilization: Optional[float] = None, + engine_config: Optional[EngineConfig] = None, enable_tracing: bool = False, ) -> None: - super().__init__("async", models, kv_cache_config, engine_mode, enable_tracing) + super().__init__( + "async", + model=model, + device=device, + model_lib_path=model_lib_path, + mode=mode, + additional_models=additional_models, + max_batch_size=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + gpu_memory_utilization=gpu_memory_utilization, + engine_config=engine_config, + enable_tracing=enable_tracing, + ) self.chat = Chat(weakref.ref(self)) self.completions = AsyncCompletion(weakref.ref(self)) @@ -1215,28 +1301,112 @@ class Engine(engine_base.EngineBase): Parameters ---------- - models : Union[ModelInfo, List[ModelInfo]] - One or a list of model info (specifying which models to load and - which device to load to) to launch the engine. - - kv_cache_config : KVCacheConfig - The configuration of the paged KV cache. - - engine_mode : Optional[EngineMode] - The Engine execution mode. + models : str + A path to ``mlc-chat-config.json``, or an MLC model directory that contains + `mlc-chat-config.json`. + It can also be a link to a HF repository pointing to an MLC compiled model. + + device: Union[str, Device] + The device used to deploy the model such as "cuda" or "cuda:0". + Will default to "auto" and detect from local available GPUs if not specified. + + model_lib_path : Optional[str] + The full path to the model library file to use (e.g. a ``.so`` file). + If unspecified, we will use the provided ``model`` to search over possible paths. + It the model lib path is not found, it will be compiled in a JIT manner. + + mode : Literal["local", "interactive", "server"] + The engine mode in MLC LLM. + We provide three preset modes: "local", "interactive" and "server". + The default mode is "local". + The choice of mode decides the values of "max_batch_size", "max_total_sequence_length" + and "prefill_chunk_size" when they are not explicitly specified. + 1. Mode "local" refers to the local server deployment which has low + request concurrency. So the max batch size will be set to 4, and max + total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 2. Mode "interactive" refers to the interactive use of server, which + has at most 1 concurrent request. So the max batch size will be set to 1, + and max total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 3. Mode "server" refers to the large server use case which may handle + many concurrent request and want to use GPU memory as much as possible. + In this mode, we will automatically infer the largest possible max batch + size and max total sequence length. + + You can manually specify arguments "max_batch_size", "max_total_sequence_length" and + "prefill_chunk_size" to override the automatic inferred values. + + additional_models : Optional[List[str]] + The model paths and (optional) model library paths of additional models + (other than the main model). + When engine is enabled with speculative decoding, additional models are needed. + Each string in the list is either in form "model_path" or "model_path:model_lib_path". + When the model lib path of a model is not given, JIT model compilation will + be activated to compile the model automatically. + + max_batch_size : Optional[int] + The maximum allowed batch size set for the KV cache to concurrently support. + + max_total_sequence_length : Optional[int] + The KV cache total token capacity, i.e., the maximum total number of tokens that + the KV cache support. This decides the GPU memory size that the KV cache consumes. + If not specified, system will automatically estimate the maximum capacity based + on the vRAM size on GPU. + + prefill_chunk_size : Optional[int] + The maximum number of tokens the model passes for prefill each time. + It should not exceed the prefill chunk size in model config. + If not specified, this defaults to the prefill chunk size in model config. + + gpu_memory_utilization : Optional[float] + A number in (0, 1) denoting the fraction of GPU memory used by the server in total. + It is used to infer to maximum possible KV cache capacity. + When it is unspecified, it defaults to 0.90. + Under mode "local" or "interactive", the actual memory usage may be + significantly smaller than this number. Under mode "server", the actual + memory usage may be slightly larger than this number. + + engine_config : Optional[EngineConfig] + The Engine execution configuration. + Currently speculative decoding mode is specified via engine config. + For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" + to specify the eagle-style speculative decoding. + Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. enable_tracing : bool A boolean indicating if to enable event logging for requests. """ - def __init__( + def __init__( # pylint: disable=too-many-arguments self, - models: Union[engine_base.ModelInfo, List[engine_base.ModelInfo]], - kv_cache_config: KVCacheConfig, - engine_mode: Optional[EngineMode] = None, + model: str, + device: Union[str, Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + gpu_memory_utilization: Optional[float] = None, + engine_config: Optional[EngineConfig] = None, enable_tracing: bool = False, ) -> None: - super().__init__("sync", models, kv_cache_config, engine_mode, enable_tracing) + super().__init__( + "sync", + model=model, + device=device, + model_lib_path=model_lib_path, + mode=mode, + additional_models=additional_models, + max_batch_size=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + gpu_memory_utilization=gpu_memory_utilization, + engine_config=engine_config, + enable_tracing=enable_tracing, + ) self.chat = Chat(weakref.ref(self)) self.completions = Completion(weakref.ref(self)) diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index fadd38978d..421cd187f7 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -5,7 +5,6 @@ import ast import asyncio import json -import os import queue import subprocess import sys @@ -21,7 +20,7 @@ from mlc_llm.protocol import openai_api_protocol, protocol_utils from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineMode, GenerationConfig, KVCacheConfig +from mlc_llm.serve.config import EngineConfig, GenerationConfig, KVCacheConfig from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -45,50 +44,49 @@ class ModelInfo: or a full path to a model directory (e.g., "dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1") - device : str - The device where to run the model. - It can be "auto", "device_name" (e.g., "cuda") or - "device_name:device_id" (e.g., "cuda:1"). - - model_lib_path : str + model_lib_path : Optional[str] The path to the compiled library of the model. E.g., "dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so" """ model: str - model_lib_path: str - device: Device = "auto" # type: ignore + model_lib_path: Optional[str] = None + + +def _parse_models( + model: str, model_lib_path: Optional[str], additional_models: Optional[List[str]] +) -> List[ModelInfo]: + """Parse the specified model paths and model lib paths. + Return a list of ModelInfo, which is a wrapper class of the model path + lib path. - def __post_init__(self): - if isinstance(self.device, str): - self.device = detect_device(self.device) - assert isinstance(self.device, Device) + Each additional model is expected to follow the format of either + "{MODEL_PATH}" or "{MODEL_PATH}:{MODEL_LIB_PATH}". + """ + models = [ModelInfo(model, model_lib_path)] + if additional_models is not None: + for additional_model in additional_models: + splits = additional_model.split(":", maxsplit=1) + if len(splits) == 2: + models.append(ModelInfo(splits[0], splits[1])) + else: + models.append(ModelInfo(splits[0])) + return models def _process_model_args( - models: List[ModelInfo], -) -> Tuple[List[Any], List[str], str, int, int, Conversation]: + models: List[ModelInfo], device: tvm.runtime.Device +) -> Tuple[List[Any], List[str], str, Conversation]: """Process the input ModelInfo to get the engine initialization arguments.""" - max_single_sequence_length = int(1e9) - prefill_chunk_size = int(1e9) tokenizer_path: Optional[str] = None conversation: Optional[Conversation] = None config_file_paths: List[str] = [] def _convert_model_info(model: ModelInfo) -> List[Any]: - nonlocal max_single_sequence_length, prefill_chunk_size, tokenizer_path, conversation + nonlocal tokenizer_path, conversation - device = model.device model_path, config_file_path = _get_model_path(model.model) config_file_paths.append(config_file_path) chat_config = _get_chat_config(config_file_path, user_chat_config=None) - if chat_config.context_window_size and chat_config.context_window_size != -1: - max_single_sequence_length = min( - max_single_sequence_length, - chat_config.context_window_size, - ) - if chat_config.prefill_chunk_size: - prefill_chunk_size = min(prefill_chunk_size, chat_config.prefill_chunk_size) if tokenizer_path is None: tokenizer_path = model_path if conversation is None: @@ -121,22 +119,21 @@ def _convert_model_info(model: ModelInfo) -> List[Any]: start=[], ) - assert prefill_chunk_size != int(1e9) assert conversation is not None - return ( - model_args, - config_file_paths, - tokenizer_path, - max_single_sequence_length, - prefill_chunk_size, - conversation, - ) + return model_args, config_file_paths, tokenizer_path, conversation -def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals - models: List[ModelInfo], config_file_paths: List[str], max_num_sequence: int -) -> int: - """Estimate the max total sequence length (capacity) of the KV cache.""" +def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-many-locals,too-many-arguments + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_paths: List[str], + model_config_dicts: List[Dict[str, Any]], + max_num_sequence: int, + gpu_memory_utilization: Optional[float], +) -> Tuple[float, float, float, float, float, int]: + """Estimate the memory usage and the max total sequence length (capacity) + that the KV cache can support. + """ assert len(models) != 0 kv_bytes_per_token = 0 @@ -146,7 +143,9 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals params_bytes = 0 temp_func_bytes = 0 - for model, config_file_path in zip(models, config_file_paths): + for model, model_config_path, model_config_dict in zip( + models, model_config_paths, model_config_dicts + ): # Read metadata for the parameter size and the temporary memory size. cmd = [ sys.executable, @@ -155,7 +154,7 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals model.model_lib_path, "--print-memory-usage-in-json", "--mlc-chat-config", - config_file_path, + model_config_path, ] usage_str = subprocess.check_output(cmd, universal_newlines=True) usage_json = json.loads(usage_str) @@ -173,16 +172,14 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals kv_cache_metadata = json.loads(kv_cache_metadata_str) # Read model config and compute the kv size per token. - with open(config_file_path, mode="rt", encoding="utf-8") as file: - json_object = json.load(file) - model_config = json_object["model_config"] - vocab_size = model_config["vocab_size"] - prefill_chunk_size = model_config["prefill_chunk_size"] - num_layers = kv_cache_metadata["num_hidden_layers"] - head_dim = kv_cache_metadata["head_dim"] - num_qo_heads = kv_cache_metadata["num_attention_heads"] - num_kv_heads = kv_cache_metadata["num_key_value_heads"] - hidden_size = head_dim * num_qo_heads + model_config = model_config_dict["model_config"] + vocab_size = model_config["vocab_size"] + prefill_chunk_size = model_config["prefill_chunk_size"] + num_layers = kv_cache_metadata["num_hidden_layers"] + head_dim = kv_cache_metadata["head_dim"] + num_qo_heads = kv_cache_metadata["num_attention_heads"] + num_kv_heads = kv_cache_metadata["num_key_value_heads"] + hidden_size = head_dim * num_qo_heads kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25 kv_aux_workspace_bytes += ( (max_num_sequence + 1) * 88 @@ -200,18 +197,15 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals ) # Get single-card GPU size. - gpu_size_bytes = os.environ.get("MLC_GPU_SIZE_BYTES", default=None) + gpu_size_bytes = device.total_global_memory if gpu_size_bytes is None: - gpu_size_bytes = models[0].device.total_global_memory - if gpu_size_bytes is None: - raise ValueError( - "Cannot read total GPU global memory from device. " - 'Please the GPU memory size in bytes through "MLC_GPU_SIZE_BYTES" env variable.' - ) + raise ValueError("Cannot read total GPU global memory from device.") + if gpu_memory_utilization is None: + gpu_memory_utilization = 0.90 - max_total_sequence_length = int( + model_max_total_sequence_length = int( ( - int(gpu_size_bytes) * 0.90 + int(gpu_size_bytes) * gpu_memory_utilization - params_bytes - temp_func_bytes - kv_aux_workspace_bytes @@ -220,38 +214,205 @@ def _estimate_max_total_sequence_length( # pylint: disable=too-many-locals ) / kv_bytes_per_token ) - assert max_total_sequence_length > 0, ( - "Cannot estimate KV cache capacity. " - f"The model weight size {params_bytes} may be larger than GPU memory size {gpu_size_bytes}" - ) + if model_max_total_sequence_length <= 0: + raise ValueError( + f"The model weight size {params_bytes} may be larger than available GPU memory " + f"size {gpu_size_bytes * gpu_memory_utilization} bytes." + ) - if models[0].device.device_type == Device.kDLMetal: + if device.device_type == Device.kDLMetal: # NOTE: Metal runtime has severe performance issues with large buffers. # To work around the issue, we limit the KV cache capacity to 32768. - max_total_sequence_length = min(max_total_sequence_length, 32768) + model_max_total_sequence_length = min(model_max_total_sequence_length, 32768) - total_size = ( + total_mem_usage_except_kv_cache = ( params_bytes + temp_func_bytes + kv_aux_workspace_bytes + model_workspace_bytes + logit_processor_workspace_bytes - + kv_bytes_per_token * max_total_sequence_length ) - logger.info( - "%s: %d.", - green('Estimated KVCacheConfig "max_total_sequence_length"'), - max_total_sequence_length, + return ( + total_mem_usage_except_kv_cache, + params_bytes, + kv_bytes_per_token, + kv_aux_workspace_bytes, + model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes, + int(model_max_total_sequence_length), + ) + + +def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[int, int, int]: + """Read the model config dictionaries, and return the maximum single + sequence length the models can support, the maximum prefill chunk + size the models can support, and the max batch size the models can support. + + Returns + ------- + model_max_single_sequence_length : int + The maximum single sequence length the models can support. + model_max_prefill_chunk_size : int + The maximum prefill chunk size the models can support. + model_max_batch_size : int + The max batch size the models can support. + """ + model_max_single_sequence_length = int(1e9) + model_max_prefill_chunk_size = int(1e9) + model_max_batch_size = int(1e9) + for i, config in enumerate(model_config_dicts): + runtime_context_window_size = config["context_window_size"] + compile_time_context_window_size = config["model_config"]["context_window_size"] + if runtime_context_window_size > compile_time_context_window_size: + raise ValueError( + f"Model {i}'s runtime context window size ({runtime_context_window_size}) is " + "larger than the context window size used at compile time " + f"({compile_time_context_window_size})" + ) + if runtime_context_window_size == -1 and compile_time_context_window_size != -1: + raise ValueError( + f"Model {i}'s runtime context window size (infinite) is " + "larger than the context window size used at compile time " + f"({compile_time_context_window_size})" + ) + if runtime_context_window_size != -1: + model_max_single_sequence_length = min( + model_max_single_sequence_length, runtime_context_window_size + ) + + runtime_prefill_chunk_size = config["prefill_chunk_size"] + compile_time_prefill_chunk_size = config["model_config"]["prefill_chunk_size"] + if runtime_prefill_chunk_size > compile_time_prefill_chunk_size: + raise ValueError( + f"Model {i}'s runtime prefill chunk size ({runtime_prefill_chunk_size}) is " + "larger than the prefill chunk size used at compile time " + f"({compile_time_prefill_chunk_size})" + ) + model_max_prefill_chunk_size = min(model_max_prefill_chunk_size, runtime_prefill_chunk_size) + + model_max_batch_size = min(model_max_batch_size, config["model_config"]["max_batch_size"]) + + assert model_max_prefill_chunk_size != int(1e9) + assert model_max_batch_size != int(1e9) + return model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size + + +def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + gpu_memory_utilization: Optional[float], + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_dicts: List[Dict[str, Any]], + model_config_paths: List[str], +) -> Tuple[KVCacheConfig, int]: + """Initialize the KV cache config with user input and GPU memory usage estimation.""" + ( + model_max_single_sequence_length, + model_max_prefill_chunk_size, + model_max_batch_size, + ) = _get_model_config_limit(model_config_dicts) + + logging_msg = 'Engine mode is "' + green(mode) + '". ' + # - max_batch_size + if max_batch_size is None: + max_batch_size = ( + min(4, model_max_batch_size) + if mode == "local" + else (1 if mode == "interactive" else model_max_batch_size) + ) + logging_msg += "Max batch size is set to " + green(str(max_batch_size)) + ". " + else: + logging_msg += "Max batch size " + green(str(max_batch_size)) + " is specified by user. " + # - infer the maximum total sequence length that can fit GPU memory. + ( + total_mem_usage_except_kv_cache, + model_params_bytes, + kv_bytes_per_token, + kv_aux_workspace_bytes, + temp_workspace_bytes, + model_max_total_sequence_length, + ) = _estimate_mem_usage_and_max_total_sequence_length( + models, + device, + model_config_paths, + model_config_dicts, + max_batch_size, + gpu_memory_utilization, + ) + # - max_total_sequence_length + if max_total_sequence_length is None: + if mode == "local": + max_total_sequence_length = min( + model_max_total_sequence_length, model_max_single_sequence_length, 8192 + ) + elif mode == "interactive": + max_total_sequence_length = min( + model_max_total_sequence_length, model_max_single_sequence_length + ) + else: + max_total_sequence_length = min( + model_max_total_sequence_length, max_batch_size * model_max_single_sequence_length + ) + logging_msg += ( + "Max KV cache token capacity is set to " + green(str(max_total_sequence_length)) + ". " + ) + else: + logging_msg += ( + "Max KV cache token capacity " + + green(str(max_total_sequence_length)) + + " is specified by user. " + ) + # - prefill_chunk_size + if prefill_chunk_size is None: + if mode in ["local", "interactive"]: + prefill_chunk_size = min( + model_max_prefill_chunk_size, + model_max_total_sequence_length, + model_max_single_sequence_length, + ) + else: + prefill_chunk_size = model_max_prefill_chunk_size + logging_msg += "Prefill chunk size is set to " + green(str(prefill_chunk_size)) + ". " + else: + logging_msg += ( + "Prefill chunk size " + green(str(prefill_chunk_size)) + " is specified by user. " + ) + logger.info(logging_msg) + # - Estimate total GPU memory usage on single GPU. + total_mem_usage = ( + total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token ) logger.info( - "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB)", + "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB). " + "The actual usage might be slightly larger than the estimated number.", green("Estimated total single GPU memory usage"), - total_size / 1024 / 1024, - params_bytes / 1024 / 1024, + total_mem_usage / 1024 / 1024, + model_params_bytes / 1024 / 1024, (kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes) / 1024 / 1024, - (model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes) / 1024 / 1024, + temp_workspace_bytes / 1024 / 1024, + ) + # - Final messages + if mode in ["local", "interactive"]: + logger.info( + 'Please switch to mode "server" if you want to use more GPU memory ' + "and support more concurrent requests." + ) + else: + logger.info( + 'Please switch to mode "local" or "interactive" if you want to use less GPU memory ' + "or do not have many concurrent requests to process." + ) + + return ( + KVCacheConfig( + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + ), + model_max_single_sequence_length, ) - return int(max_total_sequence_length) @dataclass @@ -506,72 +667,63 @@ class EngineBase: # pylint: disable=too-many-instance-attributes,too-few-public from callback functions and yield the processed delta results in the forms of standard API protocols. - Parameters - ---------- - kind : Literal["async", "sync"] - The kind of the engine. "async" for AsyncEngine and "sync" for Engine. - - models : Union[ModelInfo, List[ModelInfo]] - One or a list of model info (specifying which models to load and - which device to load to) to launch the engine. - - kv_cache_config : KVCacheConfig - The configuration of the paged KV cache. - - engine_mode : Optional[EngineMode] - The Engine execution mode. - - enable_tracing : bool - A boolean indicating if to enable event logging for requests. + Checkout subclasses AsyncEngine/Engine for the docstring of constructor parameters. """ def __init__( # pylint: disable=too-many-arguments,too-many-locals self, kind: Literal["async", "sync"], - models: Union[ModelInfo, List[ModelInfo]], - kv_cache_config: KVCacheConfig, - engine_mode: Optional[EngineMode] = None, - enable_tracing: bool = False, + model: str, + device: Union[str, tvm.runtime.Device], + model_lib_path: Optional[str], + mode: Literal["local", "interactive", "server"], + additional_models: Optional[List[str]], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + gpu_memory_utilization: Optional[float], + engine_config: Optional[EngineConfig], + enable_tracing: bool, ) -> None: - if isinstance(models, ModelInfo): - models = [models] + # - Initialize model loading info. + models = _parse_models(model, model_lib_path, additional_models) + if isinstance(device, str): + device = detect_device(device) + assert isinstance(device, Device) ( model_args, - config_file_paths, + model_config_paths, tokenizer_path, - max_single_sequence_length, - prefill_chunk_size, self.conv_template, - ) = _process_model_args(models) + ) = _process_model_args(models, device) + # - Load the raw model config into dict self.model_config_dicts = [] - for i, model in enumerate(models): + for i, model_info in enumerate(models): # model_args: # [model_lib_path, model_path, device.device_type, device.device_id] * N - model.model_lib_path = model_args[i * (len(model_args) // len(models))] - with open(config_file_paths[i], "r", encoding="utf-8") as file: + model_info.model_lib_path = model_args[i * (len(model_args) // len(models))] + with open(model_config_paths[i], "r", encoding="utf-8") as file: self.model_config_dicts.append(json.load(file)) - self.state = EngineState(enable_tracing) - - if kv_cache_config.max_total_sequence_length is None: - kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( - models, config_file_paths, kv_cache_config.max_num_sequence - ) + # - Decide the KV cache config based on mode and user input. + kv_cache_config, max_single_sequence_length = _infer_kv_cache_config( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + gpu_memory_utilization, + models, + device, + self.model_config_dicts, + model_config_paths, + ) self.max_input_sequence_length = min( max_single_sequence_length, kv_cache_config.max_total_sequence_length ) - prefill_chunk_size = min(prefill_chunk_size, kv_cache_config.max_total_sequence_length) - - if kv_cache_config.prefill_chunk_size is None: - kv_cache_config.prefill_chunk_size = prefill_chunk_size - elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: - raise ValueError( - f"The specified prefill chunk size {kv_cache_config.prefill_chunk_size} is " - f"larger than the maximum prefill chunk size {prefill_chunk_size} supported by " - "models. Please specify a smaller prefill chunk size." - ) + # - Initialize engine state and engine. + self.state = EngineState(enable_tracing) module = tvm.get_global_func("mlc.serve.create_threaded_engine", allow_missing=False)() self._ffi = { key: module[key] @@ -585,16 +737,16 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals ] } self.tokenizer = Tokenizer(tokenizer_path) - if engine_mode is None: + if engine_config is None: # The default engine mode: non-speculative - engine_mode = EngineMode() + engine_config = EngineConfig() def _background_loop(): self._ffi["init_background_engine"]( max_single_sequence_length, tokenizer_path, kv_cache_config.asjson(), - engine_mode.asjson(), + engine_config.asjson(), self.state.get_request_stream_callback(kind), self.state.trace_recorder, *model_args, @@ -604,7 +756,7 @@ def _background_loop(): def _background_stream_back_loop(): self._ffi["run_background_stream_back_loop"]() - # Create the background engine-driving thread and start the loop. + # - Create the background engine-driving thread and start the loop. self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) self._background_stream_back_loop_thread: threading.Thread = threading.Thread( target=_background_stream_back_loop diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index 9529316010..86f92d7602 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -4,10 +4,13 @@ import sys import time from pathlib import Path -from typing import Optional +from typing import List, Literal, Optional, Union import psutil import requests +from tvm.runtime import Device + +from mlc_llm.serve.config import EngineConfig class PopenServer: # pylint: disable=too-many-instance-attributes @@ -17,11 +20,16 @@ class PopenServer: # pylint: disable=too-many-instance-attributes def __init__( # pylint: disable=too-many-arguments self, model: str, - device: str = "auto", + device: Union[str, Device] = "auto", *, model_lib_path: Optional[str] = None, - max_batch_size: int = 80, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + gpu_memory_utilization: Optional[float] = None, + engine_config: Optional[EngineConfig] = None, enable_tracing: bool = False, host: str = "127.0.0.1", port: int = 8000, @@ -30,14 +38,19 @@ def __init__( # pylint: disable=too-many-arguments self.model = model self.model_lib_path = model_lib_path self.device = device + self.mode = mode + self.additional_models = additional_models self.max_batch_size = max_batch_size self.max_total_sequence_length = max_total_sequence_length + self.prefill_chunk_size = prefill_chunk_size + self.gpu_memory_utilization = gpu_memory_utilization + self.engine_config = engine_config self.enable_tracing = enable_tracing self.host = host self.port = port self._proc: Optional[subprocess.Popen] = None - def start(self) -> None: + def start(self) -> None: # pylint: disable=too-many-branches """Launch the server in a popen subprocess. Wait until the server becomes ready before return. """ @@ -46,9 +59,20 @@ def start(self) -> None: if self.model_lib_path is not None: cmd += ["--model-lib-path", self.model_lib_path] cmd += ["--device", self.device] - cmd += ["--max-batch-size", str(self.max_batch_size)] + if self.mode is not None: + cmd += ["--mode", self.mode] + if self.additional_models is not None: + cmd += ["--additional-models", *self.additional_models] + if self.max_batch_size is not None: + cmd += ["--max-batch-size", str(self.max_batch_size)] if self.max_total_sequence_length is not None: cmd += ["--max-total-seq-length", str(self.max_total_sequence_length)] + if self.prefill_chunk_size is not None: + cmd += ["--prefill-chunk-size", str(self.prefill_chunk_size)] + if self.engine_config is not None: + cmd += ["--engine-config", str(self.engine_config)] + if self.gpu_memory_utilization is not None: + cmd += ["--gpu-memory-utilization", str(self.gpu_memory_utilization)] if self.enable_tracing: cmd += ["--enable-tracing"] diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index e8bc0288cf..12c55259b6 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -9,16 +9,17 @@ """ import json -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union import tvm from mlc_llm.serve import data -from mlc_llm.serve.config import EngineMode, GenerationConfig, KVCacheConfig +from mlc_llm.serve.config import EngineConfig, GenerationConfig from mlc_llm.serve.engine_base import ( - ModelInfo, - _estimate_max_total_sequence_length, + _infer_kv_cache_config, + _parse_models, _process_model_args, + detect_device, ) from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.serve.request import Request @@ -79,31 +80,66 @@ class SyncEngine: the `set_request_stream_callback` method. Otherwise, the engine will raise exception. - engine_mode : Optional[EngineMode] - The Engine execution mode. + engine_config : Optional[EngineConfig] + The Engine execution configuration. enable_tracing : bool A boolean indicating if to enable event logging for requests. """ - def __init__( # pylint: disable=too-many-arguments + def __init__( # pylint: disable=too-many-arguments,too-many-locals self, - models: Union[ModelInfo, List[ModelInfo]], - kv_cache_config: KVCacheConfig, - engine_mode: Optional[EngineMode] = None, - request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, + model: str, + device: Union[str, tvm.runtime.Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + gpu_memory_utilization: Optional[float] = None, enable_tracing: bool = False, + engine_config: Optional[EngineConfig] = None, + request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, ): - if isinstance(models, ModelInfo): - models = [models] + # - Initialize model loading info. + models = _parse_models(model, model_lib_path, additional_models) + if isinstance(device, str): + device = detect_device(device) + assert isinstance(device, tvm.runtime.Device) ( model_args, - config_file_paths, + model_config_paths, tokenizer_path, - max_single_sequence_length, + self.conv_template, + ) = _process_model_args(models, device) + + # - Load the raw model config into dict + self.model_config_dicts = [] + for i, model_info in enumerate(models): + # model_args: + # [model_lib_path, model_path, device.device_type, device.device_id] * N + model_info.model_lib_path = model_args[i * (len(model_args) // len(models))] + with open(model_config_paths[i], "r", encoding="utf-8") as file: + self.model_config_dicts.append(json.load(file)) + + # - Decide the KV cache config based on mode and user input. + kv_cache_config, max_single_sequence_length = _infer_kv_cache_config( + mode, + max_batch_size, + max_total_sequence_length, prefill_chunk_size, - self.conv_template_name, - ) = _process_model_args(models) + gpu_memory_utilization, + models, + device, + self.model_config_dicts, + model_config_paths, + ) + self.max_input_sequence_length = min( + max_single_sequence_length, kv_cache_config.max_total_sequence_length + ) + self._ffi = _create_tvm_module( "mlc.serve.create_engine", ffi_funcs=[ @@ -118,30 +154,16 @@ def __init__( # pylint: disable=too-many-arguments ], ) self.trace_recorder = EventTraceRecorder() if enable_tracing else None - self.max_input_sequence_length = max_single_sequence_length - - if kv_cache_config.max_total_sequence_length is None: - kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( - models, config_file_paths, kv_cache_config.max_num_sequence - ) - if kv_cache_config.prefill_chunk_size is None: - kv_cache_config.prefill_chunk_size = prefill_chunk_size - elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: - raise ValueError( - f"The specified prefill chunk size {kv_cache_config.prefill_chunk_size} is " - f"larger than the maximum prefill chunk size {prefill_chunk_size} supported by " - "models. Please specify a smaller prefill chunk size." - ) - if engine_mode is None: + if engine_config is None: # The default engine mode: non-speculative - engine_mode = EngineMode() + engine_config = EngineConfig() self._ffi["init"]( max_single_sequence_length, tokenizer_path, kv_cache_config.asjson(), - engine_mode.asjson(), + engine_config.asjson(), request_stream_callback, self.trace_recorder, *model_args, diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index 0d8448c9c5..f14d4727b8 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -7,13 +7,14 @@ import tvm -from mlc_llm.protocol import error_protocol, openai_api_protocol -from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig, engine_utils +from mlc_llm.protocol import openai_api_protocol +from mlc_llm.serve import engine_utils from mlc_llm.serve.engine_base import ( - EngineMode, - ModelInfo, - _estimate_max_total_sequence_length, + EngineConfig, + _infer_kv_cache_config, + _parse_models, _process_model_args, + detect_device, ) from mlc_llm.tokenizer import Tokenizer @@ -52,49 +53,57 @@ def _sync_request_stream_callback( class JSONFFIEngine: def __init__( # pylint: disable=too-many-arguments,too-many-locals self, - models: Union[ModelInfo, List[ModelInfo]], - kv_cache_config: KVCacheConfig, - engine_mode: Optional[EngineMode] = None, + model: str, + device: Union[str, tvm.runtime.Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + engine_config: Optional[EngineConfig] = None, + gpu_memory_utilization: Optional[float] = None, ) -> None: - if isinstance(models, ModelInfo): - models = [models] + # - Initialize model loading info. + models = _parse_models(model, model_lib_path, additional_models) + if isinstance(device, str): + device = detect_device(device) + assert isinstance(device, tvm.runtime.Device) ( model_args, - config_file_paths, + model_config_paths, tokenizer_path, - max_single_sequence_length, - prefill_chunk_size, self.conv_template, - ) = _process_model_args(models) + ) = _process_model_args(models, device) + # - Load the raw model config into dict self.model_config_dicts = [] - for i, model in enumerate(models): + for i, model_info in enumerate(models): # model_args: # [model_lib_path, model_path, device.device_type, device.device_id] * N - model.model_lib_path = model_args[i * (len(model_args) // len(models))] - with open(config_file_paths[i], "r", encoding="utf-8") as file: + model_info.model_lib_path = model_args[i * (len(model_args) // len(models))] + with open(model_config_paths[i], "r", encoding="utf-8") as file: self.model_config_dicts.append(json.load(file)) - self.state = EngineState() - - if kv_cache_config.max_total_sequence_length is None: - kv_cache_config.max_total_sequence_length = _estimate_max_total_sequence_length( - models, config_file_paths, kv_cache_config.max_num_sequence - ) + # - Decide the KV cache config based on mode and user input. + kv_cache_config, max_single_sequence_length = _infer_kv_cache_config( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + gpu_memory_utilization, + models, + device, + self.model_config_dicts, + model_config_paths, + ) self.max_input_sequence_length = min( max_single_sequence_length, kv_cache_config.max_total_sequence_length ) - prefill_chunk_size = min(prefill_chunk_size, kv_cache_config.max_total_sequence_length) - - if kv_cache_config.prefill_chunk_size is None: - kv_cache_config.prefill_chunk_size = prefill_chunk_size - elif kv_cache_config.prefill_chunk_size > prefill_chunk_size: - raise ValueError( - f"The specified prefill chunk size {kv_cache_config.prefill_chunk_size} is " - f"larger than the maximum prefill chunk size {prefill_chunk_size} supported by " - "models. Please specify a smaller prefill chunk size." - ) + # - Initialize engine state and engine. + self.state = EngineState() module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() self._ffi = { key: module[key] @@ -109,16 +118,16 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals ] } self.tokenizer = Tokenizer(tokenizer_path) - if engine_mode is None: + if engine_config is None: # The default engine mode: non-speculative - engine_mode = EngineMode() + engine_config = EngineConfig() def _background_loop(): self._ffi["init_background_engine"]( max_single_sequence_length, tokenizer_path, kv_cache_config.asjson(), - engine_mode.asjson(), + engine_config.asjson(), self.state.get_request_stream_callback(), None, *model_args, @@ -245,7 +254,7 @@ def test_chat_completion(engine: JSONFFIEngine): print(f"chat completion for request {rid}") for response in engine.chat_completion( messages=[{"role": "user", "content": [{"type": "text", "text": prompts[rid]}]}], - model=model.model, + model=model, max_tokens=max_tokens, n=n, request_id=str(rid), @@ -274,13 +283,14 @@ def test_malformed_request(engine: JSONFFIEngine): if __name__ == "__main__": - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + # Create engine. + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = JSONFFIEngine( + model, + model_lib_path=model_lib_path, + max_total_sequence_length=1024, ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=1024) - engine = JSONFFIEngine(model, kv_cache_config) test_chat_completion(engine) test_malformed_request(engine) diff --git a/tests/python/serve/benchmark.py b/tests/python/serve/benchmark.py index dd6d59c72f..a34b47335c 100644 --- a/tests/python/serve/benchmark.py +++ b/tests/python/serve/benchmark.py @@ -10,9 +10,8 @@ import numpy as np from transformers import AutoTokenizer -from mlc_llm.serve import GenerationConfig, KVCacheConfig +from mlc_llm.serve import GenerationConfig from mlc_llm.serve.config import ResponseFormat -from mlc_llm.serve.engine_base import ModelInfo from mlc_llm.serve.sync_engine import SyncEngine @@ -25,7 +24,6 @@ def _parse_args(): args.add_argument("--device", type=str, default="auto") args.add_argument("--num-prompts", type=int, default=500) args.add_argument("--max-num-sequence", type=int, default=80) - args.add_argument("--page-size", type=int, default=16) args.add_argument("--max-total-seq-length", type=int) args.add_argument("--seed", type=int, default=0) args.add_argument("--json-output", type=bool, default=False) @@ -34,7 +32,6 @@ def _parse_args(): parsed = args.parse_args() parsed.model = os.path.dirname(parsed.model_lib_path) assert parsed.max_num_sequence % 16 == 0 - assert parsed.page_size == 16 return parsed @@ -106,16 +103,16 @@ def time_evaluator(func: Callable, args: List[Any], num_runs: int = 3): def benchmark(args: argparse.Namespace): random.seed(args.seed) - # Initialize model loading info and KV cache config - model = ModelInfo(args.model, args.model_lib_path, args.device) - kv_cache_config = KVCacheConfig( - page_size=args.page_size, - max_num_sequence=args.max_num_sequence, + # Create engine + engine = SyncEngine( + model=args.model, + model_lib_path=args.model_lib_path, + device=args.device, + mode="server", + max_batch_size=args.max_num_sequence, max_total_sequence_length=args.max_total_seq_length, ) - # Create engine - engine = SyncEngine(model, kv_cache_config) # Sample prompts from dataset prompts, generation_config = sample_requests( args.dataset, args.num_prompts, args.model, args.json_output diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index 82c9dfa534..0685261806 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -4,8 +4,7 @@ import random from typing import List, Tuple -from mlc_llm.serve import GenerationConfig, KVCacheConfig -from mlc_llm.serve.engine_base import ModelInfo +from mlc_llm.serve import GenerationConfig from mlc_llm.serve.sync_engine import SyncEngine @@ -14,14 +13,12 @@ def _parse_args(): args.add_argument("--model-lib-path", type=str) args.add_argument("--device", type=str, default="auto") args.add_argument("--batch-size", type=int, default=80) - args.add_argument("--page-size", type=int, default=16) args.add_argument("--max-total-seq-length", type=int) args.add_argument("--seed", type=int, default=0) parsed = args.parse_args() parsed.model = os.path.dirname(parsed.model_lib_path) assert parsed.batch_size % 16 == 0 - assert parsed.page_size == 16 return parsed @@ -43,17 +40,16 @@ def generate_requests( def benchmark(args: argparse.Namespace): random.seed(args.seed) - # Initialize model loading info and KV cache config - model = ModelInfo(args.model, args.model_lib_path, args.device) - kv_cache_config = KVCacheConfig( - page_size=args.page_size, - max_num_sequence=args.batch_size, + # Create engine + engine = SyncEngine( + model=args.model, + device=args.device, + model_lib_path=args.model_lib_path, + mode="server", + max_batch_size=args.batch_size, max_total_sequence_length=args.max_total_seq_length, ) - # Create engine - engine = SyncEngine(model, kv_cache_config) - print(args) for num_requests in [1, 2, 4, 8, 16, 32, 64]: if num_requests > args.batch_size: diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index 4da72c5deb..afa7081bd7 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -3,8 +3,7 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncEngine, GenerationConfig, KVCacheConfig -from mlc_llm.serve.engine_base import ModelInfo +from mlc_llm.serve import AsyncEngine, GenerationConfig prompts = [ "What is the meaning of life?", @@ -21,14 +20,15 @@ async def test_engine_generate(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - async_engine = AsyncEngine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + async_engine = AsyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) num_requests = 10 max_tokens = 256 @@ -77,14 +77,15 @@ async def generate_task( async def test_chat_completion(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - async_engine = AsyncEngine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + async_engine = AsyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) num_requests = 2 max_tokens = 32 @@ -96,7 +97,7 @@ async def generate_task(prompt: str, request_id: str): rid = int(request_id) async for response in await async_engine.chat.completions.create( messages=[{"role": "user", "content": prompt}], - model=model.model, + model=model, max_tokens=max_tokens, n=n, request_id=request_id, @@ -128,14 +129,15 @@ async def generate_task(prompt: str, request_id: str): async def test_chat_completion_non_stream(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - async_engine = AsyncEngine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + async_engine = AsyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) num_requests = 2 max_tokens = 32 @@ -147,7 +149,7 @@ async def generate_task(prompt: str, request_id: str): rid = int(request_id) response = await async_engine.chat.completions.create( messages=[{"role": "user", "content": prompt}], - model=model.model, + model=model, max_tokens=max_tokens, n=n, request_id=request_id, @@ -178,14 +180,15 @@ async def generate_task(prompt: str, request_id: str): async def test_completion(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - async_engine = AsyncEngine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + async_engine = AsyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) num_requests = 2 max_tokens = 128 @@ -197,7 +200,7 @@ async def generate_task(prompt: str, request_id: str): rid = int(request_id) async for response in await async_engine.completions.create( prompt=prompt, - model=model.model, + model=model, max_tokens=max_tokens, n=n, ignore_eos=True, @@ -229,14 +232,15 @@ async def generate_task(prompt: str, request_id: str): async def test_completion_non_stream(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - async_engine = AsyncEngine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + async_engine = AsyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) num_requests = 2 max_tokens = 128 @@ -248,7 +252,7 @@ async def generate_task(prompt: str, request_id: str): rid = int(request_id) response = await async_engine.completions.create( prompt=prompt, - model=model.model, + model=model, max_tokens=max_tokens, n=n, ignore_eos=True, diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index dc0d0c1c7f..f7ccb13a8d 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -3,14 +3,7 @@ import asyncio from typing import List -from mlc_llm.serve import ( - AsyncEngine, - EngineMode, - GenerationConfig, - KVCacheConfig, - SpeculativeMode, -) -from mlc_llm.serve.engine_base import ModelInfo +from mlc_llm.serve import AsyncEngine, EngineConfig, GenerationConfig, SpeculativeMode prompts = [ "What is the meaning of life?", @@ -27,19 +20,20 @@ async def test_engine_generate(): - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - llm = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + async_engine = AsyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + additional_models=[small_model + ":" + small_model_lib_path], + engine_config=EngineConfig(speculative_mode=SpeculativeMode.SMALL_DRAFT), ) - kv_cache_config = KVCacheConfig(page_size=16) - engine_mode = EngineMode(speculative_mode=SpeculativeMode.SMALL_DRAFT) - # Create engine - async_engine = AsyncEngine([llm, ssm], kv_cache_config, engine_mode) num_requests = 10 max_tokens = 256 diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index eccf1facda..376671a884 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -2,8 +2,7 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable from typing import List -from mlc_llm.serve import Engine, GenerationConfig, KVCacheConfig -from mlc_llm.serve.engine_base import ModelInfo +from mlc_llm.serve import Engine, GenerationConfig prompts = [ "What is the meaning of life?", @@ -20,14 +19,15 @@ def test_engine_generate(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - engine = Engine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = Engine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) num_requests = 10 max_tokens = 256 @@ -58,14 +58,15 @@ def test_engine_generate(): def test_chat_completion(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - engine = Engine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = Engine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) num_requests = 2 max_tokens = 64 @@ -76,7 +77,7 @@ def test_chat_completion(): print(f"chat completion for request {rid}") for response in engine.chat.completions.create( messages=[{"role": "user", "content": prompts[rid]}], - model=model.model, + model=model, max_tokens=max_tokens, n=n, request_id=str(rid), @@ -101,14 +102,15 @@ def test_chat_completion(): def test_chat_completion_non_stream(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - engine = Engine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = Engine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) num_requests = 2 max_tokens = 64 @@ -119,7 +121,7 @@ def test_chat_completion_non_stream(): print(f"chat completion for request {rid}") response = engine.chat.completions.create( messages=[{"role": "user", "content": prompts[rid]}], - model=model.model, + model=model, max_tokens=max_tokens, n=n, request_id=str(rid), @@ -143,14 +145,15 @@ def test_chat_completion_non_stream(): def test_completion(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - engine = Engine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = Engine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) num_requests = 2 max_tokens = 128 @@ -161,7 +164,7 @@ def test_completion(): print(f"completion for request {rid}") for response in engine.completions.create( prompt=prompts[rid], - model=model.model, + model=model, max_tokens=max_tokens, n=n, ignore_eos=True, @@ -186,14 +189,15 @@ def test_completion(): def test_completion_non_stream(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - engine = Engine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = Engine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) num_requests = 2 max_tokens = 128 @@ -204,7 +208,7 @@ def test_completion_non_stream(): print(f"completion for request {rid}") response = engine.completions.create( prompt=prompts[rid], - model=model.model, + model=model, max_tokens=max_tokens, n=n, ignore_eos=True, diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index e40f477061..1bb985f53a 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -7,9 +7,8 @@ import pytest from pydantic import BaseModel -from mlc_llm.serve import AsyncEngine, GenerationConfig, KVCacheConfig +from mlc_llm.serve import AsyncEngine, GenerationConfig from mlc_llm.serve.config import ResponseFormat -from mlc_llm.serve.engine_base import ModelInfo from mlc_llm.serve.sync_engine import SyncEngine prompts_list = [ @@ -22,11 +21,8 @@ def test_batch_generation_with_grammar(): - # Initialize model loading info and KV cache config - model = ModelInfo(model_path, model_lib_path=model_lib_path) - kv_cache_config = KVCacheConfig(page_size=16) # Create engine - engine = SyncEngine(model, kv_cache_config) + engine = SyncEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompt_len = len(prompts_list) prompts = prompts_list * 3 @@ -72,11 +68,8 @@ def test_batch_generation_with_grammar(): def test_batch_generation_with_schema(): - # Initialize model loading info and KV cache config - model = ModelInfo(model_path, model_lib_path=model_lib_path) - kv_cache_config = KVCacheConfig(page_size=16) # Create engine - engine = SyncEngine(model, kv_cache_config) + engine = SyncEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompt = ( "Generate a json containing three fields: an integer field named size, a " @@ -127,11 +120,8 @@ class Schema(BaseModel): async def run_async_engine(): - # Initialize model loading info and KV cache config - model = ModelInfo(model_path, model_lib_path=model_lib_path) - kv_cache_config = KVCacheConfig(page_size=16) # Create engine - async_engine = AsyncEngine(model, kv_cache_config, enable_tracing=True) + async_engine = AsyncEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompts = prompts_list * 20 @@ -185,8 +175,6 @@ async def generate_task( for i, output in enumerate(outputs): print(f"Output {req_id}({i}):{output}\n") - print(async_engine.state.trace_recorder.dump_json(), file=open("tmpfiles/tmp.json", "w")) - async_engine.terminate() diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py index e8bcb13ae4..f3e13d600b 100644 --- a/tests/python/serve/test_serve_engine_image.py +++ b/tests/python/serve/test_serve_engine_image.py @@ -1,8 +1,7 @@ import json from pathlib import Path -from mlc_llm.serve import GenerationConfig, KVCacheConfig, data -from mlc_llm.serve.engine_base import ModelInfo +from mlc_llm.serve import GenerationConfig, data from mlc_llm.serve.sync_engine import SyncEngine @@ -11,17 +10,18 @@ def get_test_image(config) -> data.ImageData: def test_engine_generate(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/llava-1.5-7b-hf-q4f16_1-MLC/params", - model_lib_path="dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - engine = SyncEngine(model, kv_cache_config) + model = "dist/llava-1.5-7b-hf-q4f16_1-MLC/params" + model_lib_path = "dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so" + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) max_tokens = 256 - with open(Path(model.model) / "mlc-chat-config.json", "r", encoding="utf-8") as file: + with open(Path(model) / "mlc-chat-config.json", "r", encoding="utf-8") as file: model_config = json.load(file) prompts = [ diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 49a55e3ed0..818064e423 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -5,7 +5,7 @@ import numpy as np from mlc_llm.serve import ( - EngineMode, + EngineConfig, GenerationConfig, KVCacheConfig, Request, @@ -69,18 +69,6 @@ def test_engine_basic(): requests + max_tokens - 1). Then check the output of each request. """ - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", - ) - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) - engine_mode = EngineMode(speculative_mode=SpeculativeMode.SMALL_DRAFT) - # Hyperparameters for tests (you can try different combinations). num_requests = len(prompts) # [4, 8, 10] temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] @@ -99,7 +87,21 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine - engine = SyncEngine([model, ssm], kv_cache_config, engine_mode, fcallback) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + ) + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + engine_config=EngineConfig(speculative_mode=SpeculativeMode.SMALL_DRAFT), + request_stream_callback=fcallback, + ) # Create requests requests = create_requests( @@ -135,18 +137,6 @@ def test_engine_eagle_basic(): - Use Eagle model as speculative model """ - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Eagle-llama2-7b-chat-q0f16-MLC", - model_lib_path="dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so", - ) - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) - engine_mode = EngineMode(spec_draft_length=2, speculative_mode=SpeculativeMode.EAGLE) - # Hyperparameters for tests (you can try different combinations). num_requests = len(prompts) # [4, 8, 10] temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] @@ -165,7 +155,21 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine - engine = SyncEngine([model, ssm], kv_cache_config, engine_mode, fcallback) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + small_model = "dist/Eagle-llama2-7b-chat-q0f16-MLC" + small_model_lib_path = ( + "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" + ) + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + engine_config=EngineConfig(spec_draft_length=2, speculative_mode=SpeculativeMode.EAGLE), + request_stream_callback=fcallback, + ) # Create requests requests = create_requests( @@ -201,18 +205,6 @@ def test_engine_continuous_batching_1(): of each request. """ - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", - ) - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) - engine_mode = EngineMode(speculative_mode=SpeculativeMode.SMALL_DRAFT) - # Hyperparameters for tests (you can try different combinations) num_requests = len(prompts) # [4, 8, 10] temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] @@ -245,8 +237,22 @@ def step(self) -> None: self.timer += 1 # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + ) timer = CallbackTimer() - engine = SyncEngine([model, ssm], kv_cache_config, engine_mode, timer.callback_getter()) + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + engine_config=EngineConfig(speculative_mode=SpeculativeMode.SMALL_DRAFT), + request_stream_callback=timer.callback_getter(), + ) # Create requests requests = create_requests( @@ -285,18 +291,6 @@ def test_engine_eagle_continuous_batching_1(): of each request. """ - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Eagle-llama2-7b-chat-q4f16_1-MLC", - model_lib_path="dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so", - ) - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) - engine_mode = EngineMode(speculative_mode=SpeculativeMode.EAGLE) - # Hyperparameters for tests (you can try different combinations) num_requests = len(prompts) # [4, 8, 10] temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] @@ -329,8 +323,22 @@ def step(self) -> None: self.timer += 1 # Create engine + model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + small_model = "dist/Eagle-llama2-7b-chat-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" + ) timer = CallbackTimer() - engine = SyncEngine([model, ssm], kv_cache_config, engine_mode, timer.callback_getter()) + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + engine_config=EngineConfig(speculative_mode=SpeculativeMode.EAGLE), + request_stream_callback=timer.callback_getter(), + ) # Create requests requests = create_requests( @@ -359,19 +367,21 @@ def step(self) -> None: def test_engine_generate(): - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + engine_config=EngineConfig(speculative_mode=SpeculativeMode.SMALL_DRAFT), ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) - engine_mode = EngineMode(speculative_mode=SpeculativeMode.SMALL_DRAFT) - # Create engine - engine = SyncEngine([model, ssm], kv_cache_config, engine_mode) num_requests = 10 max_tokens = 256 @@ -390,19 +400,21 @@ def test_engine_generate(): def test_engine_eagle_generate(): - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Eagle-llama2-7b-chat-q4f16_1-MLC", - model_lib_path="dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so", + # Create engine + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + small_model = "dist/Eagle-llama2-7b-chat-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + engine_config=EngineConfig(speculative_mode=SpeculativeMode.EAGLE), ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) - engine_mode = EngineMode(speculative_mode=SpeculativeMode.EAGLE) - # Create engine - engine = SyncEngine([model, ssm], kv_cache_config, engine_mode) num_requests = 10 max_tokens = 256 @@ -423,13 +435,6 @@ def test_engine_eagle_generate(): def test_engine_efficiency(): """Test engine speculative decoding efficiency.""" - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-13b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) - # Hyperparameters for tests (you can try different combinations). num_requests = 1 # [4, 8, 10] temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] @@ -448,7 +453,15 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine - engine = SyncEngine(model, kv_cache_config, request_stream_callback=fcallback) + model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" + model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + request_stream_callback=fcallback, + ) # Create requests requests = create_requests( @@ -485,23 +498,6 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): def test_engine_spec_efficiency(): """Test engine speculative decoding efficiency.""" - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", - ) - # If Flashinfer allows head_dim < 128, we can test this model - # ssm = ModelInfo( - # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC", - # model_lib_path="dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so", - # ) - model = ModelInfo( - "dist/Llama-2-13b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) - engine_mode = EngineMode(spec_draft_length=6, speculative_mode=SpeculativeMode.SMALL_DRAFT) - # Hyperparameters for tests (you can try different combinations). num_requests = 1 # [4, 8, 10] temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] @@ -520,7 +516,28 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine - spec_engine = SyncEngine([model, ssm], kv_cache_config, engine_mode, fcallback) + model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" + model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" + small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + small_model_lib_path = ( + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + ) + # If Flashinfer allows head_dim < 128, we can test this model + # small_model = "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC" + # small_model_lib_path = ( + # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so" + # ) + spec_engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + engine_config=EngineConfig( + spec_draft_length=6, speculative_mode=SpeculativeMode.SMALL_DRAFT + ), + request_stream_callback=fcallback, + ) # Create requests requests = create_requests( @@ -557,23 +574,6 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): def test_engine_eagle_spec_efficiency(): """Test engine speculative decoding efficiency.""" - # Initialize model loading info and KV cache config - ssm = ModelInfo( - "dist/Eagle-llama2-7b-chat-q0f16-MLC", - model_lib_path="dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so", - ) - # If Flashinfer allows head_dim < 128, we can test this model - # ssm = ModelInfo( - # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC", - # model_lib_path="dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so", - # ) - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) - engine_mode = EngineMode(spec_draft_length=6, speculative_mode=SpeculativeMode.EAGLE) - # Hyperparameters for tests (you can try different combinations). num_requests = 1 # [4, 8, 10] temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] @@ -592,7 +592,21 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine - spec_engine = SyncEngine([model, ssm], kv_cache_config, engine_mode, fcallback) + model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + small_model = "dist/Eagle-llama2-7b-chat-q0f16-MLC" + small_model_lib_path = ( + "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" + ) + spec_engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + additional_models=[small_model + ":" + small_model_lib_path], + engine_config=EngineConfig(spec_draft_length=6, speculative_mode=SpeculativeMode.EAGLE), + request_stream_callback=fcallback, + ) # Create requests requests = create_requests( diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py index 3c8ec011ae..4304348095 100644 --- a/tests/python/serve/test_serve_sync_engine.py +++ b/tests/python/serve/test_serve_sync_engine.py @@ -4,14 +4,7 @@ import numpy as np -from mlc_llm.serve import ( - GenerationConfig, - KVCacheConfig, - Request, - RequestStreamOutput, - data, -) -from mlc_llm.serve.engine_base import ModelInfo +from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data from mlc_llm.serve.sync_engine import SyncEngine prompts = [ @@ -67,13 +60,6 @@ def test_engine_basic(): requests + max_tokens - 1). Then check the output of each request. """ - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16) - # Hyperparameters for tests (you can try different combinations). num_requests = 10 # [4, 8, 10] temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] @@ -92,7 +78,14 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): outputs[int(request_id)] += stream_outputs[0].delta_token_ids # Create engine - engine = SyncEngine(model, kv_cache_config, request_stream_callback=fcallback) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + request_stream_callback=fcallback, + ) # Create requests requests = create_requests( @@ -128,13 +121,6 @@ def test_engine_continuous_batching_1(): of each request. """ - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16) - # Hyperparameters for tests (you can try different combinations) num_requests = 10 # [4, 8, 10] temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] @@ -168,7 +154,14 @@ def step(self) -> None: # Create engine timer = CallbackTimer() - engine = SyncEngine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + request_stream_callback=timer.callback_getter(), + ) # Create requests requests = create_requests( @@ -209,13 +202,6 @@ def test_engine_continuous_batching_2(): of each request. """ - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16) - # Hyperparameters for tests (you can try different combinations) num_requests = 10 # [4, 8, 10] temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] @@ -249,7 +235,14 @@ def step(self) -> None: # Create engine timer = CallbackTimer() - engine = SyncEngine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + request_stream_callback=timer.callback_getter(), + ) # Create requests requests = create_requests( @@ -289,13 +282,6 @@ def test_engine_continuous_batching_3(): Then check the output of each request. """ - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16) - # Hyperparameters for tests (you can try different combinations) num_requests = 10 # [4, 8, 10] temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] @@ -335,7 +321,14 @@ def all_finished(self) -> bool: # Create engine timer = CallbackTimer() - engine = SyncEngine(model, kv_cache_config, request_stream_callback=timer.callback_getter()) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + request_stream_callback=timer.callback_getter(), + ) # Create requests requests = create_requests( @@ -369,14 +362,15 @@ def all_finished(self) -> bool: def test_engine_generate(): - # Initialize model loading info and KV cache config - model = ModelInfo( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - model_lib_path="dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ) - kv_cache_config = KVCacheConfig(page_size=16, max_total_sequence_length=4096) # Create engine - engine = SyncEngine(model, kv_cache_config) + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = SyncEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) num_requests = 10 max_tokens = 256 From 8139a47a331b1120b1d6375e72e2569d6886f03e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 13 Apr 2024 09:18:51 -0400 Subject: [PATCH 185/531] [Serving] Revamp engine mode selection logging info (#2128) This PR revamps the logging info for engine mode selection to provide more detailed information and the rationale of different modes. --- python/mlc_llm/serve/engine_base.py | 201 +++++++++++++++++----------- 1 file changed, 124 insertions(+), 77 deletions(-) diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 421cd187f7..45ad9f7756 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -296,7 +296,7 @@ def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[i return model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size -def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches +def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements mode: Literal["local", "interactive", "server"], max_batch_size: Optional[int], max_total_sequence_length: Optional[int], @@ -314,105 +314,152 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local model_max_batch_size, ) = _get_model_config_limit(model_config_dicts) - logging_msg = 'Engine mode is "' + green(mode) + '". ' - # - max_batch_size - if max_batch_size is None: - max_batch_size = ( - min(4, model_max_batch_size) - if mode == "local" - else (1 if mode == "interactive" else model_max_batch_size) + def infer_args_under_mode( + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + ) -> Tuple[KVCacheConfig, List[float]]: + logging_msg = "" + # - max_batch_size + if max_batch_size is None: + max_batch_size = ( + min(4, model_max_batch_size) + if mode == "local" + else (1 if mode == "interactive" else model_max_batch_size) + ) + logging_msg += f"max batch size is set to {max_batch_size}, " + else: + logging_msg += f"max batch size {max_batch_size} is specified by user, " + # - infer the maximum total sequence length that can fit GPU memory. + ( + total_mem_usage_except_kv_cache, + model_params_bytes, + kv_bytes_per_token, + kv_aux_workspace_bytes, + temp_workspace_bytes, + model_max_total_sequence_length, + ) = _estimate_mem_usage_and_max_total_sequence_length( + models, + device, + model_config_paths, + model_config_dicts, + max_batch_size, + gpu_memory_utilization, ) - logging_msg += "Max batch size is set to " + green(str(max_batch_size)) + ". " - else: - logging_msg += "Max batch size " + green(str(max_batch_size)) + " is specified by user. " - # - infer the maximum total sequence length that can fit GPU memory. - ( - total_mem_usage_except_kv_cache, - model_params_bytes, - kv_bytes_per_token, - kv_aux_workspace_bytes, - temp_workspace_bytes, - model_max_total_sequence_length, - ) = _estimate_mem_usage_and_max_total_sequence_length( - models, - device, - model_config_paths, - model_config_dicts, - max_batch_size, - gpu_memory_utilization, - ) - # - max_total_sequence_length - if max_total_sequence_length is None: + # - max_total_sequence_length + if max_total_sequence_length is None: + if mode == "local": + max_total_sequence_length = min( + model_max_total_sequence_length, model_max_single_sequence_length, 8192 + ) + elif mode == "interactive": + max_total_sequence_length = min( + model_max_total_sequence_length, model_max_single_sequence_length + ) + else: + max_total_sequence_length = min( + model_max_total_sequence_length, + max_batch_size * model_max_single_sequence_length, + ) + logging_msg += f"max KV cache token capacity is set to {max_total_sequence_length}, " + else: + logging_msg += ( + f"max KV cache token capacity {max_total_sequence_length} is specified by user. " + ) + # - prefill_chunk_size + if prefill_chunk_size is None: + if mode in ["local", "interactive"]: + prefill_chunk_size = min( + model_max_prefill_chunk_size, + model_max_total_sequence_length, + model_max_single_sequence_length, + ) + else: + prefill_chunk_size = model_max_prefill_chunk_size + logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " + else: + logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " + if mode == "local": - max_total_sequence_length = min( - model_max_total_sequence_length, model_max_single_sequence_length, 8192 + logging_msg += ( + "We choose small max batch size and KV cache capacity to use less GPU memory." ) elif mode == "interactive": - max_total_sequence_length = min( - model_max_total_sequence_length, model_max_single_sequence_length - ) + logging_msg += "We fix max batch size to 1 for interactive single sequence use." else: - max_total_sequence_length = min( - model_max_total_sequence_length, max_batch_size * model_max_single_sequence_length + logging_msg += ( + "We use as much GPU memory as possible (within the" + " limit of gpu_memory_utilization)." ) - logging_msg += ( - "Max KV cache token capacity is set to " + green(str(max_total_sequence_length)) + ". " - ) - else: - logging_msg += ( - "Max KV cache token capacity " - + green(str(max_total_sequence_length)) - + " is specified by user. " - ) - # - prefill_chunk_size - if prefill_chunk_size is None: - if mode in ["local", "interactive"]: - prefill_chunk_size = min( - model_max_prefill_chunk_size, - model_max_total_sequence_length, - model_max_single_sequence_length, - ) - else: - prefill_chunk_size = model_max_prefill_chunk_size - logging_msg += "Prefill chunk size is set to " + green(str(prefill_chunk_size)) + ". " + logger.info('Under mode "%s", %s', mode, logging_msg) + + # - Construct the KV cache config + # - Estimate total GPU memory usage on single GPU. + return KVCacheConfig( + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + ), [ + total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token, + model_params_bytes, + kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes, + temp_workspace_bytes, + ] + + # - Infer KV cache config and estimate memory usage for each mode. + local_kv_cache_config, local_mem_usage_list = infer_args_under_mode( + "local", max_batch_size, max_total_sequence_length, prefill_chunk_size + ) + interactive_kv_cache_config, interactive_mem_usage_list = infer_args_under_mode( + "interactive", max_batch_size, max_total_sequence_length, prefill_chunk_size + ) + server_kv_cache_config, server_mem_usage_list = infer_args_under_mode( + "server", max_batch_size, max_total_sequence_length, prefill_chunk_size + ) + + # - Select the config based on the actual mode. + if mode == "local": + kv_cache_config = local_kv_cache_config + mem_usage_list = local_mem_usage_list + elif mode == "interactive": + kv_cache_config = interactive_kv_cache_config + mem_usage_list = interactive_mem_usage_list else: - logging_msg += ( - "Prefill chunk size " + green(str(prefill_chunk_size)) + " is specified by user. " - ) - logger.info(logging_msg) - # - Estimate total GPU memory usage on single GPU. - total_mem_usage = ( - total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token + kv_cache_config = server_kv_cache_config + mem_usage_list = server_mem_usage_list + + logger.info( + 'The actual engine mode is "%s". So max batch size is %s, ' + "max KV cache token capacity is %s, prefill chunk size is %s.", + green(mode), + green(str(kv_cache_config.max_num_sequence)), + green(str(kv_cache_config.max_total_sequence_length)), + green(str(kv_cache_config.prefill_chunk_size)), ) + logger.info( "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB). " "The actual usage might be slightly larger than the estimated number.", green("Estimated total single GPU memory usage"), - total_mem_usage / 1024 / 1024, - model_params_bytes / 1024 / 1024, - (kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes) / 1024 / 1024, - temp_workspace_bytes / 1024 / 1024, + *list(mem_usage / 1024 / 1024 for mem_usage in mem_usage_list), ) # - Final messages + override_msg = "Please override the arguments if you have particular values to set." if mode in ["local", "interactive"]: logger.info( 'Please switch to mode "server" if you want to use more GPU memory ' - "and support more concurrent requests." + "and support more concurrent requests. %s", + override_msg, ) else: logger.info( 'Please switch to mode "local" or "interactive" if you want to use less GPU memory ' - "or do not have many concurrent requests to process." + "or do not have many concurrent requests to process. %s", + override_msg, ) - return ( - KVCacheConfig( - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - prefill_chunk_size=prefill_chunk_size, - ), - model_max_single_sequence_length, - ) + return kv_cache_config, model_max_single_sequence_length @dataclass From a361119184bc1c85ff4d35d7bf22c1fced577c0a Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Sun, 14 Apr 2024 22:29:05 +0800 Subject: [PATCH 186/531] [SLM] Chatglm3 Multi-GPU support (#2123) This PR enables TP for Chatglm3 model. --- .../mlc_llm/model/chatglm3/chatglm3_model.py | 72 ++++++++++++++++--- 1 file changed, 63 insertions(+), 9 deletions(-) diff --git a/python/mlc_llm/model/chatglm3/chatglm3_model.py b/python/mlc_llm/model/chatglm3/chatglm3_model.py index e4a9f53b15..f7e81019e0 100644 --- a/python/mlc_llm/model/chatglm3/chatglm3_model.py +++ b/python/mlc_llm/model/chatglm3/chatglm3_model.py @@ -13,6 +13,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase from mlc_llm.support.style import bold @@ -40,6 +41,7 @@ class GLMConfig(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = 0 prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 + head_dim: int = 0 max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) @@ -65,6 +67,9 @@ def __post_init__(self): "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( "%s defaults to %s (%d)", @@ -82,7 +87,6 @@ def __post_init__(self): bold("context_window_size"), ) self.prefill_chunk_size = self.context_window_size - assert self.tensor_parallel_shards == 1, "ChatGLM currently does not support sharding." # pylint: disable=invalid-name,missing-docstring @@ -91,14 +95,14 @@ def __post_init__(self): class GLMAttention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: GLMConfig): self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads + self.num_heads = config.num_attention_heads // config.tensor_parallel_shards self.multi_query_attention = config.multi_query_attention self.num_key_value_heads = ( config.multi_query_group_num if config.multi_query_attention else config.num_attention_heads - ) - self.head_dim = self.hidden_size // self.num_heads + ) // config.tensor_parallel_shards + self.head_dim = config.head_dim self.query_key_value = nn.Linear( config.hidden_size, (2 * self.num_key_value_heads + self.num_heads) * self.head_dim, @@ -123,13 +127,15 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: class GLMMLP(nn.Module): def __init__(self, config: GLMConfig): + self.ffn_hidden_size = config.ffn_hidden_size // config.tensor_parallel_shards + self.dense_h_to_4h = nn.Linear( config.hidden_size, - config.ffn_hidden_size * 2, + self.ffn_hidden_size * 2, bias=config.add_bias_linear, ) self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, + self.ffn_hidden_size, config.hidden_size, bias=config.add_bias_linear, ) @@ -158,13 +164,57 @@ def __init__(self, config: GLMConfig): config.hidden_size, -1, config.layernorm_epsilon, bias=False ) + def _set_tp(): + def _set(layer, hint): + layer.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attention.num_heads * hd + k = self.self_attention.num_key_value_heads * hd + v = self.self_attention.num_key_value_heads * hd + _set( + self.self_attention.query_key_value.weight, + tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]), + ) + if config.add_bias_linear or config.add_qkv_bias: + _set( + self.self_attention.query_key_value.bias, + tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]), + ) + _set(self.self_attention.dense.weight, tp.ShardSingleDim("_shard_dense_weight", dim=1)) + if config.add_bias_linear: + _set(self.self_attention.dense.bias, tp.ShardSingleDim("_shard_dense_bias", dim=0)) + _set( + self.mlp.dense_h_to_4h.weight, + tp.ShardSingleDim("_shard_dense_h_to_4h_weight", dim=0), + ) + if config.add_bias_linear: + _set( + self.mlp.dense_h_to_4h.bias, + tp.ShardSingleDim("_shard_dense_h_to_4h_bias", dim=0), + ) + _set(self.mlp.dense_4h_to_h.weight, tp.ShardSingleDim("_shard_dense_4h_to_h", dim=1)) + if config.add_bias_linear: + _set( + self.mlp.dense_4h_to_h.bias, + tp.ShardSingleDim("_shard_dense_4h_to_h_bias", dim=1), + ) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): out = self.self_attention(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) - hidden_states = out + hidden_states + hidden_states = self._apply_residual(out, residual=hidden_states) out = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = out + hidden_states + hidden_states = self._apply_residual(out, residual=hidden_states) return hidden_states + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + class GLMTransformer(nn.Module): """Transformer class.""" @@ -217,7 +267,7 @@ def __init__(self, config: GLMConfig): if config.multi_query_attention else config.num_attention_heads ) - self.head_dim = self.hidden_size // self.num_attention_heads + self.head_dim = config.head_dim self.vocab_size = config.vocab_size self.rope_theta = 10000 self.tensor_parallel_shards = config.tensor_parallel_shards @@ -245,6 +295,8 @@ def batch_forward( return logits def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.transformer.embedding(input_ids) def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): @@ -273,6 +325,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache From 661abb2ca96ac76a0a0169e83632c91db08f4f9b Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 14 Apr 2024 11:19:53 -0400 Subject: [PATCH 187/531] [Serving] Fix support of large `n` under low max batch size (#2136) Prior to this PR, due to the improper prefill policy on `n` (parallel generation), the engine will loop forever when the a request has `n` larger than the maximum batch size that the engine can support. This PR fixes this issue by updating the prefill action, and with this PR, even the "interactive" engine mode can well support multiple parallel generation. After this fix, it is possible that a request require 10 parallel generation while the max batch size is 1. Given the shapes of temporary NDArrays in GPU sampler is determined by the max batch size, GPU sampler does not natively support sampling 10 tokens at a time. To approach this issue, this PR introduces chunking to GPU sampler. Therefore, in this particular case, the GPU sampler will have chunk size 1, and the 10 required samples will be processed by the GPU sampler one by one in order. Chunking is the minimum change we can do to support large `n`. --- .../eagle_new_request_prefill.cc | 70 ++++++++--- .../engine_actions/new_request_prefill.cc | 111 ++++++++++++------ cpp/serve/sampler/gpu_sampler.cc | 56 +++++++-- 3 files changed, 174 insertions(+), 63 deletions(-) diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index 7ed84feb86..d7a397ce92 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -217,17 +217,21 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { std::vector sample_indices; std::vector rsentries_for_sample; std::vector rngs; + std::vector rsentry_activated; sample_indices.reserve(num_rsentries); rsentries_for_sample.reserve(num_rsentries); rngs.reserve(num_rsentries); + rsentry_activated.reserve(num_rsentries); request_ids.clear(); generation_cfg.clear(); for (int i = 0; i < num_rsentries; ++i) { const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; + int remaining_num_child_to_activate = prefill_inputs[i].num_child_to_activate; for (int child_idx : rsentry->child_indices) { // Only use base model to judge if we need to add child entries. - if (rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty() || - fork_rsentry_child_map[i].count(child_idx)) { + if (rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending && + (rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty() || + fork_rsentry_child_map[i].count(child_idx))) { // If rstates_of_entries[i]->entries[child_idx] has no committed token, // the prefill of the current rsentry will unblock // rstates_of_entries[i]->entries[child_idx], @@ -239,6 +243,16 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { generation_cfg.push_back(rsentry->request->generation_cfg); rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng); + // We only fork the first `num_child_to_activate` children. + // The children not being forked will be forked via later prefills. + // Usually `num_child_to_activate` is the same as the number of children. + // But it can be fewer subject to the KV cache max num sequence limit. + if (remaining_num_child_to_activate == 0) { + rsentry_activated.push_back(false); + continue; + } + rsentry_activated.push_back(true); + --remaining_num_child_to_activate; if (model_id == 0) { ICHECK(rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending); @@ -261,6 +275,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { request_ids.push_back(rsentry->request->id); generation_cfg.push_back(rsentry->request->generation_cfg); rngs.push_back(&rsentry->rng); + rsentry_activated.push_back(true); } } std::vector prob_dist; @@ -275,6 +290,12 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { if (model_id == 0) { for (int mid = 0; mid < static_cast(models_.size()); ++mid) { rsentries_for_sample[i]->mstates[mid]->CommitToken(sample_results[i]); + if (!rsentry_activated[i]) { + // When the child rsentry is not activated, + // add the sampled token as an input of the mstate for prefill. + rsentries_for_sample[i]->mstates[mid]->inputs.push_back( + TokenData(std::vector{sample_results[i].sampled_token_id.first})); + } } // Only base model trigger timing records. if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { @@ -332,7 +353,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { /*! \brief The class of request state entry and its maximum allowed length for prefill. */ struct PrefillInput { RequestStateEntry rsentry; - int max_prefill_length; + int max_prefill_length = 0; + int num_child_to_activate = 0; }; /*! @@ -376,11 +398,19 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { total_input_length += input_length; total_required_pages += num_require_pages; // - Attempt 1. Check if the entire request state entry can fit for prefill. - if (CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), - total_input_length, total_required_pages, num_available_pages, - current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length}); - num_prefill_rsentries += 1 + rsentry->child_indices.size(); + bool can_prefill = false; + for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0; + --num_child_to_activate) { + if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate, + total_input_length, total_required_pages, num_available_pages, + current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length, num_child_to_activate}); + num_prefill_rsentries += 1 + num_child_to_activate; + can_prefill = true; + break; + } + } + if (can_prefill) { continue; } total_input_length -= input_length; @@ -388,18 +418,26 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // - Attempt 2. Check if the request state entry can partially fit by input chunking. ICHECK_LE(total_input_length, kv_cache_config_->prefill_chunk_size); - input_length = - std::min(input_length, kv_cache_config_->prefill_chunk_size - total_input_length); + if (kv_cache_config_->prefill_chunk_size - total_input_length >= input_length || + kv_cache_config_->prefill_chunk_size == total_input_length) { + // 1. If the input length can fit the remaining prefill chunk size, + // it means the failure of attempt 1 is not because of the input + // length being too long, and thus chunking does not help. + // 2. If the total input length already reaches the prefill chunk size, + // the current request state entry will not be able to be processed. + // So we can safely return in either case. + prefill_stops = true; + break; + } + input_length = kv_cache_config_->prefill_chunk_size - total_input_length; num_require_pages = (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; total_input_length += input_length; total_required_pages += num_require_pages; - if (input_length > 0 && - CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), - total_input_length, total_required_pages, num_available_pages, - current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length}); - num_prefill_rsentries += 1 + rsentry->child_indices.size(); + if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, + num_available_pages, current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length, 0}); + num_prefill_rsentries += 1; } // - Prefill stops here. diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index 1e7d798c26..d70b9d7edc 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -167,9 +167,11 @@ class NewRequestPrefillActionObj : public EngineActionObj { std::vector sample_indices; std::vector rsentries_for_sample; std::vector rngs; + std::vector rsentry_activated; sample_indices.reserve(num_rsentries); rsentries_for_sample.reserve(num_rsentries); rngs.reserve(num_rsentries); + rsentry_activated.reserve(num_rsentries); request_ids.clear(); generation_cfg.clear(); for (int i = 0; i < num_rsentries; ++i) { @@ -179,29 +181,42 @@ class NewRequestPrefillActionObj : public EngineActionObj { continue; } + int remaining_num_child_to_activate = prefill_inputs[i].num_child_to_activate; for (int child_idx : rsentry->child_indices) { - if (rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty()) { - // If rstates_of_entries[i]->entries[child_idx] has no committed token, - // the prefill of the current rsentry will unblock - // rstates_of_entries[i]->entries[child_idx], - // and thus we want to sample a token for rstates_of_entries[i]->entries[child_idx]. - sample_indices.push_back(i); - rsentries_for_sample.push_back(rstates_of_entries[i]->entries[child_idx]); - request_ids.push_back(rsentry->request->id); - generation_cfg.push_back(rsentry->request->generation_cfg); - rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng); - - ICHECK(rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending); - rstates_of_entries[i]->entries[child_idx]->status = RequestStateStatus::kAlive; - for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { - int64_t child_internal_id = - rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id; - models_[model_id]->ForkSequence(rsentry->mstates[model_id]->internal_id, - child_internal_id); - // Enable sliding window for the child sequence if the child is not a parent. - if (rstates_of_entries[i]->entries[child_idx]->child_indices.empty()) { - models_[model_id]->EnableSlidingWindowForSeq(child_internal_id); - } + // If rstates_of_entries[i]->entries[child_idx] has no committed token, + // the prefill of the current rsentry will unblock + // rstates_of_entries[i]->entries[child_idx], + // and thus we want to sample a token for rstates_of_entries[i]->entries[child_idx]. + if (rstates_of_entries[i]->entries[child_idx]->status != RequestStateStatus::kPending || + !rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty()) { + continue; + } + sample_indices.push_back(i); + rsentries_for_sample.push_back(rstates_of_entries[i]->entries[child_idx]); + request_ids.push_back(rsentry->request->id); + generation_cfg.push_back(rsentry->request->generation_cfg); + rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng); + + ICHECK(rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending); + // We only fork the first `num_child_to_activate` children. + // The children not being forked will be forked via later prefills. + // Usually `num_child_to_activate` is the same as the number of children. + // But it can be fewer subject to the KV cache max num sequence limit. + if (remaining_num_child_to_activate == 0) { + rsentry_activated.push_back(false); + continue; + } + rsentry_activated.push_back(true); + --remaining_num_child_to_activate; + rstates_of_entries[i]->entries[child_idx]->status = RequestStateStatus::kAlive; + for (int model_id = 0; model_id < static_cast(models_.size()); ++model_id) { + int64_t child_internal_id = + rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id; + models_[model_id]->ForkSequence(rsentry->mstates[model_id]->internal_id, + child_internal_id); + // Enable sliding window for the child sequence if the child is not a parent. + if (rstates_of_entries[i]->entries[child_idx]->child_indices.empty()) { + models_[model_id]->EnableSlidingWindowForSeq(child_internal_id); } } } @@ -212,6 +227,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { request_ids.push_back(rsentry->request->id); generation_cfg.push_back(rsentry->request->generation_cfg); rngs.push_back(&rsentry->rng); + rsentry_activated.push_back(true); } } std::vector sample_results = sampler_->BatchSampleTokens( @@ -224,6 +240,12 @@ class NewRequestPrefillActionObj : public EngineActionObj { for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) { mstate->CommitToken(sample_results[i]); + if (!rsentry_activated[i]) { + // When the child rsentry is not activated, + // add the sampled token as an input of the mstate for prefill. + mstate->inputs.push_back( + TokenData(std::vector{sample_results[i].sampled_token_id.first})); + } } if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { rsentries_for_sample[i]->tprefill_finish = tnow; @@ -270,7 +292,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { /*! \brief The class of request state entry and its maximum allowed length for prefill. */ struct PrefillInput { RequestStateEntry rsentry; - int max_prefill_length; + int max_prefill_length = 0; + int num_child_to_activate = 0; }; /*! @@ -314,11 +337,19 @@ class NewRequestPrefillActionObj : public EngineActionObj { total_input_length += input_length; total_required_pages += num_require_pages; // - Attempt 1. Check if the entire request state entry can fit for prefill. - if (CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), - total_input_length, total_required_pages, num_available_pages, - current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length}); - num_prefill_rsentries += 1 + rsentry->child_indices.size(); + bool can_prefill = false; + for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0; + --num_child_to_activate) { + if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate, + total_input_length, total_required_pages, num_available_pages, + current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length, num_child_to_activate}); + num_prefill_rsentries += 1 + num_child_to_activate; + can_prefill = true; + break; + } + } + if (can_prefill) { continue; } total_input_length -= input_length; @@ -326,18 +357,26 @@ class NewRequestPrefillActionObj : public EngineActionObj { // - Attempt 2. Check if the request state entry can partially fit by input chunking. ICHECK_LE(total_input_length, kv_cache_config_->prefill_chunk_size); - input_length = - std::min(input_length, kv_cache_config_->prefill_chunk_size - total_input_length); + if (kv_cache_config_->prefill_chunk_size - total_input_length >= input_length || + kv_cache_config_->prefill_chunk_size == total_input_length) { + // 1. If the input length can fit the remaining prefill chunk size, + // it means the failure of attempt 1 is not because of the input + // length being too long, and thus chunking does not help. + // 2. If the total input length already reaches the prefill chunk size, + // the current request state entry will not be able to be processed. + // So we can safely return in either case. + prefill_stops = true; + break; + } + input_length = kv_cache_config_->prefill_chunk_size - total_input_length; num_require_pages = (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; total_input_length += input_length; total_required_pages += num_require_pages; - if (input_length > 0 && - CanPrefill(estate, num_prefill_rsentries + 1 + rsentry->child_indices.size(), - total_input_length, total_required_pages, num_available_pages, - current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length}); - num_prefill_rsentries += 1 + rsentry->child_indices.size(); + if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, + num_available_pages, current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length, 0}); + num_prefill_rsentries += 1; } // - Prefill stops here. diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index a290e64b4d..b376523dac 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -92,6 +92,7 @@ class GPUSampler : public SamplerObj { NVTXScopedRange nvtx_scope("BatchSampleTokens"); // probs_on_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); + CHECK(output_prob_dist == nullptr) << "GPU sampler does not support collecting output probs."; CHECK_EQ(probs_on_device->ndim, 2); int num_samples = sample_indices.size(); int num_probs = probs_on_device->shape[0]; @@ -100,6 +101,50 @@ class GPUSampler : public SamplerObj { ICHECK_EQ(generation_cfg.size(), num_samples); ICHECK_EQ(rngs.size(), num_samples); + // Since `num_samples` may be larger than `max_num_sample_` in some cases, + // we apply chunking to support large `num_samples`. + std::vector sample_results; + if (num_samples <= max_num_sample_) { + sample_results = ChunkSampleTokensImpl(probs_on_device, sample_indices, generation_cfg, rngs); + } else { + for (int chunk_start = 0; chunk_start < num_samples; chunk_start += max_num_sample_) { + int chunk_end = std::min(chunk_start + max_num_sample_, num_samples); + std::vector sample_indices_chunk(sample_indices.begin() + chunk_start, + sample_indices.begin() + chunk_end); + Array generation_cfg_chunk(generation_cfg.begin() + chunk_start, + generation_cfg.begin() + chunk_end); + std::vector rngs_chunk(rngs.begin() + chunk_start, + rngs.begin() + chunk_end); + std::vector sample_results_chunk = ChunkSampleTokensImpl( + probs_on_device, sample_indices_chunk, generation_cfg_chunk, rngs_chunk); + sample_results.insert(sample_results.end(), sample_results_chunk.begin(), + sample_results_chunk.end()); + } + } + + RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + return sample_results; + } + + std::vector> BatchVerifyDraftTokens( + NDArray probs_on_device, const Array& request_ids, + const std::vector& cum_verify_lengths, const Array& generation_cfg, + const std::vector& rngs, + const std::vector>& draft_output_tokens, + const std::vector>& draft_output_prob_dist) final { + LOG(FATAL) << "GPU sampler does not support batch verification for now."; + } + + private: + std::vector ChunkSampleTokensImpl(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& generation_cfg, // + const std::vector& rngs) { + // probs_on_device: (n, v) + int num_samples = sample_indices.size(); + int num_probs = probs_on_device->shape[0]; + int vocab_size = probs_on_device->shape[1]; + // - Generate random numbers. // Copy the random numbers and sample indices. auto [uniform_samples_device, sample_indices_device] = @@ -148,20 +193,9 @@ class GPUSampler : public SamplerObj { SampleResult{{p_sampled_token_ids[i], sampled_prob}, top_prob_tokens}); } - RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); return sample_results; } - std::vector> BatchVerifyDraftTokens( - NDArray probs_on_device, const Array& request_ids, - const std::vector& cum_verify_lengths, const Array& generation_cfg, - const std::vector& rngs, - const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { - LOG(FATAL) << "GPU sampler does not support batch verification for now."; - } - - private: /*! \brief Generate uniform random numbers, and copy the numbers and sample indices to GPU. */ std::pair CopySamplesAndIndicesToGPU(const std::vector& sample_indices, const std::vector& rngs, From 3403a4e981da254751f43964436476be75740511 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 15 Apr 2024 10:21:12 -0400 Subject: [PATCH 188/531] [Docs] Revamp landing page with Engine Python API and server (#2137) This PR revamps the landing documentation page. * The Python API panel is changed from showing ChatModule to showing Engine. * A new panel "REST Server" is added to show a quick start example of launching REST server and send request. * A "what to do next" section is introduced at the bottom of the landing page. Todo items for future PR: * add the page of Python API with Engine. * revamp weight conversion page. * revamp model library compilation page. --- docs/compilation/compile_models.rst | 2 +- docs/compilation/convert_weights.rst | 2 +- docs/deploy/javascript.rst | 2 +- .../{python.rst => python_chat_module.rst} | 18 ++-- docs/deploy/python_engine.rst | 15 +++ docs/index.rst | 99 ++++++++++++++----- examples/python/sample_mlc_engine.py | 17 ++++ python/mlc_llm/__init__.py | 2 + 8 files changed, 121 insertions(+), 36 deletions(-) rename docs/deploy/{python.rst => python_chat_module.rst} (96%) create mode 100644 docs/deploy/python_engine.rst create mode 100644 examples/python/sample_mlc_engine.py diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index b30076f018..00beb5cc4d 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -21,7 +21,7 @@ We compile ``RedPajama-INCITE-Chat-3B-v1`` with ``q4f16_1`` as an example for al Before you proceed, make sure you followed :ref:`install-tvm-unity`, a required backend to compile models with MLC LLM. - Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python` to obtain + Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-chat-module` to obtain the CLI app / Python API that can be used to chat with the compiled model. Finally, we strongly recommend you to read :ref:`project-overview` first to get familiarized with the high-level terminologies. diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst index 2507687c21..aa65256fd6 100644 --- a/docs/compilation/convert_weights.rst +++ b/docs/compilation/convert_weights.rst @@ -24,7 +24,7 @@ This can be extended to, e.g.: Before you proceed, make sure you followed :ref:`install-tvm-unity`, a required backend to compile models with MLC LLM. - Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python` to obtain + Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-chat-module` to obtain the CLI app / Python API that can be used to chat with the compiled model. Finally, we strongly recommend you to read :ref:`project-overview` first to get familiarized with the high-level terminologies. diff --git a/docs/deploy/javascript.rst b/docs/deploy/javascript.rst index 57f192f61a..bd92908cff 100644 --- a/docs/deploy/javascript.rst +++ b/docs/deploy/javascript.rst @@ -1,6 +1,6 @@ .. _webllm-runtime: -WebLLM and Javascript API +WebLLM and JavaScript API ========================= .. contents:: Table of Contents diff --git a/docs/deploy/python.rst b/docs/deploy/python_chat_module.rst similarity index 96% rename from docs/deploy/python.rst rename to docs/deploy/python_chat_module.rst index 38cdec2f85..5776e29138 100644 --- a/docs/deploy/python.rst +++ b/docs/deploy/python_chat_module.rst @@ -1,15 +1,21 @@ -.. _deploy-python: +.. _deploy-python-chat-module: -Python API -========== +Python API (Chat Module) +======================== + +.. note:: + ❗ The Python API with :class:`mlc_llm.ChatModule` introduced in this page will be + deprecated in the near future. + Please go to :ref:`deploy-python-engine` for the latest Python API with complete + OpenAI API support. .. contents:: Table of Contents :local: :depth: 2 -We expose Python API for the MLC-Chat for easy integration into other Python projects. +We expose ChatModule Python API for the MLC-LLM for easy integration into other Python projects. -The Python API is a part of the MLC-Chat package, which we have prepared pre-built pip wheels via +The Python API is a part of the MLC-LLM package, which we have prepared pre-built pip wheels via the :doc:`installation page <../install/mlc_llm>`. Instead of following this page, you could also checkout the following tutorials in @@ -340,7 +346,7 @@ We provide an example below. API Reference ------------- -User can initiate a chat module by creating :class:`mlc_llm.ChatModule` class, which is a wrapper of the MLC-Chat model. +User can initiate a chat module by creating :class:`mlc_llm.ChatModule` class, which is a wrapper of the MLC-LLM model. The :class:`mlc_llm.ChatModule` class provides the following methods: .. currentmodule:: mlc_llm diff --git a/docs/deploy/python_engine.rst b/docs/deploy/python_engine.rst new file mode 100644 index 0000000000..60b9acc4a0 --- /dev/null +++ b/docs/deploy/python_engine.rst @@ -0,0 +1,15 @@ +.. _deploy-python-engine: + +Python API +========== + +.. note:: + This page introduces the Python API with Engine in MLC LLM. + If you want to check out the old Python API which uses :class:`mlc_llm.ChatModule`, + please go to :ref:`deploy-python-chat-module` + +.. contents:: Table of Contents + :local: + :depth: 2 + +🚧 Under construction... diff --git a/docs/index.rst b/docs/index.rst index 2aabd613bf..721d9c227c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,49 +17,75 @@ It is recommended to have at least 6GB free VRAM to run it. .. tab:: Python - **Install MLC LLM Python**. :doc:`MLC LLM ` is available via pip. + **Install MLC LLM**. :doc:`MLC LLM ` is available via pip. It is always recommended to install it in an isolated conda virtual environment. - **Download pre-quantized weights**. The commands below download the int4-quantized Llama2-7B from HuggingFace: + **Run chat completion in Python.** The following Python script showcases the Python API of MLC LLM: - .. code:: bash + .. code:: python - git lfs install && mkdir dist/ - git clone https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC \ - dist/Llama-2-7b-chat-hf-q4f16_1-MLC + from mlc_llm import Engine - **Download pre-compiled model library**. The pre-compiled model library is available as below: + # Create engine + model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" + engine = Engine(model) - .. code:: bash + # Run chat completion in OpenAI API. + for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, + ): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) + print("\n") - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs + engine.terminate() - **Run in Python.** The following Python script showcases the Python API of MLC LLM and its stream capability: + .. Todo: link the colab notebook when ready: - .. code:: python + **Documentation and tutorial.** Python API reference and its tutorials are :doc:`available online `. + + .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-engine-api.jpg + :width: 600 + :align: center + + MLC LLM Python API + + .. tab:: REST Server + + **Install MLC LLM**. :doc:`MLC LLM ` is available via pip. + It is always recommended to install it in an isolated conda virtual environment. + + **Launch a REST server.** Run the following command from command line to launch a REST server at ``http://127.0.0.1:8000``. - from mlc_llm import ChatModule - from mlc_llm.callback import StreamToStdout + .. code:: shell - cm = ChatModule( - model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so - # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so - # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} - ) - cm.generate(prompt="What is the meaning of life?", progress_callback=StreamToStdout(callback_interval=2)) + mlc_llm serve HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC - **Colab walkthrough.** A Jupyter notebook on `Colab `_ - is provided with detailed walkthrough of the Python API. + **Send requests to server.** When the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``), + open a new shell and send a request via the following command: - **Documentation and tutorial.** Python API reference and its tutorials are `available online `_. + .. code:: shell - .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-api.jpg + curl -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", + "messages": [ + {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} + ] + }' \ + http://127.0.0.1:8000/v1/chat/completions + + **Documentation and tutorial.** Check out :ref:`deploy-rest-api` for the REST API reference and tutorial. + Our REST API has complete OpenAI API support. + + .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-serve-request.jpg :width: 600 :align: center - MLC LLM Python API + Send HTTP request to REST server in MLC LLM .. tab:: Command Line @@ -149,6 +175,25 @@ It is recommended to have at least 6GB free VRAM to run it. MLC LLM on Android +What to Do Next +--------------- + +- Depending on your use case, check out our API documentation and tutorial pages: + + - :ref:`webllm-runtime` + - :ref:`deploy-rest-api` + - :ref:`deploy-cli` + - :ref:`deploy-python-engine` + - :ref:`deploy-ios` + - :ref:`deploy-android` + - :ref:`deploy-ide-integration` + +- Deploy your local model: check out :ref:`convert-weights-via-MLC` to convert your model weights to MLC format. +- Deploy models to Web or build iOS/Android apps on your own: check out :ref:`compile-model-libraries` to compile the models into binary libraries. +- Customize model optimizations: check out :ref:`compile-model-libraries`. +- Report any problem or ask any question: open new issues in our `GitHub repo `_. + + .. toctree:: :maxdepth: 1 :caption: Get Started @@ -165,7 +210,7 @@ It is recommended to have at least 6GB free VRAM to run it. deploy/javascript.rst deploy/rest.rst deploy/cli.rst - deploy/python.rst + deploy/python_engine.rst deploy/ios.rst deploy/android.rst deploy/ide_integration.rst diff --git a/examples/python/sample_mlc_engine.py b/examples/python/sample_mlc_engine.py new file mode 100644 index 0000000000..9c65bd4c51 --- /dev/null +++ b/examples/python/sample_mlc_engine.py @@ -0,0 +1,17 @@ +from mlc_llm import Engine + +# Create engine +model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" +engine = Engine(model) + +# Run chat completion in OpenAI API. +for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, +): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) +print("\n") + +engine.terminate() diff --git a/python/mlc_llm/__init__.py b/python/mlc_llm/__init__.py index f577e0308e..b891323a5a 100644 --- a/python/mlc_llm/__init__.py +++ b/python/mlc_llm/__init__.py @@ -2,6 +2,8 @@ MLC Chat is the app runtime of MLC LLM. """ + from . import protocol, serve from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig from .libinfo import __version__ +from .serve import AsyncEngine, Engine From 4cbda040176dce418838908244ce8c0bf569d94f Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 16 Apr 2024 12:54:58 +0800 Subject: [PATCH 189/531] [Target] Update Target tags (#2141) The commit updates the target tags, in order to identify the different SoC hardware targets for further target-specific optimizations. Meanwhile, update the vulkan support for int64. --- python/mlc_llm/support/auto_target.py | 45 ++++++++++++++++++- .../python/integration/test_model_compile.py | 5 ++- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index f000cc85b2..3cf49c43ba 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -193,6 +193,24 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None): return build +def _build_android_so(): + def build(mod: IRModule, args: "CompileArgs", pipeline=None): + output = args.output + mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=False) + assert output.suffix == ".so" + relax.build( + mod, + target=args.target, + pipeline=pipeline, + system_lib=False, + ).export_library( + str(output), + fcompile=ndk.create_shared, + ) + + return build + + def _build_webgpu(): def build(mod: IRModule, args: "CompileArgs", pipeline=None): output = args.output @@ -330,7 +348,9 @@ def detect_system_lib_prefix( prefix_hint : str The hint for the system lib prefix. """ - if prefix_hint == "auto" and target_hint in ["iphone", "android"]: + if prefix_hint == "auto" and ( + target_hint.startswith("iphone") or target_hint.startswith("android") + ): prefix = f"{model_name}_{quantization}_".replace("-", "_") logger.warning( "%s is automatically picked from the filename, %s, this allows us to use the filename " @@ -370,6 +390,28 @@ def detect_system_lib_prefix( }, "build": _build_android, }, + "android:adreno": { + "target": { + "kind": "opencl", + "device": "adreno", + "host": { + "kind": "llvm", + "mtriple": "aarch64-linux-android", + }, + }, + "build": _build_android, + }, + "android:adreno-so": { + "target": { + "kind": "opencl", + "device": "adreno", + "host": { + "kind": "llvm", + "mtriple": "aarch64-linux-android", + }, + }, + "build": _build_android_so, + }, "metal:x86-64": { "target": { "kind": "metal", @@ -419,6 +461,7 @@ def detect_system_lib_prefix( "max_shared_memory_per_block": 32768, "thread_warp_size": 1, "supports_float16": 1, + "supports_int64": 1, "supports_int16": 1, "supports_int8": 1, "supports_8bit_buffer": 1, diff --git a/tests/python/integration/test_model_compile.py b/tests/python/integration/test_model_compile.py index 2f136f3f16..3ec70b61b3 100644 --- a/tests/python/integration/test_model_compile.py +++ b/tests/python/integration/test_model_compile.py @@ -39,12 +39,13 @@ "max_num_threads": 256, "max_shared_memory_per_block": 32768, "thread_warp_size": 1, - "supports_int16": 1, "supports_float32": 1, + "supports_float16": 1, + "supports_int64": 1, "supports_int32": 1, + "supports_int16": 1, "supports_int8": 1, "supports_16bit_buffer": 1, - "supports_float16": 1, }, "metal": "metal", "wasm": "webgpu", From 8f33c30d1d5459275abe9d1e9f28478e2d04be08 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 16 Apr 2024 21:39:24 +0800 Subject: [PATCH 190/531] [Util] Support debug debug_compare (#2142) --- python/mlc_llm/testing/debug_chat.py | 2 +- python/mlc_llm/testing/debug_compare.py | 249 ++++++++++++++++++++++++ 2 files changed, 250 insertions(+), 1 deletion(-) create mode 100644 python/mlc_llm/testing/debug_compare.py diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index a88f3d68b8..2a70154bba 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -132,7 +132,7 @@ def __call__(self, func, name, before_run, ret_val, *args): class DebugChat: # pylint: disable=too-many-instance-attributes, too-few-public-methods """A chat interface used only for debugging purpose. - It debugs autoregressive decoding fully in Python via the prefill and + It debugs auto-regressive decoding fully in Python via the prefill and decode interface. It supports debugging instrument (either default or customized) to dump intermediate values for each VM function call. diff --git a/python/mlc_llm/testing/debug_compare.py b/python/mlc_llm/testing/debug_compare.py new file mode 100644 index 0000000000..b3487e3e48 --- /dev/null +++ b/python/mlc_llm/testing/debug_compare.py @@ -0,0 +1,249 @@ +"""Debug compiled models with TVM instrument""" + +import os +from pathlib import Path +from typing import Dict, List, Set, Tuple + +import tvm +from tvm import rpc, runtime +from tvm.relax.testing.lib_comparator import LibCompareVMInstrument + +from mlc_llm.help import HELP +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.testing.debug_chat import DebugChat + + +def _print_as_table(sorted_list): + print("=" * 100) + print( + "Name".ljust(50) + + "Time (ms)".ljust(12) + + "Count".ljust(8) + + "Total time (ms)".ljust(18) + + "Percentage (%)" + ) + total_time = sum(record[1][0] * record[1][1] for record in sorted_list) * 1000 + for record in sorted_list: + time = record[1][0] * 1000 + weighted_time = time * record[1][1] + percentage = weighted_time / total_time * 100 + print( + record[0].ljust(50) + + f"{time:.4f}".ljust(12) + + str(record[1][1]).ljust(8) + + f"{weighted_time:.4f}".ljust(18) + + f"{percentage:.2f}" + ) + print(f"Total time: {total_time:.4f} ms") + + +class LibCompare(LibCompareVMInstrument): + """The default debug instrument to use if users don't specify + a customized one. + + This debug instrument will dump the arguments and output of each + VM Call instruction into a .npz file. It will also alert the user + if any function outputs are NaN or INF. + + Parameters + ---------- + mod: runtime.Module + The module of interest to be validated. + + device: runtime.Device + The device to run the target module on. + + time_eval: bool + Whether to time evaluate the functions. + + rtol: float + rtol used in validation + + atol: float + atol used in validation + """ + + def __init__( # pylint: disable=too-many-arguments, unused-argument + self, + mod: runtime.Module, + device: runtime.Device, + debug_dir: Path, + time_eval: bool = True, + rtol: float = 1e-2, + atol: float = 1, + skip_rounds: int = 0, + ): + super().__init__(mod, device, True, rtol, atol) + self.time_eval = time_eval + self.time_eval_results: Dict[str, Tuple[float, int]] = {} + self.visited: Set[str] = set([]) + self.skip_rounds = skip_rounds + self.counter = 0 + + def reset(self, debug_dir: Path): # pylint: disable=unused-argument + """Reset the state of the Instrument class + + Note + ---- + `debug_dir` is not used in this class. + + Parameters + ---------- + debug_out : Path + the directory to dump the .npz files + """ + _print_as_table( + sorted( + self.time_eval_results.items(), + key=lambda x: -(x[1][0] * x[1][1]), + ) + ) + self.time_eval_results = {} + self.visited = set([]) + self.counter = 0 + + def skip_instrument(self, func, name, before_run, ret_val, *args): + if name.startswith("shape_func"): + return True + if self.counter < self.skip_rounds: + self.counter += 1 + print(f"[{self.counter}] Skip validating {name}..") + return True + if name in self.visited: + if self.time_eval and name in self.time_eval_results: + record = self.time_eval_results[name] + self.time_eval_results[name] = (record[0], record[1] + 1) + return True + self.visited.add(name) + return False + + def compare( + self, + name: str, + ref_args: List[tvm.nd.NDArray], + new_args: List[tvm.nd.NDArray], + ret_indices: List[int], + ): + super().compare(name, ref_args, new_args, ret_indices) + + if self.time_eval and name not in self.time_eval_results: + res = self.mod.time_evaluator( + name, self.device, number=20, repeat=3 # , cache_flush_bytes=256 * 10**6 + )(*new_args) + self.time_eval_results[name] = (res.mean, 1) + print(f"Time-eval result {name} on {self.device}:\n {res}") + + +def get_instrument(args): + """Get the debug instrument from the CLI arguments""" + if args.cmp_device is None: + assert args.cmp_lib_path is None, "cmp_lib_path must be None if cmp_device is None" + args.cmp_device = args.device + args.cmp_lib_path = args.model_lib_path + + if args.cmp_device == "iphone": + assert args.cmp_lib_path.endswith(".dylib"), "Require a dylib file for iPhone" + proxy_host = os.environ.get("TVM_RPC_PROXY_HOST", "127.0.0.1") + proxy_port = int(os.environ.get("TVM_RPC_PROXY_PORT", "9090")) + sess = rpc.connect(proxy_host, proxy_port, "iphone") + sess.upload(args.cmp_lib_path) + lib = sess.load_module(os.path.basename(args.cmp_lib_path)) + cmp_device = sess.metal() + elif args.cmp_device == "android": + assert args.cmp_lib_path.endswith(".so"), "Require a so file for Android" + tracker_host = os.environ.get("TVM_TRACKER_HOST", "0.0.0.0") + tracker_port = int(os.environ.get("TVM_TRACKER_PORT", "9190")) + tracker = rpc.connect_tracker(tracker_host, tracker_port) + sess = tracker.request("android") + sess.upload(args.cmp_lib_path) + lib = sess.load_module(os.path.basename(args.cmp_lib_path)) + cmp_device = sess.cl(0) + else: + lib = tvm.runtime.load_module( + os.path.join( + args.artifact_path, + f"{args.model}-{args.quantization.name}-{args.cmp_device}.so", + ) + ) + cmp_device = tvm.device(args.cmp_device) + + return LibCompare( + lib, + cmp_device, + time_eval=args.time_eval, + debug_dir=Path(args.debug_dir), + ) + + +def main(): + """The main function to start a DebugChat CLI""" + + parser = ArgumentParser("MLC LLM Chat Debug Tool") + parser.add_argument( + "prompt", + type=str, + help="The user input prompt.", + ) + parser.add_argument( + "--generate-len", type=int, help="Number of output tokens to generate.", required=True + ) + parser.add_argument( + "--model", + type=str, + help="An MLC model directory that contains `mlc-chat-config.json`", + required=True, + ) + parser.add_argument( + "--model-lib-path", + type=str, + help="The full path to the model library file to use (e.g. a ``.so`` file).", + required=True, + ) + parser.add_argument( + "--debug-dir", + type=str, + help="The output folder to store the dumped debug files.", + required=True, + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help=HELP["device_compile"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--cmp-device", + type=str, + default="none", + ) + parser.add_argument( + "--cmp-lib-path", + type=str, + default="none", + ) + parser.add_argument( + "--time-eval", + action="store_true", + help="Whether to time evaluate the functions.", + ) + parsed = parser.parse_args() + instrument = get_instrument(parsed) + debug_chat = DebugChat( + model=parsed.model, + model_lib_path=parsed.model_lib_path, + debug_dir=Path(parsed.debug_dir), + device=parsed.device, + debug_instrument=instrument, + ) + debug_chat.generate(parsed.prompt, parsed.generate_len) + # Only print decode for now + _print_as_table( + sorted( + instrument.time_eval_results.items(), + key=lambda x: -(x[1][0] * x[1][1]), + ) + ) + + +if __name__ == "__main__": + main() From 3d25d9da762aab7cd89bfffb8b310f515b2ddabb Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 16 Apr 2024 19:25:33 -0700 Subject: [PATCH 191/531] [Minor][SpecInfer] Fix Optional FC Bias for Mixtral Eagle Model (#2146) * Add optional fc bias for mixtral. * Fix lint. --- python/mlc_llm/model/eagle/eagle_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/mlc_llm/model/eagle/eagle_model.py b/python/mlc_llm/model/eagle/eagle_model.py index ba647604de..355618df09 100644 --- a/python/mlc_llm/model/eagle/eagle_model.py +++ b/python/mlc_llm/model/eagle/eagle_model.py @@ -22,6 +22,8 @@ class EagleConfig(LlamaConfig): """Configuration of the Eagle model.""" + bias: bool = True # Whether to use bias in the fc layers + # pylint: disable=invalid-name,missing-docstring @@ -77,7 +79,7 @@ def __init__(self, config: EagleConfig): [EagleDecoderLayer(config, i) for i in range(config.num_hidden_layers)] ) self.fc = nn.Linear( - in_features=2 * config.hidden_size, out_features=config.hidden_size, bias=True + in_features=2 * config.hidden_size, out_features=config.hidden_size, bias=config.bias ) self.num_hidden_layers = config.num_hidden_layers From 2de2875a77b4eef9dc2b086f0e1e0b13bbcf2ec1 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 17 Apr 2024 05:52:10 -0700 Subject: [PATCH 192/531] [Serving] fix hardcoded host and port in popen_server (#2147) --- python/mlc_llm/serve/server/popen_server.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index 86f92d7602..08f5dc229e 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -1,5 +1,6 @@ """The MLC LLM server launched in a subprocess.""" +import os import subprocess import sys import time @@ -79,13 +80,15 @@ def start(self) -> None: # pylint: disable=too-many-branches cmd += ["--host", self.host] cmd += ["--port", str(self.port)] process_path = str(Path(__file__).resolve().parents[4]) - self._proc = subprocess.Popen(cmd, cwd=process_path) # pylint: disable=consider-using-with + self._proc = subprocess.Popen( # pylint: disable=consider-using-with + cmd, cwd=process_path, env=os.environ + ) # NOTE: DO NOT USE `stdout=subprocess.PIPE, stderr=subprocess.PIPE` # in subprocess.Popen here. PIPE has a fixed-size buffer with may block # and hang forever. # Try to query the server until it is ready. - openai_v1_models_url = "http://127.0.0.1:8000/v1/models" + openai_v1_models_url = f"http://{self.host}:{str(self.port)}/v1/models" query_result = None timeout = 60 attempts = 0.0 From 8c673b47f576b0cf85b3a22c4a009a034617832b Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 17 Apr 2024 08:52:52 -0400 Subject: [PATCH 193/531] [Docs] Introductory tutorial (#2145) This PR updates the documentation with an introduction turorial. The landing page now directs to the quick start page and the tutorial. --- docs/_static/img/project-workflow.svg | 1173 +++++++++++++++++ docs/community/faq.rst | 2 +- docs/compilation/get-vicuna-weight.rst | 68 - docs/conf.py | 2 - .../mlc_chat_config.rst | 5 +- docs/get_started/intro.rst | 311 +++++ docs/get_started/project_overview.rst | 4 +- docs/get_started/quick_start.rst | 190 +++ docs/index.rst | 196 +-- docs/prebuilt_models.rst | 4 +- 10 files changed, 1691 insertions(+), 264 deletions(-) create mode 100644 docs/_static/img/project-workflow.svg delete mode 100644 docs/compilation/get-vicuna-weight.rst rename docs/{get_started => deploy}/mlc_chat_config.rst (99%) create mode 100644 docs/get_started/intro.rst create mode 100644 docs/get_started/quick_start.rst diff --git a/docs/_static/img/project-workflow.svg b/docs/_static/img/project-workflow.svg new file mode 100644 index 0000000000..eac1313a44 --- /dev/null +++ b/docs/_static/img/project-workflow.svg @@ -0,0 +1,1173 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/community/faq.rst b/docs/community/faq.rst index 3913dd9639..4bc6f9deb8 100644 --- a/docs/community/faq.rst +++ b/docs/community/faq.rst @@ -6,7 +6,7 @@ Frequently Asked Questions This is a list of Frequently Asked Questions (FAQ) about the MLC-LLM. Feel free to suggest new entries! ... How can I customize the temperature, and repetition penalty of models? - Please check our :doc:`/get_started/mlc_chat_config` tutorial. + Please check our :ref:`configure-mlc-chat-json` tutorial. ... What's the quantization algorithm MLC-LLM using? Please check our :doc:`/compilation/configure_quantization` tutorial. diff --git a/docs/compilation/get-vicuna-weight.rst b/docs/compilation/get-vicuna-weight.rst deleted file mode 100644 index 2ea4ba5d97..0000000000 --- a/docs/compilation/get-vicuna-weight.rst +++ /dev/null @@ -1,68 +0,0 @@ -Getting Vicuna Weights -====================== - -.. contents:: Table of Contents - :local: - :depth: 2 - -`Vicuna `_ is an open-source chatbot trained by fine-tuning `LLaMA `_ on `ShartGPT `_ data. - -Please note that the official Vicuna weights are delta weights applied to the LLaMA weights in order to comply with the LLaMA license. Users are responsible for applying these delta weights themselves. - -In this tutorial, we will show how to apply the delta weights to LLaMA weights to get Vicuna weights. - -Install FastChat ----------------- - -FastChat offers convenient utility functions for applying the delta to LLaMA weights. You can easily install it using pip. - -.. code-block:: bash - - pip install fschat - -Download HuggingFace LLaMA Weights ----------------------------------- - -The HuggingFace LLaMA weights are hosted using Git-LFS. Therefore, it is necessary to install Git-LFS first (you can ignore this step if git-lfs is already installed). - -.. code-block:: bash - - conda install git-lfs - git lfs install - -Then download the weights (both the LLaMA weight and Vicuna delta weight): - -.. code-block:: bash - - git clone https://huggingface.co/decapoda-research/llama-7b-hf - git clone https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 - - -There is a name misalignment issue in the LLaMA weights and Vicuna delta weights. -Please follow these steps to modify the content of the "config.json" file: - -.. code-block:: bash - - sed -i 's/LLaMAForCausalLM/LlamaForCausalLM/g' llama-7b-hf/config.json - -Then use ``fschat`` to apply the delta to LLaMA weights - -.. code-block:: bash - - python3 -m fastchat.model.apply_delta \ - --base-model-path llama-7b-hf \ - --target-model-path vicuna-7b-v1.1 \ - --delta-path vicuna-7b-delta-v1.1 - -You will get the Vicuna weights in ``vicuna-7b-v1.1`` folder, which can be used as input of MLC-LLM to further compile models. - - -(Optional) Move Vicuna Weights to dist folder ---------------------------------------------- - -The default model path of MLC-LLM is ``dist`` folder. Therefore, it is recommended to move the Vicuna weights to ``dist`` folder. - -.. code-block:: bash - - mkdir -p dist/models - mv vicuna-7b-v1.1 dist/models/vicuna-7b-v1.1 diff --git a/docs/conf.py b/docs/conf.py index 0f7ed19014..7743ef2985 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -9,8 +9,6 @@ sys.path.insert(0, os.path.abspath("../python")) sys.path.insert(0, os.path.abspath("../")) autodoc_mock_imports = ["torch"] -# do not load mlc-llm.so in docs -os.environ["SKIP_LOADING_MLCLLM_SO"] = "1" # General information about the project. project = "mlc-llm" diff --git a/docs/get_started/mlc_chat_config.rst b/docs/deploy/mlc_chat_config.rst similarity index 99% rename from docs/get_started/mlc_chat_config.rst rename to docs/deploy/mlc_chat_config.rst index 482e68d368..948d50bddd 100644 --- a/docs/get_started/mlc_chat_config.rst +++ b/docs/deploy/mlc_chat_config.rst @@ -1,7 +1,7 @@ .. _configure-mlc-chat-json: -Configure MLCChat in JSON -========================= +Customize MLC Config File in JSON +================================= ``mlc-chat-config.json`` is required for both compile-time and runtime, hence serving two purposes: @@ -81,6 +81,7 @@ can be customized to change the behavior of the model.** Legacy ``mlc-chat-config.json`` may specify a string for this field to look up a registered conversation template. It will be deprecated in the future. Re-generate config using the latest version of mlc_llm to make sure this field is a complete JSON object. + The conversation template that this chat uses. For more information, please refer to :ref:`conversation structure `. ``temperature`` diff --git a/docs/get_started/intro.rst b/docs/get_started/intro.rst new file mode 100644 index 0000000000..c76457647a --- /dev/null +++ b/docs/get_started/intro.rst @@ -0,0 +1,311 @@ +.. _introduction-to-mlc-llm: + +Introduction to MLC LLM +======================= + +.. contents:: Table of Contents + :local: + :depth: 2 + +Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance universal +deployment solution that allows native deployment of any large language models with native APIs +with compiler acceleration. +The mission of this project is to enable everyone to develop, optimize and deploy AI models +natively on everyone's devices with ML compilation techniques. + +This page is a quick tutorial to introduce how to try out MLC LLM, and the core steps to +deploy your own models with MLC LLM. + +Installation +------------ + +:ref:`MLC LLM ` is available via pip. +It is always recommended to install it in an isolated conda virtual environment. + +To verify the installation, activate your virtual environment, run + +.. code:: bash + + python -c "import mlc_llm; print(mlc_llm.__path__)" + +You are expected to see the installation path of MLC LLM Python package. + + +Chat CLI +-------- + +As the first example, we try out the chat CLI in MLC LLM with 4-bit quantized 7B Llama-2 model. +The simplest command to run MLC chat is a one-liner command: + +.. code:: bash + + mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + +It may take 1-2 minutes for the first time running this command. +After waiting, this command launch a chat interface where you can enter your prompt and chat with the model. + +.. code:: + + You can use the following special commands: + /help print the special commands + /exit quit the cli + /stats print out the latest stats (token/sec) + /reset restart a fresh chat + /set [overrides] override settings in the generation config. For example, + `/set temperature=0.5;max_gen_len=100;stop=end,stop` + Note: Separate stop words in the `stop` option with commas (,). + Multi-line input: Use escape+enter to start a new line. + + [INST]: What's the meaning of life? + [/INST]: + Ah, a question that has puzzled philosophers and theologians for centuries! ... + + +The figure below shows what run under the hood of this chat CLI command. +For the first time running the command, there are three major phases. + +- **Phase 1. Pre-quantized weight download.** This phase automatically downloads pre-quantized Llama-2 model from `Hugging Face `_ and saves it to your local cache directory. +- **Phase 2. Model compilation.** This phase automatically optimizes the Llama-2 model to accelerate model inference on GPU with techniques of machine learning compilation in `Apache TVM `_ compiler, and generate the binary model library that enables the execution language models on your local GPU. +- **Phase 3. Chat runtime.** This phase consumes the model library built in phase 2 and the model weights downloaded in phase 1, launches a platform-native chat runtime to drive the execution of Llama-2 model. + +We cache the pre-quantized model weights and compiled model library locally. +Therefore, phase 1 and 2 will only execute **once** over multiple runs. + +.. figure:: /_static/img/project-workflow.svg + :width: 700 + :align: center + :alt: Project Workflow + + Workflow in MLC LLM + +| + +.. _introduction-to-mlc-llm-python-api: + +Python API +---------- + +In the second example, we run the Llama-2 model with the chat completion Python API of MLC LLM. +You can save the code below into a Python file and run it. + +.. code:: python + + from mlc_llm import Engine + + # Create engine + model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" + engine = Engine(model) + + # Run chat completion in OpenAI API. + for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, + ): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) + print("\n") + + engine.terminate() + +.. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-engine-api.jpg + :width: 500 + :align: center + + MLC LLM Python API + +This code example first creates an :class:`mlc_llm.Engine` instance with the the 4-bit quantized Llama-2 model. +**The Python API of** :class:`mlc_llm.Engine` **if fully compatible with OpenAI API**, +which means you can use :class:`mlc_llm.Engine` in the same way of using `OpenAI's Python package `_ +for both synchronous and asynchronous generation. + +In this code example, we use the synchronous chat completion interface and iterate over +all the stream responses. +If you want to run without streaming, you can run + +.. code:: python + + response = engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=False, + ) + print(response) + +You can also try different arguments supported in `OpenAI chat completion API `_. + + +REST Server +----------- + +For the third example, we launch a REST server to serve the 4-bit quantized Llama-2 model +for OpenAI chat completion requests. +The server can be launched in command line with + +.. code:: bash + + mlc_llm serve HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + +The server is hooked at ``http://127.0.0.1:8000`` by default, and you can use ``--host`` and ``--port`` +to set a different host and port. +When the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``), +we can open a new shell and send a cURL request via the following command: + +.. code:: bash + + curl -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", + "messages": [ + {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} + ] + }' \ + http://127.0.0.1:8000/v1/chat/completions + +The server will process this request and send back the response. +Similar to :ref:`introduction-to-mlc-llm-python-api`, you can pass argument ``"stream": true`` +to request for stream responses. + + +Deploy Your Own Model +--------------------- + +So far we have been using pre-converted models weights from Hugging Face. +This section introduces the core workflow regarding how you can *run your own models with MLC LLM*. + +We use the `Phi-2 `_ as the example model. +Assuming the Phi-2 model is downloaded and placed under ``models/phi-2``, +there are two major steps to prepare your own models. + +- **Step 1. Generate MLC config.** The first step is to generate the configuration file of MLC LLM. + + .. code:: bash + + export LOCAL_MODEL_PATH=models/phi-2 # The path where the model resides locally. + export MLC_MODEL_PATH=dist/phi-2-MLC/ # The path where to place the model processed by MLC. + export QUANTIZATION=q0f16 # The choice of quantization. + export CONV_TEMPLATE=phi-2 # The choice of conversation template. + mlc_llm gen_config $LOCAL_MODEL_PATH \ + --quantization $QUANTIZATION \ + --conv-template $CONV_TEMPLATE \ + -o $MLC_MODEL_PATH + + The config generation command takes in the local model path, the target path of MLC output, + the conversation template name in MLC and the quantization name in MLC. + Here the quantization ``q0f16`` means float16 without quantization, + and the conversation template ``phi-2`` is the Phi-2 model's template in MLC. + + If you want to enable tensor parallelism on multiple GPUs, add argument + ``--tensor-parallel-shards $NGPU`` to the config generation command. + + - `The full list of supported quantization in MLC `_. You can try different quantization methods with MLC LLM. Typical quantization methods are ``q4f16_1`` for 4-bit group quantization, ``q4f16_ft`` for 4-bit FasterTransformer format quantization. + - `The full list of conversation template in MLC `_. + +- **Step 2. Convert model weights.** In this step, we convert the model weights to MLC format. + + .. code:: bash + + mlc_llm convert_weight $LOCAL_MODEL_PATH \ + --quantization $QUANTIZATION \ + -o $MLC_MODEL_PATH + + This step consumes the raw model weights and converts them to for MLC format. + The converted weights will be stored under ``$MLC_MODEL_PATH``, + which is the same directory where the config file generated in Step 1 resides. + +Now, we can try to run your own model with chat CLI: + +.. code:: bash + + mlc_llm chat $MLC_MODEL_PATH + +For the first run, model compilation will be triggered automatically to optimize the +model for GPU accelerate and generate the binary model library. +The chat interface will be displayed after model compilation finishes. +By simply replacing the model string ``HF://xxx`` with ``$MLC_MODEL_PATH``, +you can also use this model in Python API, MLC serve and other use scenarios. + +(Optional) Compile Model Manually +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In previous sections, model libraries are compiled automatically when the runtime +chat module or :class:`mlc_llm.Engine` launches, +which is what we call "JIT (Just-in-Time) model compilation". +In some cases (e.g., web / mobile deployment), it is beneficial to manually compile the model libraries, +so that we can deploy LLMs on platforms that come with no compiler environment, +with only the compiled model libraries being shipped. +Below is an example command of compiling model libraries in MLC LLM: + +.. code:: bash + + export $MODEL_LIB_PATH=$MLC_MODEL_PATH/lib.so # ".dylib" for Intel Macs. + # ".dll" for Windows. + # ".wasm" for web. + # ".tar" for iPhone/Android. + mlc_llm compile $MLC_MODEL_PATH -o $MODEL_LIB_PATH + +At runtime, we need to specify this model library path to use it. For example, + +.. code:: bash + + # For chat CLI + mlc_llm chat $MLC_MODEL_PATH --model-lib-path $MODEL_LIB_PATH + # For REST server + mlc_llm serve $MLC_MODEL_PATH --model-lib-path $MODEL_LIB_PATH + +.. code:: python + + from mlc_llm import Engine + + # For Python API + model = "models/phi-2" + model_lib_path = "models/phi-2/lib.so" + engine = Engine(model, model_lib_path=model_lib_path) + +:ref:`compile-model-libraries` introduces the model compilation command in detail, +where you can find instructions and example commands to compile model to different +hardware backends, such as WebGPU, iOS and Android. + +Universal Deployment +-------------------- + +MLC LLM is high-performance universal deployment solution for large language models. +The examples we ran above use your native local GPU environment (CUDA, ROCm or Metal). + +If your local environment is CUDA or ROCm, we can quickly try out the command below +to experience the universal deployment. +This command launches chat CLI with Vulkan runtime rather than CUDA/ROCm runtime. + +.. code:: bash + + mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC --device vulkan + +Summary +------- + +To briefly summarize this page, + +- we went through three examples (chat CLI, Python API, and REST server) of running language models with MLC LLM, +- we introduced how to generate MLC config file and convert model weights for your own models to run with MLC LLM, and (optionally) how to compile your models manually. +- we showcased the universal deployment of MLC LLM. + +What to Do Next +--------------- + +Next, you can check out the pages below for quick start examples and more detailed information. + +- :ref:`Quick start examples ` for Python API, chat CLI, REST server, web browser, iOS and Android. +- Depending on your use case, check out our API documentation and tutorial pages: + + - :ref:`webllm-runtime` + - :ref:`deploy-rest-api` + - :ref:`deploy-cli` + - :ref:`deploy-python-engine` + - :ref:`deploy-ios` + - :ref:`deploy-android` + - :ref:`deploy-ide-integration` + +- :ref:`Convert model weight to MLC format `, if you want to run your own models. +- :ref:`Compile model libraries `, if you want to deploy to web/iOS/Android or control the model optimizations. +- Report any problem or ask any question: open new issues in our `GitHub repo `_. diff --git a/docs/get_started/project_overview.rst b/docs/get_started/project_overview.rst index 2b6ff7495a..ef631e40c8 100644 --- a/docs/get_started/project_overview.rst +++ b/docs/get_started/project_overview.rst @@ -52,7 +52,7 @@ There are several ways to prepare the model weights and model lib. A default chat config usually comes with the model weight directory. You can further customize the system prompt, temperature, and other options by modifying the JSON file. MLC chat runtimes also provide API to override these options during model reload. -Please refer to :doc:`/get_started/mlc_chat_config` for more details. +Please refer to :ref:`configure-mlc-chat-json` for more details. Runtime Flow Overview @@ -82,7 +82,7 @@ Thank you for reading and learning the high-level concepts. Moving next, feel free to check out documents on the left navigation panel and learn about topics you are interested in. -- :doc:`/get_started/mlc_chat_config` shows how to configure specific chat behavior. +- :ref:`configure-mlc-chat-json` shows how to configure specific chat behavior. - Build and Deploy App section contains guides to build apps and platform-specific MLC chat runtimes. - Compile models section provides guidelines to convert model weights and produce model libs. diff --git a/docs/get_started/quick_start.rst b/docs/get_started/quick_start.rst new file mode 100644 index 0000000000..93d0f8bb3f --- /dev/null +++ b/docs/get_started/quick_start.rst @@ -0,0 +1,190 @@ +.. _quick-start: + +Quick Start +=========== + +Examples +-------- + +To begin with, try out MLC LLM support for int4-quantized Llama2 7B. +It is recommended to have at least 6GB free VRAM to run it. + +.. tabs:: + + .. tab:: Python + + **Install MLC LLM**. :ref:`MLC LLM ` is available via pip. + It is always recommended to install it in an isolated conda virtual environment. + + **Run chat completion in Python.** The following Python script showcases the Python API of MLC LLM: + + .. code:: python + + from mlc_llm import Engine + + # Create engine + model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" + engine = Engine(model) + + # Run chat completion in OpenAI API. + for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, + ): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) + print("\n") + + engine.terminate() + + .. Todo: link the colab notebook when ready: + + **Documentation and tutorial.** Python API reference and its tutorials are :ref:`available online `. + + .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-engine-api.jpg + :width: 600 + :align: center + + MLC LLM Python API + + .. tab:: REST Server + + **Install MLC LLM**. :ref:`MLC LLM ` is available via pip. + It is always recommended to install it in an isolated conda virtual environment. + + **Launch a REST server.** Run the following command from command line to launch a REST server at ``http://127.0.0.1:8000``. + + .. code:: shell + + mlc_llm serve HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + + **Send requests to server.** When the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``), + open a new shell and send a request via the following command: + + .. code:: shell + + curl -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", + "messages": [ + {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} + ] + }' \ + http://127.0.0.1:8000/v1/chat/completions + + **Documentation and tutorial.** Check out :ref:`deploy-rest-api` for the REST API reference and tutorial. + Our REST API has complete OpenAI API support. + + .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-serve-request.jpg + :width: 600 + :align: center + + Send HTTP request to REST server in MLC LLM + + .. tab:: Command Line + + **Install MLC LLM**. :ref:`MLC LLM ` is available via pip. + It is always recommended to install it in an isolated conda virtual environment. + + For Windows/Linux users, make sure to have latest :ref:`Vulkan driver ` installed. + + **Run in command line**. + + .. code:: bash + + mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + + + If you are using windows/linux/steamdeck and would like to use vulkan, + we recommend installing necessary vulkan loader dependency via conda + to avoid vulkan not found issues. + + .. code:: bash + + conda install -c conda-forge gcc libvulkan-loader + + + .. tab:: Web Browser + + `WebLLM `__. MLC LLM generates performant code for WebGPU and WebAssembly, + so that LLMs can be run locally in a web browser without server resources. + + **Download pre-quantized weights**. This step is self-contained in WebLLM. + + **Download pre-compiled model library**. WebLLM automatically downloads WebGPU code to execute. + + **Check browser compatibility**. The latest Google Chrome provides WebGPU runtime and `WebGPU Report `__ as a useful tool to verify WebGPU capabilities of your browser. + + .. figure:: https://blog.mlc.ai/img/redpajama/web.gif + :width: 300 + :align: center + + MLC LLM on Web + + .. tab:: iOS + + **Install MLC Chat iOS**. It is available on AppStore: + + .. image:: https://developer.apple.com/assets/elements/badges/download-on-the-app-store.svg + :width: 135 + :target: https://apps.apple.com/us/app/mlc-chat/id6448482937 + + | + + **Requirement**. Llama2-7B model needs an iOS device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. + + **Tutorial and source code**. The source code of the iOS app is fully `open source `__, + and a :ref:`tutorial ` is included in documentation. + + .. figure:: https://blog.mlc.ai/img/redpajama/ios.gif + :width: 300 + :align: center + + MLC Chat on iOS + + .. tab:: Android + + **Install MLC Chat Android**. A prebuilt is available as an APK: + + .. image:: https://seeklogo.com/images/D/download-android-apk-badge-logo-D074C6882B-seeklogo.com.png + :width: 135 + :target: https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android/mlc-chat.apk + + | + + **Requirement**. Llama2-7B model needs a device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. + The demo is tested on + + - Samsung S23 with Snapdragon 8 Gen 2 chip + - Redmi Note 12 Pro with Snapdragon 685 + - Google Pixel phones + + **Tutorial and source code**. The source code of the android app is fully `open source `__, + and a :ref:`tutorial ` is included in documentation. + + .. figure:: https://blog.mlc.ai/img/android/android-recording.gif + :width: 300 + :align: center + + MLC LLM on Android + + +What to Do Next +--------------- + +- Check out :ref:`introduction-to-mlc-llm` for the introduction of a complete workflow in MLC LLM. +- Depending on your use case, check out our API documentation and tutorial pages: + + - :ref:`webllm-runtime` + - :ref:`deploy-rest-api` + - :ref:`deploy-cli` + - :ref:`deploy-python-engine` + - :ref:`deploy-ios` + - :ref:`deploy-android` + - :ref:`deploy-ide-integration` + +- `Convert model weight to MLC format `_, if you want to run your own models. +- `Compile model libraries `_, if you want to deploy to web/iOS/Android or control the model optimizations. +- Report any problem or ask any question: open new issues in our `GitHub repo `_. diff --git a/docs/index.rst b/docs/index.rst index 721d9c227c..7160c95b28 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,193 +5,15 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. -.. _get_started: +Quick Start +----------- -Getting Started ---------------- +Check out :ref:`quick-start` for quick start examples of using MLC LLM. -To begin with, try out MLC LLM support for int4-quantized Llama2 7B. -It is recommended to have at least 6GB free VRAM to run it. +Introduction to MLC LLM +----------------------- -.. tabs:: - - .. tab:: Python - - **Install MLC LLM**. :doc:`MLC LLM ` is available via pip. - It is always recommended to install it in an isolated conda virtual environment. - - **Run chat completion in Python.** The following Python script showcases the Python API of MLC LLM: - - .. code:: python - - from mlc_llm import Engine - - # Create engine - model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" - engine = Engine(model) - - # Run chat completion in OpenAI API. - for response in engine.chat.completions.create( - messages=[{"role": "user", "content": "What is the meaning of life?"}], - model=model, - stream=True, - ): - for choice in response.choices: - print(choice.delta.content, end="", flush=True) - print("\n") - - engine.terminate() - - .. Todo: link the colab notebook when ready: - - **Documentation and tutorial.** Python API reference and its tutorials are :doc:`available online `. - - .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-engine-api.jpg - :width: 600 - :align: center - - MLC LLM Python API - - .. tab:: REST Server - - **Install MLC LLM**. :doc:`MLC LLM ` is available via pip. - It is always recommended to install it in an isolated conda virtual environment. - - **Launch a REST server.** Run the following command from command line to launch a REST server at ``http://127.0.0.1:8000``. - - .. code:: shell - - mlc_llm serve HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC - - **Send requests to server.** When the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``), - open a new shell and send a request via the following command: - - .. code:: shell - - curl -X POST \ - -H "Content-Type: application/json" \ - -d '{ - "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", - "messages": [ - {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} - ] - }' \ - http://127.0.0.1:8000/v1/chat/completions - - **Documentation and tutorial.** Check out :ref:`deploy-rest-api` for the REST API reference and tutorial. - Our REST API has complete OpenAI API support. - - .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-serve-request.jpg - :width: 600 - :align: center - - Send HTTP request to REST server in MLC LLM - - .. tab:: Command Line - - **Install MLC LLM**. :doc:`MLC LLM ` is available via pip. - It is always recommended to install it in an isolated conda virtual environment. - - For Windows/Linux users, make sure to have latest :ref:`Vulkan driver ` installed. - - **Run in command line**. - - .. code:: bash - - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC - - - If you are using windows/linux/steamdeck and would like to use vulkan, - we recommend installing necessary vulkan loader dependency via conda - to avoid vulkan not found issues. - - .. code:: bash - - conda install -c conda-forge gcc libvulkan-loader - - - .. tab:: Web Browser - - `WebLLM `__. MLC LLM generates performant code for WebGPU and WebAssembly, - so that LLMs can be run locally in a web browser without server resources. - - **Download pre-quantized weights**. This step is self-contained in WebLLM. - - **Download pre-compiled model library**. WebLLM automatically downloads WebGPU code to execute. - - **Check browser compatibility**. The latest Google Chrome provides WebGPU runtime and `WebGPU Report `__ as a useful tool to verify WebGPU capabilities of your browser. - - .. figure:: https://blog.mlc.ai/img/redpajama/web.gif - :width: 300 - :align: center - - MLC LLM on Web - - .. tab:: iOS - - **Install MLC Chat iOS**. It is available on AppStore: - - .. image:: https://developer.apple.com/assets/elements/badges/download-on-the-app-store.svg - :width: 135 - :target: https://apps.apple.com/us/app/mlc-chat/id6448482937 - - | - - **Requirement**. Llama2-7B model needs an iOS device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. - - **Tutorial and source code**. The source code of the iOS app is fully `open source `__, - and a :doc:`tutorial ` is included in documentation. - - .. figure:: https://blog.mlc.ai/img/redpajama/ios.gif - :width: 300 - :align: center - - MLC Chat on iOS - - .. tab:: Android - - **Install MLC Chat Android**. A prebuilt is available as an APK: - - .. image:: https://seeklogo.com/images/D/download-android-apk-badge-logo-D074C6882B-seeklogo.com.png - :width: 135 - :target: https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android/mlc-chat.apk - - | - - **Requirement**. Llama2-7B model needs a device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. - The demo is tested on - - - Samsung S23 with Snapdragon 8 Gen 2 chip - - Redmi Note 12 Pro with Snapdragon 685 - - Google Pixel phones - - **Tutorial and source code**. The source code of the android app is fully `open source `__, - and a :doc:`tutorial ` is included in documentation. - - .. figure:: https://blog.mlc.ai/img/android/android-recording.gif - :width: 300 - :align: center - - MLC LLM on Android - - -What to Do Next ---------------- - -- Depending on your use case, check out our API documentation and tutorial pages: - - - :ref:`webllm-runtime` - - :ref:`deploy-rest-api` - - :ref:`deploy-cli` - - :ref:`deploy-python-engine` - - :ref:`deploy-ios` - - :ref:`deploy-android` - - :ref:`deploy-ide-integration` - -- Deploy your local model: check out :ref:`convert-weights-via-MLC` to convert your model weights to MLC format. -- Deploy models to Web or build iOS/Android apps on your own: check out :ref:`compile-model-libraries` to compile the models into binary libraries. -- Customize model optimizations: check out :ref:`compile-model-libraries`. -- Report any problem or ask any question: open new issues in our `GitHub repo `_. +Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a complete workflow in MLC LLM. .. toctree:: @@ -199,8 +21,8 @@ What to Do Next :caption: Get Started :hidden: - get_started/project_overview.rst - get_started/mlc_chat_config.rst + get_started/quick_start.rst + get_started/intro.rst .. toctree:: :maxdepth: 1 @@ -214,6 +36,7 @@ What to Do Next deploy/ios.rst deploy/android.rst deploy/ide_integration.rst + deploy/mlc_chat_config.rst .. toctree:: :maxdepth: 1 @@ -231,7 +54,6 @@ What to Do Next :hidden: prebuilt_models.rst - prebuilt_models_deprecated.rst .. toctree:: :maxdepth: 1 diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst index e299f68138..f97909a515 100644 --- a/docs/prebuilt_models.rst +++ b/docs/prebuilt_models.rst @@ -44,7 +44,7 @@ We quickly go over how to use prebuilt models for each platform. You can find de **Prebuilt Models on CLI / Python** -For more, please see :doc:`the CLI page `, and the :doc:`the Python page `. +For more, please see :ref:`the CLI page `, and the :ref:`the Python page `. .. collapse:: Click to show details @@ -71,7 +71,7 @@ For more, please see :doc:`the CLI page `, and the :doc:`the Python mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC - To run the model with Python API, see :doc:`the Python page ` (all other downloading steps are the same as CLI). + To run the model with Python API, see :ref:`the Python page ` (all other downloading steps are the same as CLI). .. for a blank line From 9f9436b6f7ef7487c129f984c8a7784b70765296 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 17 Apr 2024 09:58:50 -0400 Subject: [PATCH 194/531] [Serving] Support `DebugCallFuncOnAllAllWorker` and CUDA profiler (#2148) This PR adds a new function `DebugCallFuncOnAllAllWorker` which calls a global function of sigunature `[] -> None` on all distributed workers when tensor parallelism is enabled (or the local session itself if not enabled). As the name suggests, this function is only for the debug purpose, and we will not expose any public interface to invoke this function. This PR also introduces the global functions `"mlc.debug_cuda_profiler_start"` and `"mlc.debug_cuda_profiler_stop"`, which enables CUDA profiling when using PopenServer. --- ci/task/pylint.sh | 1 + cpp/serve/engine.cc | 7 + cpp/serve/engine.h | 5 + cpp/serve/function_table.cc | 10 + cpp/serve/function_table.h | 2 + cpp/serve/model.cc | 42 +++-- cpp/serve/model.h | 5 + cpp/serve/threaded_engine.cc | 75 +++++--- cpp/serve/threaded_engine.h | 5 + python/mlc_llm/base.py | 19 ++ python/mlc_llm/serve/engine_base.py | 5 + .../serve/entrypoints/debug_entrypoints.py | 29 +++ tests/python/serve/benchmark.py | 178 ------------------ 13 files changed, 166 insertions(+), 217 deletions(-) delete mode 100644 tests/python/serve/benchmark.py diff --git a/ci/task/pylint.sh b/ci/task/pylint.sh index c4abb81d90..849efe628e 100755 --- a/ci/task/pylint.sh +++ b/ci/task/pylint.sh @@ -8,6 +8,7 @@ export PYTHONPATH="./python":${PYTHONPATH:-""} # TVM Unity is a dependency to this testing pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly +pip install --quiet --pre -U cuda-python pylint --jobs $NUM_THREADS ./python/ pylint --jobs $NUM_THREADS --recursive=y ./tests/python/ diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 7f764d3fb6..c9ca511e85 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -259,6 +259,13 @@ class EngineImpl : public Engine { "action (e.g. prefill, decode, etc.) but it does not."; } + /************** Debug/Profile **************/ + + void DebugCallFuncOnAllAllWorker(const String& func_name) final { + CHECK(!models_.empty()) << "There is no model running in Engine."; + models_[0]->DebugCallFuncOnAllAllWorker(func_name); + } + private: /*! \brief Set the maximum threading backend concurrency. */ void SetThreadMaxConcurrency() { diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index cb31304b5b..581219c350 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -107,6 +107,11 @@ class Engine { * generation results for those finished requests. */ virtual void Step() = 0; + + /************** Debug/Profile **************/ + + /*! \brief Call the given global function on all workers. Only for debug purpose. */ + virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; }; /*! diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 21835566b3..8a0bcd66c6 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -315,6 +315,16 @@ ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String buffer_ } } +void FunctionTable::DebugCallFuncOnAllAllWorker(const String& func_name) const { + if (this->use_disco) { + sess->CallPacked(sess->GetGlobalFunc(func_name)); + } else { + const PackedFunc* func = Registry::Get(func_name); + CHECK(func != nullptr) << "Global function name \"" << func_name << "\" is not found"; + (*func)(); + } +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index 195f79264e..03b0428096 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -52,6 +52,8 @@ struct FunctionTable { ObjectRef CopyToWorker0(const NDArray& host_array, String buffer_cache_key, ShapeTuple max_reserved_shape); + void DebugCallFuncOnAllAllWorker(const String& func_name) const; + bool use_disco = false; Device local_gpu_device; Session sess{nullptr}; diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index fa4a4bf09a..eb35bada38 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -714,24 +714,6 @@ class ModelImpl : public ModelObj { /*********************** KV Cache Management ***********************/ - LogitProcessor CreateLogitProcessor(int max_num_token, - Optional trace_recorder) { - return LogitProcessor(max_num_token, vocab_size_, &this->ft_, device_, - std::move(trace_recorder)); - } - - Sampler CreateSampler(int max_num_sample, int num_models, - Optional trace_recorder) { - if (num_models > 1) { // speculative decoding uses cpu sampler - return Sampler::CreateCPUSampler(std::move(trace_recorder)); - } else if (Sampler::SupportGPUSampler(device_)) { - return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_, - std::move(trace_recorder)); - } else { - return Sampler::CreateCPUSampler(std::move(trace_recorder)); - } - } - void CreateKVCache(KVCacheConfig kv_cache_config) final { IntTuple max_num_sequence{kv_cache_config->max_num_sequence}; IntTuple max_total_sequence_length{kv_cache_config->max_total_sequence_length}; @@ -776,6 +758,24 @@ class ModelImpl : public ModelObj { /*********************** Utilities ***********************/ + LogitProcessor CreateLogitProcessor(int max_num_token, + Optional trace_recorder) { + return LogitProcessor(max_num_token, vocab_size_, &this->ft_, device_, + std::move(trace_recorder)); + } + + Sampler CreateSampler(int max_num_sample, int num_models, + Optional trace_recorder) { + if (num_models > 1) { // speculative decoding uses cpu sampler + return Sampler::CreateCPUSampler(std::move(trace_recorder)); + } else if (Sampler::SupportGPUSampler(device_)) { + return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_, + std::move(trace_recorder)); + } else { + return Sampler::CreateCPUSampler(std::move(trace_recorder)); + } + } + int EstimateHostCPURequirement() const final { CHECK_NE(num_shards_, -1) << "The model has not been initialized"; return num_shards_ > 1 ? num_shards_ : 0; @@ -832,6 +832,12 @@ class ModelImpl : public ModelObj { } } + /************** Debug/Profile **************/ + + void DebugCallFuncOnAllAllWorker(const String& func_name) final { + ft_.DebugCallFuncOnAllAllWorker(func_name); + } + private: /*! \brief Load model configuration from JSON. */ picojson::object LoadModelConfigJSON(const std::string& config_str) { diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 79619acbe6..761f936363 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -291,6 +291,11 @@ class ModelObj : public Object { /*! \brief Reset the model KV cache and other statistics. */ virtual void Reset() = 0; + /************** Debug/Profile **************/ + + /*! \brief Call the given global function on all workers. Only for debug purpose. */ + virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; + static constexpr const char* _type_key = "mlc.serve.Model"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index f74517d7bf..d79b122125 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -23,6 +23,15 @@ namespace serve { using tvm::Device; using namespace tvm::runtime; +/*! \brief The threaded engine instruction kind. */ +enum class InstructionKind : int { + kAddRequest = 0, + kAbortRequest = 1, + kUnloadEngine = 2, + kReloadEngine = 3, + kDebugCallFuncOnAllAllWorker = 4, +}; + /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineImpl : public ThreadedEngine { public: @@ -65,7 +74,7 @@ class ThreadedEngineImpl : public ThreadedEngine { bool need_notify = false; { std::lock_guard lock(background_loop_mutex_); - requests_to_add_.push_back(request); + instruction_queue_.emplace_back(InstructionKind::kAddRequest, request); ++pending_request_operation_cnt_; need_notify = engine_waiting_; } @@ -78,7 +87,7 @@ class ThreadedEngineImpl : public ThreadedEngine { bool need_notify = false; { std::lock_guard lock(background_loop_mutex_); - requests_to_abort_.push_back(request_id); + instruction_queue_.emplace_back(InstructionKind::kAbortRequest, request_id); ++pending_request_operation_cnt_; need_notify = engine_waiting_; } @@ -89,8 +98,7 @@ class ThreadedEngineImpl : public ThreadedEngine { void RunBackgroundLoop() final { // The local vectors that load the requests from critical regions. - std::vector local_requests_to_add; - std::vector local_requests_to_abort; + std::vector> local_instruction_queue; while (!exit_now_.load(std::memory_order_relaxed)) { { @@ -102,17 +110,26 @@ class ThreadedEngineImpl : public ThreadedEngine { }); engine_waiting_ = false; - local_requests_to_add = requests_to_add_; - local_requests_to_abort = requests_to_abort_; - requests_to_add_.clear(); - requests_to_abort_.clear(); + local_instruction_queue = instruction_queue_; + instruction_queue_.clear(); pending_request_operation_cnt_ = 0; } - for (Request request : local_requests_to_add) { - background_engine_->AddRequest(request); - } - for (String request_id : local_requests_to_abort) { - background_engine_->AbortRequest(request_id); + for (const auto& [kind, arg] : local_instruction_queue) { + if (kind == InstructionKind::kAddRequest) { + background_engine_->AddRequest(Downcast(arg)); + } else if (kind == InstructionKind::kAbortRequest) { + background_engine_->AbortRequest(Downcast(arg)); + } else if (kind == InstructionKind::kUnloadEngine) { + // Todo(mlc-team): implement engine unload + LOG(FATAL) << "Not implemented yet."; + } else if (kind == InstructionKind::kReloadEngine) { + // Todo(mlc-team): implement engine reload + LOG(FATAL) << "Not implemented yet."; + } else if (kind == InstructionKind::kDebugCallFuncOnAllAllWorker) { + background_engine_->DebugCallFuncOnAllAllWorker(Downcast(arg)); + } else { + LOG(FATAL) << "Cannot reach here"; + } } background_engine_->Step(); } @@ -159,6 +176,21 @@ class ThreadedEngineImpl : public ThreadedEngine { request_stream_callback_cv_.notify_one(); } + /************** Debug/Profile **************/ + + void DebugCallFuncOnAllAllWorker(const String& func_name) final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kDebugCallFuncOnAllAllWorker, func_name); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } + } + private: /*! \brief The background normal engine for request processing. */ std::unique_ptr background_engine_; @@ -176,17 +208,16 @@ class ThreadedEngineImpl : public ThreadedEngine { /************** Critical Regions **************/ /*! - * \brief The requests to add into the background engine. - * Elements are sended from other threads and consumed by - * the threaded engine in the background loop. - */ - std::vector requests_to_add_; - /*! - * \brief The requests to abort from the background engine. + * \brief The instruction queue for the threaded engine. + * The instructions include: + * - requests to add into the background engine, + * - requests to abort from the background engine, + * - engine unload/reload, + * - and other debugging instructions. * Elements are sended from other threads and consumed by * the threaded engine in the background loop. */ - std::vector requests_to_abort_; + std::vector> instruction_queue_; /*! * \brief The delta outputs to pass through callback. * Elements are sended from the background loop thread and @@ -219,6 +250,8 @@ class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", &ThreadedEngineImpl::RunBackgroundStreamBackLoop); TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &ThreadedEngineImpl::ExitBackgroundLoop); + TVM_MODULE_VTABLE_ENTRY("debug_call_func_on_all_worker", + &ThreadedEngineImpl::DebugCallFuncOnAllAllWorker); if (_name == "init_background_engine") { return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { SelfPtr self = static_cast(_self.get()); diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index 1440a88056..2e57afd2a0 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -57,6 +57,11 @@ class ThreadedEngine { /*! \brief Abort the input request (specified by id string) from engine. */ virtual void AbortRequest(const String& request_id) = 0; + + /************** Debug/Profile **************/ + + /*! \brief Call the given global function on all workers. Only for debug purpose. */ + virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; }; } // namespace serve diff --git a/python/mlc_llm/base.py b/python/mlc_llm/base.py index 13c7ba9f84..308426d210 100644 --- a/python/mlc_llm/base.py +++ b/python/mlc_llm/base.py @@ -1,4 +1,5 @@ """Load MLC LLM library and _ffi_api functions.""" + import ctypes import os import sys @@ -23,6 +24,24 @@ def _load_mlc_llm_lib(): return ctypes.CDLL(lib_path[0]), lib_path[0] +@tvm.register_func("mlc.debug_cuda_profiler_start") +def _debug_cuda_profiler_start() -> None: + """Start cuda profiler.""" + import cuda # pylint: disable=import-outside-toplevel + import cuda.cudart # pylint: disable=import-outside-toplevel + + cuda.cudart.cudaProfilerStart() # pylint: disable=c-extension-no-member + + +@tvm.register_func("mlc.debug_cuda_profiler_stop") +def _debug_cuda_profiler_stop() -> None: + """Stop cuda profiler.""" + import cuda # pylint: disable=import-outside-toplevel + import cuda.cudart # pylint: disable=import-outside-toplevel + + cuda.cudart.cudaProfilerStop() # pylint: disable=c-extension-no-member + + # only load once here if SKIP_LOADING_MLCLLM_SO == "0": _LIB, _LIB_PATH = _load_mlc_llm_lib() diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 45ad9f7756..e61ab626d6 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -781,6 +781,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "run_background_stream_back_loop", "init_background_engine", "exit_background_loop", + "debug_call_func_on_all_worker", ] } self.tokenizer = Tokenizer(tokenizer_path) @@ -819,6 +820,10 @@ def terminate(self): self._background_loop_thread.join() self._background_stream_back_loop_thread.join() + def _debug_call_func_on_all_worker(self, func_name: str) -> None: + """Call the given global function on all workers. Only for debug purpose.""" + self._ffi["debug_call_func_on_all_worker"](func_name) + def process_chat_completion_request( # pylint: disable=too-many-arguments request: openai_api_protocol.ChatCompletionRequest, diff --git a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py index fe76696163..af1613c027 100644 --- a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py @@ -50,3 +50,32 @@ async def debug_dump_event_trace(request: fastapi.Request): ) return json.loads(async_engine.state.trace_recorder.dump_json()) + + +################ /debug/cuda_profiler_start/end ################ + + +@app.post("/debug/cuda_profiler_start") +async def debug_cuda_profiler_start(_request: fastapi.Request): + """Start the cuda profiler for the engine. Only for debug purpose.""" + server_context: ServerContext = ServerContext.current() + # Since the CUDA profiler is process-wise, call the function for one model is sufficient. + for model in server_context.get_model_list(): + async_engine = server_context.get_engine(model) + async_engine._debug_call_func_on_all_worker( # pylint: disable=protected-access + "mlc.debug_cuda_profiler_start" + ) + break + + +@app.post("/debug/cuda_profiler_stop") +async def debug_cuda_profiler_stop(_request: fastapi.Request): + """Stop the cuda profiler for the engine. Only for debug purpose.""" + server_context: ServerContext = ServerContext.current() + # Since the CUDA profiler is process-wise, call the function for one model is sufficient. + for model in server_context.get_model_list(): + async_engine = server_context.get_engine(model) + async_engine._debug_call_func_on_all_worker( # pylint: disable=protected-access + "mlc.debug_cuda_profiler_stop" + ) + break diff --git a/tests/python/serve/benchmark.py b/tests/python/serve/benchmark.py deleted file mode 100644 index a34b47335c..0000000000 --- a/tests/python/serve/benchmark.py +++ /dev/null @@ -1,178 +0,0 @@ -# pylint: disable=import-error,line-too-long,missing-docstring,no-member,too-many-locals -# type: ignore -import argparse -import json -import os -import random -import time -from typing import Any, Callable, List, Tuple - -import numpy as np -from transformers import AutoTokenizer - -from mlc_llm.serve import GenerationConfig -from mlc_llm.serve.config import ResponseFormat -from mlc_llm.serve.sync_engine import SyncEngine - - -def _parse_args(): - args = argparse.ArgumentParser() - args.add_argument("--model-lib-path", type=str, required=True) - # Download dataset from - # https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json - args.add_argument("--dataset", type=str, required=True) - args.add_argument("--device", type=str, default="auto") - args.add_argument("--num-prompts", type=int, default=500) - args.add_argument("--max-num-sequence", type=int, default=80) - args.add_argument("--max-total-seq-length", type=int) - args.add_argument("--seed", type=int, default=0) - args.add_argument("--json-output", type=bool, default=False) - args.add_argument("--cuda-profile", type=bool, default=False) - - parsed = args.parse_args() - parsed.model = os.path.dirname(parsed.model_lib_path) - assert parsed.max_num_sequence % 16 == 0 - return parsed - - -def sample_requests( - dataset_path: str, num_requests: int, model_path: str, json_output: bool = False -) -> Tuple[List[str], List[GenerationConfig]]: - """Sample requests from dataset. - Acknowledgement to the benchmark scripts in the vLLM project. - """ - tokenizer = AutoTokenizer.from_pretrained(model_path) - - with open(dataset_path, encoding="utf-8") as f: - dataset = json.load(f) - - # Filter out the conversations with less than 2 turns. - dataset = [ - (data["conversations"][0]["value"], data["conversations"][1]["value"]) - for data in dataset - if len(data["conversations"]) >= 2 - ] - # Tokenize the prompts and completions. - prompts = [prompt for prompt, _ in dataset] - prompt_token_ids = tokenizer(prompts).input_ids - completions = [completion for _, completion in dataset] - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): - output_len = len(completion_token_ids[i]) - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) - - # Filter out too long sequences. - filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: - prompt_len = len(prompt_token_ids) - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. - continue - filtered_dataset.append((prompt, prompt_len, output_len)) - - # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) - - # Construct generation config. - prompts = [prompt for prompt, _, _ in sampled_requests] - response_format = ResponseFormat("json_object" if json_output else "text") - generation_config_list = [ - GenerationConfig( - temperature=1.0, top_p=1.0, max_tokens=output_len, response_format=response_format - ) - for _, _, output_len in sampled_requests - ] - return prompts, generation_config_list - - -def time_evaluator(func: Callable, args: List[Any], num_runs: int = 3): - times = [] - for _ in range(num_runs): - start = time.perf_counter() - func(*args) - end = time.perf_counter() - times.append(end - start) - - return np.array(times) - - -def benchmark(args: argparse.Namespace): - random.seed(args.seed) - - # Create engine - engine = SyncEngine( - model=args.model, - model_lib_path=args.model_lib_path, - device=args.device, - mode="server", - max_batch_size=args.max_num_sequence, - max_total_sequence_length=args.max_total_seq_length, - ) - - # Sample prompts from dataset - prompts, generation_config = sample_requests( - args.dataset, args.num_prompts, args.model, args.json_output - ) - # Engine statistics - num_runs = 1 - single_token_prefill_latency = [] - single_token_decode_latency = [] - engine_total_prefill_time = [] - engine_total_decode_time = [] - total_prefill_tokens = [] - total_decode_tokens = [] - - def engine_generate(): - engine.reset() - engine.generate(prompts, generation_config) - engine_stats = engine.stats() - single_token_prefill_latency.append(engine_stats["single_token_prefill_latency"]) - single_token_decode_latency.append(engine_stats["single_token_decode_latency"]) - engine_total_prefill_time.append(engine_stats["engine_total_prefill_time"]) - engine_total_decode_time.append(engine_stats["engine_total_decode_time"]) - total_prefill_tokens.append(engine_stats["total_prefill_tokens"]) - total_decode_tokens.append(engine_stats["total_decode_tokens"]) - - if args.cuda_profile: - import cuda - import cuda.cudart - - cuda.cudart.cudaProfilerStart() - engine_generate() - cuda.cudart.cudaProfilerStop() - return - - e2e_latency = time_evaluator(engine_generate, args=[], num_runs=num_runs) - single_token_prefill_latency = np.array(single_token_prefill_latency) - single_token_decode_latency = np.array(single_token_decode_latency) - engine_total_prefill_time = np.array(engine_total_prefill_time) - engine_total_decode_time = np.array(engine_total_decode_time) - total_prefill_tokens = np.array(total_prefill_tokens) - total_decode_tokens = np.array(total_decode_tokens) - avg_prefill_tokens = total_prefill_tokens / len(prompts) - avg_decode_tokens = total_decode_tokens / len(prompts) - prefill_throughput = total_prefill_tokens / engine_total_prefill_time - decode_throughput = total_decode_tokens / engine_total_decode_time - overall_throughput = (total_prefill_tokens + total_decode_tokens) / e2e_latency - - print(args) - print(f"Average end-to-end latency: {e2e_latency.mean():.4f} seconds for the entire batch") - print(f"Average prefill tokens: {avg_prefill_tokens.mean():.4f} tok/req") - print(f"Average decode tokens: {avg_decode_tokens.mean():.4f} tok/req") - print(f"Single token prefill latency: {single_token_prefill_latency.mean() * 1e3:.4f} ms/tok") - print(f"Single token decode latency: {single_token_decode_latency.mean() * 1e3:.4f} ms/tok") - print(f"Engine prefill time: {engine_total_prefill_time.mean():.4f} s") - print(f"Engine decode time: {engine_total_decode_time.mean():.4f} s") - print(f"Request throughput: {args.num_prompts / e2e_latency.mean():.4f} req/s") - print(f"Prefill token throughput: {prefill_throughput.mean():.4f} tok/s") - print(f"Decode token throughput: {decode_throughput.mean():.4f} tok/s") - print(f"Overall token throughput: {overall_throughput.mean():.4f} tok/s") - - -if __name__ == "__main__": - ARGS = _parse_args() - benchmark(ARGS) From 2a24f1363431fb7c8318c398d7fd3dcee213294d Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 17 Apr 2024 10:01:52 -0400 Subject: [PATCH 195/531] [DOCS] Update introduction (#2151) * [DOCS] Update introduction Some minor tweaks on the introduction doc * Update docs/get_started/introduction.rst Co-authored-by: Ruihang Lai --------- Co-authored-by: Ruihang Lai --- .../{intro.rst => introduction.rst} | 80 ++++++++++--------- docs/index.rst | 2 +- 2 files changed, 45 insertions(+), 37 deletions(-) rename docs/get_started/{intro.rst => introduction.rst} (77%) diff --git a/docs/get_started/intro.rst b/docs/get_started/introduction.rst similarity index 77% rename from docs/get_started/intro.rst rename to docs/get_started/introduction.rst index c76457647a..245bdb6da1 100644 --- a/docs/get_started/intro.rst +++ b/docs/get_started/introduction.rst @@ -7,13 +7,11 @@ Introduction to MLC LLM :local: :depth: 2 -Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance universal -deployment solution that allows native deployment of any large language models with native APIs -with compiler acceleration. -The mission of this project is to enable everyone to develop, optimize and deploy AI models -natively on everyone's devices with ML compilation techniques. +Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance +universal LLM deployment engine. The mission of this project is to enable everyone to develop, +optimize and deploy AI models natively on everyone's devices with ML compilation techniques. -This page is a quick tutorial to introduce how to try out MLC LLM, and the core steps to +This page is a quick tutorial to introduce how to try out MLC LLM, and the steps to deploy your own models with MLC LLM. Installation @@ -35,7 +33,7 @@ Chat CLI -------- As the first example, we try out the chat CLI in MLC LLM with 4-bit quantized 7B Llama-2 model. -The simplest command to run MLC chat is a one-liner command: +You can run MLC chat through a one-liner command: .. code:: bash @@ -115,8 +113,9 @@ You can save the code below into a Python file and run it. MLC LLM Python API This code example first creates an :class:`mlc_llm.Engine` instance with the the 4-bit quantized Llama-2 model. -**The Python API of** :class:`mlc_llm.Engine` **if fully compatible with OpenAI API**, -which means you can use :class:`mlc_llm.Engine` in the same way of using `OpenAI's Python package `_ +**We design the Python API** :class:`mlc_llm.Engine` **to align with OpenAI API**, +which means you can use :class:`mlc_llm.Engine` in the same way of using +`OpenAI's Python package `_ for both synchronous and asynchronous generation. In this code example, we use the synchronous chat completion interface and iterate over @@ -133,14 +132,13 @@ If you want to run without streaming, you can run print(response) You can also try different arguments supported in `OpenAI chat completion API `_. - +If you would like to do concurrent asynchronous generation, you can use :class:`mlc_llm.AsyncEngine` instead. REST Server ----------- For the third example, we launch a REST server to serve the 4-bit quantized Llama-2 model -for OpenAI chat completion requests. -The server can be launched in command line with +for OpenAI chat completion requests. The server can be launched in command line with .. code:: bash @@ -222,19 +220,19 @@ Now, we can try to run your own model with chat CLI: For the first run, model compilation will be triggered automatically to optimize the model for GPU accelerate and generate the binary model library. -The chat interface will be displayed after model compilation finishes. -By simply replacing the model string ``HF://xxx`` with ``$MLC_MODEL_PATH``, -you can also use this model in Python API, MLC serve and other use scenarios. +The chat interface will be displayed after model JIT compilation finishes. +You can also use this model in Python API, MLC serve and other use scenarios. -(Optional) Compile Model Manually -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +(Optional) Compile Model Library +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -In previous sections, model libraries are compiled automatically when the runtime -chat module or :class:`mlc_llm.Engine` launches, +In previous sections, model libraries are compiled when the :class:`mlc_llm.Engine` launches, which is what we call "JIT (Just-in-Time) model compilation". -In some cases (e.g., web / mobile deployment), it is beneficial to manually compile the model libraries, -so that we can deploy LLMs on platforms that come with no compiler environment, -with only the compiled model libraries being shipped. +In some cases, it is beneficial to explicitly compile the model libraries. +We can deploy LLMs with reduced dependencies by shipping the library for deployment without going through compilation. +It will also enable advanced options such as cross-compiling the libraries for web and mobile deployments. + + Below is an example command of compiling model libraries in MLC LLM: .. code:: bash @@ -270,30 +268,40 @@ hardware backends, such as WebGPU, iOS and Android. Universal Deployment -------------------- -MLC LLM is high-performance universal deployment solution for large language models. -The examples we ran above use your native local GPU environment (CUDA, ROCm or Metal). +MLC LLM is a high-performance universal deployment solution for large language models, +to enable native deployment of any large language models with native APIs with compiler acceleration +So far, we have gone through several examples running on a local GPU environment. +The project supports multiple kinds of GPU backends. -If your local environment is CUDA or ROCm, we can quickly try out the command below -to experience the universal deployment. -This command launches chat CLI with Vulkan runtime rather than CUDA/ROCm runtime. +You can use `--device` option in compilation and runtime to pick a specific GPU backend. +For example, if you have an NVIDIA or AMD GPU, you can try to use the option below +to run chat through the vulkan backend. Vulkan-based LLM applications run in less typical +environments (e.g. SteamDeck). .. code:: bash mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC --device vulkan -Summary -------- +The same core LLM runtime engine powers all the backends, enabling the same model to be deployed across backends as +long as they fit within the memory and computing budget of the corresponding hardware backend. +We also leverage machine learning compilation to build backend-specialized optimizations to +get out the best performance on the targetted backend when possible, and reuse key insights and optimizations +across backends we support. -To briefly summarize this page, +Please checkout the what to do next sections below to find out more about different deployment scenarios, +such as WebGPU-based browser deployment, mobile and other settings. + +Summary and What to Do Next +--------------------------- -- we went through three examples (chat CLI, Python API, and REST server) of running language models with MLC LLM, -- we introduced how to generate MLC config file and convert model weights for your own models to run with MLC LLM, and (optionally) how to compile your models manually. -- we showcased the universal deployment of MLC LLM. +To briefly summarize this page, -What to Do Next ---------------- +- We went through three examples (chat CLI, Python API, and REST server) of MLC LLM, +- we introduced how to convert model weights for your own models to run with MLC LLM, and (optionally) how to compile your models. +- We also discussed the the universal deployment capability of MLC LLM. -Next, you can check out the pages below for quick start examples and more detailed information. +Next, please feel free to check out the pages below for quick start examples and more detailed information +on specific platforms - :ref:`Quick start examples ` for Python API, chat CLI, REST server, web browser, iOS and Android. - Depending on your use case, check out our API documentation and tutorial pages: diff --git a/docs/index.rst b/docs/index.rst index 7160c95b28..e9835e152d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -22,7 +22,7 @@ Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a :hidden: get_started/quick_start.rst - get_started/intro.rst + get_started/introduction.rst .. toctree:: :maxdepth: 1 From 5a37e5593a5f9bbd6bd35a46d8135884b5528c0f Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 17 Apr 2024 16:37:18 -0400 Subject: [PATCH 196/531] [Serving][Python] Rename Engine to LLMEngine (#2152) We rename the public Python serve interface from `Engine` to `LLMEngine` (and from `AsyncEngine` to `AsyncLLMEngine` accordingly) for better class name clarity. This is because in cases people do wildcard import, in which case the name `Engine` itself does not convey enough meaning. --- docs/deploy/python_engine.rst | 2 +- docs/get_started/introduction.rst | 18 +++++----- docs/get_started/quick_start.rst | 4 +-- examples/python/sample_mlc_engine.py | 4 +-- python/mlc_llm/__init__.py | 2 +- python/mlc_llm/help.py | 2 +- python/mlc_llm/interface/serve.py | 2 +- python/mlc_llm/serve/__init__.py | 2 +- python/mlc_llm/serve/config.py | 2 +- python/mlc_llm/serve/engine.py | 30 ++++++++-------- python/mlc_llm/serve/engine_base.py | 34 +++++++++---------- python/mlc_llm/serve/server/server_context.py | 8 ++--- python/mlc_llm/serve/sync_engine.py | 2 +- tests/python/serve/evaluate_engine.py | 4 +-- tests/python/serve/test_serve_async_engine.py | 14 ++++---- .../serve/test_serve_async_engine_spec.py | 11 ++++-- tests/python/serve/test_serve_engine.py | 12 +++---- .../python/serve/test_serve_engine_grammar.py | 12 +++---- tests/python/serve/test_serve_engine_image.py | 4 +-- tests/python/serve/test_serve_engine_spec.py | 20 +++++------ tests/python/serve/test_serve_sync_engine.py | 12 +++---- 21 files changed, 103 insertions(+), 98 deletions(-) diff --git a/docs/deploy/python_engine.rst b/docs/deploy/python_engine.rst index 60b9acc4a0..c5d9a072a7 100644 --- a/docs/deploy/python_engine.rst +++ b/docs/deploy/python_engine.rst @@ -4,7 +4,7 @@ Python API ========== .. note:: - This page introduces the Python API with Engine in MLC LLM. + This page introduces the Python API with LLMEngine in MLC LLM. If you want to check out the old Python API which uses :class:`mlc_llm.ChatModule`, please go to :ref:`deploy-python-chat-module` diff --git a/docs/get_started/introduction.rst b/docs/get_started/introduction.rst index 245bdb6da1..282b4764c2 100644 --- a/docs/get_started/introduction.rst +++ b/docs/get_started/introduction.rst @@ -88,11 +88,11 @@ You can save the code below into a Python file and run it. .. code:: python - from mlc_llm import Engine + from mlc_llm import LLMEngine # Create engine model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" - engine = Engine(model) + engine = LLMEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( @@ -112,9 +112,9 @@ You can save the code below into a Python file and run it. MLC LLM Python API -This code example first creates an :class:`mlc_llm.Engine` instance with the the 4-bit quantized Llama-2 model. -**We design the Python API** :class:`mlc_llm.Engine` **to align with OpenAI API**, -which means you can use :class:`mlc_llm.Engine` in the same way of using +This code example first creates an :class:`mlc_llm.LLMEngine` instance with the the 4-bit quantized Llama-2 model. +**We design the Python API** :class:`mlc_llm.LLMEngine` **to align with OpenAI API**, +which means you can use :class:`mlc_llm.LLMEngine` in the same way of using `OpenAI's Python package `_ for both synchronous and asynchronous generation. @@ -132,7 +132,7 @@ If you want to run without streaming, you can run print(response) You can also try different arguments supported in `OpenAI chat completion API `_. -If you would like to do concurrent asynchronous generation, you can use :class:`mlc_llm.AsyncEngine` instead. +If you would like to do concurrent asynchronous generation, you can use :class:`mlc_llm.AsyncLLMEngine` instead. REST Server ----------- @@ -226,7 +226,7 @@ You can also use this model in Python API, MLC serve and other use scenarios. (Optional) Compile Model Library ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -In previous sections, model libraries are compiled when the :class:`mlc_llm.Engine` launches, +In previous sections, model libraries are compiled when the :class:`mlc_llm.LLMEngine` launches, which is what we call "JIT (Just-in-Time) model compilation". In some cases, it is beneficial to explicitly compile the model libraries. We can deploy LLMs with reduced dependencies by shipping the library for deployment without going through compilation. @@ -254,12 +254,12 @@ At runtime, we need to specify this model library path to use it. For example, .. code:: python - from mlc_llm import Engine + from mlc_llm import LLMEngine # For Python API model = "models/phi-2" model_lib_path = "models/phi-2/lib.so" - engine = Engine(model, model_lib_path=model_lib_path) + engine = LLMEngine(model, model_lib_path=model_lib_path) :ref:`compile-model-libraries` introduces the model compilation command in detail, where you can find instructions and example commands to compile model to different diff --git a/docs/get_started/quick_start.rst b/docs/get_started/quick_start.rst index 93d0f8bb3f..bd3b41218e 100644 --- a/docs/get_started/quick_start.rst +++ b/docs/get_started/quick_start.rst @@ -20,11 +20,11 @@ It is recommended to have at least 6GB free VRAM to run it. .. code:: python - from mlc_llm import Engine + from mlc_llm import LLMEngine # Create engine model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" - engine = Engine(model) + engine = LLMEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( diff --git a/examples/python/sample_mlc_engine.py b/examples/python/sample_mlc_engine.py index 9c65bd4c51..e26e17f1e2 100644 --- a/examples/python/sample_mlc_engine.py +++ b/examples/python/sample_mlc_engine.py @@ -1,8 +1,8 @@ -from mlc_llm import Engine +from mlc_llm import LLMEngine # Create engine model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" -engine = Engine(model) +engine = LLMEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( diff --git a/python/mlc_llm/__init__.py b/python/mlc_llm/__init__.py index b891323a5a..8e3aaaa808 100644 --- a/python/mlc_llm/__init__.py +++ b/python/mlc_llm/__init__.py @@ -6,4 +6,4 @@ from . import protocol, serve from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig from .libinfo import __version__ -from .serve import AsyncEngine, Engine +from .serve import AsyncLLMEngine, LLMEngine diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index ffea30c303..429e8a972d 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -193,7 +193,7 @@ this number. Under mode "server", the actual memory usage may be slightly larger than this number. """, "engine_config_serve": """ -The Engine execution configuration. +The LLMEngine execution configuration. Currently speculative decoding mode is specified via engine config. For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" to specify the eagle-style speculative decoding. diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index bdbb633414..3282762c00 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -34,7 +34,7 @@ def serve( ): # pylint: disable=too-many-arguments, too-many-locals """Serve the model with the specified configuration.""" # Create engine and start the background loop - async_engine = engine.AsyncEngine( + async_engine = engine.AsyncLLMEngine( model=model, device=device, model_lib_path=model_lib_path, diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index abbedc911e..8b99c9bc50 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -4,7 +4,7 @@ from .. import base from .config import EngineConfig, GenerationConfig, KVCacheConfig, SpeculativeMode from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData -from .engine import AsyncEngine, Engine +from .engine import AsyncLLMEngine, LLMEngine from .grammar import BNFGrammar, GrammarStateMatcher from .request import Request from .server import PopenServer diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 77bca9b462..113356156b 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -175,7 +175,7 @@ class SpeculativeMode(enum.IntEnum): @dataclass class EngineConfig: - """The class of Engine execution configuration. + """The class of LLMEngine execution configuration. Parameters ---------- diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 99c455f3cd..2ad6b0f1a1 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -37,10 +37,10 @@ class Chat: # pylint: disable=too-few-public-methods """The proxy class to direct to chat completions.""" def __init__(self, engine: weakref.ReferenceType) -> None: - assert isinstance(engine(), (AsyncEngine, Engine)) + assert isinstance(engine(), (AsyncLLMEngine, LLMEngine)) self.completions = ( AsyncChatCompletion(engine) # type: ignore - if isinstance(engine(), AsyncEngine) + if isinstance(engine(), AsyncLLMEngine) else ChatCompletion(engine) # type: ignore ) @@ -49,7 +49,7 @@ class AsyncChatCompletion: # pylint: disable=too-few-public-methods """The proxy class to direct to async chat completions.""" if sys.version_info >= (3, 9): - engine: weakref.ReferenceType["AsyncEngine"] + engine: weakref.ReferenceType["AsyncLLMEngine"] else: engine: weakref.ReferenceType @@ -226,7 +226,7 @@ class ChatCompletion: # pylint: disable=too-few-public-methods """The proxy class to direct to chat completions.""" if sys.version_info >= (3, 9): - engine: weakref.ReferenceType["Engine"] + engine: weakref.ReferenceType["LLMEngine"] else: engine: weakref.ReferenceType @@ -401,7 +401,7 @@ class AsyncCompletion: # pylint: disable=too-few-public-methods """The proxy class to direct to async completions.""" if sys.version_info >= (3, 9): - engine: weakref.ReferenceType["AsyncEngine"] + engine: weakref.ReferenceType["AsyncLLMEngine"] else: engine: weakref.ReferenceType @@ -580,7 +580,7 @@ class Completion: # pylint: disable=too-few-public-methods """The proxy class to direct to completions.""" if sys.version_info >= (3, 9): - engine: weakref.ReferenceType["Engine"] + engine: weakref.ReferenceType["LLMEngine"] else: engine: weakref.ReferenceType @@ -752,8 +752,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals ) -class AsyncEngine(engine_base.EngineBase): - """The AsyncEngine in MLC LLM that provides the asynchronous +class AsyncLLMEngine(engine_base.LLMEngineBase): + """The AsyncLLMEngine in MLC LLM that provides the asynchronous interfaces with regard to OpenAI API. Parameters @@ -825,7 +825,7 @@ class AsyncEngine(engine_base.EngineBase): memory usage may be slightly larger than this number. engine_config : Optional[EngineConfig] - The Engine execution configuration. + The LLMEngine execution configuration. Currently speculative decoding mode is specified via engine config. For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" to specify the eagle-style speculative decoding. @@ -1225,7 +1225,7 @@ async def _generate( generation_config: GenerationConfig, request_id: str, ) -> AsyncGenerator[List[engine_base.CallbackStreamOutput], Any]: - """Internal asynchronous text generation interface of AsyncEngine. + """Internal asynchronous text generation interface of AsyncLLMEngine. The method is a coroutine that streams a list of CallbackStreamOutput at a time via yield. The returned list length is the number of parallel generations specified by `generation_config.n`. @@ -1295,8 +1295,8 @@ def _abort(self, request_id: str): self._ffi["abort_request"](request_id) -class Engine(engine_base.EngineBase): - """The Engine in MLC LLM that provides the synchronous +class LLMEngine(engine_base.LLMEngineBase): + """The LLMEngine in MLC LLM that provides the synchronous interfaces with regard to OpenAI API. Parameters @@ -1368,7 +1368,7 @@ class Engine(engine_base.EngineBase): memory usage may be slightly larger than this number. engine_config : Optional[EngineConfig] - The Engine execution configuration. + The LLMEngine execution configuration. Currently speculative decoding mode is specified via engine config. For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" to specify the eagle-style speculative decoding. @@ -1761,7 +1761,7 @@ def _generate( # pylint: disable=too-many-locals generation_config: GenerationConfig, request_id: str, ) -> Iterator[List[engine_base.CallbackStreamOutput]]: - """Internal synchronous text generation interface of AsyncEngine. + """Internal synchronous text generation interface of AsyncLLMEngine. The method is a coroutine that streams a list of CallbackStreamOutput at a time via yield. The returned list length is the number of parallel generations specified by `generation_config.n`. @@ -1815,7 +1815,7 @@ def _generate( # pylint: disable=too-many-locals def _request_stream_callback_impl( self, delta_outputs: List[data.RequestStreamOutput] ) -> List[List[engine_base.CallbackStreamOutput]]: - """The underlying implementation of request stream callback of Engine.""" + """The underlying implementation of request stream callback of LLMEngine.""" batch_outputs: List[List[engine_base.CallbackStreamOutput]] = [] for delta_output in delta_outputs: request_id, stream_outputs = delta_output.unpack() diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index e61ab626d6..367deda8a4 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -464,7 +464,7 @@ def infer_args_under_mode( @dataclass class CallbackStreamOutput: - """The output of Engine._generate and AsyncEngine._generate + """The output of LLMEngine._generate and AsyncLLMEngine._generate Attributes ---------- @@ -489,7 +489,7 @@ class CallbackStreamOutput: class AsyncRequestStream: - """The asynchronous stream for requests in AsyncEngine. + """The asynchronous stream for requests in AsyncLLMEngine. Each request has its own unique stream. The stream exposes the method `push` for engine to push new generated @@ -548,29 +548,29 @@ async def __anext__(self) -> List[CallbackStreamOutput]: class EngineState: """The engine states that the request stream callback function may use. - This class is used for both AsyncEngine and Engine. - AsyncEngine uses the fields and methods starting with "async", - and Engine uses the ones starting with "sync". + This class is used for both AsyncLLMEngine and LLMEngine. + AsyncLLMEngine uses the fields and methods starting with "async", + and LLMEngine uses the ones starting with "sync". - - For AsyncEngine, the state contains an asynchronous event loop, + - For AsyncLLMEngine, the state contains an asynchronous event loop, the streamers and the number of unfinished generations for each request being processed. - - For Engine, the state contains a callback output blocking queue, + - For LLMEngine, the state contains a callback output blocking queue, the text streamers and the number of unfinished requests. We use this state class to avoid the callback function from capturing - the AsyncEngine. + the AsyncLLMEngine. The state also optionally maintains an event trace recorder, which can provide Chrome tracing when enabled. """ trace_recorder = None - # States used for AsyncEngine + # States used for AsyncLLMEngine async_event_loop: Optional[asyncio.AbstractEventLoop] = None async_streamers: Dict[str, Tuple[AsyncRequestStream, List[TextStreamer]]] = {} async_num_unfinished_generations: Dict[str, int] = {} - # States used for Engine + # States used for LLMEngine sync_output_queue: queue.Queue = queue.Queue() sync_text_streamers: List[TextStreamer] = [] sync_num_unfinished_generations: int = 0 @@ -632,7 +632,7 @@ def async_lazy_init_event_loop(self) -> None: self.async_event_loop = asyncio.get_event_loop() def _async_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: - """The request stream callback function for AsyncEngine to stream back + """The request stream callback function for AsyncLLMEngine to stream back the request generation results. Note @@ -652,7 +652,7 @@ def _async_request_stream_callback(self, delta_outputs: List[data.RequestStreamO def _async_request_stream_callback_impl( self, delta_outputs: List[data.RequestStreamOutput] ) -> None: - """The underlying implementation of request stream callback for AsyncEngine.""" + """The underlying implementation of request stream callback for AsyncLLMEngine.""" for delta_output in delta_outputs: request_id, stream_outputs = delta_output.unpack() streamers = self.async_streamers.get(request_id, None) @@ -693,28 +693,28 @@ def _async_request_stream_callback_impl( self.record_event(request_id, event="finish callback") def _sync_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: - """The request stream callback function for Engine to stream back + """The request stream callback function for LLMEngine to stream back the request generation results. """ # Put the delta outputs to the queue in the unblocking way. self.sync_output_queue.put_nowait(delta_outputs) -class EngineBase: # pylint: disable=too-many-instance-attributes,too-few-public-methods +class LLMEngineBase: # pylint: disable=too-many-instance-attributes,too-few-public-methods """The base engine class, which implements common functions that - are shared by Engine and AsyncEngine. + are shared by LLMEngine and AsyncLLMEngine. This class wraps a threaded engine that runs on a standalone thread inside and streams back the delta generated results via callback functions. The internal threaded engine keeps running an loop that drives the engine. - Engine and AsyncEngine inherits this EngineBase class, and implements + LLMEngine and AsyncLLMEngine inherits this LLMEngineBase class, and implements their own methods to process the delta generated results received from callback functions and yield the processed delta results in the forms of standard API protocols. - Checkout subclasses AsyncEngine/Engine for the docstring of constructor parameters. + Checkout subclasses AsyncLLMEngine/LLMEngine for the docstring of constructor parameters. """ def __init__( # pylint: disable=too-many-arguments,too-many-locals diff --git a/python/mlc_llm/serve/server/server_context.py b/python/mlc_llm/serve/server/server_context.py index ab103c05f8..0a9a1b0b1f 100644 --- a/python/mlc_llm/serve/server/server_context.py +++ b/python/mlc_llm/serve/server/server_context.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional -from ..engine import AsyncEngine +from ..engine import AsyncLLMEngine class ServerContext: @@ -13,7 +13,7 @@ class ServerContext: server_context: Optional["ServerContext"] = None def __init__(self): - self._models: Dict[str, AsyncEngine] = {} + self._models: Dict[str, AsyncLLMEngine] = {} def __enter__(self): if ServerContext.server_context is not None: @@ -31,13 +31,13 @@ def current(): """Returns the current ServerContext.""" return ServerContext.server_context - def add_model(self, hosted_model: str, engine: AsyncEngine) -> None: + def add_model(self, hosted_model: str, engine: AsyncLLMEngine) -> None: """Add a new model to the server context together with the engine.""" if hosted_model in self._models: raise RuntimeError(f"Model {hosted_model} already running.") self._models[hosted_model] = engine - def get_engine(self, model: str) -> Optional[AsyncEngine]: + def get_engine(self, model: str) -> Optional[AsyncLLMEngine]: """Get the async engine of the requested model.""" return self._models.get(model, None) diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 12c55259b6..963ea9402f 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -41,7 +41,7 @@ def _create_tvm_module( return {key: module[key] for key in ffi_funcs} -class SyncEngine: +class SyncLLMEngine: """The Python interface of synchronize request serving engine for MLC LLM. The engine receives requests from the "add_request" method. For diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index 0685261806..4e541b7437 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -5,7 +5,7 @@ from typing import List, Tuple from mlc_llm.serve import GenerationConfig -from mlc_llm.serve.sync_engine import SyncEngine +from mlc_llm.serve.sync_engine import SyncLLMEngine def _parse_args(): @@ -41,7 +41,7 @@ def benchmark(args: argparse.Namespace): random.seed(args.seed) # Create engine - engine = SyncEngine( + engine = SyncLLMEngine( model=args.model, device=args.device, model_lib_path=args.model_lib_path, diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index afa7081bd7..9bece30578 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -3,7 +3,7 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncEngine, GenerationConfig +from mlc_llm.serve import AsyncLLMEngine, GenerationConfig prompts = [ "What is the meaning of life?", @@ -23,7 +23,7 @@ async def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncEngine( + async_engine = AsyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -39,7 +39,7 @@ async def test_engine_generate(): ] async def generate_task( - async_engine: AsyncEngine, + async_engine: AsyncLLMEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, @@ -80,7 +80,7 @@ async def test_chat_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncEngine( + async_engine = AsyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -132,7 +132,7 @@ async def test_chat_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncEngine( + async_engine = AsyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -183,7 +183,7 @@ async def test_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncEngine( + async_engine = AsyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -235,7 +235,7 @@ async def test_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncEngine( + async_engine = AsyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index f7ccb13a8d..693f0767c3 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -3,7 +3,12 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncEngine, EngineConfig, GenerationConfig, SpeculativeMode +from mlc_llm.serve import ( + AsyncLLMEngine, + EngineConfig, + GenerationConfig, + SpeculativeMode, +) prompts = [ "What is the meaning of life?", @@ -27,7 +32,7 @@ async def test_engine_generate(): small_model_lib_path = ( "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - async_engine = AsyncEngine( + async_engine = AsyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -44,7 +49,7 @@ async def test_engine_generate(): ] async def generate_task( - async_engine: AsyncEngine, + async_engine: AsyncLLMEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 376671a884..330bd4cf82 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -2,7 +2,7 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable from typing import List -from mlc_llm.serve import Engine, GenerationConfig +from mlc_llm.serve import GenerationConfig, LLMEngine prompts = [ "What is the meaning of life?", @@ -22,7 +22,7 @@ def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = Engine( + engine = LLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -61,7 +61,7 @@ def test_chat_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = Engine( + engine = LLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -105,7 +105,7 @@ def test_chat_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = Engine( + engine = LLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -148,7 +148,7 @@ def test_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = Engine( + engine = LLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -192,7 +192,7 @@ def test_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = Engine( + engine = LLMEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index 1bb985f53a..7f2a33b230 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -7,9 +7,9 @@ import pytest from pydantic import BaseModel -from mlc_llm.serve import AsyncEngine, GenerationConfig +from mlc_llm.serve import AsyncLLMEngine, GenerationConfig from mlc_llm.serve.config import ResponseFormat -from mlc_llm.serve.sync_engine import SyncEngine +from mlc_llm.serve.sync_engine import SyncLLMEngine prompts_list = [ "Generate a JSON string containing 20 objects:", @@ -22,7 +22,7 @@ def test_batch_generation_with_grammar(): # Create engine - engine = SyncEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + engine = SyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompt_len = len(prompts_list) prompts = prompts_list * 3 @@ -69,7 +69,7 @@ def test_batch_generation_with_grammar(): def test_batch_generation_with_schema(): # Create engine - engine = SyncEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + engine = SyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompt = ( "Generate a json containing three fields: an integer field named size, a " @@ -121,7 +121,7 @@ class Schema(BaseModel): async def run_async_engine(): # Create engine - async_engine = AsyncEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + async_engine = AsyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompts = prompts_list * 20 @@ -142,7 +142,7 @@ async def run_async_engine(): ] async def generate_task( - async_engine: AsyncEngine, + async_engine: AsyncLLMEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py index f3e13d600b..ff64e7235b 100644 --- a/tests/python/serve/test_serve_engine_image.py +++ b/tests/python/serve/test_serve_engine_image.py @@ -2,7 +2,7 @@ from pathlib import Path from mlc_llm.serve import GenerationConfig, data -from mlc_llm.serve.sync_engine import SyncEngine +from mlc_llm.serve.sync_engine import SyncLLMEngine def get_test_image(config) -> data.ImageData: @@ -13,7 +13,7 @@ def test_engine_generate(): # Create engine model = "dist/llava-1.5-7b-hf-q4f16_1-MLC/params" model_lib_path = "dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so" - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 818064e423..b398dd62c3 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -14,7 +14,7 @@ data, ) from mlc_llm.serve.engine_base import ModelInfo -from mlc_llm.serve.sync_engine import SyncEngine +from mlc_llm.serve.sync_engine import SyncLLMEngine prompts = [ "What is the meaning of life?", @@ -93,7 +93,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): small_model_lib_path = ( "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -161,7 +161,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): small_model_lib_path = ( "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" ) - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -244,7 +244,7 @@ def step(self) -> None: "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) timer = CallbackTimer() - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -330,7 +330,7 @@ def step(self) -> None: "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) timer = CallbackTimer() - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -374,7 +374,7 @@ def test_engine_generate(): small_model_lib_path = ( "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -407,7 +407,7 @@ def test_engine_eagle_generate(): small_model_lib_path = ( "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -455,7 +455,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -527,7 +527,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # small_model_lib_path = ( # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so" # ) - spec_engine = SyncEngine( + spec_engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -598,7 +598,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): small_model_lib_path = ( "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" ) - spec_engine = SyncEngine( + spec_engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py index 4304348095..c5d521b02d 100644 --- a/tests/python/serve/test_serve_sync_engine.py +++ b/tests/python/serve/test_serve_sync_engine.py @@ -5,7 +5,7 @@ import numpy as np from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data -from mlc_llm.serve.sync_engine import SyncEngine +from mlc_llm.serve.sync_engine import SyncLLMEngine prompts = [ "What is the meaning of life?", @@ -80,7 +80,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -156,7 +156,7 @@ def step(self) -> None: timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -237,7 +237,7 @@ def step(self) -> None: timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -323,7 +323,7 @@ def all_finished(self) -> bool: timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -365,7 +365,7 @@ def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncEngine( + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, mode="server", From 751783bc2d8199aab520ce1300f332807fccd56a Mon Sep 17 00:00:00 2001 From: Git bot Date: Wed, 17 Apr 2024 22:09:09 +0000 Subject: [PATCH 197/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 0f67508236..7a8520581e 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 0f67508236158e5c7eb7c906df068e4ed95190f9 +Subproject commit 7a8520581e4a70024de05fa9e803b5d2899796f6 From e9a4a0bf719a7c4fd42b438cf9e159a1e8d72590 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 17 Apr 2024 16:55:12 -0700 Subject: [PATCH 198/531] [Quantization] Add e4m3 mode and enable fp8 storage type (#2154) * [Quantization] Add e4m3 mode and enable fp8 storage type * add quantize linear flag --- python/mlc_llm/cli/model_metadata.py | 4 +- python/mlc_llm/interface/convert_weight.py | 5 +- python/mlc_llm/op/moe_matmul.py | 3 +- .../quantization/per_tensor_quantization.py | 80 ++++++++++++------- python/mlc_llm/quantization/quantization.py | 17 +++- python/mlc_llm/quantization/utils.py | 3 +- 6 files changed, 73 insertions(+), 39 deletions(-) diff --git a/python/mlc_llm/cli/model_metadata.py b/python/mlc_llm/cli/model_metadata.py index 9b45561665..81473b1ec7 100644 --- a/python/mlc_llm/cli/model_metadata.py +++ b/python/mlc_llm/cli/model_metadata.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Any, Dict, List, Union -import numpy as np +from tvm.runtime import DataType from mlc_llm.support import logging from mlc_llm.support.argparse import ArgumentParser @@ -81,7 +81,7 @@ def _compute_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBa else: # Contains dynamic shape; use config to look up concrete values param_shape = _read_dynamic_shape(param["shape"], config) - params_bytes += math.prod(param_shape) * np.dtype(param["dtype"]).itemsize + params_bytes += math.prod(param_shape) * DataType(param["dtype"]).itemsize() temp_func_bytes = 0.0 for _func_name, func_bytes in metadata["memory_usage"].items(): temp_func_bytes = max(temp_func_bytes, func_bytes) diff --git a/python/mlc_llm/interface/convert_weight.py b/python/mlc_llm/interface/convert_weight.py index 90c5c45831..f6c3c5f255 100644 --- a/python/mlc_llm/interface/convert_weight.py +++ b/python/mlc_llm/interface/convert_weight.py @@ -7,10 +7,9 @@ from pathlib import Path from typing import Any, Dict, Iterator, Tuple -import numpy as np from tvm import tir from tvm.contrib import tvmjs -from tvm.runtime import Device, NDArray +from tvm.runtime import DataType, Device, NDArray from tvm.runtime import cpu as cpu_device from tvm.target import Target @@ -131,7 +130,7 @@ def _param_generator() -> Iterator[Tuple[str, NDArray]]: _check_param(name, param) param_names.add(name) param = param.copyto(cpu_device()) - total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize + total_bytes += math.prod(param.shape) * DataType(param.dtype).itemsize() yield name, param total_params = loader.stats.total_param_num diff --git a/python/mlc_llm/op/moe_matmul.py b/python/mlc_llm/op/moe_matmul.py index 95d7fed941..6def4a5ff2 100644 --- a/python/mlc_llm/op/moe_matmul.py +++ b/python/mlc_llm/op/moe_matmul.py @@ -2,7 +2,7 @@ from typing import Literal, Optional -from tvm import DataType, tir +from tvm import DataType, DataTypeCode, tir from tvm.relax.frontend.nn import Tensor, op from tvm.script import tir as T @@ -218,6 +218,7 @@ def _dequantize(w, s, e, i, j): if num_elem_per_storage == 1: w = tir.reinterpret(quantize_dtype, w[e, i, j]) else: + assert DataType(storage_dtype).type_code == DataTypeCode.UINT tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) w = w[e, i, j // num_elem_per_storage] shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) diff --git a/python/mlc_llm/quantization/per_tensor_quantization.py b/python/mlc_llm/quantization/per_tensor_quantization.py index c2776b2a86..274a221393 100644 --- a/python/mlc_llm/quantization/per_tensor_quantization.py +++ b/python/mlc_llm/quantization/per_tensor_quantization.py @@ -16,6 +16,7 @@ compile_quantize_func, convert_uint_packed_fp8_to_float, is_final_fc, + is_moe_gate, pack_weight, ) @@ -30,10 +31,11 @@ class PerTensorQuantize: # pylint: disable=too-many-instance-attributes kind: str activation_dtype: Literal["e4m3_float8", "e5m2_float8"] weight_dtype: Literal["e4m3_float8", "e5m2_float8"] - storage_dtype: Literal["uint32"] + storage_dtype: Literal["uint32", "e4m3_float8", "e5m2_float8"] model_dtype: Literal["float16"] quantize_embedding: bool = True quantize_final_fc: bool = True + quantize_linear: bool = True num_elem_per_storage: int = 0 max_int_value: int = 0 @@ -101,8 +103,11 @@ def visit_module(self, name: str, node: nn.Module) -> Any: f"{name}.q_weight", ] ) - if isinstance(node, nn.Linear) and ( - not is_final_fc(name) or self.config.quantize_final_fc + if ( + isinstance(node, nn.Linear) + and self.config.quantize_linear + and (not is_final_fc(name) or self.config.quantize_final_fc) + and not is_moe_gate(name, node) ): self.quant_map.param_map[weight_name] = param_names self.quant_map.map_func[weight_name] = self.config.quantize_weight @@ -192,7 +197,11 @@ def _compute_scale(x: te.Tensor) -> te.Tensor: scale = None def _compute_quantized_weight(weight: te.Tensor, scale: Optional[te.Tensor]) -> te.Tensor: - elem_storage_dtype = f"uint{quantize_dtype.bits}" + elem_storage_dtype = ( + f"uint{quantize_dtype.bits}" + if DataType(self.storage_dtype).type_code == DataTypeCode.UINT + else quantize_dtype + ) scaled_weight = te.compute( shape=weight.shape, fcompute=lambda *idx: tir.Cast( @@ -207,6 +216,9 @@ def _compute_quantized_weight(weight: te.Tensor, scale: Optional[te.Tensor]) -> ), ) + if self.weight_dtype == self.storage_dtype: + return scaled_weight + packed_weight = pack_weight( scaled_weight, axis=-1, @@ -248,15 +260,18 @@ def dequantize_float8( out_shape: Optional[Sequence[tir.PrimExpr]] = None, ) -> te.Tensor: """Dequantize a fp8 tensor to higher-precision float.""" - weight = convert_uint_packed_fp8_to_float( - q_weight, - self.num_elem_per_storage, - self.storage_dtype, - self.model_dtype, - quantize_dtype, - axis=-1, - out_shape=out_shape, - ) + if quantize_dtype != self.storage_dtype: + weight = convert_uint_packed_fp8_to_float( + q_weight, + self.num_elem_per_storage, + self.storage_dtype, + self.model_dtype, + quantize_dtype, + axis=-1, + out_shape=out_shape, + ) + else: + weight = q_weight.astype(self.model_dtype) if scale is not None: weight = weight * scale return weight @@ -276,7 +291,7 @@ def __init__( # pylint: disable=too-many-arguments super().__init__() self.in_features = in_features self.out_features = out_features - self.out_dtype = out_dtype + self.out_dtype = out_dtype or config.model_dtype self.config = config self.q_weight = nn.Parameter( (out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), @@ -341,22 +356,27 @@ def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name ret : nn.Tensor The output tensor for the per-tensor quantized linear layer. """ - w = nn.op.tensor_expr_op( - lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access - weight, - scale, - out_shape=[ - ( - tir.IntImm("int64", self.out_features) - if isinstance(self.out_features, int) - else weight.shape[0] - ), - tir.IntImm("int64", self.in_features), - ], - ), - "dequantize", - args=[self.q_weight, self.q_scale], - ) + # Note: Use calibration scale when calibration is enabled + x = x.astype(self.config.activation_dtype) + if self.config.weight_dtype == self.config.storage_dtype: + w = self.q_weight + else: + w = nn.op.tensor_expr_op( + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + out_shape=[ + ( + tir.IntImm("int64", self.out_features) + if isinstance(self.out_features, int) + else weight.shape[0] + ), + tir.IntImm("int64", self.in_features), + ], + ), + "dequantize", + args=[self.q_weight, self.q_scale], + ) w = nn.op.permute_dims(w) x = nn.op.matmul(x, w, out_dtype=self.out_dtype) if self.bias is not None: diff --git a/python/mlc_llm/quantization/quantization.py b/python/mlc_llm/quantization/quantization.py index 1b2d8695cf..ed7d8a6720 100644 --- a/python/mlc_llm/quantization/quantization.py +++ b/python/mlc_llm/quantization/quantization.py @@ -123,10 +123,23 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr kind="per-tensor-quant", activation_dtype="e5m2_float8", weight_dtype="e5m2_float8", - storage_dtype="uint32", + storage_dtype="e5m2_float8", model_dtype="float16", - quantize_final_fc=True, + quantize_final_fc=False, + quantize_embedding=False, + quantize_linear=True, + use_scale=False, + ), + "e4m3_e4m3_f16": PerTensorQuantize( + name="e4m3_e4m3_f16", + kind="per-tensor-quant", + activation_dtype="e4m3_float8", + weight_dtype="e4m3_float8", + storage_dtype="e4m3_float8", + model_dtype="float16", + quantize_final_fc=False, quantize_embedding=False, + quantize_linear=True, use_scale=False, ), } diff --git a/python/mlc_llm/quantization/utils.py b/python/mlc_llm/quantization/utils.py index fdc50ff74d..3e55de4524 100644 --- a/python/mlc_llm/quantization/utils.py +++ b/python/mlc_llm/quantization/utils.py @@ -6,7 +6,7 @@ from tvm import dlight as dl from tvm import relax, te, tir from tvm.relax.frontend import nn -from tvm.runtime import DataType +from tvm.runtime import DataType, DataTypeCode from tvm.target import Target from mlc_llm.support import tensor_parallel as tp @@ -105,6 +105,7 @@ def convert_uint_packed_fp8_to_float( # pylint: disable=too-many-arguments ) -> te.Tensor: """Unpack a fp8 value from the storage dtype and convert to float.""" assert quant_dtype in ["e4m3_float8", "e5m2_float8"] + assert DataType(storage_dtype).type_code == DataTypeCode.UINT bits = DataType(quant_dtype).bits elem_storage_dtype = DataType(f"uint{bits}") tir_bin_mask = tir.const((1 << bits) - 1, "uint8") From 7d3f34e686ee64ffd207595043656ff88360d51f Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 18 Apr 2024 07:53:16 -0400 Subject: [PATCH 199/531] Revert "[Quantization] Add e4m3 mode and enable fp8 storage type" (#2158) Revert "[Quantization] Add e4m3 mode and enable fp8 storage type (#2154)" This reverts commit e9a4a0bf719a7c4fd42b438cf9e159a1e8d72590. --- python/mlc_llm/cli/model_metadata.py | 4 +- python/mlc_llm/interface/convert_weight.py | 5 +- python/mlc_llm/op/moe_matmul.py | 3 +- .../quantization/per_tensor_quantization.py | 80 +++++++------------ python/mlc_llm/quantization/quantization.py | 17 +--- python/mlc_llm/quantization/utils.py | 3 +- 6 files changed, 39 insertions(+), 73 deletions(-) diff --git a/python/mlc_llm/cli/model_metadata.py b/python/mlc_llm/cli/model_metadata.py index 81473b1ec7..9b45561665 100644 --- a/python/mlc_llm/cli/model_metadata.py +++ b/python/mlc_llm/cli/model_metadata.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Any, Dict, List, Union -from tvm.runtime import DataType +import numpy as np from mlc_llm.support import logging from mlc_llm.support.argparse import ArgumentParser @@ -81,7 +81,7 @@ def _compute_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBa else: # Contains dynamic shape; use config to look up concrete values param_shape = _read_dynamic_shape(param["shape"], config) - params_bytes += math.prod(param_shape) * DataType(param["dtype"]).itemsize() + params_bytes += math.prod(param_shape) * np.dtype(param["dtype"]).itemsize temp_func_bytes = 0.0 for _func_name, func_bytes in metadata["memory_usage"].items(): temp_func_bytes = max(temp_func_bytes, func_bytes) diff --git a/python/mlc_llm/interface/convert_weight.py b/python/mlc_llm/interface/convert_weight.py index f6c3c5f255..90c5c45831 100644 --- a/python/mlc_llm/interface/convert_weight.py +++ b/python/mlc_llm/interface/convert_weight.py @@ -7,9 +7,10 @@ from pathlib import Path from typing import Any, Dict, Iterator, Tuple +import numpy as np from tvm import tir from tvm.contrib import tvmjs -from tvm.runtime import DataType, Device, NDArray +from tvm.runtime import Device, NDArray from tvm.runtime import cpu as cpu_device from tvm.target import Target @@ -130,7 +131,7 @@ def _param_generator() -> Iterator[Tuple[str, NDArray]]: _check_param(name, param) param_names.add(name) param = param.copyto(cpu_device()) - total_bytes += math.prod(param.shape) * DataType(param.dtype).itemsize() + total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize yield name, param total_params = loader.stats.total_param_num diff --git a/python/mlc_llm/op/moe_matmul.py b/python/mlc_llm/op/moe_matmul.py index 6def4a5ff2..95d7fed941 100644 --- a/python/mlc_llm/op/moe_matmul.py +++ b/python/mlc_llm/op/moe_matmul.py @@ -2,7 +2,7 @@ from typing import Literal, Optional -from tvm import DataType, DataTypeCode, tir +from tvm import DataType, tir from tvm.relax.frontend.nn import Tensor, op from tvm.script import tir as T @@ -218,7 +218,6 @@ def _dequantize(w, s, e, i, j): if num_elem_per_storage == 1: w = tir.reinterpret(quantize_dtype, w[e, i, j]) else: - assert DataType(storage_dtype).type_code == DataTypeCode.UINT tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) w = w[e, i, j // num_elem_per_storage] shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) diff --git a/python/mlc_llm/quantization/per_tensor_quantization.py b/python/mlc_llm/quantization/per_tensor_quantization.py index 274a221393..c2776b2a86 100644 --- a/python/mlc_llm/quantization/per_tensor_quantization.py +++ b/python/mlc_llm/quantization/per_tensor_quantization.py @@ -16,7 +16,6 @@ compile_quantize_func, convert_uint_packed_fp8_to_float, is_final_fc, - is_moe_gate, pack_weight, ) @@ -31,11 +30,10 @@ class PerTensorQuantize: # pylint: disable=too-many-instance-attributes kind: str activation_dtype: Literal["e4m3_float8", "e5m2_float8"] weight_dtype: Literal["e4m3_float8", "e5m2_float8"] - storage_dtype: Literal["uint32", "e4m3_float8", "e5m2_float8"] + storage_dtype: Literal["uint32"] model_dtype: Literal["float16"] quantize_embedding: bool = True quantize_final_fc: bool = True - quantize_linear: bool = True num_elem_per_storage: int = 0 max_int_value: int = 0 @@ -103,11 +101,8 @@ def visit_module(self, name: str, node: nn.Module) -> Any: f"{name}.q_weight", ] ) - if ( - isinstance(node, nn.Linear) - and self.config.quantize_linear - and (not is_final_fc(name) or self.config.quantize_final_fc) - and not is_moe_gate(name, node) + if isinstance(node, nn.Linear) and ( + not is_final_fc(name) or self.config.quantize_final_fc ): self.quant_map.param_map[weight_name] = param_names self.quant_map.map_func[weight_name] = self.config.quantize_weight @@ -197,11 +192,7 @@ def _compute_scale(x: te.Tensor) -> te.Tensor: scale = None def _compute_quantized_weight(weight: te.Tensor, scale: Optional[te.Tensor]) -> te.Tensor: - elem_storage_dtype = ( - f"uint{quantize_dtype.bits}" - if DataType(self.storage_dtype).type_code == DataTypeCode.UINT - else quantize_dtype - ) + elem_storage_dtype = f"uint{quantize_dtype.bits}" scaled_weight = te.compute( shape=weight.shape, fcompute=lambda *idx: tir.Cast( @@ -216,9 +207,6 @@ def _compute_quantized_weight(weight: te.Tensor, scale: Optional[te.Tensor]) -> ), ) - if self.weight_dtype == self.storage_dtype: - return scaled_weight - packed_weight = pack_weight( scaled_weight, axis=-1, @@ -260,18 +248,15 @@ def dequantize_float8( out_shape: Optional[Sequence[tir.PrimExpr]] = None, ) -> te.Tensor: """Dequantize a fp8 tensor to higher-precision float.""" - if quantize_dtype != self.storage_dtype: - weight = convert_uint_packed_fp8_to_float( - q_weight, - self.num_elem_per_storage, - self.storage_dtype, - self.model_dtype, - quantize_dtype, - axis=-1, - out_shape=out_shape, - ) - else: - weight = q_weight.astype(self.model_dtype) + weight = convert_uint_packed_fp8_to_float( + q_weight, + self.num_elem_per_storage, + self.storage_dtype, + self.model_dtype, + quantize_dtype, + axis=-1, + out_shape=out_shape, + ) if scale is not None: weight = weight * scale return weight @@ -291,7 +276,7 @@ def __init__( # pylint: disable=too-many-arguments super().__init__() self.in_features = in_features self.out_features = out_features - self.out_dtype = out_dtype or config.model_dtype + self.out_dtype = out_dtype self.config = config self.q_weight = nn.Parameter( (out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), @@ -356,27 +341,22 @@ def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name ret : nn.Tensor The output tensor for the per-tensor quantized linear layer. """ - # Note: Use calibration scale when calibration is enabled - x = x.astype(self.config.activation_dtype) - if self.config.weight_dtype == self.config.storage_dtype: - w = self.q_weight - else: - w = nn.op.tensor_expr_op( - lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access - weight, - scale, - out_shape=[ - ( - tir.IntImm("int64", self.out_features) - if isinstance(self.out_features, int) - else weight.shape[0] - ), - tir.IntImm("int64", self.in_features), - ], - ), - "dequantize", - args=[self.q_weight, self.q_scale], - ) + w = nn.op.tensor_expr_op( + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + out_shape=[ + ( + tir.IntImm("int64", self.out_features) + if isinstance(self.out_features, int) + else weight.shape[0] + ), + tir.IntImm("int64", self.in_features), + ], + ), + "dequantize", + args=[self.q_weight, self.q_scale], + ) w = nn.op.permute_dims(w) x = nn.op.matmul(x, w, out_dtype=self.out_dtype) if self.bias is not None: diff --git a/python/mlc_llm/quantization/quantization.py b/python/mlc_llm/quantization/quantization.py index ed7d8a6720..1b2d8695cf 100644 --- a/python/mlc_llm/quantization/quantization.py +++ b/python/mlc_llm/quantization/quantization.py @@ -123,23 +123,10 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr kind="per-tensor-quant", activation_dtype="e5m2_float8", weight_dtype="e5m2_float8", - storage_dtype="e5m2_float8", - model_dtype="float16", - quantize_final_fc=False, - quantize_embedding=False, - quantize_linear=True, - use_scale=False, - ), - "e4m3_e4m3_f16": PerTensorQuantize( - name="e4m3_e4m3_f16", - kind="per-tensor-quant", - activation_dtype="e4m3_float8", - weight_dtype="e4m3_float8", - storage_dtype="e4m3_float8", + storage_dtype="uint32", model_dtype="float16", - quantize_final_fc=False, + quantize_final_fc=True, quantize_embedding=False, - quantize_linear=True, use_scale=False, ), } diff --git a/python/mlc_llm/quantization/utils.py b/python/mlc_llm/quantization/utils.py index 3e55de4524..fdc50ff74d 100644 --- a/python/mlc_llm/quantization/utils.py +++ b/python/mlc_llm/quantization/utils.py @@ -6,7 +6,7 @@ from tvm import dlight as dl from tvm import relax, te, tir from tvm.relax.frontend import nn -from tvm.runtime import DataType, DataTypeCode +from tvm.runtime import DataType from tvm.target import Target from mlc_llm.support import tensor_parallel as tp @@ -105,7 +105,6 @@ def convert_uint_packed_fp8_to_float( # pylint: disable=too-many-arguments ) -> te.Tensor: """Unpack a fp8 value from the storage dtype and convert to float.""" assert quant_dtype in ["e4m3_float8", "e5m2_float8"] - assert DataType(storage_dtype).type_code == DataTypeCode.UINT bits = DataType(quant_dtype).bits elem_storage_dtype = DataType(f"uint{bits}") tir_bin_mask = tir.const((1 << bits) - 1, "uint8") From 835223541d4135e511a50cba1deca06731b03abd Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 18 Apr 2024 14:04:19 -0400 Subject: [PATCH 200/531] [Serving] EngineConfig refactor (#2159) This PR refactors EngineConfig for a cleaner interface of internal Engine constructor in MLC serve. This is a preparation step towards the engine reload/unload which will be introduced in follow-up PRs for JSONFFIEngine functionality on mobile and other platforms. --- cpp/json_ffi/json_ffi_engine.cc | 54 ++---- cpp/llm_chat.cc | 2 - cpp/serve/config.cc | 137 +++---------- cpp/serve/config.h | 86 +++++---- cpp/serve/engine.cc | 181 +++++++----------- cpp/serve/engine.h | 46 +---- cpp/serve/engine_actions/action.h | 17 +- cpp/serve/engine_actions/batch_verify.cc | 16 +- .../engine_actions/eagle_batch_verify.cc | 16 +- .../eagle_new_request_prefill.cc | 34 ++-- .../engine_actions/new_request_prefill.cc | 35 ++-- cpp/serve/function_table.cc | 20 +- cpp/serve/function_table.h | 2 +- cpp/serve/model.cc | 24 +-- cpp/serve/model.h | 16 +- cpp/serve/threaded_engine.cc | 25 +-- cpp/serve/threaded_engine.h | 8 +- python/mlc_llm/cli/serve.py | 16 +- python/mlc_llm/help.py | 10 + python/mlc_llm/interface/serve.py | 8 +- python/mlc_llm/serve/__init__.py | 2 +- python/mlc_llm/serve/config.py | 151 +++++++-------- python/mlc_llm/serve/engine.py | 14 +- python/mlc_llm/serve/engine_base.py | 89 ++++----- python/mlc_llm/serve/server/popen_server.py | 17 +- python/mlc_llm/serve/sync_engine.py | 46 +++-- tests/python/json_ffi/test_json_ffi_engine.py | 44 +++-- .../serve/test_serve_async_engine_spec.py | 11 +- tests/python/serve/test_serve_engine_spec.py | 26 ++- 29 files changed, 503 insertions(+), 650 deletions(-) diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 489e2e5339..b02a28ca89 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -102,6 +102,7 @@ JSONFFIEngine::~JSONFFIEngine() { this->ExitBackgroundLoop(); } class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { public: TVM_MODULE_VTABLE_BEGIN("mlc.json_ffi"); + TVM_MODULE_VTABLE_ENTRY("init_background_engine", &JSONFFIEngineImpl::InitBackgroundEngine); TVM_MODULE_VTABLE_ENTRY("chat_completion", &JSONFFIEngineImpl::ChatCompletion); TVM_MODULE_VTABLE_ENTRY("abort", &JSONFFIEngineImpl::Abort); TVM_MODULE_VTABLE_ENTRY("get_last_error", &JSONFFIEngineImpl::GetLastError); @@ -109,41 +110,28 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", &JSONFFIEngineImpl::RunBackgroundStreamBackLoop); TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); - if (_name == "init_background_engine") { - return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { - SelfPtr self = static_cast(_self.get()); - - std::string tokenizer_path = args.At(1); - self->streamer_ = TextStreamer(Tokenizer::FromPath(tokenizer_path)); - - // Callback wrapper - Optional request_stream_callback; - try { - request_stream_callback = args.At>(4); - } catch (const dmlc::Error& e) { - LOG(FATAL) << "ValueError: " << e.what() << kEngineCreationErrorMessage; - } + TVM_MODULE_VTABLE_END(); - CHECK(request_stream_callback.defined()) - << "JSONFFIEngine requires request stream callback function, but it is not given."; - self->request_stream_callback_ = request_stream_callback.value(); - - auto frequest_stream_callback_wrapper = [self](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.size(), 1); - Array delta_outputs = args[0]; - Array responses = self->GetResponseFromStreamOutput(delta_outputs); - self->request_stream_callback_(responses); - }; - - std::vector values{args.values, args.values + args.size()}; - std::vector type_codes{args.type_codes, args.type_codes + args.size()}; - TVMArgsSetter setter(values.data(), type_codes.data()); - request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - setter(4, request_stream_callback); - self->engine_->InitBackgroundEngine(TVMArgs(values.data(), type_codes.data(), args.size())); - }); + void InitBackgroundEngine(EngineConfig engine_config, + Optional request_stream_callback, + Optional trace_recorder) { + this->streamer_ = TextStreamer(Tokenizer::FromPath(engine_config->model)); + + CHECK(request_stream_callback.defined()) + << "JSONFFIEngine requires request stream callback function, but it is not given."; + this->request_stream_callback_ = request_stream_callback.value(); + + auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.size(), 1); + Array delta_outputs = args[0]; + Array responses = this->GetResponseFromStreamOutput(delta_outputs); + this->request_stream_callback_(responses); + }; + + request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); + this->engine_->InitBackgroundEngine( + std::move(engine_config), std::move(request_stream_callback), std::move(trace_recorder)); } - TVM_MODULE_VTABLE_END(); void RunBackgroundLoop() { this->engine_->RunBackgroundLoop(); } diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 8cadbe8df4..9485ccad02 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -1618,8 +1618,6 @@ class LLMChat { NDArray logits_on_cpu_{nullptr}; // pre-allocated ndarray for decode function's input tokens DRef input_tokens_decode_{nullptr}; - // KV cache config - serve::KVCacheConfig kv_cache_config_{nullptr}; }; /*! diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index ec9694ca1e..5d647ec532 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -5,6 +5,7 @@ #include "config.h" #include +#include #include @@ -222,123 +223,43 @@ String GenerationConfigNode::AsJSONString() const { return picojson::value(config).serialize(true); } -/****************** KVCacheConfig ******************/ - -TVM_REGISTER_OBJECT_TYPE(KVCacheConfigNode); - -KVCacheConfig::KVCacheConfig(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size) { - ObjectPtr n = make_object(); - n->page_size = page_size; - n->max_num_sequence = max_num_sequence; - n->max_total_sequence_length = max_total_sequence_length; - n->prefill_chunk_size = prefill_chunk_size; - data_ = std::move(n); -} - -KVCacheConfig::KVCacheConfig(const std::string& config_str, int max_single_sequence_length) { - int page_size; - int max_total_sequence_length; - int max_num_sequence = -1; - int prefill_chunk_size; - - picojson::value config_json; - std::string err = picojson::parse(config_json, config_str); - if (!err.empty()) { - LOG(FATAL) << err; - } - - // Get json fields. - picojson::object config = config_json.get(); - if (config.count("page_size")) { - CHECK(config["page_size"].is()); - page_size = config["page_size"].get(); - CHECK_EQ(page_size, 16) << "KV cache page size other than 16 is not supported."; - } else { - LOG(FATAL) << "Key \"page_size\" not found."; - } - if (config.count("max_total_sequence_length")) { - CHECK(config["max_total_sequence_length"].is()); - max_total_sequence_length = config["max_total_sequence_length"].get(); - } else { - LOG(FATAL) << "Key \"max_total_sequence_length\" not found."; - } - if (config.count("prefill_chunk_size")) { - CHECK(config["prefill_chunk_size"].is()); - prefill_chunk_size = config["prefill_chunk_size"].get(); - } else { - LOG(FATAL) << "Key \"prefill_chunk_size\" not found."; - } - if (config.count("max_num_sequence")) { - CHECK(config["max_num_sequence"].is()); - max_num_sequence = config["max_num_sequence"].get(); - CHECK_GT(max_num_sequence, 0) << "Max number of sequence should be positive."; - } else { - LOG(FATAL) << "Key \"max_num_sequence\" not found."; - } - - ObjectPtr n = make_object(); - n->page_size = page_size; - n->max_num_sequence = max_num_sequence; - n->max_total_sequence_length = max_total_sequence_length; - n->prefill_chunk_size = prefill_chunk_size; - data_ = std::move(n); -} - -String KVCacheConfigNode::AsJSONString() const { - picojson::object config; - config["page_size"] = picojson::value(static_cast(this->page_size)); - config["max_num_sequence"] = picojson::value(static_cast(this->max_num_sequence)); - config["max_total_sequence_length"] = - picojson::value(static_cast(this->max_total_sequence_length)); - config["prefill_chunk_size"] = picojson::value(static_cast(this->prefill_chunk_size)); - return picojson::value(config).serialize(true); -} - /****************** EngineConfig ******************/ TVM_REGISTER_OBJECT_TYPE(EngineConfigNode); -EngineConfig::EngineConfig(int spec_draft_length, int speculative_mode) { - ObjectPtr n = make_object(); - n->spec_draft_length = spec_draft_length; - n->speculative_mode = SpeculativeMode(speculative_mode); - data_ = std::move(n); -} - -EngineConfig::EngineConfig(const std::string& config_str) { - int spec_draft_length = 4; - int speculative_mode = 0; - - picojson::value config_json; - std::string err = picojson::parse(config_json, config_str); - if (!err.empty()) { - LOG(FATAL) << err; - } - - // Get json fields. - picojson::object config = config_json.get(); - if (config.count("spec_draft_length")) { - CHECK(config["spec_draft_length"].is()); - spec_draft_length = config["spec_draft_length"].get(); - } - if (config.count("speculative_mode")) { - CHECK(config["speculative_mode"].is()); - speculative_mode = config["speculative_mode"].get(); - } - +EngineConfig::EngineConfig(String model, String model_lib_path, Array additional_models, + Array additional_model_lib_paths, DLDevice device, + int kv_cache_page_size, int max_num_sequence, + int max_total_sequence_length, int max_single_sequence_length, + int prefill_chunk_size, SpeculativeMode speculative_mode, + int spec_draft_length) { ObjectPtr n = make_object(); + n->model = std::move(model); + n->model_lib_path = std::move(model_lib_path); + n->additional_models = std::move(additional_models); + n->additional_model_lib_paths = std::move(additional_model_lib_paths); + n->device = device; + n->kv_cache_page_size = kv_cache_page_size; + n->max_num_sequence = max_num_sequence; + n->max_total_sequence_length = max_total_sequence_length; + n->max_single_sequence_length = max_single_sequence_length; + n->prefill_chunk_size = prefill_chunk_size; n->spec_draft_length = spec_draft_length; - n->speculative_mode = SpeculativeMode(speculative_mode); + n->speculative_mode = speculative_mode; data_ = std::move(n); } -String EngineConfigNode::AsJSONString() const { - picojson::object config; - config["spec_draft_length"] = picojson::value(static_cast(this->spec_draft_length)); - config["speculative_mode"] = picojson::value(static_cast(this->speculative_mode)); - return picojson::value(config).serialize(true); -} +TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig") + .set_body_typed([](String model, String model_lib_path, Array additional_models, + Array additional_model_lib_paths, DLDevice device, + int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, + int max_single_sequence_length, int prefill_chunk_size, int speculative_mode, + int spec_draft_length) { + return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models), + std::move(additional_model_lib_paths), device, kv_cache_page_size, + max_num_sequence, max_total_sequence_length, max_single_sequence_length, + prefill_chunk_size, SpeculativeMode(speculative_mode), spec_draft_length); + }); } // namespace serve } // namespace llm diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 214e9ccdd9..404566fe2c 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -68,50 +68,62 @@ class GenerationConfig : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(GenerationConfig, ObjectRef, GenerationConfigNode); }; -/****************** KV Cache config ******************/ - -/*! \brief The configuration of paged KV cache. */ -class KVCacheConfigNode : public Object { - public: - int page_size; - int max_num_sequence; - int max_total_sequence_length; - int prefill_chunk_size; - - String AsJSONString() const; - - static constexpr const char* _type_key = "mlc.serve.KVCacheConfig"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(KVCacheConfigNode, Object); -}; - -class KVCacheConfig : public ObjectRef { - public: - explicit KVCacheConfig(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size); - - explicit KVCacheConfig(const std::string& config_str, int max_single_sequence_length); - - TVM_DEFINE_OBJECT_REF_METHODS(KVCacheConfig, ObjectRef, KVCacheConfigNode); -}; - -/****************** Engine Mode ******************/ +/****************** Engine config ******************/ /*! \brief The speculative mode. */ enum class SpeculativeMode : int { + /*! \brief Disable speculative decoding. */ kDisable = 0, + /*! \brief The normal speculative decoding (small draft) mode. */ kSmallDraft = 1, + /*! \brief The eagle-style speculative decoding. */ kEagle = 2, }; /*! \brief The configuration of engine execution config. */ class EngineConfigNode : public Object { public: - /* The number of tokens to generate in speculative proposal (draft) */ - int spec_draft_length; - /* The speculative mode. */ + /*************** Models ***************/ + + /*! \brief The path to the model directory. */ + String model; + /*! \brief The path to the model library. */ + String model_lib_path; + /*! \brief The path to the additional models' directories. */ + Array additional_models; + /*! \brief The path to the additional models' libraries. */ + Array additional_model_lib_paths; + + /*************** Device ***************/ + + /*! \brief The device where the models run. */ + DLDevice device; + + /*************** KV cache config and engine capacities ***************/ + + /*! \brief The number of consecutive tokens handled in each page in paged KV cache. */ + int kv_cache_page_size; + /*! + * \brief The maximum number of sequences that are allowed to be + * processed by the KV cache at any time. + */ + int max_num_sequence; + /*! \brief The maximum length allowed for a single sequence in the engine. */ + int max_total_sequence_length; + /*! + * \brief The maximum total number of tokens whose KV data are allowed + * to exist in the KV cache at any time. + */ + int max_single_sequence_length; + /*! \brief The maximum total sequence length in a prefill. */ + int prefill_chunk_size; + + /*************** Speculative decoding ***************/ + + /*! \brief The speculative mode. */ SpeculativeMode speculative_mode; + /*! \brief The number of tokens to generate in speculative proposal (draft). */ + int spec_draft_length = 4; String AsJSONString() const; @@ -123,11 +135,13 @@ class EngineConfigNode : public Object { class EngineConfig : public ObjectRef { public: - explicit EngineConfig(int spec_draft_length, int speculative_mode); - - explicit EngineConfig(const std::string& config_str); + explicit EngineConfig(String model, String model_lib_path, Array additional_models, + Array additional_model_lib_paths, DLDevice device, + int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, + int max_single_sequence_length, int prefill_chunk_size, + SpeculativeMode speculative_mode, int spec_draft_length); - TVM_DEFINE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); }; } // namespace serve diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index c9ca511e85..85d1c66c2d 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -44,100 +44,101 @@ class EngineImpl : public Engine { public: /********************** Engine Management **********************/ - explicit EngineImpl(int max_single_sequence_length, const String& tokenizer_path, - const String& kv_cache_config_json_str, const String& engine_config_json_str, - Optional request_stream_callback, - Optional trace_recorder, - const std::vector>& model_infos) { - CHECK_GE(model_infos.size(), 1) << "ValueError: No model is provided in the engine."; + explicit EngineImpl(EngineConfig engine_config, Optional request_stream_callback, + Optional trace_recorder) { // Step 1. Initialize metadata and singleton states inside the engine this->estate_->Reset(); // Being "-1" means there is no limit on single sequence length. - this->max_single_sequence_length_ = max_single_sequence_length != -1 - ? max_single_sequence_length - : std::numeric_limits::max(); - this->kv_cache_config_ = KVCacheConfig(kv_cache_config_json_str, max_single_sequence_length); - this->engine_config_ = EngineConfig(engine_config_json_str); + if (engine_config->max_single_sequence_length == -1) { + engine_config->max_single_sequence_length = std::numeric_limits::max(); + } this->request_stream_callback_ = std::move(request_stream_callback); this->trace_recorder_ = trace_recorder; - this->tokenizer_ = Tokenizer::FromPath(tokenizer_path); + this->tokenizer_ = Tokenizer::FromPath(engine_config->model); this->token_table_ = tokenizer_->TokenTable(); this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); // Step 2. Initialize each model independently. // Create the logit processor and sampler. this->models_.clear(); this->model_workspaces_.clear(); - for (const auto& model_info : model_infos) { - TVMArgValue model_lib = std::get<0>(model_info); - String model_path = std::get<1>(model_info); - DLDevice device = std::get<2>(model_info); - Model model = Model::Create(model_lib, std::move(model_path), device, - kv_cache_config_->max_num_sequence, + + auto f_create_model = [this, &engine_config, &trace_recorder](const String& model_path, + const String& model_lib_path) { + Model model = Model::Create(model_lib_path, std::move(model_path), engine_config->device, + engine_config->max_num_sequence, /*trace_enabled=*/trace_recorder.defined()); - model->CreateKVCache(this->kv_cache_config_); - CHECK_GE(model->GetMaxWindowSize(), this->max_single_sequence_length_) + model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, + engine_config->max_total_sequence_length, + engine_config->prefill_chunk_size); + CHECK_GE(model->GetMaxWindowSize(), engine_config->max_single_sequence_length) << "The window size of the model, " << model->GetMaxWindowSize() << ", is smaller than the pre-defined max single sequence length, " - << this->max_single_sequence_length_; + << engine_config->max_single_sequence_length; this->models_.push_back(model); this->model_workspaces_.push_back( ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); + }; + + f_create_model(engine_config->model, engine_config->model_lib_path); + CHECK_EQ(engine_config->additional_models.size(), + engine_config->additional_model_lib_paths.size()) + << "The additional model and lib path list has mismatched size."; + for (int i = 0; i < static_cast(engine_config->additional_models.size()); ++i) { + f_create_model(engine_config->additional_models[i], + engine_config->additional_model_lib_paths[i]); } - int max_num_tokens = kv_cache_config_->max_num_sequence; - if (engine_config_->speculative_mode != SpeculativeMode::kDisable) { - max_num_tokens *= engine_config_->spec_draft_length; + + int max_num_tokens = engine_config->max_num_sequence; + if (engine_config->speculative_mode != SpeculativeMode::kDisable) { + max_num_tokens *= engine_config->spec_draft_length; } LogitProcessor logit_processor = this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); Sampler sampler = this->models_[0]->CreateSampler( max_num_tokens, static_cast(this->models_.size()), trace_recorder); // Step 3. Initialize engine actions that represent state transitions. - if (this->engine_config_->speculative_mode != SpeculativeMode::kDisable) { + if (engine_config->speculative_mode != SpeculativeMode::kDisable) { // Speculative decoding is only possible for more than one model. ICHECK_GT(this->models_.size(), 1U); - switch (this->engine_config_->speculative_mode) { + switch (engine_config->speculative_mode) { case SpeculativeMode::kEagle: - this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - this->kv_cache_config_, // - this->engine_config_, // - this->trace_recorder_), - EngineAction::EagleBatchDraft( - this->models_, logit_processor, sampler, this->model_workspaces_, - this->trace_recorder_, this->engine_config_->spec_draft_length), - EngineAction::EagleBatchVerify( - this->models_, logit_processor, sampler, this->model_workspaces_, - this->kv_cache_config_, this->trace_recorder_)}; + this->actions_ = { + EngineAction::EagleNewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + engine_config, // + this->trace_recorder_), + EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler, + this->model_workspaces_, this->trace_recorder_), + EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, + this->model_workspaces_, engine_config, + this->trace_recorder_)}; break; default: - this->actions_ = { - EngineAction::NewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - this->kv_cache_config_, // - this->engine_config_, // - this->trace_recorder_), - EngineAction::BatchDraft(this->models_, logit_processor, sampler, - this->trace_recorder_, - this->engine_config_->spec_draft_length), - EngineAction::BatchVerify(this->models_, logit_processor, sampler, - this->kv_cache_config_, this->trace_recorder_)}; + this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + engine_config, // + this->trace_recorder_), + EngineAction::BatchDraft(this->models_, logit_processor, sampler, + this->trace_recorder_), + EngineAction::BatchVerify(this->models_, logit_processor, sampler, + engine_config, this->trace_recorder_)}; } } else { this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // logit_processor, // sampler, // this->model_workspaces_, // - this->kv_cache_config_, // - this->engine_config_, // + engine_config, // this->trace_recorder_), EngineAction::BatchDecode(this->models_, logit_processor, sampler, this->trace_recorder_)}; } // Step 4. Automatically set the threading backend max concurrency. + this->engine_config_ = engine_config; SetThreadMaxConcurrency(); } @@ -166,7 +167,7 @@ class EngineImpl : public Engine { request = Request::FromUntokenized(request, tokenizer_); ICHECK_NE(request->input_total_length, -1); - if (request->input_total_length >= max_single_sequence_length_) { + if (request->input_total_length >= engine_config_->max_single_sequence_length) { // If the request input length exceeds the maximum allowed single sequence length, // invoke callback and do not process the request. Array output{RequestStreamOutput( @@ -250,7 +251,8 @@ class EngineImpl : public Engine { Array processed_requests = action->Step(estate_); if (!processed_requests.empty()) { ActionStepPostProcess(processed_requests, estate_, models_, tokenizer_, - request_stream_callback_.value(), max_single_sequence_length_); + request_stream_callback_.value(), + engine_config_->max_single_sequence_length); return; } } @@ -274,8 +276,8 @@ class EngineImpl : public Engine { host_cpu_usage += model->EstimateHostCPURequirement(); } int max_concurrency = tvm::runtime::threading::MaxConcurrency(); - tvm::runtime::threading::SetMaxConcurrency(std::min( - std::max(max_concurrency - host_cpu_usage, 1), kv_cache_config_->max_num_sequence)); + tvm::runtime::threading::SetMaxConcurrency( + std::min(std::max(max_concurrency - host_cpu_usage, 1), engine_config_->max_num_sequence)); } /*! \brief Create a grammar init context according to the response format. If the response format @@ -295,9 +297,7 @@ class EngineImpl : public Engine { // Engine state, managing requests and request states. EngineState estate_; // Configurations and singletons - KVCacheConfig kv_cache_config_; EngineConfig engine_config_; - int max_single_sequence_length_; Tokenizer tokenizer_; std::vector token_table_; // Helper to get the grammar init context for requests. @@ -314,14 +314,11 @@ class EngineImpl : public Engine { Optional trace_recorder_; }; -std::unique_ptr Engine::Create( - int max_single_sequence_length, const String& tokenizer_path, - const String& kv_cache_config_json_str, const String& engine_config_json_str, - Optional request_stream_callback, Optional trace_recorder, - const std::vector>& model_infos) { - return std::make_unique( - max_single_sequence_length, tokenizer_path, kv_cache_config_json_str, engine_config_json_str, - request_stream_callback, std::move(trace_recorder), model_infos); +std::unique_ptr Engine::Create(EngineConfig engine_config, + Optional request_stream_callback, + Optional trace_recorder) { + return std::make_unique(std::move(engine_config), std::move(request_stream_callback), + std::move(trace_recorder)); } /*! \brief Clear global memory manager */ @@ -332,48 +329,10 @@ void ClearGlobalMemoryManager() { (*f)(); } -std::unique_ptr CreateEnginePacked(TVMArgs args) { - ClearGlobalMemoryManager(); - const int num_non_model_args = 6; - const int num_model_args = 4; - int num_models = (args.size() - num_non_model_args) / num_model_args; - int max_single_sequence_length; - std::string tokenizer_path; - std::string kv_cache_config_json_str; - std::string engine_config_json_str; - Optional request_stream_callback; - Optional trace_recorder; - std::vector> model_infos; - model_infos.reserve(num_models); - try { - CHECK_LE(num_models * num_model_args + num_non_model_args, args.size()) - << "Incorrect number of arguments."; - max_single_sequence_length = args.At(0); - tokenizer_path = args.At(1); - kv_cache_config_json_str = args.At(2); - engine_config_json_str = args.At(3); - request_stream_callback = args.At>(4); - trace_recorder = args.At>(5); - for (int i = 0; i < num_models; ++i) { - TVMArgValue model_lib = args[i * num_model_args + num_non_model_args]; - std::string model_path = args.At(i * num_model_args + num_non_model_args + 1); - DLDeviceType device_type = - static_cast(args.At(i * num_model_args + num_non_model_args + 2)); - int device_id = args.At(i * num_model_args + num_non_model_args + 3); - model_infos.emplace_back(model_lib, model_path, DLDevice{device_type, device_id}); - } - } catch (const dmlc::Error& e) { - LOG(FATAL) << "ValueError: " << e.what() << kEngineCreationErrorMessage; - } - return Engine::Create(max_single_sequence_length, tokenizer_path, kv_cache_config_json_str, - engine_config_json_str, request_stream_callback, std::move(trace_recorder), - model_infos); -} - class EngineModule : public ModuleNode { public: TVM_MODULE_VTABLE_BEGIN("mlc.serve.engine"); - TVM_MODULE_VTABLE_ENTRY_PACKED("init", &EngineModule::InitPacked); + TVM_MODULE_VTABLE_ENTRY("init", &EngineModule::Init); TVM_MODULE_VTABLE_ENTRY("add_request", &EngineModule::AddRequest); TVM_MODULE_VTABLE_ENTRY("abort_request", &EngineModule::Abort); TVM_MODULE_VTABLE_ENTRY("step", &EngineModule::Step); @@ -383,8 +342,12 @@ class EngineModule : public ModuleNode { TVM_MODULE_VTABLE_ENTRY("set_request_stream_callback", &EngineModule::SetRequestStreamCallback); TVM_MODULE_VTABLE_END(); - void InitPacked(TVMArgs args, TVMRetValue* rv) { this->engine_ = CreateEnginePacked(args); } - + /*! \brief Initialize the engine with config and other fields. */ + void Init(EngineConfig engine_config, Optional request_stream_callback, + Optional trace_recorder) { + this->engine_ = Engine::Create(std::move(engine_config), std::move(request_stream_callback), + std::move(trace_recorder)); + } /*! \brief Construct an EngineModule. */ static tvm::runtime::Module Create() { return Module(make_object()); } /*! \brief Redirection to `Engine::AddRequest`. */ diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index 581219c350..fc5e4205ae 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -50,26 +50,14 @@ class Engine { /*! * \brief Create an engine in unique pointer. - * \param max_single_sequence_length The maximum allowed single - * sequence length supported by the engine. - * \param tokenizer_path The tokenizer path on disk. - * \param kv_cache_config_json_str The KV cache config in JSON string. - * \param engine_config_json_str The Engine execution configuration in JSON string. - * \param request_stream_callback The request stream callback function to - * stream back generated output for requests. + * \param engine_config The engine config. + * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. - * \param model_infos The model info tuples. Each tuple contains - * - the model library, which might be a path to the binary file or - * an executable module that is pre-loaded, - * - the path to the model weight parameters, - * - the device to run the model on. * \return The created Engine in pointer. */ - static std::unique_ptr Create( - int max_single_sequence_length, const String& tokenizer_path, - const String& kv_cache_config_json_str, const String& engine_config_json_str, - Optional request_stream_callback, Optional trace_recorder, - const std::vector>& model_infos); + static std::unique_ptr Create(EngineConfig engine_config, + Optional request_stream_callback, + Optional trace_recorder); /*! \brief Reset the engine, clean up all running data and statistics. */ virtual void Reset() = 0; @@ -114,30 +102,6 @@ class Engine { virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; }; -/*! - * \brief Create an Engine from packed arguments in TVMArgs. - * \param args The arguments of engine construction. - * \return The constructed engine in unique pointer. - */ -std::unique_ptr CreateEnginePacked(TVMArgs args); - -constexpr const char* kEngineCreationErrorMessage = - "With `n` models, engine initialization " - "takes (6 + 4 * n) arguments. The first 6 arguments should be: " - "1) (int) maximum length of a sequence, which must be equal or smaller than the context " - "window size of each model; " - "2) (string) path to tokenizer configuration files, which in MLC LLM, usually in a model " - "weights directory; " - "3) (string) JSON configuration for the KVCache; " - "4) (string) JSON mode for Engine;" - "5) (packed function, optional) global request stream callback function. " - "6) (EventTraceRecorder, optional) the event trace recorder for requests." - "The following (4 * n) arguments, 4 for each model, should be: " - "1) (tvm.runtime.Module) The model library loaded into TVM's RelaxVM; " - "2) (string) Model path which includes weights and mlc-chat-config.json; " - "3) (int, enum DLDeviceType) Device type, e.g. CUDA, ROCm, etc; " - "4) (int) Device id, i.e. the ordinal index of the device that exists locally."; - } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index 1c2387e834..79359c5741 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -56,15 +56,14 @@ class EngineAction : public ObjectRef { * \param logit_processor The logit processor. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. - * \param kv_cache_config The KV cache config to help decide prefill is doable. - * \param engine_config The engine operation mode. + * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction NewRequestPrefill(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineConfig engine_config, + EngineConfig engine_config, Optional trace_recorder); /*! * \brief Create the action that prefills requests in the `waiting_queue` @@ -73,15 +72,13 @@ class EngineAction : public ObjectRef { * \param logit_processor The logit processor. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. - * \param kv_cache_config The KV cache config to help decide prefill is doable. - * \param engine_config The engine operation mode. + * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction EagleNewRequestPrefill(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineConfig engine_config, Optional trace_recorder); /*! @@ -139,12 +136,12 @@ class EngineAction : public ObjectRef { * \param models The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. * \param sampler The sampler to sample new tokens. - * \param kv_cache_config The KV cache config to help decide verify is doable. + * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction BatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, KVCacheConfig kv_cache_config, + Sampler sampler, EngineConfig engine_config, Optional trace_recorder); /*! @@ -155,14 +152,14 @@ class EngineAction : public ObjectRef { * models, the `Step` function of the created action will not take effect. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. - * \param kv_cache_config The KV cache config to help decide verify is doable. + * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction EagleBatchVerify(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, + EngineConfig engine_config, Optional trace_recorder); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineAction, ObjectRef, EngineActionObj); diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 9270b6d284..6f38292ba3 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -27,12 +27,12 @@ namespace serve { class BatchVerifyActionObj : public EngineActionObj { public: explicit BatchVerifyActionObj(Array models, LogitProcessor logit_processor, - Sampler sampler, KVCacheConfig kv_cache_config, + Sampler sampler, EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), - kv_cache_config_(std::move(kv_cache_config)), + engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), rng_(RandomGenerator::GetInstance()) {} @@ -182,8 +182,8 @@ class BatchVerifyActionObj : public EngineActionObj { num_page_requirement.reserve(running_rsentries.size()); for (const RequestStateEntry& rsentry : running_rsentries) { int draft_length = rsentry->mstates[draft_model_id_]->draft_output_tokens.size(); - int num_require_pages = - (draft_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + int num_require_pages = (draft_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; draft_lengths.push_back(draft_length); num_page_requirement.push_back(num_require_pages); total_draft_length += draft_length; @@ -218,8 +218,8 @@ class BatchVerifyActionObj : public EngineActionObj { LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; - /*! \brief The kv cache config. */ - KVCacheConfig kv_cache_config_; + /*! \brief The engine config. */ + EngineConfig engine_config_; /*! \brief Event trace recorder. */ Optional trace_recorder_; /*! \brief Random number generator. */ @@ -231,10 +231,10 @@ class BatchVerifyActionObj : public EngineActionObj { }; EngineAction EngineAction::BatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, KVCacheConfig kv_cache_config, + Sampler sampler, EngineConfig engine_config, Optional trace_recorder) { return EngineAction(make_object( - std::move(models), std::move(logit_processor), std::move(sampler), std::move(kv_cache_config), + std::move(models), std::move(logit_processor), std::move(sampler), std::move(engine_config), std::move(trace_recorder))); } diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index 0c2040db9d..043f68b9c2 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -29,13 +29,13 @@ class EagleBatchVerifyActionObj : public EngineActionObj { public: explicit EagleBatchVerifyActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, + EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), - kv_cache_config_(std::move(kv_cache_config)), + engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), rng_(RandomGenerator::GetInstance()) {} @@ -279,8 +279,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { num_page_requirement.reserve(running_rsentries.size()); for (const RequestStateEntry& rsentry : running_rsentries) { int draft_length = rsentry->mstates[draft_model_id_]->draft_output_tokens.size(); - int num_require_pages = - (draft_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + int num_require_pages = (draft_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; draft_lengths.push_back(draft_length); num_page_requirement.push_back(num_require_pages); total_draft_length += draft_length; @@ -337,8 +337,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; - /*! \brief The kv cache config. */ - KVCacheConfig kv_cache_config_; + /*! \brief The engine config. */ + EngineConfig engine_config_; /*! \brief Event trace recorder. */ Optional trace_recorder_; /*! \brief Random number generator. */ @@ -352,11 +352,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj { EngineAction EngineAction::EagleBatchVerify(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, + EngineConfig engine_config, Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(kv_cache_config), std::move(trace_recorder))); + std::move(model_workspaces), std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index d7a397ce92..133c23e8a1 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -24,14 +24,12 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { explicit EagleNewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), - kv_cache_config_(std::move(kv_cache_config)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)) {} @@ -393,8 +391,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } int input_length = rsentry->mstates[0]->GetInputLength(); - int num_require_pages = - (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; total_input_length += input_length; total_required_pages += num_require_pages; // - Attempt 1. Check if the entire request state entry can fit for prefill. @@ -417,9 +415,9 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { total_required_pages -= num_require_pages; // - Attempt 2. Check if the request state entry can partially fit by input chunking. - ICHECK_LE(total_input_length, kv_cache_config_->prefill_chunk_size); - if (kv_cache_config_->prefill_chunk_size - total_input_length >= input_length || - kv_cache_config_->prefill_chunk_size == total_input_length) { + ICHECK_LE(total_input_length, engine_config_->prefill_chunk_size); + if (engine_config_->prefill_chunk_size - total_input_length >= input_length || + engine_config_->prefill_chunk_size == total_input_length) { // 1. If the input length can fit the remaining prefill chunk size, // it means the failure of attempt 1 is not because of the input // length being too long, and thus chunking does not help. @@ -429,9 +427,9 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { prefill_stops = true; break; } - input_length = kv_cache_config_->prefill_chunk_size - total_input_length; - num_require_pages = - (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + input_length = engine_config_->prefill_chunk_size - total_input_length; + num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; total_input_length += input_length; total_required_pages += num_require_pages; if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, @@ -456,7 +454,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, int num_required_pages, int num_available_pages, int current_total_seq_len, int num_running_rsentries) { - ICHECK_LE(num_running_rsentries, kv_cache_config_->max_num_sequence); + ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); // No exceeding of the maximum allowed requests that can // run simultaneously. @@ -464,7 +462,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { ? engine_config_->spec_draft_length : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > - std::min(kv_cache_config_->max_num_sequence, kv_cache_config_->prefill_chunk_size)) { + std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { return false; } @@ -475,10 +473,10 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // exceed the limit, where 8 is a watermark number can // be configured and adjusted in the future. int new_batch_size = num_running_rsentries + num_prefill_rsentries; - return total_input_length <= kv_cache_config_->prefill_chunk_size && + return total_input_length <= engine_config_->prefill_chunk_size && num_required_pages + new_batch_size <= num_available_pages && current_total_seq_len + total_input_length + 8 * new_batch_size <= - kv_cache_config_->max_total_sequence_length; + engine_config_->max_total_sequence_length; } /*! @@ -582,9 +580,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; - /*! \brief The KV cache config to help decide prefill is doable. */ - KVCacheConfig kv_cache_config_; - /*! \brief The engine operation mode. */ + /*! \brief The engine config. */ EngineConfig engine_config_; /*! \brief Event trace recorder. */ Optional trace_recorder_; @@ -593,13 +589,11 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { EngineAction EngineAction::EagleNewRequestPrefill(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineConfig engine_config, Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(kv_cache_config), std::move(engine_config), - std::move(trace_recorder))); + std::move(model_workspaces), std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index d70b9d7edc..c3f7491960 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -23,13 +23,12 @@ class NewRequestPrefillActionObj : public EngineActionObj { public: explicit NewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineConfig engine_config, + EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), - kv_cache_config_(std::move(kv_cache_config)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)) {} @@ -332,8 +331,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { } int input_length = rsentry->mstates[0]->GetInputLength(); - int num_require_pages = - (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; total_input_length += input_length; total_required_pages += num_require_pages; // - Attempt 1. Check if the entire request state entry can fit for prefill. @@ -356,9 +355,9 @@ class NewRequestPrefillActionObj : public EngineActionObj { total_required_pages -= num_require_pages; // - Attempt 2. Check if the request state entry can partially fit by input chunking. - ICHECK_LE(total_input_length, kv_cache_config_->prefill_chunk_size); - if (kv_cache_config_->prefill_chunk_size - total_input_length >= input_length || - kv_cache_config_->prefill_chunk_size == total_input_length) { + ICHECK_LE(total_input_length, engine_config_->prefill_chunk_size); + if (engine_config_->prefill_chunk_size - total_input_length >= input_length || + engine_config_->prefill_chunk_size == total_input_length) { // 1. If the input length can fit the remaining prefill chunk size, // it means the failure of attempt 1 is not because of the input // length being too long, and thus chunking does not help. @@ -368,9 +367,9 @@ class NewRequestPrefillActionObj : public EngineActionObj { prefill_stops = true; break; } - input_length = kv_cache_config_->prefill_chunk_size - total_input_length; - num_require_pages = - (input_length + kv_cache_config_->page_size - 1) / kv_cache_config_->page_size; + input_length = engine_config_->prefill_chunk_size - total_input_length; + num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; total_input_length += input_length; total_required_pages += num_require_pages; if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, @@ -395,7 +394,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, int num_required_pages, int num_available_pages, int current_total_seq_len, int num_running_rsentries) { - ICHECK_LE(num_running_rsentries, kv_cache_config_->max_num_sequence); + ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); // No exceeding of the maximum allowed requests that can // run simultaneously. @@ -403,7 +402,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { ? engine_config_->spec_draft_length : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > - std::min(kv_cache_config_->max_num_sequence, kv_cache_config_->prefill_chunk_size)) { + std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { return false; } @@ -414,10 +413,10 @@ class NewRequestPrefillActionObj : public EngineActionObj { // exceed the limit, where 8 is a watermark number can // be configured and adjusted in the future. int new_batch_size = num_running_rsentries + num_prefill_rsentries; - return total_input_length <= kv_cache_config_->prefill_chunk_size && + return total_input_length <= engine_config_->prefill_chunk_size && num_required_pages + new_batch_size <= num_available_pages && current_total_seq_len + total_input_length + 8 * new_batch_size <= - kv_cache_config_->max_total_sequence_length; + engine_config_->max_total_sequence_length; } /*! @@ -501,9 +500,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; - /*! \brief The KV cache config to help decide prefill is doable. */ - KVCacheConfig kv_cache_config_; - /*! \brief The engine operation mode. */ + /*! \brief The engine config. */ EngineConfig engine_config_; /*! \brief Event trace recorder. */ Optional trace_recorder_; @@ -512,13 +509,11 @@ class NewRequestPrefillActionObj : public EngineActionObj { EngineAction EngineAction::NewRequestPrefill(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, - KVCacheConfig kv_cache_config, EngineConfig engine_config, Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(kv_cache_config), std::move(engine_config), - std::move(trace_recorder))); + std::move(model_workspaces), std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 8a0bcd66c6..fa24828399 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -69,7 +69,7 @@ PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func, }); } -void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object model_config) { +void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config) { local_gpu_device = device; Device null_device{DLDeviceType(0), 0}; int num_shards; @@ -85,15 +85,6 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object this->cached_buffers = Map(); if (num_shards > 1) { - String lib_path{nullptr}; - try { - lib_path = reload_lib.operator String(); - } catch (...) { - LOG(FATAL) - << "ValueError: In multi-GPU inference, we expect the first argument to Reload to be a " - "string path to the model library (.so on Linux or .dll on Windows), but got: " - << ArgTypeCode2Str(reload_lib.type_code()); - } constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; if (Registry::Get(f_create_process_pool) == nullptr) { LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " @@ -116,7 +107,7 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), - lib_path, null_device); + std::move(reload_lib_path), null_device); this->mod_get_func = [this, fmodule_get_function = sess->GetGlobalFunc("runtime.ModuleGetFunction")]( const std::string& name) -> PackedFunc { @@ -139,11 +130,10 @@ void FunctionTable::Init(TVMArgValue reload_lib, Device device, picojson::object this->_InitFunctions(); } else { Module executable{nullptr}; - if (reload_lib.type_code() == kTVMModuleHandle) { - executable = reload_lib.operator Module(); + if (false) { + // Todo(mlc-team): system lib reload // reload_lib_path starts with "system://" } else { - String lib_path = reload_lib.operator String(); - executable = tvm::runtime::Module::LoadFromFile(lib_path); + executable = tvm::runtime::Module::LoadFromFile(reload_lib_path); } this->use_disco = false; auto fload_exec = executable->GetFunction("vm_load_executable"); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index 03b0428096..f6a156b8a3 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -41,7 +41,7 @@ using namespace tvm::runtime; struct FunctionTable { static PackedFunc SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name); - void Init(TVMArgValue reload_lib, Device device, picojson::object model_config); + void Init(String reload_lib_path, Device device, picojson::object model_config); ObjectRef LoadParams(const std::string& model_path, Device device); diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index eb35bada38..17121d8e28 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -25,10 +25,10 @@ class ModelImpl; TVM_REGISTER_OBJECT_TYPE(ModelObj); -Model Model::Create(TVMArgValue reload_lib, String model_path, DLDevice device, +Model Model::Create(String reload_lib_path, String model_path, DLDevice device, int max_num_sequence, bool trace_enabled) { return Model( - make_object(reload_lib, model_path, device, max_num_sequence, trace_enabled)); + make_object(reload_lib_path, model_path, device, max_num_sequence, trace_enabled)); } class ModelImpl : public ModelObj { @@ -37,7 +37,7 @@ class ModelImpl : public ModelObj { * \brief Constructor of ModelImpl. * \sa Model::Create */ - explicit ModelImpl(TVMArgValue reload_lib, String model_path, DLDevice device, + explicit ModelImpl(String reload_lib_path, String model_path, DLDevice device, int max_num_sequence, bool trace_enabled) : device_(device) { // Step 1. Process model config json string. @@ -53,7 +53,7 @@ class ModelImpl : public ModelObj { // Step 2. Initialize vm, we use the packed function mechanism // so there is no explicit abi dependency on these extra // classes other than basic tvm runtime. - this->ft_.Init(reload_lib, device_, model_config); + this->ft_.Init(reload_lib_path, device_, model_config); // Step 3. Load params in nd-array cache. this->params_ = ft_.LoadParams(model_path, device_); // Step 4. Set max_num_sequence @@ -714,14 +714,16 @@ class ModelImpl : public ModelObj { /*********************** KV Cache Management ***********************/ - void CreateKVCache(KVCacheConfig kv_cache_config) final { - IntTuple max_num_sequence{kv_cache_config->max_num_sequence}; - IntTuple max_total_sequence_length{kv_cache_config->max_total_sequence_length}; - IntTuple prefill_chunk_size{kv_cache_config->prefill_chunk_size}; - IntTuple page_size{kv_cache_config->page_size}; + void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, + int prefill_chunk_size) final { + IntTuple max_num_sequence_tuple{max_num_sequence}; + IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; + IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; + IntTuple page_size_tuple{page_size}; IntTuple support_sliding_window{sliding_window_size_ != -1}; - kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence, max_total_sequence_length, - prefill_chunk_size, page_size, support_sliding_window); + kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple, + prefill_chunk_size_tuple, page_size_tuple, + support_sliding_window); local_kv_cache_ = ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; } diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 761f936363..da532f83e8 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -227,9 +227,16 @@ class ModelObj : public Object { /*! * \brief Create the KV cache inside the model with regard to the input config. - * \param kv_cache_config The configuration of KV cache. + * \param page_size The number of consecutive tokens handled in each page in paged KV cache. + * \param max_num_sequence The maximum number of sequences that are allowed to be + * processed by the KV cache at any time. + * \param max_total_sequence_length The maximum length allowed for a single sequence + * in the engine. + * \param prefill_chunk_size The maximum total number of tokens whose KV data + * are allowed to exist in the KV cache at any time. */ - virtual void CreateKVCache(KVCacheConfig kv_cache_config) = 0; + virtual void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, + int prefill_chunk_size) = 0; /*! \brief Add a new sequence with the given sequence id to the KV cache. */ virtual void AddNewSequence(int64_t seq_id) = 0; @@ -306,15 +313,14 @@ class Model : public ObjectRef { public: /*! * \brief Create the runtime module for LLM functions. - * \param reload_lib The model library. It might be a path to the binary - * file or an executable module that is pre-loaded. + * \param reload_lib_path The model library path. * \param model_path The path to the model weight parameters. * \param device The device to run the model on. * \param max_num_sequence The maximum number of sequences to be processed * \param trace_enabled A boolean indicating whether tracing is enabled. * \return The created runtime module. */ - TVM_DLL static Model Create(TVMArgValue reload_lib, String model_path, DLDevice device, + TVM_DLL static Model Create(String reload_lib_path, String model_path, DLDevice device, int max_num_sequence, bool trace_enabled); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj); diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index d79b122125..458d2ae5d7 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -35,14 +35,9 @@ enum class InstructionKind : int { /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineImpl : public ThreadedEngine { public: - void InitBackgroundEngine(TVMArgs args) final { - Optional request_stream_callback; - try { - request_stream_callback = args.At>(4); - } catch (const dmlc::Error& e) { - LOG(FATAL) << "ValueError: " << e.what() << kEngineCreationErrorMessage; - } - + void InitBackgroundEngine(EngineConfig engine_config, + Optional request_stream_callback, + Optional trace_recorder) final { CHECK(request_stream_callback.defined()) << "ThreadedEngine requires request stream callback function, but it is not given."; request_stream_callback_ = request_stream_callback.value(); @@ -62,12 +57,9 @@ class ThreadedEngineImpl : public ThreadedEngine { } }; - std::vector values{args.values, args.values + args.size()}; - std::vector type_codes{args.type_codes, args.type_codes + args.size()}; - TVMArgsSetter setter(values.data(), type_codes.data()); request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - setter(4, request_stream_callback); - background_engine_ = CreateEnginePacked(TVMArgs(values.data(), type_codes.data(), args.size())); + background_engine_ = Engine::Create( + std::move(engine_config), std::move(request_stream_callback), std::move(trace_recorder)); } void AddRequest(Request request) final { @@ -244,6 +236,7 @@ class ThreadedEngineImpl : public ThreadedEngine { class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { public: TVM_MODULE_VTABLE_BEGIN("mlc.serve.async_threaded_engine"); + TVM_MODULE_VTABLE_ENTRY("init_background_engine", &ThreadedEngineImpl::InitBackgroundEngine); TVM_MODULE_VTABLE_ENTRY("add_request", &ThreadedEngineImpl::AddRequest); TVM_MODULE_VTABLE_ENTRY("abort_request", &ThreadedEngineImpl::AbortRequest); TVM_MODULE_VTABLE_ENTRY("run_background_loop", &ThreadedEngineImpl::RunBackgroundLoop); @@ -252,12 +245,6 @@ class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &ThreadedEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_ENTRY("debug_call_func_on_all_worker", &ThreadedEngineImpl::DebugCallFuncOnAllAllWorker); - if (_name == "init_background_engine") { - return PackedFunc([_self](TVMArgs args, TVMRetValue* rv) -> void { - SelfPtr self = static_cast(_self.get()); - self->InitBackgroundEngine(args); - }); - } TVM_MODULE_VTABLE_END(); }; diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index 2e57afd2a0..3d11ba36f1 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -35,9 +35,13 @@ class ThreadedEngine { /*! * \brief Initialize the threaded engine from packed arguments in TVMArgs. - * \param args The arguments of engine construction. + * \param engine_config The engine config. + * \param request_stream_callback The request stream callback function to. + * \param trace_recorder Event trace recorder for requests. */ - virtual void InitBackgroundEngine(TVMArgs args) = 0; + virtual void InitBackgroundEngine(EngineConfig engine_config, + Optional request_stream_callback, + Optional trace_recorder) = 0; /*! \brief Starts the background request processing loop. */ virtual void RunBackgroundLoop() = 0; diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index 48a72327e2..9f7c1c3580 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -4,7 +4,7 @@ from mlc_llm.help import HELP from mlc_llm.interface.serve import serve -from mlc_llm.serve.config import EngineConfig +from mlc_llm.serve.config import SpeculativeMode from mlc_llm.support.argparse import ArgumentParser @@ -48,9 +48,14 @@ def main(argv): "--gpu-memory-utilization", type=float, help=HELP["gpu_memory_utilization_serve"] ) parser.add_argument( - "--engine-config", - type=EngineConfig.from_str, - help=HELP["engine_config_serve"] + ' (default: "%(default)s")', + "--speculative-mode", + type=str, + choices=["DISABLE", "SMALL_DRAFT", "EAGLE"], + default="DISABLE", + help=HELP["speculative_mode_serve"], + ) + parser.add_argument( + "--spec-draft-length", type=int, default=4, help=HELP["spec_draft_length_serve"] ) parser.add_argument("--enable-tracing", action="store_true", help=HELP["enable_tracing_serve"]) parser.add_argument( @@ -96,7 +101,8 @@ def main(argv): max_total_sequence_length=parsed.max_total_seq_length, prefill_chunk_size=parsed.prefill_chunk_size, gpu_memory_utilization=parsed.gpu_memory_utilization, - engine_config=parsed.engine_config, + speculative_mode=SpeculativeMode[parsed.speculative_mode], + spec_draft_length=parsed.spec_draft_length, enable_tracing=parsed.enable_tracing, host=parsed.host, port=parsed.port, diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index 429e8a972d..b4321ebdec 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -191,6 +191,16 @@ When it is unspecified, it defaults to 0.90. Under mode "local" or "interactive", the actual memory usage may be significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. +""", + "speculative_mode_serve": """ +The speculative decoding mode. Right now three options are supported: + - DISABLE, where speculative decoding is not enabled, + - SMALL_DRAFT, denoting the normal speculative decoding (small draft) style, + - EAGLE, denoting the eagle-style speculative decoding. +The default mode is "DISABLE". +""", + "spec_draft_length_serve": """ +The number of draft tokens to generate in speculative proposal. The default values is 4. """, "engine_config_serve": """ The LLMEngine execution configuration. diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index 3282762c00..c5696ef473 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -8,7 +8,7 @@ from mlc_llm.protocol import error_protocol from mlc_llm.serve import engine -from mlc_llm.serve.config import EngineConfig +from mlc_llm.serve.config import SpeculativeMode from mlc_llm.serve.entrypoints import debug_entrypoints, openai_entrypoints from mlc_llm.serve.server import ServerContext @@ -23,7 +23,8 @@ def serve( max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], gpu_memory_utilization: Optional[float], - engine_config: Optional[EngineConfig], + speculative_mode: SpeculativeMode, + spec_draft_length: int, enable_tracing: bool, host: str, port: int, @@ -44,7 +45,8 @@ def serve( max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, gpu_memory_utilization=gpu_memory_utilization, - engine_config=engine_config, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, enable_tracing=enable_tracing, ) diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 8b99c9bc50..0a59df7421 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -2,7 +2,7 @@ # Load MLC LLM library by importing base from .. import base -from .config import EngineConfig, GenerationConfig, KVCacheConfig, SpeculativeMode +from .config import EngineConfig, GenerationConfig, SpeculativeMode from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData from .engine import AsyncLLMEngine, LLMEngine from .grammar import BNFGrammar, GrammarStateMatcher diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 113356156b..773a00625e 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -1,12 +1,14 @@ """Configuration dataclasses used in MLC LLM serving""" -import argparse import enum import json from dataclasses import asdict, dataclass, field -from io import StringIO from typing import Dict, List, Literal, Optional +import tvm + +from . import _ffi_api + @dataclass class ResponseFormat: @@ -126,100 +128,89 @@ def from_json(json_str: str) -> "GenerationConfig": return GenerationConfig(**json.loads(json_str)) -@dataclass -class KVCacheConfig: - """The KV cache initialization configuration. - - Parameters - ---------- - page_size : int - The number of consecutive tokens handled in each page in paged KV cache. - - max_num_sequence : int - The maximum number of sequences that are allowed to processed by the KV - cache at any time. - - max_total_sequence_length : Optional[int] - The maximum total number of tokens whose KV data are allowed to exist - in the KV cache at any time. - Set it to None to enable automatic computation of the max total - sequence length. - - prefill_chunk_size : Optional[int] - The maximum total sequence length in a prefill. - If not specified, it will be automatically inferred from model config. - """ - - page_size: int = 16 - max_num_sequence: int = 32 - max_total_sequence_length: Optional[int] = None - prefill_chunk_size: Optional[int] = None - - def asjson(self) -> str: - """Return the config in string of JSON format.""" - return json.dumps(asdict(self)) - - @staticmethod - def from_json(json_str: str) -> "KVCacheConfig": - """Construct a config from JSON string.""" - return KVCacheConfig(**json.loads(json_str)) - - class SpeculativeMode(enum.IntEnum): """The speculative mode.""" + # Disable speculative decoding. DISABLE = 0 + # The normal speculative decoding (small draft) mode. SMALL_DRAFT = 1 + # The eagle-style speculative decoding. EAGLE = 2 -@dataclass -class EngineConfig: +@tvm._ffi.register_object("mlc.serve.EngineConfig") # pylint: disable=protected-access +class EngineConfig(tvm.runtime.Object): """The class of LLMEngine execution configuration. Parameters ---------- - spec_draft_length : int - The number of tokens to generate in speculative proposal (draft), default 4. + model : str + The path to the model directory. - speculative_mode: SpeculativeMode - The speculative mode. - """ + model_lib_path : str + The path to the model library. - spec_draft_length: int = 4 - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE + additional_models : List[str] + The path to the additional models' directories. - def __repr__(self) -> str: - out = StringIO() - print(f"spec_draft_length={self.spec_draft_length}", file=out, end="") - print(f";speculative_mode={self.speculative_mode.name}", file=out, end="") - return out.getvalue().rstrip() + additional_model_lib_paths : List[str] + The path to the additional models' libraries. - def asjson(self) -> str: - """Return the config in string of JSON format.""" - dt = asdict(self) - dt["speculative_mode"] = int(self.speculative_mode) - return json.dumps(dt) + device : tvm.runtime.Device + The device where the models run. - @staticmethod - def from_json(json_str: str) -> "EngineConfig": - """Construct a config from JSON string.""" - return EngineConfig(**json.loads(json_str)) + kv_cache_page_size : int + The number of consecutive tokens handled in each page in paged KV cache. - @staticmethod - def from_str(source: str) -> "EngineConfig": - """Parse engine config from a string.""" - - parser = argparse.ArgumentParser(description="optimization flags") - parser.add_argument("--spec_draft_length", type=int, default=4) - parser.add_argument( - "--speculative_mode", - type=str, - choices=["DISABLE", "SMALL_DRAFT", "EAGLE"], - default="DISABLE", - ) - results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) - return EngineConfig( - spec_draft_length=results.spec_draft_length, - speculative_mode=SpeculativeMode[results.speculative_mode], + max_num_sequence : int + The maximum number of sequences that are allowed to be + processed by the KV cache at any time. + + max_total_sequence_length : int + The maximum length allowed for a single sequence in the engine. + + max_single_sequence_length : int + The maximum total number of tokens whose KV data are allowed + to exist in the KV cache at any time. + + prefill_chunk_size : int + The maximum total sequence length in a prefill. + + speculative_mode : SpeculativeMode + The speculative mode. + + spec_draft_length : int + The number of tokens to generate in speculative proposal (draft). + """ + + def __init__( # pylint: disable=too-many-arguments + self, + model: str, + model_lib_path: str, + additional_models: List[str], + additional_model_lib_paths: List[str], + device: tvm.runtime.Device, + kv_cache_page_size: int, + max_num_sequence: int, + max_total_sequence_length: int, + max_single_sequence_length: int, + prefill_chunk_size: int, + speculative_mode: SpeculativeMode, + spec_draft_length: int, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.EngineConfig, # type: ignore # pylint: disable=no-member + model, + model_lib_path, + additional_models, + additional_model_lib_paths, + device, + kv_cache_page_size, + max_num_sequence, + max_total_sequence_length, + max_single_sequence_length, + prefill_chunk_size, + speculative_mode, + spec_draft_length, ) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 2ad6b0f1a1..3a329cae21 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -22,7 +22,7 @@ from mlc_llm.protocol import openai_api_protocol from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineConfig, GenerationConfig +from mlc_llm.serve.config import GenerationConfig, SpeculativeMode from mlc_llm.serve.request import Request from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -847,7 +847,8 @@ def __init__( # pylint: disable=too-many-arguments max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, - engine_config: Optional[EngineConfig] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, enable_tracing: bool = False, ) -> None: super().__init__( @@ -861,7 +862,8 @@ def __init__( # pylint: disable=too-many-arguments max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, gpu_memory_utilization=gpu_memory_utilization, - engine_config=engine_config, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, enable_tracing=enable_tracing, ) self.chat = Chat(weakref.ref(self)) @@ -1390,7 +1392,8 @@ def __init__( # pylint: disable=too-many-arguments max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, - engine_config: Optional[EngineConfig] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, enable_tracing: bool = False, ) -> None: super().__init__( @@ -1404,7 +1407,8 @@ def __init__( # pylint: disable=too-many-arguments max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, gpu_memory_utilization=gpu_memory_utilization, - engine_config=engine_config, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, enable_tracing=enable_tracing, ) self.chat = Chat(weakref.ref(self)) diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 367deda8a4..4c95f6e612 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -20,7 +20,7 @@ from mlc_llm.protocol import openai_api_protocol, protocol_utils from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineConfig, GenerationConfig, KVCacheConfig +from mlc_llm.serve.config import EngineConfig, GenerationConfig, SpeculativeMode from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -75,20 +75,17 @@ def _parse_models( def _process_model_args( models: List[ModelInfo], device: tvm.runtime.Device -) -> Tuple[List[Any], List[str], str, Conversation]: +) -> Tuple[List[Tuple[str, str]], List[str], Conversation]: """Process the input ModelInfo to get the engine initialization arguments.""" - tokenizer_path: Optional[str] = None conversation: Optional[Conversation] = None config_file_paths: List[str] = [] - def _convert_model_info(model: ModelInfo) -> List[Any]: - nonlocal tokenizer_path, conversation + def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: + nonlocal conversation model_path, config_file_path = _get_model_path(model.model) config_file_paths.append(config_file_path) chat_config = _get_chat_config(config_file_path, user_chat_config=None) - if tokenizer_path is None: - tokenizer_path = model_path if conversation is None: assert isinstance(chat_config.conv_template, Conversation) conversation = chat_config.conv_template @@ -112,15 +109,12 @@ def _convert_model_info(model: ModelInfo) -> List[Any]: device=device, ) ) - return [model_lib_path, model_path, device.device_type, device.device_id] + return model_path, model_lib_path - model_args: List[Any] = sum( - (_convert_model_info(model) for model in models), - start=[], - ) + model_args: List[Tuple[str, str]] = [_convert_model_info(model) for model in models] assert conversation is not None - return model_args, config_file_paths, tokenizer_path, conversation + return model_args, config_file_paths, conversation def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-many-locals,too-many-arguments @@ -306,8 +300,14 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local device: tvm.runtime.Device, model_config_dicts: List[Dict[str, Any]], model_config_paths: List[str], -) -> Tuple[KVCacheConfig, int]: - """Initialize the KV cache config with user input and GPU memory usage estimation.""" +) -> Tuple[int, int, int, int]: + """Initialize the KV cache config with user input and GPU memory usage estimation. + The returned four integers are: + - max_batch_size + - max_total_sequence_length + - prefill_chunk_size + - model_max_single_sequence_length + """ ( model_max_single_sequence_length, model_max_prefill_chunk_size, @@ -319,7 +319,7 @@ def infer_args_under_mode( max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], - ) -> Tuple[KVCacheConfig, List[float]]: + ) -> Tuple[Tuple[int, int, int], List[float]]: logging_msg = "" # - max_batch_size if max_batch_size is None: @@ -396,11 +396,7 @@ def infer_args_under_mode( # - Construct the KV cache config # - Estimate total GPU memory usage on single GPU. - return KVCacheConfig( - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - prefill_chunk_size=prefill_chunk_size, - ), [ + return (max_batch_size, max_total_sequence_length, prefill_chunk_size), [ total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token, model_params_bytes, kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes, @@ -433,9 +429,9 @@ def infer_args_under_mode( 'The actual engine mode is "%s". So max batch size is %s, ' "max KV cache token capacity is %s, prefill chunk size is %s.", green(mode), - green(str(kv_cache_config.max_num_sequence)), - green(str(kv_cache_config.max_total_sequence_length)), - green(str(kv_cache_config.prefill_chunk_size)), + green(str(kv_cache_config[0])), + green(str(kv_cache_config[1])), + green(str(kv_cache_config[2])), ) logger.info( @@ -459,7 +455,7 @@ def infer_args_under_mode( override_msg, ) - return kv_cache_config, model_max_single_sequence_length + return *kv_cache_config, model_max_single_sequence_length @dataclass @@ -729,7 +725,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], gpu_memory_utilization: Optional[float], - engine_config: Optional[EngineConfig], + speculative_mode: SpeculativeMode, + spec_draft_length: int, enable_tracing: bool, ) -> None: # - Initialize model loading info. @@ -740,21 +737,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals ( model_args, model_config_paths, - tokenizer_path, self.conv_template, ) = _process_model_args(models, device) # - Load the raw model config into dict self.model_config_dicts = [] for i, model_info in enumerate(models): - # model_args: - # [model_lib_path, model_path, device.device_type, device.device_id] * N - model_info.model_lib_path = model_args[i * (len(model_args) // len(models))] + model_info.model_lib_path = model_args[i][1] with open(model_config_paths[i], "r", encoding="utf-8") as file: self.model_config_dicts.append(json.load(file)) # - Decide the KV cache config based on mode and user input. - kv_cache_config, max_single_sequence_length = _infer_kv_cache_config( + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, @@ -765,9 +764,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self.model_config_dicts, model_config_paths, ) - self.max_input_sequence_length = min( - max_single_sequence_length, kv_cache_config.max_total_sequence_length - ) + self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) # - Initialize engine state and engine. self.state = EngineState(enable_tracing) @@ -784,20 +781,26 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "debug_call_func_on_all_worker", ] } - self.tokenizer = Tokenizer(tokenizer_path) - if engine_config is None: - # The default engine mode: non-speculative - engine_config = EngineConfig() + self.tokenizer = Tokenizer(model_args[0][0]) def _background_loop(): self._ffi["init_background_engine"]( - max_single_sequence_length, - tokenizer_path, - kv_cache_config.asjson(), - engine_config.asjson(), + EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + device=device, + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ), self.state.get_request_stream_callback(kind), self.state.trace_recorder, - *model_args, ) self._ffi["run_background_loop"]() diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index 08f5dc229e..1d17f8e66a 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -11,7 +11,7 @@ import requests from tvm.runtime import Device -from mlc_llm.serve.config import EngineConfig +from mlc_llm.serve.config import SpeculativeMode class PopenServer: # pylint: disable=too-many-instance-attributes @@ -30,7 +30,8 @@ def __init__( # pylint: disable=too-many-arguments max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, - engine_config: Optional[EngineConfig] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, enable_tracing: bool = False, host: str = "127.0.0.1", port: int = 8000, @@ -45,7 +46,8 @@ def __init__( # pylint: disable=too-many-arguments self.max_total_sequence_length = max_total_sequence_length self.prefill_chunk_size = prefill_chunk_size self.gpu_memory_utilization = gpu_memory_utilization - self.engine_config = engine_config + self.speculative_mode = speculative_mode + self.spec_draft_length = spec_draft_length self.enable_tracing = enable_tracing self.host = host self.port = port @@ -70,8 +72,13 @@ def start(self) -> None: # pylint: disable=too-many-branches cmd += ["--max-total-seq-length", str(self.max_total_sequence_length)] if self.prefill_chunk_size is not None: cmd += ["--prefill-chunk-size", str(self.prefill_chunk_size)] - if self.engine_config is not None: - cmd += ["--engine-config", str(self.engine_config)] + if self.speculative_mode != SpeculativeMode.DISABLE: + cmd += [ + "--speculative-mode", + self.speculative_mode.name, + "--spec-draft-length", + str(self.spec_draft_length), + ] if self.gpu_memory_utilization is not None: cmd += ["--gpu-memory-utilization", str(self.gpu_memory_utilization)] if self.enable_tracing: diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 963ea9402f..23b151d5c7 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -14,7 +14,7 @@ import tvm from mlc_llm.serve import data -from mlc_llm.serve.config import EngineConfig, GenerationConfig +from mlc_llm.serve.config import EngineConfig, GenerationConfig, SpeculativeMode from mlc_llm.serve.engine_base import ( _infer_kv_cache_config, _parse_models, @@ -100,7 +100,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals prefill_chunk_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, enable_tracing: bool = False, - engine_config: Optional[EngineConfig] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, ): # - Initialize model loading info. @@ -111,21 +112,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals ( model_args, model_config_paths, - tokenizer_path, self.conv_template, ) = _process_model_args(models, device) # - Load the raw model config into dict self.model_config_dicts = [] for i, model_info in enumerate(models): - # model_args: - # [model_lib_path, model_path, device.device_type, device.device_id] * N - model_info.model_lib_path = model_args[i * (len(model_args) // len(models))] + model_info.model_lib_path = model_args[i][1] with open(model_config_paths[i], "r", encoding="utf-8") as file: self.model_config_dicts.append(json.load(file)) # - Decide the KV cache config based on mode and user input. - kv_cache_config, max_single_sequence_length = _infer_kv_cache_config( + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, @@ -136,9 +139,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self.model_config_dicts, model_config_paths, ) - self.max_input_sequence_length = min( - max_single_sequence_length, kv_cache_config.max_total_sequence_length - ) + self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) self._ffi = _create_tvm_module( "mlc.serve.create_engine", @@ -155,20 +156,25 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals ) self.trace_recorder = EventTraceRecorder() if enable_tracing else None - if engine_config is None: - # The default engine mode: non-speculative - engine_config = EngineConfig() - self._ffi["init"]( - max_single_sequence_length, - tokenizer_path, - kv_cache_config.asjson(), - engine_config.asjson(), + EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + device=device, + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ), request_stream_callback, self.trace_recorder, - *model_args, ) - self.tokenizer = Tokenizer(tokenizer_path) + self.tokenizer = Tokenizer(model_args[0][0]) def generate( # pylint: disable=too-many-locals self, diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index f14d4727b8..b86fd423a9 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -11,6 +11,7 @@ from mlc_llm.serve import engine_utils from mlc_llm.serve.engine_base import ( EngineConfig, + SpeculativeMode, _infer_kv_cache_config, _parse_models, _process_model_args, @@ -62,7 +63,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, - engine_config: Optional[EngineConfig] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, gpu_memory_utilization: Optional[float] = None, ) -> None: # - Initialize model loading info. @@ -73,21 +75,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals ( model_args, model_config_paths, - tokenizer_path, self.conv_template, ) = _process_model_args(models, device) # - Load the raw model config into dict self.model_config_dicts = [] for i, model_info in enumerate(models): - # model_args: - # [model_lib_path, model_path, device.device_type, device.device_id] * N - model_info.model_lib_path = model_args[i * (len(model_args) // len(models))] + model_info.model_lib_path = model_args[i][1] with open(model_config_paths[i], "r", encoding="utf-8") as file: self.model_config_dicts.append(json.load(file)) # - Decide the KV cache config based on mode and user input. - kv_cache_config, max_single_sequence_length = _infer_kv_cache_config( + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, @@ -98,9 +102,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self.model_config_dicts, model_config_paths, ) - self.max_input_sequence_length = min( - max_single_sequence_length, kv_cache_config.max_total_sequence_length - ) + self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) # - Initialize engine state and engine. self.state = EngineState() @@ -117,20 +119,26 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "exit_background_loop", ] } - self.tokenizer = Tokenizer(tokenizer_path) - if engine_config is None: - # The default engine mode: non-speculative - engine_config = EngineConfig() + self.tokenizer = Tokenizer(model_args[0][0]) def _background_loop(): self._ffi["init_background_engine"]( - max_single_sequence_length, - tokenizer_path, - kv_cache_config.asjson(), - engine_config.asjson(), + EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + device=device, + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ), self.state.get_request_stream_callback(), None, - *model_args, ) self._ffi["run_background_loop"]() diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index 693f0767c3..de91c845b3 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -1,14 +1,9 @@ # pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +# pylint: disable=too-many-arguments,too-many-locals import asyncio from typing import List -from mlc_llm.serve import ( - AsyncLLMEngine, - EngineConfig, - GenerationConfig, - SpeculativeMode, -) +from mlc_llm.serve import AsyncLLMEngine, GenerationConfig, SpeculativeMode prompts = [ "What is the meaning of life?", @@ -37,7 +32,7 @@ async def test_engine_generate(): model_lib_path=model_lib_path, mode="server", additional_models=[small_model + ":" + small_model_lib_path], - engine_config=EngineConfig(speculative_mode=SpeculativeMode.SMALL_DRAFT), + speculative_mode=SpeculativeMode.SMALL_DRAFT, ) num_requests = 10 diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index b398dd62c3..60be02ce1a 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -1,19 +1,16 @@ # pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +# pylint: disable=too-many-arguments,too-many-locals from typing import Callable, List, Optional import numpy as np from mlc_llm.serve import ( - EngineConfig, GenerationConfig, - KVCacheConfig, Request, RequestStreamOutput, SpeculativeMode, data, ) -from mlc_llm.serve.engine_base import ModelInfo from mlc_llm.serve.sync_engine import SyncLLMEngine prompts = [ @@ -99,7 +96,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): mode="server", max_total_sequence_length=4096, additional_models=[small_model + ":" + small_model_lib_path], - engine_config=EngineConfig(speculative_mode=SpeculativeMode.SMALL_DRAFT), + speculative_mode=SpeculativeMode.SMALL_DRAFT, request_stream_callback=fcallback, ) @@ -167,7 +164,8 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): mode="server", max_total_sequence_length=4096, additional_models=[small_model + ":" + small_model_lib_path], - engine_config=EngineConfig(spec_draft_length=2, speculative_mode=SpeculativeMode.EAGLE), + speculative_mode=SpeculativeMode.EAGLE, + spec_draft_length=2, request_stream_callback=fcallback, ) @@ -250,7 +248,7 @@ def step(self) -> None: mode="server", max_total_sequence_length=4096, additional_models=[small_model + ":" + small_model_lib_path], - engine_config=EngineConfig(speculative_mode=SpeculativeMode.SMALL_DRAFT), + speculative_mode=SpeculativeMode.SMALL_DRAFT, request_stream_callback=timer.callback_getter(), ) @@ -336,7 +334,7 @@ def step(self) -> None: mode="server", max_total_sequence_length=4096, additional_models=[small_model + ":" + small_model_lib_path], - engine_config=EngineConfig(speculative_mode=SpeculativeMode.EAGLE), + speculative_mode=SpeculativeMode.EAGLE, request_stream_callback=timer.callback_getter(), ) @@ -380,7 +378,7 @@ def test_engine_generate(): mode="server", max_total_sequence_length=4096, additional_models=[small_model + ":" + small_model_lib_path], - engine_config=EngineConfig(speculative_mode=SpeculativeMode.SMALL_DRAFT), + speculative_mode=SpeculativeMode.SMALL_DRAFT, ) num_requests = 10 @@ -413,7 +411,7 @@ def test_engine_eagle_generate(): mode="server", max_total_sequence_length=4096, additional_models=[small_model + ":" + small_model_lib_path], - engine_config=EngineConfig(speculative_mode=SpeculativeMode.EAGLE), + speculative_mode=SpeculativeMode.EAGLE, ) num_requests = 10 @@ -533,9 +531,8 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): mode="server", max_total_sequence_length=4096, additional_models=[small_model + ":" + small_model_lib_path], - engine_config=EngineConfig( - spec_draft_length=6, speculative_mode=SpeculativeMode.SMALL_DRAFT - ), + spec_draft_length=6, + speculative_mode=SpeculativeMode.SMALL_DRAFT, request_stream_callback=fcallback, ) @@ -604,7 +601,8 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): mode="server", max_total_sequence_length=4096, additional_models=[small_model + ":" + small_model_lib_path], - engine_config=EngineConfig(spec_draft_length=6, speculative_mode=SpeculativeMode.EAGLE), + spec_draft_length=6, + speculative_mode=SpeculativeMode.EAGLE, request_stream_callback=fcallback, ) From ad770d88f6a4325668a0f82978f860873afd9aa4 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:32:38 -0400 Subject: [PATCH 201/531] [Llama3] Support Llama 3 (#2163) * Add conv template and model preset * Fix conv template * Trivial --- python/mlc_llm/conversation_template.py | 22 +++++++++++ python/mlc_llm/interface/gen_config.py | 1 + python/mlc_llm/model/model_preset.py | 50 +++++++++++++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index 1b2a06feab..fa926708d3 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -36,6 +36,28 @@ def get_conv_template(name: str) -> Optional[Conversation]: ############## Preset Conversation Templates ############## +# Llama3 +# See https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models +# and https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py +ConvTemplateRegistry.register_conv_template( + Conversation( + name="llama-3", + system_template=( + "<|start_header_id|>system<|end_header_id|>\n\n", + f"{MessagePlaceholders.SYSTEM.value}", + ), + system_message="You are a helpful, respectful and honest assistant.", + roles={"user": "user", "assistant": "assistant"}, + seps=["<|eot_id|><|start_header_id|>"], + role_content_sep="<|end_header_id|>\n\n", + role_empty_sep="<|end_header_id|>\n\n", + stop_str=["<|end_of_text|>", "<|eot_id|>"], + stop_token_ids=[128001, 128009], # "<|end_of_text|>", "<|eot_id|>" + system_prefix_token_ids=[128000], # "<|begin_of_text|>" + add_role_after_system_message=True, + ) +) + # Llama2 ConvTemplateRegistry.register_conv_template( Conversation( diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index d22aa7d231..8e617fc3d2 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -274,6 +274,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b # FIXME: Copy RWKV tokenizer file # pylint: disable=fixme CONV_TEMPLATES = { + "llama-3", "chatml", "open_hermes_mistral", "neural_hermes_mistral", diff --git a/python/mlc_llm/model/model_preset.py b/python/mlc_llm/model/model_preset.py index 3bfe1cb891..41abf0292c 100644 --- a/python/mlc_llm/model/model_preset.py +++ b/python/mlc_llm/model/model_preset.py @@ -660,4 +660,54 @@ "eos_token_id": 2, "pad_token_id": 0, }, + "llama3_8b": { + "architectures": ["LlamaForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 500000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "use_cache": True, + "vocab_size": 128256, + }, + "llama3_70b": { + "architectures": ["LlamaForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 500000.0, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "use_cache": True, + "vocab_size": 128256, + }, } From bee19286f7ef14fd23eb11ee4cceccd37e6bc357 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:41:57 -0400 Subject: [PATCH 202/531] [Fix] Fix llama 3 conv template (#2164) Fix llama 3 conv template --- python/mlc_llm/conversation_template.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index fa926708d3..917e229632 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -43,8 +43,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: Conversation( name="llama-3", system_template=( - "<|start_header_id|>system<|end_header_id|>\n\n", - f"{MessagePlaceholders.SYSTEM.value}", + f"<|start_header_id|>system<|end_header_id|>\n\n{MessagePlaceholders.SYSTEM.value}" ), system_message="You are a helpful, respectful and honest assistant.", roles={"user": "user", "assistant": "assistant"}, From d6724b1e939cb347afae1e5a20a8a5667403f69d Mon Sep 17 00:00:00 2001 From: Git bot Date: Thu, 18 Apr 2024 20:32:10 +0000 Subject: [PATCH 203/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 7a8520581e..d694451c58 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 7a8520581e4a70024de05fa9e803b5d2899796f6 +Subproject commit d694451c580a931116a2c93571f21f7d791c7fa0 From c6edba8ca5147f712f80c72d5cf6e63363a94222 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 18 Apr 2024 22:52:13 -0400 Subject: [PATCH 204/531] [Serving][HotFix] No `std::move()` for disco CallPacked (#2166) The disco `CallPacked` function cannot handle `std::move()` very well. A previous engine refactor PR introduced a regression that broke our tensor parallelism support. This commit fixes the issue. --- cpp/serve/function_table.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index fa24828399..289abfda16 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -107,7 +107,7 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), - std::move(reload_lib_path), null_device); + reload_lib_path, null_device); this->mod_get_func = [this, fmodule_get_function = sess->GetGlobalFunc("runtime.ModuleGetFunction")]( const std::string& name) -> PackedFunc { From de9852430695a6ef915c598d24059cdeb5f81307 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 19 Apr 2024 01:03:29 -0400 Subject: [PATCH 205/531] [Docs] Update example for Llama3 (#2169) This PR updates the huggingface repo examples to use Llama3. --- docs/deploy/cli.rst | 18 ++++++++---------- docs/get_started/introduction.rst | 10 +++++----- docs/get_started/quick_start.rst | 8 ++++---- docs/prebuilt_models.rst | 2 +- examples/python/sample_mlc_engine.py | 2 +- 5 files changed, 19 insertions(+), 21 deletions(-) diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst index f341e31e71..b2e91ce2b1 100644 --- a/docs/deploy/cli.rst +++ b/docs/deploy/cli.rst @@ -54,13 +54,13 @@ To run a model with MLC LLM in any platform, you can either: **Option 1: Use model prebuilts** To run ``mlc_llm``, you can specify the Huggingface MLC prebuilt model repo path with the prefix ``HF://``. -For example, to run the MLC Llama 2 7B Q4F16_1 model (`Repo link `_), -simply use ``HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC``. The model weights and library will be downloaded +For example, to run the MLC Llama 3 8B Q4F16_1 model (`Repo link `_), +simply use ``HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC``. The model weights and library will be downloaded automatically from Huggingface. .. code:: shell - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC --device "cuda:0" --overrides context_window_size=1024 + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --device "cuda:0" --overrides context_window_size=1024 .. code:: shell @@ -74,13 +74,11 @@ automatically from Huggingface. Note: Separate stop words in the `stop` option with commas (,). Multi-line input: Use escape+enter to start a new line. - [INST]: What's the meaning of life - [/INST]: - Ah, a question that has puzzled philosophers and theologians for centuries! The meaning - of life is a deeply personal and subjective topic, and there are many different - perspectives on what it might be. However, here are some possible answers that have been - proposed by various thinkers and cultures: - ... + user: What's the meaning of life + assistant: + What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. + + The concept of the meaning of life has been debated and... **Option 2: Use locally compiled model weights and libraries** diff --git a/docs/get_started/introduction.rst b/docs/get_started/introduction.rst index 282b4764c2..b69bd1d504 100644 --- a/docs/get_started/introduction.rst +++ b/docs/get_started/introduction.rst @@ -37,7 +37,7 @@ You can run MLC chat through a one-liner command: .. code:: bash - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC It may take 1-2 minutes for the first time running this command. After waiting, this command launch a chat interface where you can enter your prompt and chat with the model. @@ -91,7 +91,7 @@ You can save the code below into a Python file and run it. from mlc_llm import LLMEngine # Create engine - model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" + model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" engine = LLMEngine(model) # Run chat completion in OpenAI API. @@ -142,7 +142,7 @@ for OpenAI chat completion requests. The server can be launched in command line .. code:: bash - mlc_llm serve HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC The server is hooked at ``http://127.0.0.1:8000`` by default, and you can use ``--host`` and ``--port`` to set a different host and port. @@ -154,7 +154,7 @@ we can open a new shell and send a cURL request via the following command: curl -X POST \ -H "Content-Type: application/json" \ -d '{ - "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", + "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", "messages": [ {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} ] @@ -280,7 +280,7 @@ environments (e.g. SteamDeck). .. code:: bash - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC --device vulkan + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --device vulkan The same core LLM runtime engine powers all the backends, enabling the same model to be deployed across backends as long as they fit within the memory and computing budget of the corresponding hardware backend. diff --git a/docs/get_started/quick_start.rst b/docs/get_started/quick_start.rst index bd3b41218e..604688f790 100644 --- a/docs/get_started/quick_start.rst +++ b/docs/get_started/quick_start.rst @@ -23,7 +23,7 @@ It is recommended to have at least 6GB free VRAM to run it. from mlc_llm import LLMEngine # Create engine - model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" + model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" engine = LLMEngine(model) # Run chat completion in OpenAI API. @@ -57,7 +57,7 @@ It is recommended to have at least 6GB free VRAM to run it. .. code:: shell - mlc_llm serve HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC **Send requests to server.** When the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``), open a new shell and send a request via the following command: @@ -67,7 +67,7 @@ It is recommended to have at least 6GB free VRAM to run it. curl -X POST \ -H "Content-Type: application/json" \ -d '{ - "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", + "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", "messages": [ {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} ] @@ -94,7 +94,7 @@ It is recommended to have at least 6GB free VRAM to run it. .. code:: bash - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC If you are using windows/linux/steamdeck and would like to use vulkan, diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst index f97909a515..2f772a5d7e 100644 --- a/docs/prebuilt_models.rst +++ b/docs/prebuilt_models.rst @@ -68,7 +68,7 @@ For more, please see :ref:`the CLI page `, and the :ref:`the Python .. code:: shell - mlc_llm chat HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC + mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC To run the model with Python API, see :ref:`the Python page ` (all other downloading steps are the same as CLI). diff --git a/examples/python/sample_mlc_engine.py b/examples/python/sample_mlc_engine.py index e26e17f1e2..f76e44c620 100644 --- a/examples/python/sample_mlc_engine.py +++ b/examples/python/sample_mlc_engine.py @@ -1,7 +1,7 @@ from mlc_llm import LLMEngine # Create engine -model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" +model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" engine = LLMEngine(model) # Run chat completion in OpenAI API. From 3dbc1d515c99c9ffe278262454dd9228954b4dd7 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 18 Apr 2024 22:33:36 -0700 Subject: [PATCH 206/531] [README] Fix broken link to Python API (#2168) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9bea5ccc0e..da3099c11e 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_mode ## Universal Deployment APIs MLC LLM provides multiple sets of APIs across platforms and environments. These include -* [Python API](https://llm.mlc.ai/docs/deploy/python.html) +* [Python API](https://llm.mlc.ai/docs/deploy/python_engine.html) * [OpenAI-compatible Rest-API](https://llm.mlc.ai/docs/deploy/rest.html) * [C++ API](https://llm.mlc.ai/docs/deploy/cli.html) * [JavaScript API](https://llm.mlc.ai/docs/deploy/javascript.html) and [Web LLM](https://github.com/mlc-ai/web-llm) From 856204eeb237dbd6dc478c3cb83c0caad0028050 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 19 Apr 2024 01:34:12 -0400 Subject: [PATCH 207/531] [Docs] Update README (#2170) This PR updates README for Llama3 quick start examples. --- README.md | 138 ++++++++++++++++++++++++++---- docs/deploy/cli.rst | 2 +- docs/get_started/introduction.rst | 22 ++--- 3 files changed, 136 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index da3099c11e..782d647531 100644 --- a/README.md +++ b/README.md @@ -64,23 +64,131 @@ Scaling of fp16 and 4-bit CodeLlama-34 and Llama2-70B on A100-80G-PCIe and A10G-

-## News - -* [10/18/2023] [[Post]](https://blog.mlc.ai/2023/10/19/Scalable-Language-Model-Inference-on-Multiple-NVDIA-AMD-GPUs) Scalable multi-GPU support for CUDA and ROCm are official. -* [09/02/2023] Prebuilt ROCm 5.7 and CUDA 12.2 package is [available](https://llm.mlc.ai/docs/install/tvm.html#option-1-prebuilt-package). -* [08/25/2023] CodeLlama support is up. -* [08/14/2023] [[Post]](https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi) Mali GPU support is up on Orange Pi. -* [08/09/2023] [[Post]](https://blog.mlc.ai/2023/08/09/Making-AMD-GPUs-competitive-for-LLM-inference) ROCm backend is mature to use. -* [08/02/2023] [Dockerfile](https://github.com/mlc-ai/llm-perf-bench/) is released for CUDA performance benchmarking. -* [07/19/2023] Support for Llama2-7B/13B/70B is up. -* [05/22/2023] [[Post]](https://blog.mlc.ai/2023/05/22/bringing-open-large-language-models-to-consumer-devices) RedPajama support is up. -* [05/08/2023] [[Post]](https://blog.mlc.ai/2023/05/08/bringing-hardware-accelerated-language-models-to-android-devices) MLC LLM is now available on Android. -* [05/01/2023] [[Post]](https://blog.mlc.ai/2023/05/01/bringing-accelerated-llm-to-consumer-hardware) MLC LLM is released with Metal, Vulkan and CUDA backends. -* [04/14/2023] [WebLLM](https://github.com/mlc-ai/web-llm) is released prior to MLC LLM with WebGPU and WebAssembly backend. ## Getting Started -Please visit our [documentation](https://llm.mlc.ai/docs/index.html#getting-started) for detailed instructions. +We introduce the quick start examples of chat CLI, Python API and REST server here to use MLC LLM. +We use 4-bit quantized 8B Llama-3 model for demonstration purpose. +The pre-quantized Llama-3 weights is available at https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC. +You can also try out unquantized Llama-3 model by replacing `q4f16_1` to `q0f16` in the examples below. +Please visit our [documentation](https://llm.mlc.ai/docs/index.html) for detailed quick start and introduction. + +### Installation + +MLC LLM is available via [pip](https://llm.mlc.ai/docs/install/mlc_llm.html#install-mlc-packages). +It is always recommended to install it in an isolated conda virtual environment. + +To verify the installation, activate your virtual environment, run + +```bash +python -c "import mlc_llm; print(mlc_llm.__path__)" +``` + +You are expected to see the installation path of MLC LLM Python package. + +### Chat CLI + +We can try out the chat CLI in MLC LLM with 4-bit quantized 8B Llama-3 model. + +```bash +mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC +``` + +It may take 1-2 minutes for the first time running this command. +After waiting, this command launch a chat interface where you can enter your prompt and chat with the model. + +``` +You can use the following special commands: +/help print the special commands +/exit quit the cli +/stats print out the latest stats (token/sec) +/reset restart a fresh chat +/set [overrides] override settings in the generation config. For example, + `/set temperature=0.5;max_gen_len=100;stop=end,stop` + Note: Separate stop words in the `stop` option with commas (,). +Multi-line input: Use escape+enter to start a new line. + +user: What's the meaning of life +assistant: +What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. + +The concept of the meaning of life has been debated and... +``` + +### Python API + +We can run the Llama-3 model with the chat completion Python API of MLC LLM. +You can save the code below into a Python file and run it. + +```python +from mlc_llm import LLMEngine + +# Create engine +model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" +engine = LLMEngine(model) + +# Run chat completion in OpenAI API. +for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, +): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) +print("\n") + +engine.terminate() +``` + +**We design the Python API `mlc_llm.LLMEngine` to align with OpenAI API**, +which means you can use LLMEngine in the same way of using +[OpenAI's Python package](https://github.com/openai/openai-python?tab=readme-ov-file#usage) +for both synchronous and asynchronous generation. + +In this code example, we use the synchronous chat completion interface and iterate over +all the stream responses. +If you want to run without streaming, you can run + +```python +response = engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=False, +) +print(response) +``` + +You can also try different arguments supported in [OpenAI chat completion API](https://platform.openai.com/docs/api-reference/chat/create). +If you would like to do concurrent asynchronous generation, you can use `mlc_llm.AsyncLLMEngine` instead. + +### REST Server + +We can launch a REST server to serve the 4-bit quantized Llama-3 model for OpenAI chat completion requests. + +```bash +mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC +``` + +The server is hooked at `http://127.0.0.1:8000` by default, and you can use `--host` and `--port` +to set a different host and port. +When the server is ready (showing `INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)`), +we can open a new shell and send a cURL request via the following command: + +```bash +curl -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", + "messages": [ + {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} + ] + }' \ + http://127.0.0.1:8000/v1/chat/completions +``` + +The server will process this request and send back the response. +Similar to [Python API](#python-api), you can pass argument ``"stream": true`` +to request for stream responses. ## Model Support @@ -97,7 +205,7 @@ use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_mode Llama - Llama-2, Code Llama, Vicuna, WizardLM, WizardMath, OpenOrca Platypus2, FlagAlpha Llama-2 Chinese, georgesung Llama-2 Uncensored + Llama-3, Code Llama, Vicuna, WizardLM, WizardMath, OpenOrca Platypus2, FlagAlpha Llama-2 Chinese, georgesung Llama-2 Uncensored GPT-NeoX diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst index b2e91ce2b1..a7ebe28d6d 100644 --- a/docs/deploy/cli.rst +++ b/docs/deploy/cli.rst @@ -62,7 +62,7 @@ automatically from Huggingface. mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --device "cuda:0" --overrides context_window_size=1024 -.. code:: shell +.. code:: You can use the following special commands: /help print the special commands diff --git a/docs/get_started/introduction.rst b/docs/get_started/introduction.rst index b69bd1d504..de979dbb57 100644 --- a/docs/get_started/introduction.rst +++ b/docs/get_started/introduction.rst @@ -32,7 +32,7 @@ You are expected to see the installation path of MLC LLM Python package. Chat CLI -------- -As the first example, we try out the chat CLI in MLC LLM with 4-bit quantized 7B Llama-2 model. +As the first example, we try out the chat CLI in MLC LLM with 4-bit quantized 8B Llama-3 model. You can run MLC chat through a one-liner command: .. code:: bash @@ -54,17 +54,19 @@ After waiting, this command launch a chat interface where you can enter your pro Note: Separate stop words in the `stop` option with commas (,). Multi-line input: Use escape+enter to start a new line. - [INST]: What's the meaning of life? - [/INST]: - Ah, a question that has puzzled philosophers and theologians for centuries! ... + user: What's the meaning of life + assistant: + What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. + + The concept of the meaning of life has been debated and... The figure below shows what run under the hood of this chat CLI command. For the first time running the command, there are three major phases. -- **Phase 1. Pre-quantized weight download.** This phase automatically downloads pre-quantized Llama-2 model from `Hugging Face `_ and saves it to your local cache directory. -- **Phase 2. Model compilation.** This phase automatically optimizes the Llama-2 model to accelerate model inference on GPU with techniques of machine learning compilation in `Apache TVM `_ compiler, and generate the binary model library that enables the execution language models on your local GPU. -- **Phase 3. Chat runtime.** This phase consumes the model library built in phase 2 and the model weights downloaded in phase 1, launches a platform-native chat runtime to drive the execution of Llama-2 model. +- **Phase 1. Pre-quantized weight download.** This phase automatically downloads pre-quantized Llama-3 model from `Hugging Face `_ and saves it to your local cache directory. +- **Phase 2. Model compilation.** This phase automatically optimizes the Llama-3 model to accelerate model inference on GPU with techniques of machine learning compilation in `Apache TVM `_ compiler, and generate the binary model library that enables the execution language models on your local GPU. +- **Phase 3. Chat runtime.** This phase consumes the model library built in phase 2 and the model weights downloaded in phase 1, launches a platform-native chat runtime to drive the execution of Llama-3 model. We cache the pre-quantized model weights and compiled model library locally. Therefore, phase 1 and 2 will only execute **once** over multiple runs. @@ -83,7 +85,7 @@ Therefore, phase 1 and 2 will only execute **once** over multiple runs. Python API ---------- -In the second example, we run the Llama-2 model with the chat completion Python API of MLC LLM. +In the second example, we run the Llama-3 model with the chat completion Python API of MLC LLM. You can save the code below into a Python file and run it. .. code:: python @@ -112,7 +114,7 @@ You can save the code below into a Python file and run it. MLC LLM Python API -This code example first creates an :class:`mlc_llm.LLMEngine` instance with the the 4-bit quantized Llama-2 model. +This code example first creates an :class:`mlc_llm.LLMEngine` instance with the the 4-bit quantized Llama-3 model. **We design the Python API** :class:`mlc_llm.LLMEngine` **to align with OpenAI API**, which means you can use :class:`mlc_llm.LLMEngine` in the same way of using `OpenAI's Python package `_ @@ -137,7 +139,7 @@ If you would like to do concurrent asynchronous generation, you can use :class:` REST Server ----------- -For the third example, we launch a REST server to serve the 4-bit quantized Llama-2 model +For the third example, we launch a REST server to serve the 4-bit quantized Llama-3 model for OpenAI chat completion requests. The server can be launched in command line with .. code:: bash From 855f9a2fae8fc92e365b03dfd5a31b705c7bb4b7 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 19 Apr 2024 03:10:16 -0400 Subject: [PATCH 208/531] [Docs] Documentation of LLMEngine in Python API (#2172) This PR completes the documentation page of LLMEngine and AsyncLLMEngine in our Python API. --- cpp/serve/engine_actions/action_commons.h | 2 +- cpp/serve/event_trace_recorder.h | 2 +- cpp/serve/grammar/grammar_serializer.h | 2 +- docs/deploy/python_engine.rst | 255 ++++++++++++++++++- docs/get_started/introduction.rst | 5 +- python/mlc_llm/serve/engine_base.py | 2 +- python/mlc_llm/serve/event_trace_recorder.py | 2 +- python/mlc_llm/testing/debug_chat.py | 2 +- 8 files changed, 258 insertions(+), 14 deletions(-) diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index aea455a1be..78e3937d0b 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -47,7 +47,7 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array`. + + +Verify Installation +------------------- + +.. code:: bash + + python -c "from mlc_llm import LLMEngine; print(LLMEngine)" + +You are expected to see the output of ````. + +If the command above results in error, follow :ref:`install-mlc-packages` to install prebuilt pip +packages or build MLC LLM from source. + + +Run LLMEngine +------------- + +:class:`mlc_llm.LLMEngine` provides the interface of OpenAI chat completion synchronously. + +**Stream Response.** In :ref:`quick-start` and :ref:`introduction-to-mlc-llm`, +we introduced the basic use of :class:`mlc_llm.LLMEngine`. + +.. code:: python + + from mlc_llm import LLMEngine + + # Create engine + model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" + engine = LLMEngine(model) + + # Run chat completion in OpenAI API. + for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, + ): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) + print("\n") + + engine.terminate() + +This code example first creates an :class:`mlc_llm.LLMEngine` instance with the 8B Llama-3 model. +**We design the Python API** :class:`mlc_llm.LLMEngine` **to align with OpenAI API**, +which means you can use :class:`mlc_llm.LLMEngine` in the same way of using +`OpenAI's Python package `_ +for both synchronous and asynchronous generation. + +**Non-stream Response.** The code example above uses the synchronous chat completion +interface and iterate over all the stream responses. +If you want to run without streaming, you can run + +.. code:: python + + response = engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=False, + ) + print(response) + +Please refer to `OpenAI's Python package `_ +and `OpenAI chat completion API `_ +for the complete chat completion interface. + + +Run AsyncLLMEngine +------------------ + +:class:`mlc_llm.AsyncLLMEngine` provides the interface of OpenAI chat completion with +asynchronous features. + +**Stream Response.** The core use of :class:`mlc_llm.AsyncLLMEngine` for stream responses is as follows. + +.. code:: python + + async for response in await engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, + ): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) + +.. collapse:: The collapsed is a complete runnable example of AsyncLLMEngine in Python. + + .. code:: python + + import asyncio + from typing import Dict + + from mlc_llm.serve import AsyncLLMEngine + + model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" + prompts = [ + "Write a three-day travel plan to Pittsburgh.", + "What is the meaning of life?", + ] + + + async def test_completion(): + # Create engine + async_engine = AsyncLLMEngine(model=model) + + num_requests = len(prompts) + output_texts: Dict[str, str] = {} + + async def generate_task(prompt: str): + async for response in await async_engine.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + model=model, + stream=True, + ): + if response.id not in output_texts: + output_texts[response.id] = "" + output_texts[response.id] += response.choices[0].delta.content + + tasks = [asyncio.create_task(generate_task(prompts[i])) for i in range(num_requests)] + await asyncio.gather(*tasks) + + # Print output. + for request_id, output in output_texts.items(): + print(f"Output of request {request_id}:\n{output}\n") + + async_engine.terminate() + + + asyncio.run(test_completion()) + +| + +**Non-stream Response.** Similarly, :class:`mlc_llm.AsyncEngine` provides the non-stream response +interface. + +.. code:: python + + response = await engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=False, + ) + print(response) + +Please refer to `OpenAI's Python package `_ +and `OpenAI chat completion API `_ +for the complete chat completion interface. + + +Engine Mode +----------- + +To ease the engine configuration, the constructors of :class:`mlc_llm.LLMEngine` and +:class:`mlc_llm.AsyncLLMEngine` have an optional argument ``mode``, +which falls into one of the three options ``"local"``, ``"interactive"`` or ``"server"``. +The default mode is ``"local"``. + +Each mode denotes a pre-defined configuration of the engine to satisfy different use cases. +The choice of the mode controls the request concurrency of the engine, +as well as engine's KV cache token capacity (or in other words, the maximum +number of tokens that the engine's KV cache can hold), +and further affects the GPU memory usage of the engine. + +In short, + +- mode ``"local"`` uses low request concurrency and low KV cache capacity, which is suitable for cases where **concurrent requests are not too many, and the user wants to save GPU memory usage**. +- mode ``"interactive"`` uses 1 as the request concurrency and low KV cache capacity, which is designed for **interactive use cases** such as chats and conversations. +- mode ``"server"`` uses as much request concurrency and KV cache capacity as possible. This mode aims to **fully utilize the GPU memory for large server scenarios** where concurrent requests may be many. + +Please refer to :ref:`python-engine-api-reference` for detailed documentation of the engine mode. + + +Deploy Your Own Model with Python API +------------------------------------- + +The :ref:`introduction page ` introduces how we can deploy our +own models with MLC LLM. +This section introduces how you can use the model weights you convert and the model library you build +in :class:`mlc_llm.LLMEngine` and :class:`mlc_llm.AsyncLLMEngine`. + +We use the `Phi-2 `_ as the example model. + +**Specify Model Weight Path.** Assume you have converted the model weights for your own model, +you can construct a :class:`mlc_llm.LLMEngine` as follows: + +.. code:: python + + from mlc_llm import LLMEngine + + model = "models/phi-2" # Assuming the converted phi-2 model weights are under "models/phi-2" + engine = LLMEngine(model) + + +**Specify Model Library Path.** Further, if you build the model library on your own, +you can use it in :class:`mlc_llm.LLMEngine` by passing the library path through argument ``model_lib_path``. + +.. code:: python + + from mlc_llm import LLMEngine + + model = "models/phi-2" + model_lib_path = "models/phi-2/lib.so" # Assuming the phi-2 model library is built at "models/phi-2/lib.so" + engine = LLMEngine(model, model_lib_path=model_lib_path) + + +The same applies to :class:`mlc_llm.AsyncLLMEngine`. + + +.. _python-engine-api-reference: + +API Reference +------------- + +The :class:`mlc_llm.LLMEngine` and :class:`mlc_llm.AsyncLLMEngine` classes provide the following constructors. + +The LLMEngine and AsyncLLMEngine have full OpenAI API completeness. +Please refer to `OpenAI's Python package `_ +and `OpenAI chat completion API `_ +for the complete chat completion interface. + +.. currentmodule:: mlc_llm + +.. autoclass:: LLMEngine + :members: + :exclude-members: evaluate + :undoc-members: + :show-inheritance: + + .. automethod:: __init__ + +.. autoclass:: AsyncLLMEngine + :members: + :exclude-members: evaluate + :undoc-members: + :show-inheritance: + + .. automethod:: __init__ diff --git a/docs/get_started/introduction.rst b/docs/get_started/introduction.rst index de979dbb57..32bcfc4cdb 100644 --- a/docs/get_started/introduction.rst +++ b/docs/get_started/introduction.rst @@ -114,7 +114,7 @@ You can save the code below into a Python file and run it. MLC LLM Python API -This code example first creates an :class:`mlc_llm.LLMEngine` instance with the the 4-bit quantized Llama-3 model. +This code example first creates an :class:`mlc_llm.LLMEngine` instance with the 4-bit quantized Llama-3 model. **We design the Python API** :class:`mlc_llm.LLMEngine` **to align with OpenAI API**, which means you can use :class:`mlc_llm.LLMEngine` in the same way of using `OpenAI's Python package `_ @@ -167,6 +167,7 @@ The server will process this request and send back the response. Similar to :ref:`introduction-to-mlc-llm-python-api`, you can pass argument ``"stream": true`` to request for stream responses. +.. _introduction-deploy-your-own-model: Deploy Your Own Model --------------------- @@ -300,7 +301,7 @@ To briefly summarize this page, - We went through three examples (chat CLI, Python API, and REST server) of MLC LLM, - we introduced how to convert model weights for your own models to run with MLC LLM, and (optionally) how to compile your models. -- We also discussed the the universal deployment capability of MLC LLM. +- We also discussed the universal deployment capability of MLC LLM. Next, please feel free to check out the pages below for quick start examples and more detailed information on specific platforms diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 4c95f6e612..9a25401d3f 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -577,7 +577,7 @@ def __init__(self, enable_tracing: bool) -> None: self.trace_recorder = EventTraceRecorder() def record_event(self, request_id: str, event: str) -> None: - """Record a event for the the input request in the trace + """Record a event for the input request in the trace recorder when the recorder exists. Parameters diff --git a/python/mlc_llm/serve/event_trace_recorder.py b/python/mlc_llm/serve/event_trace_recorder.py index 7a8a8177fe..457918d598 100644 --- a/python/mlc_llm/serve/event_trace_recorder.py +++ b/python/mlc_llm/serve/event_trace_recorder.py @@ -17,7 +17,7 @@ def __init__(self) -> None: ) def add_event(self, request_id: str, event: str) -> None: - """Record a event for the the input request in the trace recorder. + """Record a event for the input request in the trace recorder. Parameters ---------- diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index 2a70154bba..4f1cfe103d 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -118,7 +118,7 @@ def __call__(self, func, name, before_run, ret_val, *args): print(f"{red(f'{func_name} has INF')}: {num_infs}") self.first_inf_occurred = True - # Save the the arguments to npz + # Save the arguments to npz arg_dict = {} for i, arg in enumerate(args): if isinstance(arg, tvm.nd.NDArray): From f87745d26f1b1ba0746c6fb8da29c9fd88355d13 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 19 Apr 2024 10:37:33 -0400 Subject: [PATCH 209/531] [Docs] Update project website (#2175) This PR mainly updates the project website, and also updates some minor points for other docs. --- README.md | 98 +----------- docs/deploy/python_engine.rst | 6 + docs/index.rst | 1 - site/index.md | 285 +++++++++++++++++++++++++++------- 4 files changed, 243 insertions(+), 147 deletions(-) diff --git a/README.md b/README.md index 782d647531..647b9047f2 100644 --- a/README.md +++ b/README.md @@ -50,22 +50,7 @@ -**Scalable.** MLC LLM scales universally on NVIDIA and AMD GPUs, cloud and gaming GPUs. Below -showcases our single batch decoding performance with prefilling = 1 and decoding = 256. - -Performance of 4-bit CodeLlama-34B and Llama2-70B on two NVIDIA RTX 4090 and two AMD Radeon 7900 XTX: -

- - -

- -Scaling of fp16 and 4-bit CodeLlama-34 and Llama2-70B on A100-80G-PCIe and A10G-24G-PCIe, up to 8 GPUs: -

- -

- - -## Getting Started +## Quick Start We introduce the quick start examples of chat CLI, Python API and REST server here to use MLC LLM. We use 4-bit quantized 8B Llama-3 model for demonstration purpose. @@ -140,30 +125,17 @@ print("\n") engine.terminate() ``` -**We design the Python API `mlc_llm.LLMEngine` to align with OpenAI API**, -which means you can use LLMEngine in the same way of using +**The Python API of `mlc_llm.LLMEngine` fully aligns with OpenAI API**. +You can use LLMEngine in the same way of using [OpenAI's Python package](https://github.com/openai/openai-python?tab=readme-ov-file#usage) for both synchronous and asynchronous generation. -In this code example, we use the synchronous chat completion interface and iterate over -all the stream responses. -If you want to run without streaming, you can run - -```python -response = engine.chat.completions.create( - messages=[{"role": "user", "content": "What is the meaning of life?"}], - model=model, - stream=False, -) -print(response) -``` - -You can also try different arguments supported in [OpenAI chat completion API](https://platform.openai.com/docs/api-reference/chat/create). If you would like to do concurrent asynchronous generation, you can use `mlc_llm.AsyncLLMEngine` instead. ### REST Server We can launch a REST server to serve the 4-bit quantized Llama-3 model for OpenAI chat completion requests. +The server has fully OpenAI API completeness. ```bash mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC @@ -186,66 +158,6 @@ curl -X POST \ http://127.0.0.1:8000/v1/chat/completions ``` -The server will process this request and send back the response. -Similar to [Python API](#python-api), you can pass argument ``"stream": true`` -to request for stream responses. - -## Model Support - -MLC LLM supports a wide range of model architectures and variants. We have the following prebuilts which you can -use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_models.html) to see the full list, and [Compile Models via MLC](https://llm.mlc.ai/docs/compilation/compile_models.html) to see how to use models not on this list. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
ArchitecturePrebuilt Model Variants
LlamaLlama-3, Code Llama, Vicuna, WizardLM, WizardMath, OpenOrca Platypus2, FlagAlpha Llama-2 Chinese, georgesung Llama-2 Uncensored
GPT-NeoXRedPajama
GPT-J
RWKVRWKV-raven
MiniGPT
GPTBigCodeWizardCoder
ChatGLM
StableLM
Mistral
Phi
- ## Universal Deployment APIs MLC LLM provides multiple sets of APIs across platforms and environments. These include @@ -273,7 +185,7 @@ The underlying techniques of MLC LLM include:
References (Click to expand) - + ```bibtex @inproceedings{tensorir, author = {Feng, Siyuan and Hou, Bohan and Jin, Hongyi and Lin, Wuwei and Shao, Junru and Lai, Ruihang and Ye, Zihao and Zheng, Lianmin and Yu, Cody Hao and Yu, Yong and Chen, Tianqi}, diff --git a/docs/deploy/python_engine.rst b/docs/deploy/python_engine.rst index e3b88cec9c..cfbc3b5d4c 100644 --- a/docs/deploy/python_engine.rst +++ b/docs/deploy/python_engine.rst @@ -38,6 +38,8 @@ Run LLMEngine ------------- :class:`mlc_llm.LLMEngine` provides the interface of OpenAI chat completion synchronously. +:class:`mlc_llm.LLMEngine` does not batch concurrent request due to the synchronous design, +and please use :ref:`AsyncLLMEngine ` for request batching process. **Stream Response.** In :ref:`quick-start` and :ref:`introduction-to-mlc-llm`, we introduced the basic use of :class:`mlc_llm.LLMEngine`. @@ -86,11 +88,14 @@ and `OpenAI chat completion API - -

- -Note: Llama-7B takes 4GB of RAM and RedPajama-3B takes 2.2GB to run. We recommend a latest device with 6GB RAM for Llama-7B, or 4GB RAM for RedPajama-3B, to run the app. The text generation speed could vary from time to time, for example, slow in the beginning but recover to a normal speed then. - -### Android - -The demo APK is available to [download](https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android/mlc-chat.apk). The demo is tested on Samsung S23 with Snapdragon 8 Gen 2 chip, Redmi Note 12 Pro with Snapdragon 685 and Google Pixel phones. -Besides the [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) page, -[documentation](https://llm.mlc.ai/docs/deploy/android.html) is available for building android apps with MLC LLM. - -

- -

- -### Windows Linux Mac - -Our cpp interface runs on AMD, Intel, Apple and NVIDIA GPUs. -Besides the [Getting Started](https://llm.mlc.ai/docs/get_started/try_out.html) page, -[documentation](https://llm.mlc.ai/docs/deploy/cli.html) is available for building C++ apps with MLC LLM. - -

- -

- -### Web Browser - -[WebLLM](https://webllm.mlc.ai/) is our companion project that deploys MLC LLM natively to browsers using WebGPU and WebAssembly. Still everything runs inside the browser without server resources, and accelerated by local GPUs (e.g. AMD, Intel, Apple or NVIDIA). +[Documentation](https://llm.mlc.ai/docs) | [Blog](https://blog.mlc.ai/) | [Discord][discord-url] + +**M**achine **L**earning **C**ompilation for **L**arge **L**anguage **M**odels (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. + +**Universal deployment.** MLC LLM supports the following platforms and hardware: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
AMD GPUNVIDIA GPUApple GPUIntel GPU
Linux / Win✅ Vulkan, ROCm✅ Vulkan, CUDAN/A✅ Vulkan
macOS✅ Metal (dGPU)N/A✅ Metal✅ Metal (iGPU)
Web Browser✅ WebGPU and WASM
iOS / iPadOS✅ Metal on Apple A-series GPU
Android✅ OpenCL on Adreno GPU✅ OpenCL on Mali GPU
+ + +## Quick Start + +We introduce the quick start examples of chat CLI, Python API and REST server here to use MLC LLM. +We use 4-bit quantized 8B Llama-3 model for demonstration purpose. +The pre-quantized Llama-3 weights is available at https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC. +You can also try out unquantized Llama-3 model by replacing `q4f16_1` to `q0f16` in the examples below. +Please visit our [documentation](https://llm.mlc.ai/docs/index.html) for detailed quick start and introduction. + +### Installation + +MLC LLM is available via [pip](https://llm.mlc.ai/docs/install/mlc_llm.html#install-mlc-packages). +It is always recommended to install it in an isolated conda virtual environment. + +To verify the installation, activate your virtual environment, run + +```bash +python -c "import mlc_llm; print(mlc_llm.__path__)" +``` + +You are expected to see the installation path of MLC LLM Python package. + +### Chat CLI + +We can try out the chat CLI in MLC LLM with 4-bit quantized 8B Llama-3 model. + +```bash +mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC +``` + +It may take 1-2 minutes for the first time running this command. +After waiting, this command launch a chat interface where you can enter your prompt and chat with the model. + +``` +You can use the following special commands: +/help print the special commands +/exit quit the cli +/stats print out the latest stats (token/sec) +/reset restart a fresh chat +/set [overrides] override settings in the generation config. For example, + `/set temperature=0.5;max_gen_len=100;stop=end,stop` + Note: Separate stop words in the `stop` option with commas (,). +Multi-line input: Use escape+enter to start a new line. + +user: What's the meaning of life +assistant: +What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. + +The concept of the meaning of life has been debated and... +``` + +### Python API + +We can run the Llama-3 model with the chat completion Python API of MLC LLM. +You can save the code below into a Python file and run it. + +```python +from mlc_llm import LLMEngine + +# Create engine +model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" +engine = LLMEngine(model) + +# Run chat completion in OpenAI API. +for response in engine.chat.completions.create( + messages=[{"role": "user", "content": "What is the meaning of life?"}], + model=model, + stream=True, +): + for choice in response.choices: + print(choice.delta.content, end="", flush=True) +print("\n") + +engine.terminate() +``` + +**The Python API of `mlc_llm.LLMEngine` fully aligns with OpenAI API**. +You can use LLMEngine in the same way of using +[OpenAI's Python package](https://github.com/openai/openai-python?tab=readme-ov-file#usage) +for both synchronous and asynchronous generation. + +If you would like to do concurrent asynchronous generation, you can use `mlc_llm.AsyncLLMEngine` instead. + +### REST Server + +We can launch a REST server to serve the 4-bit quantized Llama-3 model for OpenAI chat completion requests. +The server has fully OpenAI API completeness. + +```bash +mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC +``` + +The server is hooked at `http://127.0.0.1:8000` by default, and you can use `--host` and `--port` +to set a different host and port. +When the server is ready (showing `INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)`), +we can open a new shell and send a cURL request via the following command: + +```bash +curl -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", + "messages": [ + {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} + ] + }' \ + http://127.0.0.1:8000/v1/chat/completions +``` + +## Universal Deployment APIs + +MLC LLM provides multiple sets of APIs across platforms and environments. These include +* [Python API](https://llm.mlc.ai/docs/deploy/python_engine.html) +* [OpenAI-compatible Rest-API](https://llm.mlc.ai/docs/deploy/rest.html) +* [C++ API](https://llm.mlc.ai/docs/deploy/cli.html) +* [JavaScript API](https://llm.mlc.ai/docs/deploy/javascript.html) and [Web LLM](https://github.com/mlc-ai/web-llm) +* [Swift API for iOS App](https://llm.mlc.ai/docs/deploy/ios.html) +* [Java API and Android App](https://llm.mlc.ai/docs/deploy/android.html) + +## Citation + +Please consider citing our project if you find it useful: + +```bibtex +@software{mlc-llm, + author = {MLC team}, + title = {{MLC-LLM}}, + url = {https://github.com/mlc-ai/mlc-llm}, + year = {2023} +} +``` + +The underlying techniques of MLC LLM include: + +
+ References (Click to expand) + + ```bibtex + @inproceedings{tensorir, + author = {Feng, Siyuan and Hou, Bohan and Jin, Hongyi and Lin, Wuwei and Shao, Junru and Lai, Ruihang and Ye, Zihao and Zheng, Lianmin and Yu, Cody Hao and Yu, Yong and Chen, Tianqi}, + title = {TensorIR: An Abstraction for Automatic Tensorized Program Optimization}, + year = {2023}, + isbn = {9781450399166}, + publisher = {Association for Computing Machinery}, + address = {New York, NY, USA}, + url = {https://doi.org/10.1145/3575693.3576933}, + doi = {10.1145/3575693.3576933}, + booktitle = {Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2}, + pages = {804–817}, + numpages = {14}, + keywords = {Tensor Computation, Machine Learning Compiler, Deep Neural Network}, + location = {Vancouver, BC, Canada}, + series = {ASPLOS 2023} + } + + @inproceedings{metaschedule, + author = {Shao, Junru and Zhou, Xiyou and Feng, Siyuan and Hou, Bohan and Lai, Ruihang and Jin, Hongyi and Lin, Wuwei and Masuda, Masahiro and Yu, Cody Hao and Chen, Tianqi}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh}, + pages = {35783--35796}, + publisher = {Curran Associates, Inc.}, + title = {Tensor Program Optimization with Probabilistic Programs}, + url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/e894eafae43e68b4c8dfdacf742bcbf3-Paper-Conference.pdf}, + volume = {35}, + year = {2022} + } + + @inproceedings{tvm, + author = {Tianqi Chen and Thierry Moreau and Ziheng Jiang and Lianmin Zheng and Eddie Yan and Haichen Shen and Meghan Cowan and Leyuan Wang and Yuwei Hu and Luis Ceze and Carlos Guestrin and Arvind Krishnamurthy}, + title = {{TVM}: An Automated {End-to-End} Optimizing Compiler for Deep Learning}, + booktitle = {13th USENIX Symposium on Operating Systems Design and Implementation (OSDI 18)}, + year = {2018}, + isbn = {978-1-939133-08-3}, + address = {Carlsbad, CA}, + pages = {578--594}, + url = {https://www.usenix.org/conference/osdi18/presentation/chen}, + publisher = {USENIX Association}, + month = oct, + } + ``` +
## Links -* Our official [GitHub repo](https://github.com/mlc-ai/mlc-llm); -* Our companion project [WebLLM](https://webllm.mlc.ai/) that enables running LLMs purely in browser. -* [Web Stable Diffusion](https://websd.mlc.ai/) is another MLC-series that runs the diffusion models purely in the browser. -* [Machine Learning Compilation course](https://mlc.ai) is available for a systematic walkthrough of our approach to universal deployment. +- You might want to check out our online public [Machine Learning Compilation course](https://mlc.ai) for a systematic +walkthrough of our approaches. +- [WebLLM](https://webllm.mlc.ai/) is a companion project using MLC LLM's WebGPU and WebAssembly backend. +- [WebStableDiffusion](https://websd.mlc.ai/) is a companion project for diffusion models with the WebGPU backend. ## Disclaimer From b3b7f237760af689e1d7c28d6ba4a5e5aa3ae7cc Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 19 Apr 2024 11:08:18 -0400 Subject: [PATCH 210/531] [Docs][Fix] Update index.md for jekyll failure (#2176) This PR fixes the jekyll failure of the project website by removing the citation section (having it in README is sufficient). --- site/index.md | 63 --------------------------------------------------- 1 file changed, 63 deletions(-) diff --git a/site/index.md b/site/index.md index 7bd71d3529..41b220b45f 100644 --- a/site/index.md +++ b/site/index.md @@ -172,69 +172,6 @@ MLC LLM provides multiple sets of APIs across platforms and environments. These * [Swift API for iOS App](https://llm.mlc.ai/docs/deploy/ios.html) * [Java API and Android App](https://llm.mlc.ai/docs/deploy/android.html) -## Citation - -Please consider citing our project if you find it useful: - -```bibtex -@software{mlc-llm, - author = {MLC team}, - title = {{MLC-LLM}}, - url = {https://github.com/mlc-ai/mlc-llm}, - year = {2023} -} -``` - -The underlying techniques of MLC LLM include: - -
- References (Click to expand) - - ```bibtex - @inproceedings{tensorir, - author = {Feng, Siyuan and Hou, Bohan and Jin, Hongyi and Lin, Wuwei and Shao, Junru and Lai, Ruihang and Ye, Zihao and Zheng, Lianmin and Yu, Cody Hao and Yu, Yong and Chen, Tianqi}, - title = {TensorIR: An Abstraction for Automatic Tensorized Program Optimization}, - year = {2023}, - isbn = {9781450399166}, - publisher = {Association for Computing Machinery}, - address = {New York, NY, USA}, - url = {https://doi.org/10.1145/3575693.3576933}, - doi = {10.1145/3575693.3576933}, - booktitle = {Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2}, - pages = {804–817}, - numpages = {14}, - keywords = {Tensor Computation, Machine Learning Compiler, Deep Neural Network}, - location = {Vancouver, BC, Canada}, - series = {ASPLOS 2023} - } - - @inproceedings{metaschedule, - author = {Shao, Junru and Zhou, Xiyou and Feng, Siyuan and Hou, Bohan and Lai, Ruihang and Jin, Hongyi and Lin, Wuwei and Masuda, Masahiro and Yu, Cody Hao and Chen, Tianqi}, - booktitle = {Advances in Neural Information Processing Systems}, - editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh}, - pages = {35783--35796}, - publisher = {Curran Associates, Inc.}, - title = {Tensor Program Optimization with Probabilistic Programs}, - url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/e894eafae43e68b4c8dfdacf742bcbf3-Paper-Conference.pdf}, - volume = {35}, - year = {2022} - } - - @inproceedings{tvm, - author = {Tianqi Chen and Thierry Moreau and Ziheng Jiang and Lianmin Zheng and Eddie Yan and Haichen Shen and Meghan Cowan and Leyuan Wang and Yuwei Hu and Luis Ceze and Carlos Guestrin and Arvind Krishnamurthy}, - title = {{TVM}: An Automated {End-to-End} Optimizing Compiler for Deep Learning}, - booktitle = {13th USENIX Symposium on Operating Systems Design and Implementation (OSDI 18)}, - year = {2018}, - isbn = {978-1-939133-08-3}, - address = {Carlsbad, CA}, - pages = {578--594}, - url = {https://www.usenix.org/conference/osdi18/presentation/chen}, - publisher = {USENIX Association}, - month = oct, - } - ``` -
- ## Links - You might want to check out our online public [Machine Learning Compilation course](https://mlc.ai) for a systematic From 9216467cda604978c702ef336eb46f0e1afaf82b Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 19 Apr 2024 10:31:14 -0700 Subject: [PATCH 211/531] [Quantization] Add e4m3 mode and enable fp8 storage type (reland #2154) (#2161) * [Quantization] Add e4m3 mode and enable fp8 storage type * add quantize linear flag --- python/mlc_llm/cli/model_metadata.py | 4 +- python/mlc_llm/interface/convert_weight.py | 5 +- python/mlc_llm/op/moe_matmul.py | 3 +- .../quantization/per_tensor_quantization.py | 80 ++++++++++++------- python/mlc_llm/quantization/quantization.py | 17 +++- python/mlc_llm/quantization/utils.py | 3 +- 6 files changed, 73 insertions(+), 39 deletions(-) diff --git a/python/mlc_llm/cli/model_metadata.py b/python/mlc_llm/cli/model_metadata.py index 9b45561665..81473b1ec7 100644 --- a/python/mlc_llm/cli/model_metadata.py +++ b/python/mlc_llm/cli/model_metadata.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Any, Dict, List, Union -import numpy as np +from tvm.runtime import DataType from mlc_llm.support import logging from mlc_llm.support.argparse import ArgumentParser @@ -81,7 +81,7 @@ def _compute_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBa else: # Contains dynamic shape; use config to look up concrete values param_shape = _read_dynamic_shape(param["shape"], config) - params_bytes += math.prod(param_shape) * np.dtype(param["dtype"]).itemsize + params_bytes += math.prod(param_shape) * DataType(param["dtype"]).itemsize() temp_func_bytes = 0.0 for _func_name, func_bytes in metadata["memory_usage"].items(): temp_func_bytes = max(temp_func_bytes, func_bytes) diff --git a/python/mlc_llm/interface/convert_weight.py b/python/mlc_llm/interface/convert_weight.py index 90c5c45831..f6c3c5f255 100644 --- a/python/mlc_llm/interface/convert_weight.py +++ b/python/mlc_llm/interface/convert_weight.py @@ -7,10 +7,9 @@ from pathlib import Path from typing import Any, Dict, Iterator, Tuple -import numpy as np from tvm import tir from tvm.contrib import tvmjs -from tvm.runtime import Device, NDArray +from tvm.runtime import DataType, Device, NDArray from tvm.runtime import cpu as cpu_device from tvm.target import Target @@ -131,7 +130,7 @@ def _param_generator() -> Iterator[Tuple[str, NDArray]]: _check_param(name, param) param_names.add(name) param = param.copyto(cpu_device()) - total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize + total_bytes += math.prod(param.shape) * DataType(param.dtype).itemsize() yield name, param total_params = loader.stats.total_param_num diff --git a/python/mlc_llm/op/moe_matmul.py b/python/mlc_llm/op/moe_matmul.py index 95d7fed941..6def4a5ff2 100644 --- a/python/mlc_llm/op/moe_matmul.py +++ b/python/mlc_llm/op/moe_matmul.py @@ -2,7 +2,7 @@ from typing import Literal, Optional -from tvm import DataType, tir +from tvm import DataType, DataTypeCode, tir from tvm.relax.frontend.nn import Tensor, op from tvm.script import tir as T @@ -218,6 +218,7 @@ def _dequantize(w, s, e, i, j): if num_elem_per_storage == 1: w = tir.reinterpret(quantize_dtype, w[e, i, j]) else: + assert DataType(storage_dtype).type_code == DataTypeCode.UINT tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype) w = w[e, i, j // num_elem_per_storage] shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype) diff --git a/python/mlc_llm/quantization/per_tensor_quantization.py b/python/mlc_llm/quantization/per_tensor_quantization.py index c2776b2a86..274a221393 100644 --- a/python/mlc_llm/quantization/per_tensor_quantization.py +++ b/python/mlc_llm/quantization/per_tensor_quantization.py @@ -16,6 +16,7 @@ compile_quantize_func, convert_uint_packed_fp8_to_float, is_final_fc, + is_moe_gate, pack_weight, ) @@ -30,10 +31,11 @@ class PerTensorQuantize: # pylint: disable=too-many-instance-attributes kind: str activation_dtype: Literal["e4m3_float8", "e5m2_float8"] weight_dtype: Literal["e4m3_float8", "e5m2_float8"] - storage_dtype: Literal["uint32"] + storage_dtype: Literal["uint32", "e4m3_float8", "e5m2_float8"] model_dtype: Literal["float16"] quantize_embedding: bool = True quantize_final_fc: bool = True + quantize_linear: bool = True num_elem_per_storage: int = 0 max_int_value: int = 0 @@ -101,8 +103,11 @@ def visit_module(self, name: str, node: nn.Module) -> Any: f"{name}.q_weight", ] ) - if isinstance(node, nn.Linear) and ( - not is_final_fc(name) or self.config.quantize_final_fc + if ( + isinstance(node, nn.Linear) + and self.config.quantize_linear + and (not is_final_fc(name) or self.config.quantize_final_fc) + and not is_moe_gate(name, node) ): self.quant_map.param_map[weight_name] = param_names self.quant_map.map_func[weight_name] = self.config.quantize_weight @@ -192,7 +197,11 @@ def _compute_scale(x: te.Tensor) -> te.Tensor: scale = None def _compute_quantized_weight(weight: te.Tensor, scale: Optional[te.Tensor]) -> te.Tensor: - elem_storage_dtype = f"uint{quantize_dtype.bits}" + elem_storage_dtype = ( + f"uint{quantize_dtype.bits}" + if DataType(self.storage_dtype).type_code == DataTypeCode.UINT + else quantize_dtype + ) scaled_weight = te.compute( shape=weight.shape, fcompute=lambda *idx: tir.Cast( @@ -207,6 +216,9 @@ def _compute_quantized_weight(weight: te.Tensor, scale: Optional[te.Tensor]) -> ), ) + if self.weight_dtype == self.storage_dtype: + return scaled_weight + packed_weight = pack_weight( scaled_weight, axis=-1, @@ -248,15 +260,18 @@ def dequantize_float8( out_shape: Optional[Sequence[tir.PrimExpr]] = None, ) -> te.Tensor: """Dequantize a fp8 tensor to higher-precision float.""" - weight = convert_uint_packed_fp8_to_float( - q_weight, - self.num_elem_per_storage, - self.storage_dtype, - self.model_dtype, - quantize_dtype, - axis=-1, - out_shape=out_shape, - ) + if quantize_dtype != self.storage_dtype: + weight = convert_uint_packed_fp8_to_float( + q_weight, + self.num_elem_per_storage, + self.storage_dtype, + self.model_dtype, + quantize_dtype, + axis=-1, + out_shape=out_shape, + ) + else: + weight = q_weight.astype(self.model_dtype) if scale is not None: weight = weight * scale return weight @@ -276,7 +291,7 @@ def __init__( # pylint: disable=too-many-arguments super().__init__() self.in_features = in_features self.out_features = out_features - self.out_dtype = out_dtype + self.out_dtype = out_dtype or config.model_dtype self.config = config self.q_weight = nn.Parameter( (out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), @@ -341,22 +356,27 @@ def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name ret : nn.Tensor The output tensor for the per-tensor quantized linear layer. """ - w = nn.op.tensor_expr_op( - lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access - weight, - scale, - out_shape=[ - ( - tir.IntImm("int64", self.out_features) - if isinstance(self.out_features, int) - else weight.shape[0] - ), - tir.IntImm("int64", self.in_features), - ], - ), - "dequantize", - args=[self.q_weight, self.q_scale], - ) + # Note: Use calibration scale when calibration is enabled + x = x.astype(self.config.activation_dtype) + if self.config.weight_dtype == self.config.storage_dtype: + w = self.q_weight + else: + w = nn.op.tensor_expr_op( + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + out_shape=[ + ( + tir.IntImm("int64", self.out_features) + if isinstance(self.out_features, int) + else weight.shape[0] + ), + tir.IntImm("int64", self.in_features), + ], + ), + "dequantize", + args=[self.q_weight, self.q_scale], + ) w = nn.op.permute_dims(w) x = nn.op.matmul(x, w, out_dtype=self.out_dtype) if self.bias is not None: diff --git a/python/mlc_llm/quantization/quantization.py b/python/mlc_llm/quantization/quantization.py index 1b2d8695cf..ed7d8a6720 100644 --- a/python/mlc_llm/quantization/quantization.py +++ b/python/mlc_llm/quantization/quantization.py @@ -123,10 +123,23 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr kind="per-tensor-quant", activation_dtype="e5m2_float8", weight_dtype="e5m2_float8", - storage_dtype="uint32", + storage_dtype="e5m2_float8", model_dtype="float16", - quantize_final_fc=True, + quantize_final_fc=False, + quantize_embedding=False, + quantize_linear=True, + use_scale=False, + ), + "e4m3_e4m3_f16": PerTensorQuantize( + name="e4m3_e4m3_f16", + kind="per-tensor-quant", + activation_dtype="e4m3_float8", + weight_dtype="e4m3_float8", + storage_dtype="e4m3_float8", + model_dtype="float16", + quantize_final_fc=False, quantize_embedding=False, + quantize_linear=True, use_scale=False, ), } diff --git a/python/mlc_llm/quantization/utils.py b/python/mlc_llm/quantization/utils.py index fdc50ff74d..3e55de4524 100644 --- a/python/mlc_llm/quantization/utils.py +++ b/python/mlc_llm/quantization/utils.py @@ -6,7 +6,7 @@ from tvm import dlight as dl from tvm import relax, te, tir from tvm.relax.frontend import nn -from tvm.runtime import DataType +from tvm.runtime import DataType, DataTypeCode from tvm.target import Target from mlc_llm.support import tensor_parallel as tp @@ -105,6 +105,7 @@ def convert_uint_packed_fp8_to_float( # pylint: disable=too-many-arguments ) -> te.Tensor: """Unpack a fp8 value from the storage dtype and convert to float.""" assert quant_dtype in ["e4m3_float8", "e5m2_float8"] + assert DataType(storage_dtype).type_code == DataTypeCode.UINT bits = DataType(quant_dtype).bits elem_storage_dtype = DataType(f"uint{bits}") tir_bin_mask = tir.const((1 << bits) - 1, "uint8") From a50fae0e3cd6e2c19cce69c6d364bf0f813f19bb Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 19 Apr 2024 15:30:40 -0400 Subject: [PATCH 212/531] [Docs] Fix API reference not displayed (#2177) This PR fixes the issue of the API reference not displayed in the documentation. --- docs/requirements.txt | 4 ++++ scripts/build_mlc_for_docs.sh | 8 ++++++++ scripts/build_site.sh | 1 + scripts/gh_deploy_site.sh | 1 + 4 files changed, 14 insertions(+) create mode 100755 scripts/build_mlc_for_docs.sh diff --git a/docs/requirements.txt b/docs/requirements.txt index bc020bc662..0156a180b0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,5 +6,9 @@ tlcpack-sphinx-addon==0.2.2 sphinxcontrib_httpdomain==1.8.1 sphinxcontrib-napoleon==0.7 sphinx-reredirects==0.1.2 +shortuuid +pydantic +uvicorn +fastapi --find-links https://mlc.ai/wheels mlc-ai-nightly diff --git a/scripts/build_mlc_for_docs.sh b/scripts/build_mlc_for_docs.sh new file mode 100755 index 0000000000..50eee3231a --- /dev/null +++ b/scripts/build_mlc_for_docs.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -euxo pipefail + +mkdir -p build +cd build +cmake .. +make -j$(nproc) +cd - diff --git a/scripts/build_site.sh b/scripts/build_site.sh index 6340ee838e..062f8094de 100755 --- a/scripts/build_site.sh +++ b/scripts/build_site.sh @@ -1,6 +1,7 @@ #!/bin/bash set -euxo pipefail +export PYTHONPATH=$PWD/python cd docs && make html && cd .. cd site && jekyll b && cd .. diff --git a/scripts/gh_deploy_site.sh b/scripts/gh_deploy_site.sh index 1b21c52d16..326c280484 100755 --- a/scripts/gh_deploy_site.sh +++ b/scripts/gh_deploy_site.sh @@ -4,6 +4,7 @@ set -euxo pipefail +scripts/build_mlc_for_docs.sh scripts/build_site.sh git fetch From 675319f2ee08c6fd973b8b31722989bf6a673fff Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 19 Apr 2024 15:54:26 -0400 Subject: [PATCH 213/531] [Docs] Update project website (#2180) This PR updates the project landing website to remove some information. --- site/index.md | 155 +++----------------------------------------------- 1 file changed, 9 insertions(+), 146 deletions(-) diff --git a/site/index.md b/site/index.md index 41b220b45f..ac0367cdb2 100644 --- a/site/index.md +++ b/site/index.md @@ -6,63 +6,15 @@ notitle: true # MLC LLM -[Documentation](https://llm.mlc.ai/docs) | [Blog](https://blog.mlc.ai/) | [Discord][discord-url] +Documentation: [https://llm.mlc.ai/docs](https://llm.mlc.ai/docs) **M**achine **L**earning **C**ompilation for **L**arge **L**anguage **M**odels (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. -**Universal deployment.** MLC LLM supports the following platforms and hardware: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
AMD GPUNVIDIA GPUApple GPUIntel GPU
Linux / Win✅ Vulkan, ROCm✅ Vulkan, CUDAN/A✅ Vulkan
macOS✅ Metal (dGPU)N/A✅ Metal✅ Metal (iGPU)
Web Browser✅ WebGPU and WASM
iOS / iPadOS✅ Metal on Apple A-series GPU
Android✅ OpenCL on Adreno GPU✅ OpenCL on Mali GPU
+

+ +

- -## Quick Start - -We introduce the quick start examples of chat CLI, Python API and REST server here to use MLC LLM. -We use 4-bit quantized 8B Llama-3 model for demonstration purpose. -The pre-quantized Llama-3 weights is available at https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC. -You can also try out unquantized Llama-3 model by replacing `q4f16_1` to `q0f16` in the examples below. -Please visit our [documentation](https://llm.mlc.ai/docs/index.html) for detailed quick start and introduction. - -### Installation +## Installation MLC LLM is available via [pip](https://llm.mlc.ai/docs/install/mlc_llm.html#install-mlc-packages). It is always recommended to install it in an isolated conda virtual environment. @@ -75,102 +27,13 @@ python -c "import mlc_llm; print(mlc_llm.__path__)" You are expected to see the installation path of MLC LLM Python package. -### Chat CLI - -We can try out the chat CLI in MLC LLM with 4-bit quantized 8B Llama-3 model. - -```bash -mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC -``` - -It may take 1-2 minutes for the first time running this command. -After waiting, this command launch a chat interface where you can enter your prompt and chat with the model. - -``` -You can use the following special commands: -/help print the special commands -/exit quit the cli -/stats print out the latest stats (token/sec) -/reset restart a fresh chat -/set [overrides] override settings in the generation config. For example, - `/set temperature=0.5;max_gen_len=100;stop=end,stop` - Note: Separate stop words in the `stop` option with commas (,). -Multi-line input: Use escape+enter to start a new line. - -user: What's the meaning of life -assistant: -What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. - -The concept of the meaning of life has been debated and... -``` - -### Python API - -We can run the Llama-3 model with the chat completion Python API of MLC LLM. -You can save the code below into a Python file and run it. - -```python -from mlc_llm import LLMEngine - -# Create engine -model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" -engine = LLMEngine(model) - -# Run chat completion in OpenAI API. -for response in engine.chat.completions.create( - messages=[{"role": "user", "content": "What is the meaning of life?"}], - model=model, - stream=True, -): - for choice in response.choices: - print(choice.delta.content, end="", flush=True) -print("\n") - -engine.terminate() -``` - -**The Python API of `mlc_llm.LLMEngine` fully aligns with OpenAI API**. -You can use LLMEngine in the same way of using -[OpenAI's Python package](https://github.com/openai/openai-python?tab=readme-ov-file#usage) -for both synchronous and asynchronous generation. - -If you would like to do concurrent asynchronous generation, you can use `mlc_llm.AsyncLLMEngine` instead. - -### REST Server - -We can launch a REST server to serve the 4-bit quantized Llama-3 model for OpenAI chat completion requests. -The server has fully OpenAI API completeness. - -```bash -mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC -``` - -The server is hooked at `http://127.0.0.1:8000` by default, and you can use `--host` and `--port` -to set a different host and port. -When the server is ready (showing `INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)`), -we can open a new shell and send a cURL request via the following command: +## Quick Start -```bash -curl -X POST \ - -H "Content-Type: application/json" \ - -d '{ - "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", - "messages": [ - {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} - ] - }' \ - http://127.0.0.1:8000/v1/chat/completions -``` +Please check out our documentation for the [quick start](https://llm.mlc.ai/docs/get_started/quick_start.html). -## Universal Deployment APIs +## Introduction -MLC LLM provides multiple sets of APIs across platforms and environments. These include -* [Python API](https://llm.mlc.ai/docs/deploy/python_engine.html) -* [OpenAI-compatible Rest-API](https://llm.mlc.ai/docs/deploy/rest.html) -* [C++ API](https://llm.mlc.ai/docs/deploy/cli.html) -* [JavaScript API](https://llm.mlc.ai/docs/deploy/javascript.html) and [Web LLM](https://github.com/mlc-ai/web-llm) -* [Swift API for iOS App](https://llm.mlc.ai/docs/deploy/ios.html) -* [Java API and Android App](https://llm.mlc.ai/docs/deploy/android.html) +Please check out our documentation for the [introduction](https://llm.mlc.ai/docs/get_started/introduction.html). ## Links From 0ec6c7aa93093394f1e9f85d2ae15dbde6f9d29a Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 19 Apr 2024 18:17:08 -0400 Subject: [PATCH 214/531] [Misc] Pass env along when calling `subprocess.run` (#2179) The uses of `subprocess.run` in the codebase did not pass the environment, which may cause some issues in cases. --- python/mlc_llm/chat_module.py | 2 +- python/mlc_llm/cli/delivery.py | 10 ++++++++-- python/mlc_llm/interface/jit.py | 2 +- python/mlc_llm/support/auto_device.py | 3 +++ python/mlc_llm/support/download.py | 4 +++- 5 files changed, 16 insertions(+), 5 deletions(-) diff --git a/python/mlc_llm/chat_module.py b/python/mlc_llm/chat_module.py index 943f98c7e2..090bfab0bc 100644 --- a/python/mlc_llm/chat_module.py +++ b/python/mlc_llm/chat_module.py @@ -664,7 +664,7 @@ def _inspect_model_lib_metadata_memory_usage(model_lib_path, config_file_path): "--mlc-chat-config", config_file_path, ] - subprocess.run(cmd, check=False) + subprocess.run(cmd, check=False, env=os.environ) class ChatModule: # pylint: disable=too-many-instance-attributes diff --git a/python/mlc_llm/cli/delivery.py b/python/mlc_llm/cli/delivery.py index 50b9c7e170..a7dd6408b0 100644 --- a/python/mlc_llm/cli/delivery.py +++ b/python/mlc_llm/cli/delivery.py @@ -1,7 +1,9 @@ """Continuous model delivery for MLC LLM models.""" + import argparse import dataclasses import json +import os import shutil import subprocess import sys @@ -131,7 +133,9 @@ def _run_quantization( cmd += ["--" + optional_arg.replace("_", "-"), str(optional_arg_val)] print(" ".join(cmd), file=log_file, flush=True) - subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT) + subprocess.run( + cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ + ) cmd = [ sys.executable, "-m", @@ -146,7 +150,9 @@ def _run_quantization( output_dir, ] print(" ".join(cmd), file=log_file, flush=True) - subprocess.run(cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT) + subprocess.run( + cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ + ) logger.info("[MLC] Complete!") if not (Path(output_dir) / "ndarray-cache.json").exists(): logger.error( diff --git a/python/mlc_llm/interface/jit.py b/python/mlc_llm/interface/jit.py index 25548e0e4a..ecc2b0de0c 100644 --- a/python/mlc_llm/interface/jit.py +++ b/python/mlc_llm/interface/jit.py @@ -93,7 +93,7 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): ] logger.info("Compiling using commands below:") logger.info("%s", blue(shlex.join(cmd))) - subprocess.run(cmd, check=True) + subprocess.run(cmd, check=True, env=os.environ) shutil.move(dso_path, dst) logger.info("Using compiled model lib: %s", bold(dst)) diff --git a/python/mlc_llm/support/auto_device.py b/python/mlc_llm/support/auto_device.py index cf6d09495a..bddb9954c6 100644 --- a/python/mlc_llm/support/auto_device.py +++ b/python/mlc_llm/support/auto_device.py @@ -1,4 +1,6 @@ """Automatic detection of the device available on the local machine.""" + +import os import subprocess import sys from typing import Dict, Optional @@ -65,6 +67,7 @@ def _device_exists(device: Device) -> bool: capture_output=True, text=True, check=False, + env=os.environ, ) .stdout.strip() .splitlines() diff --git a/python/mlc_llm/support/download.py b/python/mlc_llm/support/download.py index a109c967bc..770833e9af 100644 --- a/python/mlc_llm/support/download.py +++ b/python/mlc_llm/support/download.py @@ -36,11 +36,13 @@ def git_clone(url: str, destination: Path, ignore_lfs: bool) -> None: command = ["git", "clone", url, repo_name] _ensure_directory_not_exist(destination, force_redo=False) try: + env = os.environ.copy() + env["GIT_LFS_SKIP_SMUDGE"] = "1" with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir: logger.info("[Git] Cloning %s to %s", bold(url), destination) subprocess.run( command, - env={"GIT_LFS_SKIP_SMUDGE": "1"}, + env=env, cwd=tmp_dir, check=True, stdout=subprocess.DEVNULL, From 132ad03077398f9e496cabb4a392df0e396c23c3 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Sat, 20 Apr 2024 00:04:12 -0400 Subject: [PATCH 215/531] Change OpenAI protocol default value to None and supply using model config (#2178) * Change OpenAI protocol default value to None and supply using model config * Fix lint --- .../mlc_llm/protocol/openai_api_protocol.py | 29 +++++++++++++------ python/mlc_llm/protocol/protocol_utils.py | 3 +- python/mlc_llm/serve/engine.py | 2 ++ python/mlc_llm/serve/engine_base.py | 6 ++-- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 1cbf0bd228..1a732488a0 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -88,8 +88,8 @@ class CompletionRequest(BaseModel): prompt: Union[str, List[int]] best_of: int = 1 echo: bool = False - frequency_penalty: float = 0.0 - presence_penalty: float = 0.0 + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None logprobs: bool = False top_logprobs: int = 0 logit_bias: Optional[Dict[int, float]] = None @@ -99,8 +99,8 @@ class CompletionRequest(BaseModel): stop: Optional[Union[str, List[str]]] = None stream: bool = False suffix: Optional[str] = None - temperature: float = 1.0 - top_p: float = 1.0 + temperature: Optional[float] = None + top_p: Optional[float] = None user: Optional[str] = None ignore_eos: bool = False response_format: Optional[RequestResponseFormat] = None @@ -201,8 +201,8 @@ class ChatCompletionRequest(BaseModel): messages: List[ChatCompletionMessage] model: str - frequency_penalty: float = 0.0 - presence_penalty: float = 0.0 + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None logprobs: bool = False top_logprobs: int = 0 logit_bias: Optional[Dict[int, float]] = None @@ -211,8 +211,8 @@ class ChatCompletionRequest(BaseModel): seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stream: bool = False - temperature: float = 1.0 - top_p: float = 1.0 + temperature: Optional[float] = None + top_p: Optional[float] = None tools: Optional[List[ChatTool]] = None tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None user: Optional[str] = None @@ -386,7 +386,7 @@ def openai_api_get_unsupported_fields( def openai_api_get_generation_config( - request: Union[CompletionRequest, ChatCompletionRequest] + request: Union[CompletionRequest, ChatCompletionRequest], model_config: Dict[str, Any] ) -> Dict[str, Any]: """Create the generation config from the given request.""" from ..serve.config import ResponseFormat # pylint: disable=import-outside-toplevel @@ -407,6 +407,17 @@ def openai_api_get_generation_config( ] for arg_name in arg_names: kwargs[arg_name] = getattr(request, arg_name) + + # If per-request generation config values are missing, try loading from model config. + # If still not found, then use the default OpenAI API value + if kwargs["temperature"] is None: + kwargs["temperature"] = model_config.get("temperature", 1.0) + if kwargs["top_p"] is None: + kwargs["top_p"] = model_config.get("top_p", 1.0) + if kwargs["frequency_penalty"] is None: + kwargs["frequency_penalty"] = model_config.get("frequency_penalty", 0.0) + if kwargs["presence_penalty"] is None: + kwargs["presence_penalty"] = model_config.get("presence_penalty", 0.0) if kwargs["max_tokens"] is None: # Setting to -1 means the generation will not stop until # exceeding model capability or hit any stop criteria. diff --git a/python/mlc_llm/protocol/protocol_utils.py b/python/mlc_llm/protocol/protocol_utils.py index f4273d0302..3005909bbd 100644 --- a/python/mlc_llm/protocol/protocol_utils.py +++ b/python/mlc_llm/protocol/protocol_utils.py @@ -23,13 +23,14 @@ def get_unsupported_fields(request: RequestProtocol) -> List[str]: def get_generation_config( request: RequestProtocol, + model_config: Dict[str, Any], extra_stop_token_ids: Optional[List[int]] = None, extra_stop_str: Optional[List[str]] = None, ) -> GenerationConfig: """Create the generation config in MLC LLM out from the input request protocol.""" kwargs: Dict[str, Any] if isinstance(request, (OpenAICompletionRequest, OpenAIChatCompletionRequest)): - kwargs = openai_api_get_generation_config(request) + kwargs = openai_api_get_generation_config(request, model_config) else: raise RuntimeError("Cannot reach here") diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 3a329cae21..a84f98fb33 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -1189,6 +1189,7 @@ async def _handle_completion( request, request_id, self.state, + self.model_config_dicts[0], self.tokenizer, self.max_input_sequence_length, ) @@ -1729,6 +1730,7 @@ def _handle_completion( request, request_id, self.state, + self.model_config_dicts[0], self.tokenizer, self.max_input_sequence_length, ) diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 9a25401d3f..0f3e06f1bd 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -919,6 +919,7 @@ def process_chat_completion_request( # pylint: disable=too-many-arguments # Process generation config. Create request id. generation_cfg = protocol_utils.get_generation_config( request, + model_config, extra_stop_token_ids=conv_template.stop_token_ids, extra_stop_str=conv_template.stop_str, ) @@ -1039,10 +1040,11 @@ def process_chat_completion_stream_output( # pylint: disable=too-many-arguments return response, num_completion_tokens -def process_completion_request( +def process_completion_request( # pylint: disable=too-many-arguments request: openai_api_protocol.CompletionRequest, request_id: str, engine_state: EngineState, + model_config: Dict[str, Any], tokenizer: Tokenizer, max_input_sequence_length: int, ) -> Tuple[List[int], GenerationConfig, int, Optional[openai_api_protocol.CompletionResponse]]: @@ -1094,7 +1096,7 @@ def process_completion_request( assert isinstance(prompt, list) # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config(request) + generation_cfg = protocol_utils.get_generation_config(request, model_config) # - Echo back the prompt. echo_response = None From d43e10e67c7629ecd07028b63b8fd173cbef92ea Mon Sep 17 00:00:00 2001 From: DearFishi <89983913+DearFishi@users.noreply.github.com> Date: Sun, 21 Apr 2024 05:31:13 +0800 Subject: [PATCH 216/531] [Serving][Spec] Fix the output inconsistent bug of q0f32 spec decoding (#2184) - According to https://github.com/mlc-ai/mlc-llm/issues/2167, the problem that the output of spec decoding in q0f32 is inconsistent with the single model of q0f32 has been fixed. - Modified the test_engine_generate function located in `tests/python/serve/test_serve_engine_spec.py` to support comparison of the output of a single model and the output of spec decoding - The accuracy comparison with hugging face is left (because the current version of llama-2-7b of q0f32 cannot be consistent with the output of hugging face model) - The output of spec decoding for q0f16 cannot be consistent with the output of a single model of q0f16, but this may be due to floating point errors. Co-authored-by: DearFishi --- cpp/serve/engine_actions/batch_verify.cc | 4 +- tests/python/serve/test_serve_engine_spec.py | 51 ++++++++++++++++++-- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 6f38292ba3..aa51b647c0 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -128,10 +128,8 @@ class BatchVerifyActionObj : public EngineActionObj { rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result); } estate->stats.total_accepted_length += accept_length; - // - Minus one because the last draft token has no kv cache entry - // - Take max with 0 in case of all accepted. int rollback_length = - std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length - 1, 0); + std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0); // rollback kv cache // NOTE: when number of small models is more than 1 (in the future), // it is possible to re-compute prefill for the small models. diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 60be02ce1a..6647c7af19 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -364,7 +364,19 @@ def step(self) -> None: # assert fin_time == request.generation_config.max_tokens - 1 -def test_engine_generate(): +def compare_output_text(output_text1, output_text2): + if isinstance(output_text1, list) and isinstance(output_text2, list): + for item1, item2 in zip(output_text1, output_text2): + if not compare_output_text(item1, item2): + return False + elif output_text1 != output_text2: + print(output_text1) + print(output_text2) + return False + return True + + +def test_engine_generate(compare_precision=False): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" @@ -372,6 +384,7 @@ def test_engine_generate(): small_model_lib_path = ( "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) + engine = SyncLLMEngine( model=model, model_lib_path=model_lib_path, @@ -385,9 +398,31 @@ def test_engine_generate(): max_tokens = 256 # Generate output. - output_texts, _ = engine.generate( - prompts[:num_requests], GenerationConfig(max_tokens=max_tokens, n=3) - ) + if compare_precision: + print("compare precision") + generation_config = GenerationConfig( + temperature=0.0, top_p=0, max_tokens=1024, stop_token_ids=[2], n=1 + ) + engine_single_model = SyncLLMEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + output_texts_single_model, _ = engine_single_model.generate( + prompts[:num_requests], generation_config + ) + for req_id, outputs in enumerate(output_texts_single_model): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + # TODO: Add pytorch precision + else: + generation_config = GenerationConfig(max_tokens=max_tokens, n=3) + output_texts, _ = engine.generate(prompts[:num_requests], generation_config) for req_id, outputs in enumerate(output_texts): print(f"Prompt {req_id}: {prompts[req_id]}") if len(outputs) == 1: @@ -395,6 +430,12 @@ def test_engine_generate(): else: for i, output in enumerate(outputs): print(f"Output {req_id}({i}):{output}\n") + if compare_precision: + precision_flag = compare_output_text(output_texts, output_texts_single_model) + if precision_flag: + print(f"Accuracy verification succeed\n") + else: + print(f"Accuracy verification failed\n") def test_engine_eagle_generate(): @@ -643,7 +684,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): test_engine_eagle_basic() test_engine_continuous_batching_1() test_engine_eagle_continuous_batching_1() - test_engine_generate() + test_engine_generate(compare_precision=True) test_engine_eagle_generate() test_engine_efficiency() test_engine_spec_efficiency() From 54a679474aeb17757eea46d44e2e314a7a803900 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 20 Apr 2024 22:23:39 -0400 Subject: [PATCH 217/531] [Serving] Support ThreadedEngine Reload/Unload/Reset (#2185) This PR brings the support of reload (reload the engine with a new model), unload (unload the current running model) and reset (reset the engine to the initial states without unloading) to ThreadedEngine and JSONFFIEngine. These functions are useful for app bindings for iOS/Android. --- cpp/json_ffi/json_ffi_engine.cc | 9 +++ cpp/serve/engine.cc | 26 +++++- cpp/serve/engine.h | 3 + cpp/serve/threaded_engine.cc | 79 +++++++++++++++++-- cpp/serve/threaded_engine.h | 9 +++ tests/python/json_ffi/test_json_ffi_engine.py | 77 +++++++++++++----- 6 files changed, 175 insertions(+), 28 deletions(-) diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index b02a28ca89..fc26c46b26 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -103,6 +103,9 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { public: TVM_MODULE_VTABLE_BEGIN("mlc.json_ffi"); TVM_MODULE_VTABLE_ENTRY("init_background_engine", &JSONFFIEngineImpl::InitBackgroundEngine); + TVM_MODULE_VTABLE_ENTRY("reload", &JSONFFIEngineImpl::Reload); + TVM_MODULE_VTABLE_ENTRY("unload", &JSONFFIEngineImpl::Unload); + TVM_MODULE_VTABLE_ENTRY("reset", &JSONFFIEngineImpl::Reset); TVM_MODULE_VTABLE_ENTRY("chat_completion", &JSONFFIEngineImpl::ChatCompletion); TVM_MODULE_VTABLE_ENTRY("abort", &JSONFFIEngineImpl::Abort); TVM_MODULE_VTABLE_ENTRY("get_last_error", &JSONFFIEngineImpl::GetLastError); @@ -133,6 +136,12 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { std::move(engine_config), std::move(request_stream_callback), std::move(trace_recorder)); } + void Reload(EngineConfig engine_config) { this->engine_->Reload(std::move(engine_config)); } + + void Unload() { this->engine_->Unload(); } + + void Reset() { this->engine_->Reset(); } + void RunBackgroundLoop() { this->engine_->RunBackgroundLoop(); } void RunBackgroundStreamBackLoop() { this->engine_->RunBackgroundStreamBackLoop(); } diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 85d1c66c2d..8e47564945 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -143,6 +143,7 @@ class EngineImpl : public Engine { } void Reset() final { + AbortAllRequests(); estate_->Reset(); for (Model model : models_) { model->Reset(); @@ -167,7 +168,8 @@ class EngineImpl : public Engine { request = Request::FromUntokenized(request, tokenizer_); ICHECK_NE(request->input_total_length, -1); - if (request->input_total_length >= engine_config_->max_single_sequence_length) { + if (request->input_total_length >= engine_config_->max_single_sequence_length && + request_stream_callback_.defined()) { // If the request input length exceeds the maximum allowed single sequence length, // invoke callback and do not process the request. Array output{RequestStreamOutput( @@ -240,6 +242,28 @@ class EngineImpl : public Engine { // The request to abort is in waiting queue estate_->waiting_queue.erase(it_waiting); } + + // Send a callback to notice the abortion. + if (request_stream_callback_.defined()) { + Array output{RequestStreamOutput( + request_id, std::vector(request->generation_cfg->n), + Optional>>(), + std::vector>(request->generation_cfg->n, String("abort")))}; + request_stream_callback_.value()(std::move(output)); + } + } + + void AbortAllRequests() final { + // - Collect all the request ids. + std::vector request_ids; + request_ids.reserve(estate_->request_states.size()); + for (const auto& kv : estate_->request_states) { + request_ids.push_back(kv.first); + } + // - Abort all the requests. + for (const String& request_id : request_ids) { + AbortRequest(request_id); + } } /*********************** Engine Action ***********************/ diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index fc5e4205ae..bcc1b80988 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -82,6 +82,9 @@ class Engine { /*! \brief Abort the input request (specified by id string) from engine. */ virtual void AbortRequest(const String& request_id) = 0; + /*! \brief Abort all requests from the engine. */ + virtual void AbortAllRequests() = 0; + /*********************** Engine Action ***********************/ /*! diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index 458d2ae5d7..b9def964c4 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -29,7 +29,8 @@ enum class InstructionKind : int { kAbortRequest = 1, kUnloadEngine = 2, kReloadEngine = 3, - kDebugCallFuncOnAllAllWorker = 4, + kResetEngine = 4, + kDebugCallFuncOnAllAllWorker = 5, }; /*! \brief The implementation of ThreadedEngine. */ @@ -41,6 +42,7 @@ class ThreadedEngineImpl : public ThreadedEngine { CHECK(request_stream_callback.defined()) << "ThreadedEngine requires request stream callback function, but it is not given."; request_stream_callback_ = request_stream_callback.value(); + trace_recorder_ = trace_recorder; auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { ICHECK_EQ(args.size(), 1); @@ -62,6 +64,45 @@ class ThreadedEngineImpl : public ThreadedEngine { std::move(engine_config), std::move(request_stream_callback), std::move(trace_recorder)); } + void Reload(EngineConfig engine_config) final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kReloadEngine, std::move(engine_config)); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } + } + + void Unload() final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kUnloadEngine, ObjectRef(nullptr)); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } + } + + void Reset() final { + bool need_notify = false; + { + std::lock_guard lock(background_loop_mutex_); + instruction_queue_.emplace_back(InstructionKind::kResetEngine, ObjectRef(nullptr)); + ++pending_request_operation_cnt_; + need_notify = engine_waiting_; + } + if (need_notify) { + background_loop_cv_.notify_one(); + } + } + void AddRequest(Request request) final { bool need_notify = false; { @@ -97,7 +138,8 @@ class ThreadedEngineImpl : public ThreadedEngine { std::unique_lock lock(background_loop_mutex_); engine_waiting_ = true; background_loop_cv_.wait(lock, [this] { - return !background_engine_->Empty() || pending_request_operation_cnt_.load() > 0 || + return (background_engine_ != nullptr && !background_engine_->Empty()) || + pending_request_operation_cnt_.load() > 0 || exit_now_.load(std::memory_order_relaxed); }); engine_waiting_ = false; @@ -108,22 +150,31 @@ class ThreadedEngineImpl : public ThreadedEngine { } for (const auto& [kind, arg] : local_instruction_queue) { if (kind == InstructionKind::kAddRequest) { + CHECK(background_engine_ != nullptr) << "Background engine is not loaded."; background_engine_->AddRequest(Downcast(arg)); } else if (kind == InstructionKind::kAbortRequest) { + CHECK(background_engine_ != nullptr) << "Background engine is not loaded."; background_engine_->AbortRequest(Downcast(arg)); } else if (kind == InstructionKind::kUnloadEngine) { - // Todo(mlc-team): implement engine unload - LOG(FATAL) << "Not implemented yet."; + EngineUnloadImpl(); } else if (kind == InstructionKind::kReloadEngine) { - // Todo(mlc-team): implement engine reload - LOG(FATAL) << "Not implemented yet."; + EngineUnloadImpl(); + InitBackgroundEngine(Downcast(arg), request_stream_callback_, + trace_recorder_); + } else if (kind == InstructionKind::kResetEngine) { + if (background_engine_ != nullptr) { + background_engine_->Reset(); + } } else if (kind == InstructionKind::kDebugCallFuncOnAllAllWorker) { + CHECK(background_engine_ != nullptr) << "Background engine is not loaded."; background_engine_->DebugCallFuncOnAllAllWorker(Downcast(arg)); } else { LOG(FATAL) << "Cannot reach here"; } } - background_engine_->Step(); + if (background_engine_ != nullptr) { + background_engine_->Step(); + } } } @@ -184,10 +235,24 @@ class ThreadedEngineImpl : public ThreadedEngine { } private: + void EngineUnloadImpl() { + if (background_engine_ != nullptr) { + background_engine_->AbortAllRequests(); + background_engine_ = nullptr; + // Clear the allocated memory in cached memory pool. + const PackedFunc* fclear_memory_manager = + tvm::runtime::Registry::Get("vm.builtin.memory_manager.clear"); + ICHECK(fclear_memory_manager) << "Cannot find env function vm.builtin.memory_manager.clear"; + (*fclear_memory_manager)(); + } + } + /*! \brief The background normal engine for request processing. */ std::unique_ptr background_engine_; /*! \brief The request stream callback. */ PackedFunc request_stream_callback_; + /*! \brief Event trace recorder. */ + Optional trace_recorder_; /*! \brief The mutex ensuring only one thread can access critical regions. */ std::mutex background_loop_mutex_; diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index 3d11ba36f1..da969fe879 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -43,6 +43,15 @@ class ThreadedEngine { Optional request_stream_callback, Optional trace_recorder) = 0; + /*! \brief Reload the engine with the new engine config. */ + virtual void Reload(EngineConfig engine_config) = 0; + + /*! \brief Unload the background engine. */ + virtual void Unload() = 0; + + /*! \brief Reset the engine to the initial state. */ + virtual void Reset() = 0; + /*! \brief Starts the background request processing loop. */ virtual void RunBackgroundLoop() = 0; diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index b86fd423a9..578463066b 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -111,6 +111,9 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals key: module[key] for key in [ "init_background_engine", + "reload", + "unload", + "reset", "chat_completion", "abort", "get_last_error", @@ -121,22 +124,24 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals } self.tokenizer = Tokenizer(model_args[0][0]) + self.engine_config = EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + device=device, + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ) + def _background_loop(): self._ffi["init_background_engine"]( - EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ), + self.engine_config, self.state.get_request_stream_callback(), None, ) @@ -251,8 +256,17 @@ def _handle_chat_completion( self._ffi["abort"](request_id) raise exception + def _test_reload(self): + self._ffi["reload"](self.engine_config) + + def _test_reset(self): + self._ffi["reset"]() + + def _test_unload(self): + self._ffi["unload"]() + -def test_chat_completion(engine: JSONFFIEngine): +def run_chat_completion(engine: JSONFFIEngine, model: str): num_requests = 2 max_tokens = 64 n = 1 @@ -284,13 +298,27 @@ def test_chat_completion(engine: JSONFFIEngine): print(f"Output {req_id}({i}):{output}\n") -def test_malformed_request(engine: JSONFFIEngine): +def test_chat_completion(): + # Create engine. + model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + engine = JSONFFIEngine( + model, + model_lib_path=model_lib_path, + max_total_sequence_length=1024, + ) + + run_chat_completion(engine, model) + + # Test malformed requests. for response in engine._handle_chat_completion("malformed_string", n=1, request_id="123"): assert len(response.choices) == 1 assert response.choices[0].finish_reason == "error" + engine.terminate() -if __name__ == "__main__": + +def test_reload_reset_unload(): # Create engine. model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" @@ -300,8 +328,17 @@ def test_malformed_request(engine: JSONFFIEngine): max_total_sequence_length=1024, ) - test_chat_completion(engine) - test_malformed_request(engine) + # Run chat completion before and after reload/reset. + run_chat_completion(engine, model) + engine._test_reload() + run_chat_completion(engine, model) + engine._test_reset() + run_chat_completion(engine, model) + engine._test_unload() engine.terminate() - del engine + + +if __name__ == "__main__": + test_chat_completion() + test_reload_reset_unload() From 81862034b2da8dd579e08bba87dc0e5afaa46f65 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Sun, 21 Apr 2024 07:07:43 -0400 Subject: [PATCH 218/531] [WASM] Support grammar schema in wasm (#2187) --- cpp/serve/grammar/grammar_state_matcher.cc | 5 +++-- cpp/serve/grammar/json_schema_converter.cc | 8 ++++++++ python/mlc_llm/serve/grammar.py | 2 +- web/emcc/mlc_wasm_runtime.cc | 3 +++ 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index d9954f1e28..5c4ef98efe 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -469,9 +469,10 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenizer") TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenTable") .set_body([](TVMArgs args, TVMRetValue* rv) { BNFGrammar grammar = args[0]; + Array token_table_arr = args[1]; std::vector token_table; - for (int i = 1; i < args.size() - 1; ++i) { - token_table.push_back(args[i]); + for (int i = 0; i < token_table_arr.size(); ++i) { + token_table.push_back(token_table_arr[i]); } int max_rollback_steps = args[args.size() - 1]; auto init_ctx = GrammarStateMatcher::CreateInitContext(grammar, token_table); diff --git a/cpp/serve/grammar/json_schema_converter.cc b/cpp/serve/grammar/json_schema_converter.cc index 93d693f3c6..83be710cf5 100644 --- a/cpp/serve/grammar/json_schema_converter.cc +++ b/cpp/serve/grammar/json_schema_converter.cc @@ -23,6 +23,14 @@ namespace serve { using namespace tvm::runtime; +// EMCC somehow cannot pickup operator overload from picojson.h, so we copy here. +#ifdef COMPILE_MLC_WASM_RUNTIME +inline std::ostream& operator<<(std::ostream& os, const picojson::value& x) { + x.serialize(std::ostream_iterator(os)); + return os; +} +#endif + /*! * \brief Manage the indent and separator for the generation of EBNF grammar. * \param indent The number of spaces for each indent. If it is std::nullopt, there will be no diff --git a/python/mlc_llm/serve/grammar.py b/python/mlc_llm/serve/grammar.py index d5ad862a42..cf491884c2 100644 --- a/python/mlc_llm/serve/grammar.py +++ b/python/mlc_llm/serve/grammar.py @@ -247,7 +247,7 @@ def __init__( self.__init_handle_by_constructor__( _ffi_api.GrammarStateMatcherFromTokenTable, # type: ignore # pylint: disable=no-member grammar, - *tokenizer, + tokenizer, max_rollback_steps, ) else: diff --git a/web/emcc/mlc_wasm_runtime.cc b/web/emcc/mlc_wasm_runtime.cc index 3f05eb259f..b9a7f55bfa 100644 --- a/web/emcc/mlc_wasm_runtime.cc +++ b/web/emcc/mlc_wasm_runtime.cc @@ -29,6 +29,8 @@ // Pass in COMPILE_MLC_WASM_RUNTIME so unsupported code would not be compiled in to the .bc file #define COMPILE_MLC_WASM_RUNTIME 1 +#define __STDC_FORMAT_MACROS 1 +#define PICOJSON_USE_INT64 #define DMLC_USE_LOGGING_LIBRARY @@ -38,4 +40,5 @@ #include "serve/grammar/grammar_serializer.cc" #include "serve/grammar/grammar_simplifier.cc" #include "serve/grammar/grammar_state_matcher.cc" +#include "serve/grammar/json_schema_converter.cc" #include "support/encoding.cc" From 4994c5cc172a19441c52d6226f86ac432ba37abd Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 21 Apr 2024 19:54:40 -0400 Subject: [PATCH 219/531] [Serving] Support loading system library (#2189) This PR introduces the support of loading system libraries. Now in engine reload, when the given library path starts with `"system://"`, we recognize this as a system library and will try to load the the library from the path after the `"system://"` prefix. This PR also decouples the InitBackgroundEngine of ThreadedEngine into two parts, where the reload is now called explicitly when initializing the engine. This can be also done for the JSONFFIEngine. However, we need to move the construction of streamers in JSONFFIEngine before doing the same thing for JSONFFIEngine. So this is marked as a TODO item. --- cpp/json_ffi/json_ffi_engine.cc | 8 +++- cpp/serve/function_table.cc | 17 +++++-- cpp/serve/threaded_engine.cc | 47 ++++++++++--------- cpp/serve/threaded_engine.h | 9 ++-- cpp/support/utils.h | 17 +++++++ python/mlc_llm/serve/engine_base.py | 39 ++++++++------- tests/python/json_ffi/test_json_ffi_engine.py | 10 ++-- 7 files changed, 91 insertions(+), 56 deletions(-) diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index fc26c46b26..2f5bf49ce3 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -118,6 +118,9 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { void InitBackgroundEngine(EngineConfig engine_config, Optional request_stream_callback, Optional trace_recorder) { + // Todo(mlc-team): decouple InitBackgroundEngine into two functions + // by removing `engine_config` from arguments, after properly handling + // streamers. this->streamer_ = TextStreamer(Tokenizer::FromPath(engine_config->model)); CHECK(request_stream_callback.defined()) @@ -132,8 +135,9 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { }; request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - this->engine_->InitBackgroundEngine( - std::move(engine_config), std::move(request_stream_callback), std::move(trace_recorder)); + this->engine_->InitBackgroundEngine(std::move(request_stream_callback), + std::move(trace_recorder)); + this->engine_->Reload(std::move(engine_config)); } void Reload(EngineConfig engine_config) { this->engine_->Reload(std::move(engine_config)); } diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 289abfda16..823d3c6164 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -130,14 +130,23 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object this->_InitFunctions(); } else { Module executable{nullptr}; - if (false) { - // Todo(mlc-team): system lib reload // reload_lib_path starts with "system://" + PackedFunc fload_exec{nullptr}; + if (StartsWith(reload_lib_path, "system://")) { + const PackedFunc* f_load_system_lib = Registry::Get("runtime.SystemLib"); + ICHECK_NOTNULL(f_load_system_lib); + std::string system_lib_prefix = std::string(reload_lib_path).substr(9); + std::replace(system_lib_prefix.begin(), system_lib_prefix.end(), /*old=*/'-', /*new=*/'_'); + executable = (*f_load_system_lib)(system_lib_prefix + "_"); + fload_exec = executable->GetFunction("vm_load_executable"); + ICHECK(fload_exec.defined()) + << "Cannot find system lib with " << system_lib_prefix + << ", please make sure you set model_lib field consistently with the compilation "; } else { executable = tvm::runtime::Module::LoadFromFile(reload_lib_path); + fload_exec = executable->GetFunction("vm_load_executable"); + ICHECK(fload_exec.defined()) << "TVM runtime cannot find vm_load_executable"; } this->use_disco = false; - auto fload_exec = executable->GetFunction("vm_load_executable"); - ICHECK(fload_exec.defined()) << "TVM runtime cannot find vm_load_executable"; this->local_vm = fload_exec(); this->local_vm->GetFunction("vm_initialization")( static_cast(device.device_type), device.device_id, diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index b9def964c4..f234dfbbc3 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -36,32 +36,12 @@ enum class InstructionKind : int { /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineImpl : public ThreadedEngine { public: - void InitBackgroundEngine(EngineConfig engine_config, - Optional request_stream_callback, + void InitBackgroundEngine(Optional request_stream_callback, Optional trace_recorder) final { CHECK(request_stream_callback.defined()) << "ThreadedEngine requires request stream callback function, but it is not given."; request_stream_callback_ = request_stream_callback.value(); trace_recorder_ = trace_recorder; - - auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { - ICHECK_EQ(args.size(), 1); - Array delta_outputs = args[0]; - bool need_notify = false; - { - std::lock_guard lock(request_stream_callback_mutex_); - request_stream_callback_inputs_.push_back(std::move(delta_outputs)); - ++pending_request_stream_callback_cnt_; - need_notify = stream_callback_waiting_; - } - if (need_notify) { - request_stream_callback_cv_.notify_one(); - } - }; - - request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - background_engine_ = Engine::Create( - std::move(engine_config), std::move(request_stream_callback), std::move(trace_recorder)); } void Reload(EngineConfig engine_config) final { @@ -159,8 +139,7 @@ class ThreadedEngineImpl : public ThreadedEngine { EngineUnloadImpl(); } else if (kind == InstructionKind::kReloadEngine) { EngineUnloadImpl(); - InitBackgroundEngine(Downcast(arg), request_stream_callback_, - trace_recorder_); + EngineReloadImpl(Downcast(arg)); } else if (kind == InstructionKind::kResetEngine) { if (background_engine_ != nullptr) { background_engine_->Reset(); @@ -235,6 +214,27 @@ class ThreadedEngineImpl : public ThreadedEngine { } private: + void EngineReloadImpl(EngineConfig engine_config) { + auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { + ICHECK_EQ(args.size(), 1); + Array delta_outputs = args[0]; + bool need_notify = false; + { + std::lock_guard lock(request_stream_callback_mutex_); + request_stream_callback_inputs_.push_back(std::move(delta_outputs)); + ++pending_request_stream_callback_cnt_; + need_notify = stream_callback_waiting_; + } + if (need_notify) { + request_stream_callback_cv_.notify_one(); + } + }; + + Optional request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); + background_engine_ = Engine::Create(std::move(engine_config), + std::move(request_stream_callback), trace_recorder_); + } + void EngineUnloadImpl() { if (background_engine_ != nullptr) { background_engine_->AbortAllRequests(); @@ -302,6 +302,7 @@ class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { public: TVM_MODULE_VTABLE_BEGIN("mlc.serve.async_threaded_engine"); TVM_MODULE_VTABLE_ENTRY("init_background_engine", &ThreadedEngineImpl::InitBackgroundEngine); + TVM_MODULE_VTABLE_ENTRY("reload", &ThreadedEngineImpl::Reload); TVM_MODULE_VTABLE_ENTRY("add_request", &ThreadedEngineImpl::AddRequest); TVM_MODULE_VTABLE_ENTRY("abort_request", &ThreadedEngineImpl::AbortRequest); TVM_MODULE_VTABLE_ENTRY("run_background_loop", &ThreadedEngineImpl::RunBackgroundLoop); diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index da969fe879..f3d9c2b70c 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -35,15 +35,16 @@ class ThreadedEngine { /*! * \brief Initialize the threaded engine from packed arguments in TVMArgs. - * \param engine_config The engine config. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. */ - virtual void InitBackgroundEngine(EngineConfig engine_config, - Optional request_stream_callback, + virtual void InitBackgroundEngine(Optional request_stream_callback, Optional trace_recorder) = 0; - /*! \brief Reload the engine with the new engine config. */ + /*! + * \brief Reload the engine with the new engine config. + * \param engine_config The engine config. + */ virtual void Reload(EngineConfig engine_config) = 0; /*! \brief Unload the background engine. */ diff --git a/cpp/support/utils.h b/cpp/support/utils.h index 5360f0496c..6c53e35715 100644 --- a/cpp/support/utils.h +++ b/cpp/support/utils.h @@ -10,6 +10,7 @@ namespace mlc { namespace llm { +/*! \brief Split the input string by the given delimiter character. */ inline std::vector Split(const std::string& str, char delim) { std::string item; std::istringstream is(str); @@ -20,5 +21,21 @@ inline std::vector Split(const std::string& str, char delim) { return ret; } +/*! + * \brief Check whether the string starts with a given prefix. + * \param str The given string. + * \param prefix The given prefix. + * \return Whether the prefix matched. + */ +inline bool StartsWith(const std::string& str, const char* prefix) { + size_t n = str.length(); + for (size_t i = 0; i < n; i++) { + if (prefix[i] == '\0') return true; + if (str.data()[i] != prefix[i]) return false; + } + // return true if the str is equal to the prefix + return prefix[n] == '\0'; +} + } // namespace llm } // namespace mlc diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 0f3e06f1bd..9b0f27723a 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -776,32 +776,35 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "abort_request", "run_background_loop", "run_background_stream_back_loop", + "reload", "init_background_engine", "exit_background_loop", "debug_call_func_on_all_worker", ] } self.tokenizer = Tokenizer(model_args[0][0]) + self._ffi["init_background_engine"]( + self.state.get_request_stream_callback(kind), + self.state.trace_recorder, + ) + self._ffi["reload"]( + EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + device=device, + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ) + ) def _background_loop(): - self._ffi["init_background_engine"]( - EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ), - self.state.get_request_stream_callback(kind), - self.state.trace_recorder, - ) self._ffi["run_background_loop"]() def _background_stream_back_loop(): diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index 578463066b..b8a8d492b9 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -138,13 +138,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ) + self._ffi["init_background_engine"]( + self.engine_config, + self.state.get_request_stream_callback(), + None, + ) def _background_loop(): - self._ffi["init_background_engine"]( - self.engine_config, - self.state.get_request_stream_callback(), - None, - ) self._ffi["run_background_loop"]() def _background_stream_back_loop(): From 830c908f6528eb513bbaa9ec797d0de447918799 Mon Sep 17 00:00:00 2001 From: Bohan Hou Date: Sun, 21 Apr 2024 20:45:35 -0400 Subject: [PATCH 220/531] [Op] Batch verify for speculative decoding (#2186) This PR adds batch verify for spec decode ---- Co-authored-by: Wuwei Lin --- python/mlc_llm/op/__init__.py | 2 + python/mlc_llm/op/batch_spec_verify.py | 170 ++++++++++++++++++++++ tests/python/op/test_batch_spec_verify.py | 146 +++++++++++++++++++ 3 files changed, 318 insertions(+) create mode 100644 python/mlc_llm/op/batch_spec_verify.py create mode 100644 tests/python/op/test_batch_spec_verify.py diff --git a/python/mlc_llm/op/__init__.py b/python/mlc_llm/op/__init__.py index 342568639d..b5db353a3b 100644 --- a/python/mlc_llm/op/__init__.py +++ b/python/mlc_llm/op/__init__.py @@ -1,6 +1,8 @@ """Extern module for compiler.""" + from . import moe_matmul, moe_misc from .attention import attention +from .batch_spec_verify import batch_spec_verify from .extern import configure, enable, get_store from .ft_gemm import faster_transformer_dequantize_gemm from .position_embedding import llama_rope diff --git a/python/mlc_llm/op/batch_spec_verify.py b/python/mlc_llm/op/batch_spec_verify.py new file mode 100644 index 0000000000..9cdbe2be21 --- /dev/null +++ b/python/mlc_llm/op/batch_spec_verify.py @@ -0,0 +1,170 @@ +"""Operators for batch verify in speculative decoding.""" + +from tvm.script import tir as T + +# mypy: disable-error-code="attr-defined,valid-type,name-defined" +# pylint: disable=too-many-locals,invalid-name,too-many-arguments, +# pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches + + +def batch_spec_verify(vocab_size): + """Batch draft verify function. This function verifies the token tree. + + Before calling the function + + - token_tree_parent_ptr[b] should store the root of the tree + + - draft_probs[node_id, :] stores the prob that samples the correspond tree node + - model_probs[node_id, :] stores the prob that should be used to sample its children + - Please note that the storage convention difference between model_probs and draft_probs + draft_probs was stored on the token node, while model_probs stores on the parent. + This is an intentional design since we can sample different child token with different + proposal draft probabilities, but the ground truth model_prob is unique per parent. + + After calling the function + - token_tree_parent_ptr[b] points to the last token accepted + - There should be a followup sample step that samples from model_probs[token_tree_parent_ptr[b], :] + This token will be appended to the token generated. + + This function will inplace update model_probs if a token was rejected and renormalization is needed. + + Parameters + ---------- + draft_probs: + The draft probability attached to each tree node + + draft_tokens: + The draft token in each node + + model_probs: + The model proability attached to each parent + + token_tree_first_child: + The first child of each tree node, if there is no child, it should be -1 + + token_tree_next_sibling + The next sibling of each tree node, if there is no next sibling, it should be -1 + + uniform_samples + Per node uniform sample used to check rejection + + token_tree_parent_ptr: + Current parent ptr state + """ + TX = 128 + + def _var(dtype="int32"): + return T.alloc_buffer((1,), dtype, scope="local") + + # fmt: off + @T.prim_func(private=True) + def _func( + var_draft_probs: T.handle, + var_draft_tokens: T.handle, + var_model_probs: T.handle, + var_token_tree_first_child: T.handle, + var_token_tree_next_sibling: T.handle, + var_uniform_samples: T.handle, + var_token_tree_parent_ptr: T.handle, + ): + """ + [ + blockIdx.x on batch, + threadIdx.x on vocab_size, + for loop over excessive amounts + ] + """ + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + num_nodes = T.int32(is_size_var=True) + nbatch = T.int32(is_size_var=True) + + draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size), "float32") + draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), "int32") + model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size), "float32") + token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), "int32") + token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), "int32") + uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,), "float32") + token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), "int32") + + with T.block("kernel"): + child_ptr = _var() + parent_ptr = _var() + child_token = _var() + done = _var("bool") + psum = _var("float32") + t0 = _var("float32") + model_prob_local = _var("float32") + draft_prob_local = _var("float32") + p_child = _var("float32") + q_child = _var("float32") + uniform_sample = _var("float32") + + pred_shared = T.alloc_buffer((1,), "bool", scope="shared") + pred_local = T.alloc_buffer((1,), "bool", scope="local") + + for _bx in T.thread_binding(0, nbatch, thread="blockIdx.x"): + for _tx in T.thread_binding(0, TX, thread="threadIdx.x"): + with T.block("CTA"): + # batch size + b = T.axis.S(nbatch, _bx) + tx = T.axis.S(TX, _tx) + + parent_ptr[0] = token_tree_parent_ptr[b] + child_ptr[0] = token_tree_first_child[parent_ptr[0]] + done[0] = False + + while T.Not(done[0]): + T.tvm_storage_sync("shared") # ensure all effects last round are visible + if child_ptr[0] == -1: + done[0] = True + T.tvm_storage_sync("shared") # sync before exit + else: + # decide to validate current ptr + if tx == 0: + child_token[0] = draft_tokens[child_ptr[0]] + p_child[0] = model_probs[parent_ptr[0], child_token[0]] + q_child[0] = draft_probs[child_ptr[0], child_token[0]] + uniform_sample[0] = uniform_samples[child_ptr[0]] + pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0] # use multiplication to avoid division by zero + T.tvm_storage_sync("shared") # make sure all read of model_probs are done + pred_local[0] = pred_shared[0] + + # accept the proposal, we move to child + if pred_local[0]: + parent_ptr[0] = child_ptr[0] + child_ptr[0] = token_tree_first_child[child_ptr[0]] + else: + psum[0] = 0.0 + # renormalize probability, predicated by stopped_expansion[b]: + for i in T.serial(T.ceildiv(vocab_size, TX)): + k = T.meta_var(i * TX + tx) + if k < vocab_size: + model_prob_local[0] = model_probs[parent_ptr[0], k] + draft_prob_local[0] = draft_probs[child_ptr[0], k] + model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0) + model_probs[parent_ptr[0], k] = model_prob_local[0] + psum[0] += model_prob_local[0] + + with T.block("block_cross_thread"): + T.reads(psum[0]) + T.writes(t0[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), psum[0], True, t0[0], tx, dtype="handle") + + # renormalize + for i in T.serial(T.ceildiv(vocab_size, TX)): + k = T.meta_var(i * TX + tx) + if k < vocab_size: + model_probs[parent_ptr[0], k] = model_probs[parent_ptr[0], k] / t0[0] + + child_ptr[0] = token_tree_next_sibling[child_ptr[0]] + + if tx == 0: + token_tree_parent_ptr[b] = parent_ptr[0] + # fmt: on + + return _func diff --git a/tests/python/op/test_batch_spec_verify.py b/tests/python/op/test_batch_spec_verify.py new file mode 100644 index 0000000000..359fafdbd0 --- /dev/null +++ b/tests/python/op/test_batch_spec_verify.py @@ -0,0 +1,146 @@ +import numpy as np +import pytest +import tvm +import tvm.testing + +from mlc_llm.op.batch_spec_verify import batch_spec_verify + + +@pytest.mark.parametrize("nbatch", [32, 64]) +@pytest.mark.parametrize("vocab", [3, 32, 64, 32000, 33, 65, 32001]) +@pytest.mark.parametrize("plist", [[0.5, 0.5], [1, 0], [0, 1]]) +def test_batch_spec_verify(nbatch, vocab, plist): + def numpy_reference( + draft_probs, + draft_tokens, + model_probs, + token_tree_first_child, + token_tree_next_sibling, + uniform_samples, + token_tree_parent_ptr, + ): + nbatch = token_tree_parent_ptr.shape[0] + for b in range(nbatch): + parent_ptr = token_tree_parent_ptr[b] + child_ptr = token_tree_first_child[parent_ptr] + while child_ptr != -1: + child_token = draft_tokens[child_ptr] + p_child = model_probs[parent_ptr, child_token] + q_child = draft_probs[child_ptr, child_token] + uniform_sample = uniform_samples[child_ptr] + if p_child / q_child >= uniform_sample: + parent_ptr = child_ptr + child_ptr = token_tree_first_child[child_ptr] + else: + model_probs[parent_ptr, :] = np.maximum( + model_probs[parent_ptr, :] - draft_probs[child_ptr, :], 0.0 + ) + psum = np.sum(model_probs[parent_ptr, :]) + model_probs[parent_ptr, :] /= psum + child_ptr = token_tree_next_sibling[child_ptr] + token_tree_parent_ptr[b] = parent_ptr + + np.random.seed(0) + + def gen_chain(num_nodes, base): + token_tree_first_child = list() + token_tree_next_sibling = list() + for i in range(num_nodes): + token_tree_first_child.append(base + i + 1 if i + 1 < num_nodes else -1) + token_tree_next_sibling.append(-1) + return token_tree_first_child, token_tree_next_sibling, base, base + 1 + + def gen_full_binary_tree(height, base): + token_tree_first_child = list() + token_tree_next_sibling = list() + num_nodes = 2**height - 1 + for i in range(num_nodes): + token_tree_first_child.append(base + i * 2 + 1 if i * 2 + 1 < num_nodes else -1) + token_tree_next_sibling.append(base + i * 2 + 2 if i * 2 + 2 < num_nodes else -1) + return token_tree_first_child, token_tree_next_sibling, base, base + 1 + + ### Inputs + num_nodes = 0 + token_tree_first_child = list() + token_tree_next_sibling = list() + token_tree_parent_ptr = list() + + for _ in range(nbatch): + choice = np.random.choice(2, 1, p=plist) + if choice == 0: + nodes_batch = np.random.randint(3, 32) + res = gen_chain(nodes_batch, num_nodes) + num_nodes += nodes_batch + else: + height = np.random.randint(3, 5) + res = gen_full_binary_tree(height, num_nodes) + num_nodes += 2**height - 1 + token_tree_first_child.extend(res[0]) + token_tree_next_sibling.extend(res[1]) + token_tree_parent_ptr.append(res[2]) + + token_tree_first_child = np.array(token_tree_first_child).astype("int32") + token_tree_next_sibling = np.array(token_tree_next_sibling).astype("int32") + token_tree_parent_ptr = np.array(token_tree_parent_ptr).astype("int32") + + draft_probs = np.random.rand(num_nodes, vocab).astype("float32") + draft_probs /= np.sum(draft_probs, axis=1, keepdims=True) + draft_tokens = np.random.randint(0, vocab, num_nodes).astype("int32") + model_probs = np.random.rand(num_nodes, vocab).astype("float32") + model_probs /= np.sum(model_probs, axis=1, keepdims=True) + uniform_samples = np.random.rand(num_nodes).astype("float32") + + ### TVM Inputs + dev = tvm.cuda(0) + draft_probs_tvm = tvm.nd.array(draft_probs, dev) + draft_tokens_tvm = tvm.nd.array(draft_tokens, dev) + model_probs_tvm = tvm.nd.array(model_probs, dev) + token_tree_first_child_tvm = tvm.nd.array(token_tree_first_child, dev) + token_tree_next_sibling_tvm = tvm.nd.array(token_tree_next_sibling, dev) + uniform_samples_tvm = tvm.nd.array(uniform_samples, dev) + token_tree_parent_ptr_tvm = tvm.nd.array(token_tree_parent_ptr, dev) + + # print("draft_probs", draft_probs) + # print("draft_tokens", draft_tokens) + # print("model_probs", model_probs) + # print("token_tree_first_child", token_tree_first_child) + # print("token_tree_next_sibling", token_tree_next_sibling) + # print("uniform_samples", uniform_samples) + # print("token_tree_parent_ptr", token_tree_parent_ptr) + + ### Numpy reference + numpy_reference( + draft_probs, + draft_tokens, + model_probs, + token_tree_first_child, + token_tree_next_sibling, + uniform_samples, + token_tree_parent_ptr, + ) + # print("model_probs", model_probs) + # print("token_tree_parent_ptr", token_tree_parent_ptr) + + ### TVM + kernel = batch_spec_verify(vocab) + mod = tvm.build(kernel, target="cuda") + mod( + draft_probs_tvm, + draft_tokens_tvm, + model_probs_tvm, + token_tree_first_child_tvm, + token_tree_next_sibling_tvm, + uniform_samples_tvm, + token_tree_parent_ptr_tvm, + ) + # print("model_probs", model_probs_tvm.asnumpy()) + # print("token_tree_parent_ptr", token_tree_parent_ptr_tvm.asnumpy()) + + tvm.testing.assert_allclose(model_probs, model_probs_tvm.asnumpy()) + tvm.testing.assert_allclose( + token_tree_parent_ptr, token_tree_parent_ptr_tvm.asnumpy(), rtol=0, atol=0 + ) + + +if __name__ == "__main__": + tvm.testing.main() From a1830c166ea64d884886a079cca6e594f4604d56 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 21 Apr 2024 20:46:14 -0400 Subject: [PATCH 221/531] [JIT] Better organize JIT and AOT handling (#2191) * [JIT] Better organize JIT and AOT handling Previously we do JIT when AOT lib lookup failed. The error message can become cryptic when JIT also fails, it will show up as cannot find None-vulkan.dll. This PR changes the behavior to only to lookup when model_lib_path is provided, or only to JIT when it is not. This will leads to cleaner error message overall. * Windows compact * More windows instructions --- docs/install/mlc_llm.rst | 7 +++++++ docs/install/tvm.rst | 9 ++++++++- python/mlc_llm/chat_module.py | 6 +++--- python/mlc_llm/interface/jit.py | 6 +++++- python/mlc_llm/serve/engine_base.py | 10 +++++++--- 5 files changed, 30 insertions(+), 8 deletions(-) diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index c6602559ae..7b64dce9fb 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -118,6 +118,13 @@ Select your operating system/compute platform and run the command in your termin python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly .. note:: + Make sure you also install vulkan loader and clang to avoid vulkan + not found error or clang not found(needed for jit compile) + + .. code-block:: bash + + conda install -c conda-forge clang libvulkan-loader + If encountering the error below: .. code-block:: bash diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst index 849152cce6..ed4977e5e3 100644 --- a/docs/install/tvm.rst +++ b/docs/install/tvm.rst @@ -112,6 +112,13 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly .. note:: + Make sure you also install vulkan loader and clang to avoid vulkan + not found error or clang not found(needed for jit compile) + + .. code-block:: bash + + conda install -c conda-forge clang libvulkan-loader + If encountering the error below: .. code-block:: bash @@ -213,7 +220,7 @@ While it is generally recommended to always use the prebuilt TVM Unity, if you r If you are using CUDA and your compute capability is above 80, then it is require to build with ``set(USE_FLASHINFER ON)``. Otherwise, you may run into ``Cannot find PackedFunc`` issue during runtime. - + To check your CUDA compute capability, you can use ``nvidia-smi --query-gpu=compute_cap --format=csv``. Once ``config.cmake`` is edited accordingly, kick off build with the commands below: diff --git a/python/mlc_llm/chat_module.py b/python/mlc_llm/chat_module.py index 090bfab0bc..24ad8faecf 100644 --- a/python/mlc_llm/chat_module.py +++ b/python/mlc_llm/chat_module.py @@ -768,7 +768,7 @@ def __init__( # pylint: disable=too-many-arguments self.chat_config = _get_chat_config(self.config_file_path, chat_config) # 4. Look up model library - try: + if model_lib_path is not None: self.model_lib_path = _get_lib_module_path( model, self.model_path, @@ -777,8 +777,8 @@ def __init__( # pylint: disable=too-many-arguments self.device.MASK2STR[self.device.device_type], self.config_file_path, ) - except FileNotFoundError: - logger.info("Model lib not found. Now compiling model lib on device...") + else: + logger.info("Now compiling model lib on device...") from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel self.model_lib_path = str( diff --git a/python/mlc_llm/interface/jit.py b/python/mlc_llm/interface/jit.py index ecc2b0de0c..e999a36468 100644 --- a/python/mlc_llm/interface/jit.py +++ b/python/mlc_llm/interface/jit.py @@ -93,7 +93,11 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): ] logger.info("Compiling using commands below:") logger.info("%s", blue(shlex.join(cmd))) - subprocess.run(cmd, check=True, env=os.environ) + subprocess.run(cmd, check=False, env=os.environ) + # note on windows: compilation can succeed but return code is still nonzero + # check whether file exists instead + if not os.path.isfile(dso_path): + raise RuntimeError("Cannot find compilation output, compilation failed") shutil.move(dso_path, dst) logger.info("Using compiled model lib: %s", bold(dst)) diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 9b0f27723a..23dea5d015 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -89,8 +89,10 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: if conversation is None: assert isinstance(chat_config.conv_template, Conversation) conversation = chat_config.conv_template - # Try look up model library, and do JIT compile if model library not found. - try: + + if model.model_lib_path is not None: + # do model lib search if the model lib path is provided + # error out if file not found model_lib_path = _get_lib_module_path( model=model.model, model_path=model_path, @@ -99,7 +101,9 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: device_name=device.MASK2STR[device.device_type], config_file_path=config_file_path, ) - except FileNotFoundError: + else: + # TODO(mlc-team) add logging information + # Run jit if model_lib_path is not provided from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel model_lib_path = str( From f1f5cd142305e711dc9d518fddea30b2e3d6e63f Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 22 Apr 2024 14:21:36 +0200 Subject: [PATCH 222/531] Fix prefill and context flag names in doc (#2192) * Update compile_models.rst Fix flag names for prefill chunk size and context window size. * Update compile_models.rst --- docs/compilation/compile_models.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index 00beb5cc4d..4706e09811 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -235,7 +235,7 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con RuntimeError: Cannot find libraries: wasm_runtime.bc .. note:: - For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or lower ``context_window_size`` to decrease memory usage. + For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage. Otherwise, you may run into issues like: .. code:: text @@ -664,7 +664,7 @@ generalized to any model variant, as long as mlc-llm supports the architecture. RuntimeError: Cannot find libraries: wasm_runtime.bc .. note:: - For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or lower ``context_window_size`` to decrease memory usage. + For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage. Otherwise, you may run into issues like: .. code:: text @@ -793,7 +793,7 @@ generalized to any model variant, as long as mlc-llm supports the architecture. RuntimeError: Cannot find libraries: wasm_runtime.bc .. note:: - For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or lower ``context_window_size`` to decrease memory usage. + For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage. Otherwise, you may run into issues like: .. code:: text From 17a2c6af623cd7e1bd027d5f1b2e1192aed17766 Mon Sep 17 00:00:00 2001 From: Ewout ter Hoeven Date: Mon, 22 Apr 2024 16:54:51 +0200 Subject: [PATCH 223/531] [Docs] Update quick start to mention Llama 3 8B (#2196) This commit updates the quick start to mention Llama 3 8B instead of Llama 2 7B. The code blocks where already updated. --- docs/get_started/quick_start.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/get_started/quick_start.rst b/docs/get_started/quick_start.rst index 604688f790..76d971275b 100644 --- a/docs/get_started/quick_start.rst +++ b/docs/get_started/quick_start.rst @@ -6,7 +6,7 @@ Quick Start Examples -------- -To begin with, try out MLC LLM support for int4-quantized Llama2 7B. +To begin with, try out MLC LLM support for int4-quantized Llama3 8B. It is recommended to have at least 6GB free VRAM to run it. .. tabs:: @@ -133,7 +133,7 @@ It is recommended to have at least 6GB free VRAM to run it. | - **Requirement**. Llama2-7B model needs an iOS device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. + **Requirement**. Llama3-8B model needs an iOS device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. **Tutorial and source code**. The source code of the iOS app is fully `open source `__, and a :ref:`tutorial ` is included in documentation. @@ -154,7 +154,7 @@ It is recommended to have at least 6GB free VRAM to run it. | - **Requirement**. Llama2-7B model needs a device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. + **Requirement**. Llama3-8B model needs a device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. The demo is tested on - Samsung S23 with Snapdragon 8 Gen 2 chip From 253cd0d0e122da7b50ad64e84cdcece8c09926f1 Mon Sep 17 00:00:00 2001 From: Kartik Khandelwal Date: Mon, 22 Apr 2024 17:22:10 -0400 Subject: [PATCH 224/531] [SERVING] Add Conv Template and Function Calling support to JSON FFI (#2190) This PR adds conv template support to the JSON FFI Engine. Also add function calling and pass stop str to generation config. Co-authored-by: Shrey Gupta --- cpp/json_ffi/conv_template.cc | 313 ++++++++++++++++++ cpp/json_ffi/conv_template.h | 121 +++++++ cpp/json_ffi/json_ffi_engine.cc | 59 ++-- cpp/json_ffi/json_ffi_engine.h | 2 + cpp/json_ffi/openai_api_protocol.cc | 278 +++++++++++++++- cpp/json_ffi/openai_api_protocol.h | 39 ++- cpp/serve/config.cc | 12 +- cpp/serve/config.h | 6 +- tests/python/json_ffi/test_json_ffi_engine.py | 64 +++- 9 files changed, 831 insertions(+), 63 deletions(-) create mode 100644 cpp/json_ffi/conv_template.cc create mode 100644 cpp/json_ffi/conv_template.h diff --git a/cpp/json_ffi/conv_template.cc b/cpp/json_ffi/conv_template.cc new file mode 100644 index 0000000000..02e0b3bdbd --- /dev/null +++ b/cpp/json_ffi/conv_template.cc @@ -0,0 +1,313 @@ +#include "conv_template.h" + +#include "../metadata/json_parser.h" + +namespace mlc { +namespace llm { +namespace json_ffi { + +using namespace mlc::llm; + +std::map PLACEHOLDERS = { + {MessagePlaceholders::SYSTEM, "{system_message}"}, + {MessagePlaceholders::USER, "{user_message}"}, + {MessagePlaceholders::ASSISTANT, "{assistant_message}"}, + {MessagePlaceholders::TOOL, "{tool_message}"}, + {MessagePlaceholders::FUNCTION, "{function_string}"}}; + +MessagePlaceholders MessagePlaceholderFromString(const std::string& role) { + static const std::unordered_map enum_map = { + {"system", MessagePlaceholders::SYSTEM}, {"user", MessagePlaceholders::USER}, + {"assistant", MessagePlaceholders::ASSISTANT}, {"tool", MessagePlaceholders::TOOL}, + {"function", MessagePlaceholders::FUNCTION}, + }; + + return enum_map.at(role); +} + +Conversation::Conversation() + : role_templates({{"user", PLACEHOLDERS[MessagePlaceholders::USER]}, + {"assistant", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]}, + {"tool", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {} + +std::vector Conversation::CheckMessageSeps(std::vector& seps) { + if (seps.size() == 0 || seps.size() > 2) { + throw std::invalid_argument("seps should have size 1 or 2."); + } + return seps; +} + +std::optional> Conversation::AsPrompt(std::string* err) { + // Get the system message + std::string system_msg = system_template; + size_t pos = system_msg.find(PLACEHOLDERS[MessagePlaceholders::SYSTEM]); + if (pos != std::string::npos) { + system_msg.replace(pos, PLACEHOLDERS[MessagePlaceholders::SYSTEM].length(), + this->system_message); + } + + // Get the message strings + std::vector message_list; + std::vector separators = seps; + if (separators.size() == 1) { + separators.push_back(separators[0]); + } + + if (!system_msg.empty()) { + system_msg += separators[0]; + message_list.push_back(TextData(system_message)); + } + + for (int i = 0; i < messages.size(); i++) { + std::string role = messages[i].role; + std::optional>> content = + messages[i].content; + if (roles.find(role) == roles.end()) { + *err += "\nRole " + role + " is not supported. "; + return std::nullopt; + } + + std::string separator = separators[role == "assistant"]; // check assistant role + + // If content is empty, add the role and separator + // assistant's turn to generate text + if (!content.has_value()) { + message_list.push_back(TextData(roles[role] + role_empty_sep)); + continue; + } + + std::string message = ""; + std::string role_prefix = ""; + // Do not append role prefix if this is the first message and there + // is already a system message + if (add_role_after_system_message || system_msg.empty() || i != 0) { + role_prefix = roles[role] + role_content_sep; + } + + message += role_prefix; + + for (auto& item : content.value()) { + if (item.find("type") == item.end()) { + *err += "Content item should have a type field"; + return std::nullopt; + } + if (item["type"] == "text") { + if (item.find("text") == item.end()) { + *err += "Content item should have a text field"; + return std::nullopt; + } + // replace placeholder[ROLE] with input message from role + std::string role_text = role_templates[role]; + std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)]; + size_t pos = role_text.find(placeholder); + if (pos != std::string::npos) { + role_text.replace(pos, placeholder.length(), item["text"]); + } + if (use_function_calling.has_value() && use_function_calling.value()) { + // replace placeholder[FUNCTION] with function_string + // this assumes function calling is used for a single request scenario only + if (!function_string.has_value()) { + *err += "Function string is required for function calling"; + return std::nullopt; + } + pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]); + if (pos != std::string::npos) { + role_text.replace(pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length(), + function_string.value()); + } + } + message += role_text; + } else { + *err += "Unsupported content type: " + item["type"]; + return std::nullopt; + } + } + + message += separator; + message_list.push_back(TextData(message)); + } + + return message_list; +} + +std::optional Conversation::FromJSON(const picojson::object& json, std::string* err) { + Conversation conv; + + // name + std::string name; + if (json::ParseJSONField(json, "name", name, err, false)) { + conv.name = name; + } + + std::string system_template; + if (!json::ParseJSONField(json, "system_template", system_template, err, true)) { + return std::nullopt; + } + conv.system_template = system_template; + + std::string system_message; + if (!json::ParseJSONField(json, "system_message", system_message, err, true)) { + return std::nullopt; + } + conv.system_message = system_message; + + picojson::array system_prefix_token_ids_arr; + if (json::ParseJSONField(json, "system_prefix_token_ids", system_prefix_token_ids_arr, err, + false)) { + std::vector system_prefix_token_ids; + for (const auto& token_id : system_prefix_token_ids_arr) { + if (!token_id.is()) { + *err += "system_prefix_token_ids should be an array of integers."; + return std::nullopt; + } + system_prefix_token_ids.push_back(token_id.get()); + } + conv.system_prefix_token_ids = system_prefix_token_ids; + } + + bool add_role_after_system_message; + if (!json::ParseJSONField(json, "add_role_after_system_message", add_role_after_system_message, + err, true)) { + return std::nullopt; + } + conv.add_role_after_system_message = add_role_after_system_message; + + picojson::object roles_object; + if (!json::ParseJSONField(json, "roles", roles_object, err, true)) { + return std::nullopt; + } + std::unordered_map roles; + for (const auto& role : roles_object) { + if (!role.second.is()) { + *err += "roles should be a map of string to string."; + return std::nullopt; + } + roles[role.first] = role.second.get(); + } + conv.roles = roles; + + picojson::object role_templates_object; + if (json::ParseJSONField(json, "role_templates", role_templates_object, err, false)) { + for (const auto& role : role_templates_object) { + if (!role.second.is()) { + *err += "role_templates should be a map of string to string."; + return std::nullopt; + } + conv.role_templates[role.first] = role.second.get(); + } + } + + picojson::array messages_arr; + if (!json::ParseJSONField(json, "messages", messages_arr, err, true)) { + return std::nullopt; + } + std::vector messages; + for (const auto& message : messages_arr) { + if (!message.is()) { + *err += "messages should be an array of objects."; + return std::nullopt; + } + picojson::object message_obj = message.get(); + std::string role; + if (!json::ParseJSONField(message_obj, "role", role, err, true)) { + *err += "role field is required in messages."; + return std::nullopt; + } + picojson::array content_arr; + std::vector> content; + if (json::ParseJSONField(message_obj, "content", content_arr, err, false)) { + for (const auto& item : content_arr) { + if (!item.is()) { + *err += "Content item is not an object"; + return std::nullopt; + } + std::unordered_map item_map; + picojson::object item_obj = item.get(); + for (picojson::value::object::const_iterator i = item_obj.begin(); i != item_obj.end(); + ++i) { + item_map[i->first] = i->second.to_str(); + } + content.push_back(item_map); + } + } + messages.push_back({role, content}); + } + conv.messages = messages; + + picojson::array seps_arr; + if (!json::ParseJSONField(json, "seps", seps_arr, err, true)) { + return std::nullopt; + } + std::vector seps; + for (const auto& sep : seps_arr) { + if (!sep.is()) { + *err += "seps should be an array of strings."; + return std::nullopt; + } + seps.push_back(sep.get()); + } + conv.seps = seps; + + std::string role_content_sep; + if (!json::ParseJSONField(json, "role_content_sep", role_content_sep, err, true)) { + return std::nullopt; + } + conv.role_content_sep = role_content_sep; + + std::string role_empty_sep; + if (!json::ParseJSONField(json, "role_empty_sep", role_empty_sep, err, true)) { + return std::nullopt; + } + conv.role_empty_sep = role_empty_sep; + + picojson::array stop_str_arr; + if (!json::ParseJSONField(json, "stop_str", stop_str_arr, err, true)) { + return std::nullopt; + } + std::vector stop_str; + for (const auto& stop : stop_str_arr) { + if (!stop.is()) { + *err += "stop_str should be an array of strings."; + return std::nullopt; + } + stop_str.push_back(stop.get()); + } + conv.stop_str = stop_str; + + picojson::array stop_token_ids_arr; + if (!json::ParseJSONField(json, "stop_token_ids", stop_token_ids_arr, err, true)) { + return std::nullopt; + } + std::vector stop_token_ids; + for (const auto& stop : stop_token_ids_arr) { + if (!stop.is()) { + *err += "stop_token_ids should be an array of integers."; + return std::nullopt; + } + stop_token_ids.push_back(stop.get()); + } + conv.stop_token_ids = stop_token_ids; + + std::string function_string; + if (!json::ParseJSONField(json, "function_string", function_string, err, false)) { + conv.function_string = function_string; + } + + bool use_function_calling; + if (json::ParseJSONField(json, "use_function_calling", use_function_calling, err, false)) { + conv.use_function_calling = use_function_calling; + } + + return conv; +} + +std::optional Conversation::FromJSON(const std::string& json_str, std::string* err) { + std::optional json_obj = json::LoadJSONFromString(json_str, err); + if (!json_obj.has_value()) { + return std::nullopt; + } + return Conversation::FromJSON(json_obj.value(), err); +} +} // namespace json_ffi +} // namespace llm +} // namespace mlc diff --git a/cpp/json_ffi/conv_template.h b/cpp/json_ffi/conv_template.h new file mode 100644 index 0000000000..d3a1d1de2f --- /dev/null +++ b/cpp/json_ffi/conv_template.h @@ -0,0 +1,121 @@ +#ifndef MLC_LLM_JSON_FFI_CONV_TEMPLATE_H +#define MLC_LLM_JSON_FFI_CONV_TEMPLATE_H + +#include +#include +#include +#include +#include +#include +#include + +#include "../serve/data.h" +#include "picojson.h" + +using namespace mlc::llm::serve; + +namespace mlc { +namespace llm { +namespace json_ffi { + +enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION }; + +MessagePlaceholders messagePlaceholderFromString(const std::string& role); + +class Message { + public: + std::string role; + std::optional>> content = std::nullopt; +}; + +/** + * @brief A struct that specifies the convention template of conversation + * and contains the conversation history. + */ +struct Conversation { + // Optional name of the template. + std::optional name = std::nullopt; + + // The system prompt template, it optionally contains the system + // message placeholder, and the placeholder will be replaced with + // the system message below. + std::string system_template; + + // The content of the system prompt (without the template format). + std::string system_message; + + // The system token ids to be prepended at the beginning of tokenized + // generated prompt. + std::optional> system_prefix_token_ids = std::nullopt; + + // Whether or not to append user role and separator after the system message. + // This is mainly for [INST] [/INST] style prompt format + bool add_role_after_system_message = true; + + // The conversation roles + std::unordered_map roles; + + // The roles prompt template, it optionally contains the defaults + // message placeholders and will be replaced by actual content + std::unordered_map role_templates; + + // The conversation history messages. + // Each message is a pair of strings, denoting "(role, content)". + // The content can be None. + std::vector messages; + + // The separators between messages when concatenating into a single prompt. + // List size should be either 1 or 2. + // - When size is 1, the separator will be used between adjacent messages. + // - When size is 2, seps[0] is used after user message, and + // seps[1] is used after assistant message. + std::vector seps; + + // The separator between the role and the content in a message. + std::string role_content_sep; + + // The separator between the role and empty contents. + std::string role_empty_sep; + + // The stop criteria + std::vector stop_str; + std::vector stop_token_ids; + + // Function call fields + // whether using function calling or not, helps check for output message format in API call + std::optional function_string = std::nullopt; + std::optional use_function_calling = false; + + Conversation(); + + /** + * @brief Checks the size of the separators vector. + * This function checks if the size of the separators vector is either 1 or 2. + * If the size is not 1 or 2, it throws an invalid_argument exception. + */ + static std::vector CheckMessageSeps(std::vector& seps); + + /*! + * \brief Create the list of prompts from the messages based on the conversation template. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + std::optional> AsPrompt(std::string* err); + + /*! + * \brief Create a Conversation instance from the given JSON object. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + static std::optional FromJSON(const picojson::object& json, std::string* err); + + /*! + * \brief Parse and create a Conversation instance from the given JSON string. + * When creation fails, errors are dumped to the input error string, and nullopt is returned. + */ + static std::optional FromJSON(const std::string& json_str, std::string* err); +}; + +} // namespace json_ffi +} // namespace llm +} // namespace mlc + +#endif /* MLC_LLM_JSON_FFI_CONV_TEMPLATE_H */ diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 2f5bf49ce3..0e21735e2f 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -51,33 +51,40 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request // TODO: Check if request_id is present already // inputs - // TODO: Apply conv template - Array inputs; + Conversation conv_template = this->conv_template_; + std::vector messages; for (const auto& message : request.messages) { - if (message.content.has_value()) { - for (const auto& content : message.content.value()) { - if (content.find("type") == content.end()) { - err_ += "Content should have a type field"; - return false; - } - std::string type = content.at("type"); - if (type == "text") { - if (content.find("text") == content.end()) { - err_ += "Content should have a text field"; - return false; - } - std::string text = content.at("text"); - inputs.push_back(TextData(text)); - } else { - err_ += "Content type not supported"; - return false; - } - } + std::string role; + if (message.role == Role::user) { + role = "user"; + } else if (message.role == Role::assistant) { + role = "assistant"; + } else if (message.role == Role::tool) { + role = "tool"; + } else { + role = "system"; } + messages.push_back({role, message.content}); + } + messages.push_back({"assistant", std::nullopt}); + conv_template.messages = messages; + + // check function calling + bool success_check = request.CheckFunctionCalling(conv_template, &err_); + if (!success_check) { + return false; } + // get prompt + std::optional> inputs_obj = conv_template.AsPrompt(&err_); + if (!inputs_obj.has_value()) { + return false; + } + Array inputs = inputs_obj.value(); + // generation_cfg - Optional generation_cfg = GenerationConfig::FromJSON(request_json_str, &err_); + Optional generation_cfg = + GenerationConfig::FromJSON(request_json_str, &err_, conv_template); if (!generation_cfg.defined()) { return false; } @@ -115,9 +122,15 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_END(); - void InitBackgroundEngine(EngineConfig engine_config, + void InitBackgroundEngine(std::string conv_template_str, EngineConfig engine_config, Optional request_stream_callback, Optional trace_recorder) { + std::optional conv_template = Conversation::FromJSON(conv_template_str, &err_); + if (!conv_template.has_value()) { + LOG(FATAL) << "Invalid conversation template JSON: " << err_; + } + this->conv_template_ = conv_template.value(); + // Todo(mlc-team): decouple InitBackgroundEngine into two functions // by removing `engine_config` from arguments, after properly handling // streamers. diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h index 83013b5876..2c7501c337 100644 --- a/cpp/json_ffi/json_ffi_engine.h +++ b/cpp/json_ffi/json_ffi_engine.h @@ -12,6 +12,7 @@ #include "../serve/threaded_engine.h" #include "../streamer.h" +#include "conv_template.h" #include "openai_api_protocol.h" namespace mlc { @@ -47,6 +48,7 @@ class JSONFFIEngine { std::string err_; PackedFunc request_stream_callback_; TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request + Conversation conv_template_; }; } // namespace json_ffi diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc index 41378fc3e0..13f4b140ce 100644 --- a/cpp/json_ffi/openai_api_protocol.cc +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -11,14 +11,166 @@ namespace mlc { namespace llm { namespace json_ffi { -std::optional ChatCompletionMessage::FromJSON(const picojson::value& json, - std::string* err) { - if (!json.is()) { - *err += "Input is not a valid JSON object"; +std::string generate_uuid_string(size_t length) { + auto randchar = []() -> char { + const char charset[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[rand() % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; +} + +std::optional ChatFunction::FromJSON(const picojson::object& json_obj, + std::string* err) { + ChatFunction chatFunc; + + // description (optional) + std::string description; + if (json::ParseJSONField(json_obj, "description", description, err, false)) { + chatFunc.description = description; + } + + // name + std::string name; + if (!json::ParseJSONField(json_obj, "name", name, err, true)) { + return std::nullopt; + } + chatFunc.name = name; + + // parameters + picojson::object parameters_obj; + if (!json::ParseJSONField(json_obj, "parameters", parameters_obj, err, true)) { + return std::nullopt; + } + std::unordered_map parameters; + for (picojson::value::object::const_iterator i = parameters_obj.begin(); + i != parameters_obj.end(); ++i) { + parameters[i->first] = i->second.to_str(); + } + chatFunc.parameters = parameters; + + return chatFunc; +} + +picojson::object ChatFunction::ToJSON() const { + picojson::object obj; + if (this->description.has_value()) { + obj["description"] = picojson::value(this->description.value()); + } + obj["name"] = picojson::value(this->name); + picojson::object parameters_obj; + for (const auto& pair : this->parameters) { + parameters_obj[pair.first] = picojson::value(pair.second); + } + obj["parameters"] = picojson::value(parameters_obj); + return obj; +} + +std::optional ChatTool::FromJSON(const picojson::object& json_obj, std::string* err) { + ChatTool chatTool; + + // function + picojson::object function_obj; + if (!json::ParseJSONField(json_obj, "function", function_obj, err, true)) { + return std::nullopt; + } + + std::optional function = ChatFunction::FromJSON(function_obj, err); + if (!function.has_value()) { return std::nullopt; } - picojson::object json_obj = json.get(); + chatTool.function = function.value(); + + return chatTool; +} +picojson::object ChatTool::ToJSON() const { + picojson::object obj; + obj["type"] = picojson::value("function"); + obj["function"] = picojson::value(this->function.ToJSON()); + return obj; +} + +std::optional ChatFunctionCall::FromJSON(const picojson::object& json_obj, + std::string* err) { + ChatFunctionCall chatFuncCall; + + // name + std::string name; + if (!json::ParseJSONField(json_obj, "name", name, err, true)) { + return std::nullopt; + } + chatFuncCall.name = name; + + // arguments + picojson::object arguments_obj; + if (json::ParseJSONField(json_obj, "arguments", arguments_obj, err, false)) { + std::unordered_map arguments; + for (picojson::value::object::const_iterator i = arguments_obj.begin(); + i != arguments_obj.end(); ++i) { + arguments[i->first] = i->second.to_str(); + } + chatFuncCall.arguments = arguments; + } + + return chatFuncCall; +} + +picojson::object ChatFunctionCall::ToJSON() const { + picojson::object obj; + picojson::object arguments_obj; + if (this->arguments.has_value()) { + for (const auto& pair : this->arguments.value()) { + arguments_obj[pair.first] = picojson::value(pair.second); + } + obj["arguments"] = picojson::value(arguments_obj); + } + + obj["name"] = picojson::value(this->name); + return obj; +} + +std::optional ChatToolCall::FromJSON(const picojson::object& json_obj, + std::string* err) { + ChatToolCall chatToolCall; + + // function + picojson::object function_obj; + if (!json::ParseJSONField(json_obj, "function", function_obj, err, true)) { + return std::nullopt; + } + + std::optional function = ChatFunctionCall::FromJSON(function_obj, err); + if (!function.has_value()) { + return std::nullopt; + }; + chatToolCall.function = function.value(); + + // overwrite default id + std::string id; + if (!json::ParseJSONField(json_obj, "id", id, err, false)) { + return std::nullopt; + } + chatToolCall.id = id; + + return chatToolCall; +} + +picojson::object ChatToolCall::ToJSON() const { + picojson::object obj; + obj["id"] = picojson::value(this->id); + obj["function"] = picojson::value(this->function.ToJSON()); + obj["type"] = picojson::value("function"); + return obj; +} + +std::optional ChatCompletionMessage::FromJSON( + const picojson::object& json_obj, std::string* err) { ChatCompletionMessage message; // content @@ -65,7 +217,30 @@ std::optional ChatCompletionMessage::FromJSON(const picoj message.name = name; } - // TODO: tool_calls and tool_call_id + // tool calls + picojson::array tool_calls_arr; + if (json::ParseJSONField(json_obj, "tool_calls", tool_calls_arr, err, false)) { + std::vector tool_calls; + for (const auto& item : tool_calls_arr) { + if (!item.is()) { + *err += "Chat Tool Call item is not an object"; + return std::nullopt; + } + picojson::object item_obj = item.get(); + std::optional tool_call = ChatToolCall::FromJSON(item_obj, err); + if (!tool_call.has_value()) { + return std::nullopt; + }; + tool_calls.push_back(tool_call.value()); + } + message.tool_calls = tool_calls; + } + + // tool call id + std::string tool_call_id; + if (json::ParseJSONField(json_obj, "tool_call_id", tool_call_id, err, false)) { + message.tool_call_id = tool_call_id; + } return message; } @@ -81,7 +256,8 @@ std::optional ChatCompletionRequest::FromJSON( } std::vector messages; for (const auto& item : messages_arr) { - std::optional message = ChatCompletionMessage::FromJSON(item, err); + picojson::object item_obj = item.get(); + std::optional message = ChatCompletionMessage::FromJSON(item_obj, err); if (!message.has_value()) { return std::nullopt; } @@ -108,6 +284,32 @@ std::optional ChatCompletionRequest::FromJSON( request.presence_penalty = presence_penalty; } + // tool_choice + std::string tool_choice = "auto"; + request.tool_choice = tool_choice; + if (json::ParseJSONField(json_obj, "tool_choice", tool_choice, err, false)) { + request.tool_choice = tool_choice; + } + + // tools + picojson::array tools_arr; + if (json::ParseJSONField(json_obj, "tools", tools_arr, err, false)) { + std::vector tools; + for (const auto& item : tools_arr) { + if (!item.is()) { + *err += "Chat Tool item is not an object"; + return std::nullopt; + } + picojson::object item_obj = item.get(); + std::optional tool = ChatTool::FromJSON(item_obj, err); + if (!tool.has_value()) { + return std::nullopt; + }; + tools.push_back(tool.value()); + } + request.tools = tools; + } + // TODO: Other parameters return request; @@ -122,7 +324,7 @@ std::optional ChatCompletionRequest::FromJSON(const std:: return ChatCompletionRequest::FromJSON(json_obj.value(), err); } -picojson::object ChatCompletionMessage::ToJSON() { +picojson::object ChatCompletionMessage::ToJSON() const { picojson::object obj; picojson::array content_arr; for (const auto& item : this->content.value()) { @@ -142,13 +344,57 @@ picojson::object ChatCompletionMessage::ToJSON() { } else if (this->role == Role::tool) { obj["role"] = picojson::value("tool"); } - if (name.has_value()) { - obj["name"] = picojson::value(name.value()); + if (this->name.has_value()) { + obj["name"] = picojson::value(this->name.value()); + } + if (this->tool_call_id.has_value()) { + obj["tool_call_id"] = picojson::value(this->tool_call_id.value()); + } + if (this->tool_calls.has_value()) { + picojson::array tool_calls_arr; + for (const auto& tool_call : this->tool_calls.value()) { + tool_calls_arr.push_back(picojson::value(tool_call.ToJSON())); + } + obj["tool_calls"] = picojson::value(tool_calls_arr); } return obj; } -picojson::object ChatCompletionResponseChoice::ToJSON() { +bool ChatCompletionRequest::CheckFunctionCalling(Conversation& conv_template, std::string* err) { + if (!tools.has_value() || (tool_choice.has_value() && tool_choice.value() == "none")) { + conv_template.use_function_calling = false; + return true; + } + std::vector tools_ = tools.value(); + std::string tool_choice_ = tool_choice.value(); + + // TODO: support with tool choice as dict + for (const auto& tool : tools_) { + if (tool.function.name == tool_choice_) { + conv_template.use_function_calling = true; + picojson::value function_str(tool.function.ToJSON()); + conv_template.function_string = function_str.serialize(); + return true; + } + } + + if (tool_choice_ != "auto") { + *err += "Invalid tool_choice value: " + tool_choice_; + return false; + } + + picojson::array function_list; + for (const auto& tool : tools_) { + function_list.push_back(picojson::value(tool.function.ToJSON())); + } + + conv_template.use_function_calling = true; + picojson::value function_list_json(function_list); + conv_template.function_string = function_list_json.serialize(); + return true; +}; + +picojson::object ChatCompletionResponseChoice::ToJSON() const { picojson::object obj; if (!this->finish_reason.has_value()) { obj["finish_reason"] = picojson::value(); @@ -168,7 +414,7 @@ picojson::object ChatCompletionResponseChoice::ToJSON() { return obj; } -picojson::object ChatCompletionStreamResponseChoice::ToJSON() { +picojson::object ChatCompletionStreamResponseChoice::ToJSON() const { picojson::object obj; if (!this->finish_reason.has_value()) { obj["finish_reason"] = picojson::value(); @@ -189,11 +435,11 @@ picojson::object ChatCompletionStreamResponseChoice::ToJSON() { return obj; } -picojson::object ChatCompletionResponse::ToJSON() { +picojson::object ChatCompletionResponse::ToJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); picojson::array choices_arr; - for (auto& choice : this->choices) { + for (const auto& choice : this->choices) { choices_arr.push_back(picojson::value(choice.ToJSON())); } obj["choices"] = picojson::value(choices_arr); @@ -204,11 +450,11 @@ picojson::object ChatCompletionResponse::ToJSON() { return obj; } -picojson::object ChatCompletionStreamResponse::ToJSON() { +picojson::object ChatCompletionStreamResponse::ToJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); picojson::array choices_arr; - for (auto& choice : this->choices) { + for (const auto& choice : this->choices) { choices_arr.push_back(picojson::value(choice.ToJSON())); } obj["choices"] = picojson::value(choices_arr); diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index 1579b5f337..bed225d3d0 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -8,10 +8,12 @@ #include #include +#include #include #include #include +#include "conv_template.h" #include "picojson.h" namespace mlc { @@ -22,7 +24,8 @@ enum class Role { system, user, assistant, tool }; enum class Type { text, json_object, function }; enum class FinishReason { stop, length, tool_calls, error }; -// TODO: Implement the following class +std::string generate_uuid_string(size_t length); + class ChatFunction { public: std::optional description = std::nullopt; @@ -30,32 +33,37 @@ class ChatFunction { std::unordered_map parameters; // Assuming parameters are string key-value pairs - static std::optional FromJSON(const picojson::value& json, std::string* err); + static std::optional FromJSON(const picojson::object& json, std::string* err); + picojson::object ToJSON() const; }; -// TODO: Implement the following class class ChatTool { public: Type type = Type::function; ChatFunction function; - static std::optional FromJSON(const picojson::value& json, std::string* err); + static std::optional FromJSON(const picojson::object& json, std::string* err); + picojson::object ToJSON() const; }; -// TODO: Implement the following class class ChatFunctionCall { public: std::string name; std::optional> arguments = std::nullopt; // Assuming arguments are string key-value pairs + + static std::optional FromJSON(const picojson::object& json, std::string* err); + picojson::object ToJSON() const; }; -// TODO: Implement the following class class ChatToolCall { public: - std::string id; // TODO: python code initializes this to an random string + std::string id = "call_" + generate_uuid_string(8); Type type = Type::function; ChatFunctionCall function; + + static std::optional FromJSON(const picojson::object& json, std::string* err); + picojson::object ToJSON() const; }; class ChatCompletionMessage { @@ -64,12 +72,12 @@ class ChatCompletionMessage { std::nullopt; // Assuming content is a list of string key-value pairs Role role; std::optional name = std::nullopt; - std::optional> tool_calls = std::nullopt; // TODO: Implement this - std::optional tool_call_id = std::nullopt; // TODO: Implement this + std::optional> tool_calls = std::nullopt; + std::optional tool_call_id = std::nullopt; - static std::optional FromJSON(const picojson::value& json, + static std::optional FromJSON(const picojson::object& json, std::string* err); - picojson::object ToJSON(); + picojson::object ToJSON() const; }; class RequestResponseFormat { @@ -113,6 +121,7 @@ class ChatCompletionRequest { static std::optional FromJSON(const std::string& json_str, std::string* err); + bool CheckFunctionCalling(Conversation& conv_template, std::string* err); // TODO: check_penalty_range, check_logit_bias, check_logprobs }; @@ -123,7 +132,7 @@ class ChatCompletionResponseChoice { ChatCompletionMessage message; // TODO: logprobs - picojson::object ToJSON(); + picojson::object ToJSON() const; }; class ChatCompletionStreamResponseChoice { @@ -133,7 +142,7 @@ class ChatCompletionStreamResponseChoice { ChatCompletionMessage delta; // TODO: logprobs - picojson::object ToJSON(); + picojson::object ToJSON() const; }; class ChatCompletionResponse { @@ -146,7 +155,7 @@ class ChatCompletionResponse { std::string object = "chat.completion"; // TODO: usage_info - picojson::object ToJSON(); + picojson::object ToJSON() const; }; class ChatCompletionStreamResponse { @@ -158,7 +167,7 @@ class ChatCompletionStreamResponse { std::string system_fingerprint; std::string object = "chat.completion.chunk"; - picojson::object ToJSON(); + picojson::object ToJSON() const; }; } // namespace json_ffi diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 5d647ec532..7379bad7ed 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -161,8 +161,8 @@ GenerationConfig::GenerationConfig(String config_json_str) { data_ = std::move(n); } -Optional GenerationConfig::FromJSON(const std::string& json_str, - std::string* err) { +Optional GenerationConfig::FromJSON(const std::string& json_str, std::string* err, + const Conversation& conv_template) { std::optional json_obj = json::LoadJSONFromString(json_str, err); if (!err->empty() || !json_obj.has_value()) { return NullOpt; @@ -171,6 +171,14 @@ Optional GenerationConfig::FromJSON(const std::string& json_st // TODO(mlc-team): Pass the parameters from `json_obj` to `n`. + // Copy stop str from conversation template to generation config + for (auto& stop_str : conv_template.stop_str) { + n->stop_strs.push_back(stop_str); + } + for (auto& stop_token_id : conv_template.stop_token_ids) { + n->stop_token_ids.push_back(stop_token_id); + } + if (!err->empty()) { return NullOpt; } diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 404566fe2c..41ddb3c6e4 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -11,12 +11,15 @@ #include +#include "../json_ffi/conv_template.h" + namespace mlc { namespace llm { namespace serve { using namespace tvm; using namespace tvm::runtime; +using namespace mlc::llm::json_ffi; /****************** GenerationConfig ******************/ @@ -63,7 +66,8 @@ class GenerationConfig : public ObjectRef { * \brief Parse the generation config from the given JSON string. * When parsing fails, errors are dumped to the input error string, and NullOpt is returned. */ - static Optional FromJSON(const std::string& json_str, std::string* err); + static Optional FromJSON(const std::string& json_str, std::string* err, + const Conversation& conv_template); TVM_DEFINE_OBJECT_REF_METHODS(GenerationConfig, ObjectRef, GenerationConfigNode); }; diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index b8a8d492b9..9b594e9784 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -19,7 +19,7 @@ ) from mlc_llm.tokenizer import Tokenizer -prompts = [ +chat_completion_prompts = [ "What is the meaning of life?", "Introduce the history of Pittsburgh to me. Please elaborate in detail.", "Write a three-day Seattle travel plan. Please elaborate in detail.", @@ -32,6 +32,33 @@ "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", ] +function_calling_prompts = [ + "What is the temperature in Pittsburgh, PA?", + "What is the temperature in Tokyo, JP?", + "What is the temperature in Pittsburgh, PA and Tokyo, JP?", +] + +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } +] + class EngineState: sync_queue: queue.Queue @@ -139,6 +166,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals spec_draft_length=spec_draft_length, ) self._ffi["init_background_engine"]( + self.conv_template.model_dump_json(), self.engine_config, self.state.get_request_stream_callback(), None, @@ -266,7 +294,12 @@ def _test_unload(self): self._ffi["unload"]() -def run_chat_completion(engine: JSONFFIEngine, model: str): +def run_chat_completion( + engine: JSONFFIEngine, + model: str, + prompts: List[str] = chat_completion_prompts, + tools: Optional[List[Dict]] = None, +): num_requests = 2 max_tokens = 64 n = 1 @@ -280,6 +313,7 @@ def run_chat_completion(engine: JSONFFIEngine, model: str): max_tokens=max_tokens, n=n, request_id=str(rid), + tools=tools, ): for choice in response.choices: assert choice.delta.role == "assistant" @@ -300,8 +334,8 @@ def run_chat_completion(engine: JSONFFIEngine, model: str): def test_chat_completion(): # Create engine. - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-cuda.so" engine = JSONFFIEngine( model, model_lib_path=model_lib_path, @@ -320,8 +354,8 @@ def test_chat_completion(): def test_reload_reset_unload(): # Create engine. - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-cuda.so" engine = JSONFFIEngine( model, model_lib_path=model_lib_path, @@ -339,6 +373,24 @@ def test_reload_reset_unload(): engine.terminate() +def test_function_calling(): + model = "dist/gorilla-openfunctions-v1-q4f16_1-MLC" + model_lib_path = ( + "dist/gorilla-openfunctions-v1-q4f16_1-MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so" + ) + engine = JSONFFIEngine( + model, + model_lib_path=model_lib_path, + max_total_sequence_length=1024, + ) + + # run function calling + run_chat_completion(engine, model, function_calling_prompts, tools) + + engine.terminate() + + if __name__ == "__main__": test_chat_completion() test_reload_reset_unload() + test_function_calling() From 12647d57c4f3c8a86d2212764319e11a564d78c1 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 22 Apr 2024 18:18:25 -0700 Subject: [PATCH 225/531] [Serving] Paged Radix Tree for Prefix Caching (#2183) This PR introduces the Paged Radix Tree data structure, as foundation and prerequisite of prefix caching. --- cpp/serve/radix_tree.cc | 718 ++++++++++++++++++++++++++ cpp/serve/radix_tree.h | 110 ++++ python/mlc_llm/serve/__init__.py | 1 + python/mlc_llm/serve/radix_tree.py | 150 ++++++ tests/python/serve/test_radix_tree.py | 79 +++ 5 files changed, 1058 insertions(+) create mode 100644 cpp/serve/radix_tree.cc create mode 100644 cpp/serve/radix_tree.h create mode 100644 python/mlc_llm/serve/radix_tree.py create mode 100644 tests/python/serve/test_radix_tree.py diff --git a/cpp/serve/radix_tree.cc b/cpp/serve/radix_tree.cc new file mode 100644 index 0000000000..5d5c311593 --- /dev/null +++ b/cpp/serve/radix_tree.cc @@ -0,0 +1,718 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/radix_tree.cc + */ +#include "radix_tree.h" + +#include + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! + * \brief The sequence ID linked list structure in paged radix tree node. + */ +struct SequenceIDNode { + /*! \brief The stored sequence ID. */ + int64_t id = 0; + /*! \brief The pointer to the next sequence ID. */ + SequenceIDNode* next = nullptr; +}; + +/*! + * \brief The sequence Id node pool. + * + * The sequence Id node pool allocates all sequence ID nodes when construction and frees when + * destruction, to avoid frequent memory operation. + */ +class SequenceIDNodePool { + public: + /*! \brief The constructor of sequence Id node pool, allocating memory for each node. */ + SequenceIDNodePool(size_t num_nodes) : num_nodes_(num_nodes) { + nodes_.reserve(num_nodes); + free_node_indicess_.reserve(num_nodes); + used_nodes_.clear(); + raw_pool_ = new SequenceIDNode[num_nodes_]; + for (size_t i = 0; i < num_nodes; ++i) { + nodes_.push_back(&raw_pool_[i]); + free_node_indicess_.push_back(i); + } + } + + /*! + * \brief Get a radix page from pool, and assign the fields. + * \param seq_id The assigned sequence ID of allocated sequence ID node. + * \param node The next sequence ID node pointer of allocated sequence ID node. + * \return The allocated radix page. + * \throw Error if no free radix page available in pool. + */ + SequenceIDNode* Allocate(int64_t seq_id, SequenceIDNode* next) { + CHECK(!free_node_indicess_.empty()) << "Sequence ID node pool has no free sequence ID nodes."; + size_t id = free_node_indicess_.back(); + free_node_indicess_.pop_back(); + SequenceIDNode* node = nodes_[id]; + used_nodes_[node] = id; + node->id = seq_id; + node->next = next; + return node; + } + + /*! + * \brief Free a sequence ID node to pool. + * \param node The sequence ID node to free. + */ + void Free(SequenceIDNode* node) { + CHECK(used_nodes_.find(node) != used_nodes_.end()); + free_node_indicess_.push_back(used_nodes_[node]); + used_nodes_.erase(node); + } + + /*! \brief The destructor of sequence Id node pool, freeing memory for each node. */ + ~SequenceIDNodePool() { delete[] raw_pool_; } + + private: + /*! \brief The number of nodes in sequence ID node pool. */ + size_t num_nodes_; + /*! \brief The raw sequence ID node pool. */ + SequenceIDNode* raw_pool_; + /*! \brief The sequence ID node pool. */ + std::vector nodes_; + /*! \brief The indices of free sequence ID node in node pool. */ + std::vector free_node_indicess_; + /*! \brief The map from used paged sequence ID node to its index in node pool. */ + std::unordered_map used_nodes_; +}; + +/*! + * \brief The paged radix tree node data structure. + * + * The paged radix tree node is similar to original radix tree node, but with the limited length for + * prefix in page, so that the memory usage in each page is the same and is fixed once allocated. + * Since the page only consists of pointers and int tokens, the page memory layout is int array + * indeed. The lower offset is the pointers and page information, while the higher offset is the + * stored prefix tokens. + * + * And since the vocabulary size may be very large, the paged Radix tree is represented + * as left-child, right-sibling binary tree. + * + * Also, due to possible pop/push front/back tokens in page, the page is designed as circular + * buffer, to make full use of each page. + * + * Each page records the sequence excatly ends with the prefix tokens stored in page. In other word, + * all sequences locate in the boundary of each page, or the end of each page. + */ +struct RedixPage { + /*! \brief The parent page. */ + RedixPage* parent; + /*! \brief The first child page. */ + RedixPage* first_child; + /*! \brief The sibling page shareing the same parent page. */ + RedixPage* next_sibiling; + /*! \brief The head of sequence ID linked list. */ + SequenceIDNode* seq_ids; + /*! \brief The capacity of maximum stored prefix tokens. */ + size_t capacity; + /*! \brief The start offset of stored prefix tokens. The legal value is of [0, capacity). */ + size_t offset; + /*! \brief The length of stored prefix tokens. The legal value is of [0, capacity). */ + size_t length; + /*! \brief The offset of first prefix token in memory layout. */ + static constexpr int DATA_OFFSET = (sizeof(RedixPage*) * 3 + sizeof(SequenceIDNode*) + + sizeof(size_t) * 3 + sizeof(int32_t) - 1) / + sizeof(int32_t); + + /*! + * \brief Overload opeartor [] to get the prefix tokens by index as simple int array. + * \param i The prefix token index. + * \return The value of i-th prefix token. + */ + int32_t& operator[](size_t i) { + return reinterpret_cast(this)[DATA_OFFSET + (i + offset) % capacity]; + } + + /*! + * \brief Extend or push back a suffix tokens in page. + * \param suffix The suffix tokens array. + * \param suffix_length The suffix length to extend. + * \throw Error if suffix length is larger than current vacant space. + */ + void Extend(const int64_t* suffix, size_t suffix_length) { + CHECK_LE(suffix_length + length, capacity); + for (int i = 0; i < suffix_length; ++i) { + (*this)[i + length] = (int32_t)suffix[i]; + } + length += suffix_length; + } + + /*! + * \brief Add a sequence ID in page. + * \param pool The sequence ID node pool to allocate new node. + * \param id The sequence ID to add. + */ + void AddSequence(SequenceIDNodePool* pool, int64_t id) { seq_ids = pool->Allocate(id, seq_ids); } + + /*! + * \brief Pop a sequence ID in page. + * \param pool The sequence ID node pool to free popped node. + * \param id The sequence ID to pop. + * \throw Error if no such sequence ID in page. + */ + void PopSequence(SequenceIDNodePool* pool, int64_t id) { + if (seq_ids->id == id) { + // If the popped sequencs ID is the first node in linked list, + // directly skip from head and free it. + SequenceIDNode* next = seq_ids->next; + pool->Free(seq_ids); + seq_ids = next; + } else { + // If the popped sequencs ID is not the first node in linked list, + // skip it from previous node and free it. + SequenceIDNode* last = seq_ids; + SequenceIDNode* cur = seq_ids->next; + while (cur) { + if (cur->id == id) { + last->next = cur->next; + pool->Free(cur); + return; + } + } + LOG(FATAL) << "Sequence ID = " << id << " not found."; + } + } + + /*! + * \brief Get all sequence ID in page. + * \return The std::vector of sequence ID in page. + */ + std::vector GetLocalSequence() { + std::vector output; + for (SequenceIDNode* node = seq_ids; node; node = node->next) { + output.push_back(node->id); + } + return output; + } + + /*! + * \brief Get any sequence ID in current page or child pages. + * Since there is always a sequence in leaf pages, it only check first child if no sequence ID in + * current page. + * \return The any sequence ID in current page or child pages. + */ + int32_t FindAnyChildSequence() { + if (seq_ids) return seq_ids->id; + return first_child->FindAnyChildSequence(); + } + + /*! + * \brief Get all sequence ID in current page and child pages, using Iterate method with lambda + * expression as callback to avoid frequently memory allocation of std::vector. + * \return The std::vector of all sequence ID in current page and child pages. + */ + std::vector FindAllChildSequence() { + std::vector output = GetLocalSequence(); + if (first_child) { + first_child->Iterate([&output](const RedixPage* page) { + for (SequenceIDNode* node = page->seq_ids; node; node = node->next) { + output.push_back(node->id); + } + }); + } + return output; + } + + /*! + * \brief The iteration method for tree or sub-tree traverse. + * \param f The callback function to invoke at each radix page visited. + */ + template + void Iterate(CallbackFunc f) { + f(this); + if (next_sibiling) next_sibiling->Iterate(f); + if (first_child) first_child->Iterate(f); + } + + /*! + * \brief Get the last sibling of current page. + * \return The page whose next_sibling is current page, or nullptr if current is the fisrt_child + * of its parent page. + */ + RedixPage* GetLastSibling() { + if (parent == nullptr) return nullptr; + if (parent->first_child == this) return nullptr; + for (RedixPage* child = parent->first_child; child; child = child->next_sibiling) { + if (child->next_sibiling == this) return child; + } + return nullptr; + } + + /*! + * \brief Find the child indexed by first token. + * \return The child page started with first token, or nullptr if no such child page. + */ + RedixPage* FindChild(int64_t first_token) { + int32_t casted = first_token; + // Iterate all child radix pages, as the child radix pages are stored unorderly. + for (RedixPage* child = first_child; child; child = child->next_sibiling) { + if ((*child)[0] == casted) return child; + } + return nullptr; + } + + /*! \brief Insert a new child page. */ + void InsertChild(RedixPage* child) { + child->parent = this; + child->next_sibiling = first_child; + first_child = child; + } + + /*! + * \brief Remove a child page. + * \throw Error if page to be removed is not child page. + */ + void RemoveChild(RedixPage* child) { + CHECK(child->parent == this); + if (first_child == child) { + first_child = child->next_sibiling; + } else { + child->GetLastSibling()->next_sibiling = child->next_sibiling; + } + } + + /*! + * \brief Check current page is mergable with its child page. + * The page is mergable if and only if + * 1. No sequence ID in current page, as sequence ID is not allowed to exist within page. + * 2. The current page has child page. + * 3. The current page has only one child page. + * 4. The current page perfix and the child page prefix can be concatenated into one page. + * \return True if current page is mergable, or false. + */ + bool Mergeable() { + if (seq_ids) return false; + if (!first_child) return false; + if (first_child->next_sibiling) return false; + if (length + first_child->length > capacity) return false; + return true; + } + + /*! + * \brief Match the given prefix within page. + * \param prefix The prefix token array. + * \param prefix_length The length of prefix token array. + * \return The matched prefix offset within page, or the first mismatched token position. The + * possible return value is [0, page->length], where page->length means the page is completely the + * prefix of given prefix. + */ + size_t MatchPrefix(const int64_t* prefix, size_t prefix_length) { + size_t n = std::min(length, prefix_length); + for (int i = 0; i < n; ++i) { + if ((*this)[i] != prefix[i]) return i; + } + return n; + } +}; + +/*! + * \brief The paged radix tree page pool. + * + * The paged radix tree page pool allocates all radix tree pages when construction and frees when + * destruction, to avoid frequent memory operation. + */ +class RadixPagePool { + public: + /*! \brief The constructor of paged radix tree page pool, allocating memory for each page. */ + RadixPagePool(size_t page_size, size_t num_pages) : page_size_(page_size), num_pages_(num_pages) { + pages_.reserve(num_pages); + free_page_indices_.reserve(num_pages); + raw_pool_ = new int32_t[num_pages * page_size / sizeof(int32_t)]; + int32_t num_int = page_size / sizeof(int32_t); + for (size_t i = 0; i < num_pages; ++i) { + pages_.push_back(reinterpret_cast(raw_pool_ + i * num_int)); + free_page_indices_.push_back(i); + } + } + + /*! + * \brief Get a radix page from pool. + * \return The allocated radix page. + * \throw Error if no free radix page available in pool. + */ + RedixPage* Allocate() { + CHECK(!free_page_indices_.empty()) << "Radix page pool has no free radix tree pages."; + int id = free_page_indices_.back(); + free_page_indices_.pop_back(); + RedixPage* page = pages_[id]; + used_pages_[page] = id; + page->parent = page->first_child = page->next_sibiling = nullptr; + page->capacity = page_size_ / sizeof(int32_t) - RedixPage::DATA_OFFSET; + page->offset = page->length = 0; + page->seq_ids = nullptr; + return page; + } + + /*! + * \brief Free a radix page to pool. + * \param page The radix page to free. + */ + void Free(RedixPage* page) { + CHECK_EQ(page->seq_ids, nullptr); + CHECK(used_pages_.find(page) != used_pages_.end()); + free_page_indices_.push_back(used_pages_[page]); + CHECK(used_pages_.erase(page)); + } + + /*! + * \brief Get the token capacity of free pages. + * \return The the token capacity of free pages. + */ + size_t FreeCapacity() { + return free_page_indices_.size() * (page_size_ / sizeof(int32_t) - RedixPage::DATA_OFFSET); + } + + /*! \brief The destructor of paged radix tree page pool, freeing memory for each page. */ + ~RadixPagePool() { delete[] raw_pool_; } + + private: + /*! \brief The page size of each paged radix tree page. */ + size_t page_size_; + /*! \brief The number of pages in paged radix tree page pool. */ + size_t num_pages_; + /*! \brief The raw paged radix tree page pool. */ + int32_t* raw_pool_; + /*! \brief The paged radix tree page pool. */ + std::vector pages_; + /*! \brief The indices of free paged radix page in page pool. */ + std::vector free_page_indices_; + /*! \brief The map from used paged radix tree page to its index in page pool. */ + std::unordered_map used_pages_; +}; + +// PagedRadixTree + +/*! + * \brief The paged radix tree data structure. + */ +class PagedRadixTreeImpl : public PagedRadixTreeObj { + public: + /*! \brief The page size of each paged radix tree node. */ + size_t page_size; + /*! \brief The number of pages in paged radix tree page pool. */ + size_t num_pages; + /*! \brief The maximum number of sequence ID in paged radix tree page pool. */ + size_t num_seqs; + /*! \brief The map from sequence to paged radix tree node it is stored. */ + std::unordered_map seq2page; + /*! \brief The sequence ID node pool. */ + SequenceIDNodePool* seq_id_node_pool = nullptr; + /*! \brief The radix page pool. */ + RadixPagePool* radix_page_pool = nullptr; + /*! \brief The root page of paged radix tree. */ + RedixPage* root = nullptr; + + explicit PagedRadixTreeImpl(size_t num_pages, size_t page_size, size_t num_seqs) { + num_pages = num_pages; + page_size = page_size; + num_seqs = num_seqs; + + seq_id_node_pool = new SequenceIDNodePool(num_seqs); + radix_page_pool = new RadixPagePool(page_size, num_pages); + + root = reinterpret_cast(new int32_t[RedixPage::DATA_OFFSET]); + root->parent = root->first_child = root->next_sibiling = nullptr; + root->offset = root->length = root->capacity = 0; + root->seq_ids = nullptr; + } + + /*! + * \brief Get a sequence's all tokens. + * \param seq_id The sequence ID for index. + * \return The sequence tokens. + * \throw Error if sequence ID is not valid. + */ + IntTuple GetSequence(int64_t seq_id) { + CHECK(seq2page.find(seq_id) != seq2page.end()); + size_t length = GetSequenceLength(seq_id); + std::vector output(length); + size_t offset = length; + for (RedixPage* page = seq2page[seq_id]; page; page = page->parent) { + offset -= page->length; + for (int i = 0; i < page->length; ++i) { + output[offset + i] = (*page)[i]; + } + } + return IntTuple(output); + } + + /*! + * \brief Get all sequences with longest common prefix with give prefix tokens. + * \param tokens The prefix tokens for reference. + * \return The pair of matched prefix length and the array of matched sequences indices. + */ + std::pair> MatchPrefix(IntTuple tokens) { + const int64_t* prefix = tokens.data(); + size_t length = tokens.size(); + auto [page, offset, in_page_offset] = MatchSequence(root, prefix, length); + if (!offset) return std::make_pair(0, std::vector()); + return std::make_pair(offset, page->FindAllChildSequence()); + } + + /*! + * \brief Get a sequence's length. + * \param seq_id The sequence ID for index. + * \return The sequence length. + * \throw Error if sequence ID is not valid. + */ + size_t GetSequenceLength(int64_t seq_id) { + CHECK(seq2page.find(seq_id) != seq2page.end()); + size_t length = 0; + for (RedixPage* page = seq2page[seq_id]; page; page = page->parent) { + length += page->length; + } + return length; + } + + /*! + * \brief Fork a sequence from parent sequence at given position. + * \param seq_id The new sequence ID. + * \param parent_seq_id The parent sequence ID to fork from. + * \param forked_offset The position of parent sequence to fork at. + * The valid value is [1, length of forked sequence]. If the position equals the length of forked + * sequence, the new sequence will copy the entire forked sequence. + * \throw Error if sequence ID or + * forked postion is not valid. + */ + void ForkSequence(int64_t seq_id, int64_t parent_seq_id, size_t forked_offset) { + CHECK(seq2page.find(seq_id) == seq2page.end()); + CHECK(seq2page.find(parent_seq_id) != seq2page.end()); + CHECK_GT(forked_offset, 0); + size_t length = GetSequenceLength(parent_seq_id); + CHECK_LE(forked_offset, length); + for (RedixPage* page = seq2page[parent_seq_id]; page; page = page->parent) { + if (forked_offset >= length - page->length) { + if (forked_offset < length) { + // Split radix page if forked position is within page + page = SplitPage(page, forked_offset + page->length - length); + } + page->AddSequence(seq_id_node_pool, seq_id); + seq2page[seq_id] = page; + return; + } + length -= page->length; + } + } + + /*! + * \brief Add an empty sequence at root. + * \param seq_id The new sequence ID. + * \throw Error if sequence ID is not valid. + */ + void AddSequence(int64_t seq_id) { + CHECK(seq2page.find(seq_id) == seq2page.end()); + root->AddSequence(seq_id_node_pool, seq_id); + seq2page[seq_id] = root; + } + + /*! + * \brief Extend a sequence with given tokens. + * \param seq_id The sequence ID for index. + * \param tokens The given tokens to extend. + * \throw Error if sequence ID is not valid. + */ + void ExtendSequence(int64_t seq_id, IntTuple tokens) { + CHECK(seq2page.find(seq_id) != seq2page.end()); + const int64_t* suffix = tokens.data(); + size_t length = tokens.size(); + RedixPage* original_page = seq2page[seq_id]; + original_page->PopSequence(seq_id_node_pool, seq_id); + auto [page, offset, in_page_offset] = MatchSequence(original_page, suffix, length); + if (in_page_offset < page->length) { + // Split page if extended sequence mismatches within page + page = SplitPage(page, in_page_offset); + } + if (offset < length && !page->seq_ids && !page->first_child && page->capacity > page->length) { + // Extend in the existing leaf page first if possible. + size_t suffix_length = std::min(page->capacity - page->length, length - offset); + page->Extend(suffix + offset, suffix_length); + offset += suffix_length; + } + while (offset < length) { + // Allocate new radix page and extend tokens + RedixPage* new_page = radix_page_pool->Allocate(); + page->InsertChild(new_page); + page = new_page; + size_t suffix_length = std::min(page->capacity - page->length, length - offset); + page->Extend(suffix + offset, suffix_length); + offset += suffix_length; + } + page->AddSequence(seq_id_node_pool, seq_id); + seq2page[seq_id] = page; + if (original_page->Mergeable()) { + // The original page may be mergeable, as the sequence ID changes + MergePage(original_page); + } + } + + /*! + * \brief Remove a sequence. + * \param seq_id The sequence ID to remove. + * \throw Error if sequence ID is not valid. + */ + void RemoveSequence(int64_t seq_id) { + RedixPage* page = seq2page[seq_id]; + page->PopSequence(seq_id_node_pool, seq_id); + seq2page.erase(seq_id); + while (page->parent && !page->seq_ids && !page->first_child) { + RedixPage* parent = page->parent; + parent->RemoveChild(page); + radix_page_pool->Free(page); + page = parent; + } + if (page && page->Mergeable()) { + // The remaining page may be mergeable, as the sequence ID changes + MergePage(page); + } + } + + /*! + * \brief Get the remaining token capacity of the paged radix tree. + * \return The the remaining token capacity of the paged radix tree. + */ + size_t FreeCapacity() { return radix_page_pool->FreeCapacity(); } + + /*! \brief The destructor to free root page. */ + ~PagedRadixTreeImpl() { + delete[] reinterpret_cast(root); + delete seq_id_node_pool; + delete radix_page_pool; + } + + private: + /*! + * \brief Merge a radix tree page with its child radix tree page, to save radix tree page. + * e.g. MergePage([1, 2, _, _, _] -> [3, 4, 5, _, _]) = [1, 2, 3, 4, 5]. + * And the page to be merged should be page->Mergeable(). + * \param page The parent radix tree page. + */ + void MergePage(RedixPage* page) { + CHECK(page->Mergeable()); + RedixPage* child = page->first_child; + for (int i = 0; i < child->length; ++i) { + (*page)[i + page->length] = (*child)[i]; + } + page->length += child->length; + page->first_child = child->first_child; + for (RedixPage* p = child->first_child; p; p = p->next_sibiling) { + p->parent = page; + } + page->seq_ids = child->seq_ids; + std::vector seq_ids = page->GetLocalSequence(); + for (int64_t id : seq_ids) seq2page[id] = page; + child->seq_ids = nullptr; + radix_page_pool->Free(child); + } + + /*! + * \brief Split a radix tree page at given postition, to accept new sequence. + * e.g. SplitPage([1, 2, 3, 4, 5], 2) = [1, 2, _, _, _] -> [3, 4, 5, _, _]. + * \param page The radix tree page to split. + * \param offset The position to split the radix tree page. + * \return The splitted radix tree page. It can be different from the input radix tree page, as + * there may be implicit radix tree page merge. + */ + RedixPage* SplitPage(RedixPage* page, size_t offset) { + CHECK_LT(offset, page->length); + RedixPage* child = radix_page_pool->Allocate(); + child->parent = page; + child->first_child = page->first_child; + for (RedixPage* p = page->first_child; p; p = p->next_sibiling) { + p->parent = child; + } + page->first_child = child; + for (int i = offset; i < page->length; ++i) { + (*child)[i - offset] = (*page)[i]; + } + child->length = page->length - offset; + page->length = offset; + if (child->Mergeable()) { + // The child page may be mergeable + MergePage(child); + } + if (page->parent && page->parent->Mergeable()) { + // The parent page may be mergeable + page = page->parent; + MergePage(page); + } + return page; + } + + /*! + * \brief Match with given token from a radix tree page, stopping at first mismatch. + * \param page The radix tree page to start matching. + * \param tokens The given tokens to match. + * \param length The length of given tokens. + */ + std::tuple MatchSequence(RedixPage* page, const int64_t* tokens, + size_t length) { + size_t offset = 0; + while (offset < length) { + if (RedixPage* child = page->FindChild(tokens[offset])) { + // If child page starts with offset-th token, common prefix at least ends with child page + size_t matched_offset = child->MatchPrefix(tokens + offset, length - offset); + offset += matched_offset; + if (matched_offset < child->length) { + // Common prefix ends within child page + return std::make_tuple(child, offset, matched_offset); + } + page = child; + } else { + // No child page starts with offset-th token, common prefix ends with current page + return std::make_tuple(page, offset, page->length); + } + } + return std::make_tuple(page, length, page->length); + } +}; + +TVM_REGISTER_OBJECT_TYPE(PagedRadixTreeImpl); + +PagedRadixTree::PagedRadixTree(size_t num_pages, size_t page_size, size_t num_seqs) { + data_ = std::move(make_object(num_pages, page_size, num_pages)); +} + +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTree") + .set_body_typed([](uint64_t num_pages, uint64_t page_size, uint64_t num_seqs) { + return PagedRadixTree(num_pages, page_size, num_seqs); + }); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeMatchPrefix") + .set_body_typed([](PagedRadixTree paged_radix_tree, IntTuple tokens) { + auto [offset, seq_ids] = paged_radix_tree->MatchPrefix(tokens); + seq_ids.insert(seq_ids.begin(), offset); + return IntTuple(seq_ids); + }); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeExtendSequence") + .set_body_method(&PagedRadixTreeObj::ExtendSequence); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeForkSequence") + .set_body_typed([](PagedRadixTree paged_radix_tree, int64_t seq_id, int64_t parent_seq_id, + uint64_t forked_offset) { + paged_radix_tree->ForkSequence(seq_id, parent_seq_id, forked_offset); + }); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeAddSequence") + .set_body_method(&PagedRadixTreeObj::AddSequence); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeRemoveSequence") + .set_body_method(&PagedRadixTreeObj::RemoveSequence); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeGetSequence") + .set_body_method(&PagedRadixTreeObj::GetSequence); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeGetSequenceLength") + .set_body_typed([](PagedRadixTree paged_radix_tree, int64_t seq_id) { + return (int64_t)paged_radix_tree->GetSequenceLength(seq_id); + }); +TVM_REGISTER_GLOBAL("mlc.serve.PagedRadixTreeFreeCapacity") + .set_body_typed([](PagedRadixTree paged_radix_tree) { + return (int64_t)paged_radix_tree->FreeCapacity(); + }); +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/radix_tree.h b/cpp/serve/radix_tree.h new file mode 100644 index 0000000000..ed831c17b1 --- /dev/null +++ b/cpp/serve/radix_tree.h @@ -0,0 +1,110 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file serve/radix_tree.h + */ +#ifndef MLC_LLM_SERVE_RADIX_TREE_H_ +#define MLC_LLM_SERVE_RADIX_TREE_H_ +#include +#include + +#include +#include + +namespace mlc { +namespace llm { +namespace serve { + +using namespace tvm::runtime; + +/*! + * \brief The paged radix tree data structure. + */ +class PagedRadixTreeObj : public Object { + public: + /*! + * \brief Get a sequence's all tokens. + * \param seq_id The sequence ID for index. + * \return The sequence tokens. + * \throw Error if sequence ID is not valid. + */ + virtual IntTuple GetSequence(int64_t seq_id) = 0; + + /*! + * \brief Get all sequences with longest common prefix with give prefix tokens. + * \param tokens The prefix tokens for reference. + * \return The pair of matched prefix length and the array of matched sequences indices. + */ + virtual std::pair> MatchPrefix(IntTuple tokens) = 0; + + /*! + * \brief Get a sequence's length. + * \param seq_id The sequence ID for index. + * \return The sequence length. + * \throw Error if sequence ID is not valid. + */ + virtual size_t GetSequenceLength(int64_t seq_id) = 0; + + /*! + * \brief Fork a sequence from parent sequence at given position. + * \param seq_id The new sequence ID. + * \param parent_seq_id The parent sequence ID to fork from. + * \param forked_offset The position of parent sequence to fork at. + * The valid value is [1, length of forked sequence]. If the position equals the length of forked + * sequence, the new sequence will copy the entire forked sequence. + * \throw Error if sequence ID or + * forked postion is not valid. + */ + virtual void ForkSequence(int64_t seq_id, int64_t parent_seq_id, size_t forked_offset) = 0; + + /*! + * \brief Add an empty sequence at root. + * \param seq_id The new sequence ID. + * \throw Error if sequence ID is not valid. + */ + virtual void AddSequence(int64_t seq_id) = 0; + + /*! + * \brief Extend a sequence with given tokens. + * \param seq_id The sequence ID for index. + * \param tokens The given tokens to extend. + * \throw Error if sequence ID is not valid. + */ + virtual void ExtendSequence(int64_t seq_id, IntTuple tokens) = 0; + + /*! + * \brief Remove a sequence. + * \param seq_id The sequence ID to remove. + * \throw Error if sequence ID is not valid. + */ + virtual void RemoveSequence(int64_t seq_id) = 0; + + /*! + * \brief Get the remaining token capacity of the paged radix tree. + * \return The the remaining token capacity of the paged radix tree. + */ + virtual size_t FreeCapacity() = 0; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "mlc.serve.PagedRadixTree"; + TVM_DECLARE_BASE_OBJECT_INFO(PagedRadixTreeObj, Object) +}; + +TVM_REGISTER_OBJECT_TYPE(PagedRadixTreeObj); + +class PagedRadixTree : public ObjectRef { + public: + /*! + * \brief Constructor of paged radix tree. + * \param num_pages The number of radix tree pages. + * \param page_size The page size of each radix tree page. + * \param num_seqs The maximum number of sequence ID. + */ + PagedRadixTree(size_t num_pages, size_t page_size, size_t num_seqs); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PagedRadixTree, ObjectRef, PagedRadixTreeObj); +}; +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_RADIX_TREE_H_ diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 0a59df7421..79caff7cad 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -6,5 +6,6 @@ from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData from .engine import AsyncLLMEngine, LLMEngine from .grammar import BNFGrammar, GrammarStateMatcher +from .radix_tree import PagedRadixTree from .request import Request from .server import PopenServer diff --git a/python/mlc_llm/serve/radix_tree.py b/python/mlc_llm/serve/radix_tree.py new file mode 100644 index 0000000000..102cdac675 --- /dev/null +++ b/python/mlc_llm/serve/radix_tree.py @@ -0,0 +1,150 @@ +"""The Paged Radix Tree class.""" + +from typing import List, Tuple, Union + +import tvm +import tvm._ffi +from tvm.runtime import Object, ShapeTuple + +from . import _ffi_api + + +@tvm._ffi.register_object("mlc.serve.PagedRadixTree") # pylint: disable=protected-access +class PagedRadixTree(Object): + """The paged radix tree to manage prefix and sequence.""" + + def __init__(self, num_pages: int, page_size: int, num_seqs: int): + """ + Constructor of paged radix tree. + + Parameters + ---------- + num_pages : int + The number of radix tree pages. + page_size : int + The page size of each radix tree page. + num_seqs : int + The maximum number of sequence ID. + """ + self.__init_handle_by_constructor__(_ffi_api.PagedRadixTree, num_pages, page_size, num_seqs) # type: ignore # pylint: disable=no-member + + def match(self, tokens: Union[ShapeTuple, List, Tuple]) -> Tuple[int, ShapeTuple]: + """ + Get all sequences with longest common prefix with given prefix tokens. + + Parameters + ---------- + tokens : Union[ShapeTuple, List, Tuple] + The prefix tokens for reference. + + Returns + ------ + matched_offset : int + The matched prefix length. + seq_ids : ShapeTuple + The array of matched sequence indice. + """ + if isinstance(tokens, (list, tuple)): + tokens = ShapeTuple(tokens) + output = _ffi_api.PagedRadixTreeMatchPrefix(self, tokens) # type: ignore # pylint: disable=no-member + if len(output) == 1: + return output[0], [] + return output[0], output[1:] + + def add(self, seq_id: int) -> None: + """ + Get all sequences with longest common prefix with give prefix tokens. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + """ + _ffi_api.PagedRadixTreeAddSequence(self, seq_id) # type: ignore # pylint: disable=no-member + + def remove(self, seq_id: int) -> None: + """ + Remove a sequence. + + Parameters + ---------- + seq_id : int + The sequence ID to remove. + """ + _ffi_api.PagedRadixTreeRemoveSequence(self, seq_id) # type: ignore # pylint: disable=no-member + + def extend(self, seq_id: int, tokens: Union[ShapeTuple, List, Tuple]) -> None: + """ + Get all sequences with longest common prefix with give prefix tokens. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + tokens : Union[ShapeTuple, List, Tuple] + The given tokens to extend. + """ + if isinstance(tokens, (list, tuple)): + tokens = ShapeTuple(tokens) + _ffi_api.PagedRadixTreeExtendSequence(self, seq_id, tokens) # type: ignore # pylint: disable=no-member + + def fork(self, seq_id: int, parent_seq_id: int, forked_offset: int) -> None: + """ + Fork a sequence from parent sequence at given position. + + Parameters + ---------- + seq_id : int + The new sequence ID. + parent_seq_id : int + The parent sequence ID to fork from. + forked_offset : int + The position of parent sequence to fork at. + The valid value is [1, length of forked sequence]. + If the position equals the length of forked sequence, + the new sequence will copy the entire forked sequence. + """ + _ffi_api.PagedRadixTreeForkSequence(self, seq_id, parent_seq_id, forked_offset) # type: ignore # pylint: disable=no-member + + def get(self, seq_id: int) -> ShapeTuple: + """ + Get a sequence's all tokens. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + + Returns + ------ + tokens : ShapeTuple + The sequence tokens. + """ + return _ffi_api.PagedRadixTreeGetSequence(self, seq_id) # type: ignore # pylint: disable=no-member + + def get_length(self, seq_id: int) -> int: + """ + Get a sequence's length. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + + Returns + ------ + length : int + The sequence length. + """ + return _ffi_api.PagedRadixTreeGetSequenceLength(self, seq_id) # type: ignore # pylint: disable=no-member + + def free_capacity(self) -> int: + """ + Get the remaining token capacity of the paged radix tree. + + Returns + ------ + capacity : int + The remaining token capacity of the paged radix tree. + """ + return _ffi_api.PagedRadixTreeFreeCapacity(self) # type: ignore # pylint: disable=no-member diff --git a/tests/python/serve/test_radix_tree.py b/tests/python/serve/test_radix_tree.py new file mode 100644 index 0000000000..cea421cd95 --- /dev/null +++ b/tests/python/serve/test_radix_tree.py @@ -0,0 +1,79 @@ +from tvm import TVMError +from tvm.runtime import ShapeTuple + +from mlc_llm.serve import PagedRadixTree + + +def test_add(): + prt = PagedRadixTree(16, 128, 16) + prt.add(0) + assert prt.get(0) == [] + + +def test_remove(): + prt = PagedRadixTree(32, 128, 16) + capacity = prt.free_capacity() + prt.add(0) + prt.remove(0) + prt.add(0) + prt.extend(0, [1 for _ in range(200)]) + prt.remove(0) + assert prt.free_capacity() == capacity + + prt.add(1) + prt.extend(1, [1 for _ in range(200)]) + capacity = prt.free_capacity() + prt.add(2) + prt.extend(2, [1 for _ in range(100)] + [2 for _ in range(100)]) + prt.remove(2) + assert prt.free_capacity() == capacity + + prt.add(3) + prt.extend(3, [1 for _ in range(200)]) + prt.remove(3) + assert prt.free_capacity() == capacity + + +def test_extend(): + prt = PagedRadixTree(1024, 256, 256) + L = prt.free_capacity() // 1024 + H = L // 2 + Q = L // 4 + seq_id = 0 + for start_pos in [0, H, L, L + H]: + for length in [Q, L - H, L, 2 * L - H, 2 * L]: + prt.add(seq_id) + if start_pos: + tokens_1 = [seq_id for _ in range(start_pos)] + prt.extend(seq_id, tokens_1) + assert prt.get(seq_id) == tokens_1 + else: + tokens_1 = [] + tokens_2 = [seq_id for _ in range(length)] + prt.extend(seq_id, tokens_2) + assert prt.get(seq_id) == tokens_1 + tokens_2 + seq_id += 1 + + +def test_fork(): + prt = PagedRadixTree(1024, 256, 256) + L = prt.free_capacity() // 1024 + H = L // 2 + Q = L // 4 + seq_id = 0 + length_list = [Q, H, L, L + Q, L + H, L * 2] + for p_idx in range(1, len(length_list)): + for c_idx in range(0, p_idx + 1): + prt.add(seq_id) + tokens = [seq_id for _ in range(length_list[p_idx])] + prt.extend(seq_id, tokens) + prt.fork(seq_id + 1, seq_id, length_list[c_idx]) + assert prt.get(seq_id + 1) == tokens[: length_list[c_idx]] + seq_id += 2 + + +if __name__ == "__main__": + test_add() + test_remove() + test_extend() + test_fork() From dc3988a8224f9ddc65b8c4d466930d1919095782 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 22 Apr 2024 21:35:51 -0400 Subject: [PATCH 226/531] [Serving] Remove mandatory model check in server (#2195) This PR removes the mandatory model check in server since as of now we serve one engine at most which means there is always a unique engine being served. As issue #2155 points out, the model check in server can be a bad experience when the model string mismatches. --- .../mlc_llm/protocol/openai_api_protocol.py | 10 +++--- python/mlc_llm/serve/engine.py | 32 +++++++++---------- python/mlc_llm/serve/server/server_context.py | 7 ++-- tests/python/serve/server/test_server.py | 18 ----------- 4 files changed, 26 insertions(+), 41 deletions(-) diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 1a732488a0..d6ce4a4fcb 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -84,7 +84,7 @@ class CompletionRequest(BaseModel): API reference: https://platform.openai.com/docs/api-reference/completions/create """ - model: str + model: Optional[str] = None prompt: Union[str, List[int]] best_of: int = 1 echo: bool = False @@ -154,7 +154,7 @@ class CompletionResponse(BaseModel): id: str choices: List[CompletionResponseChoice] created: int = Field(default_factory=lambda: int(time.time())) - model: str + model: Optional[str] = None object: str = "text_completion" usage: UsageInfo = Field( default_factory=lambda: UsageInfo() # pylint: disable=unnecessary-lambda @@ -200,7 +200,7 @@ class ChatCompletionRequest(BaseModel): """ messages: List[ChatCompletionMessage] - model: str + model: Optional[str] = None frequency_penalty: Optional[float] = None presence_penalty: Optional[float] = None logprobs: bool = False @@ -343,7 +343,7 @@ class ChatCompletionResponse(BaseModel): id: str choices: List[ChatCompletionResponseChoice] created: int = Field(default_factory=lambda: int(time.time())) - model: str + model: Optional[str] = None system_fingerprint: str object: Literal["chat.completion"] = "chat.completion" usage: UsageInfo = Field( @@ -359,7 +359,7 @@ class ChatCompletionStreamResponse(BaseModel): id: str choices: List[ChatCompletionStreamResponseChoice] created: int = Field(default_factory=lambda: int(time.time())) - model: str + model: Optional[str] = None system_fingerprint: str object: Literal["chat.completion.chunk"] = "chat.completion.chunk" usage: UsageInfo = Field( diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index a84f98fb33..5bbdc149d4 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -61,8 +61,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], - model: str, stream: Literal[True], + model: Optional[str] = None, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, logprobs: bool = False, @@ -111,7 +111,7 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], - model: str, + model: Optional[str] = None, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, logprobs: bool = False, @@ -160,7 +160,7 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], - model: str, + model: Optional[str] = None, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, logprobs: bool = False, @@ -238,8 +238,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], - model: str, stream: Literal[True], + model: Optional[str] = None, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, logprobs: bool = False, @@ -288,7 +288,7 @@ def create( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], - model: str, + model: Optional[str] = None, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, logprobs: bool = False, @@ -335,7 +335,7 @@ def create( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], - model: str, + model: Optional[str] = None, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, logprobs: bool = False, @@ -412,9 +412,9 @@ def __init__(self, engine: weakref.ReferenceType) -> None: async def create( # pylint: disable=too-many-arguments,too-many-locals self, *, - model: str, prompt: Union[str, List[int]], stream: Literal[True], + model: Optional[str] = None, best_of: int = 1, echo: bool = False, frequency_penalty: float = 0.0, @@ -463,8 +463,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals async def create( # pylint: disable=too-many-arguments,too-many-locals self, *, - model: str, prompt: Union[str, List[int]], + model: Optional[str] = None, best_of: int = 1, echo: bool = False, frequency_penalty: float = 0.0, @@ -511,8 +511,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals async def create( # pylint: disable=too-many-arguments,too-many-locals self, *, - model: str, prompt: Union[str, List[int]], + model: Optional[str] = None, best_of: int = 1, echo: bool = False, frequency_penalty: float = 0.0, @@ -591,9 +591,9 @@ def __init__(self, engine: weakref.ReferenceType) -> None: def create( # pylint: disable=too-many-arguments,too-many-locals self, *, - model: str, prompt: Union[str, List[int]], stream: Literal[True], + model: Optional[str] = None, best_of: int = 1, echo: bool = False, frequency_penalty: float = 0.0, @@ -642,8 +642,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals def create( # pylint: disable=too-many-arguments,too-many-locals self, *, - model: str, prompt: Union[str, List[int]], + model: Optional[str] = None, best_of: int = 1, echo: bool = False, frequency_penalty: float = 0.0, @@ -690,8 +690,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals def create( # pylint: disable=too-many-arguments,too-many-locals self, *, - model: str, prompt: Union[str, List[int]], + model: Optional[str] = None, best_of: int = 1, echo: bool = False, frequency_penalty: float = 0.0, @@ -883,7 +883,7 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local self, *, messages: List[Dict[str, Any]], - model: str, + model: Optional[str] = None, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, logprobs: bool = False, @@ -1003,8 +1003,8 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local async def _completion( # pylint: disable=too-many-arguments,too-many-locals self, *, - model: str, prompt: Union[str, List[int]], + model: Optional[str] = None, best_of: int = 1, echo: bool = False, frequency_penalty: float = 0.0, @@ -1429,7 +1429,7 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], - model: str, + model: Optional[str] = None, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, logprobs: bool = False, @@ -1549,8 +1549,8 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals def _completion( # pylint: disable=too-many-arguments,too-many-locals self, *, - model: str, prompt: Union[str, List[int]], + model: Optional[str] = None, best_of: int = 1, echo: bool = False, frequency_penalty: float = 0.0, diff --git a/python/mlc_llm/serve/server/server_context.py b/python/mlc_llm/serve/server/server_context.py index 0a9a1b0b1f..46b841aaa9 100644 --- a/python/mlc_llm/serve/server/server_context.py +++ b/python/mlc_llm/serve/server/server_context.py @@ -37,8 +37,11 @@ def add_model(self, hosted_model: str, engine: AsyncLLMEngine) -> None: raise RuntimeError(f"Model {hosted_model} already running.") self._models[hosted_model] = engine - def get_engine(self, model: str) -> Optional[AsyncLLMEngine]: - """Get the async engine of the requested model.""" + def get_engine(self, model: Optional[str]) -> Optional[AsyncLLMEngine]: + """Get the async engine of the requested model, or the unique async engine + if only one engine is served.""" + if len(self._models) == 1: + return next(iter(self._models.values())) return self._models.get(model, None) def get_model_list(self) -> List[str]: diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index cca9a4265e..e4f64d2ce4 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -329,23 +329,6 @@ def test_openai_v1_completions_openai_package( ) -def test_openai_v1_completions_invalid_requested_model( - launch_server, # pylint: disable=unused-argument -): - # `launch_server` is a pytest fixture defined in conftest.py. - - model = "unserved_model" - payload = { - "model": model, - "prompt": "What is the meaning of life?", - "max_tokens": 10, - } - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - expect_error( - response_str=response.json(), msg_prefix=f'The requested model "{model}" is not served.' - ) - - @pytest.mark.parametrize("stream", [False, True]) def test_openai_v1_completions_echo( served_model: Tuple[str, str], @@ -1319,7 +1302,6 @@ def test_debug_dump_event_trace( test_openai_v1_completions(MODEL, None, stream=True) test_openai_v1_completions_openai_package(MODEL, None, stream=False) test_openai_v1_completions_openai_package(MODEL, None, stream=True) - test_openai_v1_completions_invalid_requested_model(None) test_openai_v1_completions_echo(MODEL, None, stream=False) test_openai_v1_completions_echo(MODEL, None, stream=True) test_openai_v1_completions_suffix(MODEL, None, stream=False) From 651c2a0c295a85fe70469382c297b4de4e2ea4f3 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 22 Apr 2024 22:40:22 -0700 Subject: [PATCH 227/531] [Sampler] Enable GPU sampler for draft verification (#2198) * [Eagle] Attach gpu verifier to model * WIP * WIP * fix * Enable GPU verifier * lint * lint --- .../engine_actions/eagle_batch_verify.cc | 1 - cpp/serve/function_table.cc | 1 + cpp/serve/function_table.h | 1 + cpp/serve/model.cc | 4 +- cpp/serve/sampler/gpu_sampler.cc | 140 +++++++++++++++++- .../mlc_llm/compiler_pass/attach_sampler.py | 50 +++++++ 6 files changed, 191 insertions(+), 6 deletions(-) diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index 043f68b9c2..6718afaccf 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -88,7 +88,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj { generation_cfg.push_back(rsentries[i]->request->generation_cfg); rngs.push_back(&rsentries[i]->rng); draft_output_tokens.push_back(draft_mstate->draft_output_tokens); - CHECK(draft_mstate->draft_output_prob_dist[0]->device.device_type == kDLCPU); draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); } diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 823d3c6164..55b494dae0 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -265,6 +265,7 @@ void FunctionTable::_InitFunctions() { gpu_argsort_probs_func_ = mod->GetFunction("argsort_probs", true); gpu_sample_with_top_p_func_ = mod->GetFunction("sample_with_top_p", true); gpu_sampler_take_probs_func_ = mod->GetFunction("sampler_take_probs", true); + gpu_verify_draft_tokens_func_ = mod->GetFunction("sampler_verify_draft_tokens", true); } this->nd_view_func_ = get_global_func("vm.builtin.reshape"); this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of"); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index f6a156b8a3..5f08a9ba5c 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -104,6 +104,7 @@ struct FunctionTable { PackedFunc gpu_argsort_probs_func_; PackedFunc gpu_sample_with_top_p_func_; PackedFunc gpu_sampler_take_probs_func_; + PackedFunc gpu_verify_draft_tokens_func_; PackedFunc nd_view_func_; PackedFunc nd_get_shape_func_; PackedFunc nd_copy_embedding_to_offset_func_; diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 17121d8e28..fc8c8b485c 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -768,9 +768,7 @@ class ModelImpl : public ModelObj { Sampler CreateSampler(int max_num_sample, int num_models, Optional trace_recorder) { - if (num_models > 1) { // speculative decoding uses cpu sampler - return Sampler::CreateCPUSampler(std::move(trace_recorder)); - } else if (Sampler::SupportGPUSampler(device_)) { + if (Sampler::SupportGPUSampler(device_)) { return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_, std::move(trace_recorder)); } else { diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index b376523dac..1f1d2e9eb3 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -43,6 +43,7 @@ class GPUSampler : public SamplerObj { gpu_argsort_probs_func_(ft->gpu_argsort_probs_func_), gpu_sample_with_top_p_func_(ft->gpu_sample_with_top_p_func_), gpu_sampler_take_probs_func_(ft->gpu_sampler_take_probs_func_), + gpu_verify_draft_tokens_func_(ft->gpu_verify_draft_tokens_func_), trace_recorder_(std::move(trace_recorder)) { ICHECK(gpu_multinomial_from_uniform_func_.defined()); ICHECK(gpu_argsort_probs_func_.defined()); @@ -92,11 +93,20 @@ class GPUSampler : public SamplerObj { NVTXScopedRange nvtx_scope("BatchSampleTokens"); // probs_on_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); - CHECK(output_prob_dist == nullptr) << "GPU sampler does not support collecting output probs."; CHECK_EQ(probs_on_device->ndim, 2); int num_samples = sample_indices.size(); int num_probs = probs_on_device->shape[0]; int vocab_size = probs_on_device->shape[1]; + if (output_prob_dist != nullptr) { + ICHECK(output_prob_dist->empty()); + output_prob_dist->reserve(num_probs); + for (int i = 0; i < num_probs; ++i) { + NDArray prob_dist = NDArray::Empty({vocab_size}, dtype_f32_, device_); + float* p_prob = static_cast(probs_on_device->data) + i * vocab_size; + prob_dist.CopyFromBytes(p_prob, vocab_size * sizeof(float)); + output_prob_dist->push_back(std::move(prob_dist)); + } + } ICHECK_EQ(request_ids.size(), num_samples); ICHECK_EQ(generation_cfg.size(), num_samples); ICHECK_EQ(rngs.size(), num_samples); @@ -132,7 +142,132 @@ class GPUSampler : public SamplerObj { const std::vector& rngs, const std::vector>& draft_output_tokens, const std::vector>& draft_output_prob_dist) final { - LOG(FATAL) << "GPU sampler does not support batch verification for now."; + std::vector> sample_results; + // probs_on_device: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); + CHECK_EQ(probs_on_device->ndim, 2); + + int num_sequence = static_cast(cum_verify_lengths.size()) - 1; + CHECK_EQ(rngs.size(), num_sequence); + CHECK_EQ(draft_output_tokens.size(), num_sequence); + CHECK_EQ(draft_output_prob_dist.size(), num_sequence); + sample_results.resize(num_sequence); + + int num_nodes = cum_verify_lengths.back(); + NDArray uniform_samples_host = uniform_samples_host_.CreateView({num_nodes}, dtype_f32_); + NDArray uniform_samples_device = uniform_samples_device_.CreateView({num_nodes}, dtype_f32_); + NDArray draft_probs_device = NDArray::Empty({num_nodes, vocab_size_}, dtype_f32_, device_); + NDArray draft_tokens_device = NDArray::Empty({num_nodes}, dtype_i32_, device_); + NDArray draft_tokens_host = + NDArray::Empty({num_nodes}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0}); + + // Concat draft prob distributions to a ragged tensor (num_nodes, vocab_size) + for (int i = 0; i < num_sequence; i++) { + const std::vector& draft_output_tokens_i = draft_output_tokens[i]; + const std::vector& draft_output_prob_dist_i = draft_output_prob_dist[i]; + int start = cum_verify_lengths[i]; + int end = cum_verify_lengths[i + 1]; + // start/end is the range of the sequence i in probs_on_device, which includes the prob dist + // of the draft tokens and the last committed token + ICHECK_EQ(draft_output_tokens_i.size() + 1, end - start); + ICHECK_EQ(draft_output_prob_dist_i.size() + 1, end - start); + for (int j = 0; j < end - start - 1; j++) { + // Copy prob dist + ICHECK_EQ(draft_probs_device->dtype.bits, 32); + float* p_draft_probs = + static_cast(draft_probs_device->data) + + (j + start + 1) * + vocab_size_; // shift by one, q of the last committed token is undefined + // Copy sampled token id + draft_output_prob_dist_i[j].CopyToBytes(p_draft_probs, vocab_size_ * sizeof(float)); + *(static_cast(draft_tokens_host->data) + j + start + 1) = + draft_output_tokens_i[j].sampled_token_id.first; + } + } + CopyArray(draft_tokens_host, draft_tokens_device, copy_stream_); + + float* p_uniform_samples = static_cast(uniform_samples_host->data); + for (int i = 0; i < num_sequence; ++i) { + int start = cum_verify_lengths[i]; + int end = cum_verify_lengths[i + 1]; + for (int j = start; j < end; j++) { + p_uniform_samples[j] = rngs[i]->GetRandomNumber(); + } + } + CopyArray(uniform_samples_host, uniform_samples_device, copy_stream_); + + // This should be refactored to use the cached tensors + NDArray token_tree_first_child_device = NDArray::Empty({num_nodes}, dtype_i32_, device_); + NDArray token_tree_next_sibling_device = NDArray::Empty({num_nodes}, dtype_i32_, device_); + NDArray token_tree_parent_ptr_device = NDArray::Empty({num_sequence}, dtype_i32_, device_); + NDArray token_tree_first_child_host = + NDArray::Empty({num_nodes}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0}); + NDArray token_tree_next_sibling_host = + NDArray::Empty({num_nodes}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0}); + NDArray token_tree_parent_ptr_host = + NDArray::Empty({num_sequence}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0}); + NDArray token_tree_child_to_parent_host = + NDArray::Empty({num_nodes}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0}); + + // Build the tree structure on CPU + for (int i = 0; i < num_sequence; i++) { + // Assuming no tree structure for now + int start = cum_verify_lengths[i]; + int end = cum_verify_lengths[i + 1]; + ICHECK_EQ(end - start, 2); // one committed token and assuming only one draft token + static_cast(token_tree_child_to_parent_host->data)[start] = -1; // root has no parent + for (int j = 0; j < end - start; j++) { + int cur_node = j + start; + int child_node = j + 1 >= end - start ? -1 : cur_node + 1; + static_cast(token_tree_first_child_host->data)[cur_node] = child_node; + if (child_node != -1) { + static_cast(token_tree_child_to_parent_host->data)[child_node] = cur_node; + } + static_cast(token_tree_next_sibling_host->data)[cur_node] = -1; + } + static_cast(token_tree_parent_ptr_host->data)[i] = start; // point to the root + } + // Copy token tree structure to GPU + CopyArray(token_tree_first_child_host, token_tree_first_child_device, copy_stream_); + CopyArray(token_tree_next_sibling_host, token_tree_next_sibling_device, copy_stream_); + CopyArray(token_tree_parent_ptr_host, token_tree_parent_ptr_device, copy_stream_); + + SyncCopyStream(device_, compute_stream_, copy_stream_); + + gpu_verify_draft_tokens_func_(draft_probs_device, draft_tokens_device, probs_on_device, + token_tree_first_child_device, token_tree_next_sibling_device, + uniform_samples_device, token_tree_parent_ptr_device); + + CopyArray(token_tree_parent_ptr_device, token_tree_parent_ptr_host, compute_stream_); + TVMSynchronize(device_.device_type, device_.device_id, compute_stream_); + + std::vector sample_indices; + + for (int i = 0; i < num_sequence; i++) { + int start = cum_verify_lengths[i]; + int end = cum_verify_lengths[i + 1]; + int last_accepted = static_cast(token_tree_parent_ptr_host->data)[i]; + int num_accepted = 0; + for (int cur_node = last_accepted; cur_node != start; + cur_node = static_cast(token_tree_child_to_parent_host->data)[cur_node]) { + sample_results[i].push_back(draft_output_tokens[i][cur_node - start - 1]); + num_accepted++; + } + std::reverse(sample_results[i].rbegin(), sample_results[i].rbegin() + num_accepted); + sample_indices.push_back(last_accepted); + } + std::vector additional_sample_result; + // This only works for top-p = 1. To enable top-p, we need to normalize the probs before + // verifying. + additional_sample_result = this->BatchSampleTokens(probs_on_device, sample_indices, request_ids, + generation_cfg, rngs, nullptr); + ICHECK_EQ(additional_sample_result.size(), num_sequence); + for (int i = 0; i < num_sequence; i++) { + sample_results[i].push_back(additional_sample_result[i]); + } + + RECORD_EVENT(trace_recorder_, request_ids, "finish draft verification"); + return sample_results; } private: @@ -370,6 +505,7 @@ class GPUSampler : public SamplerObj { PackedFunc gpu_argsort_probs_func_; PackedFunc gpu_sample_with_top_p_func_; PackedFunc gpu_sampler_take_probs_func_; + PackedFunc gpu_verify_draft_tokens_func_; // Auxiliary NDArrays on CPU NDArray uniform_samples_host_; NDArray sample_indices_host_; diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index 1b7b0328a9..f044c3a6d8 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -7,6 +7,8 @@ from tvm.relax.frontend import nn from tvm.script import tir as T +from ..op.batch_spec_verify import batch_spec_verify + @tvm.transform.module_pass(opt_level=0, name="AttachGPUSamplingFunc") class AttachGPUSamplingFunc: # pylint: disable=too-few-public-methods @@ -46,6 +48,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR _attach_argsort_func(bb, vocab_size), _attach_sample_with_top_p(bb, vocab_size), _attach_take_probs_func(bb, vocab_size), + _attach_batch_verifier(bb, vocab_size), ] ] @@ -289,3 +292,50 @@ def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-argument bb.emit_output(taken_probs_indices) gv = bb.emit_func_output(taken_probs_indices) return gv + + +def _attach_batch_verifier(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): + num_nodes = tir.Var("num_nodes", "int64") + nbatch = tir.Var("nbatch", "int64") + draft_probs = relax.Var( + "draft_probs", relax.TensorStructInfo((num_nodes, vocab_size), "float32") + ) + draft_tokens = relax.Var("draft_tokens", relax.TensorStructInfo((num_nodes,), "int32")) + model_probs = relax.Var( + "model_probs", relax.TensorStructInfo((num_nodes, vocab_size), "float32") + ) + token_tree_first_child = relax.Var( + "token_tree_first_child", relax.TensorStructInfo((num_nodes,), "int32") + ) + token_tree_next_sibling = relax.Var( + "token_tree_next_sibling", relax.TensorStructInfo((num_nodes,), "int32") + ) + uniform_samples = relax.Var("uniform_samples", relax.TensorStructInfo((num_nodes,), "float32")) + token_tree_parent_ptr = relax.Var( + "token_tree_parent_ptr", relax.TensorStructInfo((nbatch,), "int32") + ) + args = [ + draft_probs, + draft_tokens, + model_probs, + token_tree_first_child, + token_tree_next_sibling, + uniform_samples, + token_tree_parent_ptr, + ] + with bb.function("sampler_verify_draft_tokens", args): + with bb.dataflow(): + res = bb.emit( + relax.call_tir_inplace( + bb.add_func(batch_spec_verify(vocab_size), "batch_verify_on_gpu_single_kernel"), + args, + inplace_indices=[args.index(model_probs), args.index(token_tree_parent_ptr)], + out_sinfo=[ + model_probs.struct_info, # pylint: disable=no-member + token_tree_parent_ptr.struct_info, # pylint: disable=no-member + ], + ) + ) + bb.emit_output(res) + gv = bb.emit_func_output(res) + return gv From 0ed4bcb7c1756b8df1d09d3d4260a587e03bb926 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 23 Apr 2024 04:06:20 -0700 Subject: [PATCH 228/531] [Eagle] Make eagle disco compatible (#2197) * [Eagle] Make BatchSelectLastHidden able to run on the controller --- cpp/serve/engine.cc | 24 ++++++------- cpp/serve/function_table.cc | 4 +-- cpp/serve/model.cc | 42 ++++++++++++++++++----- python/mlc_llm/model/llama/llama_model.py | 4 +-- 4 files changed, 48 insertions(+), 26 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 8e47564945..afde4d1eb5 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -102,18 +102,18 @@ class EngineImpl : public Engine { ICHECK_GT(this->models_.size(), 1U); switch (engine_config->speculative_mode) { case SpeculativeMode::kEagle: - this->actions_ = { - EngineAction::EagleNewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler, - this->model_workspaces_, this->trace_recorder_), - EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, - this->model_workspaces_, engine_config, - this->trace_recorder_)}; + this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + engine_config, // + this->trace_recorder_), + EngineAction::EagleBatchDraft( + this->models_, logit_processor, sampler, this->model_workspaces_, + this->trace_recorder_, engine_config->spec_draft_length), + EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, + this->model_workspaces_, engine_config, + this->trace_recorder_)}; break; default: this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 55b494dae0..792f98094b 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -234,8 +234,8 @@ void FunctionTable::_InitFunctions() { this->verify_to_last_hidden_func_ = mod_get_func("batch_verify_to_last_hidden_states"); this->fuse_embed_hidden_func_ = mod_get_func("fuse_embed_hidden_states"); Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm; - this->get_logits_func_ = mod->GetFunction("get_logits", true); - this->batch_get_logits_func_ = mod->GetFunction("batch_get_logits", true); + this->get_logits_func_ = mod_get_func("get_logits"); + this->batch_get_logits_func_ = mod_get_func("batch_get_logits"); this->batch_select_last_hidden_func_ = mod->GetFunction("batch_select_last_hidden_states", true); this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true); this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true); diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index fc8c8b485c..3583b5d84b 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -136,16 +136,23 @@ class ModelImpl : public ModelObj { ICHECK_EQ(hidden_states->device.device_type, device_.device_type); ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - hidden_states_dref_or_nd = + hidden_states = hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); + // This copy can be avoided by not copying the hidden states to engine. + hidden_states_dref_or_nd = ft_.CopyToWorker0( + hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); ObjectRef ret = ft_.get_logits_func_(hidden_states_dref_or_nd, params_); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - NDArray logits; - logits = Downcast(ret); + NDArray logits{nullptr}; + if (ret->IsInstance()) { + logits = Downcast(ret)->DebugGetFromRemote(0); + } else { + logits = Downcast(ret); + } CHECK(logits.defined()); // logits: (b * s, v) ICHECK_EQ(logits->ndim, 2); @@ -185,8 +192,11 @@ class ModelImpl : public ModelObj { ICHECK_EQ(hidden_states->device.device_type, device_.device_type); ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - hidden_states_dref_or_nd = - hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); + hidden_states = hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); + + // This copy can be avoided by not copying the hidden states to engine. + hidden_states_dref_or_nd = ft_.CopyToWorker0( + hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); ObjectRef ret = ft_.batch_get_logits_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); @@ -218,8 +228,15 @@ class ModelImpl : public ModelObj { p_logit_pos[i] = total_length - 1; } NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); + + // This step runs on the engine thread. + // By temporarily turning off the disco flag, this copies the logit_pos_nd to the cached device + // tensor without actually copying to the worker. + bool use_disco = ft_.use_disco; + ft_.use_disco = false; ObjectRef logit_pos_dref_or_nd = ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); + ft_.use_disco = use_disco; CHECK(ft_.batch_select_last_hidden_func_.defined()) << "`batch_select_last_hidden_states` function is not found in the model."; @@ -240,7 +257,7 @@ class ModelImpl : public ModelObj { hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); ObjectRef ret = - ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); + ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } @@ -265,10 +282,17 @@ class ModelImpl : public ModelObj { // No ICHECK_EQ(hidden->shape[0], hidden_size_) here to allow different hidden_sizes. hidden = hidden.CreateView({1, hidden_size_}, hidden->dtype); // Reuse the copy embedding function - ft_.nd_copy_embedding_to_offset_func_(hidden, *dst, cum_length); + ObjectRef hidden_dref_or_nd = + ft_.CopyToWorker0(hidden, "hidden_for_concat", {1, hidden_size_}); + ft_.nd_copy_embedding_to_offset_func_(hidden_dref_or_nd, *dst, cum_length); cum_length += 1; } - NDArray ret = Downcast(*dst); + NDArray ret{nullptr}; + if ((*dst)->IsInstance()) { + ret = Downcast(*dst)->DebugGetFromRemote(0); + } else { + ret = Downcast(*dst); + } ret = ret.CreateView({cum_length, hidden_size_}, hidden_states[0]->dtype); return ret; } @@ -295,7 +319,7 @@ class ModelImpl : public ModelObj { return embeddings_nd.CreateView({batch_size, seq_len, hidden_size_}, embeddings_nd->dtype); } } else { - ShapeTuple embedding_shape{batch_size, seq_len, hidden_size_}; + ShapeTuple embedding_shape{batch_size * seq_len, hidden_size_}; embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 7a01cc20de..18238f688e 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -257,8 +257,6 @@ def batch_get_logits(self, hidden_states: Tensor, logit_positions: Tensor): def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): op_ext.configure() - if self.tensor_parallel_shards > 1: - logit_positions = op.ccl_broadcast_from_worker0(logit_positions) hidden_states = op.take(hidden_states, logit_positions, axis=0) return hidden_states @@ -382,7 +380,7 @@ def get_default_spec(self): "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), "$": { - "param_mode": "packed", + "param_mode": "none", "effect_mode": "none", }, }, From af8206ba2fda5e934741feb2cfb87610afb933fe Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 23 Apr 2024 14:59:01 -0400 Subject: [PATCH 229/531] [Serving][Spec] Fix normal mode verification for extra draft token (#2206) This PR updates the draft verification of the normal mode speculative decoding. Prior to this PR, we did not effectively leverage all the draft tokens, and this PR fixes the issue. --- cpp/serve/engine.cc | 2 +- cpp/serve/engine_actions/batch_verify.cc | 44 +++++++++---------- .../eagle_new_request_prefill.cc | 2 +- .../engine_actions/new_request_prefill.cc | 2 +- cpp/serve/sampler/gpu_sampler.cc | 1 + 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index afde4d1eb5..8568c6ce94 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -90,7 +90,7 @@ class EngineImpl : public Engine { int max_num_tokens = engine_config->max_num_sequence; if (engine_config->speculative_mode != SpeculativeMode::kDisable) { - max_num_tokens *= engine_config->spec_draft_length; + max_num_tokens *= engine_config->spec_draft_length + 1; } LogitProcessor logit_processor = this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index aa51b647c0..f8e7939e44 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -42,8 +42,8 @@ class BatchVerifyActionObj : public EngineActionObj { return {}; } - const auto& [rsentries, draft_lengths, total_draft_length] = GetDraftsToVerify(estate); - ICHECK_EQ(rsentries.size(), draft_lengths.size()); + const auto& [rsentries, verify_lengths, total_verify_length] = GetDraftsToVerify(estate); + ICHECK_EQ(rsentries.size(), verify_lengths.size()); if (rsentries.empty()) { return {}; } @@ -62,7 +62,7 @@ class BatchVerifyActionObj : public EngineActionObj { std::vector> draft_output_tokens; std::vector> draft_output_prob_dist; request_internal_ids.reserve(num_rsentries); - all_tokens_to_verify.reserve(total_draft_length); + all_tokens_to_verify.reserve(total_verify_length); verify_request_mstates.reserve(num_rsentries); rngs.reserve(num_rsentries); generation_cfg.reserve(num_rsentries); @@ -73,12 +73,12 @@ class BatchVerifyActionObj : public EngineActionObj { RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_]; RequestModelState draft_mstate = rsentries[i]->mstates[draft_model_id_]; request_internal_ids.push_back(verify_mstate->internal_id); - ICHECK(!draft_lengths.empty()); - ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_tokens.size()); - ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_prob_dist.size()); - // the last committed token + all the draft tokens but the last one. + ICHECK(!verify_lengths.empty()); + ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_tokens.size() + 1); + ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_prob_dist.size() + 1); + // the last committed token + all the draft tokens. all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().sampled_token_id.first); - for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()) - 1; ++j) { + for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()); ++j) { all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); } verify_request_mstates.push_back(verify_mstate); @@ -95,19 +95,19 @@ class BatchVerifyActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "start verify"); NDArray logits = - models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids, draft_lengths); + models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids, verify_lengths); RECORD_EVENT(trace_recorder_, request_ids, "finish verify"); ICHECK_EQ(logits->ndim, 3); ICHECK_EQ(logits->shape[0], 1); - ICHECK_EQ(logits->shape[1], total_draft_length); + ICHECK_EQ(logits->shape[1], total_verify_length); // - Update logits. std::vector cum_verify_lengths = {0}; cum_verify_lengths.reserve(num_rsentries + 1); for (int i = 0; i < num_rsentries; ++i) { - cum_verify_lengths.push_back(cum_verify_lengths.back() + draft_lengths[i]); + cum_verify_lengths.push_back(cum_verify_lengths.back() + verify_lengths[i]); } - logits = logits.CreateView({total_draft_length, logits->shape[2]}, logits->dtype); + logits = logits.CreateView({total_verify_length, logits->shape[2]}, logits->dtype); logit_processor_->InplaceUpdateLogits(logits, generation_cfg, verify_request_mstates, request_ids, &cum_verify_lengths, &draft_output_tokens); @@ -156,10 +156,10 @@ class BatchVerifyActionObj : public EngineActionObj { struct DraftRequestStateEntries { /*! \brief The request state entries to verify. */ Array draft_rsentries; - /*! \brief The draft length of each request state. */ - std::vector draft_lengths; + /*! \brief The length to verify for each request state. */ + std::vector verify_lengths; /*! \brief The total draft length. */ - int total_draft_length; + int total_verify_length; }; /*! @@ -169,8 +169,8 @@ class BatchVerifyActionObj : public EngineActionObj { * state and input length. */ DraftRequestStateEntries GetDraftsToVerify(EngineState estate) { - std::vector draft_lengths; - int total_draft_length = 0; + std::vector verify_lengths; + int total_verify_length = 0; int total_required_pages = 0; int num_available_pages = models_[verify_model_id_]->GetNumAvailablePages(); @@ -182,24 +182,24 @@ class BatchVerifyActionObj : public EngineActionObj { int draft_length = rsentry->mstates[draft_model_id_]->draft_output_tokens.size(); int num_require_pages = (draft_length + engine_config_->kv_cache_page_size - 1) / engine_config_->kv_cache_page_size; - draft_lengths.push_back(draft_length); + verify_lengths.push_back(draft_length + 1); num_page_requirement.push_back(num_require_pages); - total_draft_length += draft_length; + total_verify_length += draft_length + 1; total_required_pages += num_require_pages; } while (!CanVerify(total_required_pages)) { RequestStateEntry preempted = PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { - total_draft_length -= draft_lengths.back(); + total_verify_length -= verify_lengths.back(); total_required_pages -= num_page_requirement.back(); - draft_lengths.pop_back(); + verify_lengths.pop_back(); num_page_requirement.pop_back(); running_rsentries.pop_back(); } } - return {running_rsentries, draft_lengths, total_draft_length}; + return {running_rsentries, verify_lengths, total_verify_length}; } bool CanVerify(int num_required_pages) { diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index 133c23e8a1..5fb294d1b4 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -459,7 +459,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // No exceeding of the maximum allowed requests that can // run simultaneously. int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable - ? engine_config_->spec_draft_length + ? (engine_config_->spec_draft_length + 1) : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index c3f7491960..8c0999bb71 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -399,7 +399,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { // No exceeding of the maximum allowed requests that can // run simultaneously. int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable - ? engine_config_->spec_draft_length + ? (engine_config_->spec_draft_length + 1) : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index 1f1d2e9eb3..af4cc9615f 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -142,6 +142,7 @@ class GPUSampler : public SamplerObj { const std::vector& rngs, const std::vector>& draft_output_tokens, const std::vector>& draft_output_prob_dist) final { + NVTXScopedRange nvtx_scope("BatchVerifyDraftTokensWithProbAfterTopP"); std::vector> sample_results; // probs_on_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); From d7c5a6e300d9d30bdddddfeadce1ce23744ba02b Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 23 Apr 2024 18:31:09 -0400 Subject: [PATCH 230/531] [Sampler] Prob renormalization with top p for spec decoding (#2201) This PR introduces a renormalization interface with regard to top-p values for speculative decoding. This is helpful for simplifying the logic of speculative decoding verification stage, as all probs have been already updated with the top-p values and no top-p needs to be taken into consideration. So for speculative decoding, we always renorm the probability distribution before sampling/verifying. For non speculative decoding mode, we keep using the previous flow, which applies top-p together when sampling. Co-authored-by: Wuwei Lin --- cpp/serve/engine_actions/batch_decode.cc | 2 +- cpp/serve/engine_actions/batch_draft.cc | 6 +- cpp/serve/engine_actions/batch_verify.cc | 12 +- cpp/serve/engine_actions/eagle_batch_draft.cc | 6 +- .../engine_actions/eagle_batch_verify.cc | 18 +- .../eagle_new_request_prefill.cc | 6 +- .../engine_actions/new_request_prefill.cc | 2 +- cpp/serve/function_table.cc | 1 + cpp/serve/function_table.h | 1 + cpp/serve/sampler/cpu_sampler.cc | 281 ++++++++++++++---- cpp/serve/sampler/gpu_sampler.cc | 254 +++++++++++----- cpp/serve/sampler/sampler.h | 58 +++- .../mlc_llm/compiler_pass/attach_sampler.py | 59 +++- python/mlc_llm/help.py | 2 +- python/mlc_llm/serve/engine.py | 4 +- python/mlc_llm/serve/engine_base.py | 2 +- 16 files changed, 536 insertions(+), 178 deletions(-) diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index 94e441279a..36acc6b06e 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -114,7 +114,7 @@ class BatchDecodeActionObj : public EngineActionObj { // Fill range [0, num_rsentries) into `sample_indices`. std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - std::vector sample_results = sampler_->BatchSampleTokens( + std::vector sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP( probs_on_device, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index b56f7fa9b6..c1ddeb6e4e 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -116,8 +116,10 @@ class BatchDraftActionObj : public EngineActionObj { std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); std::vector prob_dist; - std::vector sample_results = sampler_->BatchSampleTokens( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index f8e7939e44..42c9bbe018 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -7,6 +7,7 @@ #include #include +#include #include "../../random.h" #include "../config.h" @@ -115,9 +116,14 @@ class BatchVerifyActionObj : public EngineActionObj { NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits( logits, generation_cfg, request_ids, &cum_verify_lengths); - std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokens( - probs_on_device, request_ids, cum_verify_lengths, generation_cfg, rngs, draft_output_tokens, - draft_output_prob_dist); + std::vector sample_indices(num_rsentries); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector> sample_results_arr = + sampler_->BatchVerifyDraftTokensWithProbAfterTopP( + renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs, + draft_output_tokens, draft_output_prob_dist); ICHECK_EQ(sample_results_arr.size(), num_rsentries); for (int i = 0; i < num_rsentries; ++i) { diff --git a/cpp/serve/engine_actions/eagle_batch_draft.cc b/cpp/serve/engine_actions/eagle_batch_draft.cc index 50393c38a2..fde314a5c5 100644 --- a/cpp/serve/engine_actions/eagle_batch_draft.cc +++ b/cpp/serve/engine_actions/eagle_batch_draft.cc @@ -145,8 +145,10 @@ class EagleBatchDraftActionObj : public EngineActionObj { std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); std::vector prob_dist; - std::vector sample_results = sampler_->BatchSampleTokens( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index 6718afaccf..b259417050 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -128,10 +128,14 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // - Compute probability distributions. NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits( logits, generation_cfg, request_ids, &cum_verify_lengths); - - std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokens( - probs_on_device, request_ids, cum_verify_lengths, generation_cfg, rngs, draft_output_tokens, - draft_output_prob_dist); + std::vector sample_indices(num_rsentries); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector> sample_results_arr = + sampler_->BatchVerifyDraftTokensWithProbAfterTopP( + renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs, + draft_output_tokens, draft_output_prob_dist); ICHECK_EQ(sample_results_arr.size(), num_rsentries); std::vector last_hidden_states; @@ -229,8 +233,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj { std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); std::vector prob_dist; - std::vector sample_results = sampler_->BatchSampleTokens( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index 5fb294d1b4..a687e7eb7f 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -277,8 +277,10 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } } std::vector prob_dist; - std::vector sample_results = sampler_->BatchSampleTokens( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); // - Update the committed tokens of states. diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index 8c0999bb71..c80c5e0ede 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -229,7 +229,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { rsentry_activated.push_back(true); } } - std::vector sample_results = sampler_->BatchSampleTokens( + std::vector sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP( probs_on_device, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 792f98094b..b33d3709e8 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -266,6 +266,7 @@ void FunctionTable::_InitFunctions() { gpu_sample_with_top_p_func_ = mod->GetFunction("sample_with_top_p", true); gpu_sampler_take_probs_func_ = mod->GetFunction("sampler_take_probs", true); gpu_verify_draft_tokens_func_ = mod->GetFunction("sampler_verify_draft_tokens", true); + gpu_renormalize_by_top_p_func_ = mod->GetFunction("renormalize_by_top_p", true); } this->nd_view_func_ = get_global_func("vm.builtin.reshape"); this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of"); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index 5f08a9ba5c..b6ea3287ad 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -105,6 +105,7 @@ struct FunctionTable { PackedFunc gpu_sample_with_top_p_func_; PackedFunc gpu_sampler_take_probs_func_; PackedFunc gpu_verify_draft_tokens_func_; + PackedFunc gpu_renormalize_by_top_p_func_; PackedFunc nd_view_func_; PackedFunc nd_get_shape_func_; PackedFunc nd_copy_embedding_to_offset_func_; diff --git a/cpp/serve/sampler/cpu_sampler.cc b/cpp/serve/sampler/cpu_sampler.cc index 02b7e2a81d..98080c979d 100644 --- a/cpp/serve/sampler/cpu_sampler.cc +++ b/cpp/serve/sampler/cpu_sampler.cc @@ -8,6 +8,7 @@ #include #include +#include #include #include "../../random.h" @@ -43,12 +44,7 @@ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_o ICHECK(prob.IsContiguous()); ICHECK(prob.DataType() == DataType::Float(32)); - - if (prob->device.device_type != kDLCPU) { - prob = prob.CopyTo(DLDevice{kDLCPU, 0}); - } - - ICHECK(prob->device.device_type == kDLCPU); + ICHECK_EQ(prob->device.device_type, DLDeviceType::kDLCPU); int64_t ndata = prob->shape[prob->ndim - 1]; const float* __restrict p_prob = @@ -186,6 +182,98 @@ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_o return {sampled_index.second, sampled_index.first}; } +/*! + * \brief Renormalize the probability distribution by the top p value. + * \param prob The input batch of probability distributions. + * \param unit_offset The offset specifying which distribution to output + * \param top_p The top p value for renormalization. + * \param eps A small epsilon value for comparison stability. + */ +void RenormalizeProbByTopP(NDArray prob, int unit_offset, double top_p, double eps) { + // prob: (*, v) + // The prob array may have arbitrary ndim and shape. + // The last dimension corresponds to the prob distribution size. + // We use the `unit_offset` parameter to determine which slice + // of the prob array we will renormalize. + ICHECK(prob.IsContiguous()); + ICHECK(prob.DataType() == DataType::Float(32)); + ICHECK_EQ(prob->device.device_type, DLDeviceType::kDLCPU); + + int vocab_size = prob->shape[prob->ndim - 1]; + float* __restrict p_prob = + static_cast(__builtin_assume_aligned(prob->data, 4)) + (unit_offset * vocab_size); + + // We manually choice the cutoff values of "top_p / 256" and "top_p / 8192". + // In most of the cases, only one round is needed. + std::vector cutoff_values{top_p / 256, top_p / 8192, 0.0f}; + + // Create the upper partition vector and the lower partition rolling vectors. + std::vector upper_partition; + std::vector lower_partitions[2]; + upper_partition.reserve(vocab_size); + lower_partitions[0].reserve(vocab_size); + lower_partitions[1].reserve(vocab_size); + float upper_partition_sum = 0.0; + for (int round = 0; round < static_cast(cutoff_values.size()); ++round) { + const float* lower_partition_begin; + const float* lower_partition_end; + if (round == 0) { + lower_partition_begin = p_prob; + lower_partition_end = p_prob + vocab_size; + } else { + int idx = (round - 1) & 1; + lower_partition_begin = lower_partitions[idx].data(); + lower_partition_end = lower_partitions[idx].data() + lower_partitions[idx].size(); + } + + // - Partition the last round lower partition into upper and lower + // based on the new cutoff value. + std::vector& lower_partition = lower_partitions[round & 1]; + lower_partition.clear(); + for (const float* ptr = lower_partition_begin; ptr != lower_partition_end; ++ptr) { + if (*ptr >= cutoff_values[round]) { + upper_partition.push_back(*ptr); + upper_partition_sum += *ptr; + } else { + lower_partition.push_back(*ptr); + } + } + // - If the upper partition sum is at least top p, exit the loop. + if (upper_partition_sum >= top_p - eps) { + break; + } + } + + // - Sort the upper partition in descending order. + std::sort(upper_partition.begin(), upper_partition.end(), std::greater<>()); + // - Find the top p boundary prob value. + float boundary_value = -1.0; + upper_partition_sum = 0.0; + for (float upper_value : upper_partition) { + upper_partition_sum += upper_value; + if (upper_partition_sum >= top_p - eps) { + boundary_value = upper_value; + break; + } + } + // - Mask all values smaller than the boundary to 0. + float renormalize_sum = 0.0; + std::vector upper_partition_indices; + upper_partition_indices.reserve(vocab_size); + for (int i = 0; i < vocab_size; ++i) { + if (p_prob[i] >= boundary_value) { + upper_partition_indices.push_back(i); + renormalize_sum += p_prob[i]; + } else { + p_prob[i] = 0.0; + } + } + // - Renormalize. + for (int idx : upper_partition_indices) { + p_prob[idx] /= renormalize_sum; + } +} + namespace detail { /*! \brief Implementation of getting top probs on CPU. */ @@ -266,68 +354,87 @@ class CPUSampler : public SamplerObj { } } - std::vector BatchSampleTokens(NDArray probs_on_device, // - const std::vector& sample_indices, // - const Array& request_ids, // - const Array& generation_cfg, // - const std::vector& rngs, // - std::vector* output_prob_dist) final { + NDArray BatchRenormalizeProbsByTopP(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg) final { // probs_on_device: (n, v) - RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); CHECK_EQ(probs_on_device->ndim, 2); // - Copy probs to CPU RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); - NDArray probs_host = CopyProbsToCPU(probs_on_device); + NDArray probs_on_host = CopyProbsToCPU(probs_on_device); RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); - - // - Sample tokens from probabilities. - int n = request_ids.size(); - ICHECK_EQ(generation_cfg.size(), n); - ICHECK_EQ(rngs.size(), n); - - std::vector sample_results; - sample_results.resize(n); - if (output_prob_dist) { - output_prob_dist->resize(n); + int num_samples = sample_indices.size(); + int num_probs = probs_on_device->shape[0]; + int vocab_size = probs_on_device->shape[1]; + ICHECK_EQ(request_ids.size(), num_samples); + ICHECK_EQ(generation_cfg.size(), num_samples); + + std::vector top_p_indices; + std::vector top_p_values; + for (int i = 0; i < num_samples; ++i) { + if (top_p_indices.empty() || top_p_indices.back() != sample_indices[i]) { + top_p_indices.push_back(sample_indices[i]); + top_p_values.push_back(generation_cfg[i]->top_p); + } else { + CHECK(fabs(top_p_values.back() - generation_cfg[i]->top_p) < eps_) + << "Sampler requires the top_p values for each prob distribution are the same."; + } + } + if (top_p_indices.empty()) { + // Return if no top p needs to apply. + return probs_on_host; } tvm::runtime::parallel_for_with_threading_backend( - [this, &sample_results, &probs_host, &generation_cfg, &rngs, &request_ids, sample_indices, - output_prob_dist](int i) { - RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); - // Sample top p from probability. - sample_results[i].sampled_token_id = SampleTopPFromProb( - probs_host, i, sample_indices[i], - generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, - rngs[i]->GetRandomNumber(), output_prob_dist); - if (output_prob_dist == nullptr) { - // When `output_prob_dist` is not nullptr, it means right now - // we are sampling for a small model in speculation, in which - // case we do not need to get the top probs. - sample_results[i].top_prob_tokens = - ComputeTopProbs(probs_host, i, generation_cfg[i]->top_logprobs); - } - RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token"); + [this, &probs_on_host, &request_ids, &top_p_indices, &top_p_values](int i) { + RECORD_EVENT(this->trace_recorder_, request_ids[i], "start renormalize by top p"); + RenormalizeProbByTopP(probs_on_host, top_p_indices[i], top_p_values[i], eps_); + RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish renormalize by top p"); }, - 0, n); - RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); - return sample_results; + 0, static_cast(top_p_indices.size())); + + return probs_on_host; } - std::vector> BatchVerifyDraftTokens( - NDArray probs_on_device, const Array& request_ids, - const std::vector& cum_verify_lengths, const Array& generation_cfg, - const std::vector& rngs, - const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { + std::vector BatchSampleTokensWithProbBeforeTopP( + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs) final { // probs_on_device: (n, v) - RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); CHECK_EQ(probs_on_device->ndim, 2); // - Copy probs to CPU RECORD_EVENT(trace_recorder_, request_ids, "start copy probs to CPU"); - NDArray probs_host = CopyProbsToCPU(probs_on_device); + NDArray probs_on_host = CopyProbsToCPU(probs_on_device); RECORD_EVENT(trace_recorder_, request_ids, "finish copy probs to CPU"); + return BatchSampleTokensImpl(probs_on_host, sample_indices, request_ids, generation_cfg, rngs, + /*top_p_applied=*/false); + } + + std::vector BatchSampleTokensWithProbAfterTopP( + NDArray probs_on_host, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + std::vector* output_prob_dist) final { + return BatchSampleTokensImpl(probs_on_host, sample_indices, request_ids, generation_cfg, rngs, + /*top_p_applied=*/true, output_prob_dist); + } + + std::vector> BatchVerifyDraftTokensWithProbAfterTopP( + NDArray probs_on_host, const Array& request_ids, + const std::vector& cum_verify_lengths, const Array& generation_cfg, + const std::vector& rngs, + const std::vector>& draft_output_tokens, + const std::vector>& draft_output_prob_dist) final { + // probs_on_host: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); + CHECK_EQ(probs_on_host->ndim, 2); + int num_sequence = static_cast(cum_verify_lengths.size()) - 1; CHECK_EQ(rngs.size(), num_sequence); CHECK_EQ(draft_output_tokens.size(), num_sequence); @@ -337,8 +444,8 @@ class CPUSampler : public SamplerObj { sample_results.resize(num_sequence); float* __restrict global_p_probs = - static_cast(__builtin_assume_aligned(probs_host->data, 4)); - int vocab_size = probs_host->shape[1]; + static_cast(__builtin_assume_aligned(probs_on_host->data, 4)); + int vocab_size = probs_on_host->shape[1]; tvm::runtime::parallel_for_with_threading_backend( [&](int i) { @@ -355,7 +462,7 @@ class CPUSampler : public SamplerObj { if (p_value >= q_value) { sample_results[i].push_back( SampleResult{{cur_token, p_value}, - ComputeTopProbs(probs_host, verify_start + cur_token_idx, + ComputeTopProbs(probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs)}); continue; } @@ -363,7 +470,7 @@ class CPUSampler : public SamplerObj { if (r < p_value / (q_value + eps_)) { sample_results[i].push_back( SampleResult{{cur_token, p_value}, - ComputeTopProbs(probs_host, verify_start + cur_token_idx, + ComputeTopProbs(probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs)}); continue; } @@ -388,11 +495,10 @@ class CPUSampler : public SamplerObj { // sample a new token from the new distribution SampleResult sample_result; sample_result.sampled_token_id = SampleTopPFromProb( - probs_host, verify_start + cur_token_idx, verify_start + cur_token_idx, - generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, - rngs[i]->GetRandomNumber()); + probs_on_host, verify_start + cur_token_idx, verify_start + cur_token_idx, + /*top_p=*/1.0f, rngs[i]->GetRandomNumber()); sample_result.top_prob_tokens = ComputeTopProbs( - probs_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); + probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); sample_results[i].push_back(sample_result); break; } @@ -403,11 +509,10 @@ class CPUSampler : public SamplerObj { SampleResult sample_result; // sample a new token from the original distribution sample_result.sampled_token_id = SampleTopPFromProb( - probs_host, verify_start + cur_token_idx, verify_start + cur_token_idx, - generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p, - rngs[i]->GetRandomNumber()); + probs_on_host, verify_start + cur_token_idx, verify_start + cur_token_idx, + /*top_p=*/1.0f, rngs[i]->GetRandomNumber()); sample_result.top_prob_tokens = ComputeTopProbs( - probs_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); + probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs); sample_results[i].push_back(sample_result); } }, @@ -417,6 +522,56 @@ class CPUSampler : public SamplerObj { } private: + std::vector BatchSampleTokensImpl( + NDArray probs_on_host, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + bool top_p_applied, // + std::vector* output_prob_dist = nullptr) { + // probs_on_host: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); + ICHECK_EQ(probs_on_host->ndim, 2); + ICHECK_EQ(probs_on_host->device.device_type, DLDeviceType::kDLCPU); + + // - Sample tokens from probabilities. + int n = request_ids.size(); + ICHECK_EQ(generation_cfg.size(), n); + ICHECK_EQ(rngs.size(), n); + + std::vector sample_results; + sample_results.resize(n); + if (output_prob_dist) { + output_prob_dist->resize(n); + } + + tvm::runtime::parallel_for_with_threading_backend( + [this, &sample_results, &probs_on_host, &generation_cfg, &rngs, &request_ids, top_p_applied, + sample_indices, output_prob_dist](int i) { + RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); + // Sample top p from probability. + double top_p = + top_p_applied + ? 1.0f + : (generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p); + sample_results[i].sampled_token_id = + SampleTopPFromProb(probs_on_host, i, sample_indices[i], top_p, + rngs[i]->GetRandomNumber(), output_prob_dist); + if (output_prob_dist == nullptr) { + // When `output_prob_dist` is not nullptr, it means right now + // we are sampling for a small model in speculation, in which + // case we do not need to get the top probs. + sample_results[i].top_prob_tokens = + ComputeTopProbs(probs_on_host, i, generation_cfg[i]->top_logprobs); + } + RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token"); + }, + 0, n); + RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + return sample_results; + } + /*! \brief Copy prob distributions from device to CPU. */ NDArray CopyProbsToCPU(NDArray probs_on_device) { // probs_on_device: (n, v) diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index af4cc9615f..c80a846b19 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -44,6 +44,7 @@ class GPUSampler : public SamplerObj { gpu_sample_with_top_p_func_(ft->gpu_sample_with_top_p_func_), gpu_sampler_take_probs_func_(ft->gpu_sampler_take_probs_func_), gpu_verify_draft_tokens_func_(ft->gpu_verify_draft_tokens_func_), + gpu_renormalize_by_top_p_func_(ft->gpu_renormalize_by_top_p_func_), trace_recorder_(std::move(trace_recorder)) { ICHECK(gpu_multinomial_from_uniform_func_.defined()); ICHECK(gpu_argsort_probs_func_.defined()); @@ -57,6 +58,10 @@ class GPUSampler : public SamplerObj { sample_indices_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); top_p_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); top_prob_offsets_host_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device_cpu); + draft_tokens_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); + token_tree_first_child_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); + token_tree_next_sibling_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); + token_tree_parent_ptr_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); sampled_token_ids_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); sampled_probs_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); top_prob_probs_host_ = NDArray::Empty({max_num_sample * 5}, dtype_f32_, device_cpu); @@ -66,6 +71,11 @@ class GPUSampler : public SamplerObj { sample_indices_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); top_p_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); top_prob_offsets_device_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device); + draft_probs_device_ = NDArray::Empty({max_num_sample, vocab_size}, dtype_f32_, device); + draft_tokens_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + token_tree_first_child_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + token_tree_next_sibling_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + token_tree_parent_ptr_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); // If the device is CUDA/ROCm, we create a standalone copy stream, in // purpose to hide the latency of auxiliary stream copy. @@ -84,59 +94,71 @@ class GPUSampler : public SamplerObj { } } - std::vector BatchSampleTokens(NDArray probs_on_device, // - const std::vector& sample_indices, // - const Array& request_ids, // - const Array& generation_cfg, // - const std::vector& rngs, // - std::vector* output_prob_dist) final { - NVTXScopedRange nvtx_scope("BatchSampleTokens"); + NDArray BatchRenormalizeProbsByTopP(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg) final { + NVTXScopedRange nvtx_scope("BatchRenormalizeProbsByTopP"); // probs_on_device: (n, v) - RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); + RECORD_EVENT(trace_recorder_, request_ids, "start renormalization by top p"); CHECK_EQ(probs_on_device->ndim, 2); int num_samples = sample_indices.size(); int num_probs = probs_on_device->shape[0]; int vocab_size = probs_on_device->shape[1]; - if (output_prob_dist != nullptr) { - ICHECK(output_prob_dist->empty()); - output_prob_dist->reserve(num_probs); - for (int i = 0; i < num_probs; ++i) { - NDArray prob_dist = NDArray::Empty({vocab_size}, dtype_f32_, device_); - float* p_prob = static_cast(probs_on_device->data) + i * vocab_size; - prob_dist.CopyFromBytes(p_prob, vocab_size * sizeof(float)); - output_prob_dist->push_back(std::move(prob_dist)); - } - } + ICHECK_LE(num_probs, max_num_sample_); ICHECK_EQ(request_ids.size(), num_samples); ICHECK_EQ(generation_cfg.size(), num_samples); - ICHECK_EQ(rngs.size(), num_samples); - // Since `num_samples` may be larger than `max_num_sample_` in some cases, - // we apply chunking to support large `num_samples`. - std::vector sample_results; - if (num_samples <= max_num_sample_) { - sample_results = ChunkSampleTokensImpl(probs_on_device, sample_indices, generation_cfg, rngs); - } else { - for (int chunk_start = 0; chunk_start < num_samples; chunk_start += max_num_sample_) { - int chunk_end = std::min(chunk_start + max_num_sample_, num_samples); - std::vector sample_indices_chunk(sample_indices.begin() + chunk_start, - sample_indices.begin() + chunk_end); - Array generation_cfg_chunk(generation_cfg.begin() + chunk_start, - generation_cfg.begin() + chunk_end); - std::vector rngs_chunk(rngs.begin() + chunk_start, - rngs.begin() + chunk_end); - std::vector sample_results_chunk = ChunkSampleTokensImpl( - probs_on_device, sample_indices_chunk, generation_cfg_chunk, rngs_chunk); - sample_results.insert(sample_results.end(), sample_results_chunk.begin(), - sample_results_chunk.end()); - } + // - Check if there is need for applying top p. + bool need_top_p = CheckTopP(generation_cfg, sample_indices, num_probs, num_samples, vocab_size); + if (!need_top_p) { + return probs_on_device; } - RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); - return sample_results; + // - Argsort the probability. + Array argsort_results = gpu_argsort_probs_func_(probs_on_device); + ICHECK_EQ(argsort_results.size(), 2); + NDArray sorted_probs_on_device = argsort_results[0]; + NDArray sorted_indices_on_device = argsort_results[1]; + + // - Copy auxiliary array for top-p. + NDArray top_p_host = top_p_host_.CreateView({num_probs}, dtype_f32_); + NDArray top_p_device = top_p_device_.CreateView({num_probs}, dtype_f32_); + CopyArray(/*src=*/top_p_host, /*dst=*/top_p_device, copy_stream_); + SyncCopyStream(device_, compute_stream_, copy_stream_); + + // - Renormalize the prob with top p. + NDArray renormed_probs_on_device = + gpu_renormalize_by_top_p_func_(probs_on_device, sorted_probs_on_device, top_p_device); + + RECORD_EVENT(trace_recorder_, request_ids, "finish renormalization by top p"); + return renormed_probs_on_device; + } + + std::vector BatchSampleTokensWithProbBeforeTopP( + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs) final { + NVTXScopedRange nvtx_scope("BatchSampleTokensWithProbBeforeTopP"); + return BatchSampleTokensImpl(std::move(probs_on_device), sample_indices, request_ids, + generation_cfg, rngs, /*top_p_applied=*/false); } - std::vector> BatchVerifyDraftTokens( + std::vector BatchSampleTokensWithProbAfterTopP( + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + std::vector* output_prob_dist = nullptr) final { + NVTXScopedRange nvtx_scope("BatchSampleTokensWithProbAfterTopP"); + return BatchSampleTokensImpl(std::move(probs_on_device), sample_indices, request_ids, + generation_cfg, rngs, /*top_p_applied=*/true, output_prob_dist); + } + + std::vector> BatchVerifyDraftTokensWithProbAfterTopP( NDArray probs_on_device, const Array& request_ids, const std::vector& cum_verify_lengths, const Array& generation_cfg, const std::vector& rngs, @@ -157,10 +179,10 @@ class GPUSampler : public SamplerObj { int num_nodes = cum_verify_lengths.back(); NDArray uniform_samples_host = uniform_samples_host_.CreateView({num_nodes}, dtype_f32_); NDArray uniform_samples_device = uniform_samples_device_.CreateView({num_nodes}, dtype_f32_); - NDArray draft_probs_device = NDArray::Empty({num_nodes, vocab_size_}, dtype_f32_, device_); - NDArray draft_tokens_device = NDArray::Empty({num_nodes}, dtype_i32_, device_); - NDArray draft_tokens_host = - NDArray::Empty({num_nodes}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0}); + NDArray draft_probs_device = + draft_probs_device_.CreateView({num_nodes, vocab_size_}, dtype_f32_); + NDArray draft_tokens_host = draft_tokens_host_.CreateView({num_nodes}, dtype_i32_); + NDArray draft_tokens_device = draft_tokens_device_.CreateView({num_nodes}, dtype_i32_); // Concat draft prob distributions to a ragged tensor (num_nodes, vocab_size) for (int i = 0; i < num_sequence; i++) { @@ -197,32 +219,33 @@ class GPUSampler : public SamplerObj { } CopyArray(uniform_samples_host, uniform_samples_device, copy_stream_); - // This should be refactored to use the cached tensors - NDArray token_tree_first_child_device = NDArray::Empty({num_nodes}, dtype_i32_, device_); - NDArray token_tree_next_sibling_device = NDArray::Empty({num_nodes}, dtype_i32_, device_); - NDArray token_tree_parent_ptr_device = NDArray::Empty({num_sequence}, dtype_i32_, device_); NDArray token_tree_first_child_host = - NDArray::Empty({num_nodes}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0}); + token_tree_first_child_host_.CreateView({num_nodes}, dtype_i32_); + NDArray token_tree_first_child_device = + token_tree_first_child_device_.CreateView({num_nodes}, dtype_i32_); NDArray token_tree_next_sibling_host = - NDArray::Empty({num_nodes}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0}); + token_tree_next_sibling_host_.CreateView({num_nodes}, dtype_i32_); + NDArray token_tree_next_sibling_device = + token_tree_next_sibling_device_.CreateView({num_nodes}, dtype_i32_); NDArray token_tree_parent_ptr_host = - NDArray::Empty({num_sequence}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0}); - NDArray token_tree_child_to_parent_host = - NDArray::Empty({num_nodes}, dtype_i32_, DLDevice{DLDeviceType::kDLCPU, 0}); + token_tree_parent_ptr_host_.CreateView({num_sequence}, dtype_i32_); + NDArray token_tree_parent_ptr_device = + token_tree_parent_ptr_device_.CreateView({num_sequence}, dtype_i32_); + std::vector token_tree_child_to_parent(/*n=*/num_nodes); // Build the tree structure on CPU for (int i = 0; i < num_sequence; i++) { // Assuming no tree structure for now int start = cum_verify_lengths[i]; int end = cum_verify_lengths[i + 1]; - ICHECK_EQ(end - start, 2); // one committed token and assuming only one draft token - static_cast(token_tree_child_to_parent_host->data)[start] = -1; // root has no parent + ICHECK_GE(end - start, 2); + token_tree_child_to_parent[start] = -1; // root has no parent for (int j = 0; j < end - start; j++) { int cur_node = j + start; int child_node = j + 1 >= end - start ? -1 : cur_node + 1; static_cast(token_tree_first_child_host->data)[cur_node] = child_node; if (child_node != -1) { - static_cast(token_tree_child_to_parent_host->data)[child_node] = cur_node; + token_tree_child_to_parent[child_node] = cur_node; } static_cast(token_tree_next_sibling_host->data)[cur_node] = -1; } @@ -250,7 +273,7 @@ class GPUSampler : public SamplerObj { int last_accepted = static_cast(token_tree_parent_ptr_host->data)[i]; int num_accepted = 0; for (int cur_node = last_accepted; cur_node != start; - cur_node = static_cast(token_tree_child_to_parent_host->data)[cur_node]) { + cur_node = token_tree_child_to_parent[cur_node]) { sample_results[i].push_back(draft_output_tokens[i][cur_node - start - 1]); num_accepted++; } @@ -258,10 +281,8 @@ class GPUSampler : public SamplerObj { sample_indices.push_back(last_accepted); } std::vector additional_sample_result; - // This only works for top-p = 1. To enable top-p, we need to normalize the probs before - // verifying. - additional_sample_result = this->BatchSampleTokens(probs_on_device, sample_indices, request_ids, - generation_cfg, rngs, nullptr); + additional_sample_result = this->BatchSampleTokensWithProbAfterTopP( + probs_on_device, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(additional_sample_result.size(), num_sequence); for (int i = 0; i < num_sequence; i++) { sample_results[i].push_back(additional_sample_result[i]); @@ -272,10 +293,67 @@ class GPUSampler : public SamplerObj { } private: + std::vector BatchSampleTokensImpl( + NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + bool top_p_applied, // + std::vector* output_prob_dist = nullptr) { + // probs_on_device: (n, v) + RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); + CHECK_EQ(probs_on_device->ndim, 2); + CHECK_EQ(probs_on_device->device.device_id, device_.device_id); + CHECK_EQ(probs_on_device->device.device_type, device_.device_type); + int num_samples = sample_indices.size(); + int num_probs = probs_on_device->shape[0]; + int vocab_size = probs_on_device->shape[1]; + if (output_prob_dist != nullptr) { + ICHECK(output_prob_dist->empty()); + output_prob_dist->reserve(num_probs); + for (int i = 0; i < num_probs; ++i) { + NDArray prob_dist = NDArray::Empty({vocab_size}, dtype_f32_, device_); + float* p_prob = static_cast(probs_on_device->data) + i * vocab_size; + prob_dist.CopyFromBytes(p_prob, vocab_size * sizeof(float)); + output_prob_dist->push_back(std::move(prob_dist)); + } + } + ICHECK_EQ(request_ids.size(), num_samples); + ICHECK_EQ(generation_cfg.size(), num_samples); + ICHECK_EQ(rngs.size(), num_samples); + + // Since `num_samples` may be larger than `max_num_sample_` in some cases, + // we apply chunking to support large `num_samples`. + std::vector sample_results; + if (num_samples <= max_num_sample_) { + sample_results = ChunkSampleTokensImpl(probs_on_device, sample_indices, generation_cfg, rngs, + top_p_applied); + } else { + for (int chunk_start = 0; chunk_start < num_samples; chunk_start += max_num_sample_) { + int chunk_end = std::min(chunk_start + max_num_sample_, num_samples); + std::vector sample_indices_chunk(sample_indices.begin() + chunk_start, + sample_indices.begin() + chunk_end); + Array generation_cfg_chunk(generation_cfg.begin() + chunk_start, + generation_cfg.begin() + chunk_end); + std::vector rngs_chunk(rngs.begin() + chunk_start, + rngs.begin() + chunk_end); + std::vector sample_results_chunk = ChunkSampleTokensImpl( + probs_on_device, sample_indices_chunk, generation_cfg_chunk, rngs_chunk, top_p_applied); + sample_results.insert(sample_results.end(), sample_results_chunk.begin(), + sample_results_chunk.end()); + } + } + + RECORD_EVENT(trace_recorder_, request_ids, "finish sampling"); + return sample_results; + } + std::vector ChunkSampleTokensImpl(NDArray probs_on_device, // const std::vector& sample_indices, // const Array& generation_cfg, // - const std::vector& rngs) { + const std::vector& rngs, // + bool top_p_applied) { // probs_on_device: (n, v) int num_samples = sample_indices.size(); int num_probs = probs_on_device->shape[0]; @@ -289,11 +367,13 @@ class GPUSampler : public SamplerObj { // - Check if there is need for applying top p or prob values, // so that argsort is needed. bool need_top_p = false; - bool need_prob_values = false; + if (!top_p_applied) { + need_top_p = CheckTopP(generation_cfg, sample_indices, num_probs, num_samples, vocab_size); + } // The indptr array of the number of top probs for each sample. std::vector top_prob_offset_indptr; - CheckTopPAndProbValues(generation_cfg, sample_indices, num_probs, num_samples, vocab_size, - &need_top_p, &need_prob_values, &top_prob_offset_indptr); + bool need_prob_values = CheckProbValues(generation_cfg, sample_indices, num_probs, num_samples, + vocab_size, &top_prob_offset_indptr); // - Sample tokens on GPU, and take out the probability values if needed. std::vector device_arrays = @@ -353,30 +433,39 @@ class GPUSampler : public SamplerObj { return {uniform_samples_device, sample_indices_device}; } - /*! \brief Check if top p and prob values are needed, and collect info when necessary. */ - void CheckTopPAndProbValues(const Array& generation_cfg, - const std::vector& sample_indices, int num_probs, - int num_samples, int vocab_size, bool* need_top_p, - bool* need_prob_values, std::vector* top_prob_offset_indptr) { - top_prob_offset_indptr->reserve(num_samples + 1); - top_prob_offset_indptr->push_back(0); + /*! \brief Check if top p is needed. Update host top p array in place. */ + bool CheckTopP(const Array& generation_cfg, + const std::vector& sample_indices, int num_probs, int num_samples, + int vocab_size) { // Initialize top p values with -1. float* p_top_p = static_cast(top_p_host_->data); for (int i = 0; i < num_probs; ++i) { p_top_p[i] = -1.0; } - int* p_top_prob_offsets = static_cast(top_prob_offsets_host_->data); - int num_top_probs = 0; + bool need_top_p = false; for (int i = 0; i < num_samples; ++i) { if (p_top_p[sample_indices[i]] == -1.0) { p_top_p[sample_indices[i]] = generation_cfg[i]->top_p; - *need_top_p |= generation_cfg[i]->top_p != 1.0; + need_top_p |= generation_cfg[i]->top_p != 1.0; } else { CHECK(fabs(p_top_p[sample_indices[i]] - generation_cfg[i]->top_p) < eps_) << "GPU sampler requires the top_p values for each prob distribution are the same."; } + } + return need_top_p; + } - *need_prob_values |= generation_cfg[i]->logprobs; + /*! \brief Check whether prob values are needed, and collect info when necessary. */ + bool CheckProbValues(const Array& generation_cfg, + const std::vector& sample_indices, int num_probs, int num_samples, + int vocab_size, std::vector* top_prob_offset_indptr) { + top_prob_offset_indptr->reserve(num_samples + 1); + top_prob_offset_indptr->push_back(0); + int* p_top_prob_offsets = static_cast(top_prob_offsets_host_->data); + int num_top_probs = 0; + bool need_prob_values = false; + for (int i = 0; i < num_samples; ++i) { + need_prob_values |= generation_cfg[i]->logprobs; for (int j = 0; j < generation_cfg[i]->top_logprobs; ++j) { p_top_prob_offsets[num_top_probs++] = sample_indices[i] * vocab_size + j; } @@ -384,6 +473,7 @@ class GPUSampler : public SamplerObj { generation_cfg[i]->top_logprobs); } ICHECK_EQ(num_top_probs, top_prob_offset_indptr->back()); + return need_prob_values; } /*! \brief Sample tokens on GPU. Take out the probability values when needed. */ @@ -507,11 +597,16 @@ class GPUSampler : public SamplerObj { PackedFunc gpu_sample_with_top_p_func_; PackedFunc gpu_sampler_take_probs_func_; PackedFunc gpu_verify_draft_tokens_func_; + PackedFunc gpu_renormalize_by_top_p_func_; // Auxiliary NDArrays on CPU NDArray uniform_samples_host_; NDArray sample_indices_host_; NDArray top_p_host_; NDArray top_prob_offsets_host_; + NDArray draft_tokens_host_; + NDArray token_tree_first_child_host_; + NDArray token_tree_next_sibling_host_; + NDArray token_tree_parent_ptr_host_; NDArray sampled_token_ids_host_; NDArray sampled_probs_host_; NDArray top_prob_probs_host_; @@ -521,6 +616,11 @@ class GPUSampler : public SamplerObj { NDArray sample_indices_device_; NDArray top_p_device_; NDArray top_prob_offsets_device_; + NDArray draft_probs_device_; + NDArray draft_tokens_device_; + NDArray token_tree_first_child_device_; + NDArray token_tree_next_sibling_device_; + NDArray token_tree_parent_ptr_device_; // The event trace recorder for requests. */ Optional trace_recorder_; // The device stream for the default computation operations. diff --git a/cpp/serve/sampler/sampler.h b/cpp/serve/sampler/sampler.h index 03d031bdb7..7943231e55 100644 --- a/cpp/serve/sampler/sampler.h +++ b/cpp/serve/sampler/sampler.h @@ -26,14 +26,33 @@ using namespace tvm::runtime; /*! * \brief The base class of runtime sampler. - * Its main function is `BatchSampleTokens`, which takes a batch of + * Its main function is `BatchSampleTokensWithProbBeforeTopP`, which takes a batch of * logits and corresponding configuration, and sample one token * for each instance of the batch. */ class SamplerObj : public Object { public: + /*! + * \brief Renormalize the input batch of probability distributions with top p values. + * \param probs_on_device The batch of prob distributions before normalization. + * \param sample_indices Specifying which request we will sample for + * in i-th output for the sampling later on. + * The output result of the sampling will be as follow: + * result[i] = sample_from(prob_on_device[sample_indices[i],:], generation_config[i])); + * For renormalization, the sample indices are used for determine the top-p grouping. + * \param request_ids The id of each request. + * \param generation_cfg The generation config of each request in the input batch. + * \return The renormalized probability distributions, residing on device + * if the sampler is GPU sampler, or on host if the sampler is CPU sampler. + */ + virtual NDArray BatchRenormalizeProbsByTopP(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg) = 0; + /*! * \brief Sample tokens from the input batch of prob distribution on device. + * The input prob distributions are not yet applied with top-p. * \param probs_on_device The prob distributions on GPU to sample tokens from. * \param sample_indices Specifying which request we should sample for * in i-th output. The output result is sample as follow: @@ -42,22 +61,46 @@ class SamplerObj : public Object { * \param generation_cfg The generation config of each request * in the input batch. * \param rngs The random number generator of each sequence. - * \param output_prob_dist The output probability distribution * \return The batch of sampling results, which contain the sampled token id * and other probability info. */ - virtual std::vector BatchSampleTokens( + virtual std::vector BatchSampleTokensWithProbBeforeTopP( NDArray probs_on_device, // const std::vector& sample_indices, // const Array& request_ids, // const Array& generation_cfg, // + const std::vector& rngs) = 0; + + /*! + * \brief Sample tokens from the input batch of prob distribution on device. + * The input prob distributions are already applied with top-p. + * \param probs The prob distributions. + * It resides on GPU if the sampler is GPU sampler, or on host if hte sampler is CPU sampler. + * \param sample_indices Specifying which request we should sample for + * in i-th output. The output result is sample as follow: + * result[i] = sample_from(prob_on_device[sample_indices[i],:], generation_config[i])); + * \param request_ids The id of each request. + * \param generation_cfg The generation config of each request + * in the input batch. + * \param rngs The random number generator of each sequence. + * \param output_prob_dist The output probability distribution + * \return The batch of sampling results, which contain the sampled token id + * and other probability info. + */ + virtual std::vector BatchSampleTokensWithProbAfterTopP( + NDArray probs, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // const std::vector& rngs, // std::vector* output_prob_dist = nullptr) = 0; /*! * \brief Verify draft tokens generated by small models in the large model * in speculative decoding. The input corresponds to a batch of sequences. - * \param probs_on_device The prob distributions on GPU to sample tokens from. + * The input prob distributions are already applied with top-p. + * \param probs The prob distributions on GPU to sample tokens from. + * It resides on GPU if the sampler is GPU sampler, or on host if hte sampler is CPU sampler. * \param request_ids The id of each request. * \param cum_verify_lengths The cumulative draft lengths to verify of all sequences. * \param generation_cfg The generation config of each request @@ -69,10 +112,9 @@ class SamplerObj : public Object { * small model for each sequence. * \return The list of accepted tokens for each request. */ - virtual std::vector> BatchVerifyDraftTokens( - NDArray probs_on_device, const Array& request_ids, - const std::vector& cum_verify_lengths, const Array& generation_cfg, - const std::vector& rngs, + virtual std::vector> BatchVerifyDraftTokensWithProbAfterTopP( + NDArray probs, const Array& request_ids, const std::vector& cum_verify_lengths, + const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, const std::vector>& draft_output_prob_dist) = 0; diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index f044c3a6d8..46dc40c106 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -49,6 +49,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR _attach_sample_with_top_p(bb, vocab_size), _attach_take_probs_func(bb, vocab_size), _attach_batch_verifier(bb, vocab_size), + _attach_renormalize_by_top_p(bb, vocab_size), ] ] @@ -129,6 +130,17 @@ def _attach_argsort_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): return gv +@T.prim_func +def full(var_result: T.handle, value: T.int32): + """The filling function for top k.""" + batch_size = T.int32(is_size_var=True) + result = T.match_buffer(var_result, (batch_size, 1), "int32") + for i in T.serial(batch_size): + with T.block("block"): + vi = T.axis.spatial(batch_size, i) + result[vi, 0] = value + + def _attach_sample_with_top_p( # pylint: disable=too-many-locals bb: relax.BlockBuilder, vocab_size: tir.PrimExpr ): @@ -146,15 +158,6 @@ def _attach_sample_with_top_p( # pylint: disable=too-many-locals sample_indices = relax.Var("sample_indices", relax.TensorStructInfo((num_samples,), "int32")) top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32")) - @T.prim_func - def full(var_result: T.handle, value: T.int32): - batch_size = T.int32(is_size_var=True) - result = T.match_buffer(var_result, (batch_size, 1), "int32") - for i in T.serial(batch_size): - with T.block("block"): - vi = T.axis.spatial(batch_size, i) - result[vi, 0] = value - with bb.function( "sample_with_top_p", [sorted_probs, sorted_indices, uniform_samples, sample_indices, top_p], @@ -224,6 +227,44 @@ def full(var_result: T.handle, value: T.int32): return gv +def _attach_renormalize_by_top_p(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): + batch_size = tir.Var("batch_size", "int64") + probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) + sorted_probs = relax.Var( + "sorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32") + ) + top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32")) + with bb.function("renormalize_by_top_p", [probs, sorted_probs, top_p]): + with bb.dataflow(): + probs_tensor = nn.wrap_nested(probs, name="probs") + sorted_probs_tensor = nn.wrap_nested(sorted_probs, name="sorted_probs") + top_p_shape = relax.ShapeExpr([batch_size, 1]) + top_p_tensor = nn.wrap_nested( + relax.call_pure_packed( + "vm.builtin.reshape", + top_p, + top_p_shape, + sinfo_args=relax.TensorStructInfo(top_p_shape, "float32"), + ), + name="sample_indices", + ) + top_k_tensor = nn.tensor_ir_op( + full, + name_hint="full", + args=[vocab_size], + out=nn.Tensor.placeholder( + [batch_size, 1], + "int32", + ), + ) + renormalized_probs = nn.renormalize_top_p_top_k_prob( + probs_tensor, sorted_probs_tensor, top_p_tensor, top_k_tensor + ) + bb.emit_output(renormalized_probs._expr) # pylint: disable=protected-access + gv = bb.emit_func_output(renormalized_probs._expr) # pylint: disable=protected-access + return gv + + def _attach_take_probs_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): batch_size = tir.Var("batch_size", "int64") num_samples = tir.Var("num_samples", "int64") diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index b4321ebdec..eff6f6f46e 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -188,7 +188,7 @@ "gpu_memory_utilization_serve": """ A number in (0, 1) denoting the fraction of GPU memory used by the server in total. It is used to infer to maximum possible KV cache capacity. -When it is unspecified, it defaults to 0.90. +When it is unspecified, it defaults to 0.85. Under mode "local" or "interactive", the actual memory usage may be significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. """, diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 5bbdc149d4..febf88e99e 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -819,7 +819,7 @@ class AsyncLLMEngine(engine_base.LLMEngineBase): gpu_memory_utilization : Optional[float] A number in (0, 1) denoting the fraction of GPU memory used by the server in total. It is used to infer to maximum possible KV cache capacity. - When it is unspecified, it defaults to 0.90. + When it is unspecified, it defaults to 0.85. Under mode "local" or "interactive", the actual memory usage may be significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. @@ -1365,7 +1365,7 @@ class LLMEngine(engine_base.LLMEngineBase): gpu_memory_utilization : Optional[float] A number in (0, 1) denoting the fraction of GPU memory used by the server in total. It is used to infer to maximum possible KV cache capacity. - When it is unspecified, it defaults to 0.90. + When it is unspecified, it defaults to 0.85. Under mode "local" or "interactive", the actual memory usage may be significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 23dea5d015..6d89d223d1 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -199,7 +199,7 @@ def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-ma if gpu_size_bytes is None: raise ValueError("Cannot read total GPU global memory from device.") if gpu_memory_utilization is None: - gpu_memory_utilization = 0.90 + gpu_memory_utilization = 0.85 model_max_total_sequence_length = int( ( From 9ec75ee258b28fbe2aec6f1cfd61bb6c1b7c6b20 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 24 Apr 2024 09:52:51 -0400 Subject: [PATCH 231/531] [Python] Rename LLMEngine to MLCEngine (#2210) This commit renames the LLMEngine to MLCEngine. --- README.md | 10 +-- docs/deploy/python_engine.rst | 72 +++++++++---------- docs/get_started/introduction.rst | 18 ++--- docs/get_started/quick_start.rst | 4 +- examples/python/sample_mlc_engine.py | 4 +- python/mlc_llm/__init__.py | 2 +- python/mlc_llm/help.py | 2 +- python/mlc_llm/interface/serve.py | 2 +- python/mlc_llm/serve/__init__.py | 2 +- python/mlc_llm/serve/config.py | 2 +- python/mlc_llm/serve/engine.py | 30 ++++---- python/mlc_llm/serve/engine_base.py | 34 ++++----- python/mlc_llm/serve/server/server_context.py | 8 +-- python/mlc_llm/serve/sync_engine.py | 2 +- tests/python/serve/evaluate_engine.py | 4 +- tests/python/serve/test_serve_async_engine.py | 14 ++-- .../serve/test_serve_async_engine_spec.py | 6 +- tests/python/serve/test_serve_engine.py | 12 ++-- .../python/serve/test_serve_engine_grammar.py | 12 ++-- tests/python/serve/test_serve_engine_image.py | 4 +- tests/python/serve/test_serve_engine_spec.py | 22 +++--- tests/python/serve/test_serve_sync_engine.py | 12 ++-- 22 files changed, 139 insertions(+), 139 deletions(-) diff --git a/README.md b/README.md index 647b9047f2..88e3abd07d 100644 --- a/README.md +++ b/README.md @@ -106,11 +106,11 @@ We can run the Llama-3 model with the chat completion Python API of MLC LLM. You can save the code below into a Python file and run it. ```python -from mlc_llm import LLMEngine +from mlc_llm import MLCEngine # Create engine model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" -engine = LLMEngine(model) +engine = MLCEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( @@ -125,12 +125,12 @@ print("\n") engine.terminate() ``` -**The Python API of `mlc_llm.LLMEngine` fully aligns with OpenAI API**. -You can use LLMEngine in the same way of using +**The Python API of `mlc_llm.MLCEngine` fully aligns with OpenAI API**. +You can use MLCEngine in the same way of using [OpenAI's Python package](https://github.com/openai/openai-python?tab=readme-ov-file#usage) for both synchronous and asynchronous generation. -If you would like to do concurrent asynchronous generation, you can use `mlc_llm.AsyncLLMEngine` instead. +If you would like to do concurrent asynchronous generation, you can use `mlc_llm.AsyncMLCEngine` instead. ### REST Server diff --git a/docs/deploy/python_engine.rst b/docs/deploy/python_engine.rst index cfbc3b5d4c..89c60ac422 100644 --- a/docs/deploy/python_engine.rst +++ b/docs/deploy/python_engine.rst @@ -4,7 +4,7 @@ Python API ========== .. note:: - This page introduces the Python API with LLMEngine in MLC LLM. + This page introduces the Python API with MLCEngine in MLC LLM. If you want to check out the old Python API which uses :class:`mlc_llm.ChatModule`, please go to :ref:`deploy-python-chat-module` @@ -13,10 +13,10 @@ Python API :depth: 2 -MLC LLM provides Python API through classes :class:`mlc_llm.LLMEngine` and :class:`mlc_llm.AsyncLLMEngine` +MLC LLM provides Python API through classes :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine` which **support full OpenAI API completeness** for easy integration into other Python projects. -This page introduces how to use the LLM engines in MLC LLM. +This page introduces how to use the engines in MLC LLM. The Python API is a part of the MLC-LLM package, which we have prepared pre-built pip wheels via the :ref:`installation page `. @@ -26,31 +26,31 @@ Verify Installation .. code:: bash - python -c "from mlc_llm import LLMEngine; print(LLMEngine)" + python -c "from mlc_llm import MLCEngine; print(MLCEngine)" -You are expected to see the output of ````. +You are expected to see the output of ````. If the command above results in error, follow :ref:`install-mlc-packages` to install prebuilt pip packages or build MLC LLM from source. -Run LLMEngine +Run MLCEngine ------------- -:class:`mlc_llm.LLMEngine` provides the interface of OpenAI chat completion synchronously. -:class:`mlc_llm.LLMEngine` does not batch concurrent request due to the synchronous design, -and please use :ref:`AsyncLLMEngine ` for request batching process. +:class:`mlc_llm.MLCEngine` provides the interface of OpenAI chat completion synchronously. +:class:`mlc_llm.MLCEngine` does not batch concurrent request due to the synchronous design, +and please use :ref:`AsyncMLCEngine ` for request batching process. **Stream Response.** In :ref:`quick-start` and :ref:`introduction-to-mlc-llm`, -we introduced the basic use of :class:`mlc_llm.LLMEngine`. +we introduced the basic use of :class:`mlc_llm.MLCEngine`. .. code:: python - from mlc_llm import LLMEngine + from mlc_llm import MLCEngine # Create engine model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" - engine = LLMEngine(model) + engine = MLCEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( @@ -64,9 +64,9 @@ we introduced the basic use of :class:`mlc_llm.LLMEngine`. engine.terminate() -This code example first creates an :class:`mlc_llm.LLMEngine` instance with the 8B Llama-3 model. -**We design the Python API** :class:`mlc_llm.LLMEngine` **to align with OpenAI API**, -which means you can use :class:`mlc_llm.LLMEngine` in the same way of using +This code example first creates an :class:`mlc_llm.MLCEngine` instance with the 8B Llama-3 model. +**We design the Python API** :class:`mlc_llm.MLCEngine` **to align with OpenAI API**, +which means you can use :class:`mlc_llm.MLCEngine` in the same way of using `OpenAI's Python package `_ for both synchronous and asynchronous generation. @@ -90,14 +90,14 @@ for the complete chat completion interface. .. _python-engine-async-llm-engine: -Run AsyncLLMEngine +Run AsyncMLCEngine ------------------ -:class:`mlc_llm.AsyncLLMEngine` provides the interface of OpenAI chat completion with +:class:`mlc_llm.AsyncMLCEngine` provides the interface of OpenAI chat completion with asynchronous features. -**We recommend using** :class:`mlc_llm.AsyncLLMEngine` **to batch concurrent request for better throughput.** +**We recommend using** :class:`mlc_llm.AsyncMLCEngine` **to batch concurrent request for better throughput.** -**Stream Response.** The core use of :class:`mlc_llm.AsyncLLMEngine` for stream responses is as follows. +**Stream Response.** The core use of :class:`mlc_llm.AsyncMLCEngine` for stream responses is as follows. .. code:: python @@ -109,14 +109,14 @@ asynchronous features. for choice in response.choices: print(choice.delta.content, end="", flush=True) -.. collapse:: The collapsed is a complete runnable example of AsyncLLMEngine in Python. +.. collapse:: The collapsed is a complete runnable example of AsyncMLCEngine in Python. .. code:: python import asyncio from typing import Dict - from mlc_llm.serve import AsyncLLMEngine + from mlc_llm.serve import AsyncMLCEngine model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" prompts = [ @@ -127,7 +127,7 @@ asynchronous features. async def test_completion(): # Create engine - async_engine = AsyncLLMEngine(model=model) + async_engine = AsyncMLCEngine(model=model) num_requests = len(prompts) output_texts: Dict[str, str] = {} @@ -176,8 +176,8 @@ for the complete chat completion interface. Engine Mode ----------- -To ease the engine configuration, the constructors of :class:`mlc_llm.LLMEngine` and -:class:`mlc_llm.AsyncLLMEngine` have an optional argument ``mode``, +To ease the engine configuration, the constructors of :class:`mlc_llm.MLCEngine` and +:class:`mlc_llm.AsyncMLCEngine` have an optional argument ``mode``, which falls into one of the three options ``"local"``, ``"interactive"`` or ``"server"``. The default mode is ``"local"``. @@ -203,34 +203,34 @@ Deploy Your Own Model with Python API The :ref:`introduction page ` introduces how we can deploy our own models with MLC LLM. This section introduces how you can use the model weights you convert and the model library you build -in :class:`mlc_llm.LLMEngine` and :class:`mlc_llm.AsyncLLMEngine`. +in :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine`. We use the `Phi-2 `_ as the example model. **Specify Model Weight Path.** Assume you have converted the model weights for your own model, -you can construct a :class:`mlc_llm.LLMEngine` as follows: +you can construct a :class:`mlc_llm.MLCEngine` as follows: .. code:: python - from mlc_llm import LLMEngine + from mlc_llm import MLCEngine model = "models/phi-2" # Assuming the converted phi-2 model weights are under "models/phi-2" - engine = LLMEngine(model) + engine = MLCEngine(model) **Specify Model Library Path.** Further, if you build the model library on your own, -you can use it in :class:`mlc_llm.LLMEngine` by passing the library path through argument ``model_lib_path``. +you can use it in :class:`mlc_llm.MLCEngine` by passing the library path through argument ``model_lib_path``. .. code:: python - from mlc_llm import LLMEngine + from mlc_llm import MLCEngine model = "models/phi-2" model_lib_path = "models/phi-2/lib.so" # Assuming the phi-2 model library is built at "models/phi-2/lib.so" - engine = LLMEngine(model, model_lib_path=model_lib_path) + engine = MLCEngine(model, model_lib_path=model_lib_path) -The same applies to :class:`mlc_llm.AsyncLLMEngine`. +The same applies to :class:`mlc_llm.AsyncMLCEngine`. .. _python-engine-api-reference: @@ -238,16 +238,16 @@ The same applies to :class:`mlc_llm.AsyncLLMEngine`. API Reference ------------- -The :class:`mlc_llm.LLMEngine` and :class:`mlc_llm.AsyncLLMEngine` classes provide the following constructors. +The :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine` classes provide the following constructors. -The LLMEngine and AsyncLLMEngine have full OpenAI API completeness. +The MLCEngine and AsyncMLCEngine have full OpenAI API completeness. Please refer to `OpenAI's Python package `_ and `OpenAI chat completion API `_ for the complete chat completion interface. .. currentmodule:: mlc_llm -.. autoclass:: LLMEngine +.. autoclass:: MLCEngine :members: :exclude-members: evaluate :undoc-members: @@ -255,7 +255,7 @@ for the complete chat completion interface. .. automethod:: __init__ -.. autoclass:: AsyncLLMEngine +.. autoclass:: AsyncMLCEngine :members: :exclude-members: evaluate :undoc-members: diff --git a/docs/get_started/introduction.rst b/docs/get_started/introduction.rst index 32bcfc4cdb..29060d5a60 100644 --- a/docs/get_started/introduction.rst +++ b/docs/get_started/introduction.rst @@ -90,11 +90,11 @@ You can save the code below into a Python file and run it. .. code:: python - from mlc_llm import LLMEngine + from mlc_llm import MLCEngine # Create engine model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" - engine = LLMEngine(model) + engine = MLCEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( @@ -114,9 +114,9 @@ You can save the code below into a Python file and run it. MLC LLM Python API -This code example first creates an :class:`mlc_llm.LLMEngine` instance with the 4-bit quantized Llama-3 model. -**We design the Python API** :class:`mlc_llm.LLMEngine` **to align with OpenAI API**, -which means you can use :class:`mlc_llm.LLMEngine` in the same way of using +This code example first creates an :class:`mlc_llm.MLCEngine` instance with the 4-bit quantized Llama-3 model. +**We design the Python API** :class:`mlc_llm.MLCEngine` **to align with OpenAI API**, +which means you can use :class:`mlc_llm.MLCEngine` in the same way of using `OpenAI's Python package `_ for both synchronous and asynchronous generation. @@ -134,7 +134,7 @@ If you want to run without streaming, you can run print(response) You can also try different arguments supported in `OpenAI chat completion API `_. -If you would like to do concurrent asynchronous generation, you can use :class:`mlc_llm.AsyncLLMEngine` instead. +If you would like to do concurrent asynchronous generation, you can use :class:`mlc_llm.AsyncMLCEngine` instead. REST Server ----------- @@ -229,7 +229,7 @@ You can also use this model in Python API, MLC serve and other use scenarios. (Optional) Compile Model Library ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -In previous sections, model libraries are compiled when the :class:`mlc_llm.LLMEngine` launches, +In previous sections, model libraries are compiled when the :class:`mlc_llm.MLCEngine` launches, which is what we call "JIT (Just-in-Time) model compilation". In some cases, it is beneficial to explicitly compile the model libraries. We can deploy LLMs with reduced dependencies by shipping the library for deployment without going through compilation. @@ -257,12 +257,12 @@ At runtime, we need to specify this model library path to use it. For example, .. code:: python - from mlc_llm import LLMEngine + from mlc_llm import MLCEngine # For Python API model = "models/phi-2" model_lib_path = "models/phi-2/lib.so" - engine = LLMEngine(model, model_lib_path=model_lib_path) + engine = MLCEngine(model, model_lib_path=model_lib_path) :ref:`compile-model-libraries` introduces the model compilation command in detail, where you can find instructions and example commands to compile model to different diff --git a/docs/get_started/quick_start.rst b/docs/get_started/quick_start.rst index 76d971275b..8349197eda 100644 --- a/docs/get_started/quick_start.rst +++ b/docs/get_started/quick_start.rst @@ -20,11 +20,11 @@ It is recommended to have at least 6GB free VRAM to run it. .. code:: python - from mlc_llm import LLMEngine + from mlc_llm import MLCEngine # Create engine model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" - engine = LLMEngine(model) + engine = MLCEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( diff --git a/examples/python/sample_mlc_engine.py b/examples/python/sample_mlc_engine.py index f76e44c620..e4f869930f 100644 --- a/examples/python/sample_mlc_engine.py +++ b/examples/python/sample_mlc_engine.py @@ -1,8 +1,8 @@ -from mlc_llm import LLMEngine +from mlc_llm import MLCEngine # Create engine model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" -engine = LLMEngine(model) +engine = MLCEngine(model) # Run chat completion in OpenAI API. for response in engine.chat.completions.create( diff --git a/python/mlc_llm/__init__.py b/python/mlc_llm/__init__.py index 8e3aaaa808..4843c6766d 100644 --- a/python/mlc_llm/__init__.py +++ b/python/mlc_llm/__init__.py @@ -6,4 +6,4 @@ from . import protocol, serve from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig from .libinfo import __version__ -from .serve import AsyncLLMEngine, LLMEngine +from .serve import AsyncMLCEngine, MLCEngine diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index eff6f6f46e..14e5cee321 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -203,7 +203,7 @@ The number of draft tokens to generate in speculative proposal. The default values is 4. """, "engine_config_serve": """ -The LLMEngine execution configuration. +The MLCEngine execution configuration. Currently speculative decoding mode is specified via engine config. For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" to specify the eagle-style speculative decoding. diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index c5696ef473..d0cbd4690b 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -35,7 +35,7 @@ def serve( ): # pylint: disable=too-many-arguments, too-many-locals """Serve the model with the specified configuration.""" # Create engine and start the background loop - async_engine = engine.AsyncLLMEngine( + async_engine = engine.AsyncMLCEngine( model=model, device=device, model_lib_path=model_lib_path, diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 79caff7cad..59358c1646 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -4,7 +4,7 @@ from .. import base from .config import EngineConfig, GenerationConfig, SpeculativeMode from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData -from .engine import AsyncLLMEngine, LLMEngine +from .engine import AsyncMLCEngine, MLCEngine from .grammar import BNFGrammar, GrammarStateMatcher from .radix_tree import PagedRadixTree from .request import Request diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 773a00625e..60e4eca8c5 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -141,7 +141,7 @@ class SpeculativeMode(enum.IntEnum): @tvm._ffi.register_object("mlc.serve.EngineConfig") # pylint: disable=protected-access class EngineConfig(tvm.runtime.Object): - """The class of LLMEngine execution configuration. + """The class of MLCEngine execution configuration. Parameters ---------- diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index febf88e99e..d9721b4864 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -37,10 +37,10 @@ class Chat: # pylint: disable=too-few-public-methods """The proxy class to direct to chat completions.""" def __init__(self, engine: weakref.ReferenceType) -> None: - assert isinstance(engine(), (AsyncLLMEngine, LLMEngine)) + assert isinstance(engine(), (AsyncMLCEngine, MLCEngine)) self.completions = ( AsyncChatCompletion(engine) # type: ignore - if isinstance(engine(), AsyncLLMEngine) + if isinstance(engine(), AsyncMLCEngine) else ChatCompletion(engine) # type: ignore ) @@ -49,7 +49,7 @@ class AsyncChatCompletion: # pylint: disable=too-few-public-methods """The proxy class to direct to async chat completions.""" if sys.version_info >= (3, 9): - engine: weakref.ReferenceType["AsyncLLMEngine"] + engine: weakref.ReferenceType["AsyncMLCEngine"] else: engine: weakref.ReferenceType @@ -226,7 +226,7 @@ class ChatCompletion: # pylint: disable=too-few-public-methods """The proxy class to direct to chat completions.""" if sys.version_info >= (3, 9): - engine: weakref.ReferenceType["LLMEngine"] + engine: weakref.ReferenceType["MLCEngine"] else: engine: weakref.ReferenceType @@ -401,7 +401,7 @@ class AsyncCompletion: # pylint: disable=too-few-public-methods """The proxy class to direct to async completions.""" if sys.version_info >= (3, 9): - engine: weakref.ReferenceType["AsyncLLMEngine"] + engine: weakref.ReferenceType["AsyncMLCEngine"] else: engine: weakref.ReferenceType @@ -580,7 +580,7 @@ class Completion: # pylint: disable=too-few-public-methods """The proxy class to direct to completions.""" if sys.version_info >= (3, 9): - engine: weakref.ReferenceType["LLMEngine"] + engine: weakref.ReferenceType["MLCEngine"] else: engine: weakref.ReferenceType @@ -752,8 +752,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals ) -class AsyncLLMEngine(engine_base.LLMEngineBase): - """The AsyncLLMEngine in MLC LLM that provides the asynchronous +class AsyncMLCEngine(engine_base.MLCEngineBase): + """The AsyncMLCEngine in MLC LLM that provides the asynchronous interfaces with regard to OpenAI API. Parameters @@ -825,7 +825,7 @@ class AsyncLLMEngine(engine_base.LLMEngineBase): memory usage may be slightly larger than this number. engine_config : Optional[EngineConfig] - The LLMEngine execution configuration. + The MLCEngine execution configuration. Currently speculative decoding mode is specified via engine config. For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" to specify the eagle-style speculative decoding. @@ -1228,7 +1228,7 @@ async def _generate( generation_config: GenerationConfig, request_id: str, ) -> AsyncGenerator[List[engine_base.CallbackStreamOutput], Any]: - """Internal asynchronous text generation interface of AsyncLLMEngine. + """Internal asynchronous text generation interface of AsyncMLCEngine. The method is a coroutine that streams a list of CallbackStreamOutput at a time via yield. The returned list length is the number of parallel generations specified by `generation_config.n`. @@ -1298,8 +1298,8 @@ def _abort(self, request_id: str): self._ffi["abort_request"](request_id) -class LLMEngine(engine_base.LLMEngineBase): - """The LLMEngine in MLC LLM that provides the synchronous +class MLCEngine(engine_base.MLCEngineBase): + """The MLCEngine in MLC LLM that provides the synchronous interfaces with regard to OpenAI API. Parameters @@ -1371,7 +1371,7 @@ class LLMEngine(engine_base.LLMEngineBase): memory usage may be slightly larger than this number. engine_config : Optional[EngineConfig] - The LLMEngine execution configuration. + The MLCEngine execution configuration. Currently speculative decoding mode is specified via engine config. For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" to specify the eagle-style speculative decoding. @@ -1767,7 +1767,7 @@ def _generate( # pylint: disable=too-many-locals generation_config: GenerationConfig, request_id: str, ) -> Iterator[List[engine_base.CallbackStreamOutput]]: - """Internal synchronous text generation interface of AsyncLLMEngine. + """Internal synchronous text generation interface of AsyncMLCEngine. The method is a coroutine that streams a list of CallbackStreamOutput at a time via yield. The returned list length is the number of parallel generations specified by `generation_config.n`. @@ -1821,7 +1821,7 @@ def _generate( # pylint: disable=too-many-locals def _request_stream_callback_impl( self, delta_outputs: List[data.RequestStreamOutput] ) -> List[List[engine_base.CallbackStreamOutput]]: - """The underlying implementation of request stream callback of LLMEngine.""" + """The underlying implementation of request stream callback of MLCEngine.""" batch_outputs: List[List[engine_base.CallbackStreamOutput]] = [] for delta_output in delta_outputs: request_id, stream_outputs = delta_output.unpack() diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 6d89d223d1..7b2ede60b2 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -464,7 +464,7 @@ def infer_args_under_mode( @dataclass class CallbackStreamOutput: - """The output of LLMEngine._generate and AsyncLLMEngine._generate + """The output of MLCEngine._generate and AsyncMLCEngine._generate Attributes ---------- @@ -489,7 +489,7 @@ class CallbackStreamOutput: class AsyncRequestStream: - """The asynchronous stream for requests in AsyncLLMEngine. + """The asynchronous stream for requests in AsyncMLCEngine. Each request has its own unique stream. The stream exposes the method `push` for engine to push new generated @@ -548,29 +548,29 @@ async def __anext__(self) -> List[CallbackStreamOutput]: class EngineState: """The engine states that the request stream callback function may use. - This class is used for both AsyncLLMEngine and LLMEngine. - AsyncLLMEngine uses the fields and methods starting with "async", - and LLMEngine uses the ones starting with "sync". + This class is used for both AsyncMLCEngine and MLCEngine. + AsyncMLCEngine uses the fields and methods starting with "async", + and MLCEngine uses the ones starting with "sync". - - For AsyncLLMEngine, the state contains an asynchronous event loop, + - For AsyncMLCEngine, the state contains an asynchronous event loop, the streamers and the number of unfinished generations for each request being processed. - - For LLMEngine, the state contains a callback output blocking queue, + - For MLCEngine, the state contains a callback output blocking queue, the text streamers and the number of unfinished requests. We use this state class to avoid the callback function from capturing - the AsyncLLMEngine. + the AsyncMLCEngine. The state also optionally maintains an event trace recorder, which can provide Chrome tracing when enabled. """ trace_recorder = None - # States used for AsyncLLMEngine + # States used for AsyncMLCEngine async_event_loop: Optional[asyncio.AbstractEventLoop] = None async_streamers: Dict[str, Tuple[AsyncRequestStream, List[TextStreamer]]] = {} async_num_unfinished_generations: Dict[str, int] = {} - # States used for LLMEngine + # States used for MLCEngine sync_output_queue: queue.Queue = queue.Queue() sync_text_streamers: List[TextStreamer] = [] sync_num_unfinished_generations: int = 0 @@ -632,7 +632,7 @@ def async_lazy_init_event_loop(self) -> None: self.async_event_loop = asyncio.get_event_loop() def _async_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: - """The request stream callback function for AsyncLLMEngine to stream back + """The request stream callback function for AsyncMLCEngine to stream back the request generation results. Note @@ -652,7 +652,7 @@ def _async_request_stream_callback(self, delta_outputs: List[data.RequestStreamO def _async_request_stream_callback_impl( self, delta_outputs: List[data.RequestStreamOutput] ) -> None: - """The underlying implementation of request stream callback for AsyncLLMEngine.""" + """The underlying implementation of request stream callback for AsyncMLCEngine.""" for delta_output in delta_outputs: request_id, stream_outputs = delta_output.unpack() streamers = self.async_streamers.get(request_id, None) @@ -693,28 +693,28 @@ def _async_request_stream_callback_impl( self.record_event(request_id, event="finish callback") def _sync_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: - """The request stream callback function for LLMEngine to stream back + """The request stream callback function for MLCEngine to stream back the request generation results. """ # Put the delta outputs to the queue in the unblocking way. self.sync_output_queue.put_nowait(delta_outputs) -class LLMEngineBase: # pylint: disable=too-many-instance-attributes,too-few-public-methods +class MLCEngineBase: # pylint: disable=too-many-instance-attributes,too-few-public-methods """The base engine class, which implements common functions that - are shared by LLMEngine and AsyncLLMEngine. + are shared by MLCEngine and AsyncMLCEngine. This class wraps a threaded engine that runs on a standalone thread inside and streams back the delta generated results via callback functions. The internal threaded engine keeps running an loop that drives the engine. - LLMEngine and AsyncLLMEngine inherits this LLMEngineBase class, and implements + MLCEngine and AsyncMLCEngine inherits this MLCEngineBase class, and implements their own methods to process the delta generated results received from callback functions and yield the processed delta results in the forms of standard API protocols. - Checkout subclasses AsyncLLMEngine/LLMEngine for the docstring of constructor parameters. + Checkout subclasses AsyncMLCEngine/MLCEngine for the docstring of constructor parameters. """ def __init__( # pylint: disable=too-many-arguments,too-many-locals diff --git a/python/mlc_llm/serve/server/server_context.py b/python/mlc_llm/serve/server/server_context.py index 46b841aaa9..d6acd4a2be 100644 --- a/python/mlc_llm/serve/server/server_context.py +++ b/python/mlc_llm/serve/server/server_context.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional -from ..engine import AsyncLLMEngine +from ..engine import AsyncMLCEngine class ServerContext: @@ -13,7 +13,7 @@ class ServerContext: server_context: Optional["ServerContext"] = None def __init__(self): - self._models: Dict[str, AsyncLLMEngine] = {} + self._models: Dict[str, AsyncMLCEngine] = {} def __enter__(self): if ServerContext.server_context is not None: @@ -31,13 +31,13 @@ def current(): """Returns the current ServerContext.""" return ServerContext.server_context - def add_model(self, hosted_model: str, engine: AsyncLLMEngine) -> None: + def add_model(self, hosted_model: str, engine: AsyncMLCEngine) -> None: """Add a new model to the server context together with the engine.""" if hosted_model in self._models: raise RuntimeError(f"Model {hosted_model} already running.") self._models[hosted_model] = engine - def get_engine(self, model: Optional[str]) -> Optional[AsyncLLMEngine]: + def get_engine(self, model: Optional[str]) -> Optional[AsyncMLCEngine]: """Get the async engine of the requested model, or the unique async engine if only one engine is served.""" if len(self._models) == 1: diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 23b151d5c7..257338da3a 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -41,7 +41,7 @@ def _create_tvm_module( return {key: module[key] for key in ffi_funcs} -class SyncLLMEngine: +class SyncMLCEngine: """The Python interface of synchronize request serving engine for MLC LLM. The engine receives requests from the "add_request" method. For diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index 4e541b7437..c89a9e2c38 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -5,7 +5,7 @@ from typing import List, Tuple from mlc_llm.serve import GenerationConfig -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine def _parse_args(): @@ -41,7 +41,7 @@ def benchmark(args: argparse.Namespace): random.seed(args.seed) # Create engine - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=args.model, device=args.device, model_lib_path=args.model_lib_path, diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index 9bece30578..6e3835238a 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -3,7 +3,7 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncLLMEngine, GenerationConfig +from mlc_llm.serve import AsyncMLCEngine, GenerationConfig prompts = [ "What is the meaning of life?", @@ -23,7 +23,7 @@ async def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -39,7 +39,7 @@ async def test_engine_generate(): ] async def generate_task( - async_engine: AsyncLLMEngine, + async_engine: AsyncMLCEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, @@ -80,7 +80,7 @@ async def test_chat_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -132,7 +132,7 @@ async def test_chat_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -183,7 +183,7 @@ async def test_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -235,7 +235,7 @@ async def test_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index de91c845b3..c3963af613 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -3,7 +3,7 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncLLMEngine, GenerationConfig, SpeculativeMode +from mlc_llm.serve import AsyncMLCEngine, GenerationConfig, SpeculativeMode prompts = [ "What is the meaning of life?", @@ -27,7 +27,7 @@ async def test_engine_generate(): small_model_lib_path = ( "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - async_engine = AsyncLLMEngine( + async_engine = AsyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -44,7 +44,7 @@ async def test_engine_generate(): ] async def generate_task( - async_engine: AsyncLLMEngine, + async_engine: AsyncMLCEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 330bd4cf82..f965e8cc82 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -2,7 +2,7 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable from typing import List -from mlc_llm.serve import GenerationConfig, LLMEngine +from mlc_llm.serve import GenerationConfig, MLCEngine prompts = [ "What is the meaning of life?", @@ -22,7 +22,7 @@ def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( + engine = MLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -61,7 +61,7 @@ def test_chat_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( + engine = MLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -105,7 +105,7 @@ def test_chat_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( + engine = MLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -148,7 +148,7 @@ def test_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( + engine = MLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -192,7 +192,7 @@ def test_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = LLMEngine( + engine = MLCEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index 7f2a33b230..b764c62cd2 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -7,9 +7,9 @@ import pytest from pydantic import BaseModel -from mlc_llm.serve import AsyncLLMEngine, GenerationConfig +from mlc_llm.serve import AsyncMLCEngine, GenerationConfig from mlc_llm.serve.config import ResponseFormat -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine prompts_list = [ "Generate a JSON string containing 20 objects:", @@ -22,7 +22,7 @@ def test_batch_generation_with_grammar(): # Create engine - engine = SyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + engine = SyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompt_len = len(prompts_list) prompts = prompts_list * 3 @@ -69,7 +69,7 @@ def test_batch_generation_with_grammar(): def test_batch_generation_with_schema(): # Create engine - engine = SyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + engine = SyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompt = ( "Generate a json containing three fields: an integer field named size, a " @@ -121,7 +121,7 @@ class Schema(BaseModel): async def run_async_engine(): # Create engine - async_engine = AsyncLLMEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + async_engine = AsyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") prompts = prompts_list * 20 @@ -142,7 +142,7 @@ async def run_async_engine(): ] async def generate_task( - async_engine: AsyncLLMEngine, + async_engine: AsyncMLCEngine, prompt: str, generation_cfg: GenerationConfig, request_id: str, diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py index ff64e7235b..59e8c97196 100644 --- a/tests/python/serve/test_serve_engine_image.py +++ b/tests/python/serve/test_serve_engine_image.py @@ -2,7 +2,7 @@ from pathlib import Path from mlc_llm.serve import GenerationConfig, data -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine def get_test_image(config) -> data.ImageData: @@ -13,7 +13,7 @@ def test_engine_generate(): # Create engine model = "dist/llava-1.5-7b-hf-q4f16_1-MLC/params" model_lib_path = "dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 6647c7af19..33c06b1c5e 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -11,7 +11,7 @@ SpeculativeMode, data, ) -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine prompts = [ "What is the meaning of life?", @@ -90,7 +90,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): small_model_lib_path = ( "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -158,7 +158,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): small_model_lib_path = ( "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" ) - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -242,7 +242,7 @@ def step(self) -> None: "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) timer = CallbackTimer() - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -328,7 +328,7 @@ def step(self) -> None: "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) timer = CallbackTimer() - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -385,7 +385,7 @@ def test_engine_generate(compare_precision=False): "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" ) - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -403,7 +403,7 @@ def test_engine_generate(compare_precision=False): generation_config = GenerationConfig( temperature=0.0, top_p=0, max_tokens=1024, stop_token_ids=[2], n=1 ) - engine_single_model = SyncLLMEngine( + engine_single_model = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -446,7 +446,7 @@ def test_engine_eagle_generate(): small_model_lib_path = ( "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -494,7 +494,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -566,7 +566,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # small_model_lib_path = ( # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so" # ) - spec_engine = SyncLLMEngine( + spec_engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -636,7 +636,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): small_model_lib_path = ( "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" ) - spec_engine = SyncLLMEngine( + spec_engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py index c5d521b02d..f68f48b7c5 100644 --- a/tests/python/serve/test_serve_sync_engine.py +++ b/tests/python/serve/test_serve_sync_engine.py @@ -5,7 +5,7 @@ import numpy as np from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data -from mlc_llm.serve.sync_engine import SyncLLMEngine +from mlc_llm.serve.sync_engine import SyncMLCEngine prompts = [ "What is the meaning of life?", @@ -80,7 +80,7 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -156,7 +156,7 @@ def step(self) -> None: timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -237,7 +237,7 @@ def step(self) -> None: timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -323,7 +323,7 @@ def all_finished(self) -> bool: timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", @@ -365,7 +365,7 @@ def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncLLMEngine( + engine = SyncMLCEngine( model=model, model_lib_path=model_lib_path, mode="server", From e115dde2455711ff62abed377f5611508520ceac Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Wed, 24 Apr 2024 15:24:21 -0400 Subject: [PATCH 232/531] [Fix] CUDA architecture detection bug fix (#2211) This commit returns a list of integers and adds an assert to check that the string of CUDA architecture must contain numbers only. Co-authored-by: msyu --- python/mlc_llm/interface/compiler_flags.py | 3 +-- python/mlc_llm/support/auto_target.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index 2d0d668672..77b55c5a48 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -2,7 +2,6 @@ import dataclasses import enum -import re from io import StringIO from typing import Optional @@ -96,7 +95,7 @@ def _flashinfer(target) -> bool: return False arch_list = detect_cuda_arch_list(target) for arch in arch_list: - if int(re.findall(r"\d+", arch)[0]) < 80: + if arch < 80: logger.warning("flashinfer is not supported on CUDA arch < 80") return False return True diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 3cf49c43ba..5239756d9d 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -293,14 +293,20 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None): return build -def detect_cuda_arch_list(target: Target) -> List[str]: +def detect_cuda_arch_list(target: Target) -> List[int]: """Detect the CUDA architecture list from the target.""" + + def convert_to_num(arch_str): + arch_num_str = "".join(filter(str.isdigit, arch_str)) + assert arch_num_str, f"'{arch_str}' does not contain any digits" + return int(arch_num_str) + assert target.kind.name == "cuda", f"Expect target to be CUDA, but got {target}" if MLC_MULTI_ARCH is not None: - multi_arch = [x.strip() for x in MLC_MULTI_ARCH.split(",")] + multi_arch = [convert_to_num(x) for x in MLC_MULTI_ARCH.split(",")] else: assert target.arch.startswith("sm_") - multi_arch = [target.arch[3:]] + multi_arch = [convert_to_num(target.arch[3:])] multi_arch = list(set(multi_arch)) return multi_arch From 55b5c007d065f20d3168afc83384111dd46d278c Mon Sep 17 00:00:00 2001 From: Siva Date: Thu, 25 Apr 2024 17:45:22 +0530 Subject: [PATCH 233/531] [Android ] Enable OpenCL host pointer usage (#2215) Take advantage of OpenCl host ptr that improves copy performance --- android/library/prepare_libs.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/android/library/prepare_libs.sh b/android/library/prepare_libs.sh index a06e9f067d..c089927d09 100755 --- a/android/library/prepare_libs.sh +++ b/android/library/prepare_libs.sh @@ -27,6 +27,7 @@ cmake .. \ -DMLC_LLM_INSTALL_STATIC_LIB=ON \ -DCMAKE_SKIP_INSTALL_ALL_DEPENDENCY=ON \ -DUSE_OPENCL=ON \ + -DUSE_OPENCL_ENABLE_HOST_PTR=ON \ -DUSE_CUSTOM_LOGGING=ON \ cmake --build . --target tvm4j_runtime_packed --config release From 85fffee2d9dc4083ec406dd6e983cda65def18c5 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Thu, 25 Apr 2024 18:46:23 +0530 Subject: [PATCH 234/531] [PYTHON][KVCACHE] Enhance the thread limit for opencl (#2216) It improves 2x time for tir based page attention for opencl adreno. --- python/mlc_llm/nn/kv_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index 4a058c6e03..e4cbf1c047 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -887,7 +887,7 @@ def _attention_decode( THREAD_LIMIT = 512 TILE_SIZE_PER_BDX = 2 if target.kind.name == "opencl" and "android" in str(target.host): - THREAD_LIMIT = 64 + THREAD_LIMIT = 256 TILE_SIZE_PER_BDX = 1 max_num_threads_per_block = get_max_num_threads_per_block(target) thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) From 71c7b3cf06b07fbbf25d8fb97345919086fc98e7 Mon Sep 17 00:00:00 2001 From: Linyu Wu <95223577+Celve@users.noreply.github.com> Date: Fri, 26 Apr 2024 04:48:27 +0800 Subject: [PATCH 235/531] [Serving] Support RWKV for serving (#2111) feat: support serving for rwkv --- cpp/serve/config.cc | 13 +- cpp/serve/config.h | 11 + cpp/serve/engine.cc | 3 +- .../engine_actions/new_request_prefill.cc | 5 + cpp/serve/function_table.cc | 7 +- cpp/serve/model.cc | 53 +++- cpp/serve/model.h | 6 +- python/mlc_llm/cli/serve.py | 4 + python/mlc_llm/conversation_template.py | 2 +- python/mlc_llm/help.py | 5 + python/mlc_llm/interface/serve.py | 2 + python/mlc_llm/model/rwkv5/rwkv5_model.py | 70 +++-- python/mlc_llm/model/rwkv6/rwkv6_model.py | 68 ++++- python/mlc_llm/serve/config.py | 17 ++ python/mlc_llm/serve/engine.py | 7 + python/mlc_llm/serve/engine_base.py | 264 +++++++++++++++++- python/mlc_llm/serve/sync_engine.py | 6 + tests/python/json_ffi/test_json_ffi_engine.py | 6 + tests/python/serve/test_serve_engine.py | 106 ++++--- 19 files changed, 543 insertions(+), 112 deletions(-) diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 7379bad7ed..f36bc151a3 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -239,8 +239,8 @@ EngineConfig::EngineConfig(String model, String model_lib_path, Array ad Array additional_model_lib_paths, DLDevice device, int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, int max_single_sequence_length, - int prefill_chunk_size, SpeculativeMode speculative_mode, - int spec_draft_length) { + int prefill_chunk_size, int max_history_size, KVStateKind kv_state_kind, + SpeculativeMode speculative_mode, int spec_draft_length) { ObjectPtr n = make_object(); n->model = std::move(model); n->model_lib_path = std::move(model_lib_path); @@ -252,6 +252,8 @@ EngineConfig::EngineConfig(String model, String model_lib_path, Array ad n->max_total_sequence_length = max_total_sequence_length; n->max_single_sequence_length = max_single_sequence_length; n->prefill_chunk_size = prefill_chunk_size; + n->max_history_size = max_history_size; + n->kv_state_kind = kv_state_kind; n->spec_draft_length = spec_draft_length; n->speculative_mode = speculative_mode; data_ = std::move(n); @@ -261,12 +263,13 @@ TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig") .set_body_typed([](String model, String model_lib_path, Array additional_models, Array additional_model_lib_paths, DLDevice device, int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, - int max_single_sequence_length, int prefill_chunk_size, int speculative_mode, - int spec_draft_length) { + int max_single_sequence_length, int prefill_chunk_size, int max_history_size, + int kv_state_kind, int speculative_mode, int spec_draft_length) { return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models), std::move(additional_model_lib_paths), device, kv_cache_page_size, max_num_sequence, max_total_sequence_length, max_single_sequence_length, - prefill_chunk_size, SpeculativeMode(speculative_mode), spec_draft_length); + prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind), + SpeculativeMode(speculative_mode), spec_draft_length); }); } // namespace serve diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 41ddb3c6e4..ef147b751b 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -84,6 +84,12 @@ enum class SpeculativeMode : int { kEagle = 2, }; +/*! \brief The kind of cache. */ +enum KVStateKind { + kAttention = 0, + kRNNState = 1, +}; + /*! \brief The configuration of engine execution config. */ class EngineConfigNode : public Object { public: @@ -121,6 +127,10 @@ class EngineConfigNode : public Object { int max_single_sequence_length; /*! \brief The maximum total sequence length in a prefill. */ int prefill_chunk_size; + /*! \brief The maximum history size for RNN state. KV cache does not need this. */ + int max_history_size; + /*! \brief The kind of cache. Whether it's KV cache or RNN state. */ + KVStateKind kv_state_kind; /*************** Speculative decoding ***************/ @@ -143,6 +153,7 @@ class EngineConfig : public ObjectRef { Array additional_model_lib_paths, DLDevice device, int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, int max_single_sequence_length, int prefill_chunk_size, + int max_history_size, KVStateKind kv_state_kind, SpeculativeMode speculative_mode, int spec_draft_length); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 8568c6ce94..0348f7f40a 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -69,7 +69,8 @@ class EngineImpl : public Engine { /*trace_enabled=*/trace_recorder.defined()); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, engine_config->max_total_sequence_length, - engine_config->prefill_chunk_size); + engine_config->prefill_chunk_size, engine_config->max_history_size, + engine_config->kv_state_kind); CHECK_GE(model->GetMaxWindowSize(), engine_config->max_single_sequence_length) << "The window size of the model, " << model->GetMaxWindowSize() << ", is smaller than the pre-defined max single sequence length, " diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index c80c5e0ede..b4192a04f1 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -396,6 +396,11 @@ class NewRequestPrefillActionObj : public EngineActionObj { int num_running_rsentries) { ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); + // For RNN State, it can prefill as long as it can be instantiated. + if (engine_config_->kv_state_kind == KVStateKind::kRNNState) { + return true; + } + // No exceeding of the maximum allowed requests that can // run simultaneously. int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index b33d3709e8..b721eae7c3 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -244,7 +244,12 @@ void FunctionTable::_InitFunctions() { this->alloc_embedding_tensor_func_ = mod_get_func("alloc_embedding_tensor"); this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache"); if (!this->create_kv_cache_func_.defined()) { - this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); + PackedFunc f_create_rnn_state = mod_get_func("create_rnn_state"); + if (f_create_rnn_state.defined()) { + this->create_kv_cache_func_ = f_create_rnn_state; + } else { + this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); + } ICHECK(this->create_kv_cache_func_.defined()); } this->reset_kv_cache_func_ = get_global_func("vm.builtin.kv_state_clear"); diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 3583b5d84b..27a0043850 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -13,6 +13,7 @@ #include +#include "config.h" #include "logit_processor.h" namespace mlc { @@ -68,6 +69,12 @@ class ModelImpl : public ModelObj { token_ids_storage_ = memory::Storage( allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)), allocator); this->logit_pos_arr_ = NDArray::Empty({max_num_sequence}, DataType::Int(32), device_host); + // Step 7. Set model type + if (model_config["model_type"].get().find("rwkv") != std::string::npos) { + this->kind = KVStateKind::kRNNState; + } else { + this->kind = KVStateKind::kAttention; + } } /*********************** Model Computation ***********************/ @@ -739,16 +746,26 @@ class ModelImpl : public ModelObj { /*********************** KV Cache Management ***********************/ void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size) final { - IntTuple max_num_sequence_tuple{max_num_sequence}; - IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; - IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; - IntTuple page_size_tuple{page_size}; - IntTuple support_sliding_window{sliding_window_size_ != -1}; - kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple, - prefill_chunk_size_tuple, page_size_tuple, - support_sliding_window); - local_kv_cache_ = ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + int prefill_chunk_size, int max_history_size, + KVStateKind kv_state_kind) final { + if (kv_state_kind == KVStateKind::kAttention) { + IntTuple max_num_sequence_tuple{max_num_sequence}; + IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; + IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; + IntTuple page_size_tuple{page_size}; + IntTuple support_sliding_window{sliding_window_size_ != -1}; + kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple, + prefill_chunk_size_tuple, page_size_tuple, + support_sliding_window); + local_kv_cache_ = + ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + } else { + IntTuple max_num_sequence_tuple{max_num_sequence}; + IntTuple max_history_size_tuple = {std::max(max_history_size, 1)}; + kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_history_size_tuple); + local_kv_cache_ = + ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + } } void AddNewSequence(int64_t seq_id) final { ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); } @@ -775,11 +792,21 @@ class ModelImpl : public ModelObj { /************** Raw Info Query **************/ int GetNumAvailablePages() const final { - return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_); + if (this->kind == KVStateKind::kRNNState) { + // RNNState does not introduce new page at runtime + return std::numeric_limits::max(); + } else { + return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_); + } } int GetCurrentTotalSequenceLength() const final { - return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_); + if (this->kind == KVStateKind::kRNNState) { + // RNNState does not have a total sequence length limit + return 0; + } else { + return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_); + } } /*********************** Utilities ***********************/ @@ -946,6 +973,8 @@ class ModelImpl : public ModelObj { NDArray logit_pos_arr_{nullptr}; // A boolean indicating if tracing is enabled. bool trace_enabled_; + // An enum indicating whether it's RNN-based. + KVStateKind kind; }; TVM_REGISTER_GLOBAL("mlc.copy_embedding_to_offset") diff --git a/cpp/serve/model.h b/cpp/serve/model.h index da532f83e8..045daff874 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -234,9 +234,13 @@ class ModelObj : public Object { * in the engine. * \param prefill_chunk_size The maximum total number of tokens whose KV data * are allowed to exist in the KV cache at any time. + * \param max_history_size The maximum history size for RNN state to roll back. + * The KV cache does not need this. + * \param kv_state_kind The kind of cache. It can be KV cache or RNN state. */ virtual void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size) = 0; + int prefill_chunk_size, int max_history_size, + KVStateKind kv_state_kind) = 0; /*! \brief Add a new sequence with the given sequence id to the KV cache. */ virtual void AddNewSequence(int64_t seq_id) = 0; diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index 9f7c1c3580..6663a0c230 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -44,6 +44,9 @@ def main(argv): "--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"] ) parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"]) + parser.add_argument( + "--max-history-size", type=int, default=1, help=HELP["max_history_size_serve"] + ) parser.add_argument( "--gpu-memory-utilization", type=float, help=HELP["gpu_memory_utilization_serve"] ) @@ -100,6 +103,7 @@ def main(argv): max_batch_size=parsed.max_batch_size, max_total_sequence_length=parsed.max_total_seq_length, prefill_chunk_size=parsed.prefill_chunk_size, + max_history_size=parsed.max_history_size, gpu_memory_utilization=parsed.gpu_memory_utilization, speculative_mode=SpeculativeMode[parsed.speculative_mode], spec_draft_length=parsed.spec_draft_length, diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index 917e229632..1c599fa875 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -365,7 +365,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: # RWKV World ConvTemplateRegistry.register_conv_template( Conversation( - name="rwkv-world", + name="rwkv_world", system_template=f"User: hi\n\nAssistant: {MessagePlaceholders.SYSTEM.value}", system_message=( "Hi. I am your assistant and I will provide expert full response " diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index 14e5cee321..86930fa5ea 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -152,6 +152,11 @@ The maximum number of tokens the model passes for prefill each time. It should not exceed the prefill chunk size in model config. If not specified, this defaults to the prefill chunk size in model config. +""".strip(), + "max_history_size_serve": """ +The maximum history length for rolling back the RNN state. +If unspecified, the default value is 1. +KV cache does not need this. """.strip(), "enable_tracing_serve": """ Enable Chrome Tracing for the server. diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index d0cbd4690b..40fa9fdda8 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -22,6 +22,7 @@ def serve( max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + max_history_size: Optional[int], gpu_memory_utilization: Optional[float], speculative_mode: SpeculativeMode, spec_draft_length: int, @@ -44,6 +45,7 @@ def serve( max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, gpu_memory_utilization=gpu_memory_utilization, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, diff --git a/python/mlc_llm/model/rwkv5/rwkv5_model.py b/python/mlc_llm/model/rwkv5/rwkv5_model.py index 49386720da..81c9e9aa7f 100644 --- a/python/mlc_llm/model/rwkv5/rwkv5_model.py +++ b/python/mlc_llm/model/rwkv5/rwkv5_model.py @@ -40,6 +40,7 @@ class RWKV5Config(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = -1 # RWKV does not have context window limitation. prefill_chunk_size: int = 4096 num_heads: int = 0 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -129,23 +130,18 @@ def wkv_func( def token_shift(state: Tensor, x: Tensor): - # x.shape = (batch, seq_len, hidden_size) - # state.shape = (batch, hidden_size) - seq_len = x.shape[1] - def _te_token_shift(state: te.Tensor, x: te.Tensor): return te.compute( x.shape, lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]), ) - return state if seq_len == 1 else op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) + return op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) def last_token(x: Tensor): # x.shape = (batch, seq_len, hidden_size) batch, seq_len, hidden_size = x.shape - assert batch == 1 def _te_last_token(x: te.Tensor): return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j]) @@ -350,10 +346,14 @@ def to(self, dtype: Optional[str] = None): def embed(self, input_ids: Tensor): return self.model.embeddings(input_ids) - def forward(self, input_embed: Tensor, state: RNNState): + def forward( + self, input_embed: Tensor, state: RNNState, logit_positions: Optional[Tensor] = None + ): """Forward pass.""" hidden_states, state = self.model(input_embed, state) hidden_states = last_token(hidden_states) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") @@ -367,11 +367,27 @@ def decode(self, input_embed: Tensor, state: RNNState): """Decoding step.""" return self.forward(input_embed, state) + def batch_prefill(self, input_embeds: Tensor, logit_positions: Tensor, state: RNNState): + """Prefilling the prompt.""" + return self.forward(input_embeds, state, logit_positions=logit_positions) + + def batch_decode(self, input_embeds: Tensor, state: RNNState): + """Decoding step.""" + return self.forward(input_embeds, state) + + def batch_verify(self, input_embeds: Tensor, state: RNNState): + """Verify step.""" + return self.forward(input_embeds, state) + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): """Softmax.""" - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Object: + def create_rnn_state( + self, + max_batch_size: tir.Var, + max_history: tir.Var, + ) -> Object: """Create RNN state.""" init_values = [ op.zeros((self.hidden_size,), dtype=self.dtype), # ATT_X @@ -386,7 +402,6 @@ def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Obj ) def get_default_spec(self): - batch_size = 1 mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), @@ -396,9 +411,7 @@ def get_default_spec(self): }, }, "prefill": { - "input_embed": nn.spec.Tensor( - [batch_size, "seq_len", self.hidden_size], self.dtype - ), + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -406,7 +419,32 @@ def get_default_spec(self): }, }, "decode": { - "input_embed": nn.spec.Tensor([batch_size, 1, self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -414,8 +452,8 @@ def get_default_spec(self): }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([batch_size, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/model/rwkv6/rwkv6_model.py b/python/mlc_llm/model/rwkv6/rwkv6_model.py index 0e1887310d..a8faf48a6b 100644 --- a/python/mlc_llm/model/rwkv6/rwkv6_model.py +++ b/python/mlc_llm/model/rwkv6/rwkv6_model.py @@ -40,6 +40,7 @@ class RWKV6Config(ConfigBase): # pylint: disable=too-many-instance-attributes context_window_size: int = -1 # RWKV does not have context window limitation. prefill_chunk_size: int = 4096 num_heads: int = 0 + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -126,20 +127,17 @@ def wkv_func( def token_shift(state: Tensor, x: Tensor): - seq_len = x.shape[1] - def _te_token_shift(state: te.Tensor, x: te.Tensor): return te.compute( x.shape, lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]), ) - return state if seq_len == 1 else op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) + return op.tensor_expr_op(_te_token_shift, "token_shift", [state, x]) def last_token(x: Tensor): batch, seq_len, hidden_size = x.shape - assert batch == 1 def _te_last_token(x: te.Tensor): return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j]) @@ -390,10 +388,14 @@ def to(self, dtype: Optional[str] = None): def embed(self, input_ids: Tensor): return self.model.embeddings(input_ids) - def forward(self, input_embed: Tensor, state: RNNState): + def forward( + self, input_embed: Tensor, state: RNNState, logit_positions: Optional[Tensor] = None + ): """Forward pass.""" hidden_states, state = self.model(input_embed, state) hidden_states = last_token(hidden_states) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) logits = self.head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") @@ -407,11 +409,27 @@ def decode(self, input_embed: Tensor, state: RNNState): """Decoding step.""" return self.forward(input_embed, state) + def batch_prefill(self, input_embeds: Tensor, logit_positions: Tensor, state: RNNState): + """Prefilling the prompt.""" + return self.forward(input_embeds, state, logit_positions=logit_positions) + + def batch_decode(self, input_embeds: Tensor, state: RNNState): + """Decoding step.""" + return self.forward(input_embeds, state) + + def batch_verify(self, input_embeds: Tensor, state: RNNState): + """Verify step.""" + return self.forward(input_embeds, state) + def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): """Softmax.""" - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Object: + def create_rnn_state( + self, + max_batch_size: tir.Var, + max_history: tir.Var, + ) -> Object: """Create RNN state.""" init_values = [ op.zeros((self.hidden_size,), dtype=self.dtype), # ATT_X @@ -426,7 +444,6 @@ def create_rnn_state(self, max_batch_size: tir.Var, max_history: tir.Var) -> Obj ) def get_default_spec(self): - batch_size = 1 mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), @@ -436,9 +453,7 @@ def get_default_spec(self): }, }, "prefill": { - "input_embed": nn.spec.Tensor( - [batch_size, "seq_len", self.hidden_size], self.dtype - ), + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -446,7 +461,32 @@ def get_default_spec(self): }, }, "decode": { - "input_embed": nn.spec.Tensor([batch_size, 1, self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "state": nn.spec.Object(object_type=RNNState), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), "state": nn.spec.Object(object_type=RNNState), "$": { "param_mode": "packed", @@ -454,8 +494,8 @@ def get_default_spec(self): }, }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([batch_size, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 60e4eca8c5..40c53e336a 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -128,6 +128,13 @@ def from_json(json_str: str) -> "GenerationConfig": return GenerationConfig(**json.loads(json_str)) +class KVStateKind(enum.IntEnum): # pylint: disable=too-few-public-methods + """Possible kinds of KV state.""" + + ATTENTION = 0 + RNNSTATE = 1 + + class SpeculativeMode(enum.IntEnum): """The speculative mode.""" @@ -177,6 +184,12 @@ class EngineConfig(tvm.runtime.Object): prefill_chunk_size : int The maximum total sequence length in a prefill. + max_history_size: int + The maximum history size for RNN state to rool back. + + kv_state_kind: KVStateKind + The kind of cache. + speculative_mode : SpeculativeMode The speculative mode. @@ -196,6 +209,8 @@ def __init__( # pylint: disable=too-many-arguments max_total_sequence_length: int, max_single_sequence_length: int, prefill_chunk_size: int, + max_history_size: int, + kv_state_kind: KVStateKind, speculative_mode: SpeculativeMode, spec_draft_length: int, ) -> None: @@ -211,6 +226,8 @@ def __init__( # pylint: disable=too-many-arguments max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, + max_history_size, + kv_state_kind, speculative_mode, spec_draft_length, ) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index d9721b4864..413c856db1 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -816,6 +816,9 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): It should not exceed the prefill chunk size in model config. If not specified, this defaults to the prefill chunk size in model config. + max_history_size : Optional[int] + The maximum history for RNN state. + gpu_memory_utilization : Optional[float] A number in (0, 1) denoting the fraction of GPU memory used by the server in total. It is used to infer to maximum possible KV cache capacity. @@ -846,6 +849,7 @@ def __init__( # pylint: disable=too-many-arguments max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, spec_draft_length: int = 4, @@ -861,6 +865,7 @@ def __init__( # pylint: disable=too-many-arguments max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, gpu_memory_utilization=gpu_memory_utilization, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, @@ -1392,6 +1397,7 @@ def __init__( # pylint: disable=too-many-arguments max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, spec_draft_length: int = 4, @@ -1407,6 +1413,7 @@ def __init__( # pylint: disable=too-many-arguments max_batch_size=max_batch_size, max_total_sequence_length=max_total_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, gpu_memory_utilization=gpu_memory_utilization, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 7b2ede60b2..5d62dd5fb1 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -17,10 +17,16 @@ from tvm.runtime import Device from mlc_llm.chat_module import _get_chat_config, _get_lib_module_path, _get_model_path +from mlc_llm.cli.model_metadata import _compute_memory_usage, _extract_metadata from mlc_llm.protocol import openai_api_protocol, protocol_utils from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineConfig, GenerationConfig, SpeculativeMode +from mlc_llm.serve.config import ( + EngineConfig, + GenerationConfig, + KVStateKind, + SpeculativeMode, +) from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -121,7 +127,7 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: return model_args, config_file_paths, conversation -def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-many-locals,too-many-arguments +def _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( # pylint: disable=too-many-locals,too-many-arguments models: List[ModelInfo], device: tvm.runtime.Device, model_config_paths: List[str], @@ -240,6 +246,77 @@ def _estimate_mem_usage_and_max_total_sequence_length( # pylint: disable=too-ma ) +def _estimate_mem_usage_and_max_history_size_for_rnn_state( # pylint: disable=too-many-arguments, too-many-locals, unused-argument + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_paths: List[str], + model_config_dicts: List[Dict[str, Any]], + max_num_sequence: int, + gpu_memory_utilization: Optional[float], +) -> Tuple[float, float, float, int]: + # Get single-card GPU size. + gpu_size_bytes = device.total_global_memory + if gpu_size_bytes is None: + raise ValueError("Cannot read total GPU global memory from device.") + if gpu_memory_utilization is None: + gpu_memory_utilization = 0.90 + + rnn_state_base_bytes = 0.0 # the memory usage for rnn state when history = 1 + param_bytes = 0.0 + model_workspace_bytes = 0.0 + logit_processor_workspace_bytes = 0.0 + for model, model_config_dict in zip(models, model_config_dicts): + model_config = model_config_dict["model_config"] + vocab_size = model_config_dict["vocab_size"] + head_size = model_config["head_size"] + num_heads = model_config["num_heads"] + num_layers = model_config["num_hidden_layers"] + hidden_size = model_config["hidden_size"] + prefill_chunk_size = model_config["prefill_chunk_size"] + logit_processor_workspace_bytes += ( + max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 + ) + + model_workspace_bytes += ( + prefill_chunk_size * 4 + + max_num_sequence * 4 + + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 + ) + + rnn_state_base_bytes += ( + max_num_sequence * hidden_size * num_layers * 2 * 2 + + max_num_sequence * num_heads * head_size * head_size * num_layers * 2 + ) + + metadata = _extract_metadata(Path(model.model_lib_path)) + metadata["memory_usage"] = {} + metadata["kv_cache_bytes"] = 0 + current_param_bytes, _, _ = _compute_memory_usage(metadata, model_config_dict) + param_bytes += current_param_bytes + + max_history_size = int( + ( + gpu_size_bytes * gpu_memory_utilization + - logit_processor_workspace_bytes + - model_workspace_bytes + - param_bytes + ) + / rnn_state_base_bytes + ) + if max_history_size < 1: + raise ValueError( + f"Memory required by models may be larger than available GPU memory " + f"size {gpu_size_bytes * gpu_memory_utilization} bytes." + ) + + return ( + param_bytes, + model_workspace_bytes + logit_processor_workspace_bytes, + rnn_state_base_bytes, + max_history_size, + ) + + def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[int, int, int]: """Read the model config dictionaries, and return the maximum single sequence length the models can support, the maximum prefill chunk @@ -294,7 +371,7 @@ def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[i return model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size -def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements +def _infer_kv_cache_config_for_kv_cache( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements mode: Literal["local", "interactive", "server"], max_batch_size: Optional[int], max_total_sequence_length: Optional[int], @@ -304,12 +381,13 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local device: tvm.runtime.Device, model_config_dicts: List[Dict[str, Any]], model_config_paths: List[str], -) -> Tuple[int, int, int, int]: +) -> Tuple[int, int, int, KVStateKind, int]: """Initialize the KV cache config with user input and GPU memory usage estimation. The returned four integers are: - max_batch_size - max_total_sequence_length - prefill_chunk_size + - kv_state_kind - model_max_single_sequence_length """ ( @@ -323,7 +401,7 @@ def infer_args_under_mode( max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], - ) -> Tuple[Tuple[int, int, int], List[float]]: + ) -> Tuple[Tuple[int, int, int, KVStateKind], List[float]]: logging_msg = "" # - max_batch_size if max_batch_size is None: @@ -343,7 +421,7 @@ def infer_args_under_mode( kv_aux_workspace_bytes, temp_workspace_bytes, model_max_total_sequence_length, - ) = _estimate_mem_usage_and_max_total_sequence_length( + ) = _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( models, device, model_config_paths, @@ -400,7 +478,12 @@ def infer_args_under_mode( # - Construct the KV cache config # - Estimate total GPU memory usage on single GPU. - return (max_batch_size, max_total_sequence_length, prefill_chunk_size), [ + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + KVStateKind.ATTENTION, + ), [ total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token, model_params_bytes, kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes, @@ -462,6 +545,167 @@ def infer_args_under_mode( return *kv_cache_config, model_max_single_sequence_length +def _infer_kv_cache_config_for_rnn_state( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + max_history_size: Optional[int], + gpu_memory_utilization: Optional[float], + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_dicts: List[Dict[str, Any]], + model_config_paths: List[str], +) -> Tuple[int, int, int, KVStateKind, int]: + """Initialize the RNN state config with user input and GPU memory usage estimation. + The returned four integers are: + - max_batch_size + - max_total_sequence_length + - prefill_chunk_size + - kv_state_kind + - max_history_size + """ + logging_msg = "" + prefill_chunk_size = 0 + + if prefill_chunk_size is None: + prefill_chunk_size = min( + config["prefill_chunk_size"] if "prefill_chunk_size" in config else 4096 + for config in model_config_dicts + ) + logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " + else: + logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " + if max_batch_size is None: + max_batch_size = 1 if mode == "interactive" else 4 + logging_msg += f"max batch size is set to {max_batch_size}, " + else: + logging_msg += f"max batch size {max_batch_size} is specified by user, " + + if mode == "local": + logging_msg += ( + "We choose small max batch size and RNN state capacity to use less GPU memory." + ) + elif mode == "interactive": + logging_msg += "We fix max batch size to 1 for interactive single sequence use." + else: + logging_msg += ( + "We use as much GPU memory as possible (within the" " limit of gpu_memory_utilization)." + ) + logger.info('Under mode "%s", %s', mode, logging_msg) + + ( + model_param_bytes, + model_temp_bytes, + model_rnn_state_base_bytes, + model_max_history_size, + ) = _estimate_mem_usage_and_max_history_size_for_rnn_state( + models, + device, + model_config_paths, + model_config_dicts, + max_batch_size, + gpu_memory_utilization, + ) + if max_history_size is None: + max_history_size = model_max_history_size + else: + max_history_size = min(max_history_size, model_max_history_size) + max_total_sequence_length = 32768 + prefill_chunk_size = 0 + kind = KVStateKind.RNNSTATE + + logger.info( + "%s: %.2f MB (Parameters: %.2f MB. RNNState: %.2f MB. Temporary buffer: %.2f MB). " + "The actual usage might be slightly larger than the estimated number.", + green("Estimated total single GPU memory usage"), + (model_param_bytes + model_temp_bytes + model_rnn_state_base_bytes) / 1024 / 1024, + model_param_bytes / 1024 / 1024, + max_history_size * model_rnn_state_base_bytes / 1024 / 1024, + model_temp_bytes / 1024 / 1024, + ) + + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kind, + max_history_size, + ) + + +def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + mode: Literal["local", "interactive", "server"], + max_batch_size: Optional[int], + max_total_sequence_length: Optional[int], + prefill_chunk_size: Optional[int], + max_history_size: Optional[int], + gpu_memory_utilization: Optional[float], + models: List[ModelInfo], + device: tvm.runtime.Device, + model_config_dicts: List[Dict[str, Any]], + model_config_paths: List[str], +) -> Tuple[int, int, int, int, int, KVStateKind]: + """Initialize the cache config with user input and GPU memory usage estimation. + The returned four integers are: + - max_batch_size + - max_total_sequence_length + - prefill_chunk_size + - max_single_sequence_length + - max_history_size + - kv_state_kind + """ + if all("rwkv" not in model.model for model in models): + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kv_state_kind, + max_single_sequence_length, + ) = _infer_kv_cache_config_for_kv_cache( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + gpu_memory_utilization, + models, + device, + model_config_dicts, + model_config_paths, + ) + max_history_size = 0 # KV cache doesn't need this + elif all("rwkv" in model.model for model in models): + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + kv_state_kind, + max_history_size, + ) = _infer_kv_cache_config_for_rnn_state( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_history_size, + gpu_memory_utilization, + models, + device, + model_config_dicts, + model_config_paths, + ) + max_single_sequence_length = max_total_sequence_length # RNN state doesn't need this + else: + raise ValueError("The models should be either all KV cache models or all RNN state models.") + return ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + max_history_size, + kv_state_kind, + ) + + @dataclass class CallbackStreamOutput: """The output of MLCEngine._generate and AsyncMLCEngine._generate @@ -728,6 +972,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_batch_size: Optional[int], max_total_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + max_history_size: Optional[int], gpu_memory_utilization: Optional[float], speculative_mode: SpeculativeMode, spec_draft_length: int, @@ -757,11 +1002,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length, prefill_chunk_size, max_single_sequence_length, + max_history_size, + kv_state_kind, ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, prefill_chunk_size, + max_history_size, gpu_memory_utilization, models, device, @@ -803,6 +1051,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length=max_total_sequence_length, max_single_sequence_length=max_single_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ) diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 257338da3a..7469ddc241 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -98,6 +98,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, enable_tracing: bool = False, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, @@ -128,11 +129,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length, prefill_chunk_size, max_single_sequence_length, + max_history_size, + kv_state_kind, ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, prefill_chunk_size, + max_history_size, gpu_memory_utilization, models, device, @@ -168,6 +172,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length=max_total_sequence_length, max_single_sequence_length=max_single_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ), diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index 9b594e9784..c0c749c0a7 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -89,6 +89,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, + max_history_size: Optional[int] = None, prefill_chunk_size: Optional[int] = None, speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, spec_draft_length: int = 4, @@ -118,11 +119,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length, prefill_chunk_size, max_single_sequence_length, + max_history_size, + kv_state_kind, ) = _infer_kv_cache_config( mode, max_batch_size, max_total_sequence_length, prefill_chunk_size, + max_history_size, gpu_memory_utilization, models, device, @@ -162,6 +166,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_total_sequence_length=max_total_sequence_length, max_single_sequence_length=max_single_sequence_length, prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ) diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index f965e8cc82..37d1833b14 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -2,6 +2,8 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable from typing import List +import pytest + from mlc_llm.serve import GenerationConfig, MLCEngine prompts = [ @@ -17,17 +19,39 @@ "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", ] +test_models = [ + ( + "dist/Llama-2-7b-chat-hf-q0f16-MLC", + "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", + ), + ( + "dist/rwkv-6-world-1b6-q0f16-MLC", + "dist/rwkv-6-world-1b6-q0f16-MLC/rwkv-6-world-1b6-q0f16-MLC-cuda.so", + ), +] -def test_engine_generate(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) + +def create_engine(model: str, model_lib_path: str): + if "rwkv" in model: + return MLCEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_batch_size=8, + max_history_size=1, + ) + else: + return MLCEngine( + model=model, + model_lib_path=model_lib_path, + mode="server", + max_total_sequence_length=4096, + ) + + +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_engine_generate(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 10 max_tokens = 256 @@ -57,16 +81,10 @@ def test_engine_generate(): del engine -def test_chat_completion(): +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_chat_completion(model: str, model_lib_path: str): # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 64 @@ -101,16 +119,9 @@ def test_chat_completion(): del engine -def test_chat_completion_non_stream(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_chat_completion_non_stream(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 64 @@ -144,16 +155,9 @@ def test_chat_completion_non_stream(): del engine -def test_completion(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_completion(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 128 @@ -188,16 +192,9 @@ def test_completion(): del engine -def test_completion_non_stream(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) +@pytest.mark.parametrize("model,model_lib_path", test_models) +def test_completion_non_stream(model: str, model_lib_path: str): + engine = create_engine(model, model_lib_path) num_requests = 2 max_tokens = 128 @@ -232,8 +229,9 @@ def test_completion_non_stream(): if __name__ == "__main__": - test_engine_generate() - test_chat_completion() - test_chat_completion_non_stream() - test_completion() - test_completion_non_stream() + for model, model_lib_path in test_models: + test_engine_generate(model, model_lib_path) + test_chat_completion(model, model_lib_path) + test_chat_completion_non_stream(model, model_lib_path) + test_completion(model, model_lib_path) + test_completion_non_stream(model, model_lib_path) From fab0dd33b75efc98b8f9cad1eac0b4f0cb670ccd Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 25 Apr 2024 21:00:02 -0400 Subject: [PATCH 236/531] [Serving] Remove `cli.model_metadata` import from engine base (#2226) This PR removes the imports of functions in `cli.model_metadata` from engine_base.py. The file `cli.model_metadata` is not designed for import directly, and when importing functions from the file, it repetitively reports warnings of ``` RuntimeWarning: 'mlc_llm.cli.model_metadata' found in sys.modules after import of package 'mlc_llm.cli', but prior to execution of 'mlc_llm.cli.model_metadata'; this may result in unpredictable behaviour ``` --- python/mlc_llm/serve/engine_base.py | 30 ++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 5d62dd5fb1..85720adcac 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -17,7 +17,6 @@ from tvm.runtime import Device from mlc_llm.chat_module import _get_chat_config, _get_lib_module_path, _get_model_path -from mlc_llm.cli.model_metadata import _compute_memory_usage, _extract_metadata from mlc_llm.protocol import openai_api_protocol, protocol_utils from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.serve import data, engine_utils @@ -263,9 +262,27 @@ def _estimate_mem_usage_and_max_history_size_for_rnn_state( # pylint: disable=t rnn_state_base_bytes = 0.0 # the memory usage for rnn state when history = 1 param_bytes = 0.0 + temp_func_bytes = 0.0 model_workspace_bytes = 0.0 logit_processor_workspace_bytes = 0.0 - for model, model_config_dict in zip(models, model_config_dicts): + for model, model_config_path, model_config_dict in zip( + models, model_config_paths, model_config_dicts + ): + # Read metadata for the parameter size and the temporary memory size. + cmd = [ + sys.executable, + "-m", + "mlc_llm.cli.model_metadata", + model.model_lib_path, + "--print-memory-usage-in-json", + "--mlc-chat-config", + model_config_path, + ] + usage_str = subprocess.check_output(cmd, universal_newlines=True) + usage_json = json.loads(usage_str) + param_bytes += usage_json["params_bytes"] + temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) + model_config = model_config_dict["model_config"] vocab_size = model_config_dict["vocab_size"] head_size = model_config["head_size"] @@ -288,18 +305,13 @@ def _estimate_mem_usage_and_max_history_size_for_rnn_state( # pylint: disable=t + max_num_sequence * num_heads * head_size * head_size * num_layers * 2 ) - metadata = _extract_metadata(Path(model.model_lib_path)) - metadata["memory_usage"] = {} - metadata["kv_cache_bytes"] = 0 - current_param_bytes, _, _ = _compute_memory_usage(metadata, model_config_dict) - param_bytes += current_param_bytes - max_history_size = int( ( gpu_size_bytes * gpu_memory_utilization - logit_processor_workspace_bytes - model_workspace_bytes - param_bytes + - temp_func_bytes ) / rnn_state_base_bytes ) @@ -311,7 +323,7 @@ def _estimate_mem_usage_and_max_history_size_for_rnn_state( # pylint: disable=t return ( param_bytes, - model_workspace_bytes + logit_processor_workspace_bytes, + model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes, rnn_state_base_bytes, max_history_size, ) From 1cdd0f914a55b027224b890599f744b07a4776d8 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Fri, 26 Apr 2024 09:27:00 -0400 Subject: [PATCH 237/531] [JSONFFIEngine] Support generation config in JSONFFIEngine. Default config values to NOT_GIVEN (#2225) * Change OpenAI protocol default value to None in JSON FFI engine * [JSONFFIEngine] Support generation config in JSONFFIEngine. Default config values to NOT_GIVEN --- cpp/json_ffi/{conv_template.cc => config.cc} | 46 +++++++++++++++- cpp/json_ffi/{conv_template.h => config.h} | 55 ++++++++++++++++++- cpp/json_ffi/json_ffi_engine.cc | 10 ++-- cpp/json_ffi/json_ffi_engine.h | 3 +- cpp/json_ffi/openai_api_protocol.h | 10 ++-- cpp/metadata/json_parser.h | 16 ++++++ cpp/serve/config.cc | 24 +++++--- cpp/serve/config.h | 12 ++-- .../mlc_llm/protocol/openai_api_protocol.py | 2 +- python/mlc_llm/serve/engine_base.py | 22 ++++++++ python/mlc_llm/support/auto_config.py | 2 +- tests/python/json_ffi/_ffi_api.py | 6 ++ tests/python/json_ffi/test_json_ffi_engine.py | 51 +++++++++++++++-- 13 files changed, 226 insertions(+), 33 deletions(-) rename cpp/json_ffi/{conv_template.cc => config.cc} (85%) rename cpp/json_ffi/{conv_template.h => config.h} (67%) create mode 100644 tests/python/json_ffi/_ffi_api.py diff --git a/cpp/json_ffi/conv_template.cc b/cpp/json_ffi/config.cc similarity index 85% rename from cpp/json_ffi/conv_template.cc rename to cpp/json_ffi/config.cc index 02e0b3bdbd..8f5c0e1062 100644 --- a/cpp/json_ffi/conv_template.cc +++ b/cpp/json_ffi/config.cc @@ -1,4 +1,6 @@ -#include "conv_template.h" +#include "config.h" + +#include #include "../metadata/json_parser.h" @@ -8,6 +10,29 @@ namespace json_ffi { using namespace mlc::llm; +/****************** Model-defined generation config ******************/ + +TVM_REGISTER_OBJECT_TYPE(ModelDefinedGenerationConfigNode); + +ModelDefinedGenerationConfig::ModelDefinedGenerationConfig(double temperature, double top_p, + double frequency_penalty, + double presence_penalty) { + ObjectPtr n = make_object(); + n->temperature = temperature; + n->top_p = top_p; + n->frequency_penalty = frequency_penalty; + n->presence_penalty = presence_penalty; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("mlc.json_ffi.ModelDefinedGenerationConfig") + .set_body_typed([](double temperature, double top_p, double frequency_penalty, + double presence_penalty) { + return ModelDefinedGenerationConfig(temperature, top_p, frequency_penalty, presence_penalty); + }); + +/****************** Conversation template ******************/ + std::map PLACEHOLDERS = { {MessagePlaceholders::SYSTEM, "{system_message}"}, {MessagePlaceholders::USER, "{user_message}"}, @@ -308,6 +333,25 @@ std::optional Conversation::FromJSON(const std::string& json_str, } return Conversation::FromJSON(json_obj.value(), err); } + +/****************** JSON FFI engine config ******************/ + +TVM_REGISTER_OBJECT_TYPE(JSONFFIEngineConfigNode); + +JSONFFIEngineConfig::JSONFFIEngineConfig( + String conv_template, Map model_generation_cfgs) { + ObjectPtr n = make_object(); + n->conv_template = conv_template; + n->model_generation_cfgs = model_generation_cfgs; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("mlc.json_ffi.JSONFFIEngineConfig") + .set_body_typed([](String conv_template, + Map model_generation_cfgs) { + return JSONFFIEngineConfig(std::move(conv_template), std::move(model_generation_cfgs)); + }); + } // namespace json_ffi } // namespace llm } // namespace mlc diff --git a/cpp/json_ffi/conv_template.h b/cpp/json_ffi/config.h similarity index 67% rename from cpp/json_ffi/conv_template.h rename to cpp/json_ffi/config.h index d3a1d1de2f..fe5e4e42e2 100644 --- a/cpp/json_ffi/conv_template.h +++ b/cpp/json_ffi/config.h @@ -1,5 +1,9 @@ -#ifndef MLC_LLM_JSON_FFI_CONV_TEMPLATE_H -#define MLC_LLM_JSON_FFI_CONV_TEMPLATE_H +#ifndef MLC_LLM_JSON_FFI_CONFIG_H +#define MLC_LLM_JSON_FFI_CONFIG_H + +#include +#include +#include #include #include @@ -18,6 +22,32 @@ namespace mlc { namespace llm { namespace json_ffi { +/****************** Model-defined generation config ******************/ + +class ModelDefinedGenerationConfigNode : public Object { + public: + double temperature; + double top_p; + double frequency_penalty; + double presence_penalty; + + static constexpr const char* _type_key = "mlc.json_ffi.ModelDefinedGenerationConfig"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(ModelDefinedGenerationConfigNode, Object); +}; + +class ModelDefinedGenerationConfig : public ObjectRef { + public: + explicit ModelDefinedGenerationConfig(double temperature, double top_p, double frequency_penalty, + double presence_penalty); + + TVM_DEFINE_OBJECT_REF_METHODS(ModelDefinedGenerationConfig, ObjectRef, + ModelDefinedGenerationConfigNode); +}; + +/****************** Conversation template ******************/ + enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION }; MessagePlaceholders messagePlaceholderFromString(const std::string& role); @@ -114,6 +144,27 @@ struct Conversation { static std::optional FromJSON(const std::string& json_str, std::string* err); }; +/****************** JSON FFI engine config ******************/ + +class JSONFFIEngineConfigNode : public Object { + public: + String conv_template; + Map model_generation_cfgs; + + static constexpr const char* _type_key = "mlc.json_ffi.JSONFFIEngineConfig"; + static constexpr const bool _type_has_method_sequal_reduce = false; + static constexpr const bool _type_has_method_shash_reduce = false; + TVM_DECLARE_BASE_OBJECT_INFO(JSONFFIEngineConfigNode, Object); +}; + +class JSONFFIEngineConfig : public ObjectRef { + public: + explicit JSONFFIEngineConfig(String conv_template, + Map model_generation_cfgs); + + TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineConfig, ObjectRef, JSONFFIEngineConfigNode); +}; + } // namespace json_ffi } // namespace llm } // namespace mlc diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 0e21735e2f..1a21c2962d 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -83,8 +83,8 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request Array inputs = inputs_obj.value(); // generation_cfg - Optional generation_cfg = - GenerationConfig::FromJSON(request_json_str, &err_, conv_template); + Optional generation_cfg = GenerationConfig::Create( + request_json_str, &err_, conv_template, this->model_generation_cfgs[request.model]); if (!generation_cfg.defined()) { return false; } @@ -122,14 +122,16 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_END(); - void InitBackgroundEngine(std::string conv_template_str, EngineConfig engine_config, + void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config, Optional request_stream_callback, Optional trace_recorder) { - std::optional conv_template = Conversation::FromJSON(conv_template_str, &err_); + std::optional conv_template = + Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_); if (!conv_template.has_value()) { LOG(FATAL) << "Invalid conversation template JSON: " << err_; } this->conv_template_ = conv_template.value(); + this->model_generation_cfgs = json_ffi_engine_config->model_generation_cfgs; // Todo(mlc-team): decouple InitBackgroundEngine into two functions // by removing `engine_config` from arguments, after properly handling diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h index 2c7501c337..d57384abb5 100644 --- a/cpp/json_ffi/json_ffi_engine.h +++ b/cpp/json_ffi/json_ffi_engine.h @@ -12,7 +12,7 @@ #include "../serve/threaded_engine.h" #include "../streamer.h" -#include "conv_template.h" +#include "config.h" #include "openai_api_protocol.h" namespace mlc { @@ -49,6 +49,7 @@ class JSONFFIEngine { PackedFunc request_stream_callback_; TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request Conversation conv_template_; + Map model_generation_cfgs; }; } // namespace json_ffi diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index bed225d3d0..429050da3c 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -13,7 +13,7 @@ #include #include -#include "conv_template.h" +#include "config.h" #include "picojson.h" namespace mlc { @@ -90,8 +90,8 @@ class ChatCompletionRequest { public: std::vector messages; std::string model; - double frequency_penalty = 0.0; - double presence_penalty = 0.0; + std::optional frequency_penalty = std::nullopt; + std::optional presence_penalty = std::nullopt; bool logprobs = false; int top_logprobs = 0; std::optional> logit_bias = std::nullopt; @@ -100,8 +100,8 @@ class ChatCompletionRequest { std::optional seed = std::nullopt; std::optional> stop = std::nullopt; bool stream = false; - double temperature = 1.0; - double top_p = 1.0; + std::optional temperature = std::nullopt; + std::optional top_p = std::nullopt; std::optional> tools = std::nullopt; std::optional tool_choice = std::nullopt; std::optional user = std::nullopt; diff --git a/cpp/metadata/json_parser.h b/cpp/metadata/json_parser.h index f6ff10e1ac..99a284fc42 100644 --- a/cpp/metadata/json_parser.h +++ b/cpp/metadata/json_parser.h @@ -149,6 +149,22 @@ inline ValueType Lookup(const picojson::object& json, const std::string& key) { return it->second.get(); } +template +inline ValueType LookupOrDefault(const picojson::object& json, const std::string& key, + const ValueType& default_value) { + auto it = json.find(key); + if (it == json.end()) { + return default_value; + } + + if (it->second.is()) { + return default_value; + } + + CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; + return it->second.get(); +} + template inline ValueType Lookup(const picojson::array& json, int index) { CHECK(index < json.size()) << "IndexError: json::array index out of range"; diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index f36bc151a3..19f26ff624 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -161,15 +161,26 @@ GenerationConfig::GenerationConfig(String config_json_str) { data_ = std::move(n); } -Optional GenerationConfig::FromJSON(const std::string& json_str, std::string* err, - const Conversation& conv_template) { - std::optional json_obj = json::LoadJSONFromString(json_str, err); - if (!err->empty() || !json_obj.has_value()) { +Optional GenerationConfig::Create( + const std::string& json_str, std::string* err, const Conversation& conv_template, + const ModelDefinedGenerationConfig& model_defined_gen_config) { + std::optional optional_json_obj = json::LoadJSONFromString(json_str, err); + if (!err->empty() || !optional_json_obj.has_value()) { return NullOpt; } + picojson::object& json_obj = optional_json_obj.value(); ObjectPtr n = make_object(); - // TODO(mlc-team): Pass the parameters from `json_obj` to `n`. + n->temperature = + json::LookupOrDefault(json_obj, "temperature", model_defined_gen_config->temperature); + n->top_p = json::LookupOrDefault(json_obj, "top_p", model_defined_gen_config->top_p); + n->frequency_penalty = json::LookupOrDefault(json_obj, "frequency_penalty", + model_defined_gen_config->frequency_penalty); + n->presence_penalty = json::LookupOrDefault(json_obj, "presence_penalty", + model_defined_gen_config->presence_penalty); + n->logprobs = json::LookupOrDefault(json_obj, "logprobs", false); + n->top_logprobs = static_cast(json::LookupOrDefault(json_obj, "top_logprobs", 0)); + n->ignore_eos = json::LookupOrDefault(json_obj, "ignore_eos", false); // Copy stop str from conversation template to generation config for (auto& stop_str : conv_template.stop_str) { @@ -179,9 +190,6 @@ Optional GenerationConfig::FromJSON(const std::string& json_st n->stop_token_ids.push_back(stop_token_id); } - if (!err->empty()) { - return NullOpt; - } GenerationConfig gen_config; gen_config.data_ = std::move(n); return gen_config; diff --git a/cpp/serve/config.h b/cpp/serve/config.h index ef147b751b..6a3bdd8997 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -11,7 +11,7 @@ #include -#include "../json_ffi/conv_template.h" +#include "../json_ffi/config.h" namespace mlc { namespace llm { @@ -63,11 +63,13 @@ class GenerationConfig : public ObjectRef { explicit GenerationConfig(String config_json_str); /*! - * \brief Parse the generation config from the given JSON string. - * When parsing fails, errors are dumped to the input error string, and NullOpt is returned. + * \brief Create a generation config from a ChatCompletionRequest. + * If the request does not contain a generation config, the model-defined + * generation config will be used. */ - static Optional FromJSON(const std::string& json_str, std::string* err, - const Conversation& conv_template); + static Optional Create( + const std::string& json_str, std::string* err, const Conversation& conv_template, + const ModelDefinedGenerationConfig& model_defined_gen_config); TVM_DEFINE_OBJECT_REF_METHODS(GenerationConfig, ObjectRef, GenerationConfigNode); }; diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index d6ce4a4fcb..4a5168f971 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -223,7 +223,7 @@ class ChatCompletionRequest(BaseModel): @classmethod def check_penalty_range(cls, penalty_value: float) -> float: """Check if the penalty value is in range [-2, 2].""" - if penalty_value < -2 or penalty_value > 2: + if penalty_value and (penalty_value < -2 or penalty_value > 2): raise ValueError("Penalty value should be in range [-2, 2].") return penalty_value diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 85720adcac..fb0a35ddd2 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -718,6 +718,28 @@ def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-local ) +def _infer_generation_config( + model_config_dicts: List[Dict[str, Any]] +) -> List[Tuple[float, float, float, float]]: + """Infer the generation config from the model config dictionaries. + The returned four floats are: + - temperature + - top_p + - frequency_penalty + - presence_penalty + """ + generation_configs = [] + + for model_config in model_config_dicts: + temperature = model_config.get("temperature", 1.0) + top_p = model_config.get("top_p", 1.0) + frequency_penalty = model_config.get("frequency_penalty", 0.0) + presence_penalty = model_config.get("presence_penalty", 0.0) + generation_configs.append((temperature, top_p, frequency_penalty, presence_penalty)) + + return generation_configs + + @dataclass class CallbackStreamOutput: """The output of MLCEngine._generate and AsyncMLCEngine._generate diff --git a/python/mlc_llm/support/auto_config.py b/python/mlc_llm/support/auto_config.py index f0247a6ef9..be0ee8af98 100644 --- a/python/mlc_llm/support/auto_config.py +++ b/python/mlc_llm/support/auto_config.py @@ -62,7 +62,7 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path: # search mlc-chat-config.json under path mlc_chat_config_json_path = mlc_chat_config_path / "mlc-chat-config.json" if not mlc_chat_config_json_path.exists(): - raise ValueError(f"Fail to find mlc_chat_config.json under {mlc_chat_config_path}.") + raise ValueError(f"Fail to find mlc-chat-config.json under {mlc_chat_config_path}.") else: mlc_chat_config_json_path = mlc_chat_config_path diff --git a/tests/python/json_ffi/_ffi_api.py b/tests/python/json_ffi/_ffi_api.py new file mode 100644 index 0000000000..3df07d6a1f --- /dev/null +++ b/tests/python/json_ffi/_ffi_api.py @@ -0,0 +1,6 @@ +"""FFI APIs for mlc.json_ffi""" +import tvm._ffi + +# Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc.json_ffi" prefix. +# e.g. TVM_REGISTER_GLOBAL("mlc.serve.TextData") +tvm._ffi._init_api("mlc.json_ffi", __name__) # pylint: disable=protected-access diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index c0c749c0a7..f5235663be 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union import tvm +from tests.python.json_ffi import _ffi_api from mlc_llm.protocol import openai_api_protocol from mlc_llm.serve import engine_utils @@ -60,6 +61,32 @@ ] +@tvm._ffi.register_object( + "mlc.json_ffi.ModelDefinedGenerationConfig" +) # pylint: disable=protected-access +class ModelDefinedGenerationConfig(tvm.runtime.Object): + def __init__( # pylint: disable=too-many-arguments + self, temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ModelDefinedGenerationConfig, + temperature, + top_p, + frequency_penalty, + presence_penalty, + ) + + +@tvm._ffi.register_object("mlc.json_ffi.JSONFFIEngineConfig") # pylint: disable=protected-access +class JSONFFIEngineConfig(tvm.runtime.Object): + def __init__( # pylint: disable=too-many-arguments + self, conv_template: str, model_generation_cfgs: Dict[str, ModelDefinedGenerationConfig] + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.JSONFFIEngineConfig, conv_template, model_generation_cfgs + ) + + class EngineState: sync_queue: queue.Queue @@ -171,8 +198,22 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ) + + self.json_ffi_engine_config = JSONFFIEngineConfig( + conv_template=self.conv_template.model_dump_json(), + model_generation_cfgs={ + model.model: ModelDefinedGenerationConfig( + temperature=model_config["temperature"], + top_p=model_config["top_p"], + frequency_penalty=model_config["frequency_penalty"], + presence_penalty=model_config["presence_penalty"], + ) + for model, model_config in zip(models, self.model_config_dicts) + }, + ) + self._ffi["init_background_engine"]( - self.conv_template.model_dump_json(), + self.json_ffi_engine_config, self.engine_config, self.state.get_request_stream_callback(), None, @@ -204,8 +245,8 @@ def chat_completion( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: str, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -214,8 +255,8 @@ def chat_completion( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, From 68505295ee9b260ef53f9df604fda2372f247a1b Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 26 Apr 2024 10:59:16 -0400 Subject: [PATCH 238/531] [Sampler] Fix GPU sampler behavior when batch size is 0 (#2234) This PR adds the early exit for the GPU sampler, which ran into GPU kernels even when the batch size is 0 prior to this commit. The 0 batch size case can happen when parallel generation of a request and engine preemption exists. In this case, the GPU sampler should just synchronization and return, and not run into any GPU kernel. --- cpp/serve/sampler/gpu_sampler.cc | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index c80a846b19..62911a7cd1 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -311,14 +311,20 @@ class GPUSampler : public SamplerObj { int vocab_size = probs_on_device->shape[1]; if (output_prob_dist != nullptr) { ICHECK(output_prob_dist->empty()); - output_prob_dist->reserve(num_probs); - for (int i = 0; i < num_probs; ++i) { + output_prob_dist->reserve(num_samples); + for (int i = 0; i < num_samples; ++i) { NDArray prob_dist = NDArray::Empty({vocab_size}, dtype_f32_, device_); - float* p_prob = static_cast(probs_on_device->data) + i * vocab_size; + float* p_prob = static_cast(probs_on_device->data) + sample_indices[i] * vocab_size; prob_dist.CopyFromBytes(p_prob, vocab_size * sizeof(float)); output_prob_dist->push_back(std::move(prob_dist)); } } + if (num_samples == 0) { + // This synchronization is necessary for making sure that this round + // of model forward is finished. + TVMSynchronize(device_.device_type, device_.device_id, compute_stream_); + return {}; + } ICHECK_EQ(request_ids.size(), num_samples); ICHECK_EQ(generation_cfg.size(), num_samples); ICHECK_EQ(rngs.size(), num_samples); @@ -580,7 +586,7 @@ class GPUSampler : public SamplerObj { } // Synchronize for CPU to get the correct array results. - TVMSynchronize(device_.device_type, device_.device_id, nullptr); + TVMSynchronize(device_.device_type, device_.device_id, compute_stream_); return {sampled_token_ids_host, sampled_probs_host, top_prob_probs_host, top_prob_indices_host}; } From ff72113272a4e5073c4ed18c6a11b80a3f677755 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 26 Apr 2024 12:43:57 -0400 Subject: [PATCH 239/531] [Pass] Support two-stage softmax (#2220) This PR introduces the compiler pass that rewrites the normal softmax to a two-stage softmax. This is based on our finding that when vocabulary size is large, the normal softmax cannot have high-enough parallelism on GPU. So we partition the workload into two stages for better parallelism and better performance. --- python/mlc_llm/compiler_pass/pipeline.py | 2 + .../mlc_llm/compiler_pass/rewrite_softmax.py | 190 ++++++++++++++++++ python/mlc_llm/support/max_thread_check.py | 2 +- tests/python/op/test_two_stage_softmax.py | 47 +++++ 4 files changed, 240 insertions(+), 1 deletion(-) create mode 100644 python/mlc_llm/compiler_pass/rewrite_softmax.py create mode 100644 tests/python/op/test_two_stage_softmax.py diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index b85a6a2cf6..57b68f742d 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -33,6 +33,7 @@ from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc from .low_batch_specialization import LowBatchGemvSpecialize +from .rewrite_softmax import RewriteTwoStageSoftmax from .scatter_tuple_get_item import ScatterTupleGetItem logger = logging.getLogger(__name__) @@ -117,6 +118,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline _LogProgress("Lowering to TVM TIR kernels"), tvm.relax.backend.DispatchSortScan(), + RewriteTwoStageSoftmax(target=target), tvm.relax.transform.LegalizeOps(), tvm.relax.transform.AnnotateTIROpPattern(), tvm.relax.transform.FoldConstant(), diff --git a/python/mlc_llm/compiler_pass/rewrite_softmax.py b/python/mlc_llm/compiler_pass/rewrite_softmax.py new file mode 100644 index 0000000000..1a6e41eafc --- /dev/null +++ b/python/mlc_llm/compiler_pass/rewrite_softmax.py @@ -0,0 +1,190 @@ +"""A compiler pass that rewrites one-shot softmax into two-stage softmax.""" + +import math + +import tvm +from tvm import relax +from tvm.ir.module import IRModule +from tvm.relax.expr import Expr +from tvm.relax.expr_functor import PyExprMutator, mutator +from tvm.script import tir as T + +from ..support.max_thread_check import get_max_num_threads_per_block + + +@tvm.transform.module_pass(opt_level=0, name="RewriteTwoStageSoftmax") +class RewriteTwoStageSoftmax: # pylint: disable=too-few-public-methods + """Rewrites one-shot softmax into two-stage softmax.""" + + def __init__(self, target: tvm.target.Target) -> None: + self.target = target + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + return _Rewriter(mod, self.target).transform() + + +@mutator +class _Rewriter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: IRModule, target: tvm.target.Target) -> None: + super().__init__(mod) + self.mod = mod + self.target = target + self.chunk_size = 4096 + + def transform(self) -> IRModule: + """Entry point""" + gv = self.mod.get_global_var("softmax_with_temperature") + updated_func = self.visit_expr(self.mod[gv]) + self.builder_.update_func(gv, updated_func) + return self.builder_.get() + + def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-renamed + if call.op != tvm.ir.Op.get("relax.nn.softmax"): + return call + x = call.args[0] + if call.attrs.axis not in [-1, x.struct_info.ndim - 1]: + return call + # Currently the softmax input is 3-dim, and dtype is float32. + assert x.struct_info.ndim == 3 + assert x.struct_info.dtype == "float32" + x_shape = x.struct_info.shape + new_shape = relax.ShapeExpr([x_shape[0] * x_shape[1], x_shape[2]]) + x_reshaped = relax.call_pure_packed( + "vm.builtin.reshape", + x, + new_shape, + sinfo_args=relax.TensorStructInfo(new_shape, x.struct_info.dtype), + ) + f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(self.target, self.chunk_size) + chunked_lse = relax.call_tir( + self.builder_.add_func(f_chunk_lse, "chunk_lse"), + args=[x_reshaped], + out_sinfo=relax.TensorStructInfo( + (new_shape[0], (new_shape[1] + self.chunk_size - 1) // self.chunk_size), + x.struct_info.dtype, + ), + ) + softmax = relax.call_tir( + self.builder_.add_func(f_softmax_with_lse, "softmax_with_chunked_lse"), + args=[x_reshaped, chunked_lse], + out_sinfo=relax.TensorStructInfo(new_shape, x.struct_info.dtype), + ) + return relax.call_pure_packed( + "vm.builtin.reshape", softmax, x_shape, sinfo_args=x.struct_info + ) + + +def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements + target: tvm.target.Target, chunk_size: int +): + log2e = math.log2(math.exp(1)) + + # pylint: disable=invalid-name + @T.prim_func + def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=too-many-locals + T.func_attr({"tir.noalias": T.bool(True)}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + chunked_lse = T.match_buffer(var_chunked_lse, (batch_size, num_chunks), dtype="float32") + A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(chunk_size)), dtype="float32") + temp_max = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + temp_sum = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + A_pad[v0, v1, v2] = T.if_then_else( + v1 * T.int64(chunk_size) + v2 < vocab_size, + A[v0, v1 * T.int64(chunk_size) + v2], + T.min_value("float32"), + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("max"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_max[v0, v1] = T.min_value("float32") + temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2]) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("sum_exp"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_sum[v0, v1] = T.float32(0) + temp_sum[v0, v1] += T.if_then_else( + v1 * T.int64(chunk_size) + v2 < vocab_size, + T.exp2((A_pad[v0, v1, v2] - temp_max[v0, v1]) * log2e), + T.float32(0), + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)): + with T.block("log"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + chunked_lse[v0, v1] = T.log2(temp_sum[v0, v1]) + temp_max[v0, v1] * log2e + + @T.prim_func + def softmax_with_chunked_lse(var_A: T.handle, var_chunked_lse: T.handle, var_softmax: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + chunked_lse = T.match_buffer(var_chunked_lse, (batch_size, num_chunks), dtype="float32") + softmax = T.match_buffer(var_softmax, (batch_size, vocab_size), dtype="float32") + temp_max = T.alloc_buffer((batch_size,), dtype="float32") + temp_sum = T.alloc_buffer((batch_size,), dtype="float32") + lse = T.alloc_buffer((batch_size,), dtype="float32") + for l0, l1 in T.grid(batch_size, num_chunks): + with T.block("max"): + v0, v1 = T.axis.remap("SR", [l0, l1]) + with T.init(): + temp_max[v0] = T.min_value("float32") + temp_max[v0] = T.max(temp_max[v0], chunked_lse[v0, v1]) + for l0, l1 in T.grid(batch_size, num_chunks): + with T.block("sum_exp"): + v0, v1 = T.axis.remap("SR", [l0, l1]) + with T.init(): + temp_sum[v0] = T.float32(0) + temp_sum[v0] += T.exp2(chunked_lse[v0, v1] - temp_max[v0]) + for l0 in T.serial(0, batch_size): + with T.block("log"): + v0 = T.axis.remap("S", [l0]) + lse[v0] = T.log2(temp_sum[v0]) + temp_max[v0] + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + if v1 * T.int64(chunk_size) + v2 < vocab_size: + softmax[v0, v1 * T.int64(chunk_size) + v2] = T.exp2( + A[v0, v1 * T.int64(chunk_size) + v2] * log2e - lse[v0] + ) + + sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_lse": softmax_with_chunked_lse})) + max_threads = get_max_num_threads_per_block(target) + TX = 32 + TY = max_threads // TX + unroll_depth = 64 + # pylint: enable=invalid-name + + sch.work_on("softmax_with_chunked_lse") + sch.compute_inline("log") + l0, l1, l2 = sch.get_loops("pad") + bx = sch.fuse(l0, l1) + sch.bind(bx, "blockIdx.x") + unroll, ty, tx = sch.split(l2, [None, TY, TX]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1) + + for block_name in ["sum_exp", "max"]: + block = sch.get_block(block_name) + sch.set_scope(block, buffer_index=0, storage_scope="shared") + sch.compute_at(block, bx) + r_loop = sch.get_loops(block)[-1] + r_loop, tx = sch.split(r_loop, [None, TX]) + sch.reorder(tx, r_loop) + sch.bind(tx, "threadIdx.x") + sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) + + return chunk_lse, sch.mod["softmax_with_chunked_lse"] diff --git a/python/mlc_llm/support/max_thread_check.py b/python/mlc_llm/support/max_thread_check.py index 6c078c3bbf..6711fb5c55 100644 --- a/python/mlc_llm/support/max_thread_check.py +++ b/python/mlc_llm/support/max_thread_check.py @@ -3,7 +3,7 @@ from tvm.target import Target -def get_max_num_threads_per_block(target: Target): +def get_max_num_threads_per_block(target: Target) -> int: """ max(max_num_threads, max_threads_per_block); if latter does not exist, return max_num_threads. We add this method since some targets have both fields and `max_threads_per_block` is larger. diff --git a/tests/python/op/test_two_stage_softmax.py b/tests/python/op/test_two_stage_softmax.py new file mode 100644 index 0000000000..1d3d55d8e3 --- /dev/null +++ b/tests/python/op/test_two_stage_softmax.py @@ -0,0 +1,47 @@ +import numpy as np +import scipy.special +import tvm +from tvm import dlight + +from mlc_llm.compiler_pass.rewrite_softmax import _get_lse_and_softmax_func + + +def test_two_stage_softmax(): + chunk_size = 4096 + target = tvm.target.Target("cuda") + f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(target, chunk_size) + mod = tvm.IRModule({"chunk_lse": f_chunk_lse, "softmax_with_chunked_lse": f_softmax_with_lse}) + with target: + mod = dlight.ApplyDefaultSchedule(dlight.gpu.GeneralReduction())(mod) + + runtime_mod = tvm.build(mod, target=target) + device = tvm.cuda() + + num_runs = 5 + vocab_size = 128256 + for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: + for _ in range(num_runs): + x_np = np.random.uniform(low=-10, high=10, size=(batch_size, vocab_size)).astype( + "float32" + ) + y_np = scipy.special.softmax(x_np, axis=-1) + + x_nd = tvm.nd.array(x_np, device=device) + r_nd = tvm.nd.empty( + (batch_size, (vocab_size + chunk_size - 1) // chunk_size), + x_np.dtype, + device=device, + ) + y_nd = tvm.nd.empty(x_np.shape, x_np.dtype, device=device) + + runtime_mod["chunk_lse"](x_nd, r_nd) + runtime_mod["softmax_with_chunked_lse"](x_nd, r_nd, y_nd) + + y_nd_arr = y_nd.numpy() + np.testing.assert_allclose(y_nd_arr, y_np, atol=1e-6, rtol=1e-6) + + print(f"pass batch size {batch_size}") + + +if __name__ == "__main__": + test_two_stage_softmax() From 3139fd7f25ce34e4f6aabe0e5c2af1c70f91e198 Mon Sep 17 00:00:00 2001 From: Git bot Date: Fri, 26 Apr 2024 16:50:54 +0000 Subject: [PATCH 240/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index d694451c58..ced07e8878 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit d694451c580a931116a2c93571f21f7d791c7fa0 +Subproject commit ced07e88781c0d6416e276d9cd084bb46aaf3da5 From 470a42a382da6dea8e132c211f444c9a37e7e76e Mon Sep 17 00:00:00 2001 From: "Kimura (Yamakado) Nobuhiro" <37305503+nobuhiroYamakado@users.noreply.github.com> Date: Sat, 27 Apr 2024 01:55:00 +0900 Subject: [PATCH 241/531] [Docs] Update deploy/ios#bring-your-own-model-library (#2235) remove model metadata step (#1) * remove model metadata step and make minor fixes --- docs/deploy/ios.rst | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index c0217db9e9..75a5cdbdc7 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -341,10 +341,24 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con mlc_llm gen_config ./dist/models/phi-2/ \ --quantization q4f16_1 --conv-template phi-2 \ -o dist/phi-2-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json + # 2. mkdir: create a directory to store the compiled model library + mkdir -p dist/libs + # 3. compile: compile model library with specification in mlc-chat-config.json mlc_llm compile ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json \ --device iphone -o dist/libs/phi-2-q4f16_1-iphone.tar +Given the compiled library, it is possible to calculate an upper bound for the VRAM +usage during runtime. This useful to better understand if a model is able to fit particular +hardware. +That information will be displayed at the end of the console log when the ``compile`` is executed. +It might look something like this: + +.. code:: shell + + [2024-04-25 03:19:56] INFO model_metadata.py:96: Total memory usage: 1625.73 MB (Parameters: 1492.45 MB. KVCache: 0.00 MB. Temporary buffer: 133.28 MB) + [2024-04-25 03:19:56] INFO model_metadata.py:105: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size` + [2024-04-25 03:19:56] INFO compile.py:198: Generated: dist/libs/phi-2-q4f16_1-iphone.tar + .. note:: When compiling larger models like ``Llama-2-7B``, you may want to add a lower chunk size while prefilling prompts ``--prefill_chunk_size 128`` or even lower ``context_window_size``\ @@ -388,21 +402,7 @@ This would result in something like `phi-2-q4f16_1-MLC `_. -**Step 4. Calculate estimated VRAM usage** - -Given the compiled library, it is possible to calculate an upper bound for the VRAM -usage during runtime. This useful to better understand if a model is able to fit particular -hardware. We can calculate this estimate using the following command: - -.. code:: shell - - ~/mlc-llm > python -m mlc_llm.cli.model_metadata ./dist/libs/phi-2-q4f16_1-iphone.tar \ - > --memory-only --mlc-chat-config ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json - INFO model_metadata.py:90: Total memory usage: 3042.96 MB (Parameters: 1492.45 MB. KVCache: 640.00 MB. Temporary buffer: 910.51 MB) - INFO model_metadata.py:99: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size` - - -**Step 5. Register as a ModelRecord** +**Step 4. Register as a ModelRecord** Finally, we update the code snippet for `app-config.json `__ From 93c560b470a5e4c8105eddd34e0cc118d2b9d9e6 Mon Sep 17 00:00:00 2001 From: Bohan Hou Date: Fri, 26 Apr 2024 20:18:05 -0400 Subject: [PATCH 242/531] [Op] Top-p cutoff pivot (#2221) This commit introduces the GPU top-p cutoff operator for efficient probability renormalization under top-p. --- python/mlc_llm/op/__init__.py | 1 + python/mlc_llm/op/top_p_pivot.py | 315 ++++++++++++++++++++++++++++ tests/python/op/test_top_p_pivot.py | 83 ++++++++ 3 files changed, 399 insertions(+) create mode 100644 python/mlc_llm/op/top_p_pivot.py create mode 100644 tests/python/op/test_top_p_pivot.py diff --git a/python/mlc_llm/op/__init__.py b/python/mlc_llm/op/__init__.py index b5db353a3b..850312a8a7 100644 --- a/python/mlc_llm/op/__init__.py +++ b/python/mlc_llm/op/__init__.py @@ -6,3 +6,4 @@ from .extern import configure, enable, get_store from .ft_gemm import faster_transformer_dequantize_gemm from .position_embedding import llama_rope +from .top_p_pivot import top_p_pivot, top_p_renorm diff --git a/python/mlc_llm/op/top_p_pivot.py b/python/mlc_llm/op/top_p_pivot.py new file mode 100644 index 0000000000..9c97959bff --- /dev/null +++ b/python/mlc_llm/op/top_p_pivot.py @@ -0,0 +1,315 @@ +"""Operators for choosing the pivot to cut-off top-p percentile """ + +import tvm +from tvm.script import tir as T + +# mypy: disable-error-code="attr-defined,valid-type,name-defined" +# pylint: disable=too-many-locals,invalid-name,too-many-arguments,unnecessary-lambda +# pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches + + +def top_p_pivot(pN): + """Top-p pivot function. This function finds the pivot to cut-off top-p percentile. + + A valide pivot should satisfy the following conditions: + - lsum >= top_p + - top_p > lsum - cmin * lmin + where lsum is the sum of elements that are larger or equal to the pivot, + lmin is the minimum elements that is larger or equal to the pivot, + cmin is the count of elements that are equal to lmin, + + Parameters + ---------- + prob: + The probability vector + + top_p_global: + The top-p threshold + + init_pivots: + The initial pivot candidates + + final_pivot: + The final pivot to cut-off top-p percentile + """ + TX = 1024 + K = 32 + eps_LR = 1e-7 + + def _var(dtype="int32"): + return T.alloc_buffer((1,), dtype, scope="local") + + def valid(lsum, lmin, cmin, top_p): + return tvm.tir.all(lsum >= top_p, top_p > lsum - cmin * lmin) + + # fmt: off + @T.prim_func(private=True) + def _func( + var_prob: T.handle, + top_p_global: T.buffer([1], dtype="float32"), + var_init_pivots: T.handle, + var_final_pivot: T.handle, + var_final_lsum: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + B = T.int32() + N = T.int32() + prob = T.match_buffer(var_prob, (B, N,), "float32") + init_pivots = T.match_buffer(var_init_pivots, (pN,), "float32") + final_pivot = T.match_buffer(var_final_pivot, (B,), "float32") + final_lsum = T.match_buffer(var_final_lsum, (B,), "float32") + + with T.block("kernel"): + pivot = T.alloc_buffer((pN,), "float32", scope="local") + top_p = _var("float32") + + L = T.alloc_buffer((1,), "float32", scope="shared") + R = T.alloc_buffer((1,), "float32", scope="shared") + L_local = _var("float32") + R_local = _var("float32") + + q = _var("float32") + lsum = T.alloc_buffer((pN,), "float32", scope="local") + lmin_broadcast = T.alloc_buffer((1), "float32", scope="shared") + lmin_broadcast_local = _var("float32") + lmin = T.alloc_buffer((pN,), "float32", scope="local") + cmin = T.alloc_buffer((pN,), "int32", scope="local") + total_sum = _var("float32") + + it = _var("int32") + es_local = _var("bool") + es = T.alloc_buffer((1,), "bool", scope="shared") + find_pivot_local = _var("bool") + find_pivot = T.alloc_buffer((1,), "bool", scope="shared") + + total_sum_reduce = _var("float32") + lsum_reduce = _var("float32") + lmin_reduce = _var("float32") + cmin_reduce = _var("int32") + + for _bx in T.thread_binding(0, B, thread="blockIdx.x"): + for _tx in T.thread_binding(0, TX, thread="threadIdx.x"): + with T.block("CTA"): + b, tx = T.axis.remap("SS", [_bx, _tx]) + + top_p[0] = top_p_global[0] + + if tx == 0: + # leader thread initializes L, R + L[0] = 1.0 - top_p[0] + R[0] = eps_LR + find_pivot[0] = False + T.tvm_storage_sync("shared") + + L_local[0] = L[0] + R_local[0] = R[0] + for i in T.unroll(0, pN): + # pivots are in descending order + pivot[i] = init_pivots[i] + find_pivot_local[0] = False + + while T.tvm_thread_invariant( + L_local[0] - R_local[0] > eps_LR + and T.Not(find_pivot_local[0]) + ): + # sync before each iteration + T.tvm_storage_sync("shared") + + ### get lsum, lmin, total_sum + for pidx in T.unroll(0, pN): + lsum[pidx] = 0.0 + lmin[pidx] = 1.0 + cmin[pidx] = 0 + total_sum[0] = 0.0 + it[0] = 0 + es_local[0] = False + while it[0] < T.ceildiv(N, TX) and T.Not(es_local[0]): + idx = T.meta_var(it[0] * TX + tx) + q[0] = T.if_then_else(idx < N, prob[b, idx], 0.0) + total_sum[0] += q[0] + for pidx in T.unroll(0, pN): + if q[0] >= pivot[pidx]: + lsum[pidx] += q[0] + if lmin[pidx] > q[0]: + lmin[pidx] = q[0] + cmin[pidx] = 1 + elif lmin[pidx] == q[0]: + cmin[pidx] += 1 + it[0] += 1 + + # early stop every K iterations + if it[0] % K == 0: + # reduce total_sum over tx + # T.tvm_storage_sync("shared") + with T.block("block_cross_thread"): + T.reads(total_sum[0]) + T.writes(total_sum_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), total_sum[0], True, total_sum_reduce[0], tx, dtype="handle") + # T.tvm_storage_sync("shared") + + if tx == 0: + # leader thread checks if we can stop early + es[0] = 1 - total_sum_reduce[0] < pivot[pN - 1] + T.tvm_storage_sync("shared") + es_local[0] = es[0] + + T.tvm_storage_sync("shared") + + # reduce lsum, lmin, cmin, over tx + for pidx in T.serial(0, pN): + # reduce lsum over tx for pivot[j] + with T.block("block_cross_thread"): + T.reads(lsum[pidx]) + T.writes(lsum_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], True, lsum_reduce[0], tx, dtype="handle") + + # reduce lmin over tx for pivot[j] + with T.block("block_cross_thread"): + T.reads(lmin[pidx]) + T.writes(lmin_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], True, lmin_reduce[0], tx, dtype="handle") + + if tx == 0: + # broadcast lmin to all threads + lmin_broadcast[0] = lmin_reduce[0] + T.tvm_storage_sync("shared") + lmin_broadcast_local[0] = lmin_broadcast[0] + if lmin[pidx] > lmin_broadcast_local[0]: + cmin[pidx] = 0 + if tx == 0: + # only the leader thread updates lsum, lmin + lsum[pidx] = lsum_reduce[0] + lmin[pidx] = lmin_reduce[0] + + # reduce cmin over tx for pivot[j] + with T.block("block_cross_thread"): + T.reads(cmin[pidx]) + T.writes(cmin_reduce[0]) + T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.int32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ) + T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], True, cmin_reduce[0], tx, dtype="handle") + + if tx == 0: + # only the leader thread updates cmin + cmin[pidx] = cmin_reduce[0] + + T.tvm_storage_sync("shared") + + if tx == 0: + # leader thread checks if we have found the pivot, or updates L, R + it[0] = 0 + while it[0] < pN and T.Not(find_pivot_local[0]): + pidx = T.meta_var(it[0]) + if valid(lsum[pidx], lmin[pidx], cmin[pidx], top_p[0]): + find_pivot[0] = True + find_pivot_local[0] = True + # write back the pivot and lsum + final_pivot[b] = pivot[pidx] + final_lsum[b] = lsum[pidx] + elif lsum[pidx] - lmin[pidx] * cmin[pidx] >= top_p[0]: + R[0] = pivot[pidx] + elif lsum[pidx] < top_p[0]: + L[0] = pivot[pidx] + it[0] += 1 + + T.tvm_storage_sync("shared") + + L_local[0] = L[0] + R_local[0] = R[0] + find_pivot_local[0] = find_pivot[0] + # new pivots for next iteration + # uniform spacing between L and R + for pidx in T.unroll(0, pN): + pivot[pidx] = L[0] - (pidx + 1) * (L_local[0] - R_local[0]) / (pN + 1) + + if tx == 0: + # leader thread writes back the pivot + if T.Not(find_pivot_local[0]): + final_pivot[b] = -1e5 + # fmt: on + + return _func + + +def top_p_renorm(): + """Top-p renormalization function. This function renormalizes the probability vector. + + Given the pivot, the probability vector is renormalized as follows: + - if prob >= pivot, renorm_prob = prob / lsum + - otherwise, renorm_prob = 0 + + Parameters + ---------- + prob: + The probability vector + + final_pivot: + The final pivot to cut-off top-p percentile + + final_lsum: + The sum of elements that are larger or equal to the pivot + + renorm_prob: + The renormalized probability vector + """ + TX = 1024 + CTA_COUNT = 512 + + def _var(dtype="int32"): + return T.alloc_buffer((1,), dtype, scope="local") + + # fmt: off + @T.prim_func(private=True) + def _func( + var_prob: T.handle, + var_final_pivot: T.handle, + var_final_lsum: T.handle, + var_renorm_prob: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + B = T.int32() + N = T.int32() + prob = T.match_buffer(var_prob, (B, N,), "float32") + final_pivot = T.match_buffer(var_final_pivot, (B,), "float32") + final_lsum = T.match_buffer(var_final_lsum, (B,), "float32") + renorm_prob = T.match_buffer(var_renorm_prob, (B, N,), "float32") + + with T.block("kernel"): + pivot = _var("float32") + lsum = _var("float32") + BX = T.meta_var(T.ceildiv(CTA_COUNT, B)) + + for _by in T.thread_binding(0, B, thread="blockIdx.y"): + for _bx in T.thread_binding(0, BX, thread="blockIdx.x"): + for _tx in T.thread_binding(0, TX, thread="threadIdx.x"): + with T.block("CTA"): + by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx]) + + pivot[0] = final_pivot[by] + lsum[0] = final_lsum[by] + + for i in T.serial(T.ceildiv(N, BX * TX)): + idx = T.meta_var(i * BX * TX + bx * TX + tx) + if idx < N: + renorm_prob[by, idx] = T.if_then_else(prob[by, idx] >= pivot[0], prob[by, idx] / lsum[0], 0.0) + # fmt: on + + return _func diff --git a/tests/python/op/test_top_p_pivot.py b/tests/python/op/test_top_p_pivot.py new file mode 100644 index 0000000000..7cfeb60e9c --- /dev/null +++ b/tests/python/op/test_top_p_pivot.py @@ -0,0 +1,83 @@ +import numpy as np +import pytest +import tvm +import tvm.testing + +from mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm + +# mypy: disable-error-code="var-annotated" + + +@pytest.mark.parametrize("batch_size", [32, 64]) +@pytest.mark.parametrize("vocab", [3, 32, 64, 128]) +def test_top_p_renorm(batch_size, vocab): + top_p = 0.95 + init_pivots_np = np.array([1 - top_p, 0.02, 0.01]).astype(np.float32) + top_p_np = np.array([top_p]).astype(np.float32) + + p_np = np.random.exponential(3, size=(batch_size, vocab)).astype(np.float32) + p_np /= np.sum(p_np, axis=-1, keepdims=True) + final_pivot_np = np.zeros(batch_size).astype(np.float32) + final_lsum_np = np.zeros(batch_size).astype(np.float32) + + dev = tvm.cuda(0) + var_prob = tvm.nd.array(p_np, dev) + var_init_pivots = tvm.nd.array(init_pivots_np, dev) + top_p_global = tvm.nd.array(top_p_np, dev) + var_final_pivot = tvm.nd.array(final_pivot_np, dev) + var_final_lsum = tvm.nd.array(final_lsum_np, dev) + + kernel = top_p_pivot(init_pivots_np.shape[0]) + mod = tvm.build(kernel, target="cuda") + mod(var_prob, top_p_global, var_init_pivots, var_final_pivot, var_final_lsum) + + final_pivot = var_final_pivot.asnumpy() + final_lsum = var_final_lsum.asnumpy() + + renorm_np = p_np.copy() + var_renorm = tvm.nd.array(renorm_np, dev) + + kernel_renorm = top_p_renorm() + mod_renorm = tvm.build(kernel_renorm, target="cuda") + mod_renorm(var_prob, var_final_pivot, var_final_lsum, var_renorm) + + renorm = var_renorm.asnumpy() + + def verify_pivot(probs: np.ndarray, pivot: float, lsum: float, renorm: np.ndarray): + sorted_probs = np.sort(probs, axis=-1)[::-1] + num_larger_than_pivot = np.sum(sorted_probs >= pivot) + filtered_sorted_probs = sorted_probs[:num_larger_than_pivot] + min_larger_than_pivot = min(filtered_sorted_probs) + + sum_larger_than_pivot = np.sum(np.where(sorted_probs >= pivot, sorted_probs, 0)) + sum_larger_than_pivot_exclude_min = np.sum( + np.where(filtered_sorted_probs != min_larger_than_pivot, filtered_sorted_probs, 0) + ) + + probs[probs < pivot] = 0 + renorm_prob = probs / np.sum(probs, axis=-1, keepdims=True) + try: + assert sum_larger_than_pivot >= top_p + assert sum_larger_than_pivot_exclude_min < top_p + assert abs(lsum - sum_larger_than_pivot) < 1e-6 + assert np.allclose(renorm, renorm_prob, atol=1e-6, rtol=1e-6) + except AssertionError: + print("Failed") + print("probs:", repr(probs)) + print("pivot:", pivot) + print("sorted_probs:", sorted_probs) + print("num_larger_than_pivot:", num_larger_than_pivot) + print("filtered_sorted_probs:", filtered_sorted_probs) + print("min_larger_than_pivot:", min_larger_than_pivot) + print("sum_larger_than_pivot:", sum_larger_than_pivot) + print("sum_larger_than_pivot_exclude_min:", sum_larger_than_pivot_exclude_min) + print("renom_prob:", renorm_prob) + print("renorm:", renorm) + raise + + for i in range(batch_size): + verify_pivot(p_np[i], final_pivot[i], final_lsum[i], renorm[i]) + + +if __name__ == "__main__": + tvm.testing.main() From 8e7b38a6678fa831b347b5525b89571aa7a2f0df Mon Sep 17 00:00:00 2001 From: Bohan Hou Date: Sat, 27 Apr 2024 07:54:22 -0400 Subject: [PATCH 243/531] [Op] Batch Verify: accept proposal when p and q are close enough (#2236) * dev * dev --- python/mlc_llm/op/batch_spec_verify.py | 25 +++++++++++++++-------- tests/python/op/test_batch_spec_verify.py | 16 ++++++++++++++- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/python/mlc_llm/op/batch_spec_verify.py b/python/mlc_llm/op/batch_spec_verify.py index 9cdbe2be21..d1a57fc71c 100644 --- a/python/mlc_llm/op/batch_spec_verify.py +++ b/python/mlc_llm/op/batch_spec_verify.py @@ -51,7 +51,7 @@ def batch_spec_verify(vocab_size): token_tree_parent_ptr: Current parent ptr state """ - TX = 128 + TX = 1024 def _var(dtype="int32"): return T.alloc_buffer((1,), dtype, scope="local") @@ -142,7 +142,6 @@ def _func( model_prob_local[0] = model_probs[parent_ptr[0], k] draft_prob_local[0] = draft_probs[child_ptr[0], k] model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0) - model_probs[parent_ptr[0], k] = model_prob_local[0] psum[0] += model_prob_local[0] with T.block("block_cross_thread"): @@ -155,13 +154,21 @@ def _func( ) T.tvm_thread_allreduce(T.uint32(1), psum[0], True, t0[0], tx, dtype="handle") - # renormalize - for i in T.serial(T.ceildiv(vocab_size, TX)): - k = T.meta_var(i * TX + tx) - if k < vocab_size: - model_probs[parent_ptr[0], k] = model_probs[parent_ptr[0], k] / t0[0] - - child_ptr[0] = token_tree_next_sibling[child_ptr[0]] + if t0[0] < 1e-7: + # accept the proposal, we move to child + parent_ptr[0] = child_ptr[0] + child_ptr[0] = token_tree_first_child[child_ptr[0]] + else: + # renormalize + for i in T.serial(T.ceildiv(vocab_size, TX)): + k = T.meta_var(i * TX + tx) + if k < vocab_size: + model_prob_local[0] = model_probs[parent_ptr[0], k] + draft_prob_local[0] = draft_probs[child_ptr[0], k] + model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0) + model_probs[parent_ptr[0], k] = model_prob_local[0] / t0[0] + + child_ptr[0] = token_tree_next_sibling[child_ptr[0]] if tx == 0: token_tree_parent_ptr[b] = parent_ptr[0] diff --git a/tests/python/op/test_batch_spec_verify.py b/tests/python/op/test_batch_spec_verify.py index 359fafdbd0..f35a39d71e 100644 --- a/tests/python/op/test_batch_spec_verify.py +++ b/tests/python/op/test_batch_spec_verify.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("nbatch", [32, 64]) -@pytest.mark.parametrize("vocab", [3, 32, 64, 32000, 33, 65, 32001]) +@pytest.mark.parametrize("vocab", [3, 32, 64, 32000, 33, 65, 32001, 128000]) @pytest.mark.parametrize("plist", [[0.5, 0.5], [1, 0], [0, 1]]) def test_batch_spec_verify(nbatch, vocab, plist): def numpy_reference( @@ -141,6 +141,20 @@ def gen_full_binary_tree(height, base): token_tree_parent_ptr, token_tree_parent_ptr_tvm.asnumpy(), rtol=0, atol=0 ) + time_evaluator = mod.time_evaluator(mod.entry_name, dev, number=10, repeat=3) + print(f"batch_size: {nbatch}, vocab_size: {vocab}, tree_structure: {plist}") + print( + time_evaluator( + draft_probs_tvm, + draft_tokens_tvm, + model_probs_tvm, + token_tree_first_child_tvm, + token_tree_next_sibling_tvm, + uniform_samples_tvm, + token_tree_parent_ptr_tvm, + ) + ) + if __name__ == "__main__": tvm.testing.main() From 135bcf98dbd78268669fee9010b0249358a08361 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 27 Apr 2024 07:54:36 -0400 Subject: [PATCH 244/531] [Serving] Creating EngineConfig from JSON (#2237) This PR supports creating EngineConfig from a JSON string, which is useful for JSONFFIEngine and its API bindings. This commit also removes the device from the EngineConfig for better clarity. --- cpp/json_ffi/json_ffi_engine.cc | 4 +- cpp/serve/config.cc | 61 ++++++++++++++++--- cpp/serve/config.h | 12 ++-- cpp/serve/engine.cc | 20 +++--- cpp/serve/engine.h | 3 +- cpp/serve/threaded_engine.cc | 7 ++- cpp/serve/threaded_engine.h | 3 +- python/mlc_llm/serve/config.py | 5 -- python/mlc_llm/serve/engine_base.py | 2 +- python/mlc_llm/serve/sync_engine.py | 2 +- tests/python/json_ffi/_ffi_api.py | 6 -- tests/python/json_ffi/test_json_ffi_engine.py | 44 ++++++------- 12 files changed, 100 insertions(+), 69 deletions(-) delete mode 100644 tests/python/json_ffi/_ffi_api.py diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 1a21c2962d..d5fc53b8fa 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -123,7 +123,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_END(); void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config, - Optional request_stream_callback, + Device device, Optional request_stream_callback, Optional trace_recorder) { std::optional conv_template = Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_); @@ -150,7 +150,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { }; request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - this->engine_->InitBackgroundEngine(std::move(request_stream_callback), + this->engine_->InitBackgroundEngine(device, std::move(request_stream_callback), std::move(trace_recorder)); this->engine_->Reload(std::move(engine_config)); } diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 19f26ff624..3bb809ad67 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -244,17 +244,16 @@ String GenerationConfigNode::AsJSONString() const { TVM_REGISTER_OBJECT_TYPE(EngineConfigNode); EngineConfig::EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, - int max_total_sequence_length, int max_single_sequence_length, - int prefill_chunk_size, int max_history_size, KVStateKind kv_state_kind, + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, + int max_single_sequence_length, int prefill_chunk_size, + int max_history_size, KVStateKind kv_state_kind, SpeculativeMode speculative_mode, int spec_draft_length) { ObjectPtr n = make_object(); n->model = std::move(model); n->model_lib_path = std::move(model_lib_path); n->additional_models = std::move(additional_models); n->additional_model_lib_paths = std::move(additional_model_lib_paths); - n->device = device; n->kv_cache_page_size = kv_cache_page_size; n->max_num_sequence = max_num_sequence; n->max_total_sequence_length = max_total_sequence_length; @@ -267,14 +266,60 @@ EngineConfig::EngineConfig(String model, String model_lib_path, Array ad data_ = std::move(n); } +EngineConfig EngineConfig::FromJSONString(const std::string& json_str) { + picojson::value config_json; + std::string err = picojson::parse(config_json, json_str); + if (!err.empty()) { + LOG(FATAL) << err; + } + + // Get json fields. + picojson::object config = config_json.get(); + String model = json::Lookup(config, "model"); + String model_lib_path = json::Lookup(config, "model_lib_path"); + std::vector additional_models; + std::vector additional_model_lib_paths; + int kv_cache_page_size = json::Lookup(config, "kv_cache_page_size"); + int max_num_sequence = json::Lookup(config, "max_num_sequence"); + int max_total_sequence_length = json::Lookup(config, "max_total_sequence_length"); + int max_single_sequence_length = json::Lookup(config, "max_single_sequence_length"); + int prefill_chunk_size = json::Lookup(config, "prefill_chunk_size"); + int max_history_size = json::Lookup(config, "max_history_size"); + KVStateKind kv_state_kind = + static_cast(json::Lookup(config, "kv_state_kind")); + SpeculativeMode speculative_mode = + static_cast(json::Lookup(config, "speculative_mode")); + int spec_draft_length = json::Lookup(config, "spec_draft_length"); + + picojson::array additional_models_arr = + json::Lookup(config, "additional_models"); + picojson::array additional_model_lib_paths_arr = + json::Lookup(config, "additional_model_lib_paths"); + CHECK_EQ(additional_models_arr.size(), additional_model_lib_paths_arr.size()) + << "The number of additional model lib paths does not match the number of additional models"; + int num_additional_models = additional_models_arr.size(); + additional_models.reserve(num_additional_models); + additional_model_lib_paths.reserve(num_additional_models); + for (int i = 0; i < num_additional_models; ++i) { + additional_models.push_back(json::Lookup(additional_models_arr, i)); + additional_model_lib_paths.push_back( + json::Lookup(additional_model_lib_paths_arr, i)); + } + + return EngineConfig(std::move(model), std::move(model_lib_path), additional_models, + additional_model_lib_paths, kv_cache_page_size, max_num_sequence, + max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, + max_history_size, kv_state_kind, speculative_mode, spec_draft_length); +} + TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig") .set_body_typed([](String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, int max_single_sequence_length, int prefill_chunk_size, int max_history_size, int kv_state_kind, int speculative_mode, int spec_draft_length) { return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models), - std::move(additional_model_lib_paths), device, kv_cache_page_size, + std::move(additional_model_lib_paths), kv_cache_page_size, max_num_sequence, max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind), SpeculativeMode(speculative_mode), spec_draft_length); diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 6a3bdd8997..fd76dd49f0 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -106,11 +106,6 @@ class EngineConfigNode : public Object { /*! \brief The path to the additional models' libraries. */ Array additional_model_lib_paths; - /*************** Device ***************/ - - /*! \brief The device where the models run. */ - DLDevice device; - /*************** KV cache config and engine capacities ***************/ /*! \brief The number of consecutive tokens handled in each page in paged KV cache. */ @@ -152,12 +147,15 @@ class EngineConfigNode : public Object { class EngineConfig : public ObjectRef { public: explicit EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, DLDevice device, - int kv_cache_page_size, int max_num_sequence, int max_total_sequence_length, + Array additional_model_lib_paths, int kv_cache_page_size, + int max_num_sequence, int max_total_sequence_length, int max_single_sequence_length, int prefill_chunk_size, int max_history_size, KVStateKind kv_state_kind, SpeculativeMode speculative_mode, int spec_draft_length); + /*! \brief Create EngineConfig from JSON string. */ + static EngineConfig FromJSONString(const std::string& json_str); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); }; diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 0348f7f40a..c9588cc4e8 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -44,7 +44,8 @@ class EngineImpl : public Engine { public: /********************** Engine Management **********************/ - explicit EngineImpl(EngineConfig engine_config, Optional request_stream_callback, + explicit EngineImpl(EngineConfig engine_config, DLDevice device, + Optional request_stream_callback, Optional trace_recorder) { // Step 1. Initialize metadata and singleton states inside the engine this->estate_->Reset(); @@ -62,9 +63,9 @@ class EngineImpl : public Engine { this->models_.clear(); this->model_workspaces_.clear(); - auto f_create_model = [this, &engine_config, &trace_recorder](const String& model_path, - const String& model_lib_path) { - Model model = Model::Create(model_lib_path, std::move(model_path), engine_config->device, + auto f_create_model = [this, &engine_config, &device, &trace_recorder]( + const String& model_path, const String& model_lib_path) { + Model model = Model::Create(model_lib_path, std::move(model_path), device, engine_config->max_num_sequence, /*trace_enabled=*/trace_recorder.defined()); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, @@ -339,10 +340,11 @@ class EngineImpl : public Engine { Optional trace_recorder_; }; -std::unique_ptr Engine::Create(EngineConfig engine_config, +std::unique_ptr Engine::Create(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder) { - return std::make_unique(std::move(engine_config), std::move(request_stream_callback), + return std::make_unique(std::move(engine_config), device, + std::move(request_stream_callback), std::move(trace_recorder)); } @@ -368,10 +370,10 @@ class EngineModule : public ModuleNode { TVM_MODULE_VTABLE_END(); /*! \brief Initialize the engine with config and other fields. */ - void Init(EngineConfig engine_config, Optional request_stream_callback, + void Init(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder) { - this->engine_ = Engine::Create(std::move(engine_config), std::move(request_stream_callback), - std::move(trace_recorder)); + this->engine_ = Engine::Create(std::move(engine_config), device, + std::move(request_stream_callback), std::move(trace_recorder)); } /*! \brief Construct an EngineModule. */ static tvm::runtime::Module Create() { return Module(make_object()); } diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index bcc1b80988..2fc0a4d730 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -51,11 +51,12 @@ class Engine { /*! * \brief Create an engine in unique pointer. * \param engine_config The engine config. + * \param device The device where the run models. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. * \return The created Engine in pointer. */ - static std::unique_ptr Create(EngineConfig engine_config, + static std::unique_ptr Create(EngineConfig engine_config, Device device, Optional request_stream_callback, Optional trace_recorder); diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index f234dfbbc3..2f6f77a3a0 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -36,8 +36,9 @@ enum class InstructionKind : int { /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineImpl : public ThreadedEngine { public: - void InitBackgroundEngine(Optional request_stream_callback, + void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) final { + device_ = device; CHECK(request_stream_callback.defined()) << "ThreadedEngine requires request stream callback function, but it is not given."; request_stream_callback_ = request_stream_callback.value(); @@ -231,7 +232,7 @@ class ThreadedEngineImpl : public ThreadedEngine { }; Optional request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - background_engine_ = Engine::Create(std::move(engine_config), + background_engine_ = Engine::Create(std::move(engine_config), device_, std::move(request_stream_callback), trace_recorder_); } @@ -247,6 +248,8 @@ class ThreadedEngineImpl : public ThreadedEngine { } } + /*! \brief The device to run models on. */ + Device device_; /*! \brief The background normal engine for request processing. */ std::unique_ptr background_engine_; /*! \brief The request stream callback. */ diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index f3d9c2b70c..49ba8f2175 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -35,10 +35,11 @@ class ThreadedEngine { /*! * \brief Initialize the threaded engine from packed arguments in TVMArgs. + * \param device The device where to run models. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. */ - virtual void InitBackgroundEngine(Optional request_stream_callback, + virtual void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) = 0; /*! diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 40c53e336a..6b808ac37b 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -164,9 +164,6 @@ class EngineConfig(tvm.runtime.Object): additional_model_lib_paths : List[str] The path to the additional models' libraries. - device : tvm.runtime.Device - The device where the models run. - kv_cache_page_size : int The number of consecutive tokens handled in each page in paged KV cache. @@ -203,7 +200,6 @@ def __init__( # pylint: disable=too-many-arguments model_lib_path: str, additional_models: List[str], additional_model_lib_paths: List[str], - device: tvm.runtime.Device, kv_cache_page_size: int, max_num_sequence: int, max_total_sequence_length: int, @@ -220,7 +216,6 @@ def __init__( # pylint: disable=too-many-arguments model_lib_path, additional_models, additional_model_lib_paths, - device, kv_cache_page_size, max_num_sequence, max_total_sequence_length, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index fb0a35ddd2..65b41a66ac 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -1070,6 +1070,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals } self.tokenizer = Tokenizer(model_args[0][0]) self._ffi["init_background_engine"]( + device, self.state.get_request_stream_callback(kind), self.state.trace_recorder, ) @@ -1079,7 +1080,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model_lib_path=model_args[0][1], additional_models=[model_arg[0] for model_arg in model_args[1:]], additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, kv_cache_page_size=16, max_num_sequence=max_batch_size, max_total_sequence_length=max_total_sequence_length, diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 7469ddc241..1be841cb08 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -166,7 +166,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model_lib_path=model_args[0][1], additional_models=[model_arg[0] for model_arg in model_args[1:]], additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, kv_cache_page_size=16, max_num_sequence=max_batch_size, max_total_sequence_length=max_total_sequence_length, @@ -177,6 +176,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, ), + device, request_stream_callback, self.trace_recorder, ) diff --git a/tests/python/json_ffi/_ffi_api.py b/tests/python/json_ffi/_ffi_api.py deleted file mode 100644 index 3df07d6a1f..0000000000 --- a/tests/python/json_ffi/_ffi_api.py +++ /dev/null @@ -1,6 +0,0 @@ -"""FFI APIs for mlc.json_ffi""" -import tvm._ffi - -# Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc.json_ffi" prefix. -# e.g. TVM_REGISTER_GLOBAL("mlc.serve.TextData") -tvm._ffi._init_api("mlc.json_ffi", __name__) # pylint: disable=protected-access diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index f5235663be..2220303e42 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -6,7 +6,6 @@ from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union import tvm -from tests.python.json_ffi import _ffi_api from mlc_llm.protocol import openai_api_protocol from mlc_llm.serve import engine_utils @@ -61,30 +60,23 @@ ] -@tvm._ffi.register_object( - "mlc.json_ffi.ModelDefinedGenerationConfig" -) # pylint: disable=protected-access -class ModelDefinedGenerationConfig(tvm.runtime.Object): - def __init__( # pylint: disable=too-many-arguments - self, temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.ModelDefinedGenerationConfig, - temperature, - top_p, - frequency_penalty, - presence_penalty, - ) +def create_model_defined_generation_config( + temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float +) -> tvm.runtime.Object: + return tvm.get_global_func("mlc.json_ffi.ModelDefinedGenerationConfig")( + temperature, + top_p, + frequency_penalty, + presence_penalty, + ) -@tvm._ffi.register_object("mlc.json_ffi.JSONFFIEngineConfig") # pylint: disable=protected-access -class JSONFFIEngineConfig(tvm.runtime.Object): - def __init__( # pylint: disable=too-many-arguments - self, conv_template: str, model_generation_cfgs: Dict[str, ModelDefinedGenerationConfig] - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.JSONFFIEngineConfig, conv_template, model_generation_cfgs - ) +def create_json_ffi_engine_config( + conv_template: str, model_generation_cfgs: Dict[str, tvm.runtime.Object] +) -> tvm.runtime.Object: + return tvm.get_global_func("mlc.json_ffi.JSONFFIEngineConfig")( + conv_template, model_generation_cfgs + ) class EngineState: @@ -187,7 +179,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model_lib_path=model_args[0][1], additional_models=[model_arg[0] for model_arg in model_args[1:]], additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - device=device, kv_cache_page_size=16, max_num_sequence=max_batch_size, max_total_sequence_length=max_total_sequence_length, @@ -199,10 +190,10 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals spec_draft_length=spec_draft_length, ) - self.json_ffi_engine_config = JSONFFIEngineConfig( + self.json_ffi_engine_config = create_json_ffi_engine_config( conv_template=self.conv_template.model_dump_json(), model_generation_cfgs={ - model.model: ModelDefinedGenerationConfig( + model.model: create_model_defined_generation_config( temperature=model_config["temperature"], top_p=model_config["top_p"], frequency_penalty=model_config["frequency_penalty"], @@ -215,6 +206,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self._ffi["init_background_engine"]( self.json_ffi_engine_config, self.engine_config, + device, self.state.get_request_stream_callback(), None, ) From fd659733d3e681bebb925961c7af5b83c209e77b Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Sat, 27 Apr 2024 15:52:05 -0400 Subject: [PATCH 245/531] [Bugfix] layer_norm_eps in GPT2Config should be float (#2240) --- python/mlc_llm/model/gpt2/gpt2_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index 28c34353e2..ede9dc350f 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -28,7 +28,7 @@ class GPT2Config(ConfigBase): # pylint: disable=too-many-instance-attributes n_embd: int n_layer: int n_head: int - layer_norm_epsilon: int + layer_norm_epsilon: float n_inner: int = -1 context_window_size: int = 0 prefill_chunk_size: int = 0 From 63a3804e772d179d6d26d53e154a1447cc61fd7a Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 27 Apr 2024 17:51:27 -0400 Subject: [PATCH 246/531] [REFACTOR] Migrate JSONFFIEngine to formal namespace (#2241) This PR migrates JSONFFIEngine to a formal namespace. Also list TODOs to further simplify the JSONFFIEngine. --- python/mlc_llm/json_ffi/__init__.py | 8 + python/mlc_llm/json_ffi/engine.py | 310 ++++++++++++++++++ tests/python/json_ffi/test_json_ffi_engine.py | 296 +---------------- 3 files changed, 319 insertions(+), 295 deletions(-) create mode 100644 python/mlc_llm/json_ffi/__init__.py create mode 100644 python/mlc_llm/json_ffi/engine.py diff --git a/python/mlc_llm/json_ffi/__init__.py b/python/mlc_llm/json_ffi/__init__.py new file mode 100644 index 0000000000..8a7059153d --- /dev/null +++ b/python/mlc_llm/json_ffi/__init__.py @@ -0,0 +1,8 @@ +"""JSON FFI is a pure string based interface of MLC LLM Engine. + +We build interfacing with JSON FFI for both testing purposes +and internal use. For most python API usage, please use MLCEngine +and MLCAsyncEngine +""" + +from .engine import JSONFFIEngine diff --git a/python/mlc_llm/json_ffi/engine.py b/python/mlc_llm/json_ffi/engine.py new file mode 100644 index 0000000000..0c604a2ef3 --- /dev/null +++ b/python/mlc_llm/json_ffi/engine.py @@ -0,0 +1,310 @@ +# pylint: disable=chained-comparison,missing-docstring,too-few-public-methods,too-many-instance-attributes +# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +import json +import queue +import threading +from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union + +import tvm + +from mlc_llm.protocol import openai_api_protocol +from mlc_llm.serve import engine_utils +from mlc_llm.serve.engine_base import ( + EngineConfig, + SpeculativeMode, + _infer_kv_cache_config, + _parse_models, + _process_model_args, + detect_device, +) +from mlc_llm.tokenizer import Tokenizer + + +# TODO(mlc-team): further minimize the JSONFFIEngine +# construction to not depend on any config and directly pass in JSON +# model defined generation config should be read from the JSONFFIEngine via Reload +def create_model_defined_generation_config( + temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float +) -> tvm.runtime.Object: + return tvm.get_global_func("mlc.json_ffi.ModelDefinedGenerationConfig")( + temperature, + top_p, + frequency_penalty, + presence_penalty, + ) + + +# TODO(mlc-team): further minimize the JSONFFIEngine +# Engine config should be passed as json str +# and backend should have good default +# only model and model_lib should be mandatory +def create_json_ffi_engine_config( + conv_template: str, model_generation_cfgs: Dict[str, tvm.runtime.Object] +) -> tvm.runtime.Object: + return tvm.get_global_func("mlc.json_ffi.JSONFFIEngineConfig")( + conv_template, model_generation_cfgs + ) + + +class EngineState: + sync_queue: queue.Queue + + def get_request_stream_callback(self) -> Callable[[List[str]], None]: + # ChatCompletionStreamResponse + + def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: + self._sync_request_stream_callback(chat_completion_stream_responses_json_str) + + return _callback + + def _sync_request_stream_callback( + self, chat_completion_stream_responses_json_str: List[str] + ) -> None: + # Put the delta outputs to the queue in the unblocking way. + self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) + + +class JSONFFIEngine: + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + model: str, + device: Union[str, tvm.runtime.Device] = "auto", + *, + model_lib_path: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + additional_models: Optional[List[str]] = None, + max_batch_size: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + max_history_size: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + spec_draft_length: int = 4, + gpu_memory_utilization: Optional[float] = None, + ) -> None: + # - Initialize model loading info. + models = _parse_models(model, model_lib_path, additional_models) + if isinstance(device, str): + device = detect_device(device) + assert isinstance(device, tvm.runtime.Device) + ( + model_args, + model_config_paths, + self.conv_template, + ) = _process_model_args(models, device) + + # TODO(mlc-team) Remove the model config parsing, estimation below + # in favor of a simple direct passing of parameters into backend. + # JSONFFIEngine do not have to support automatic mode + # + # Instead, its config should default to interactive mode always + # and allow overrides of parameters through json config via reload + # + # This is to simplify the logic of users of JSONFFI + # since we won't have similar logics in android/iOS + # + # - Load the raw model config into dict + self.model_config_dicts = [] + for i, model_info in enumerate(models): + model_info.model_lib_path = model_args[i][1] + with open(model_config_paths[i], "r", encoding="utf-8") as file: + self.model_config_dicts.append(json.load(file)) + + # - Decide the KV cache config based on mode and user input. + ( + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_single_sequence_length, + max_history_size, + kv_state_kind, + ) = _infer_kv_cache_config( + mode, + max_batch_size, + max_total_sequence_length, + prefill_chunk_size, + max_history_size, + gpu_memory_utilization, + models, + device, + self.model_config_dicts, + model_config_paths, + ) + self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + + # - Initialize engine state and engine. + self.state = EngineState() + module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() + self._ffi = { + key: module[key] + for key in [ + "init_background_engine", + "reload", + "unload", + "reset", + "chat_completion", + "abort", + "get_last_error", + "run_background_loop", + "run_background_stream_back_loop", + "exit_background_loop", + ] + } + self.tokenizer = Tokenizer(model_args[0][0]) + + self.engine_config = EngineConfig( + model=model_args[0][0], + model_lib_path=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + kv_state_kind=kv_state_kind, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + ) + + self.json_ffi_engine_config = create_json_ffi_engine_config( + conv_template=self.conv_template.model_dump_json(), + model_generation_cfgs={ + model.model: create_model_defined_generation_config( + temperature=model_config["temperature"], + top_p=model_config["top_p"], + frequency_penalty=model_config["frequency_penalty"], + presence_penalty=model_config["presence_penalty"], + ) + for model, model_config in zip(models, self.model_config_dicts) + }, + ) + + self._ffi["init_background_engine"]( + self.json_ffi_engine_config, + self.engine_config, + device, + self.state.get_request_stream_callback(), + None, + ) + + def _background_loop(): + self._ffi["run_background_loop"]() + + def _background_stream_back_loop(): + self._ffi["run_background_stream_back_loop"]() + + # Create the background engine-driving thread and start the loop. + self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_stream_back_loop_thread: threading.Thread = threading.Thread( + target=_background_stream_back_loop + ) + self._background_loop_thread.start() + self._background_stream_back_loop_thread.start() + self._terminated = False + + def terminate(self): + self._terminated = True + self._ffi["exit_background_loop"]() + self._background_loop_thread.join() + self._background_stream_back_loop_thread.join() + + def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + self, + *, + messages: List[Dict[str, Any]], + model: str, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + logprobs: bool = False, + top_logprobs: int = 0, + logit_bias: Optional[Dict[int, float]] = None, + max_tokens: Optional[int] = None, + n: int = 1, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, + user: Optional[str] = None, + ignore_eos: bool = False, + response_format: Optional[Dict[str, Any]] = None, + request_id: Optional[str] = None, + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + if request_id is None: + request_id = f"chatcmpl-{engine_utils.random_uuid()}" + + chatcmpl_generator = self._handle_chat_completion( + openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + seed=seed, + stop=stop, + stream=stream, + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + ignore_eos=ignore_eos, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + ).model_dump_json(), + n=n, + request_id=request_id, + ) + for response in chatcmpl_generator: + yield response + + def _handle_chat_completion( + self, request_json_str: str, n: int, request_id: str + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + self.state.sync_queue = queue.Queue() + num_unfinished_requests = n + + success = bool(self._ffi["chat_completion"](request_json_str, request_id)) + + try: + while num_unfinished_requests > 0: + chat_completion_stream_responses_json_str = self.state.sync_queue.get() + for chat_completion_response_json_str in chat_completion_stream_responses_json_str: + chat_completion_response = ( + openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( + chat_completion_response_json_str + ) + ) + for choice in chat_completion_response.choices: + if choice.finish_reason is not None: + num_unfinished_requests -= 1 + yield chat_completion_response + except Exception as exception: # pylint: disable=broad-exception-caught + self._ffi["abort"](request_id) + raise exception + + def _test_reload(self): + self._ffi["reload"](self.engine_config) + + def _test_reset(self): + self._ffi["reset"]() + + def _test_unload(self): + self._ffi["unload"]() diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index 2220303e42..c52571b522 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -1,23 +1,6 @@ -# pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -import json -import queue -import threading from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union -import tvm - -from mlc_llm.protocol import openai_api_protocol -from mlc_llm.serve import engine_utils -from mlc_llm.serve.engine_base import ( - EngineConfig, - SpeculativeMode, - _infer_kv_cache_config, - _parse_models, - _process_model_args, - detect_device, -) -from mlc_llm.tokenizer import Tokenizer +from mlc_llm.json_ffi import JSONFFIEngine chat_completion_prompts = [ "What is the meaning of life?", @@ -60,279 +43,6 @@ ] -def create_model_defined_generation_config( - temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float -) -> tvm.runtime.Object: - return tvm.get_global_func("mlc.json_ffi.ModelDefinedGenerationConfig")( - temperature, - top_p, - frequency_penalty, - presence_penalty, - ) - - -def create_json_ffi_engine_config( - conv_template: str, model_generation_cfgs: Dict[str, tvm.runtime.Object] -) -> tvm.runtime.Object: - return tvm.get_global_func("mlc.json_ffi.JSONFFIEngineConfig")( - conv_template, model_generation_cfgs - ) - - -class EngineState: - sync_queue: queue.Queue - - def get_request_stream_callback(self) -> Callable[[List[str]], None]: - # ChatCompletionStreamResponse - - def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: - self._sync_request_stream_callback(chat_completion_stream_responses_json_str) - - return _callback - - def _sync_request_stream_callback( - self, chat_completion_stream_responses_json_str: List[str] - ) -> None: - # Put the delta outputs to the queue in the unblocking way. - self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) - - -class JSONFFIEngine: - def __init__( # pylint: disable=too-many-arguments,too-many-locals - self, - model: str, - device: Union[str, tvm.runtime.Device] = "auto", - *, - model_lib_path: Optional[str] = None, - mode: Literal["local", "interactive", "server"] = "local", - additional_models: Optional[List[str]] = None, - max_batch_size: Optional[int] = None, - max_total_sequence_length: Optional[int] = None, - max_history_size: Optional[int] = None, - prefill_chunk_size: Optional[int] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, - spec_draft_length: int = 4, - gpu_memory_utilization: Optional[float] = None, - ) -> None: - # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) - if isinstance(device, str): - device = detect_device(device) - assert isinstance(device, tvm.runtime.Device) - ( - model_args, - model_config_paths, - self.conv_template, - ) = _process_model_args(models, device) - - # - Load the raw model config into dict - self.model_config_dicts = [] - for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] - with open(model_config_paths[i], "r", encoding="utf-8") as file: - self.model_config_dicts.append(json.load(file)) - - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) - - # - Initialize engine state and engine. - self.state = EngineState() - module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() - self._ffi = { - key: module[key] - for key in [ - "init_background_engine", - "reload", - "unload", - "reset", - "chat_completion", - "abort", - "get_last_error", - "run_background_loop", - "run_background_stream_back_loop", - "exit_background_loop", - ] - } - self.tokenizer = Tokenizer(model_args[0][0]) - - self.engine_config = EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - max_history_size=max_history_size, - kv_state_kind=kv_state_kind, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ) - - self.json_ffi_engine_config = create_json_ffi_engine_config( - conv_template=self.conv_template.model_dump_json(), - model_generation_cfgs={ - model.model: create_model_defined_generation_config( - temperature=model_config["temperature"], - top_p=model_config["top_p"], - frequency_penalty=model_config["frequency_penalty"], - presence_penalty=model_config["presence_penalty"], - ) - for model, model_config in zip(models, self.model_config_dicts) - }, - ) - - self._ffi["init_background_engine"]( - self.json_ffi_engine_config, - self.engine_config, - device, - self.state.get_request_stream_callback(), - None, - ) - - def _background_loop(): - self._ffi["run_background_loop"]() - - def _background_stream_back_loop(): - self._ffi["run_background_stream_back_loop"]() - - # Create the background engine-driving thread and start the loop. - self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) - self._background_stream_back_loop_thread: threading.Thread = threading.Thread( - target=_background_stream_back_loop - ) - self._background_loop_thread.start() - self._background_stream_back_loop_thread.start() - self._terminated = False - - def terminate(self): - self._terminated = True - self._ffi["exit_background_loop"]() - self._background_loop_thread.join() - self._background_stream_back_loop_thread.join() - - def chat_completion( # pylint: disable=too-many-arguments,too-many-locals - self, - *, - messages: List[Dict[str, Any]], - model: str, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - logprobs: bool = False, - top_logprobs: int = 0, - logit_bias: Optional[Dict[int, float]] = None, - max_tokens: Optional[int] = None, - n: int = 1, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: bool = False, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - tools: Optional[List[Dict[str, Any]]] = None, - tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, - user: Optional[str] = None, - ignore_eos: bool = False, - response_format: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None, - ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - if request_id is None: - request_id = f"chatcmpl-{engine_utils.random_uuid()}" - - chatcmpl_generator = self._handle_chat_completion( - openai_api_protocol.ChatCompletionRequest( - messages=[ - openai_api_protocol.ChatCompletionMessage.model_validate(message) - for message in messages - ], - model=model, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - logprobs=logprobs, - top_logprobs=top_logprobs, - logit_bias=logit_bias, - max_tokens=max_tokens, - n=n, - seed=seed, - stop=stop, - stream=stream, - temperature=temperature, - top_p=top_p, - tools=( - [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] - if tools is not None - else None - ), - tool_choice=tool_choice, - user=user, - ignore_eos=ignore_eos, - response_format=( - openai_api_protocol.RequestResponseFormat.model_validate(response_format) - if response_format is not None - else None - ), - ).model_dump_json(), - n=n, - request_id=request_id, - ) - for response in chatcmpl_generator: - yield response - - def _handle_chat_completion( - self, request_json_str: str, n: int, request_id: str - ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - self.state.sync_queue = queue.Queue() - num_unfinished_requests = n - - success = bool(self._ffi["chat_completion"](request_json_str, request_id)) - - try: - while num_unfinished_requests > 0: - chat_completion_stream_responses_json_str = self.state.sync_queue.get() - for chat_completion_response_json_str in chat_completion_stream_responses_json_str: - chat_completion_response = ( - openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( - chat_completion_response_json_str - ) - ) - for choice in chat_completion_response.choices: - if choice.finish_reason is not None: - num_unfinished_requests -= 1 - yield chat_completion_response - except Exception as exception: # pylint: disable=broad-exception-caught - self._ffi["abort"](request_id) - raise exception - - def _test_reload(self): - self._ffi["reload"](self.engine_config) - - def _test_reset(self): - self._ffi["reset"]() - - def _test_unload(self): - self._ffi["unload"]() - - def run_chat_completion( engine: JSONFFIEngine, model: str, @@ -374,10 +84,8 @@ def run_chat_completion( def test_chat_completion(): # Create engine. model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-cuda.so" engine = JSONFFIEngine( model, - model_lib_path=model_lib_path, max_total_sequence_length=1024, ) @@ -394,10 +102,8 @@ def test_chat_completion(): def test_reload_reset_unload(): # Create engine. model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-cuda.so" engine = JSONFFIEngine( model, - model_lib_path=model_lib_path, max_total_sequence_length=1024, ) From 1a8bad0152ff4bc1e02b0533e19a3974bd761992 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sun, 28 Apr 2024 13:52:43 -0700 Subject: [PATCH 247/531] [Serving] Share disco sessions among multiple model function tables (#2242) --- cpp/serve/engine.cc | 67 +++++++++++++++++++++++++++++++++---- cpp/serve/function_table.cc | 24 ++----------- cpp/serve/function_table.h | 3 +- cpp/serve/model.cc | 52 ++++++++++++++-------------- cpp/serve/model.h | 15 +++++++-- 5 files changed, 106 insertions(+), 55 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index c9588cc4e8..d82c886355 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -63,10 +64,19 @@ class EngineImpl : public Engine { this->models_.clear(); this->model_workspaces_.clear(); - auto f_create_model = [this, &engine_config, &device, &trace_recorder]( - const String& model_path, const String& model_lib_path) { - Model model = Model::Create(model_lib_path, std::move(model_path), device, - engine_config->max_num_sequence, + std::vector model_configs; + model_configs.push_back(Model::LoadModelConfig(engine_config->model)); + for (const auto& model_path : engine_config->additional_models) { + model_configs.push_back(Model::LoadModelConfig(model_path)); + } + + Optional session = CreateDiscoSession(model_configs, device); + + auto f_create_model = [this, &engine_config, &device, &trace_recorder, &model_configs, + &session](const String& model_path, const String& model_lib_path, + int model_index) { + Model model = Model::Create(model_lib_path, std::move(model_path), model_configs[model_index], + device, engine_config->max_num_sequence, session, /*trace_enabled=*/trace_recorder.defined()); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, engine_config->max_total_sequence_length, @@ -81,13 +91,13 @@ class EngineImpl : public Engine { ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); }; - f_create_model(engine_config->model, engine_config->model_lib_path); + f_create_model(engine_config->model, engine_config->model_lib_path, /*model_index=*/0); CHECK_EQ(engine_config->additional_models.size(), engine_config->additional_model_lib_paths.size()) << "The additional model and lib path list has mismatched size."; for (int i = 0; i < static_cast(engine_config->additional_models.size()); ++i) { f_create_model(engine_config->additional_models[i], - engine_config->additional_model_lib_paths[i]); + engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1); } int max_num_tokens = engine_config->max_num_sequence; @@ -287,6 +297,51 @@ class EngineImpl : public Engine { "action (e.g. prefill, decode, etc.) but it does not."; } + /************** Utility Functions **************/ + Optional CreateDiscoSession(std::vector model_configs, Device device) { + const auto& base_model_config = model_configs[0]; + + auto f_get_num_shards = [](const picojson::object& model_config) -> int { + constexpr auto kNumShardsKey = "tensor_parallel_shards"; + if (model_config.count(kNumShardsKey)) { + const auto& val = model_config.at(kNumShardsKey); + CHECK(val.is()); + return static_cast(val.get()); + } else { + LOG(FATAL) << "Key \"tensor_parallel_shards\" not found."; + } + throw; + }; + + int num_shards = std::transform_reduce( + model_configs.begin(), model_configs.end(), 1, [](int a, int b) { return std::max(a, b); }, + f_get_num_shards); + Optional session = NullOpt; + if (num_shards > 1) { + constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; + if (Registry::Get(f_create_process_pool) == nullptr) { + LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " + << "Multi-GPU inference depends on MLC LLM Python API to launch process."; + } + std::string ccl; + if (device.device_type == kDLCUDA) { + ccl = "nccl"; + } else if (device.device_type == kDLROCM) { + ccl = "rccl"; + } else { + LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type) + << " is not supported. Currently, only NCCL and RCCL are integrated."; + } + std::vector device_ids(num_shards); + for (int i = 0; i < num_shards; ++i) { + device_ids[i] = i; + } + session = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); + session.value()->InitCCL(ccl, ShapeTuple(device_ids)); + } + return session; + } + /************** Debug/Profile **************/ void DebugCallFuncOnAllAllWorker(const String& func_name) final { diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index b721eae7c3..3267f1dd38 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -69,7 +69,8 @@ PackedFunc FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func, }); } -void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config) { +void FunctionTable::Init(String reload_lib_path, Device device, picojson::object model_config, + Optional session) { local_gpu_device = device; Device null_device{DLDeviceType(0), 0}; int num_shards; @@ -85,27 +86,8 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object this->cached_buffers = Map(); if (num_shards > 1) { - constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool"; - if (Registry::Get(f_create_process_pool) == nullptr) { - LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. " - << "Multi-GPU inference depends on MLC LLM Python API to launch process."; - } - std::string ccl; - if (device.device_type == kDLCUDA) { - ccl = "nccl"; - } else if (device.device_type == kDLROCM) { - ccl = "rccl"; - } else { - LOG(FATAL) << "ValueError: Multi-GPU on device " << DLDeviceType2Str(device.device_type) - << " is not supported. Currently, only NCCL and RCCL are integrated."; - } - std::vector device_ids(num_shards); - for (int i = 0; i < num_shards; ++i) { - device_ids[i] = i; - } + this->sess = session.value(); this->use_disco = true; - this->sess = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker"); - this->sess->InitCCL(ccl, ShapeTuple(device_ids)); this->disco_mod = sess->CallPacked(sess->GetGlobalFunc("runtime.disco.load_vm_module"), reload_lib_path, null_device); this->mod_get_func = [this, diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index b6ea3287ad..bc2b4f21c8 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -41,7 +41,8 @@ using namespace tvm::runtime; struct FunctionTable { static PackedFunc SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name); - void Init(String reload_lib_path, Device device, picojson::object model_config); + void Init(String reload_lib_path, Device device, picojson::object model_config, + Optional session); ObjectRef LoadParams(const std::string& model_path, Device device); diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 27a0043850..6f34220219 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -26,10 +26,27 @@ class ModelImpl; TVM_REGISTER_OBJECT_TYPE(ModelObj); -Model Model::Create(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled) { - return Model( - make_object(reload_lib_path, model_path, device, max_num_sequence, trace_enabled)); +Model Model::Create(String reload_lib_path, String model_path, const picojson::object& model_config, + DLDevice device, int max_num_sequence, const Optional& session, + bool trace_enabled) { + return Model(make_object(reload_lib_path, model_path, model_config, device, + max_num_sequence, session, trace_enabled)); +} + +picojson::object Model::LoadModelConfig(const String& model_path) { + picojson::object model_config; + std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); + std::ostringstream config_ostream; + ICHECK(config_istream); + config_ostream << config_istream.rdbuf(); + std::string config_str = config_ostream.str(); + picojson::value config_json; + std::string err = picojson::parse(config_json, config_str); + if (!err.empty()) { + LOG(FATAL) << err; + } + picojson::object config = config_json.get(); + return config; } class ModelImpl : public ModelObj { @@ -38,23 +55,16 @@ class ModelImpl : public ModelObj { * \brief Constructor of ModelImpl. * \sa Model::Create */ - explicit ModelImpl(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled) + explicit ModelImpl(String reload_lib_path, String model_path, picojson::object model_config, + DLDevice device, int max_num_sequence, const Optional& session, + bool trace_enabled) : device_(device) { // Step 1. Process model config json string. - picojson::object model_config; - { - std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); - std::ostringstream config_ostream; - ICHECK(config_istream); - config_ostream << config_istream.rdbuf(); - std::string config_str = config_ostream.str(); - model_config = LoadModelConfigJSON(config_str); - } + LoadModelConfigJSON(model_config); // Step 2. Initialize vm, we use the packed function mechanism // so there is no explicit abi dependency on these extra // classes other than basic tvm runtime. - this->ft_.Init(reload_lib_path, device_, model_config); + this->ft_.Init(reload_lib_path, device_, model_config, session); // Step 3. Load params in nd-array cache. this->params_ = ft_.LoadParams(model_path, device_); // Step 4. Set max_num_sequence @@ -891,15 +901,7 @@ class ModelImpl : public ModelObj { private: /*! \brief Load model configuration from JSON. */ - picojson::object LoadModelConfigJSON(const std::string& config_str) { - picojson::value config_json; - std::string err = picojson::parse(config_json, config_str); - if (!err.empty()) { - LOG(FATAL) << err; - } - - // Get json fields. - picojson::object config = config_json.get(); + picojson::object LoadModelConfigJSON(picojson::object config) { if (config.count("context_window_size")) { CHECK(config["context_window_size"].is()); this->max_window_size_ = config["context_window_size"].get(); diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 045daff874..bc63840a74 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -319,13 +319,24 @@ class Model : public ObjectRef { * \brief Create the runtime module for LLM functions. * \param reload_lib_path The model library path. * \param model_path The path to the model weight parameters. + * \param model_config The model config json object. * \param device The device to run the model on. * \param max_num_sequence The maximum number of sequences to be processed + * \param session The session to run the model on. * \param trace_enabled A boolean indicating whether tracing is enabled. * \return The created runtime module. */ - TVM_DLL static Model Create(String reload_lib_path, String model_path, DLDevice device, - int max_num_sequence, bool trace_enabled); + TVM_DLL static Model Create(String reload_lib_path, String model_path, + const picojson::object& model_config, DLDevice device, + int max_num_sequence, const Optional& session, + bool trace_enabled); + + /*! + * Load the model config from the given model path. + * \param model_path The path to the model weight parameters. + * \return The model config json object. + */ + static picojson::object LoadModelConfig(const String& model_path); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj); }; From 5a26795382e23986d9958e76eb033410b01dab48 Mon Sep 17 00:00:00 2001 From: Wei Tao <1136862851@qq.com> Date: Mon, 29 Apr 2024 19:37:28 +0800 Subject: [PATCH 248/531] [DOC] Improve Install via environment variable (#2245) improve Install via environment variable --- docs/install/mlc_llm.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index 7b64dce9fb..ce15616957 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -214,7 +214,9 @@ There are two ways to do so: .. code-tab :: bash Install via environment variable - export PYTHONPATH=/path-to-mlc-llm/python:$PYTHONPATH + export MLC_LLM_HOME=/path-to-mlc-llm + export PYTHONPATH=$MLC_LLM_HOME/python:$PYTHONPATH + alias mlc_llm="python -m mlc_llm" .. code-tab :: bash Install via pip local project From 3cb2ee83324dca47e7490209484b0a314372145d Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 29 Apr 2024 08:09:52 -0400 Subject: [PATCH 249/531] [Sampler] FlashInfer sampling func integration (#2224) This PR integrates the sampling function in FlashInfer. We integrate the one without top-p for now. --- cpp/serve/sampler/gpu_sampler.cc | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index 62911a7cd1..58a27c24f7 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -51,6 +51,9 @@ class GPUSampler : public SamplerObj { ICHECK(gpu_sample_with_top_p_func_.defined()); ICHECK(gpu_sampler_take_probs_func_.defined()); + flashinfer_multinomial_sample_func_ = + Registry::Get("flashinfer.sampling.parallel_sampling_from_prob"); + DLDevice device_cpu{DLDeviceType::kDLCPU, /*device_id=*/0}; // We support at most 5 top prob results for each sequence. // Initialize auxiliary arrays on CPU. @@ -76,6 +79,7 @@ class GPUSampler : public SamplerObj { token_tree_first_child_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_next_sibling_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_parent_ptr_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); + sampled_token_ids_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); // If the device is CUDA/ROCm, we create a standalone copy stream, in // purpose to hide the latency of auxiliary stream copy. @@ -495,8 +499,15 @@ class GPUSampler : public SamplerObj { if (!need_top_p && !need_prob_values) { // - Short path: If top_p and prob values are not needed, we directly sample from multinomial. SyncCopyStream(device_, compute_stream_, copy_stream_); - sampled_token_ids_device = gpu_multinomial_from_uniform_func_( - probs_on_device, uniform_samples_device, sample_indices_device); + if (flashinfer_multinomial_sample_func_ != nullptr) { + sampled_token_ids_device = + sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_); + (*flashinfer_multinomial_sample_func_)(probs_on_device, uniform_samples_device, + sample_indices_device, sampled_token_ids_device); + } else { + sampled_token_ids_device = gpu_multinomial_from_uniform_func_( + probs_on_device, uniform_samples_device, sample_indices_device); + } return {sampled_token_ids_device, sampled_probs_device, top_prob_probs_device, top_prob_indices_device}; } @@ -531,8 +542,15 @@ class GPUSampler : public SamplerObj { uniform_samples_device, sample_indices_device, top_p_device); } else { // - Sample without top_p. - sampled_token_ids_device = gpu_multinomial_from_uniform_func_( - probs_on_device, uniform_samples_device, sample_indices_device); + if (flashinfer_multinomial_sample_func_ != nullptr) { + sampled_token_ids_device = + sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_); + (*flashinfer_multinomial_sample_func_)(probs_on_device, uniform_samples_device, + sample_indices_device, sampled_token_ids_device); + } else { + sampled_token_ids_device = gpu_multinomial_from_uniform_func_( + probs_on_device, uniform_samples_device, sample_indices_device); + } } if (need_prob_values) { @@ -604,6 +622,7 @@ class GPUSampler : public SamplerObj { PackedFunc gpu_sampler_take_probs_func_; PackedFunc gpu_verify_draft_tokens_func_; PackedFunc gpu_renormalize_by_top_p_func_; + const PackedFunc* flashinfer_multinomial_sample_func_; // Auxiliary NDArrays on CPU NDArray uniform_samples_host_; NDArray sample_indices_host_; @@ -627,6 +646,7 @@ class GPUSampler : public SamplerObj { NDArray token_tree_first_child_device_; NDArray token_tree_next_sibling_device_; NDArray token_tree_parent_ptr_device_; + NDArray sampled_token_ids_device_; // The event trace recorder for requests. */ Optional trace_recorder_; // The device stream for the default computation operations. From d3d264d4b05d73e9757375013b842254f052c6ed Mon Sep 17 00:00:00 2001 From: Kartik Khandelwal Date: Mon, 29 Apr 2024 14:27:38 -0400 Subject: [PATCH 250/531] Model Library Delivery (#2139) * add model lib delivery * fix lint --- python/mlc_llm/cli/lib_delivery.py | 200 +++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 python/mlc_llm/cli/lib_delivery.py diff --git a/python/mlc_llm/cli/lib_delivery.py b/python/mlc_llm/cli/lib_delivery.py new file mode 100644 index 0000000000..a5d678fbe2 --- /dev/null +++ b/python/mlc_llm/cli/lib_delivery.py @@ -0,0 +1,200 @@ +"""Continuous model delivery for MLC LLM models.""" + +import argparse +import dataclasses +import json +import os +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Any, Callable, Dict, List + +from mlc_llm.support import logging +from mlc_llm.support.argparse import ArgumentParser +from mlc_llm.support.constants import MLC_TEMP_DIR +from mlc_llm.support.style import bold, green, red + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ModelInfo: # pylint: disable=too-many-instance-attributes + """Necessary information for the model delivery""" + + model_id: str + model: Path + quantization: str + device: str + # overrides the `context_window_size`, `prefill_chunk_size`, + # `sliding_window_size`, `attention_sink_size`, `max_batch_size` + # and `tensor_parallel_shards in mlc-chat-config.json + overrides: Dict[str, int] + + +class DeferredScope: + """A context manager that defers execution of functions until exiting the scope.""" + + def __init__(self): + self.deferred_functions = [] + + def add(self, func: Callable[[], None]): + """Add a function to be executed when exiting the scope.""" + self.deferred_functions.append(func) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + for func in reversed(self.deferred_functions): + func() + return False + + def create_temp_dir(self) -> Path: + """Create a temporary directory that will be deleted when exiting the scope.""" + temp_dir = tempfile.mkdtemp(dir=MLC_TEMP_DIR) + self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True)) + return Path(temp_dir) + + +def _run_compilation(model_info: ModelInfo, repo_dir: Path) -> bool: + """Run the compilation of the model library.""" + + def get_lib_ext(device: str) -> str: + if device in ["cuda", "vulkan", "metal"]: + return ".so" + if device in ["android", "ios"]: + return ".tar" + if device in ["webgpu"]: + return ".wasm" + + return "" + + succeeded = True + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as temp_dir: + log_path = Path(temp_dir) / "logs.txt" + model_lib_name = f"{model_info.model_id}-{model_info.quantization}-{model_info.device}" + lib_ext = get_lib_ext(model_info.device) + if lib_ext == "": + raise ValueError(f"Unsupported device: {model_info.device}") + model_lib_name += lib_ext + with log_path.open("a", encoding="utf-8") as log_file: + overrides = ";".join(f"{key}={value}" for key, value in model_info.overrides.items()) + cmd = [ + sys.executable, + "-m", + "mlc_llm", + "compile", + str(model_info.model), + "--device", + model_info.device, + "--quantization", + model_info.quantization, + "--overrides", + overrides, + "--output", + os.path.join(temp_dir, model_lib_name), + ] + print(" ".join(cmd), file=log_file, flush=True) + subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT) + logger.info("[MLC] Compilation Complete!") + if not (Path(temp_dir) / model_lib_name).exists(): + logger.error( + "[%s] Model %s. Device %s. No compiled library found.", + red("FAILED"), + model_info.model_id, + model_info.device, + ) + succeeded = False + return succeeded + + # overwrite git repo file with the compiled library + repo_filepath = repo_dir / model_info.model_id / model_lib_name + if not repo_filepath.parent.exists(): + repo_filepath.parent.mkdir(parents=True, exist_ok=True) + # copy lib from Path(temp_dir) / model_lib_name to repo_filepath + shutil.copy(Path(temp_dir) / model_lib_name, repo_filepath) + logger.info("Saved library %s at %s", model_lib_name, repo_filepath) + return succeeded + + +def _main( # pylint: disable=too-many-locals + spec: Dict[str, Any], +): + """Compile the model libs in the spec and save them to the binary_libs_dir.""" + failed_cases: List[Any] = [] + for task_index, task in enumerate(spec["tasks"], 1): + logger.info( + bold("[{task_index}/{total_tasks}] Processing model: ").format( + task_index=task_index, + total_tasks=len(spec["tasks"]), + ) + + green(task["model_id"]) + ) + model_info = { + "model_id": task["model_id"], + "model": task["model"], + } + for compile_opt in spec["default_compile_options"] + task.get("compile_options", []): + for quantization in spec["default_quantization"] + task.get("quantization", []): + model_info["quantization"] = quantization + model_info["device"] = compile_opt["device"] + model_info["overrides"] = compile_opt.get("overrides", {}) + logger.info( + "[Config] " + + bold("model_id: ") + + model_info["model_id"] + + bold(", quantization: ") + + model_info["quantization"] + + bold(", device: ") + + model_info["device"] + + bold(", overrides: ") + + json.dumps(model_info["overrides"]) + ) + + result = _run_compilation( + ModelInfo(**model_info), + repo_dir=Path(spec["binary_libs_dir"]), + ) + if not result: + failed_cases.append(model_info) + + if failed_cases: + logger.info("Total %s %s:", len(failed_cases), red("failures")) + for case in failed_cases: + logger.info( + "model_id %s, quantization %s, device %s, overrides %s", + case["model_id"], + case["quantization"], + case["device"], + json.dumps(case["overrides"]), + ) + + +def main(): + """Entry point.""" + + def _load_spec(path_spec: str) -> Dict[str, Any]: + path = Path(path_spec) + if not path.exists(): + raise argparse.ArgumentTypeError(f"Spec file does not exist: {path}") + with path.open("r", encoding="utf-8") as i_f: + return json.load(i_f) + + parser = ArgumentParser("MLC LLM continuous library delivery") + parser.add_argument( + "--spec", + type=_load_spec, + required=True, + help="Path to the spec file", + ) + parsed = parser.parse_args() + _main( + spec=parsed.spec, + ) + + +if __name__ == "__main__": + main() From 248996422773c0bf9d78177ec069e3052bfe81a4 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Tue, 30 Apr 2024 05:31:54 -0700 Subject: [PATCH 251/531] [Support] Simplify function names in encoding.h (#2251) This PR simplifies the tool function names in encoding.h. The new names are - PrintAsUTF8 - PrintAsEscaped - ParseNextUTF8 - ParseUTF8 - ParseNextUTF8OrEscaped Also make ParseNextUTF8 return the new char pointer instead of the number of chars processed to make the interface simpler. --- cpp/serve/grammar/grammar_parser.cc | 11 ++--- cpp/serve/grammar/grammar_serializer.cc | 4 +- cpp/serve/grammar/grammar_state_matcher.cc | 10 ++--- .../grammar/grammar_state_matcher_base.h | 8 ++-- .../grammar/grammar_state_matcher_preproc.h | 2 +- cpp/support/encoding.cc | 42 +++++++++---------- cpp/support/encoding.h | 14 +++---- 7 files changed, 46 insertions(+), 45 deletions(-) diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index 1ece99099e..55ab0a1dff 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -156,14 +156,14 @@ int32_t EBNFParserImpl::ParseCharacterClass() { continue; } - auto [codepoint, len] = Utf8OrEscapeToCodepoint(cur_, kCustomEscapeMap); + auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_, kCustomEscapeMap); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { - ThrowParseError("Invalid utf8 sequence"); + ThrowParseError("Invalid UTF8 sequence"); } if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { ThrowParseError("Invalid escape sequence"); } - Consume(len); + Consume(new_cur - cur_); if (past_is_hyphen) { ICHECK(!elements.empty()); if (elements.back().lower > codepoint) { @@ -194,14 +194,15 @@ int32_t EBNFParserImpl::ParseString() { if (Peek() == '\r' || Peek() == '\n') { ThrowParseError("There should be no newline character in a string literal"); } - auto [codepoint, len] = Utf8OrEscapeToCodepoint(cur_); + + auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { ThrowParseError("Invalid utf8 sequence"); } if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { ThrowParseError("Invalid escape sequence"); } - Consume(len); + Consume(new_cur - cur_); character_classes.push_back(builder_.AddCharacterClass({{codepoint, codepoint}})); } if (character_classes.empty()) { diff --git a/cpp/serve/grammar/grammar_serializer.cc b/cpp/serve/grammar/grammar_serializer.cc index fd41517863..c3c2c88baa 100644 --- a/cpp/serve/grammar/grammar_serializer.cc +++ b/cpp/serve/grammar/grammar_serializer.cc @@ -59,12 +59,12 @@ std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) { result += "^"; } for (auto i = 0; i < rule_expr.data_len; i += 2) { - result += CodepointToPrintable(rule_expr[i], kCustomEscapeMap); + result += PrintAsEscaped(rule_expr[i], kCustomEscapeMap); if (rule_expr[i] == rule_expr[i + 1]) { continue; } result += "-"; - result += CodepointToPrintable(rule_expr[i + 1], kCustomEscapeMap); + result += PrintAsEscaped(rule_expr[i + 1], kCustomEscapeMap); } result += "]"; return result; diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index 5c4ef98efe..451127e746 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -510,7 +510,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherResetState") bool MatchCompleteString(GrammarStateMatcher matcher, String str) { auto mutable_node = const_cast(matcher.as()); - auto codepoints = Utf8StringToCodepoints(str.c_str()); + auto codepoints = ParseUTF8(str.c_str()); int accepted_cnt = 0; for (auto codepoint : codepoints) { if (!mutable_node->AcceptCodepoint(codepoint, false)) { @@ -553,9 +553,9 @@ void PrintAcceptedRejectedTokens( // First cast to unsigned, then cast to int std::cerr << static_cast(static_cast(token[0])); } else { - auto codepoints = Utf8StringToCodepoints(token.c_str()); + auto codepoints = ParseUTF8(token.c_str()); for (auto c : codepoints) { - std::cerr << CodepointToPrintable(c); + std::cerr << PrintAsEscaped(c); } } std::cerr << "> "; @@ -571,9 +571,9 @@ void PrintAcceptedRejectedTokens( if (token.size() == 1 && ((unsigned char)token[0] >= 128 || token[0] == 0)) { std::cerr << (int)(unsigned char)token[0]; } else { - auto codepoints = Utf8StringToCodepoints(token.c_str()); + auto codepoints = ParseUTF8(token.c_str()); for (auto c : codepoints) { - std::cerr << CodepointToPrintable(c); + std::cerr << PrintAsEscaped(c); } } std::cerr << "> "; diff --git a/cpp/serve/grammar/grammar_state_matcher_base.h b/cpp/serve/grammar/grammar_state_matcher_base.h index 55c986bb10..5b774d33a4 100644 --- a/cpp/serve/grammar/grammar_state_matcher_base.h +++ b/cpp/serve/grammar/grammar_state_matcher_base.h @@ -156,15 +156,15 @@ inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool } if (tmp_new_stack_tops_.empty()) { if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << CodepointToPrintable(codepoint) - << "\" Rejected" << std::endl; + std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Rejected" + << std::endl; } return false; } stack_tops_history_.PushHistory(tmp_new_stack_tops_); if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << CodepointToPrintable(codepoint) - << "\" Accepted" << std::endl; + std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Accepted" + << std::endl; std::cout << "Stack after accepting: " << PrintStackState() << std::endl; } #if TVM_LOG_DEBUG diff --git a/cpp/serve/grammar/grammar_state_matcher_preproc.h b/cpp/serve/grammar/grammar_state_matcher_preproc.h index c853ac7e04..f63eee2c5c 100644 --- a/cpp/serve/grammar/grammar_state_matcher_preproc.h +++ b/cpp/serve/grammar/grammar_state_matcher_preproc.h @@ -268,7 +268,7 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC ptr->special_token_ids.push_back(i); } else { // First replace the special underscore with space. - auto codepoints = Utf8StringToCodepoints(token.c_str()); + auto codepoints = ParseUTF8(token.c_str()); DCHECK(!codepoints.empty() && codepoints[0] != static_cast(CharHandlingError::kInvalidUtf8)) << "Invalid token: " << token; diff --git a/cpp/support/encoding.cc b/cpp/support/encoding.cc index 0509c1eb2a..d9420bbbd5 100644 --- a/cpp/support/encoding.cc +++ b/cpp/support/encoding.cc @@ -11,7 +11,7 @@ namespace mlc { namespace llm { -std::string CodepointToUtf8(TCodepoint codepoint) { +std::string PrintAsUTF8(TCodepoint codepoint) { ICHECK(codepoint <= 0x10FFFF) << "Invalid codepoint: " << codepoint; std::string utf8; if (codepoint <= 0x7F) { @@ -36,8 +36,8 @@ std::string CodepointToUtf8(TCodepoint codepoint) { return utf8; } -std::string CodepointToPrintable( - TCodepoint codepoint, const std::unordered_map& custom_escape_map) { +std::string PrintAsEscaped(TCodepoint codepoint, + const std::unordered_map& custom_escape_map) { static const std::unordered_map kCodepointToEscape = { {'\'', "\\\'"}, {'\"', "\\\""}, {'\?', "\\\?"}, {'\\', "\\\\"}, {'\a', "\\a"}, {'\b', "\\b"}, {'\f', "\\f"}, {'\n', "\\n"}, {'\r', "\\r"}, {'\t', "\\t"}, @@ -63,10 +63,10 @@ std::string CodepointToPrintable( return codepoint <= 0xFFFF ? "\\u" + hex : "\\U" + hex; } -std::pair Utf8ToCodepoint(const char* utf8) { - const std::array kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07}; +std::pair ParseNextUTF8(const char* utf8) { + static const std::array kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07}; // clang-format off - const std::array kUtf8Bytes = { + static const std::array kUtf8Bytes = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, @@ -89,7 +89,7 @@ std::pair Utf8ToCodepoint(const char* utf8) { auto bytes = kUtf8Bytes[static_cast(utf8[0])]; if (bytes == -1) { // invalid utf8 - return {static_cast(CharHandlingError::kInvalidUtf8), 0}; + return {static_cast(CharHandlingError::kInvalidUtf8), utf8}; } TCodepoint res = static_cast(utf8[0]) & kFirstByteMask[bytes]; @@ -100,23 +100,23 @@ std::pair Utf8ToCodepoint(const char* utf8) { } res = (res << 6) | (static_cast(utf8[i]) & 0x3F); } - return {res, bytes}; + return {res, utf8 + bytes}; } -std::vector Utf8StringToCodepoints(const char* utf8) { +std::vector ParseUTF8(const char* utf8) { std::vector codepoints; while (*utf8 != 0) { - auto [codepoint, bytes] = Utf8ToCodepoint(utf8); + TCodepoint codepoint; + std::tie(codepoint, utf8) = ParseNextUTF8(utf8); if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { return {codepoint}; } codepoints.push_back(codepoint); - utf8 += bytes; } return codepoints; } -int HexCharToInt(char c) { +inline int HexCharToInt(char c) { if (c >= '0' && c <= '9') { return c - '0'; } else if (c >= 'a' && c <= 'f') { @@ -128,22 +128,22 @@ int HexCharToInt(char c) { } } -std::pair Utf8OrEscapeToCodepoint( +std::pair ParseNextUTF8OrEscaped( const char* utf8, const std::unordered_map& custom_escape_map) { static const std::unordered_map kEscapeToCodepoint = { {"\\\'", '\''}, {"\\\"", '\"'}, {"\\\?", '\?'}, {"\\\\", '\\'}, {"\\a", '\a'}, {"\\b", '\b'}, {"\\f", '\f'}, {"\\n", '\n'}, {"\\r", '\r'}, {"\\t", '\t'}, {"\\v", '\v'}, {"\\0", '\0'}, {"\\e", '\x1B'}}; if (utf8[0] != '\\') { - return Utf8ToCodepoint(utf8); + return ParseNextUTF8(utf8); } auto escape_sequence = std::string(utf8, 2); if (auto it = custom_escape_map.find(escape_sequence); it != custom_escape_map.end()) { - return {it->second, 2}; + return {it->second, utf8 + 2}; } if (auto it = kEscapeToCodepoint.find(escape_sequence); it != kEscapeToCodepoint.end()) { - return {it->second, 2}; + return {it->second, utf8 + 2}; } if (utf8[1] == 'x') { @@ -159,9 +159,9 @@ std::pair Utf8OrEscapeToCodepoint( ++len; } if (len == 0) { - return {static_cast(CharHandlingError::kInvalidEscape), 0}; + return {static_cast(CharHandlingError::kInvalidEscape), utf8}; } - return {codepoint, len + 2}; + return {codepoint, utf8 + len + 2}; } else if (utf8[1] == 'u' || utf8[1] == 'U') { // 4- or 8-digit hex int len = utf8[1] == 'u' ? 4 : 8; @@ -170,13 +170,13 @@ std::pair Utf8OrEscapeToCodepoint( for (int i = 0; i < len; ++i) { auto digit = HexCharToInt(utf8[i + 2]); if (digit == -1) { - return {static_cast(CharHandlingError::kInvalidEscape), 0}; + return {static_cast(CharHandlingError::kInvalidEscape), utf8}; } codepoint = codepoint * 16 + digit; } - return {codepoint, len + 2}; + return {codepoint, utf8 + len + 2}; } else { - return {static_cast(CharHandlingError::kInvalidEscape), 0}; + return {static_cast(CharHandlingError::kInvalidEscape), utf8}; } } diff --git a/cpp/support/encoding.h b/cpp/support/encoding.h index f28aae6d74..790040e97e 100644 --- a/cpp/support/encoding.h +++ b/cpp/support/encoding.h @@ -21,7 +21,7 @@ using TCodepoint = int32_t; * \param codepoint The codepoint. * \return The UTF-8 string. */ -std::string CodepointToUtf8(TCodepoint codepoint); +std::string PrintAsUTF8(TCodepoint codepoint); /*! * \brief Convert a codepoint to a printable string. If the codepoint is not printable, it will be @@ -29,10 +29,10 @@ std::string CodepointToUtf8(TCodepoint codepoint); * specify more escape sequences using custom_escape_map. * \param codepoint The codepoint. * \param custom_escape_map A map from codepoint to escape sequence. If the codepoint is in the map, - * it will be escaped using the corresponding escape sequence. e.g. {'-', "\\-"}. + * it will be escaped using the corresponding escape sequence. e.g. {{'-', "\\-"}}. * \return The printable string. */ -std::string CodepointToPrintable( +std::string PrintAsEscaped( TCodepoint codepoint, const std::unordered_map& custom_escape_map = {}); @@ -53,9 +53,9 @@ enum class CharHandlingError : TCodepoint { * \return The codepoint and the number of bytes consumed. If the UTF-8 string is invalid, the * function returns (CharHandlingError::kInvalidUtf8, 0). */ -std::pair Utf8ToCodepoint(const char* utf8); +std::pair ParseNextUTF8(const char* utf8); -std::vector Utf8StringToCodepoints(const char* utf8); +std::vector ParseUTF8(const char* utf8); /*! * \brief Convert a UTF-8 string or an escape sequence to a codepoint. By default the function @@ -63,12 +63,12 @@ std::vector Utf8StringToCodepoints(const char* utf8); * using custom_escape_map. * \param utf8 The UTF-8 string or the escape sequence. * \param custom_escape_map A map from escape sequence to codepoint. If the escape sequence is in - * the map, it will be converted to the corresponding codepoint. e.g. {"\\-", '-'}. + * the map, it will be converted to the corresponding codepoint. e.g. {{"\\-", '-'}}. * \return The codepoint and the number of bytes consumed. If the UTF-8 string or the escape * sequence is invalid, the function returns * (CharHandlingError::kInvalidUtf8 or CharHandlingError::kInvalidEscape, 0). */ -std::pair Utf8OrEscapeToCodepoint( +std::pair ParseNextUTF8OrEscaped( const char* utf8, const std::unordered_map& custom_escape_map = {}); } // namespace llm From afde65c8dc03c724691cf56c6b1e7595260e6116 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 30 Apr 2024 05:45:05 -0700 Subject: [PATCH 252/531] [Serving] Introduce DraftTokenWorkspaceManager (#2250) Using DraftTokenWorkspaceManager to maintain workspace for draft probs and hidden states (if needed). This allows states of the draft token to be kept fully on GPU. --- cpp/serve/draft_token_workspace_manager.cc | 54 +++++++++++ cpp/serve/draft_token_workspace_manager.h | 95 ++++++++++++++++++ cpp/serve/engine.cc | 55 ++++++----- cpp/serve/engine_actions/action.h | 29 ++++-- cpp/serve/engine_actions/action_commons.cc | 13 ++- cpp/serve/engine_actions/action_commons.h | 13 ++- cpp/serve/engine_actions/batch_decode.cc | 2 +- cpp/serve/engine_actions/batch_draft.cc | 28 ++++-- cpp/serve/engine_actions/batch_verify.cc | 42 +++++--- cpp/serve/engine_actions/eagle_batch_draft.cc | 75 ++++++--------- .../engine_actions/eagle_batch_verify.cc | 96 +++++++++++-------- .../eagle_new_request_prefill.cc | 50 ++++++---- .../engine_actions/new_request_prefill.cc | 2 +- cpp/serve/function_table.cc | 11 ++- cpp/serve/function_table.h | 17 +++- cpp/serve/logit_processor.cc | 4 +- cpp/serve/model.cc | 77 ++++++++++++--- cpp/serve/model.h | 40 +++++++- cpp/serve/request_state.cc | 13 +-- cpp/serve/request_state.h | 27 ++---- cpp/serve/sampler/cpu_sampler.cc | 12 +-- cpp/serve/sampler/gpu_sampler.cc | 23 ++--- cpp/serve/sampler/sampler.h | 7 +- .../attach_spec_decode_aux_funcs.py | 66 +++++++++++++ python/mlc_llm/compiler_pass/pipeline.py | 2 + .../mlc_llm/compiler_pass/rewrite_softmax.py | 5 +- 26 files changed, 627 insertions(+), 231 deletions(-) create mode 100644 cpp/serve/draft_token_workspace_manager.cc create mode 100644 cpp/serve/draft_token_workspace_manager.h create mode 100644 python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py diff --git a/cpp/serve/draft_token_workspace_manager.cc b/cpp/serve/draft_token_workspace_manager.cc new file mode 100644 index 0000000000..185b899e14 --- /dev/null +++ b/cpp/serve/draft_token_workspace_manager.cc @@ -0,0 +1,54 @@ +/*! + * Copyright (c) 2024 by Contributors + * \file serve/draft_token_workspace_manager.cc + */ + +#include "draft_token_workspace_manager.h" + +#include "model.h" + +namespace mlc { +namespace llm { +namespace serve { + +DraftTokenWorkspaceManagerObj::DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size, + int hidden_size, + DLDataType hidden_states_dtype, + DLDevice device, + const FunctionTable& ft) + : max_num_tokens_(max_num_tokens), + vocab_size_(vocab_size), + hidden_size_(hidden_size), + hidden_states_dtype_(hidden_states_dtype), + device_(device), + ft_(ft) { + free_slots_.resize(max_num_tokens); + std::iota(free_slots_.begin(), free_slots_.end(), 0); +} + +void DraftTokenWorkspaceManagerObj::AllocSlots(int num_slots, std::vector* result) { + ICHECK_LE(num_slots, free_slots_.size()); + result->assign(free_slots_.rbegin(), free_slots_.rbegin() + num_slots); + std::vector allocated(free_slots_.begin(), free_slots_.begin() + num_slots); + free_slots_.resize(free_slots_.size() - num_slots); +} + +void DraftTokenWorkspaceManagerObj::FreeSlots(const std::vector& slots) { + std::copy(slots.begin(), slots.end(), std::back_inserter(free_slots_)); +} + +void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace, + bool require_hidden_states) { + workspace->draft_probs = + NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_); + workspace->draft_probs_storage = + NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_); + if (require_hidden_states) { + workspace->draft_hidden_states_storage = + NDArray::Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_); + } +} + +} // namespace serve +} // namespace llm +} // namespace mlc diff --git a/cpp/serve/draft_token_workspace_manager.h b/cpp/serve/draft_token_workspace_manager.h new file mode 100644 index 0000000000..1a1dfbc8e0 --- /dev/null +++ b/cpp/serve/draft_token_workspace_manager.h @@ -0,0 +1,95 @@ +/*! + * Copyright (c) 2024 by Contributors + * \file serve/draft_token_workspace_manager.h + */ + +#ifndef MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_ +#define MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_ +#include + +#include +#include +#include + +#include "data.h" +#include "function_table.h" +namespace mlc { +namespace llm { +namespace serve { + +using tvm::Device; +using namespace tvm::runtime; + +struct ModelWorkspace; + +/*! + * \brief Managing the workspace for draft token generation. + * + * The workspace is used to store the associated states for each draft token, including the + * probability distribution of the draft token, the hidden states, etc. The workspace manager + * maintains a pool of slots for the draft tokens to store the states. + */ +class DraftTokenWorkspaceManagerObj : public Object { + public: + /*! + * \brief Constructor + * \param max_num_tokens The maximum number of draft tokens that can be stored in the workspace. + * \param vocab_size The size of the vocabulary. + * \param hidden_size The size of the hidden states. + * \param hidden_states_dtype The data type of the hidden states. + * \param device The device running the model. + * \param ft The function table. + */ + DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size, int hidden_size, + DLDataType hidden_states_dtype, DLDevice device, + const FunctionTable& ft); + + /*! + * \brief Allocate the workspace for draft tokens and update `ModelWorkspace` data structure. + * \param workspace The object to stored the allocated draft token workspace. + * \param require_hidden_states Whether to allocate workspace for the hidden states. + */ + void AllocWorkspace(ModelWorkspace* workspace, bool require_hidden_states); + + /*! + * \brief Allocate slots for the draft tokens. + * \param num_slots The number of slots to allocate. + * \param result The vector to store the allocated slots. + */ + void AllocSlots(int num_slots, std::vector* result); + + /*! + * \brief Free the slots. + * \param slots The slots to free. + */ + void FreeSlots(const std::vector& slots); + + static constexpr const char* _type_key = "mlc.serve.DraftTokenWorkspaceManager"; + + private: + std::vector free_slots_; + int max_num_tokens_; + int vocab_size_; + int hidden_size_; + DataType hidden_states_dtype_; + DLDevice device_; + const FunctionTable& ft_; +}; + +class DraftTokenWorkspaceManager : public ObjectRef { + public: + DraftTokenWorkspaceManager(int max_num_tokens, int vocab_size, int hidden_size, + DLDataType hidden_states_dtype, DLDevice device, + const FunctionTable& ft) { + data_ = make_object(max_num_tokens, vocab_size, hidden_size, + hidden_states_dtype, device, ft); + } + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DraftTokenWorkspaceManager, ObjectRef, + DraftTokenWorkspaceManagerObj); +}; + +} // namespace serve +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_ diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index d82c886355..9703dda472 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -101,8 +101,13 @@ class EngineImpl : public Engine { } int max_num_tokens = engine_config->max_num_sequence; + DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr}; if (engine_config->speculative_mode != SpeculativeMode::kDisable) { max_num_tokens *= engine_config->spec_draft_length + 1; + draft_token_workspace_manager = models_[0]->CreateDraftTokenWorkspaceManager(max_num_tokens); + draft_token_workspace_manager->AllocWorkspace( + &model_workspaces_[0], + /*require_hidden_states=*/engine_config->speculative_mode == SpeculativeMode::kEagle); } LogitProcessor logit_processor = this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); @@ -114,30 +119,36 @@ class EngineImpl : public Engine { ICHECK_GT(this->models_.size(), 1U); switch (engine_config->speculative_mode) { case SpeculativeMode::kEagle: - this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::EagleBatchDraft( - this->models_, logit_processor, sampler, this->model_workspaces_, - this->trace_recorder_, engine_config->spec_draft_length), - EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, - this->model_workspaces_, engine_config, - this->trace_recorder_)}; + this->actions_ = { + EngineAction::EagleNewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + draft_token_workspace_manager, // + engine_config, // + this->trace_recorder_), + EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler, + this->model_workspaces_, draft_token_workspace_manager, + this->trace_recorder_, + engine_config->spec_draft_length), + EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, + this->model_workspaces_, draft_token_workspace_manager, + engine_config, this->trace_recorder_)}; break; default: - this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::BatchDraft(this->models_, logit_processor, sampler, - this->trace_recorder_), - EngineAction::BatchVerify(this->models_, logit_processor, sampler, - engine_config, this->trace_recorder_)}; + this->actions_ = { + EngineAction::NewRequestPrefill(this->models_, // + logit_processor, // + sampler, // + this->model_workspaces_, // + engine_config, // + this->trace_recorder_), + EngineAction::BatchDraft(this->models_, logit_processor, sampler, + this->model_workspaces_, draft_token_workspace_manager, + this->trace_recorder_), + EngineAction::BatchVerify(this->models_, logit_processor, sampler, + this->model_workspaces_, draft_token_workspace_manager, + engine_config, this->trace_recorder_)}; } } else { this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index 79359c5741..c69c508810 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -8,6 +8,7 @@ #define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_ #include "../config.h" +#include "../draft_token_workspace_manager.h" #include "../engine_state.h" #include "../event_trace_recorder.h" #include "../model.h" @@ -72,15 +73,16 @@ class EngineAction : public ObjectRef { * \param logit_processor The logit processor. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ - static EngineAction EagleNewRequestPrefill(Array models, LogitProcessor logit_processor, - Sampler sampler, - std::vector model_workspaces, - EngineConfig engine_config, - Optional trace_recorder); + static EngineAction EagleNewRequestPrefill( + Array models, LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, + Optional trace_recorder); /*! * \brief Create the action that runs one-step decode for requests in the * `running_queue` of engine state. Preempt low-priority requests @@ -104,13 +106,16 @@ class EngineAction : public ObjectRef { * \param models The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. * \param sampler The sampler to sample new tokens. + * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param trace_recorder The event trace recorder for requests. * \param draft_length The number of draft proposal rounds. * \return The created action object. */ static EngineAction BatchDraft(Array models, LogitProcessor logit_processor, - Sampler sampler, Optional trace_recorder, - int draft_length = 4); + Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + Optional trace_recorder, int draft_length = 4); /*! * \brief Create the action that runs one-step speculative draft proposal for @@ -120,12 +125,14 @@ class EngineAction : public ObjectRef { * models, the `Step` function of the created action will not take effect. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param trace_recorder The event trace recorder for requests. * \param draft_length The number of draft proposal rounds. * \return The created action object. */ static EngineAction EagleBatchDraft(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, Optional trace_recorder, int draft_length = 4); @@ -135,13 +142,17 @@ class EngineAction : public ObjectRef { * accordingly when it is impossible to decode all the running requests. * \param models The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. + * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param sampler The sampler to sample new tokens. * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. */ static EngineAction BatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, EngineConfig engine_config, + Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + EngineConfig engine_config, Optional trace_recorder); /*! @@ -152,6 +163,7 @@ class EngineAction : public ObjectRef { * models, the `Step` function of the created action will not take effect. * \param sampler The sampler to sample new tokens. * \param model_workspaces The workspace of each model. + * \param draft_token_workspace_manager The draft token workspace manager. * \param engine_config The engine config. * \param trace_recorder The event trace recorder for requests. * \return The created action object. @@ -159,6 +171,7 @@ class EngineAction : public ObjectRef { static EngineAction EagleBatchVerify(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, Optional trace_recorder); diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index 6eb7a3d84a..af0dfe978d 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -142,9 +142,10 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array& models, - Optional trace_recorder) { +RequestStateEntry PreemptLastRunningRequestStateEntry( + EngineState estate, const Array& models, + Optional draft_token_workspace_manager, + Optional trace_recorder) { ICHECK(!estate->running_queue.empty()); Request request = estate->running_queue.back(); @@ -168,8 +169,12 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate, // - Update `inputs` for future prefill. RECORD_EVENT(trace_recorder, rsentry->request->id, "preempt"); rsentry->status = RequestStateStatus::kPending; + std::vector draft_token_slots; for (RequestModelState mstate : rsentry->mstates) { - mstate->RemoveAllDraftTokens(); + if (draft_token_workspace_manager.defined()) { + mstate->RemoveAllDraftTokens(&draft_token_slots); + draft_token_workspace_manager.value()->FreeSlots(draft_token_slots); + } std::vector committed_token_ids; committed_token_ids.reserve(mstate->committed_tokens.size()); for (const SampleResult& committed_token : mstate->committed_tokens) { diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index 78e3937d0b..07bef2d2d9 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -7,6 +7,7 @@ #define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_ #include "../../tokenizers.h" +#include "../draft_token_workspace_manager.h" #include "../engine.h" #include "../engine_state.h" #include "../event_trace_recorder.h" @@ -52,12 +53,14 @@ void ActionStepPostProcess(Array requests, EngineState estate, Array& models, - Optional trace_recorder); +RequestStateEntry PreemptLastRunningRequestStateEntry( + EngineState estate, const Array& models, + Optional draft_token_workspace_manager, + Optional trace_recorder); /*! \brief Get the running request entries from the engine state. */ inline std::vector GetRunningRequestStateEntries(const EngineState& estate) { diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index 36acc6b06e..ecff914baa 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -48,7 +48,7 @@ class BatchDecodeActionObj : public EngineActionObj { running_rsentries = GetRunningRequestStateEntries(estate); while (!CanDecode(running_rsentries.size())) { RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + PreemptLastRunningRequestStateEntry(estate, models_, NullOpt, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { running_rsentries.pop_back(); } diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index c1ddeb6e4e..513a0fe447 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -23,10 +23,14 @@ namespace serve { class BatchDraftActionObj : public EngineActionObj { public: explicit BatchDraftActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, Optional trace_recorder, int draft_length) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), trace_recorder_(std::move(trace_recorder)), draft_length_(draft_length) { ICHECK_GT(draft_length_, 0); @@ -41,8 +45,8 @@ class BatchDraftActionObj : public EngineActionObj { // Preempt request state entries when decode cannot apply. std::vector running_rsentries = GetRunningRequestStateEntries(estate); while (!CanDecode(running_rsentries.size())) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + RequestStateEntry preempted = PreemptLastRunningRequestStateEntry( + estate, models_, draft_token_workspace_manager_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { running_rsentries.pop_back(); } @@ -123,8 +127,11 @@ class BatchDraftActionObj : public EngineActionObj { ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. + draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_); + models_[model_id]->ScatterDraftProbs(probs_on_device, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); for (int i = 0; i < num_rsentries; ++i) { - mstates[i]->AddDraftToken(sample_results[i], prob_dist[i]); + mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); estate->stats.total_draft_length += 1; } } @@ -156,18 +163,27 @@ class BatchDraftActionObj : public EngineActionObj { LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; + /*! \brief The model workspaces. */ + std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief Event trace recorder. */ Optional trace_recorder_; /*! \brief Draft proposal length */ int draft_length_; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; EngineAction EngineAction::BatchDraft(Array models, LogitProcessor logit_processor, - Sampler sampler, Optional trace_recorder, + Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + Optional trace_recorder, int draft_length) { return EngineAction(make_object( - std::move(models), std::move(logit_processor), std::move(sampler), std::move(trace_recorder), - draft_length)); + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(trace_recorder), draft_length)); } } // namespace serve diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 42c9bbe018..6f27a50394 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -28,11 +28,15 @@ namespace serve { class BatchVerifyActionObj : public EngineActionObj { public: explicit BatchVerifyActionObj(Array models, LogitProcessor logit_processor, - Sampler sampler, EngineConfig engine_config, + Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), + model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), rng_(RandomGenerator::GetInstance()) {} @@ -61,14 +65,13 @@ class BatchVerifyActionObj : public EngineActionObj { Array generation_cfg; std::vector rngs; std::vector> draft_output_tokens; - std::vector> draft_output_prob_dist; request_internal_ids.reserve(num_rsentries); all_tokens_to_verify.reserve(total_verify_length); verify_request_mstates.reserve(num_rsentries); rngs.reserve(num_rsentries); generation_cfg.reserve(num_rsentries); draft_output_tokens.reserve(num_rsentries); - draft_output_prob_dist.reserve(num_rsentries); + draft_token_slots_.clear(); for (int i = 0; i < num_rsentries; ++i) { RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_]; @@ -76,18 +79,22 @@ class BatchVerifyActionObj : public EngineActionObj { request_internal_ids.push_back(verify_mstate->internal_id); ICHECK(!verify_lengths.empty()); ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_tokens.size() + 1); - ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_prob_dist.size() + 1); + ICHECK_EQ(verify_lengths[i], draft_mstate->draft_token_slots.size() + 1); // the last committed token + all the draft tokens. + draft_token_slots_.push_back(0); // placeholder for the last committed token all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().sampled_token_id.first); for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()); ++j) { all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); + draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]); } verify_request_mstates.push_back(verify_mstate); generation_cfg.push_back(rsentries[i]->request->generation_cfg); rngs.push_back(&rsentries[i]->rng); draft_output_tokens.push_back(draft_mstate->draft_output_tokens); - draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); } + NDArray draft_probs_on_device = models_[draft_model_id_]->GatherDraftProbs( + model_workspaces_[verify_model_id_].draft_probs_storage, draft_token_slots_, + &model_workspaces_[verify_model_id_].draft_probs); RECORD_EVENT(trace_recorder_, request_ids, "start verify embedding"); ObjectRef embeddings = models_[verify_model_id_]->TokenEmbed( @@ -123,7 +130,7 @@ class BatchVerifyActionObj : public EngineActionObj { std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokensWithProbAfterTopP( renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs, - draft_output_tokens, draft_output_prob_dist); + draft_output_tokens, draft_probs_on_device); ICHECK_EQ(sample_results_arr.size(), num_rsentries); for (int i = 0; i < num_rsentries; ++i) { @@ -149,7 +156,8 @@ class BatchVerifyActionObj : public EngineActionObj { // clear the draft model state entries for (int i = 0; i < num_rsentries; ++i) { - rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(); + rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(&draft_token_slots_); + draft_token_workspace_manager_->FreeSlots(draft_token_slots_); } auto tend = std::chrono::high_resolution_clock::now(); @@ -194,8 +202,8 @@ class BatchVerifyActionObj : public EngineActionObj { total_required_pages += num_require_pages; } while (!CanVerify(total_required_pages)) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + RequestStateEntry preempted = PreemptLastRunningRequestStateEntry( + estate, models_, draft_token_workspace_manager_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { total_verify_length -= verify_lengths.back(); total_required_pages -= num_page_requirement.back(); @@ -222,6 +230,10 @@ class BatchVerifyActionObj : public EngineActionObj { LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; + /*! \brief The model workspaces. */ + std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief The engine config. */ EngineConfig engine_config_; /*! \brief Event trace recorder. */ @@ -232,14 +244,20 @@ class BatchVerifyActionObj : public EngineActionObj { const int verify_model_id_ = 0; const int draft_model_id_ = 1; const float eps_ = 1e-5; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; EngineAction EngineAction::BatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, EngineConfig engine_config, + Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, + EngineConfig engine_config, Optional trace_recorder) { return EngineAction(make_object( - std::move(models), std::move(logit_processor), std::move(sampler), std::move(engine_config), - std::move(trace_recorder))); + std::move(models), std::move(logit_processor), std::move(sampler), + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/eagle_batch_draft.cc b/cpp/serve/engine_actions/eagle_batch_draft.cc index fde314a5c5..7ad66a045c 100644 --- a/cpp/serve/engine_actions/eagle_batch_draft.cc +++ b/cpp/serve/engine_actions/eagle_batch_draft.cc @@ -24,11 +24,13 @@ class EagleBatchDraftActionObj : public EngineActionObj { public: explicit EagleBatchDraftActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, Optional trace_recorder, int draft_length) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), trace_recorder_(std::move(trace_recorder)), draft_length_(draft_length) { ICHECK_GT(draft_length_, 0); @@ -43,8 +45,8 @@ class EagleBatchDraftActionObj : public EngineActionObj { // Preempt request state entries when decode cannot apply. std::vector running_rsentries = GetRunningRequestStateEntries(estate); while (!CanDecode(running_rsentries.size())) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + RequestStateEntry preempted = PreemptLastRunningRequestStateEntry( + estate, models_, draft_token_workspace_manager_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { running_rsentries.pop_back(); } @@ -81,21 +83,20 @@ class EagleBatchDraftActionObj : public EngineActionObj { mstates.push_back(rsentry->mstates[model_id]); } // draft_length_ rounds of draft proposal. - NDArray hidden_states_nd{nullptr}; ObjectRef last_hidden_states{nullptr}; - ObjectRef hidden_states = model_workspaces_[model_id].hidden_states; + NDArray hidden_states = Downcast(model_workspaces_[model_id].hidden_states); // Concat last hidden_states - std::vector previous_hidden_on_device; - for (int i = 0; i < num_rsentries; ++i) { - previous_hidden_on_device.push_back(mstates[i]->draft_last_hidden_on_device.back()); + draft_token_slots_.clear(); + if (draft_length_ > 1) { + for (int i = 0; i < num_rsentries; ++i) { + draft_token_slots_.push_back(mstates[i]->draft_token_slots.back()); + } + hidden_states = Downcast(models_[model_id]->GatherHiddenStates( + model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states)); + ICHECK(hidden_states->ndim == 2); + last_hidden_states = hidden_states.CreateView( + {hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype); } - hidden_states_nd = - models_[model_id]->ConcatLastHidden(previous_hidden_on_device, &hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 2); - ICHECK_EQ(hidden_states_nd->shape[0], num_rsentries); - hidden_states_nd = hidden_states_nd.CreateView( - {hidden_states_nd->shape[0], 1, hidden_states_nd->shape[1]}, hidden_states_nd->dtype); - last_hidden_states = hidden_states_nd; // The first draft token has been generated in prefill/verify stage for (int draft_id = 1; draft_id < draft_length_; ++draft_id) { // prepare new input tokens @@ -115,17 +116,17 @@ class EagleBatchDraftActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden( embeddings, last_hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); - hidden_states_nd = + hidden_states = models_[model_id]->BatchDecodeToLastHidden(fused_hidden_states, request_internal_ids); - last_hidden_states = hidden_states_nd; + last_hidden_states = hidden_states; NDArray logits; if (models_[model_id]->CanGetLogits()) { - logits = models_[model_id]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, + logits = models_[model_id]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); } else { // - Use base model's head. logits = - models_[0]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + models_[0]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); } RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); ICHECK_EQ(logits->ndim, 3); @@ -152,12 +153,12 @@ class EagleBatchDraftActionObj : public EngineActionObj { ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. + draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_); + models_[model_id]->ScatterDraftProbs(probs_on_device, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); + // No need to save hidden states as they are not used by subsequent engine actions for (int i = 0; i < num_rsentries; ++i) { - // - Slice hidden_states_for_sample - NDArray last_hidden_on_device = GetTokenHidden(hidden_states_nd, i); - CHECK(i < static_cast(prob_dist.size())); - CHECK(prob_dist[i].defined()); - mstates[i]->AddDraftToken(sample_results[i], prob_dist[i], last_hidden_on_device); + mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); estate->stats.total_draft_length += 1; } } @@ -183,26 +184,6 @@ class EagleBatchDraftActionObj : public EngineActionObj { return true; } - /*! - * \brief Get one item from a hidden_states array, which corresponds to the last token. - * \param hidden_states The hidden_states of all the tokens. - * \param token_pos The desired token position in the sequence. - * \return The desired token's hidden_states - */ - NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { - ICHECK_EQ(hidden_states->ndim, 3); - NDArray last_hidden_on_device = - NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); - - int64_t ndata = hidden_states->shape[2]; - const int16_t* __restrict p_hidden = - static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + - (token_pos * ndata); - - last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); - return last_hidden_on_device; - } - /*! \brief The model to run draft generation in speculative decoding. */ Array models_; /*! \brief The logit processor. */ @@ -211,20 +192,26 @@ class EagleBatchDraftActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief Event trace recorder. */ Optional trace_recorder_; /*! \brief Draft proposal length */ int draft_length_; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; EngineAction EngineAction::EagleBatchDraft(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, Optional trace_recorder, int draft_length) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(trace_recorder), draft_length)); + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(trace_recorder), draft_length)); } } // namespace serve diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index b259417050..d52f60d5c7 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -29,12 +29,14 @@ class EagleBatchVerifyActionObj : public EngineActionObj { public: explicit EagleBatchVerifyActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)), rng_(RandomGenerator::GetInstance()) {} @@ -70,7 +72,7 @@ class EagleBatchVerifyActionObj : public EngineActionObj { rngs.reserve(num_rsentries); generation_cfg.reserve(num_rsentries); draft_output_tokens.reserve(num_rsentries); - draft_output_prob_dist.reserve(num_rsentries); + draft_token_slots_.clear(); for (int i = 0; i < num_rsentries; ++i) { RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_]; @@ -78,19 +80,24 @@ class EagleBatchVerifyActionObj : public EngineActionObj { request_internal_ids.push_back(verify_mstate->internal_id); ICHECK(!draft_lengths.empty()); ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_tokens.size()); - ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_prob_dist.size()); + ICHECK_EQ(draft_lengths[i], draft_mstate->draft_token_slots.size()); // the last committed token + all the draft tokens but the last one. all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().sampled_token_id.first); + draft_token_slots_.push_back(0); // placeholder for the last committed token for (int j = 0; j < static_cast(draft_mstate->draft_output_tokens.size()); ++j) { all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].sampled_token_id.first); + draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]); } verify_request_mstates.push_back(verify_mstate); generation_cfg.push_back(rsentries[i]->request->generation_cfg); rngs.push_back(&rsentries[i]->rng); draft_output_tokens.push_back(draft_mstate->draft_output_tokens); - draft_output_prob_dist.push_back(draft_mstate->draft_output_prob_dist); } + NDArray draft_probs_on_device = models_[draft_model_id_]->GatherDraftProbs( + model_workspaces_[verify_model_id_].draft_probs_storage, draft_token_slots_, + &model_workspaces_[verify_model_id_].draft_probs); + std::vector cum_verify_lengths = {0}; cum_verify_lengths.reserve(num_rsentries + 1); std::vector verify_lengths; @@ -135,10 +142,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj { std::vector> sample_results_arr = sampler_->BatchVerifyDraftTokensWithProbAfterTopP( renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs, - draft_output_tokens, draft_output_prob_dist); + draft_output_tokens, draft_probs_on_device); ICHECK_EQ(sample_results_arr.size(), num_rsentries); - std::vector last_hidden_states; + std::vector last_accepted_hidden_positions; + last_accepted_hidden_positions.reserve(num_rsentries); for (int i = 0; i < num_rsentries; ++i) { const std::vector& sample_results = sample_results_arr[i]; int accept_length = sample_results.size(); @@ -163,24 +171,24 @@ class EagleBatchVerifyActionObj : public EngineActionObj { rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1); } // clear the draft model state entries - rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(); - // - Slice hidden_states_for_sample - NDArray last_hidden_on_device = - GetTokenHidden(hidden_states, (cum_verify_lengths[i] + accept_length - 1)); - last_hidden_states.push_back(last_hidden_on_device); + rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(&draft_token_slots_); + draft_token_workspace_manager_->FreeSlots(draft_token_slots_); + // - Slice and save hidden_states_for_sample + last_accepted_hidden_positions.push_back(cum_verify_lengths[i] + accept_length - 1); } { // One step draft for the following steps - NDArray hidden_states_nd{nullptr}; - ObjectRef next_hidden_states = model_workspaces_[draft_model_id_].hidden_states; - // Concat last hidden_states - hidden_states_nd = - models_[draft_model_id_]->ConcatLastHidden(last_hidden_states, &next_hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 2); - ICHECK_EQ(hidden_states_nd->shape[0], num_rsentries); - hidden_states_nd = hidden_states_nd.CreateView( - {hidden_states_nd->shape[0], 1, hidden_states_nd->shape[1]}, hidden_states_nd->dtype); + NDArray last_hidden_states_nd = hidden_states.CreateView( + {hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]}, + hidden_states->dtype); + + hidden_states = Downcast(models_[draft_model_id_]->GatherHiddenStates( + last_hidden_states_nd, last_accepted_hidden_positions, + &model_workspaces_[draft_model_id_].hidden_states)); + ICHECK(hidden_states->ndim == 2); + hidden_states = hidden_states.CreateView( + {hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype); std::vector input_tokens; Array mstates; @@ -203,17 +211,16 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // - Invoke model decode. RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); ObjectRef fused_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( - embeddings, hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); - hidden_states_nd = models_[draft_model_id_]->BatchDecodeToLastHidden(fused_hidden_states, - request_internal_ids); + embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden(fused_hidden_states, + request_internal_ids); if (models_[draft_model_id_]->CanGetLogits()) { - logits = models_[draft_model_id_]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, + logits = models_[draft_model_id_]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); } else { // - Use base model's head. - logits = - models_[0]->GetLogits(hidden_states_nd, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + logits = models_[0]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); } RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); ICHECK_EQ(logits->ndim, 3); @@ -239,13 +246,21 @@ class EagleBatchVerifyActionObj : public EngineActionObj { renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); ICHECK_EQ(sample_results.size(), num_rsentries); + // - Slice and save hidden_states_for_sample + draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_); + models_[draft_model_id_]->ScatterDraftProbs( + renormalized_probs, draft_token_slots_, + &model_workspaces_[verify_model_id_].draft_probs_storage); + ICHECK(hidden_states->ndim == 3); + hidden_states = hidden_states.CreateView( + {hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]}, + hidden_states->dtype); + models_[draft_model_id_]->ScatterHiddenStates( + hidden_states, draft_token_slots_, + &model_workspaces_[verify_model_id_].draft_hidden_states_storage); // - Add draft token to the state. for (int i = 0; i < num_rsentries; ++i) { - // - Slice hidden_states_for_sample - NDArray last_hidden_on_device = GetTokenHidden(hidden_states_nd, i); - CHECK(i < static_cast(prob_dist.size())); - CHECK(prob_dist[i].defined()); - mstates[i]->AddDraftToken(sample_results[i], prob_dist[i], last_hidden_on_device); + mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); estate->stats.total_draft_length += 1; } } @@ -292,8 +307,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { total_required_pages += num_require_pages; } while (!CanVerify(total_required_pages)) { - RequestStateEntry preempted = - PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_); + RequestStateEntry preempted = PreemptLastRunningRequestStateEntry( + estate, models_, draft_token_workspace_manager_, trace_recorder_); if (preempted.same_as(running_rsentries.back())) { total_draft_length -= draft_lengths.back(); total_required_pages -= num_page_requirement.back(); @@ -342,6 +357,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief The engine config. */ EngineConfig engine_config_; /*! \brief Event trace recorder. */ @@ -352,16 +369,19 @@ class EagleBatchVerifyActionObj : public EngineActionObj { const int verify_model_id_ = 0; const int draft_model_id_ = 1; const float eps_ = 1e-5; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; -EngineAction EngineAction::EagleBatchVerify(Array models, LogitProcessor logit_processor, - Sampler sampler, - std::vector model_workspaces, - EngineConfig engine_config, - Optional trace_recorder) { +EngineAction EngineAction::EagleBatchVerify( + Array models, LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, + Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(engine_config), std::move(trace_recorder))); + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index a687e7eb7f..57310f7986 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -24,12 +24,14 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { explicit EagleNewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, Optional trace_recorder) : models_(std::move(models)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), engine_config_(std::move(engine_config)), trace_recorder_(std::move(trace_recorder)) {} @@ -107,7 +109,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } ICHECK(mstate->draft_output_tokens.empty()); - ICHECK(mstate->draft_output_prob_dist.empty()); + ICHECK(mstate->draft_token_slots.empty()); if (status_before_prefill[i] == RequestStateStatus::kPending) { // Add the sequence to the model, or fork the sequence from its parent. if (rsentry->parent_idx == -1) { @@ -286,8 +288,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // - Update the committed tokens of states. // - If a request is first-time prefilled, set the prefill finish time. auto tnow = std::chrono::high_resolution_clock::now(); - for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { - if (model_id == 0) { + if (model_id == 0) { + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { for (int mid = 0; mid < static_cast(models_.size()); ++mid) { rsentries_for_sample[i]->mstates[mid]->CommitToken(sample_results[i]); if (!rsentry_activated[i]) { @@ -301,13 +303,24 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { rsentries_for_sample[i]->tprefill_finish = tnow; } - } else { - // - Slice hidden_states_for_sample - NDArray last_hidden_on_device = GetTokenHidden(hidden_states_for_sample, i); - CHECK(i < static_cast(prob_dist.size())); - CHECK(prob_dist[i].defined()); - rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i], prob_dist[i], - last_hidden_on_device); + } + } else { + // - Slice and save hidden_states_for_sample + draft_token_workspace_manager_->AllocSlots(rsentries_for_sample.size(), + &draft_token_slots_); + models_[model_id]->ScatterDraftProbs(renormalized_probs, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); + if (engine_config_->spec_draft_length > 1) { + hidden_states_for_sample = hidden_states_for_sample.CreateView( + {hidden_states_for_sample->shape[0] * hidden_states_for_sample->shape[1], + hidden_states_for_sample->shape[2]}, + hidden_states_for_sample->dtype); + models_[model_id]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_, + &model_workspaces_[0].draft_hidden_states_storage); + } + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i], + draft_token_slots_[i]); estate->stats.total_draft_length += 1; } } @@ -582,20 +595,25 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; + /*! \brief The draft token workspace manager. */ + DraftTokenWorkspaceManager draft_token_workspace_manager_; /*! \brief The engine config. */ EngineConfig engine_config_; /*! \brief Event trace recorder. */ Optional trace_recorder_; + /*! \brief Temporary buffer to store the slots of the current draft tokens */ + std::vector draft_token_slots_; }; -EngineAction EngineAction::EagleNewRequestPrefill(Array models, - LogitProcessor logit_processor, Sampler sampler, - std::vector model_workspaces, - EngineConfig engine_config, - Optional trace_recorder) { +EngineAction EngineAction::EagleNewRequestPrefill( + Array models, LogitProcessor logit_processor, Sampler sampler, + std::vector model_workspaces, + DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, + Optional trace_recorder) { return EngineAction(make_object( std::move(models), std::move(logit_processor), std::move(sampler), - std::move(model_workspaces), std::move(engine_config), std::move(trace_recorder))); + std::move(model_workspaces), std::move(draft_token_workspace_manager), + std::move(engine_config), std::move(trace_recorder))); } } // namespace serve diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index b4192a04f1..f801b1e282 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -100,7 +100,7 @@ class NewRequestPrefillActionObj : public EngineActionObj { } ICHECK(mstate->draft_output_tokens.empty()); - ICHECK(mstate->draft_output_prob_dist.empty()); + ICHECK(mstate->draft_token_slots.empty()); if (status_before_prefill[i] == RequestStateStatus::kPending) { // Add the sequence to the model, or fork the sequence from its parent. if (rsentry->parent_idx == -1) { diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 3267f1dd38..4e0301eb2d 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -93,7 +93,7 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object this->mod_get_func = [this, fmodule_get_function = sess->GetGlobalFunc("runtime.ModuleGetFunction")]( const std::string& name) -> PackedFunc { - DRef func = sess->CallPacked(fmodule_get_function, this->disco_mod, name, false); + DRef func = sess->CallPacked(fmodule_get_function, this->disco_mod, name, true); bool exists = (func->DebugGetFromRemote(0).operator PackedFunc()) != nullptr; if (!exists) { return PackedFunc(nullptr); @@ -259,6 +259,11 @@ void FunctionTable::_InitFunctions() { this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of"); this->nd_copy_embedding_to_offset_func_ = get_global_func("mlc.copy_embedding_to_offset"); support_backtracking_kv_ = true; + + this->gather_probs_func_ = mod->GetFunction("gather_probs", true); + this->scatter_probs_func_ = mod->GetFunction("scatter_probs", true); + this->gather_hidden_states_func_ = mod->GetFunction("gather_hidden_states", true); + this->scatter_hidden_states_func_ = mod->GetFunction("scatter_hidden_states", true); } ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) const { @@ -272,8 +277,8 @@ ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) } ObjectRef FunctionTable::CopyToWorker0(const NDArray& host_array, String buffer_cache_key, - ShapeTuple max_reserved_shape) { - if (this->use_disco) { + ShapeTuple max_reserved_shape, bool local_only) { + if (this->use_disco && !local_only) { Device null_device{DLDeviceType(0), 0}; DRef buffer(nullptr); auto it = this->cached_buffers.find(buffer_cache_key); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index bc2b4f21c8..e368edcb9c 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -50,8 +50,18 @@ struct FunctionTable { ObjectRef Empty(ShapeTuple shape, DataType dtype, Device device) const; + /*! + * \brief Copy a host array to the worker or local gpu. + * \param host_array The host array to be copied. + * \param buffer_cache_key The key to the buffer cache. + * \param max_reserved_shape The maximum shape to be reserved in the buffer cache. + * \param local_only Whether to copy the array to the local gpu only. If true, the use_disco + * flag will be ignored. This can be useful for functions that run only on the + * local gpu when disco is enabled. + * \return The array on the worker or local gpu. + */ ObjectRef CopyToWorker0(const NDArray& host_array, String buffer_cache_key, - ShapeTuple max_reserved_shape); + ShapeTuple max_reserved_shape, bool local_only = false); void DebugCallFuncOnAllAllWorker(const String& func_name) const; @@ -110,6 +120,11 @@ struct FunctionTable { PackedFunc nd_view_func_; PackedFunc nd_get_shape_func_; PackedFunc nd_copy_embedding_to_offset_func_; + // Auxiliary functions for speculative decoding. + PackedFunc gather_probs_func_; + PackedFunc scatter_probs_func_; + PackedFunc gather_hidden_states_func_; + PackedFunc scatter_hidden_states_func_; }; } // namespace serve diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index f7190d50ac..7ce70a0d26 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -289,7 +289,7 @@ class LogitProcessorImpl : public LogitProcessorObj { p_penalties[num_token_for_penalty * 3 + 2] = generation_cfg[i]->repetition_penalty; ++num_token_for_penalty; if (j > 0) { - mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray(), NDArray()); + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1); } } if (num_token_to_process != 1) { @@ -368,7 +368,7 @@ class LogitProcessorImpl : public LogitProcessorObj { p_seq_ids[token_start_offset + j] = 1; } if (j > 0) { - mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], NDArray(), NDArray()); + mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1); } } if (token_number != 1) { diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 6f34220219..8918cecdc4 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -246,14 +246,8 @@ class ModelImpl : public ModelObj { } NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); - // This step runs on the engine thread. - // By temporarily turning off the disco flag, this copies the logit_pos_nd to the cached device - // tensor without actually copying to the worker. - bool use_disco = ft_.use_disco; - ft_.use_disco = false; - ObjectRef logit_pos_dref_or_nd = - ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); - ft_.use_disco = use_disco; + ObjectRef logit_pos_dref_or_nd = ft_.CopyToWorker0(logit_pos_nd, "logit_pos_local", + {max_num_sequence_}, /*local_only=*/true); CHECK(ft_.batch_select_last_hidden_func_.defined()) << "`batch_select_last_hidden_states` function is not found in the model."; @@ -870,20 +864,21 @@ class ModelImpl : public ModelObj { // Allocate the hidden_states tensor. // Use the same function as embeddings. ObjectRef hidden_states = ft_.alloc_embedding_tensor_func_(); + NDArray hidden_states_nd{nullptr}; // Get the shape of the hidden_states tensor for hidden size. - ShapeTuple hidden_states_shape; if (ft_.use_disco) { ICHECK(hidden_states->IsInstance()); - ObjectRef shape_ref = ft_.nd_get_shape_func_(hidden_states); - hidden_states_shape = Downcast(shape_ref)->DebugGetFromRemote(0); + hidden_states_nd = Downcast(hidden_states)->DebugGetFromRemote(0); } else { - NDArray hidden_states_nd = Downcast(hidden_states); - hidden_states_shape = hidden_states_nd.Shape(); + hidden_states_nd = Downcast(hidden_states); } + ShapeTuple hidden_states_shape = hidden_states_nd.Shape(); ICHECK_EQ(hidden_states_shape.size(), 2); ICHECK_EQ(hidden_states_shape[0], prefill_chunk_size_); this->hidden_size_ = hidden_states_shape[1]; - return hidden_states; + this->hidden_states_dtype_ = hidden_states_nd->dtype; + // TODO(wuwei): We can keep hidden_states on the worker after refactor + return hidden_states_nd; } void Reset() final { @@ -893,6 +888,59 @@ class ModelImpl : public ModelObj { } } + /********************** Utilities for speculative decoding **********************/ + + DraftTokenWorkspaceManager CreateDraftTokenWorkspaceManager(int max_num_tokens) { + return DraftTokenWorkspaceManager(max_num_tokens, vocab_size_, hidden_size_, + hidden_states_dtype_, device_, ft_); + } + + ObjectRef GatherHiddenStates(const ObjectRef& input, const std::vector& indices, + ObjectRef* dst) final { + NDArray dst_view = Downcast(*dst).CreateView( + {static_cast(indices.size()), hidden_size_}, hidden_states_dtype_); + NDArray indices_nd = + logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); + indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ObjectRef indices_device = + ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ft_.gather_hidden_states_func_(input, indices_device, dst_view); + return dst_view; + } + + void ScatterHiddenStates(const ObjectRef& input, const std::vector& indices, + ObjectRef* dst) final { + NDArray indices_nd = + logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); + indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ObjectRef indices_device = + ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ft_.scatter_hidden_states_func_(input, indices_device, *dst); + } + + NDArray GatherDraftProbs(const NDArray& input, const std::vector& indices, + NDArray* dst) final { + NDArray dst_view = + dst->CreateView({static_cast(indices.size()), vocab_size_}, DataType::Float(32)); + NDArray indices_nd = + logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); + indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ObjectRef indices_device = + ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ft_.gather_probs_func_(input, indices_device, dst_view); + return dst_view; + } + + void ScatterDraftProbs(const NDArray& input, const std::vector& indices, + NDArray* dst) final { + NDArray indices_nd = + logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); + indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ObjectRef indices_device = + ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ft_.scatter_probs_func_(input, indices_device, *dst); + } + /************** Debug/Profile **************/ void DebugCallFuncOnAllAllWorker(const String& func_name) final { @@ -951,6 +999,7 @@ class ModelImpl : public ModelObj { int max_num_sequence_ = -1; int prefill_chunk_size_ = -1; int hidden_size_ = -1; + DLDataType hidden_states_dtype_; int vocab_size_ = -1; int image_embed_size_ = -1; //---------------------------- diff --git a/cpp/serve/model.h b/cpp/serve/model.h index bc63840a74..d672739581 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -12,6 +12,7 @@ #include "../base.h" #include "config.h" +#include "draft_token_workspace_manager.h" #include "event_trace_recorder.h" #include "function_table.h" #include "logit_processor.h" @@ -40,10 +41,26 @@ struct ModelWorkspace { */ ObjectRef embeddings{nullptr}; /*! - * \brief The hidden_states tensor. It can be either an NDArray when tensor + * \brief The hidden_states tensor for the current batch. It can be either an NDArray when tensor * model parallelism is not enabled, or a DRef when using tensor model parallelism. */ ObjectRef hidden_states{nullptr}; + + /*! + * \brief The draft token probabilities tensor for the current batch. + */ + NDArray draft_probs{nullptr}; + + /*! + * \brief The hidden_states tensor storing the hidden_states of draft tokens of all requests. + */ + ObjectRef draft_hidden_states_storage{nullptr}; + + /*! + * \brief The draft token probabilities tensor storing the probabilities of draft tokens of all + * requests. + */ + NDArray draft_probs_storage{nullptr}; }; /*! @@ -302,6 +319,27 @@ class ModelObj : public Object { /*! \brief Reset the model KV cache and other statistics. */ virtual void Reset() = 0; + /*********************** Utilities for speculative decoding. ***********************/ + + virtual DraftTokenWorkspaceManager CreateDraftTokenWorkspaceManager(int max_num_token) = 0; + + /*! \brief Gather the hidden_states of the given indices and in-place update the dst tensor. */ + virtual ObjectRef GatherHiddenStates(const ObjectRef& input, const std::vector& indices, + ObjectRef* dst) = 0; + + /*! \brief Scatter the hidden_states of the given indices to the dst tensor. */ + virtual void ScatterHiddenStates(const ObjectRef& input, const std::vector& indices, + ObjectRef* dst) = 0; + + /*! \brief Gather the draft token probabilities of the given indices and in-place update the dst + * tensor. */ + virtual NDArray GatherDraftProbs(const NDArray& input, const std::vector& indices, + NDArray* dst) = 0; + + /*! \brief Scatter the draft token probabilities of the given indices to the dst tensor. */ + virtual void ScatterDraftProbs(const NDArray& input, const std::vector& indices, + NDArray* dst) = 0; + /************** Debug/Profile **************/ /*! \brief Call the given global function on all workers. Only for debug purpose. */ diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index b1f5ae27a2..4c59ae52a2 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -59,11 +59,9 @@ void RequestModelStateNode::CommitToken(SampleResult sampled_token) { } } -void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, NDArray prob_dist, - NDArray last_hidden_on_device) { +void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, int draft_token_slot) { draft_output_tokens.push_back(std::move(sampled_token)); - draft_output_prob_dist.push_back(std::move(prob_dist)); - draft_last_hidden_on_device.push_back(std::move(last_hidden_on_device)); + draft_token_slots.push_back(draft_token_slot); appeared_token_ids[sampled_token.sampled_token_id.first] += 1; } @@ -71,14 +69,17 @@ void RequestModelStateNode::RemoveLastDraftToken() { ICHECK(!draft_output_tokens.empty()); auto it = appeared_token_ids.find(draft_output_tokens.back().sampled_token_id.first); draft_output_tokens.pop_back(); - draft_output_prob_dist.pop_back(); CHECK(it != appeared_token_ids.end()); if (--it->second == 0) { appeared_token_ids.erase(it); } } -void RequestModelStateNode::RemoveAllDraftTokens() { +void RequestModelStateNode::RemoveAllDraftTokens(std::vector* removed_draft_token_slots) { + if (removed_draft_token_slots != nullptr) { + removed_draft_token_slots->assign(draft_token_slots.begin(), draft_token_slots.end()); + } + draft_token_slots.clear(); while (!draft_output_tokens.empty()) { RemoveLastDraftToken(); } diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 950bb6e290..79abcb1a24 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -62,20 +62,8 @@ class RequestModelStateNode : public Object { * result of speculation. */ std::vector draft_output_tokens; - /*! - * \brief The probability distribution on each position in the - * draft. We keep the distributions for stochastic sampling when merging - * speculations from multiple models. - * \note We only need this value when we have multiple parallel small models - * and draft outputs in speculative inference settings. - */ - std::vector draft_output_prob_dist; - /*! - * \brief The last hidden_states used to get probs in drafting. - * \note We only need this value when we have multiple parallel small models - * and draft outputs in speculative inference settings. - */ - std::vector draft_last_hidden_on_device; + /*! \brief The storage slots for the associated states of draft tokens. */ + std::vector draft_token_slots; /*! \brief The appeared committed and draft tokens and their occurrence times. */ std::unordered_map appeared_token_ids; @@ -101,17 +89,18 @@ class RequestModelStateNode : public Object { /*! \brief Commit a new token into committed_tokens. Update appeared_token_ids. */ void CommitToken(SampleResult sampled_token); /*! \brief Add a draft token into draft_output_tokens. Update appeared_token_ids. */ - void AddDraftToken(SampleResult sampled_token, NDArray prob_dist, - NDArray draft_last_hidden_on_device = NDArray()); - /*! \brief Remove the last token from draft_output_tokens. Update appeared_token_ids. */ - void RemoveLastDraftToken(); + void AddDraftToken(SampleResult sampled_token, int draft_token_slot); /*! \brief Remove all draft tokens from draft_output_tokens. Update appeared_token_ids. */ - void RemoveAllDraftTokens(); + void RemoveAllDraftTokens(std::vector* removed_draft_token_slots = nullptr); static constexpr const char* _type_key = "mlc.serve.RequestModelState"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; TVM_DECLARE_BASE_OBJECT_INFO(RequestModelStateNode, Object); + + private: + /*! \brief Remove the last token from draft_output_tokens. Update appeared_token_ids. */ + void RemoveLastDraftToken(); }; class RequestModelState : public ObjectRef { diff --git a/cpp/serve/sampler/cpu_sampler.cc b/cpp/serve/sampler/cpu_sampler.cc index 98080c979d..196a6dd695 100644 --- a/cpp/serve/sampler/cpu_sampler.cc +++ b/cpp/serve/sampler/cpu_sampler.cc @@ -430,7 +430,7 @@ class CPUSampler : public SamplerObj { const std::vector& cum_verify_lengths, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { + NDArray draft_probs_on_device) final { // probs_on_host: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start draft verification"); CHECK_EQ(probs_on_host->ndim, 2); @@ -438,8 +438,8 @@ class CPUSampler : public SamplerObj { int num_sequence = static_cast(cum_verify_lengths.size()) - 1; CHECK_EQ(rngs.size(), num_sequence); CHECK_EQ(draft_output_tokens.size(), num_sequence); - CHECK_EQ(draft_output_prob_dist.size(), num_sequence); + NDArray draft_probs_on_host = draft_probs_on_device.CopyTo(DLDevice{kDLCPU, 0}); std::vector> sample_results; sample_results.resize(num_sequence); @@ -451,6 +451,7 @@ class CPUSampler : public SamplerObj { [&](int i) { int verify_start = cum_verify_lengths[i]; int verify_end = cum_verify_lengths[i + 1]; + int cur_token_idx = 0; // Sub 1 to ignore the last prediction. for (; cur_token_idx < verify_end - verify_start - 1; ++cur_token_idx) { @@ -477,12 +478,9 @@ class CPUSampler : public SamplerObj { // normalize a new probability distribution double sum_v = 0.0; - NDArray q_dist = draft_output_prob_dist[i][cur_token_idx]; - ICHECK(q_dist->device.device_type == kDLCPU); - ICHECK(q_dist->ndim == 1); - ICHECK(vocab_size == q_dist->shape[q_dist->ndim - 1]); const float* __restrict p_qdist = - static_cast(__builtin_assume_aligned(q_dist->data, 4)); + static_cast(__builtin_assume_aligned(draft_probs_on_host->data, 4)) + + (verify_start + cur_token_idx + 1) * vocab_size; for (int j = 0; j < vocab_size; ++j) { p_probs[j] = std::max(p_probs[j] - p_qdist[j], 0.0f); diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index 58a27c24f7..c6f463eb32 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -167,7 +167,7 @@ class GPUSampler : public SamplerObj { const std::vector& cum_verify_lengths, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) final { + NDArray draft_probs_on_device) final { NVTXScopedRange nvtx_scope("BatchVerifyDraftTokensWithProbAfterTopP"); std::vector> sample_results; // probs_on_device: (n, v) @@ -177,38 +177,27 @@ class GPUSampler : public SamplerObj { int num_sequence = static_cast(cum_verify_lengths.size()) - 1; CHECK_EQ(rngs.size(), num_sequence); CHECK_EQ(draft_output_tokens.size(), num_sequence); - CHECK_EQ(draft_output_prob_dist.size(), num_sequence); sample_results.resize(num_sequence); int num_nodes = cum_verify_lengths.back(); + CHECK_EQ(draft_probs_on_device->shape[0], num_nodes); NDArray uniform_samples_host = uniform_samples_host_.CreateView({num_nodes}, dtype_f32_); NDArray uniform_samples_device = uniform_samples_device_.CreateView({num_nodes}, dtype_f32_); - NDArray draft_probs_device = - draft_probs_device_.CreateView({num_nodes, vocab_size_}, dtype_f32_); NDArray draft_tokens_host = draft_tokens_host_.CreateView({num_nodes}, dtype_i32_); NDArray draft_tokens_device = draft_tokens_device_.CreateView({num_nodes}, dtype_i32_); - // Concat draft prob distributions to a ragged tensor (num_nodes, vocab_size) + // Copy draft tokens to GPU + int* p_draft_tokens_host = static_cast(draft_tokens_host->data); for (int i = 0; i < num_sequence; i++) { const std::vector& draft_output_tokens_i = draft_output_tokens[i]; - const std::vector& draft_output_prob_dist_i = draft_output_prob_dist[i]; int start = cum_verify_lengths[i]; int end = cum_verify_lengths[i + 1]; // start/end is the range of the sequence i in probs_on_device, which includes the prob dist // of the draft tokens and the last committed token ICHECK_EQ(draft_output_tokens_i.size() + 1, end - start); - ICHECK_EQ(draft_output_prob_dist_i.size() + 1, end - start); for (int j = 0; j < end - start - 1; j++) { - // Copy prob dist - ICHECK_EQ(draft_probs_device->dtype.bits, 32); - float* p_draft_probs = - static_cast(draft_probs_device->data) + - (j + start + 1) * - vocab_size_; // shift by one, q of the last committed token is undefined // Copy sampled token id - draft_output_prob_dist_i[j].CopyToBytes(p_draft_probs, vocab_size_ * sizeof(float)); - *(static_cast(draft_tokens_host->data) + j + start + 1) = - draft_output_tokens_i[j].sampled_token_id.first; + p_draft_tokens_host[start + j + 1] = draft_output_tokens_i[j].sampled_token_id.first; } } CopyArray(draft_tokens_host, draft_tokens_device, copy_stream_); @@ -262,7 +251,7 @@ class GPUSampler : public SamplerObj { SyncCopyStream(device_, compute_stream_, copy_stream_); - gpu_verify_draft_tokens_func_(draft_probs_device, draft_tokens_device, probs_on_device, + gpu_verify_draft_tokens_func_(draft_probs_on_device, draft_tokens_device, probs_on_device, token_tree_first_child_device, token_tree_next_sibling_device, uniform_samples_device, token_tree_parent_ptr_device); diff --git a/cpp/serve/sampler/sampler.h b/cpp/serve/sampler/sampler.h index 7943231e55..59e433ac47 100644 --- a/cpp/serve/sampler/sampler.h +++ b/cpp/serve/sampler/sampler.h @@ -108,15 +108,16 @@ class SamplerObj : public Object { * \param rngs The random number generator of each sequence. * \param draft_output_tokens The draft tokens generated by the small model for * each sequence. - * \param draft_output_prob_dist The probability distribution computed from the - * small model for each sequence. + * \param draft_probs_on_device The probability distribution computed from the + * small model for each sequence. Concatenated tensor of shape (total_verify_length, vocab_size). + * It includes the slot for the last committed token that has undefined probablity value. * \return The list of accepted tokens for each request. */ virtual std::vector> BatchVerifyDraftTokensWithProbAfterTopP( NDArray probs, const Array& request_ids, const std::vector& cum_verify_lengths, const Array& generation_cfg, const std::vector& rngs, const std::vector>& draft_output_tokens, - const std::vector>& draft_output_prob_dist) = 0; + NDArray draft_probs_on_device) = 0; static constexpr const char* _type_key = "mlc.serve.Sampler"; static constexpr const bool _type_has_method_sequal_reduce = false; diff --git a/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py b/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py new file mode 100644 index 0000000000..b7cfd76fa3 --- /dev/null +++ b/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py @@ -0,0 +1,66 @@ +"""The pass that attaches logit processor functions to the IRModule.""" + +import tvm +from tvm import IRModule +from tvm.script import tir as T + + +@tvm.transform.module_pass(opt_level=0, name="AttachSpecDecodeAuxFuncs") +class AttachSpecDecodeAuxFuncs: # pylint: disable=too-few-public-methods + """Attach logit processing TIR functions to IRModule.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + mod = mod.clone() + mod["scatter_probs"] = _get_scatter_2d_inplace( + dtype="float32", global_symbol="scatter_probs" + ) + mod["gather_probs"] = _get_gather_2d_inplace(dtype="float32", global_symbol="gather_probs") + if "prefill_to_last_hidden_states" in mod: + hidden_states_struct_info = mod["prefill_to_last_hidden_states"].ret_struct_info.fields[ + 0 + ] # pylint: disable=no-member + dtype = hidden_states_struct_info.dtype + mod["scatter_hidden_states"] = _get_scatter_2d_inplace( + dtype, global_symbol="scatter_hidden_states" + ) + mod["gather_hidden_states"] = _get_gather_2d_inplace( + dtype, global_symbol="gather_hidden_states" + ) + return mod + + +def _get_scatter_2d_inplace(dtype: str, global_symbol: str): + @T.prim_func + def _scatter_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): + T.func_attr({"global_symbol": global_symbol, "tir.noalias": True}) + batch_size = T.int32(is_size_var=True) + m = T.int32(is_size_var=True) + n = T.int32(is_size_var=True) + src = T.match_buffer(var_src, (batch_size, n), dtype) + indices = T.match_buffer(var_indices, (batch_size,), "int32") + dst = T.match_buffer(var_dst, (m, n), dtype) + for b, j in T.grid(batch_size, n): + with T.block("scatter_2d"): + vb, vj = T.axis.remap("SS", [b, j]) + dst[indices[vb], vj] = src[vb, vj] + + return _scatter_2d + + +def _get_gather_2d_inplace(dtype: str, global_symbol: str): + @T.prim_func + def _gather_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): + T.func_attr({"global_symbol": global_symbol, "tir.noalias": True}) + batch_size = T.int32(is_size_var=True) + m = T.int32(is_size_var=True) + n = T.int32(is_size_var=True) + src = T.match_buffer(var_src, (m, n), dtype) + indices = T.match_buffer(var_indices, (batch_size,), "int32") + dst = T.match_buffer(var_dst, (batch_size, n), dtype) + for b, j in T.grid(batch_size, n): + with T.block("gather_2d"): + vb, vj = T.axis.remap("SS", [b, j]) + dst[vb, vj] = src[indices[vb], vj] + + return _gather_2d diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index 57b68f742d..3c80d2c4df 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -15,6 +15,7 @@ from .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc from .attach_logit_processor import AttachLogitProcessFunc from .attach_sampler import AttachGPUSamplingFunc +from .attach_spec_decode_aux_funcs import AttachSpecDecodeAuxFuncs from .attach_support_info import ( AttachAdditionalPrimFuncs, AttachCUDAGraphSymbolicCaptureHints, @@ -104,6 +105,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I AttachAdditionalPrimFuncs(additional_tirs), AttachAllocEmbeddingTensorFunc(metadata), AttachGPUSamplingFunc(target, variable_bounds), + AttachSpecDecodeAuxFuncs(), AttachMemoryPlanAttr(), tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)), _DebugDump("debug-phase0.py", debug_dump, show_meta=False), diff --git a/python/mlc_llm/compiler_pass/rewrite_softmax.py b/python/mlc_llm/compiler_pass/rewrite_softmax.py index 1a6e41eafc..82e6cf863b 100644 --- a/python/mlc_llm/compiler_pass/rewrite_softmax.py +++ b/python/mlc_llm/compiler_pass/rewrite_softmax.py @@ -34,7 +34,10 @@ def __init__(self, mod: IRModule, target: tvm.target.Target) -> None: def transform(self) -> IRModule: """Entry point""" - gv = self.mod.get_global_var("softmax_with_temperature") + func_name = "softmax_with_temperature" + if func_name not in self.mod: + return self.mod + gv = self.mod.get_global_var(func_name) updated_func = self.visit_expr(self.mod[gv]) self.builder_.update_func(gv, updated_func) return self.builder_.get() From 6a4357087dc5eb3828e6756276ede7fbf348ff4a Mon Sep 17 00:00:00 2001 From: Kevin_Xiong Date: Tue, 30 Apr 2024 20:46:54 +0800 Subject: [PATCH 253/531] [Fix] fix a typo in event_trace_recorder (#2253) * Fix typo in event_tracer --- cpp/serve/event_trace_recorder.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/serve/event_trace_recorder.cc b/cpp/serve/event_trace_recorder.cc index 8a930002fe..e0311716fd 100644 --- a/cpp/serve/event_trace_recorder.cc +++ b/cpp/serve/event_trace_recorder.cc @@ -51,7 +51,7 @@ class EventTraceRecorderImpl : public EventTraceRecorderObj { void AddEvent(const Array& request_ids, const std::string& event) final { double event_time = std::chrono::duration_cast>( std::chrono::system_clock::now().time_since_epoch()) - .count(); + .count(); // in seconds { std::lock_guard lock(mutex_); @@ -96,16 +96,16 @@ class EventTraceRecorderImpl : public EventTraceRecorderObj { name = event; phase = "i"; } - int64_t event_time_in_ms = static_cast(event_time * 1e6); + int64_t event_time_in_us = static_cast(event_time * 1e6); picojson::object event_json; event_json["name"] = picojson::value(name); event_json["ph"] = picojson::value(phase); - event_json["ts"] = picojson::value(event_time_in_ms); + event_json["ts"] = picojson::value(event_time_in_us); event_json["pid"] = picojson::value(static_cast(1)); event_json["tid"] = picojson::value(request_id); - events_to_sort.push_back({event_time_in_ms, picojson::value(event_json)}); + events_to_sort.push_back({event_time_in_us, picojson::value(event_json)}); } std::sort(events_to_sort.begin(), events_to_sort.end(), fcmp_events); for (auto [timestamp, event] : events_to_sort) { From ca7cdcc2652844381181ccdd3e1e8a5aca2aa0a8 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Tue, 30 Apr 2024 07:55:24 -0700 Subject: [PATCH 254/531] [Tokenizer] Support ByteLevel BPE in tokenizer token table (#2248) --- cpp/serve/engine.cc | 20 ++++- cpp/tokenizers.cc | 105 +++++++++++++++++++++---- cpp/tokenizers.h | 21 ++++- python/mlc_llm/interface/gen_config.py | 74 ++++++++++++++++- 4 files changed, 198 insertions(+), 22 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 9703dda472..755af998cd 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -56,9 +56,7 @@ class EngineImpl : public Engine { } this->request_stream_callback_ = std::move(request_stream_callback); this->trace_recorder_ = trace_recorder; - this->tokenizer_ = Tokenizer::FromPath(engine_config->model); - this->token_table_ = tokenizer_->TokenTable(); - this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); + // Step 2. Initialize each model independently. // Create the logit processor and sampler. this->models_.clear(); @@ -100,6 +98,21 @@ class EngineImpl : public Engine { engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1); } + // Step 3. Initialize tokenizer and grammar + this->tokenizer_ = Tokenizer::FromPath(engine_config->model); + std::string token_table_postproc_method; + if (model_configs[0].count("token_table_postproc_method") == 0) { + // Backward compatibility: use "byte-fallback" by default + token_table_postproc_method = "byte-fallback"; + } else { + token_table_postproc_method = + model_configs[0].at("token_table_postproc_method").get(); + } + this->token_table_ = + Tokenizer::PostProcessTokenTable(tokenizer_->TokenTable(), token_table_postproc_method); + this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); + + // Step 4. Initialize engine actions that represent state transitions. int max_num_tokens = engine_config->max_num_sequence; DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr}; if (engine_config->speculative_mode != SpeculativeMode::kDisable) { @@ -113,7 +126,6 @@ class EngineImpl : public Engine { this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); Sampler sampler = this->models_[0]->CreateSampler( max_num_tokens, static_cast(this->models_.size()), trace_recorder); - // Step 3. Initialize engine actions that represent state transitions. if (engine_config->speculative_mode != SpeculativeMode::kDisable) { // Speculative decoding is only possible for more than one model. ICHECK_GT(this->models_.size(), 1U); diff --git a/cpp/tokenizers.cc b/cpp/tokenizers.cc index ef866f3bfc..6fe9217520 100644 --- a/cpp/tokenizers.cc +++ b/cpp/tokenizers.cc @@ -9,10 +9,12 @@ #include #include +#include #include #include #include +#include "./support/encoding.h" #include "./support/load_bytes_from_file.h" namespace mlc { @@ -91,13 +93,8 @@ Tokenizer Tokenizer::FromPath(const String& _path) { LOG(FATAL) << "Cannot find any tokenizer under: " << _path; } -/*! - * \brief Post-process a raw token (which may be a raw byte or contain lower - * one eights block) to the actual token. - * We do this in order to conform with the tokenizers' setup. - */ -inline std::string PostProcessToken(std::string token) { - // 1. The token represents a byte. +/*! \brief ByteFallback decoder: transform tokens like <0x1B> to hex char byte 1B */ +inline std::string ByteFallbackDecoder(const std::string& token) { if (token.length() == 6 && token.substr(0, 3) == "<0x" && token.back() == '>') { int byte = 0; for (int i = 0; i < 2; ++i) { @@ -108,15 +105,82 @@ inline std::string PostProcessToken(std::string token) { ICHECK(byte >= 0 && byte < 256); return std::string(/*n=*/1, static_cast(byte)); } + return token; +} - // 2. The token contains "\u2581" which means space. - static const std::string& lower_one_eighth_block = "\u2581"; - size_t pos = token.find(lower_one_eighth_block); - while (pos != std::string::npos) { - token.replace(pos, /*n=*/lower_one_eighth_block.length(), /*str=*/" "); - pos = token.find(lower_one_eighth_block); +/*! \brief SpaceReplacer decoder: transform "\u2581" back to space */ +inline std::string SpaceReplacerDecoder(const std::string& token) { + // \u2581 is the unicode for "lower one eighth block" + // UTF8 encoding for \u2581 is 0xE2 0x96 0x81 + std::string result; + for (size_t i = 0; i < token.size(); ++i) { + if (i + 2 < token.size() && token[i] == char(0xE2) && token[i + 1] == char(0x96) && + token[i + 2] == char(0x81)) { + result += ' '; + i += 2; + } else { + result += token[i]; + } + } + return result; +} + +/*! \brief ByteLevel decoder: inverses the bytes-to-unicode transformation in the encoding + * process as in + * https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59 + */ +inline std::string ByteLevelDecoder(const std::string& token) { + // clang-format off + // The inverse map of bytes_to_unicode. -1 means there is no mapping to this unicode. + static const std::array unicode_to_byte_map = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, + 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, + 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, -1, + 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, + 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, + 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, + 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 127, 128, + 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, + 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 173 + }; + // clang-format on + + auto unicode_codepoints = ParseUTF8(token.c_str()); + std::string decoded; + + for (auto unicode_codepoint : unicode_codepoints) { + ICHECK(unicode_codepoint >= 0 && + unicode_codepoint < static_cast(unicode_to_byte_map.size())); + int byte = unicode_to_byte_map[unicode_codepoint]; + if (byte == -1) { + // If there is no mapping, add the codepoint itself to the result string + // Some tokenizer like Phi-2 have raw tokens like \t\t + decoded += static_cast(unicode_codepoint); + } else { + decoded += static_cast(byte); + } + } + return decoded; +} + +/*! + * \brief Post-process a raw token to the actual token with the given post-processing method. + */ +inline std::string PostProcessToken(const std::string& token, const std::string& postproc_method) { + if (postproc_method == "byte_fallback") { + return SpaceReplacerDecoder(ByteFallbackDecoder(token)); + } else if (postproc_method == "byte_level") { + return ByteLevelDecoder(token); + } else { + LOG(FATAL) << "Unknown post-processing method: " << postproc_method; } - return token; } const std::vector& TokenizerObj::TokenTable() { @@ -127,12 +191,21 @@ const std::vector& TokenizerObj::TokenTable() { int vocab_size = tokenizer->GetVocabSize(); token_table_.reserve(vocab_size); for (int32_t token_id = 0; token_id < vocab_size; ++token_id) { - std::string token = tokenizer->IdToToken(token_id); - token_table_.push_back(PostProcessToken(token)); + token_table_.push_back(tokenizer->IdToToken(token_id)); } return token_table_; } +std::vector Tokenizer::PostProcessTokenTable( + const std::vector& token_table, const std::string& postproc_method) { + std::vector postprocessed_token_table; + postprocessed_token_table.reserve(token_table.size()); + for (const std::string& token : token_table) { + postprocessed_token_table.push_back(PostProcessToken(token, postproc_method)); + } + return postprocessed_token_table; +} + TVM_REGISTER_GLOBAL("mlc.Tokenizer").set_body_typed([](const String& path) { return Tokenizer::FromPath(path); }); diff --git a/cpp/tokenizers.h b/cpp/tokenizers.h index 16d9ba456b..36fc0c23db 100644 --- a/cpp/tokenizers.h +++ b/cpp/tokenizers.h @@ -30,7 +30,7 @@ class TokenizerObj : public Object { std::vector Encode(const std::string& text) const; /*! \brief Decode token ids into text. */ std::string Decode(const std::vector& token_ids) const; - /*! \brief Return the token table of the tokenizer. */ + /*! \brief Return the token table of the tokenizer. Special tokens are included. */ const std::vector& TokenTable(); /*! @@ -64,6 +64,25 @@ class Tokenizer : public ObjectRef { /*! \brief Create a tokenizer from a directory path on disk. */ MLC_LLM_DLL static Tokenizer FromPath(const String& path); + /*! + * \brief Convert raw tokens provided by the tokenizer to their original string to simplify + * later processing. E.g. For LLaMA-2, convert "▁of" to " of". + * + * \param token_table The raw token table. + * \param postproc_method The postprocessing method to use. Now we only support "byte-fallback" + * and "byte-level", which refers to the type of the decoder of the tokenizer. + * - "byte-fallback": Use the decoding method in the byte-fallback BPE tokenizer. This is used + * by LLaMA-2, Mixtral-7b, etc. This method: 1) transform tokens like <0x1B> to hex char + * byte 1B. (known as the byte-fallback method); 2) transform \\u2581 to space. + * - "byte-level": Use the decoding method in the byte-level BPE tokenizer. This is used by + * LLaMA-3, GPT-2, Phi-2, etc. This method inverses the bytes-to-unicode transformation in + * the encoding process as in + * https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59 + * \returns The postprocessed token table containing the original strings. + */ + static std::vector PostProcessTokenTable(const std::vector& token_table, + const std::string& postproc_method); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Tokenizer, ObjectRef, TokenizerObj); private: diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index 8e617fc3d2..13f0e1215f 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -5,7 +5,7 @@ import re import shutil from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from mlc_llm.conversation_template import ConvTemplateRegistry from mlc_llm.model import Model @@ -51,7 +51,11 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes pad_token_id: int = None bos_token_id: int = None eos_token_id: int = None + # Tokenizer configuration tokenizer_files: List[str] = dataclasses.field(default_factory=list) + # The method to post-process the token table. See + # cpp/tokenizers.h::Tokenizer::PostProcessTokenTable for details + token_table_postproc_method: Literal["byte_fallback", "byte_level"] = None # Version control version: str = VERSION @@ -129,6 +133,70 @@ def json2rwkv_tokenizer(vocab: Path, out: Path) -> None: msgpack.pack(idx2token, f) +def detect_token_table_postproc_method(output_path: Path) -> Literal["byte_fallback", "byte_level"]: + """Detect the token table postprocessing method from tokenizer.json that is found under + output_path. If not detected, use ByteFallback as default. + + Check the decoder field of the tokenizer. If it uses ByteFallback decoder, return + "byte_fallback". If it uses ByteLevel decoder, return "byte_level". Otherwise, use + ByteFallback as default. + + See also cpp/tokenizers.h::Tokenizer::PostProcessTokenTable. + """ + output_tokenizer_path = output_path / "tokenizer.json" + if not output_tokenizer_path.exists(): + logger.warning( + "Tokenizer token table postprocessing method is not detected as tokenizer.json " + "is not found, use ByteFallback (the same as LLaMA/LLaMA2) by default" + ) + return "byte_fallback" + + with output_tokenizer_path.open("r", encoding="utf-8") as in_file: + tokenizer_json = json.load(in_file) + + # Find all decoders in tokenizer.json + decoders = [] + + if "decoder" not in tokenizer_json: + logger.warning( + "Decoder field is not found in tokenizer.json, use ByteFallback (the same as " + "LLaMA/LLaMA2) as the token table postprocessing method by default" + ) + return "byte_fallback" + + decoders_json = tokenizer_json["decoder"] + assert "type" in decoders_json, "Decoder type is not specified in tokenizer.json" + if decoders_json["type"] == "Sequence": + assert "decoders" in decoders_json + decoders = decoders_json["decoders"] + else: + decoders = [decoders_json] + + is_byte_level = False + is_byte_fallback = False + + for decoder in decoders: + if decoder["type"] == "ByteLevel": + is_byte_level = True + if decoder["type"] == "ByteFallback": + is_byte_fallback = True + assert not ( + is_byte_level and is_byte_fallback + ), "Tokenizer decoder cannot have both type ByteLevel and type ByteFallback" + + if is_byte_level: + return "byte_level" + if is_byte_fallback: + return "byte_fallback" + + logger.warning( + "Neither ByteLevel nor ByteFallback decoder is detected in tokenizer.json, use " + "ByteFallback (the same as LLaMA/LLaMA2) as the token table postprocessing method " + "by default" + ) + return "byte_fallback" + + def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements config: Path, model: Model, @@ -255,6 +323,10 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b except Exception: # pylint: disable=broad-exception-caught logger.exception("%s with the exception below. Skipping", FAILED) + # 3.4. Find the token table postprocessing method from tokenizer.json if it exists. If not + # detected, use "byte_fallback" as default. + mlc_chat_config.token_table_postproc_method = detect_token_table_postproc_method(output) + # Step 4. Load system default value mlc_chat_config.apply_defaults() # Step 5. Dump the configuration file to output directory From 51391c3c1b720378694c876ca8b84d8cc9400907 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 30 Apr 2024 17:58:58 -0700 Subject: [PATCH 255/531] [Eagle] Avoid worker - engine transfer for hidden states (#2256) --- cpp/serve/draft_token_workspace_manager.cc | 2 +- cpp/serve/engine_actions/eagle_batch_draft.cc | 22 +- .../engine_actions/eagle_batch_verify.cc | 55 +-- .../eagle_new_request_prefill.cc | 70 ++- cpp/serve/function_table.cc | 7 +- cpp/serve/function_table.h | 1 + cpp/serve/model.cc | 407 +++++------------- cpp/serve/model.h | 46 +- cpp/serve/sampler/gpu_sampler.cc | 2 - python/mlc_llm/interface/compile.py | 7 +- python/mlc_llm/model/eagle/eagle_model.py | 4 +- python/mlc_llm/model/llama/llama_model.py | 15 +- 12 files changed, 188 insertions(+), 450 deletions(-) diff --git a/cpp/serve/draft_token_workspace_manager.cc b/cpp/serve/draft_token_workspace_manager.cc index 185b899e14..d004e91ee5 100644 --- a/cpp/serve/draft_token_workspace_manager.cc +++ b/cpp/serve/draft_token_workspace_manager.cc @@ -45,7 +45,7 @@ void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace, NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_); if (require_hidden_states) { workspace->draft_hidden_states_storage = - NDArray::Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_); + ft_.Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_); } } diff --git a/cpp/serve/engine_actions/eagle_batch_draft.cc b/cpp/serve/engine_actions/eagle_batch_draft.cc index 7ad66a045c..b4e7ec4c39 100644 --- a/cpp/serve/engine_actions/eagle_batch_draft.cc +++ b/cpp/serve/engine_actions/eagle_batch_draft.cc @@ -83,19 +83,15 @@ class EagleBatchDraftActionObj : public EngineActionObj { mstates.push_back(rsentry->mstates[model_id]); } // draft_length_ rounds of draft proposal. - ObjectRef last_hidden_states{nullptr}; - NDArray hidden_states = Downcast(model_workspaces_[model_id].hidden_states); + ObjectRef hidden_states = model_workspaces_[model_id].hidden_states; // Concat last hidden_states draft_token_slots_.clear(); if (draft_length_ > 1) { for (int i = 0; i < num_rsentries; ++i) { draft_token_slots_.push_back(mstates[i]->draft_token_slots.back()); } - hidden_states = Downcast(models_[model_id]->GatherHiddenStates( - model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states)); - ICHECK(hidden_states->ndim == 2); - last_hidden_states = hidden_states.CreateView( - {hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype); + hidden_states = models_[model_id]->GatherHiddenStates( + model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states); } // The first draft token has been generated in prefill/verify stage for (int draft_id = 1; draft_id < draft_length_; ++draft_id) { @@ -114,11 +110,10 @@ class EagleBatchDraftActionObj : public EngineActionObj { // - Invoke model decode. RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); - ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden( - embeddings, last_hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); - hidden_states = - models_[model_id]->BatchDecodeToLastHidden(fused_hidden_states, request_internal_ids); - last_hidden_states = hidden_states; + ObjectRef fused_embedding_hidden_states = models_[model_id]->FuseEmbedHidden( + embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states = models_[model_id]->BatchDecodeToLastHidden(fused_embedding_hidden_states, + request_internal_ids); NDArray logits; if (models_[model_id]->CanGetLogits()) { logits = models_[model_id]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, @@ -145,11 +140,10 @@ class EagleBatchDraftActionObj : public EngineActionObj { // Fill range [0, num_rsentries) into `sample_indices`. std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - std::vector prob_dist; NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( probs_on_device, sample_indices, request_ids, generation_cfg); std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index d52f60d5c7..f7c858192d 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -65,7 +65,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj { Array generation_cfg; std::vector rngs; std::vector> draft_output_tokens; - std::vector> draft_output_prob_dist; request_internal_ids.reserve(num_rsentries); all_tokens_to_verify.reserve(total_draft_length); verify_request_mstates.reserve(num_rsentries); @@ -113,12 +112,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding"); RECORD_EVENT(trace_recorder_, request_ids, "start verify"); - ObjectRef fused_hidden_states = models_[verify_model_id_]->FuseEmbedHidden( - embeddings, NDArray(), 1, cum_verify_lengths[num_rsentries]); - NDArray hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden( - fused_hidden_states, request_internal_ids, verify_lengths); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], 1); + ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden( + embeddings, request_internal_ids, verify_lengths); NDArray logits = models_[verify_model_id_]->GetLogits(hidden_states, 1, cum_verify_lengths[num_rsentries]); RECORD_EVENT(trace_recorder_, request_ids, "finish verify"); @@ -179,16 +174,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj { { // One step draft for the following steps - NDArray last_hidden_states_nd = hidden_states.CreateView( - {hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]}, - hidden_states->dtype); - hidden_states = Downcast(models_[draft_model_id_]->GatherHiddenStates( - last_hidden_states_nd, last_accepted_hidden_positions, - &model_workspaces_[draft_model_id_].hidden_states)); - ICHECK(hidden_states->ndim == 2); - hidden_states = hidden_states.CreateView( - {hidden_states->shape[0], 1, hidden_states->shape[1]}, hidden_states->dtype); + // Gather hidden states for the last accepted tokens. + hidden_states = models_[draft_model_id_]->GatherHiddenStates( + hidden_states, last_accepted_hidden_positions, + &model_workspaces_[draft_model_id_].hidden_states); std::vector input_tokens; Array mstates; @@ -210,10 +200,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // - Invoke model decode. RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); - ObjectRef fused_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( + ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); - hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden(fused_hidden_states, - request_internal_ids); + hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden( + fused_embedding_hidden_states, request_internal_ids); if (models_[draft_model_id_]->CanGetLogits()) { logits = models_[draft_model_id_]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, @@ -239,11 +229,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // Fill range [0, num_rsentries) into `sample_indices`. std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - std::vector prob_dist; NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( probs_on_device, sample_indices, request_ids, generation_cfg); std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); // - Slice and save hidden_states_for_sample @@ -251,10 +240,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj { models_[draft_model_id_]->ScatterDraftProbs( renormalized_probs, draft_token_slots_, &model_workspaces_[verify_model_id_].draft_probs_storage); - ICHECK(hidden_states->ndim == 3); - hidden_states = hidden_states.CreateView( - {hidden_states->shape[0] * hidden_states->shape[1], hidden_states->shape[2]}, - hidden_states->dtype); models_[draft_model_id_]->ScatterHiddenStates( hidden_states, draft_token_slots_, &model_workspaces_[verify_model_id_].draft_hidden_states_storage); @@ -326,26 +311,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj { return num_required_pages <= num_available_pages; } - /*! - * \brief Get one item from a hidden_states array, which corresponds to the last token. - * \param hidden_states The hidden_states of all the tokens. - * \param token_pos The desired token position in the sequence. - * \return The desired token's hidden_states - */ - NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { - ICHECK_EQ(hidden_states->ndim, 3); - NDArray last_hidden_on_device = - NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); - - int64_t ndata = hidden_states->shape[2]; - const int16_t* __restrict p_hidden = - static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + - (token_pos * ndata); - - last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); - return last_hidden_on_device; - } - /*! * \brief The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index 57310f7986..80de254ca8 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -83,8 +83,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // - Get embedding and run prefill for each model. std::vector prefill_lengths; prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/-1); - NDArray hidden_states_for_input{nullptr}; - NDArray hidden_states_for_sample{nullptr}; + ObjectRef hidden_states_for_input{nullptr}; + ObjectRef hidden_states_for_sample{nullptr}; NDArray logits_for_sample{nullptr}; // A map used to record the entry and child_idx pair needed to fork sequence. // The base model (id 0) should record all the pairs and all the small models @@ -167,14 +167,17 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { } RECORD_EVENT(trace_recorder_, request_ids, "start prefill"); - ObjectRef fused_hidden_states = models_[model_id]->FuseEmbedHidden( - embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length); - NDArray hidden_states = models_[model_id]->BatchPrefillToLastHidden( - fused_hidden_states, request_internal_ids, prefill_lengths); + ObjectRef embedding_or_hidden_states{nullptr}; + if (model_id == 0) { + embedding_or_hidden_states = embeddings; + } else { + embedding_or_hidden_states = models_[model_id]->FuseEmbedHidden( + embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length); + } + // hidden_states: (b * s, h) + ObjectRef hidden_states = models_[model_id]->BatchPrefillToLastHidden( + embedding_or_hidden_states, request_internal_ids, prefill_lengths); RECORD_EVENT(trace_recorder_, request_ids, "finish prefill"); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], 1); - ICHECK_EQ(hidden_states->shape[1], cum_prefill_length); if (model_id == 0) { // We only need to sample for model 0 in prefill. @@ -183,14 +186,23 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // Whether to use base model to get logits. int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id; - hidden_states_for_sample = models_[sample_model_id]->BatchSelectLastHidden( - hidden_states, request_internal_ids, prefill_lengths); + + std::vector logit_positions; + { + // Prepare the logit positions + logit_positions.reserve(prefill_lengths.size()); + int total_len = 0; + for (int i = 0; i < prefill_lengths.size(); ++i) { + total_len += prefill_lengths[i]; + logit_positions.push_back(total_len - 1); + } + } + // hidden_states_for_sample: (b * s, h) + hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates( + hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states); + // logits_for_sample: (b * s, v) logits_for_sample = models_[sample_model_id]->GetLogits(hidden_states_for_sample, 1, num_rsentries); - ICHECK_EQ(hidden_states_for_sample->ndim, 3); - ICHECK_EQ(hidden_states_for_sample->shape[0], 1); - ICHECK_EQ(hidden_states_for_sample->shape[1], num_rsentries); - // - Update logits. ICHECK(logits_for_sample.defined()); Array generation_cfg; @@ -278,11 +290,11 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { rsentry_activated.push_back(true); } } - std::vector prob_dist; + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( probs_on_device, sample_indices, request_ids, generation_cfg); std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); // - Update the committed tokens of states. @@ -311,10 +323,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { models_[model_id]->ScatterDraftProbs(renormalized_probs, draft_token_slots_, &model_workspaces_[0].draft_probs_storage); if (engine_config_->spec_draft_length > 1) { - hidden_states_for_sample = hidden_states_for_sample.CreateView( - {hidden_states_for_sample->shape[0] * hidden_states_for_sample->shape[1], - hidden_states_for_sample->shape[2]}, - hidden_states_for_sample->dtype); models_[model_id]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_, &model_workspaces_[0].draft_hidden_states_storage); } @@ -567,26 +575,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { ICHECK(false) << "Cannot reach here"; } - /*! - * \brief Get one item from a hidden_states array, which corresponds to the last token. - * \param hidden_states The hidden_states of all the tokens. - * \param token_pos The desired token position in the sequence. - * \return The desired token's hidden_states - */ - NDArray GetTokenHidden(NDArray hidden_states, int token_pos) { - ICHECK_EQ(hidden_states->ndim, 3); - NDArray last_hidden_on_device = - NDArray::Empty({hidden_states->shape[2]}, hidden_states->dtype, hidden_states->device); - - int64_t ndata = hidden_states->shape[2]; - const int16_t* __restrict p_hidden = - static_cast(__builtin_assume_aligned(hidden_states->data, 2)) + - (token_pos * ndata); - - last_hidden_on_device.CopyFromBytes(p_hidden, ndata * sizeof(int16_t)); - return last_hidden_on_device; - } - /*! \brief The models to run prefill in. */ Array models_; /*! \brief The logit processor. */ diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 4e0301eb2d..16db4a8a03 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -218,7 +218,7 @@ void FunctionTable::_InitFunctions() { Module mod = this->use_disco ? this->disco_mod->DebugGetFromRemote(0) : this->local_vm; this->get_logits_func_ = mod_get_func("get_logits"); this->batch_get_logits_func_ = mod_get_func("batch_get_logits"); - this->batch_select_last_hidden_func_ = mod->GetFunction("batch_select_last_hidden_states", true); + this->batch_select_last_hidden_func_ = mod_get_func("batch_select_last_hidden_states"); this->softmax_func_ = mod->GetFunction("softmax_with_temperature", true); this->apply_logit_bias_func_ = mod->GetFunction("apply_logit_bias_inplace", true); this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true); @@ -259,11 +259,12 @@ void FunctionTable::_InitFunctions() { this->nd_get_shape_func_ = get_global_func("vm.builtin.shape_of"); this->nd_copy_embedding_to_offset_func_ = get_global_func("mlc.copy_embedding_to_offset"); support_backtracking_kv_ = true; + this->tuple_getitem_func_ = get_global_func("vm.builtin.tuple_getitem"); this->gather_probs_func_ = mod->GetFunction("gather_probs", true); this->scatter_probs_func_ = mod->GetFunction("scatter_probs", true); - this->gather_hidden_states_func_ = mod->GetFunction("gather_hidden_states", true); - this->scatter_hidden_states_func_ = mod->GetFunction("scatter_hidden_states", true); + this->gather_hidden_states_func_ = mod_get_func("gather_hidden_states"); + this->scatter_hidden_states_func_ = mod_get_func("scatter_hidden_states"); } ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) const { diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index e368edcb9c..2350f3d37a 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -120,6 +120,7 @@ struct FunctionTable { PackedFunc nd_view_func_; PackedFunc nd_get_shape_func_; PackedFunc nd_copy_embedding_to_offset_func_; + PackedFunc tuple_getitem_func_; // Auxiliary functions for speculative decoding. PackedFunc gather_probs_func_; PackedFunc scatter_probs_func_; diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 8918cecdc4..be76b40e2e 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -5,7 +5,6 @@ */ #include "model.h" -#include #include #include #include @@ -137,35 +136,23 @@ class ModelImpl : public ModelObj { return ft_.get_logits_func_.defined() && ft_.batch_get_logits_func_.defined(); } - NDArray GetLogits(const ObjectRef& last_hidden_states, int batch_size, int seq_len) final { + NDArray GetLogits(const ObjectRef& hidden_states, int batch_size, int seq_len) final { NVTXScopedRange nvtx_scope("GetLogits"); CHECK(ft_.get_logits_func_.defined()) << "`get_logits` function is not found in the model."; - ObjectRef hidden_states_dref_or_nd; - CHECK(!last_hidden_states->IsInstance()); - // hidden_states: (b, s, h) - NDArray hidden_states = Downcast(last_hidden_states); - ICHECK_NE(hidden_size_, -1); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], batch_size); - ICHECK_EQ(hidden_states->shape[1], seq_len); - ICHECK_EQ(hidden_states->shape[2], hidden_size_); - ICHECK_EQ(hidden_states->device.device_type, device_.device_type); - ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - - hidden_states = - hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); - - // This copy can be avoided by not copying the hidden states to engine. - hidden_states_dref_or_nd = ft_.CopyToWorker0( - hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); + ObjectRef hidden_states_dref_or_nd{nullptr}; + if (!ft_.use_disco && hidden_states->IsInstance()) { + hidden_states_dref_or_nd = Downcast(hidden_states)->DebugGetFromRemote(0); + } else { + hidden_states_dref_or_nd = hidden_states; + } ObjectRef ret = ft_.get_logits_func_(hidden_states_dref_or_nd, params_); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } NDArray logits{nullptr}; - if (ret->IsInstance()) { + if (ft_.use_disco) { logits = Downcast(ret)->DebugGetFromRemote(0); } else { logits = Downcast(ret); @@ -177,142 +164,11 @@ class ModelImpl : public ModelObj { return logits.CreateView({batch_size, seq_len, logits->shape[1]}, logits->dtype); } - NDArray BatchGetLogits(const ObjectRef& last_hidden_states, const std::vector& seq_ids, - const std::vector& lengths) { - NVTXScopedRange nvtx_scope("BatchGetLogits"); - CHECK(!seq_ids.empty()); - CHECK_EQ(seq_ids.size(), lengths.size()); - int num_sequences = seq_ids.size(); - int total_length = 0; - - int* p_logit_pos = static_cast(logit_pos_arr_->data); - for (int i = 0; i < num_sequences; ++i) { - total_length += lengths[i]; - p_logit_pos[i] = total_length - 1; - } - NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); - ObjectRef logit_pos_dref_or_nd = - ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); - - CHECK(ft_.batch_get_logits_func_.defined()) - << "`batch_get_logits` function is not found in the model."; - - ObjectRef hidden_states_dref_or_nd; - CHECK(!last_hidden_states->IsInstance()); - // hidden_states: (b, s, h) - NDArray hidden_states = Downcast(last_hidden_states); - ICHECK_NE(hidden_size_, -1); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], 1); - ICHECK_EQ(hidden_states->shape[1], total_length); - ICHECK_EQ(hidden_states->shape[2], hidden_size_); - ICHECK_EQ(hidden_states->device.device_type, device_.device_type); - ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - - hidden_states = hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); - - // This copy can be avoided by not copying the hidden states to engine. - hidden_states_dref_or_nd = ft_.CopyToWorker0( - hidden_states, "hidden_states", {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); - - ObjectRef ret = - ft_.batch_get_logits_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd, params_); - if (trace_enabled_) { - TVMSynchronize(device_.device_type, device_.device_id, nullptr); - } - - NDArray logits; - logits = Downcast(ret); - CHECK(logits.defined()); - // logits: (b * s, v) - ICHECK_EQ(logits->ndim, 2); - ICHECK_EQ(logits->shape[0], num_sequences); - return logits.CreateView({1, num_sequences, logits->shape[1]}, logits->dtype); - } - - NDArray BatchSelectLastHidden(const ObjectRef& last_hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) { - NVTXScopedRange nvtx_scope("BatchSelectLastHidden"); - CHECK(!seq_ids.empty()); - CHECK_EQ(seq_ids.size(), lengths.size()); - int num_sequences = seq_ids.size(); - int total_length = 0; - - int* p_logit_pos = static_cast(logit_pos_arr_->data); - for (int i = 0; i < num_sequences; ++i) { - total_length += lengths[i]; - p_logit_pos[i] = total_length - 1; - } - NDArray logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32)); - - ObjectRef logit_pos_dref_or_nd = ft_.CopyToWorker0(logit_pos_nd, "logit_pos_local", - {max_num_sequence_}, /*local_only=*/true); - - CHECK(ft_.batch_select_last_hidden_func_.defined()) - << "`batch_select_last_hidden_states` function is not found in the model."; - - ObjectRef hidden_states_dref_or_nd; - CHECK(!last_hidden_states->IsInstance()); - // hidden_states: (b, s, h) - NDArray hidden_states = Downcast(last_hidden_states); - ICHECK_NE(hidden_size_, -1); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], 1); - ICHECK_EQ(hidden_states->shape[1], total_length); - ICHECK_EQ(hidden_states->shape[2], hidden_size_); - ICHECK_EQ(hidden_states->device.device_type, device_.device_type); - ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - - hidden_states_dref_or_nd = - hidden_states.CreateView({total_length, hidden_size_}, hidden_states->dtype); - - ObjectRef ret = - ft_.batch_select_last_hidden_func_(hidden_states_dref_or_nd, logit_pos_dref_or_nd); - if (trace_enabled_) { - TVMSynchronize(device_.device_type, device_.device_id, nullptr); - } - - NDArray hidden; - hidden = Downcast(ret); - // hidden: (b * s, v) - ICHECK_EQ(hidden->ndim, 2); - ICHECK_EQ(hidden->shape[0], num_sequences); - return hidden.CreateView({1, num_sequences, hidden->shape[1]}, hidden->dtype); - } - - NDArray ConcatLastHidden(std::vector& hidden_states, ObjectRef* dst) final { - NVTXScopedRange nvtx_scope("ConcatLastHidden"); - - CHECK(dst->defined()); - - int cum_length = 0; - ICHECK_GE(hidden_states.size(), 1); - for (auto hidden : hidden_states) { - ICHECK_EQ(hidden->ndim, 1); - // No ICHECK_EQ(hidden->shape[0], hidden_size_) here to allow different hidden_sizes. - hidden = hidden.CreateView({1, hidden_size_}, hidden->dtype); - // Reuse the copy embedding function - ObjectRef hidden_dref_or_nd = - ft_.CopyToWorker0(hidden, "hidden_for_concat", {1, hidden_size_}); - ft_.nd_copy_embedding_to_offset_func_(hidden_dref_or_nd, *dst, cum_length); - cum_length += 1; - } - NDArray ret{nullptr}; - if ((*dst)->IsInstance()) { - ret = Downcast(*dst)->DebugGetFromRemote(0); - } else { - ret = Downcast(*dst); - } - ret = ret.CreateView({cum_length, hidden_size_}, hidden_states[0]->dtype); - return ret; - } - ObjectRef FuseEmbedHidden(const ObjectRef& embeddings, const ObjectRef& previous_hidden_states, int batch_size, int seq_len) final { NVTXScopedRange nvtx_scope("FuseEmbedHidden"); - ObjectRef embeddings_dref_or_nd; + ObjectRef embeddings_dref_or_nd{nullptr}; if (!embeddings->IsInstance()) { // embeddings: (n, h) NDArray embeddings_nd = Downcast(embeddings); @@ -320,51 +176,33 @@ class ModelImpl : public ModelObj { ICHECK_EQ(embeddings_nd->ndim, 2); ICHECK_GE(embeddings_nd->shape[0], batch_size * seq_len); ICHECK_EQ(embeddings_nd->shape[1], hidden_size_); - ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); - ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); embeddings_dref_or_nd = embeddings_nd.CreateView({batch_size * seq_len, hidden_size_}, embeddings_nd->dtype); - - if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { - // Model has no support for fuse_embed_hidden_states or this is the first model (base model) - return embeddings_nd.CreateView({batch_size, seq_len, hidden_size_}, embeddings_nd->dtype); - } } else { ShapeTuple embedding_shape{batch_size * seq_len, hidden_size_}; embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); - - if (!ft_.fuse_embed_hidden_func_.defined() || !previous_hidden_states.defined()) { - // Model has no support for fuse_embed_hidden_states or this is the first model (base model) - ShapeTuple embedding_shape{batch_size, seq_len, hidden_size_}; - return ft_.nd_view_func_(embeddings, embedding_shape); - } } - NDArray hidden_states = Downcast(previous_hidden_states); - CHECK(hidden_states.defined()); - ICHECK_EQ(hidden_states->ndim, 3); - ICHECK_EQ(hidden_states->shape[0], batch_size); - ICHECK_EQ(hidden_states->shape[1], seq_len); - ICHECK_EQ(hidden_states->shape[2], hidden_size_); - ICHECK_EQ(hidden_states->device.device_type, device_.device_type); - ICHECK_EQ(hidden_states->device.device_id, device_.device_id); - NDArray hidden_states_2d = - hidden_states.CreateView({batch_size * seq_len, hidden_size_}, hidden_states->dtype); - auto hidden_states_dref_or_nd = - ft_.CopyToWorker0(hidden_states_2d, "hidden_states_2d", - {max_num_sequence_ * prefill_chunk_size_, hidden_size_}); - - ObjectRef ret = - ft_.fuse_embed_hidden_func_(embeddings_dref_or_nd, hidden_states_dref_or_nd, params_); + ObjectRef previous_hidden_states_dref_or_nd{nullptr}; + if (!ft_.use_disco && previous_hidden_states->IsInstance()) { + previous_hidden_states_dref_or_nd = + Downcast(previous_hidden_states)->DebugGetFromRemote(0); + } else { + previous_hidden_states_dref_or_nd = previous_hidden_states; + } + ObjectRef fused = ft_.fuse_embed_hidden_func_(embeddings_dref_or_nd, + previous_hidden_states_dref_or_nd, params_); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - if (!ret->IsInstance()) { - NDArray fused = Downcast(ret); - return fused.CreateView({batch_size, seq_len, hidden_size_}, fused->dtype); + ShapeTuple out_shape{batch_size, seq_len, hidden_size_}; + if (ft_.use_disco) { + return ft_.nd_view_func_(fused, out_shape); } else { - ShapeTuple fused_shape{batch_size, seq_len, hidden_size_}; - return ft_.nd_view_func_(ret, fused_shape); + NDArray fused_nd = Downcast(fused); + ICHECK_EQ(fused_nd->ndim, 2); + ICHECK_EQ(fused_nd->shape[0], batch_size * seq_len); + return fused_nd.CreateView(out_shape, fused_nd->dtype); } } @@ -439,9 +277,9 @@ class ModelImpl : public ModelObj { return logits; } - NDArray BatchPrefillToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) final { + ObjectRef BatchPrefillToLastHidden(const ObjectRef& embedding_or_hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) final { NVTXScopedRange nvtx_scope("BatchPrefillToLastHidden"); CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); @@ -452,19 +290,15 @@ class ModelImpl : public ModelObj { total_length += lengths[i]; } - ObjectRef hidden_states_dref_or_nd; - if (!hidden_states->IsInstance()) { - // hidden_states: (1, n, h) - NDArray hidden_states_nd = Downcast(hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 3); - ICHECK_EQ(hidden_states_nd->shape[0], 1); - ICHECK_EQ(hidden_states_nd->shape[1], total_length); - ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); - hidden_states_dref_or_nd = - hidden_states_nd.CreateView({1, total_length, hidden_size_}, hidden_states_nd->dtype); + ObjectRef embedding_or_hidden_states_dref_or_nd{nullptr}; + ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; + if (!ft_.use_disco) { + NDArray embedding_or_hidden_states_nd = Downcast(embedding_or_hidden_states); + embedding_or_hidden_states_dref_or_nd = embedding_or_hidden_states_nd.CreateView( + hidden_states_shape, embedding_or_hidden_states_nd->dtype); } else { - ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; - hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); + embedding_or_hidden_states_dref_or_nd = + ft_.nd_view_func_(embedding_or_hidden_states, hidden_states_shape); } CHECK(ft_.prefill_to_last_hidden_func_.defined()) @@ -479,32 +313,34 @@ class ModelImpl : public ModelObj { ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); // args: embeddings, logit_pos, kv_cache, params - ObjectRef ret; + ObjectRef result{nullptr}; if (seq_ids.size() == 1) { CHECK(ft_.single_batch_prefill_to_last_hidden_func_.defined()) << "`single_batch_prefill_to_last_hidden_states` function is not found in the model."; - ret = ft_.single_batch_prefill_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, - params_); + result = ft_.single_batch_prefill_to_last_hidden_func_(embedding_or_hidden_states_dref_or_nd, + kv_cache_, params_); } else { - ret = ft_.prefill_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); - } - NDArray last_hidden_states; - if (ft_.use_disco) { - Array result = Downcast(ret)->DebugGetFromRemote(0); - last_hidden_states = Downcast(result[0]); - } else { - last_hidden_states = Downcast>(ret)[0]; + result = ft_.prefill_to_last_hidden_func_(embedding_or_hidden_states_dref_or_nd, kv_cache_, + params_); } + ObjectRef hidden_states = ft_.tuple_getitem_func_(result, 0); + if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } ft_.kv_cache_end_forward_func_(kv_cache_); - // hidden_states: (1, total_length, v) - ICHECK_EQ(last_hidden_states->ndim, 3); - ICHECK_EQ(last_hidden_states->shape[0], 1); - ICHECK_EQ(last_hidden_states->shape[1], total_length); - return last_hidden_states; + ShapeTuple out_shape{total_length, hidden_size_}; + if (ft_.use_disco) { + return ft_.nd_view_func_(hidden_states, out_shape); + } else { + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], 1); + ICHECK_EQ(hidden_states_nd->shape[1], total_length); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + return hidden_states_nd.CreateView(out_shape, hidden_states_nd->dtype); + } } NDArray BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) final { @@ -567,8 +403,8 @@ class ModelImpl : public ModelObj { return logits; } - NDArray BatchDecodeToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids) final { + ObjectRef BatchDecodeToLastHidden(const ObjectRef& hidden_states_dref_or_nd, + const std::vector& seq_ids) final { NVTXScopedRange nvtx_scope("BatchDecodeToLastHidden"); int num_sequence = seq_ids.size(); @@ -578,21 +414,6 @@ class ModelImpl : public ModelObj { ICHECK(ft_.kv_cache_end_forward_func_.defined()); ICHECK(kv_cache_.defined()) << "KV cache has not been initialized."; - ObjectRef hidden_states_dref_or_nd; - if (!hidden_states->IsInstance()) { - // hidden_states: (1, n, h) - NDArray hidden_states_nd = Downcast(hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 3); - ICHECK_EQ(hidden_states_nd->shape[0], num_sequence); - ICHECK_EQ(hidden_states_nd->shape[1], 1); - ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); - hidden_states_dref_or_nd = - hidden_states_nd.CreateView({num_sequence, 1, hidden_size_}, hidden_states_nd->dtype); - } else { - ShapeTuple hidden_states_shape{num_sequence, 1, hidden_size_}; - hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); - } - // Reserve in KV cache for the lengths of the input. // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); @@ -600,32 +421,34 @@ class ModelImpl : public ModelObj { ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); // args: embeddings, kv_cache, params - ObjectRef ret; + ObjectRef result{nullptr}; if (seq_ids.size() == 1) { CHECK(ft_.single_batch_decode_to_last_hidden_func_.defined()) << "`decode_to_last_hidden_states` function is not found in the model."; - ret = ft_.single_batch_decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, - params_); - } else { - ret = ft_.decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); - } - NDArray last_hidden_states; - if (ft_.use_disco) { - Array result = Downcast(ret)->DebugGetFromRemote(0); - last_hidden_states = Downcast(result[0]); + result = ft_.single_batch_decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, + params_); } else { - last_hidden_states = Downcast>(ret)[0]; + result = ft_.decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); } + ft_.kv_cache_end_forward_func_(kv_cache_); + ObjectRef hidden_states = ft_.tuple_getitem_func_(result, 0); + if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - ft_.kv_cache_end_forward_func_(kv_cache_); - // hidden_states: (b, 1, v) - ICHECK_EQ(last_hidden_states->ndim, 3); - ICHECK_EQ(last_hidden_states->shape[0], num_sequence); - ICHECK_EQ(last_hidden_states->shape[1], 1); - return last_hidden_states; + // hidden_states: (b, 1, v) to (b, v) + ShapeTuple out_shape{num_sequence, hidden_size_}; + if (ft_.use_disco) { + return ft_.nd_view_func_(hidden_states, out_shape); + } else { + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], num_sequence); + ICHECK_EQ(hidden_states_nd->shape[1], 1); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + return hidden_states_nd.CreateView(out_shape, hidden_states_nd->dtype); + } } NDArray BatchVerify(const ObjectRef& embeddings, const std::vector& seq_ids, @@ -688,9 +511,9 @@ class ModelImpl : public ModelObj { return logits; } - NDArray BatchVerifyToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) final { + ObjectRef BatchVerifyToLastHidden(const ObjectRef& embeddings, + const std::vector& seq_ids, + const std::vector& lengths) final { NVTXScopedRange nvtx_scope("BatchVerifyToLastHidden"); CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); @@ -706,45 +529,46 @@ class ModelImpl : public ModelObj { ICHECK(ft_.kv_cache_end_forward_func_.defined()); ICHECK(kv_cache_.defined()) << "KV cache has not been initialized."; - ObjectRef hidden_states_dref_or_nd; - if (!hidden_states->IsInstance()) { - // hidden_states: (1, n, h) - NDArray hidden_states_nd = Downcast(hidden_states); - ICHECK_EQ(hidden_states_nd->ndim, 3); - ICHECK_EQ(hidden_states_nd->shape[0], 1); - ICHECK_EQ(hidden_states_nd->shape[1], total_length); - ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); - hidden_states_dref_or_nd = - hidden_states_nd.CreateView({1, total_length, hidden_size_}, hidden_states_nd->dtype); + ObjectRef embeddings_dref_or_nd; + if (!embeddings->IsInstance()) { + // embeddings: (1, n, h) + NDArray embeddings_nd = Downcast(embeddings); + ICHECK_NE(hidden_size_, -1); + ICHECK_EQ(embeddings_nd->ndim, 2); + ICHECK_GE(embeddings_nd->shape[0], total_length); + ICHECK_EQ(embeddings_nd->shape[1], hidden_size_); + ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type); + ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id); + embeddings_dref_or_nd = + embeddings_nd.CreateView({1, total_length, hidden_size_}, embeddings_nd->dtype); } else { - ShapeTuple hidden_states_shape{1, total_length, hidden_size_}; - hidden_states_dref_or_nd = ft_.nd_view_func_(hidden_states, hidden_states_shape); + ShapeTuple embedding_shape{1, total_length, hidden_size_}; + embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); } - // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); // args: embeddings, logit_pos, kv_cache, params - ObjectRef ret = ft_.verify_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_); - NDArray last_hidden_states; - if (ft_.use_disco) { - Array result = Downcast(ret)->DebugGetFromRemote(0); - last_hidden_states = Downcast(result[0]); - } else { - last_hidden_states = Downcast>(ret)[0]; - } + ObjectRef result = ft_.verify_to_last_hidden_func_(embeddings_dref_or_nd, kv_cache_, params_); + ft_.kv_cache_end_forward_func_(kv_cache_); + ObjectRef hidden_states = ft_.tuple_getitem_func_(result, 0); if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - ft_.kv_cache_end_forward_func_(kv_cache_); - // hidden_states: (1, total_length, v) - ICHECK_EQ(last_hidden_states->ndim, 3); - ICHECK_EQ(last_hidden_states->shape[0], 1); - ICHECK_EQ(last_hidden_states->shape[1], total_length); - return last_hidden_states; + ShapeTuple out_shape{total_length, hidden_size_}; + if (!ft_.use_disco) { + NDArray hidden_states_nd = Downcast(hidden_states); + ICHECK_EQ(hidden_states_nd->ndim, 3); + ICHECK_EQ(hidden_states_nd->shape[0], 1); + ICHECK_EQ(hidden_states_nd->shape[1], total_length); + ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_); + return hidden_states_nd.CreateView(out_shape, hidden_states_nd->dtype); + } else { + return ft_.nd_view_func_(hidden_states, out_shape); + } } /*********************** KV Cache Management ***********************/ @@ -877,8 +701,7 @@ class ModelImpl : public ModelObj { ICHECK_EQ(hidden_states_shape[0], prefill_chunk_size_); this->hidden_size_ = hidden_states_shape[1]; this->hidden_states_dtype_ = hidden_states_nd->dtype; - // TODO(wuwei): We can keep hidden_states on the worker after refactor - return hidden_states_nd; + return hidden_states; } void Reset() final { @@ -897,13 +720,18 @@ class ModelImpl : public ModelObj { ObjectRef GatherHiddenStates(const ObjectRef& input, const std::vector& indices, ObjectRef* dst) final { - NDArray dst_view = Downcast(*dst).CreateView( - {static_cast(indices.size()), hidden_size_}, hidden_states_dtype_); + ObjectRef dst_view{nullptr}; + ShapeTuple out_shape{static_cast(indices.size()), hidden_size_}; + if ((*dst)->IsInstance()) { + dst_view = ft_.nd_view_func_(*dst, out_shape); + } else { + NDArray dst_nd = Downcast(*dst); + dst_view = dst_nd.CreateView(out_shape, hidden_states_dtype_); + } NDArray indices_nd = logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); - ObjectRef indices_device = - ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos", {max_num_sequence_}); ft_.gather_hidden_states_func_(input, indices_device, dst_view); return dst_view; } @@ -913,8 +741,7 @@ class ModelImpl : public ModelObj { NDArray indices_nd = logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); - ObjectRef indices_device = - ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); + ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos", {max_num_sequence_}); ft_.scatter_hidden_states_func_(input, indices_device, *dst); } diff --git a/cpp/serve/model.h b/cpp/serve/model.h index d672739581..f587969bfb 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -7,6 +7,7 @@ #ifndef MLC_LLM_SERVE_MODEL_H_ #define MLC_LLM_SERVE_MODEL_H_ +#include #include #include @@ -139,35 +140,6 @@ class ModelObj : public Object { */ virtual NDArray GetLogits(const ObjectRef& last_hidden_states, int batch_size, int seq_len) = 0; - /*! - * \brief Compute logits for last hidden_states in a batch. - * \param last_hidden_states The last hidden_states to compute logits for. - * \param seq_ids The id of the sequence in the KV cache. - * \param lengths The length of each sequence to prefill. - * \return The computed logits. - */ - virtual NDArray BatchGetLogits(const ObjectRef& last_hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) = 0; - - /*! - * \brief Select desired hidden_states for last hidden_states in a batch. - * \param last_hidden_states The last hidden_states to select from. - * \param seq_ids The id of the sequence in the KV cache. - * \param lengths The length of each sequence to prefill. - * \return The last hidden_states for the batch. - */ - virtual NDArray BatchSelectLastHidden(const ObjectRef& last_hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) = 0; - - /*! - * \brief Concat a list of 1D hidden_states to 2D tensor. - * \param hidden_states The hidden_states to concat. - * \param dst The copy destination. - */ - virtual NDArray ConcatLastHidden(std::vector& hidden_states, ObjectRef* dst) = 0; - /*! * \brief Batch prefill function. Embedding in, logits out. * The embedding order of sequences in `embedding_arr` follows @@ -188,9 +160,9 @@ class ModelObj : public Object { * \param lengths The length of each sequence to prefill. * \return The hidden_states for the next token. */ - virtual NDArray BatchPrefillToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) = 0; + virtual ObjectRef BatchPrefillToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; /*! * \brief Batch decode function. Embedding in, logits out. @@ -209,8 +181,8 @@ class ModelObj : public Object { * \param seq_id The id of the sequence in the KV cache. * \return The hidden_states for the next token for each sequence in the batch. */ - virtual NDArray BatchDecodeToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids) = 0; + virtual ObjectRef BatchDecodeToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids) = 0; /*! * \brief Batch verify function. Embedding in, logits out. @@ -236,9 +208,9 @@ class ModelObj : public Object { * That is to say, it does not accept "running a verify step for a subset * of the full batch". */ - virtual NDArray BatchVerifyToLastHidden(const ObjectRef& hidden_states, - const std::vector& seq_ids, - const std::vector& lengths) = 0; + virtual ObjectRef BatchVerifyToLastHidden(const ObjectRef& hidden_states, + const std::vector& seq_ids, + const std::vector& lengths) = 0; /*********************** KV Cache Management ***********************/ diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index c6f463eb32..87a9a31d30 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -74,7 +74,6 @@ class GPUSampler : public SamplerObj { sample_indices_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); top_p_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); top_prob_offsets_device_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device); - draft_probs_device_ = NDArray::Empty({max_num_sample, vocab_size}, dtype_f32_, device); draft_tokens_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_first_child_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_next_sibling_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); @@ -630,7 +629,6 @@ class GPUSampler : public SamplerObj { NDArray sample_indices_device_; NDArray top_p_device_; NDArray top_prob_offsets_device_; - NDArray draft_probs_device_; NDArray draft_tokens_device_; NDArray token_tree_first_child_device_; NDArray token_tree_next_sibling_device_; diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index 4e8bcabd9e..7be9dadd39 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -1,4 +1,5 @@ """Python entrypoint of compilation.""" + import dataclasses import math from io import StringIO @@ -162,7 +163,11 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: logger.info("Running optimizations using TVM Unity") additional_tirs = _apply_preproc_to_params(named_params, model_config) variable_bounds = _get_variable_bounds(model_config) - cuda_graph_symbolic_capture_hints = {"batch_decode": ["batch_size"]} + cuda_graph_symbolic_capture_hints = { + "batch_decode": ["batch_size"], + "batch_decode_to_last_hidden_states": ["batch_size"], + "batch_verify_to_last_hidden_states": ["batch_size", "seq_len"], + } metadata = { "model_type": args.model.name, "quantization": args.quantization.name, diff --git a/python/mlc_llm/model/eagle/eagle_model.py b/python/mlc_llm/model/eagle/eagle_model.py index 355618df09..9d7820b841 100644 --- a/python/mlc_llm/model/eagle/eagle_model.py +++ b/python/mlc_llm/model/eagle/eagle_model.py @@ -190,8 +190,8 @@ def get_default_spec(self): }, }, "fuse_embed_hidden_states": { - "input_embed": nn.spec.Tensor(["length", self.hidden_size], self.dtype), - "hidden_states": nn.spec.Tensor(["length", self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 18238f688e..60c8f138d1 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -248,16 +248,11 @@ def get_logits(self, hidden_states: Tensor): logits = logits.astype("float32") return logits - def batch_get_logits(self, hidden_states: Tensor, logit_positions: Tensor): + def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): op_ext.configure() if self.tensor_parallel_shards > 1: logit_positions = op.ccl_broadcast_from_worker0(logit_positions) hidden_states = op.take(hidden_states, logit_positions, axis=0) - return self.get_logits(hidden_states) - - def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): - op_ext.configure() - hidden_states = op.take(hidden_states, logit_positions, axis=0) return hidden_states def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): @@ -368,14 +363,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "batch_get_logits": { - "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), - "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), - "$": { - "param_mode": "packed", - "effect_mode": "none", - }, - }, "batch_select_last_hidden_states": { "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), From eb4d6242518369850bcfa0d57ab7006edbe0e7ff Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 30 Apr 2024 18:33:38 -0700 Subject: [PATCH 256/531] [Serving] Add engine stats for speculative decoding (#2257) --- cpp/serve/engine_actions/batch_verify.cc | 2 ++ .../engine_actions/eagle_batch_verify.cc | 2 ++ cpp/serve/engine_state.cc | 26 +++++++++++++++-- cpp/serve/engine_state.h | 12 ++++++++ cpp/serve/threaded_engine.cc | 6 ++++ cpp/serve/threaded_engine.h | 3 ++ python/mlc_llm/serve/engine_base.py | 4 +++ .../serve/entrypoints/debug_entrypoints.py | 28 +++++++++++++++++++ 8 files changed, 81 insertions(+), 2 deletions(-) diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 6f27a50394..42524d46b2 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -141,6 +141,8 @@ class BatchVerifyActionObj : public EngineActionObj { rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result); } estate->stats.total_accepted_length += accept_length; + estate->stats.UpdateSpecDecodingStats(cum_verify_lengths[i + 1] - cum_verify_lengths[i], + accept_length); int rollback_length = std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0); // rollback kv cache diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index f7c858192d..6b23035f78 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -150,6 +150,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { rsentries[i]->mstates[verify_model_id_]->CommitToken(sample_result); rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result); } + estate->stats.UpdateSpecDecodingStats(cum_verify_lengths[i + 1] - cum_verify_lengths[i], + accept_length); estate->stats.total_accepted_length += accept_length - 1; // - Minus one because the last draft token has no kv cache entry // - Take max with 0 in case of all accepted. diff --git a/cpp/serve/engine_state.cc b/cpp/serve/engine_state.cc index 563f0e7b13..4304ca48af 100644 --- a/cpp/serve/engine_state.cc +++ b/cpp/serve/engine_state.cc @@ -13,15 +13,24 @@ namespace serve { String EngineStats::AsJSON() const { picojson::object config; config["single_token_prefill_latency"] = - picojson::value(request_total_prefill_time / total_prefill_length); + picojson::value(total_prefill_length > 0 ? request_total_prefill_time / total_prefill_length : 0.0); config["single_token_decode_latency"] = - picojson::value(request_total_decode_time / total_decode_length); + picojson::value(total_decode_length > 0 ? request_total_decode_time / total_decode_length : 0.0); config["engine_total_prefill_time"] = picojson::value(engine_total_prefill_time); config["engine_total_decode_time"] = picojson::value(engine_total_decode_time); config["total_prefill_tokens"] = picojson::value(total_prefill_length); config["total_decode_tokens"] = picojson::value(total_decode_length); config["total_accepted_tokens"] = picojson::value(total_accepted_length); config["total_draft_tokens"] = picojson::value(total_draft_length); + auto f_vector_to_array = [](const std::vector& vec) { + picojson::array arr; + for (int64_t v : vec) { + arr.push_back(picojson::value(v)); + } + return picojson::value(arr); + }; + config["accept_count"] = f_vector_to_array(accept_count); + config["draft_count"] = f_vector_to_array(draft_count); return picojson::value(config).serialize(true); } @@ -54,6 +63,19 @@ RequestState EngineStateObj::GetRequestState(Request request) { return it->second; } +void EngineStats::UpdateSpecDecodingStats(int draft_length, int accept_length) { + if (accept_count.size() < draft_length) { + this->accept_count.resize(draft_length, 0); + this->draft_count.resize(draft_length, 0); + } + for (int j = 0; j < draft_length; ++j) { + if (j < accept_length) { + this->accept_count[j]++; + } + this->draft_count[j]++; + } +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/engine_state.h b/cpp/serve/engine_state.h index ff955a264f..8218cbd73d 100644 --- a/cpp/serve/engine_state.h +++ b/cpp/serve/engine_state.h @@ -34,6 +34,10 @@ struct EngineStats { int64_t total_accepted_length = 0; /*! \brief The total number of speculated draft tokens. */ int64_t total_draft_length = 0; + /*! \brief The number of accepted tokens in speculative decoding. */ + std::vector accept_count; + /*! \brief The number of draft tokens in speculative decoding. */ + std::vector draft_count; /*! * \brief Return the engine runtime statistics in JSON string. @@ -49,6 +53,14 @@ struct EngineStats { String AsJSON() const; /*! \brief Reset all the statistics. */ void Reset(); + + /*! + * \brief Update the statistics of speculative decoding. + * \param draft_length The number of draft tokens (including the last prediction by the base + * model) + * \param accept_length The number of accepted tokens in the speculative decoding. + */ + void UpdateSpecDecodingStats(int draft_length, int accept_length); }; /*! \brief The manager of internal id for requests in engine. */ diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index 2f6f77a3a0..080853d465 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -214,6 +214,11 @@ class ThreadedEngineImpl : public ThreadedEngine { } } + String Stats() final { + std::lock_guard lock(background_loop_mutex_); + return background_engine_->Stats(); + } + private: void EngineReloadImpl(EngineConfig engine_config) { auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { @@ -314,6 +319,7 @@ class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &ThreadedEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_ENTRY("debug_call_func_on_all_worker", &ThreadedEngineImpl::DebugCallFuncOnAllAllWorker); + TVM_MODULE_VTABLE_ENTRY("stats", &ThreadedEngineImpl::Stats); TVM_MODULE_VTABLE_END(); }; diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index 49ba8f2175..d0f2ebe2d7 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -77,6 +77,9 @@ class ThreadedEngine { /*! \brief Call the given global function on all workers. Only for debug purpose. */ virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; + + /*! \brief Print the statistics of the engine. */ + virtual String Stats() = 0; }; } // namespace serve diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 65b41a66ac..2b24d8f1c4 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -1066,6 +1066,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "init_background_engine", "exit_background_loop", "debug_call_func_on_all_worker", + "stats", ] } self.tokenizer = Tokenizer(model_args[0][0]) @@ -1118,6 +1119,9 @@ def _debug_call_func_on_all_worker(self, func_name: str) -> None: """Call the given global function on all workers. Only for debug purpose.""" self._ffi["debug_call_func_on_all_worker"](func_name) + def stats(self): + return self._ffi["stats"]() + def process_chat_completion_request( # pylint: disable=too-many-arguments request: openai_api_protocol.ChatCompletionRequest, diff --git a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py index af1613c027..9f6508ea42 100644 --- a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py @@ -79,3 +79,31 @@ async def debug_cuda_profiler_stop(_request: fastapi.Request): "mlc.debug_cuda_profiler_stop" ) break + + +@app.post("/debug/dump_engine_stats") +async def debug_dump_engine_stats(request: fastapi.Request): + """Dump the engine stats for the engine. Only for debug purpose.""" + # Get the raw request body as bytes + request_raw_data = await request.body() + request_json_str = request_raw_data.decode("utf-8") + try: + # Parse the JSON string + request_dict = json.loads(request_json_str) + except json.JSONDecodeError: + return error_protocol.create_error_response( + HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" + ) + if "model" not in request_dict: + return error_protocol.create_error_response( + HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" + ) + + # - Check the requested model. + model = request_dict["model"] + + server_context: ServerContext = ServerContext.current() + async_engine = server_context.get_engine(model) + res = async_engine.stats() + print(res) + return json.loads(res) From d206c44f78236aa9556bcc12af32bbd979e21800 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 30 Apr 2024 19:34:23 -0700 Subject: [PATCH 257/531] [Serving] Fix lints (#2258) --- cpp/serve/engine_state.cc | 8 ++++---- python/mlc_llm/serve/engine_base.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/serve/engine_state.cc b/cpp/serve/engine_state.cc index 4304ca48af..7847f53fd5 100644 --- a/cpp/serve/engine_state.cc +++ b/cpp/serve/engine_state.cc @@ -12,10 +12,10 @@ namespace serve { String EngineStats::AsJSON() const { picojson::object config; - config["single_token_prefill_latency"] = - picojson::value(total_prefill_length > 0 ? request_total_prefill_time / total_prefill_length : 0.0); - config["single_token_decode_latency"] = - picojson::value(total_decode_length > 0 ? request_total_decode_time / total_decode_length : 0.0); + config["single_token_prefill_latency"] = picojson::value( + total_prefill_length > 0 ? request_total_prefill_time / total_prefill_length : 0.0); + config["single_token_decode_latency"] = picojson::value( + total_decode_length > 0 ? request_total_decode_time / total_decode_length : 0.0); config["engine_total_prefill_time"] = picojson::value(engine_total_prefill_time); config["engine_total_decode_time"] = picojson::value(engine_total_decode_time); config["total_prefill_tokens"] = picojson::value(total_prefill_length); diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 2b24d8f1c4..7f3f7e1331 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -1120,6 +1120,7 @@ def _debug_call_func_on_all_worker(self, func_name: str) -> None: self._ffi["debug_call_func_on_all_worker"](func_name) def stats(self): + """Get the engine stats.""" return self._ffi["stats"]() From 9941b4fff01d533809cb0924baf551b4dee577a3 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 1 May 2024 10:13:06 -0700 Subject: [PATCH 258/531] [Sampler] Avoid unnecessary sync in GPU verifier (#2260) --- cpp/serve/sampler/gpu_sampler.cc | 161 ++++++++++++++++++++----------- 1 file changed, 105 insertions(+), 56 deletions(-) diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index 87a9a31d30..a1c7a308bc 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -179,9 +179,9 @@ class GPUSampler : public SamplerObj { sample_results.resize(num_sequence); int num_nodes = cum_verify_lengths.back(); + ICHECK(num_nodes <= max_num_sample_); CHECK_EQ(draft_probs_on_device->shape[0], num_nodes); - NDArray uniform_samples_host = uniform_samples_host_.CreateView({num_nodes}, dtype_f32_); - NDArray uniform_samples_device = uniform_samples_device_.CreateView({num_nodes}, dtype_f32_); + NDArray uniform_samples_device = GenerateUniformSamples(rngs, cum_verify_lengths); NDArray draft_tokens_host = draft_tokens_host_.CreateView({num_nodes}, dtype_i32_); NDArray draft_tokens_device = draft_tokens_device_.CreateView({num_nodes}, dtype_i32_); @@ -201,16 +201,6 @@ class GPUSampler : public SamplerObj { } CopyArray(draft_tokens_host, draft_tokens_device, copy_stream_); - float* p_uniform_samples = static_cast(uniform_samples_host->data); - for (int i = 0; i < num_sequence; ++i) { - int start = cum_verify_lengths[i]; - int end = cum_verify_lengths[i + 1]; - for (int j = start; j < end; j++) { - p_uniform_samples[j] = rngs[i]->GetRandomNumber(); - } - } - CopyArray(uniform_samples_host, uniform_samples_device, copy_stream_); - NDArray token_tree_first_child_host = token_tree_first_child_host_.CreateView({num_nodes}, dtype_i32_); NDArray token_tree_first_child_device = @@ -254,10 +244,44 @@ class GPUSampler : public SamplerObj { token_tree_first_child_device, token_tree_next_sibling_device, uniform_samples_device, token_tree_parent_ptr_device); - CopyArray(token_tree_parent_ptr_device, token_tree_parent_ptr_host, compute_stream_); - TVMSynchronize(device_.device_type, device_.device_id, compute_stream_); + CopyArray(token_tree_parent_ptr_device, token_tree_parent_ptr_host, copy_stream_); - std::vector sample_indices; + std::vector additional_sample_result; + { + additional_sample_result.reserve(num_sequence); + // Sample one additional token for each sequence using the probablity at the last accepted + // token. + uniform_samples_device = GenerateUniformSamples(rngs, num_sequence); + const NDArray& sample_indices_device = token_tree_parent_ptr_device; + // Check need_prob_values + bool need_prob_values = false; + for (int i = 0; i < num_sequence; i++) { + need_prob_values |= generation_cfg[i]->logprobs; + } + std::vector top_prob_offset_indptr; + if (!need_prob_values) { + top_prob_offset_indptr.resize(num_sequence + 1, 0); + } else { + // Slow path: if any of the generation config requires prob values, we need to copy + // sample_indices to host to compute top_prob_offset_indptr. + TVMSynchronize(device_.device_type, device_.device_id, copy_stream_); + std::vector sample_indices; + sample_indices.reserve(num_sequence); + const int* p_token_tree_parent_ptr = static_cast(token_tree_parent_ptr_host->data); + for (int i = 0; i < num_sequence; i++) { + sample_indices.push_back(p_token_tree_parent_ptr[i]); + } + CheckProbValues(generation_cfg, sample_indices, num_nodes, num_sequence, vocab_size_, + &top_prob_offset_indptr); + } + auto device_arrays = + SampleOnGPU(probs_on_device, uniform_samples_device, sample_indices_device, + /*need_top_p=*/false, need_prob_values, num_nodes, top_prob_offset_indptr); + auto host_arrays = CopyArraysToCPU(device_arrays, num_sequence, need_prob_values, + top_prob_offset_indptr.back()); + additional_sample_result = + CollectSampleResult(host_arrays, num_sequence, need_prob_values, top_prob_offset_indptr); + } for (int i = 0; i < num_sequence; i++) { int start = cum_verify_lengths[i]; @@ -270,11 +294,9 @@ class GPUSampler : public SamplerObj { num_accepted++; } std::reverse(sample_results[i].rbegin(), sample_results[i].rbegin() + num_accepted); - sample_indices.push_back(last_accepted); } - std::vector additional_sample_result; - additional_sample_result = this->BatchSampleTokensWithProbAfterTopP( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs); + + // Append the additional sample result to the sample_results ICHECK_EQ(additional_sample_result.size(), num_sequence); for (int i = 0; i < num_sequence; i++) { sample_results[i].push_back(additional_sample_result[i]); @@ -347,6 +369,36 @@ class GPUSampler : public SamplerObj { return sample_results; } + /*! \brief Collect the sampling results from the computed NDArray results. */ + std::vector CollectSampleResult(const std::vector& host_arrays, + int num_samples, bool need_prob_values, + const std::vector top_prob_offset_indptr) { + const int* p_sampled_token_ids = static_cast(host_arrays[0]->data); + const float* p_sampled_probs = nullptr; + const float* p_top_prob_probs = nullptr; + const int* p_top_prob_indices = nullptr; + if (need_prob_values) { + p_sampled_probs = static_cast(host_arrays[1]->data); + p_top_prob_probs = static_cast(host_arrays[2]->data); + p_top_prob_indices = static_cast(host_arrays[3]->data); + } + std::vector sample_results; + sample_results.reserve(num_samples); + ICHECK_EQ(top_prob_offset_indptr.size(), num_samples + 1); + for (int i = 0; i < num_samples; ++i) { + // Note: we set the probability in SampleResult to 1.0 since prob value is not needed. + float sampled_prob = need_prob_values ? p_sampled_probs[i] : 1.0; + std::vector top_prob_tokens; + top_prob_tokens.reserve(top_prob_offset_indptr[i + 1] - top_prob_offset_indptr[i]); + for (int j = top_prob_offset_indptr[i]; j < top_prob_offset_indptr[i + 1]; ++j) { + top_prob_tokens.emplace_back(p_top_prob_indices[j], p_top_prob_probs[j]); + } + sample_results.push_back( + SampleResult{{p_sampled_token_ids[i], sampled_prob}, top_prob_tokens}); + } + return sample_results; + } + std::vector ChunkSampleTokensImpl(NDArray probs_on_device, // const std::vector& sample_indices, // const Array& generation_cfg, // @@ -359,8 +411,8 @@ class GPUSampler : public SamplerObj { // - Generate random numbers. // Copy the random numbers and sample indices. - auto [uniform_samples_device, sample_indices_device] = - CopySamplesAndIndicesToGPU(sample_indices, rngs, num_samples); + auto uniform_samples_device = GenerateUniformSamples(rngs, num_samples); + auto sample_indices_device = CopySampleIndicesToGPU(sample_indices); // - Check if there is need for applying top p or prob values, // so that argsort is needed. @@ -383,52 +435,49 @@ class GPUSampler : public SamplerObj { top_prob_offset_indptr.back()); // - Collect the sampling results. - const int* p_sampled_token_ids = static_cast(host_arrays[0]->data); - const float* p_sampled_probs = nullptr; - const float* p_top_prob_probs = nullptr; - const int* p_top_prob_indices = nullptr; - if (need_prob_values) { - p_sampled_probs = static_cast(host_arrays[1]->data); - p_top_prob_probs = static_cast(host_arrays[2]->data); - p_top_prob_indices = static_cast(host_arrays[3]->data); - } - std::vector sample_results; - sample_results.reserve(num_samples); - ICHECK_EQ(top_prob_offset_indptr.size(), num_samples + 1); - for (int i = 0; i < num_samples; ++i) { - // Note: we set the probability in SampleResult to 1.0 since prob value is not needed. - float sampled_prob = need_prob_values ? p_sampled_probs[i] : 1.0; - std::vector top_prob_tokens; - top_prob_tokens.reserve(top_prob_offset_indptr[i + 1] - top_prob_offset_indptr[i]); - for (int j = top_prob_offset_indptr[i]; j < top_prob_offset_indptr[i + 1]; ++j) { - top_prob_tokens.emplace_back(p_top_prob_indices[j], p_top_prob_probs[j]); - } - sample_results.push_back( - SampleResult{{p_sampled_token_ids[i], sampled_prob}, top_prob_tokens}); - } - - return sample_results; + return CollectSampleResult(host_arrays, num_samples, need_prob_values, top_prob_offset_indptr); } - /*! \brief Generate uniform random numbers, and copy the numbers and sample indices to GPU. */ - std::pair CopySamplesAndIndicesToGPU(const std::vector& sample_indices, - const std::vector& rngs, - int num_samples) { - // Generate random numbers. + /*! \brief Generate num_samples uniform random numbers, and copy them to GPU. */ + NDArray GenerateUniformSamples(const std::vector& rngs, int num_samples) { float* p_uniform_samples = static_cast(uniform_samples_host_->data); - int* p_sample_indices = static_cast(sample_indices_host_->data); for (int i = 0; i < num_samples; ++i) { p_uniform_samples[i] = rngs[i]->GetRandomNumber(); - p_sample_indices[i] = sample_indices[i]; } - // Copy the random numbers and sample indices to GPU. NDArray uniform_samples_host = uniform_samples_host_.CreateView({num_samples}, dtype_f32_); NDArray uniform_samples_device = uniform_samples_device_.CreateView({num_samples}, dtype_f32_); + CopyArray(/*src=*/uniform_samples_host, /*dst=*/uniform_samples_device, copy_stream_); + return uniform_samples_device; + } + + /*! \brief Generate uniform random numbers, and copy the numbers and sample indices to GPU. The + * number of samples for each random generator is given by `cum_num_samples`. */ + NDArray GenerateUniformSamples(const std::vector& rngs, + const std::vector& cum_num_samples) { + float* p_uniform_samples = static_cast(uniform_samples_host_->data); + int total_samples = cum_num_samples.back(); + for (int i = 0; i + 1 < static_cast(cum_num_samples.size()); ++i) { + for (int j = cum_num_samples[i]; j < cum_num_samples[i + 1]; ++j) { + p_uniform_samples[j] = rngs[i]->GetRandomNumber(); + } + } + NDArray uniform_samples_host = uniform_samples_host_.CreateView({total_samples}, dtype_f32_); + NDArray uniform_samples_device = + uniform_samples_device_.CreateView({total_samples}, dtype_f32_); + CopyArray(/*src=*/uniform_samples_host, /*dst=*/uniform_samples_device, copy_stream_); + return uniform_samples_device; + } + + /*! \brief Generate uniform random numbers, and copy the numbers and sample indices to GPU. */ + NDArray CopySampleIndicesToGPU(const std::vector& sample_indices) { + int* p_sample_indices = static_cast(sample_indices_host_->data); + std::copy(sample_indices.begin(), sample_indices.end(), p_sample_indices); + // Copy the sample indices to GPU. + int num_samples = static_cast(sample_indices.size()); NDArray sample_indices_host = sample_indices_host_.CreateView({num_samples}, dtype_i32_); NDArray sample_indices_device = sample_indices_device_.CreateView({num_samples}, dtype_i32_); - CopyArray(/*src=*/uniform_samples_host, /*dst=*/uniform_samples_device, copy_stream_); CopyArray(/*src=*/sample_indices_host, /*dst=*/sample_indices_device, copy_stream_); - return {uniform_samples_device, sample_indices_device}; + return sample_indices_device; } /*! \brief Check if top p is needed. Update host top p array in place. */ From cfd3b2ca462ffdee575477496e46146f4147375b Mon Sep 17 00:00:00 2001 From: zifeitong Date: Wed, 1 May 2024 10:13:45 -0700 Subject: [PATCH 259/531] Fix typo in token_postproc_method names (#2261) --- cpp/serve/engine.cc | 4 ++-- cpp/tokenizers.h | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 755af998cd..297eba8b10 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -102,8 +102,8 @@ class EngineImpl : public Engine { this->tokenizer_ = Tokenizer::FromPath(engine_config->model); std::string token_table_postproc_method; if (model_configs[0].count("token_table_postproc_method") == 0) { - // Backward compatibility: use "byte-fallback" by default - token_table_postproc_method = "byte-fallback"; + // Backward compatibility: use "byte_fallback" by default + token_table_postproc_method = "byte_fallback"; } else { token_table_postproc_method = model_configs[0].at("token_table_postproc_method").get(); diff --git a/cpp/tokenizers.h b/cpp/tokenizers.h index 36fc0c23db..b2e7446358 100644 --- a/cpp/tokenizers.h +++ b/cpp/tokenizers.h @@ -69,12 +69,12 @@ class Tokenizer : public ObjectRef { * later processing. E.g. For LLaMA-2, convert "▁of" to " of". * * \param token_table The raw token table. - * \param postproc_method The postprocessing method to use. Now we only support "byte-fallback" - * and "byte-level", which refers to the type of the decoder of the tokenizer. - * - "byte-fallback": Use the decoding method in the byte-fallback BPE tokenizer. This is used + * \param postproc_method The postprocessing method to use. Now we only support "byte_fallback" + * and "byte_level", which refers to the type of the decoder of the tokenizer. + * - "byte_fallback": Use the decoding method in the byte-fallback BPE tokenizer. This is used * by LLaMA-2, Mixtral-7b, etc. This method: 1) transform tokens like <0x1B> to hex char * byte 1B. (known as the byte-fallback method); 2) transform \\u2581 to space. - * - "byte-level": Use the decoding method in the byte-level BPE tokenizer. This is used by + * - "byte_level": Use the decoding method in the byte-level BPE tokenizer. This is used by * LLaMA-3, GPT-2, Phi-2, etc. This method inverses the bytes-to-unicode transformation in * the encoding process as in * https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59 From 8e5af29a91b2f8c5d490e7134ec3d01f2e00202b Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 1 May 2024 16:50:22 -0700 Subject: [PATCH 260/531] [Sampler] Add missing sync in gpu verifier (#2262) --- cpp/serve/function_table.cc | 2 +- cpp/serve/sampler/gpu_sampler.cc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 16db4a8a03..bdf28dfdb5 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -135,7 +135,7 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object static_cast(tvm::runtime::memory::AllocatorType::kPooled), static_cast(kDLCPU), 0, static_cast(tvm::runtime::memory::AllocatorType::kPooled)); this->mod_get_func = [this](const std::string& name) -> PackedFunc { - return this->local_vm->GetFunction(name, false); + return this->local_vm->GetFunction(name, true); }; this->get_global_func = [](const std::string& name) -> PackedFunc { const auto* f = tvm::runtime::Registry::Get(name); diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index a1c7a308bc..36cb6e5c0a 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -244,6 +244,7 @@ class GPUSampler : public SamplerObj { token_tree_first_child_device, token_tree_next_sibling_device, uniform_samples_device, token_tree_parent_ptr_device); + DeviceAPI::Get(device_)->SyncStreamFromTo(device_, compute_stream_, copy_stream_); CopyArray(token_tree_parent_ptr_device, token_tree_parent_ptr_host, copy_stream_); std::vector additional_sample_result; From e756f23992baf1cd2f28e676c2108ff165d68283 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 2 May 2024 12:36:55 -0700 Subject: [PATCH 261/531] [Model] Remove redundant space in llama2 tokenizer (#2263) --- cpp/conv_templates.cc | 4 ++-- python/mlc_llm/conversation_template.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index 729e6f3b38..6ef8038cf4 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -97,12 +97,12 @@ Conversation Llama2() { Conversation conv; conv.name = "llama-2"; conv.system = - ("[INST] <>\nYou are a helpful, respectful and honest assistant.\n<>\n\n "); + ("[INST] <>\nYou are a helpful, respectful and honest assistant.\n<>\n\n"); conv.roles = {"[INST]", "[/INST]"}; conv.messages = {}; conv.offset = 0; conv.separator_style = SeparatorStyle::kSepRoleMsg; - conv.seps = {" "}; + conv.seps = {"", " "}; conv.role_msg_sep = " "; conv.role_empty_sep = " "; conv.stop_tokens = {2}; diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index 1c599fa875..e5af9773bc 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -64,7 +64,7 @@ def get_conv_template(name: str) -> Optional[Conversation]: system_template=f"[INST] <>\n{MessagePlaceholders.SYSTEM.value}\n<>\n\n", system_message="You are a helpful, respectful and honest assistant.", roles={"user": "[INST]", "assistant": "[/INST]", "tool": "[INST]"}, - seps=[" "], + seps=["", " "], role_content_sep=" ", role_empty_sep=" ", stop_str=["[INST]"], From 878be83e4007e58c47009d0f1e4eb9c718a5fc6d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 2 May 2024 17:57:27 -0700 Subject: [PATCH 262/531] [Model] Fix llama2 chat template and remove redundant separator added by engine (#2264) * [Model] Fix llama2 chat template and remove redundant separator added by engine --- cpp/conv_templates.cc | 4 ++-- python/mlc_llm/conversation_template.py | 4 ++-- python/mlc_llm/protocol/conversation_protocol.py | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index 6ef8038cf4..7947a2fc24 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -98,11 +98,11 @@ Conversation Llama2() { conv.name = "llama-2"; conv.system = ("[INST] <>\nYou are a helpful, respectful and honest assistant.\n<>\n\n"); - conv.roles = {"[INST]", "[/INST]"}; + conv.roles = {"[INST]", "[/INST]"}; conv.messages = {}; conv.offset = 0; conv.separator_style = SeparatorStyle::kSepRoleMsg; - conv.seps = {"", " "}; + conv.seps = {" ", " "}; conv.role_msg_sep = " "; conv.role_empty_sep = " "; conv.stop_tokens = {2}; diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index e5af9773bc..56547ec1c3 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -63,8 +63,8 @@ def get_conv_template(name: str) -> Optional[Conversation]: name="llama-2", system_template=f"[INST] <>\n{MessagePlaceholders.SYSTEM.value}\n<>\n\n", system_message="You are a helpful, respectful and honest assistant.", - roles={"user": "[INST]", "assistant": "[/INST]", "tool": "[INST]"}, - seps=["", " "], + roles={"user": "[INST]", "assistant": "[/INST]", "tool": "[INST]"}, + seps=[" ", " "], role_content_sep=" ", role_empty_sep=" ", stop_str=["[INST]"], diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index 482cce54c8..e1ba1ce513 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -135,7 +135,6 @@ def as_prompt(self, config=None) -> List[Any]: separators.append(separators[0]) if system_msg != "": - system_msg += separators[0] message_list.append(system_msg) for i, (role, content) in enumerate(self.messages): # pylint: disable=not-an-iterable From b310ee1cccd92fe5939d4f5825063e7cca10cc0f Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 3 May 2024 08:36:09 -0400 Subject: [PATCH 263/531] [Refactor][Serving] EngineConfig refactor and "model-lib-path" rename (#2268) * This PR refactors the EngineConfig to allow minimal JSON string passing. This is helpful for the JSONFFIEngine construction. * This PR moves the automatic engine config inference from Python side to C++ side, so that we don't have duplicate code on multiple platforms. * This PR renames `model_lib_path` to `model_lib`. * This PR makes the reload/unload of ThreadedEngine act in a blocking style. * This PR refactors the default generation config process flow, and unifies everything to C++. --- android/library/prepare_model_lib.py | 9 +- cpp/json_ffi/{config.cc => conv_template.cc} | 43 +- cpp/json_ffi/{config.h => conv_template.h} | 57 +- cpp/json_ffi/json_ffi_engine.cc | 71 +- cpp/json_ffi/json_ffi_engine.h | 4 +- cpp/json_ffi/openai_api_protocol.cc | 2 +- cpp/json_ffi/openai_api_protocol.h | 4 +- cpp/metadata/model.cc | 27 +- cpp/metadata/model.h | 9 + cpp/serve/config.cc | 965 ++++++++++++++---- cpp/serve/config.h | 194 +++- cpp/serve/engine.cc | 325 ++++-- cpp/serve/engine.h | 23 +- cpp/serve/grammar/grammar_parser.cc | 2 +- cpp/serve/model.cc | 131 ++- cpp/serve/model.h | 28 +- cpp/serve/request.cc | 6 +- cpp/serve/threaded_engine.cc | 88 +- cpp/serve/threaded_engine.h | 20 +- cpp/{metadata => support}/json_parser.h | 17 +- cpp/support/result.h | 77 ++ docs/compilation/compile_models.rst | 6 +- docs/compilation/convert_weights.rst | 2 +- docs/deploy/cli.rst | 4 +- docs/deploy/ide_integration.rst | 2 +- docs/deploy/ios.rst | 3 +- docs/deploy/python_chat_module.rst | 10 +- docs/deploy/python_engine.rst | 6 +- docs/deploy/rest.rst | 6 +- docs/get_started/introduction.rst | 18 +- examples/python/sample_mlc_chat.py | 4 +- python/mlc_llm/chat_module.py | 42 +- python/mlc_llm/cli/bench.py | 7 +- python/mlc_llm/cli/benchmark.py | 3 +- python/mlc_llm/cli/chat.py | 7 +- python/mlc_llm/cli/serve.py | 19 +- python/mlc_llm/help.py | 16 +- python/mlc_llm/interface/bench.py | 5 +- python/mlc_llm/interface/chat.py | 5 +- python/mlc_llm/interface/serve.py | 7 +- python/mlc_llm/json_ffi/engine.py | 127 +-- .../mlc_llm/protocol/openai_api_protocol.py | 19 +- python/mlc_llm/protocol/protocol_utils.py | 3 +- python/mlc_llm/serve/__init__.py | 2 +- python/mlc_llm/serve/config.py | 154 +-- python/mlc_llm/serve/engine.py | 210 ++-- python/mlc_llm/serve/engine_base.py | 750 ++------------ python/mlc_llm/serve/request.py | 9 +- python/mlc_llm/serve/server/popen_server.py | 16 +- python/mlc_llm/serve/sync_engine.py | 64 +- python/mlc_llm/testing/debug_chat.py | 12 +- python/mlc_llm/testing/debug_compare.py | 6 +- rust/src/chat_module.rs | 20 +- tests/python/json_ffi/test_json_ffi_engine.py | 8 +- tests/python/serve/evaluate_engine.py | 6 +- tests/python/serve/server/conftest.py | 10 +- tests/python/serve/server/test_server.py | 6 +- .../serve/server/test_server_function_call.py | 6 +- .../python/serve/server/test_server_image.py | 8 +- tests/python/serve/test_radix_tree.py | 3 - tests/python/serve/test_serve_async_engine.py | 20 +- .../serve/test_serve_async_engine_spec.py | 14 +- tests/python/serve/test_serve_engine.py | 48 +- .../python/serve/test_serve_engine_grammar.py | 8 +- tests/python/serve/test_serve_engine_image.py | 4 +- tests/python/serve/test_serve_engine_spec.py | 108 +- tests/python/serve/test_serve_sync_engine.py | 20 +- 67 files changed, 2075 insertions(+), 1860 deletions(-) rename cpp/json_ffi/{config.cc => conv_template.cc} (86%) rename cpp/json_ffi/{config.h => conv_template.h} (66%) rename cpp/{metadata => support}/json_parser.h (92%) create mode 100644 cpp/support/result.h diff --git a/android/library/prepare_model_lib.py b/android/library/prepare_model_lib.py index dc14397a16..9f143d7357 100644 --- a/android/library/prepare_model_lib.py +++ b/android/library/prepare_model_lib.py @@ -1,5 +1,6 @@ import json import os + from tvm.contrib import ndk @@ -23,8 +24,8 @@ def main(): tar_list = [] model_set = set() - for model, model_lib_path in app_config["model_lib_path_for_prepare_libs"].items(): - path = os.path.join(artifact_path, model_lib_path) + for model, model_lib in app_config["model_lib_path_for_prepare_libs"].items(): + path = os.path.join(artifact_path, model_lib) if not os.path.isfile(path): raise RuntimeError(f"Cannot find android library {path}") tar_list.append(path) @@ -58,11 +59,11 @@ def main(): model_prefix_pattern not in global_symbol_map and "_" + model_prefix_pattern not in global_symbol_map ): - model_lib_path = app_config["model_lib_path_for_prepare_libs"][model_lib] + model_lib = app_config["model_lib_path_for_prepare_libs"][model_lib] print( "ValidationError:\n" f"\tmodel_lib {model_lib} requested in {app_config_path} is not found in {lib_path}\n" - f"\tspecifically the model_lib for {model_lib_path} in model_lib_path_for_prepare_libs.\n" + f"\tspecifically the model_lib for {model_lib} in model_lib_path_for_prepare_libs.\n" f"\tcurrent available model_libs in {lib_path}: {available_model_libs}" ) error_happened = True diff --git a/cpp/json_ffi/config.cc b/cpp/json_ffi/conv_template.cc similarity index 86% rename from cpp/json_ffi/config.cc rename to cpp/json_ffi/conv_template.cc index 8f5c0e1062..9511bb5b64 100644 --- a/cpp/json_ffi/config.cc +++ b/cpp/json_ffi/conv_template.cc @@ -1,8 +1,8 @@ -#include "config.h" +#include "conv_template.h" #include -#include "../metadata/json_parser.h" +#include "../support/json_parser.h" namespace mlc { namespace llm { @@ -10,27 +10,6 @@ namespace json_ffi { using namespace mlc::llm; -/****************** Model-defined generation config ******************/ - -TVM_REGISTER_OBJECT_TYPE(ModelDefinedGenerationConfigNode); - -ModelDefinedGenerationConfig::ModelDefinedGenerationConfig(double temperature, double top_p, - double frequency_penalty, - double presence_penalty) { - ObjectPtr n = make_object(); - n->temperature = temperature; - n->top_p = top_p; - n->frequency_penalty = frequency_penalty; - n->presence_penalty = presence_penalty; - data_ = std::move(n); -} - -TVM_REGISTER_GLOBAL("mlc.json_ffi.ModelDefinedGenerationConfig") - .set_body_typed([](double temperature, double top_p, double frequency_penalty, - double presence_penalty) { - return ModelDefinedGenerationConfig(temperature, top_p, frequency_penalty, presence_penalty); - }); - /****************** Conversation template ******************/ std::map PLACEHOLDERS = { @@ -334,24 +313,6 @@ std::optional Conversation::FromJSON(const std::string& json_str, return Conversation::FromJSON(json_obj.value(), err); } -/****************** JSON FFI engine config ******************/ - -TVM_REGISTER_OBJECT_TYPE(JSONFFIEngineConfigNode); - -JSONFFIEngineConfig::JSONFFIEngineConfig( - String conv_template, Map model_generation_cfgs) { - ObjectPtr n = make_object(); - n->conv_template = conv_template; - n->model_generation_cfgs = model_generation_cfgs; - data_ = std::move(n); -} - -TVM_REGISTER_GLOBAL("mlc.json_ffi.JSONFFIEngineConfig") - .set_body_typed([](String conv_template, - Map model_generation_cfgs) { - return JSONFFIEngineConfig(std::move(conv_template), std::move(model_generation_cfgs)); - }); - } // namespace json_ffi } // namespace llm } // namespace mlc diff --git a/cpp/json_ffi/config.h b/cpp/json_ffi/conv_template.h similarity index 66% rename from cpp/json_ffi/config.h rename to cpp/json_ffi/conv_template.h index fe5e4e42e2..eeb348831c 100644 --- a/cpp/json_ffi/config.h +++ b/cpp/json_ffi/conv_template.h @@ -1,9 +1,5 @@ -#ifndef MLC_LLM_JSON_FFI_CONFIG_H -#define MLC_LLM_JSON_FFI_CONFIG_H - -#include -#include -#include +#ifndef MLC_LLM_JSON_FFI_CONV_TEMPLATE_H +#define MLC_LLM_JSON_FFI_CONV_TEMPLATE_H #include #include @@ -22,35 +18,11 @@ namespace mlc { namespace llm { namespace json_ffi { -/****************** Model-defined generation config ******************/ - -class ModelDefinedGenerationConfigNode : public Object { - public: - double temperature; - double top_p; - double frequency_penalty; - double presence_penalty; - - static constexpr const char* _type_key = "mlc.json_ffi.ModelDefinedGenerationConfig"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(ModelDefinedGenerationConfigNode, Object); -}; - -class ModelDefinedGenerationConfig : public ObjectRef { - public: - explicit ModelDefinedGenerationConfig(double temperature, double top_p, double frequency_penalty, - double presence_penalty); - - TVM_DEFINE_OBJECT_REF_METHODS(ModelDefinedGenerationConfig, ObjectRef, - ModelDefinedGenerationConfigNode); -}; - /****************** Conversation template ******************/ enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION }; -MessagePlaceholders messagePlaceholderFromString(const std::string& role); +MessagePlaceholders MessagePlaceholderFromString(const std::string& role); class Message { public: @@ -144,29 +116,8 @@ struct Conversation { static std::optional FromJSON(const std::string& json_str, std::string* err); }; -/****************** JSON FFI engine config ******************/ - -class JSONFFIEngineConfigNode : public Object { - public: - String conv_template; - Map model_generation_cfgs; - - static constexpr const char* _type_key = "mlc.json_ffi.JSONFFIEngineConfig"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(JSONFFIEngineConfigNode, Object); -}; - -class JSONFFIEngineConfig : public ObjectRef { - public: - explicit JSONFFIEngineConfig(String conv_template, - Map model_generation_cfgs); - - TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineConfig, ObjectRef, JSONFFIEngineConfigNode); -}; - } // namespace json_ffi } // namespace llm } // namespace mlc -#endif /* MLC_LLM_JSON_FFI_CONV_TEMPLATE_H */ +#endif // MLC_LLM_JSON_FFI_CONV_TEMPLATE_H diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index d5fc53b8fa..6b2676ee3f 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -4,6 +4,10 @@ #include #include +#include "../serve/model.h" +#include "../support/json_parser.h" +#include "../support/result.h" + namespace mlc { namespace llm { namespace json_ffi { @@ -83,13 +87,27 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request Array inputs = inputs_obj.value(); // generation_cfg - Optional generation_cfg = GenerationConfig::Create( - request_json_str, &err_, conv_template, this->model_generation_cfgs[request.model]); - if (!generation_cfg.defined()) { - return false; + Array stop_strs; + stop_strs.reserve(conv_template.stop_str.size()); + for (const std::string& stop_str : conv_template.stop_str) { + stop_strs.push_back(stop_str); + } + if (request.stop.has_value()) { + stop_strs.reserve(stop_strs.size() + request.stop.value().size()); + for (const std::string& stop_str : request.stop.value()) { + stop_strs.push_back(stop_str); + } } - Request engine_request(request_id, inputs, generation_cfg.value()); + GenerationConfig generation_cfg(request.n, request.temperature, request.top_p, + request.frequency_penalty, request.presence_penalty, + /*repetition_penalty=*/std::nullopt, request.logprobs, + request.top_logprobs, request.logit_bias, request.seed, + request.ignore_eos, request.max_tokens, std::move(stop_strs), + conv_template.stop_token_ids, /*response_format=*/std::nullopt, + this->default_generation_cfg_json_str_); + + Request engine_request(request_id, inputs, generation_cfg); this->engine_->AddRequest(engine_request); return true; @@ -122,22 +140,8 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_END(); - void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config, - Device device, Optional request_stream_callback, + void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) { - std::optional conv_template = - Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_); - if (!conv_template.has_value()) { - LOG(FATAL) << "Invalid conversation template JSON: " << err_; - } - this->conv_template_ = conv_template.value(); - this->model_generation_cfgs = json_ffi_engine_config->model_generation_cfgs; - - // Todo(mlc-team): decouple InitBackgroundEngine into two functions - // by removing `engine_config` from arguments, after properly handling - // streamers. - this->streamer_ = TextStreamer(Tokenizer::FromPath(engine_config->model)); - CHECK(request_stream_callback.defined()) << "JSONFFIEngine requires request stream callback function, but it is not given."; this->request_stream_callback_ = request_stream_callback.value(); @@ -150,12 +154,31 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { }; request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - this->engine_->InitBackgroundEngine(device, std::move(request_stream_callback), - std::move(trace_recorder)); - this->engine_->Reload(std::move(engine_config)); + this->engine_->InitThreadedEngine(device, std::move(request_stream_callback), + std::move(trace_recorder)); } - void Reload(EngineConfig engine_config) { this->engine_->Reload(std::move(engine_config)); } + void Reload(String engine_config_json_str) { + this->engine_->Reload(engine_config_json_str); + this->default_generation_cfg_json_str_ = this->engine_->GetDefaultGenerationConfigJSONString(); + picojson::object engine_config_json = + json::ParseToJsonObject(this->engine_->GetCompleteEngineConfigJSONString()); + + // Load conversation template. + Result model_config_json = + serve::Model::LoadModelConfig(json::Lookup(engine_config_json, "model")); + CHECK(model_config_json.IsOk()) << model_config_json.UnwrapErr(); + std::optional conv_template = Conversation::FromJSON( + json::Lookup(model_config_json.Unwrap(), "conv_template"), &err_); + if (!conv_template.has_value()) { + LOG(FATAL) << "Invalid conversation template JSON: " << err_; + } + this->conv_template_ = conv_template.value(); + // Create streamer. + // Todo(mlc-team): Create one streamer for each request, instead of a global one. + this->streamer_ = + TextStreamer(Tokenizer::FromPath(json::Lookup(engine_config_json, "model"))); + } void Unload() { this->engine_->Unload(); } diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h index d57384abb5..e805cb6e8a 100644 --- a/cpp/json_ffi/json_ffi_engine.h +++ b/cpp/json_ffi/json_ffi_engine.h @@ -12,7 +12,7 @@ #include "../serve/threaded_engine.h" #include "../streamer.h" -#include "config.h" +#include "conv_template.h" #include "openai_api_protocol.h" namespace mlc { @@ -49,7 +49,7 @@ class JSONFFIEngine { PackedFunc request_stream_callback_; TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request Conversation conv_template_; - Map model_generation_cfgs; + String default_generation_cfg_json_str_; }; } // namespace json_ffi diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc index 13f4b140ce..4547108eb5 100644 --- a/cpp/json_ffi/openai_api_protocol.cc +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -5,7 +5,7 @@ */ #include "openai_api_protocol.h" -#include "../metadata/json_parser.h" +#include "../support/json_parser.h" namespace mlc { namespace llm { diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index 429050da3c..70ef2fb22f 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -13,7 +13,7 @@ #include #include -#include "config.h" +#include "conv_template.h" #include "picojson.h" namespace mlc { @@ -94,7 +94,7 @@ class ChatCompletionRequest { std::optional presence_penalty = std::nullopt; bool logprobs = false; int top_logprobs = 0; - std::optional> logit_bias = std::nullopt; + std::optional>> logit_bias = std::nullopt; std::optional max_tokens = std::nullopt; int n = 1; std::optional seed = std::nullopt; diff --git a/cpp/metadata/model.cc b/cpp/metadata/model.cc index 8c2cf66a80..2daf1d0338 100644 --- a/cpp/metadata/model.cc +++ b/cpp/metadata/model.cc @@ -4,7 +4,7 @@ #include -#include "./json_parser.h" +#include "../support/json_parser.h" namespace mlc { namespace llm { @@ -39,6 +39,16 @@ ModelMetadata::Param ModelMetadata::Param::FromJSON(const picojson::object& para return result; } +ModelMetadata::KVCacheMetadata ModelMetadata::KVCacheMetadata::FromJSON( + const picojson::object& json) { + KVCacheMetadata kv_cache_metadata; + kv_cache_metadata.num_hidden_layers = json::Lookup(json, "num_hidden_layers"); + kv_cache_metadata.head_dim = json::Lookup(json, "head_dim"); + kv_cache_metadata.num_attention_heads = json::Lookup(json, "num_attention_heads"); + kv_cache_metadata.num_key_value_heads = json::Lookup(json, "num_key_value_heads"); + return kv_cache_metadata; +} + ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata, const picojson::object& model_config) { ModelMetadata result; @@ -53,6 +63,8 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata, if (metadata.count("attention_sink_size")) // remove after sink is decoupled from model lib result.attention_sink_size = json::Lookup(metadata, "attention_sink_size"); result.tensor_parallel_shards = json::Lookup(metadata, "tensor_parallel_shards"); + result.kv_cache_metadata = + KVCacheMetadata::FromJSON(json::Lookup(metadata, "kv_cache")); { std::vector& params = result.params; picojson::array json_params = json::Lookup(metadata, "params"); @@ -76,17 +88,8 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata, ModelMetadata ModelMetadata::FromModule(tvm::runtime::Module module, const picojson::object& model_config) { std::string json_str = ""; - try { - TypedPackedFunc pf = module.GetFunction("_metadata"); - if (pf == nullptr) { - // legacy path - // TODO: remove this after full SLMify - return ModelMetadata(); - } - json_str = pf(); - } catch (...) { - return ModelMetadata(); // TODO: add a warning message about legacy usecases - } + TypedPackedFunc pf = module.GetFunction("_metadata"); + json_str = pf(); picojson::object json = json::ParseToJsonObject(json_str); try { return ModelMetadata::FromJSON(json, model_config); diff --git a/cpp/metadata/model.h b/cpp/metadata/model.h index 2472cb7d36..ede06b6b3f 100644 --- a/cpp/metadata/model.h +++ b/cpp/metadata/model.h @@ -32,6 +32,14 @@ struct ModelMetadata { static Param FromJSON(const picojson::object& param_obj, const picojson::object& model_config); }; + struct KVCacheMetadata { + int64_t num_hidden_layers; + int64_t num_attention_heads; + int64_t num_key_value_heads; + int64_t head_dim; + static KVCacheMetadata FromJSON(const picojson::object& json); + }; + std::string model_type; std::string quantization; int64_t context_window_size; @@ -41,6 +49,7 @@ struct ModelMetadata { int64_t attention_sink_size; std::vector params; std::unordered_map memory_usage; + KVCacheMetadata kv_cache_metadata; static ModelMetadata FromJSON(const picojson::object& json_str, const picojson::object& model_config); diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 3bb809ad67..30a3617a8d 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -5,12 +5,14 @@ #include "config.h" #include +#include #include +#include #include #include "../json_ffi/openai_api_protocol.h" -#include "../metadata/json_parser.h" +#include "../support/json_parser.h" #include "data.h" namespace mlc { @@ -21,178 +23,174 @@ namespace serve { TVM_REGISTER_OBJECT_TYPE(GenerationConfigNode); -GenerationConfig::GenerationConfig(String config_json_str) { - picojson::value config_json; - std::string err = picojson::parse(config_json, config_json_str); - if (!err.empty()) { - LOG(FATAL) << err; - return; +GenerationConfig::GenerationConfig( + std::optional n, std::optional temperature, std::optional top_p, + std::optional frequency_penalty, std::optional presense_penalty, + std::optional repetition_penalty, std::optional logprobs, + std::optional top_logprobs, std::optional>> logit_bias, + std::optional seed, std::optional ignore_eos, std::optional max_tokens, + std::optional> stop_strs, std::optional> stop_token_ids, + std::optional response_format, Optional default_config_json_str) { + ObjectPtr obj = make_object(); + GenerationConfig default_config; + if (default_config_json_str.defined()) { + default_config = GenerationConfig(default_config_json_str.value(), NullOpt); + } else { + default_config = GenerationConfig(obj); } - ObjectPtr n = make_object(); - - picojson::object config = config_json.get(); - if (config.count("n")) { - CHECK(config["n"].is()); - n->n = config["n"].get(); - CHECK_GT(n->n, 0) << "\"n\" should be at least 1"; - } - if (config.count("temperature")) { - CHECK(config["temperature"].is()); - n->temperature = config["temperature"].get(); - } - if (config.count("top_p")) { - CHECK(config["top_p"].is()); - n->top_p = config["top_p"].get(); - } - if (config.count("frequency_penalty")) { - CHECK(config["frequency_penalty"].is()); - n->frequency_penalty = config["frequency_penalty"].get(); - CHECK(std::fabs(n->frequency_penalty) <= 2.0) << "Frequency penalty must be in [-2, 2]!"; - } - if (config.count("presence_penalty")) { - CHECK(config["presence_penalty"].is()); - n->presence_penalty = config["presence_penalty"].get(); - CHECK(std::fabs(n->presence_penalty) <= 2.0) << "Presence penalty must be in [-2, 2]!"; - } - if (config.count("repetition_penalty")) { - CHECK(config["repetition_penalty"].is()); - n->repetition_penalty = config["repetition_penalty"].get(); - CHECK(n->repetition_penalty > 0) << "Repetition penalty must be a positive number!"; - } - if (config.count("logprobs")) { - CHECK(config["logprobs"].is()); - n->logprobs = config["logprobs"].get(); - } - if (config.count("top_logprobs")) { - CHECK(config["top_logprobs"].is()); - n->top_logprobs = config["top_logprobs"].get(); - CHECK(n->top_logprobs >= 0 && n->top_logprobs <= 5) - << "At most 5 top logprob tokens are supported"; - CHECK(n->top_logprobs == 0 || n->logprobs) - << "\"logprobs\" must be true to support \"top_logprobs\""; - } - if (config.count("logit_bias")) { - CHECK(config["logit_bias"].is() || config["logit_bias"].is()); - if (config["logit_bias"].is()) { - picojson::object logit_bias_json = config["logit_bias"].get(); - std::vector> logit_bias; - logit_bias.reserve(logit_bias_json.size()); - for (auto [token_id_str, bias] : logit_bias_json) { - CHECK(bias.is()); - double bias_value = bias.get(); - CHECK_LE(std::fabs(bias_value), 100.0) - << "Logit bias value should be in range [-100, 100]."; - logit_bias.emplace_back(std::stoi(token_id_str), bias_value); - } - n->logit_bias = std::move(logit_bias); - } + obj->n = n.value_or(default_config->n); + CHECK_GT(obj->n, 0) << "\"n\" should be at least 1"; + obj->temperature = temperature.value_or(default_config->temperature); + CHECK_GE(obj->temperature, 0) << "\"temperature\" should be non-negative"; + obj->top_p = top_p.value_or(default_config->top_p); + CHECK(obj->top_p >= 0 && obj->top_p <= 1) << "\"top_p\" should be in range [0, 1]"; + obj->frequency_penalty = frequency_penalty.value_or(default_config->frequency_penalty); + CHECK(std::fabs(obj->frequency_penalty) <= 2.0) << "Frequency penalty must be in [-2, 2]!"; + obj->presence_penalty = presense_penalty.value_or(default_config->presence_penalty); + CHECK(std::fabs(obj->presence_penalty) <= 2.0) << "Presence penalty must be in [-2, 2]!"; + obj->repetition_penalty = repetition_penalty.value_or(default_config->repetition_penalty); + CHECK(obj->repetition_penalty > 0) << "Repetition penalty must be a positive number!"; + obj->logprobs = logprobs.value_or(default_config->logprobs); + obj->top_logprobs = top_logprobs.value_or(default_config->top_logprobs); + CHECK(obj->top_logprobs >= 0 && obj->top_logprobs <= 5) + << "At most 5 top logprob tokens are supported"; + CHECK(obj->top_logprobs == 0 || obj->logprobs) + << "\"logprobs\" must be true to support \"top_logprobs\""; + + obj->logit_bias = logit_bias.value_or(default_config->logit_bias); + for (auto [token_id_str, bias] : obj->logit_bias) { + CHECK_LE(std::fabs(bias), 100.0) << "Logit bias value should be in range [-100, 100]."; } - if (config.count("max_tokens")) { - if (config["max_tokens"].is()) { - n->max_tokens = config["max_tokens"].get(); - } else { - CHECK(config["max_tokens"].is()) << "Unrecognized max_tokens"; - // "-1" means the generation will not stop until exceeding - // model capability or hit any stop criteria. - n->max_tokens = -1; - } + + obj->seed = seed.value_or(std::random_device{}()); + // "ignore_eos" is for benchmarking. Not the part of OpenAI API spec. + obj->ignore_eos = ignore_eos.value_or(default_config->ignore_eos); + // "-1" means the generation will not stop until exceeding + // model capability or hit any stop criteria. + obj->max_tokens = max_tokens.value_or(-1); + + obj->stop_strs = stop_strs.value_or(default_config->stop_strs); + obj->stop_token_ids = stop_token_ids.value_or(default_config->stop_token_ids); + obj->response_format = response_format.value_or(default_config->response_format); + + data_ = std::move(obj); +} + +GenerationConfig::GenerationConfig(String config_json_str, + Optional default_config_json_str) { + picojson::object config = json::ParseToJsonObject(config_json_str); + ObjectPtr n = make_object(); + GenerationConfig default_config; + if (default_config_json_str.defined()) { + default_config = GenerationConfig(default_config_json_str.value(), NullOpt); + } else { + default_config = GenerationConfig(n); } - if (config.count("seed")) { - if (config["seed"].is()) { - n->seed = config["seed"].get(); - } else { - CHECK(config["seed"].is()) << "Unrecognized seed"; - n->seed = std::random_device{}(); + + n->n = json::LookupOrDefault(config, "n", default_config->n); + CHECK_GT(n->n, 0) << "\"n\" should be at least 1"; + n->temperature = + json::LookupOrDefault(config, "temperature", default_config->temperature); + CHECK_GE(n->temperature, 0) << "\"temperature\" should be non-negative"; + n->top_p = json::LookupOrDefault(config, "top_p", default_config->top_p); + CHECK(n->top_p >= 0 && n->top_p <= 1) << "\"top_p\" should be in range [0, 1]"; + n->frequency_penalty = + json::LookupOrDefault(config, "frequency_penalty", default_config->frequency_penalty); + CHECK(std::fabs(n->frequency_penalty) <= 2.0) << "Frequency penalty must be in [-2, 2]!"; + n->presence_penalty = + json::LookupOrDefault(config, "presence_penalty", default_config->presence_penalty); + CHECK(std::fabs(n->presence_penalty) <= 2.0) << "Presence penalty must be in [-2, 2]!"; + n->repetition_penalty = json::LookupOrDefault(config, "repetition_penalty", + default_config->repetition_penalty); + CHECK(n->repetition_penalty > 0) << "Repetition penalty must be a positive number!"; + n->logprobs = json::LookupOrDefault(config, "logprobs", default_config->logprobs); + n->top_logprobs = + json::LookupOrDefault(config, "top_logprobs", default_config->top_logprobs); + CHECK(n->top_logprobs >= 0 && n->top_logprobs <= 5) + << "At most 5 top logprob tokens are supported"; + CHECK(n->top_logprobs == 0 || n->logprobs) + << "\"logprobs\" must be true to support \"top_logprobs\""; + + std::optional logit_bias_obj = + json::LookupOptional(config, "logit_bias"); + if (logit_bias_obj.has_value()) { + std::vector> logit_bias; + logit_bias.reserve(logit_bias_obj.value().size()); + for (auto [token_id_str, bias] : logit_bias_obj.value()) { + CHECK(bias.is()); + double bias_value = bias.get(); + CHECK_LE(std::fabs(bias_value), 100.0) << "Logit bias value should be in range [-100, 100]."; + logit_bias.emplace_back(std::stoi(token_id_str), bias_value); } + n->logit_bias = std::move(logit_bias); } else { - n->seed = std::random_device{}(); + n->logit_bias = default_config->logit_bias; } - if (config.count("stop_strs")) { - CHECK(config["stop_strs"].is()) - << "Invalid stop_strs. Stop strs should be an array of strings"; - picojson::array stop_strs_arr = config["stop_strs"].get(); + + n->seed = json::LookupOrDefault(config, "seed", std::random_device{}()); + // "ignore_eos" is for benchmarking. Not the part of OpenAI API spec. + n->ignore_eos = json::LookupOrDefault(config, "ignore_eos", default_config->ignore_eos); + // "-1" means the generation will not stop until exceeding + // model capability or hit any stop criteria. + n->max_tokens = json::LookupOrDefault(config, "max_tokens", -1); + + std::optional stop_strs_arr = + json::LookupOptional(config, "stop_strs"); + if (stop_strs_arr.has_value()) { Array stop_strs; - stop_strs.reserve(stop_strs_arr.size()); - for (const picojson::value& v : stop_strs_arr) { + stop_strs.reserve(stop_strs_arr.value().size()); + for (const picojson::value& v : stop_strs_arr.value()) { CHECK(v.is()) << "Invalid stop string in stop_strs"; stop_strs.push_back(v.get()); } n->stop_strs = std::move(stop_strs); + } else { + n->stop_strs = default_config->stop_strs; } - if (config.count("stop_token_ids")) { - CHECK(config["stop_token_ids"].is()) - << "Invalid stop_token_ids. Stop tokens should be an array of integers"; - picojson::array stop_token_ids_arr = config["stop_token_ids"].get(); + std::optional stop_token_ids_arr = + json::LookupOptional(config, "stop_token_ids"); + if (stop_token_ids_arr.has_value()) { std::vector stop_token_ids; - stop_token_ids.reserve(stop_token_ids_arr.size()); - for (const picojson::value& v : stop_token_ids_arr) { + stop_token_ids.reserve(stop_token_ids_arr.value().size()); + for (const picojson::value& v : stop_token_ids_arr.value()) { CHECK(v.is()) << "Invalid stop token in stop_token_ids"; stop_token_ids.push_back(v.get()); } n->stop_token_ids = std::move(stop_token_ids); + } else { + n->stop_token_ids = default_config->stop_token_ids; } - // Params for benchmarking. Not the part of openai spec. - if (config.count("ignore_eos")) { - CHECK(config["ignore_eos"].is()); - n->ignore_eos = config["ignore_eos"].get(); - } - - if (config.count("response_format")) { - CHECK(config["response_format"].is()); - picojson::object response_format_json = config["response_format"].get(); + std::optional response_format_obj = + json::LookupOptional(config, "response_format"); + if (response_format_obj.has_value()) { ResponseFormat response_format; - if (response_format_json.count("type")) { - CHECK(response_format_json["type"].is()); - response_format.type = response_format_json["type"].get(); - } - if (response_format_json.count("schema")) { - if (response_format_json["schema"].is()) { - response_format.schema = NullOpt; - } else { - CHECK(response_format_json["schema"].is()); - response_format.schema = response_format_json["schema"].get(); - } + response_format.type = json::LookupOrDefault(response_format_obj.value(), "type", + response_format.type); + std::optional schema = + json::LookupOptional(response_format_obj.value(), "schema"); + if (schema.has_value()) { + response_format.schema = schema.value(); } n->response_format = response_format; + } else { + n->response_format = default_config->response_format; } data_ = std::move(n); } -Optional GenerationConfig::Create( - const std::string& json_str, std::string* err, const Conversation& conv_template, - const ModelDefinedGenerationConfig& model_defined_gen_config) { - std::optional optional_json_obj = json::LoadJSONFromString(json_str, err); - if (!err->empty() || !optional_json_obj.has_value()) { - return NullOpt; - } - picojson::object& json_obj = optional_json_obj.value(); +GenerationConfig GenerationConfig::GetDefaultFromModelConfig( + const picojson::object& model_config_json) { ObjectPtr n = make_object(); - - n->temperature = - json::LookupOrDefault(json_obj, "temperature", model_defined_gen_config->temperature); - n->top_p = json::LookupOrDefault(json_obj, "top_p", model_defined_gen_config->top_p); - n->frequency_penalty = json::LookupOrDefault(json_obj, "frequency_penalty", - model_defined_gen_config->frequency_penalty); - n->presence_penalty = json::LookupOrDefault(json_obj, "presence_penalty", - model_defined_gen_config->presence_penalty); - n->logprobs = json::LookupOrDefault(json_obj, "logprobs", false); - n->top_logprobs = static_cast(json::LookupOrDefault(json_obj, "top_logprobs", 0)); - n->ignore_eos = json::LookupOrDefault(json_obj, "ignore_eos", false); - - // Copy stop str from conversation template to generation config - for (auto& stop_str : conv_template.stop_str) { - n->stop_strs.push_back(stop_str); - } - for (auto& stop_token_id : conv_template.stop_token_ids) { - n->stop_token_ids.push_back(stop_token_id); - } - - GenerationConfig gen_config; - gen_config.data_ = std::move(n); - return gen_config; + n->temperature = json::LookupOrDefault(model_config_json, "temperature", n->temperature); + n->top_p = json::LookupOrDefault(model_config_json, "top_p", n->top_p); + n->frequency_penalty = + json::LookupOrDefault(model_config_json, "frequency_penalty", n->frequency_penalty); + n->presence_penalty = + json::LookupOrDefault(model_config_json, "presence_penalty", n->presence_penalty); + return GenerationConfig(n); } String GenerationConfigNode::AsJSONString() const { @@ -243,87 +241,638 @@ String GenerationConfigNode::AsJSONString() const { TVM_REGISTER_OBJECT_TYPE(EngineConfigNode); -EngineConfig::EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, int kv_cache_page_size, - int max_num_sequence, int max_total_sequence_length, - int max_single_sequence_length, int prefill_chunk_size, - int max_history_size, KVStateKind kv_state_kind, - SpeculativeMode speculative_mode, int spec_draft_length) { +EngineConfig EngineConfig::FromJSONAndInferredConfig( + const picojson::object& json, const InferrableEngineConfig& inferred_config) { + CHECK(inferred_config.max_num_sequence.has_value()); + CHECK(inferred_config.max_total_sequence_length.has_value()); + CHECK(inferred_config.max_single_sequence_length.has_value()); + CHECK(inferred_config.prefill_chunk_size.has_value()); + CHECK(inferred_config.max_history_size.has_value()); + CHECK(inferred_config.kv_state_kind.has_value()); ObjectPtr n = make_object(); - n->model = std::move(model); - n->model_lib_path = std::move(model_lib_path); - n->additional_models = std::move(additional_models); - n->additional_model_lib_paths = std::move(additional_model_lib_paths); - n->kv_cache_page_size = kv_cache_page_size; - n->max_num_sequence = max_num_sequence; - n->max_total_sequence_length = max_total_sequence_length; - n->max_single_sequence_length = max_single_sequence_length; - n->prefill_chunk_size = prefill_chunk_size; - n->max_history_size = max_history_size; - n->kv_state_kind = kv_state_kind; - n->spec_draft_length = spec_draft_length; - n->speculative_mode = speculative_mode; - data_ = std::move(n); + + // - Get models and model libs. + n->model = json::Lookup(json, "model"); + n->model_lib = json::Lookup(json, "model_lib"); + std::vector additional_models; + std::vector additional_model_libs; + picojson::array additional_models_arr = + json::LookupOrDefault(json, "additional_models", picojson::array()); + picojson::array additional_model_libs_arr = + json::LookupOrDefault(json, "additional_model_libs", picojson::array()); + CHECK_EQ(additional_models_arr.size(), additional_model_libs_arr.size()) + << "The number of additional model libs does not match the number of additional models"; + int num_additional_models = additional_models_arr.size(); + additional_models.reserve(num_additional_models); + additional_model_libs.reserve(num_additional_models); + for (int i = 0; i < num_additional_models; ++i) { + additional_models.push_back(json::Lookup(additional_models_arr, i)); + additional_model_libs.push_back(json::Lookup(additional_model_libs_arr, i)); + } + n->additional_models = additional_models; + n->additional_model_libs = additional_model_libs; + n->mode = EngineModeFromString(json::Lookup(json, "mode")); + + // - Other fields with default value. + n->gpu_memory_utilization = + json::LookupOrDefault(json, "gpu_memory_utilization", n->gpu_memory_utilization); + n->kv_cache_page_size = + json::LookupOrDefault(json, "kv_cache_page_size", n->kv_cache_page_size); + n->speculative_mode = SpeculativeModeFromString(json::LookupOrDefault( + json, "speculative_mode", SpeculativeModeToString(n->speculative_mode))); + n->spec_draft_length = + json::LookupOrDefault(json, "spec_draft_length", n->spec_draft_length); + n->verbose = json::LookupOrDefault(json, "verbose", n->verbose); + + // - Fields from the inferred engine config. + n->max_num_sequence = inferred_config.max_num_sequence.value(); + n->max_total_sequence_length = inferred_config.max_total_sequence_length.value(); + n->max_single_sequence_length = inferred_config.max_single_sequence_length.value(); + n->prefill_chunk_size = inferred_config.prefill_chunk_size.value(); + n->max_history_size = inferred_config.max_history_size.value(); + n->kv_state_kind = inferred_config.kv_state_kind.value(); + + return EngineConfig(n); } -EngineConfig EngineConfig::FromJSONString(const std::string& json_str) { +Result>> +EngineConfig::GetModelsAndModelLibsFromJSONString(const std::string& json_str) { + using TResult = Result>>; picojson::value config_json; std::string err = picojson::parse(config_json, json_str); if (!err.empty()) { - LOG(FATAL) << err; + return TResult::Error(err); } - // Get json fields. + // Get the models and model libs from JSON. picojson::object config = config_json.get(); String model = json::Lookup(config, "model"); - String model_lib_path = json::Lookup(config, "model_lib_path"); - std::vector additional_models; - std::vector additional_model_lib_paths; - int kv_cache_page_size = json::Lookup(config, "kv_cache_page_size"); - int max_num_sequence = json::Lookup(config, "max_num_sequence"); - int max_total_sequence_length = json::Lookup(config, "max_total_sequence_length"); - int max_single_sequence_length = json::Lookup(config, "max_single_sequence_length"); - int prefill_chunk_size = json::Lookup(config, "prefill_chunk_size"); - int max_history_size = json::Lookup(config, "max_history_size"); - KVStateKind kv_state_kind = - static_cast(json::Lookup(config, "kv_state_kind")); - SpeculativeMode speculative_mode = - static_cast(json::Lookup(config, "speculative_mode")); - int spec_draft_length = json::Lookup(config, "spec_draft_length"); - + String model_lib = json::Lookup(config, "model_lib"); picojson::array additional_models_arr = - json::Lookup(config, "additional_models"); - picojson::array additional_model_lib_paths_arr = - json::Lookup(config, "additional_model_lib_paths"); - CHECK_EQ(additional_models_arr.size(), additional_model_lib_paths_arr.size()) - << "The number of additional model lib paths does not match the number of additional models"; + json::LookupOrDefault(config, "additional_models", picojson::array()); + picojson::array additional_model_libs_arr = + json::LookupOrDefault(config, "additional_model_libs", picojson::array()); + if (additional_models_arr.size() != additional_model_libs_arr.size()) { + return TResult::Error( + "The number of additional model libs does not match the number of additional models"); + } + int num_additional_models = additional_models_arr.size(); - additional_models.reserve(num_additional_models); - additional_model_lib_paths.reserve(num_additional_models); + std::vector> models_and_model_libs; + models_and_model_libs.reserve(num_additional_models + 1); + models_and_model_libs.emplace_back(model, model_lib); for (int i = 0; i < num_additional_models; ++i) { - additional_models.push_back(json::Lookup(additional_models_arr, i)); - additional_model_lib_paths.push_back( - json::Lookup(additional_model_lib_paths_arr, i)); + models_and_model_libs.emplace_back(json::Lookup(additional_models_arr, i), + json::Lookup(additional_model_libs_arr, i)); + } + return TResult::Ok(models_and_model_libs); +} + +String EngineConfigNode::AsJSONString() const { + picojson::object config; + + // - Models and model libs + config["model"] = picojson::value(this->model); + config["model_lib"] = picojson::value(this->model_lib); + picojson::array additional_models_arr; + picojson::array additional_model_libs_arr; + additional_models_arr.reserve(this->additional_models.size()); + additional_model_libs_arr.reserve(this->additional_models.size()); + for (int i = 0; i < static_cast(this->additional_models.size()); ++i) { + additional_models_arr.push_back(picojson::value(this->additional_models[i])); + additional_model_libs_arr.push_back(picojson::value(this->additional_model_libs[i])); } + config["additional_models"] = picojson::value(additional_models_arr); + config["additional_model_libs"] = picojson::value(additional_model_libs_arr); + + // - Other fields + config["mode"] = picojson::value(EngineModeToString(this->mode)); + config["gpu_memory_utilization"] = picojson::value(this->gpu_memory_utilization); + config["kv_cache_page_size"] = picojson::value(static_cast(this->kv_cache_page_size)); + config["max_num_sequence"] = picojson::value(static_cast(this->max_num_sequence)); + config["max_total_sequence_length"] = + picojson::value(static_cast(this->max_total_sequence_length)); + config["max_single_sequence_length"] = + picojson::value(static_cast(this->max_single_sequence_length)); + config["prefill_chunk_size"] = picojson::value(static_cast(this->prefill_chunk_size)); + config["max_history_size"] = picojson::value(static_cast(this->max_history_size)); + config["kv_state_kind"] = picojson::value(KVStateKindToString(this->kv_state_kind)); + config["speculative_mode"] = picojson::value(SpeculativeModeToString(this->speculative_mode)); + config["spec_draft_length"] = picojson::value(static_cast(this->spec_draft_length)); + config["verbose"] = picojson::value(static_cast(this->verbose)); + + return picojson::value(config).serialize(true); +} + +/****************** InferrableEngineConfig ******************/ + +/*! \brief The class for config limitation from models. */ +struct ModelConfigLimits { + int64_t model_max_single_sequence_length; + int64_t model_max_prefill_chunk_size; + int64_t model_max_batch_size; +}; + +/*! \brief Convert the bytes to megabytes, keeping 3 decimals. */ +inline std::string BytesToMegabytesString(double bytes) { + std::string str; + str.resize(20); + std::sprintf(&str[0], "%.3f", bytes / 1024 / 1024); + str.resize(std::strlen(str.c_str())); + return str; +} - return EngineConfig(std::move(model), std::move(model_lib_path), additional_models, - additional_model_lib_paths, kv_cache_page_size, max_num_sequence, - max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, - max_history_size, kv_state_kind, speculative_mode, spec_draft_length); +/*! + * \brief Get the upper bound of single sequence length, prefill size and batch size + * from model config. + */ +Result GetModelConfigLimits(const std::vector& model_configs) { + int64_t model_max_single_sequence_length = std::numeric_limits::max(); + int64_t model_max_prefill_chunk_size = std::numeric_limits::max(); + int64_t model_max_batch_size = std::numeric_limits::max(); + for (int i = 0; i < static_cast(model_configs.size()); ++i) { + picojson::object compile_time_model_config = + json::Lookup(model_configs[i], "model_config"); + // - The maximum single sequence length is the minimum context window size among all models. + int64_t runtime_context_window_size = + json::Lookup(model_configs[i], "context_window_size"); + int64_t compile_time_context_window_size = + json::Lookup(compile_time_model_config, "context_window_size"); + if (runtime_context_window_size > compile_time_context_window_size) { + return Result::Error( + "Model " + std::to_string(i) + "'s runtime context window size (" + + std::to_string(runtime_context_window_size) + + ") is larger than the context window size used at compile time (" + + std::to_string(compile_time_context_window_size) + ")."); + } + if (runtime_context_window_size == -1 && compile_time_context_window_size != -1) { + return Result::Error( + "Model " + std::to_string(i) + + "'s runtime context window size (infinite) is larger than the context " + "window size used at compile time (" + + std::to_string(compile_time_context_window_size) + ")."); + } + if (runtime_context_window_size != -1) { + model_max_single_sequence_length = + std::min(model_max_single_sequence_length, runtime_context_window_size); + } + // - The maximum prefill chunk size is the minimum prefill chunk size among all models. + int64_t runtime_prefill_chunk_size = + json::Lookup(model_configs[i], "prefill_chunk_size"); + int64_t compile_time_prefill_chunk_size = + json::Lookup(compile_time_model_config, "prefill_chunk_size"); + if (runtime_prefill_chunk_size > compile_time_prefill_chunk_size) { + return Result::Error( + "Model " + std::to_string(i) + "'s runtime prefill chunk size (" + + std::to_string(runtime_prefill_chunk_size) + + ") is larger than the prefill chunk size used at compile time (" + + std::to_string(compile_time_prefill_chunk_size) + ")."); + } + model_max_prefill_chunk_size = + std::min(model_max_prefill_chunk_size, runtime_prefill_chunk_size); + // - The maximum batch size is the minimum max batch size among all models. + model_max_batch_size = std::min( + model_max_batch_size, json::Lookup(compile_time_model_config, "max_batch_size")); + } + ICHECK_NE(model_max_prefill_chunk_size, std::numeric_limits::max()); + ICHECK_NE(model_max_batch_size, std::numeric_limits::max()); + return Result::Ok( + {model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size}); } -TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig") - .set_body_typed([](String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, int kv_cache_page_size, - int max_num_sequence, int max_total_sequence_length, - int max_single_sequence_length, int prefill_chunk_size, int max_history_size, - int kv_state_kind, int speculative_mode, int spec_draft_length) { - return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models), - std::move(additional_model_lib_paths), kv_cache_page_size, - max_num_sequence, max_total_sequence_length, max_single_sequence_length, - prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind), - SpeculativeMode(speculative_mode), spec_draft_length); - }); +/*! \brief The class for memory usage estimation result. */ +struct MemUsageEstimationResult { + double total_memory_bytes; + double kv_cache_memory_bytes; + double temp_memory_bytes; + InferrableEngineConfig inferred_config; +}; + +Result EstimateMemoryUsageOnMode( + EngineMode mode, Device device, double gpu_memory_utilization, int64_t params_bytes, + int64_t temp_buffer_bytes, + const std::vector& model_configs, // + const std::vector& model_metadata, // + ModelConfigLimits model_config_limits, // + InferrableEngineConfig init_config, bool verbose) { + std::ostringstream os; + InferrableEngineConfig inferred_config = init_config; + // - 1. max_mum_sequence + if (!init_config.max_num_sequence.has_value()) { + if (mode == EngineMode::kLocal) { + inferred_config.max_num_sequence = + std::min(static_cast(4), model_config_limits.model_max_batch_size); + } else if (mode == EngineMode::kInteractive) { + inferred_config.max_num_sequence = 1; + } else { + inferred_config.max_num_sequence = model_config_limits.model_max_batch_size; + } + os << "max batch size will be set to " << inferred_config.max_num_sequence.value() << ", "; + } else { + os << "max batch size " << inferred_config.max_num_sequence.value() + << " is specified by user, "; + } + int64_t max_num_sequence = inferred_config.max_num_sequence.value(); + // - 2. max_single_sequence_length + if (!init_config.max_single_sequence_length.has_value()) { + inferred_config.max_single_sequence_length = + model_config_limits.model_max_single_sequence_length; + } else { + inferred_config.max_single_sequence_length = + std::min(inferred_config.max_single_sequence_length.value(), + model_config_limits.model_max_single_sequence_length); + } + // - 3. infer the maximum total sequence length that can fit GPU memory. + double kv_bytes_per_token = 0; + double kv_aux_workspace_bytes = 0; + double model_workspace_bytes = 0; + double logit_processor_workspace_bytes = 0; + ICHECK_EQ(model_configs.size(), model_metadata.size()); + int num_models = model_configs.size(); + for (int i = 0; i < num_models; ++i) { + // - Read the vocab size and compile-time prefill chunk size (which affects memory allocation). + picojson::object compile_time_model_config = + json::Lookup(model_configs[i], "model_config"); + int64_t vocab_size = json::Lookup(compile_time_model_config, "vocab_size"); + int64_t prefill_chunk_size = + json::Lookup(compile_time_model_config, "prefill_chunk_size"); + // - Calculate KV cache memory usage. + int64_t num_layers = model_metadata[i].kv_cache_metadata.num_hidden_layers; + int64_t head_dim = model_metadata[i].kv_cache_metadata.head_dim; + int64_t num_qo_heads = model_metadata[i].kv_cache_metadata.num_attention_heads; + int64_t num_kv_heads = model_metadata[i].kv_cache_metadata.num_key_value_heads; + int64_t hidden_size = head_dim * num_qo_heads; + kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25; + kv_aux_workspace_bytes += + (max_num_sequence + 1) * 88 + prefill_chunk_size * (num_qo_heads + 1) * 8 + + prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 + 48 * 1024 * 1024; + model_workspace_bytes += prefill_chunk_size * 4 + max_num_sequence * 4 + + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2; + logit_processor_workspace_bytes += + max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125; + } + // Get single-card GPU size. + TVMRetValue rv; + DeviceAPI::Get(device)->GetAttr(device, DeviceAttrKind::kTotalGlobalMemory, &rv); + int64_t gpu_size_bytes = rv; + // Compute the maximum total sequence length under the GPU memory budget. + int64_t model_max_total_sequence_length = + static_cast((gpu_size_bytes * gpu_memory_utilization // + - params_bytes // + - temp_buffer_bytes // + - kv_aux_workspace_bytes // + - model_workspace_bytes // + - logit_processor_workspace_bytes) / + kv_bytes_per_token); + if (model_max_total_sequence_length <= 0) { + if (verbose) { + LOG(INFO) << "temp_buffer = " << BytesToMegabytesString(temp_buffer_bytes); + LOG(INFO) << "kv_aux workspace = " << BytesToMegabytesString(kv_aux_workspace_bytes); + LOG(INFO) << "model workspace = " << BytesToMegabytesString(model_workspace_bytes); + LOG(INFO) << "logit processor workspace = " + << BytesToMegabytesString(logit_processor_workspace_bytes); + } + return Result::Error( + "Insufficient GPU memory error: " + "The available single GPU memory is " + + BytesToMegabytesString(gpu_size_bytes * gpu_memory_utilization) + + " MB, " + "which is less than the sum of model weight size (" + + BytesToMegabytesString(params_bytes) + " MB) and temporary buffer size (" + + BytesToMegabytesString(temp_buffer_bytes + kv_aux_workspace_bytes + model_workspace_bytes + + logit_processor_workspace_bytes) + + " MB).\n" + "1. You can set a larger \"gpu_memory_utilization\" value.\n" + "2. If the model weight size is too large, please enable tensor parallelism by passing " + "`--tensor-parallel-shards $NGPU` to `mlc_llm gen_config` or use quantization.\n" + "3. If the temporary buffer size is too large, please use a smaller `--prefill-chunk-size` " + "in `mlc_llm gen_config`."); + } + if (device.device_type == DLDeviceType::kDLMetal) { + // NOTE: Metal runtime has severe performance issues with large buffers. + // To work around the issue, we limit the KV cache capacity to 32768. + model_max_total_sequence_length = + std::min(model_max_total_sequence_length, static_cast(32768)); + } + // Compute the total memory usage except the KV cache part. + double total_mem_usage_except_kv_cache = + (params_bytes + temp_buffer_bytes + kv_aux_workspace_bytes + model_workspace_bytes + + logit_processor_workspace_bytes); + + // - 4. max_total_sequence_length + if (!init_config.max_total_sequence_length.has_value()) { + if (mode == EngineMode::kLocal) { + inferred_config.max_total_sequence_length = std::min( + {model_max_total_sequence_length, model_config_limits.model_max_single_sequence_length, + static_cast(8192)}); + } else if (mode == EngineMode::kInteractive) { + inferred_config.max_total_sequence_length = std::min( + model_max_total_sequence_length, model_config_limits.model_max_single_sequence_length); + } else { + inferred_config.max_total_sequence_length = + std::min(model_max_total_sequence_length, + max_num_sequence * model_config_limits.model_max_single_sequence_length); + } + os << "max KV cache token capacity will be set to " + << inferred_config.max_total_sequence_length.value() << ", "; + } else { + os << "max KV cache token capacity " << inferred_config.max_total_sequence_length.value() + << " is specified by user, "; + } + // - 5. prefill_chunk_size + if (!init_config.prefill_chunk_size.has_value()) { + if (mode == EngineMode::kLocal || mode == EngineMode::kInteractive) { + inferred_config.prefill_chunk_size = + std::min({model_config_limits.model_max_prefill_chunk_size, + inferred_config.max_total_sequence_length.value(), + model_config_limits.model_max_single_sequence_length}); + } else { + inferred_config.prefill_chunk_size = model_config_limits.model_max_prefill_chunk_size; + } + os << "prefill chunk size will be set to " << inferred_config.prefill_chunk_size.value() + << ". "; + } else { + os << "prefill chunk size " << inferred_config.prefill_chunk_size.value() + << " is specified by user. "; + } + + // - Print logging message + if (verbose) { + LOG(INFO) << "Under mode \"" << EngineModeToString(mode) << "\", " << os.str(); + } + + return Result::Ok( + {total_mem_usage_except_kv_cache + + inferred_config.max_total_sequence_length.value() * kv_bytes_per_token, + kv_bytes_per_token * inferred_config.max_total_sequence_length.value() + + kv_aux_workspace_bytes, + model_workspace_bytes + logit_processor_workspace_bytes + temp_buffer_bytes, + inferred_config}); +} + +Result InferrableEngineConfig::InferForKVCache( + EngineMode mode, Device device, double gpu_memory_utilization, + const std::vector& model_configs, + const std::vector& model_metadata, InferrableEngineConfig init_config, + bool verbose) { + // - Check if max_history_size is not set. + if (init_config.max_history_size.has_value() && init_config.max_history_size.value() != 0) { + return Result::Error( + "KV cache does not support max_history_size, while it is set to " + + std::to_string(init_config.max_history_size.value()) + " in the input EngineConfig"); + } + // - Get the upper bound of single sequence length, prefill size and batch size + // from model config. + Result model_config_limits_res = GetModelConfigLimits(model_configs); + if (model_config_limits_res.IsErr()) { + return Result::Error(model_config_limits_res.UnwrapErr()); + } + ModelConfigLimits model_config_limits = model_config_limits_res.Unwrap(); + // - Get total model parameter size and temporary in-function buffer + // size in bytes on single GPU. + int64_t params_bytes = 0; + int64_t temp_buffer_bytes = 0; + for (const ModelMetadata& metadata : model_metadata) { + for (const ModelMetadata::Param& param : metadata.params) { + int64_t param_size = param.dtype.bytes(); + for (int64_t v : param.shape) { + ICHECK_GE(v, 0); + param_size *= v; + } + params_bytes += param_size; + } + for (const auto& [func_name, temp_buffer_size] : metadata.memory_usage) { + temp_buffer_bytes = std::max(temp_buffer_bytes, temp_buffer_size); + } + } + // Magnify the temp buffer by a factor of 2 for safety. + temp_buffer_bytes *= 2; + + // - Infer the engine config and estimate memory usage for each mode. + Result local_mode_estimation_result = EstimateMemoryUsageOnMode( + EngineMode::kLocal, device, gpu_memory_utilization, params_bytes, temp_buffer_bytes, + model_configs, model_metadata, model_config_limits, init_config, verbose); + Result interactive_mode_estimation_result = EstimateMemoryUsageOnMode( + EngineMode::kInteractive, device, gpu_memory_utilization, params_bytes, temp_buffer_bytes, + model_configs, model_metadata, model_config_limits, init_config, verbose); + Result server_mode_estimation_result = EstimateMemoryUsageOnMode( + EngineMode::kServer, device, gpu_memory_utilization, params_bytes, temp_buffer_bytes, + model_configs, model_metadata, model_config_limits, init_config, verbose); + // - Pick the estimation result according to the mode. + std::string mode_name; + Result final_estimation_result; + if (mode == EngineMode::kLocal) { + final_estimation_result = std::move(local_mode_estimation_result); + } else if (mode == EngineMode::kInteractive) { + final_estimation_result = std::move(interactive_mode_estimation_result); + } else { + final_estimation_result = std::move(server_mode_estimation_result); + } + if (final_estimation_result.IsErr()) { + return Result::Error(final_estimation_result.UnwrapErr()); + } + // - Print log message. + MemUsageEstimationResult final_estimation = final_estimation_result.Unwrap(); + InferrableEngineConfig inferred_config = std::move(final_estimation.inferred_config); + if (verbose) { + LOG(INFO) << "The actual engine mode is \"" << EngineModeToString(mode) + << "\". So max batch size is " << inferred_config.max_num_sequence.value() + << ", max KV cache token capacity is " + << inferred_config.max_total_sequence_length.value() << ", prefill chunk size is " + << inferred_config.prefill_chunk_size.value() << "."; + LOG(INFO) << "Estimated total single GPU memory usage: " + << BytesToMegabytesString(final_estimation.total_memory_bytes) + << " MB (Parameters: " << BytesToMegabytesString(params_bytes) + << " MB. KVCache: " << BytesToMegabytesString(final_estimation.kv_cache_memory_bytes) + << " MB. Temporary buffer: " + << BytesToMegabytesString(final_estimation.temp_memory_bytes) + << " MB). The actual usage might be slightly larger than the estimated number."; + } + + inferred_config.kv_state_kind = KVStateKind::kKVCache; + inferred_config.max_history_size = 0; + return Result::Ok(inferred_config); +} + +Result InferrableEngineConfig::InferForRNNState( + EngineMode mode, Device device, double gpu_memory_utilization, + const std::vector& model_configs, + const std::vector& model_metadata, InferrableEngineConfig init_config, + bool verbose) { + // - Check max_single_sequence_length is not set. + if (init_config.max_single_sequence_length.has_value()) { + return Result::Error( + "RNN state does not support max_single_sequence_length, while it is set to " + + std::to_string(init_config.max_single_sequence_length.value()) + + " in the input EngineConfig"); + } + // - Get the upper bound of single sequence length, prefill size and batch size + // from model config. + Result model_config_limits_res = GetModelConfigLimits(model_configs); + if (model_config_limits_res.IsErr()) { + return Result::Error(model_config_limits_res.UnwrapErr()); + } + ModelConfigLimits model_config_limits = model_config_limits_res.Unwrap(); + + std::ostringstream os; + InferrableEngineConfig inferred_config = init_config; + // - 1. prefill_chunk_size + if (!init_config.prefill_chunk_size.has_value()) { + inferred_config.prefill_chunk_size = + std::min(model_config_limits.model_max_prefill_chunk_size, static_cast(4096)); + os << "prefill chunk size will be set to " << inferred_config.prefill_chunk_size.value() + << ", "; + } else { + os << "prefill chunk size " << inferred_config.prefill_chunk_size.value() + << " is specified by user, "; + } + // - 2. max_batch_size + if (!init_config.max_num_sequence.has_value()) { + inferred_config.max_num_sequence = + mode == EngineMode::kInteractive + ? 1 + : std::min(static_cast(4), model_config_limits.model_max_batch_size); + os << "max batch size will be set to " << inferred_config.max_num_sequence.value() << ", "; + } else { + os << "max batch size " << inferred_config.max_num_sequence.value() + << " is specified by user, "; + } + int64_t max_num_sequence = inferred_config.max_num_sequence.value(); + // - 3. max_total_sequence_length + if (!init_config.max_total_sequence_length.has_value()) { + inferred_config.max_total_sequence_length = 32768; + os << "max RNN state token capacity will be set to " + << inferred_config.max_total_sequence_length.value() << ". "; + } else { + os << "max RNN state token capacity " << inferred_config.max_total_sequence_length.value() + << " is specified by user. "; + } + + // - Extra logging message + if (mode == EngineMode::kLocal) { + os << "We choose small max batch size and RNN state capacity to use less GPU memory."; + } else if (mode == EngineMode::kInteractive) { + os << "We fix max batch size to 1 for interactive single sequence use."; + } else { + os << "We use as much GPU memory as possible (within the limit of gpu_memory_utilization)."; + } + if (verbose) { + LOG(INFO) << "Under mode \"" << EngineModeToString(mode) << "\", " << os.str(); + } + + // - Get total model parameter size and temporary in-function buffer + // size in bytes on single GPU. + int64_t params_bytes = 0; + int64_t temp_buffer_bytes = 0; + for (const ModelMetadata& metadata : model_metadata) { + for (const ModelMetadata::Param& param : metadata.params) { + int64_t param_size = param.dtype.bytes(); + for (int64_t v : param.shape) { + ICHECK_GE(v, 0); + param_size *= v; + } + params_bytes += param_size; + } + for (const auto& [func_name, temp_buffer_size] : metadata.memory_usage) { + temp_buffer_bytes += temp_buffer_size; + } + } + // - 4. max_history_size + double rnn_state_base_bytes = 0; // The memory usage for rnn state when history = 1. + double model_workspace_bytes = 0; + double logit_processor_workspace_bytes = 0; + ICHECK_EQ(model_configs.size(), model_metadata.size()); + int num_models = model_configs.size(); + for (int i = 0; i < num_models; ++i) { + // - Read the vocab size and compile-time prefill chunk size (which affects memory allocation). + picojson::object compile_time_model_config = + json::Lookup(model_configs[i], "model_config"); + int64_t vocab_size = json::Lookup(compile_time_model_config, "vocab_size"); + int64_t prefill_chunk_size = + json::Lookup(compile_time_model_config, "prefill_chunk_size"); + int64_t head_size = json::Lookup(compile_time_model_config, "head_size"); + int64_t num_heads = json::Lookup(compile_time_model_config, "num_heads"); + int64_t num_layers = json::Lookup(compile_time_model_config, "num_hidden_layers"); + int64_t hidden_size = json::Lookup(compile_time_model_config, "hidden_size"); + // - Calculate RNN state memory usage. + rnn_state_base_bytes += (max_num_sequence * hidden_size * num_layers * 2 * 2 + + max_num_sequence * num_heads * head_size * head_size * num_layers * 2); + model_workspace_bytes += prefill_chunk_size * 4 + max_num_sequence * 4 + + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2; + logit_processor_workspace_bytes += + max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125; + } + // Get single-card GPU size. + TVMRetValue rv; + DeviceAPI::Get(device)->GetAttr(device, DeviceAttrKind::kTotalGlobalMemory, &rv); + int64_t gpu_size_bytes = rv; + // Compute the maximum history size length under the GPU memory budget. + int64_t model_max_history_size = static_cast((gpu_size_bytes * gpu_memory_utilization // + - params_bytes // + - temp_buffer_bytes // + - model_workspace_bytes // + - logit_processor_workspace_bytes) / + rnn_state_base_bytes); + if (model_max_history_size <= 0) { + return Result::Error( + "Insufficient GPU memory error: " + "The available single GPU memory is " + + BytesToMegabytesString(gpu_size_bytes * gpu_memory_utilization) + + " MB, " + "which is less than the sum of model weight size (" + + BytesToMegabytesString(params_bytes) + " MB) and temporary buffer size (" + + BytesToMegabytesString( + (temp_buffer_bytes + model_workspace_bytes + logit_processor_workspace_bytes)) + + " MB). " + "If the model weight size is too large, please use quantization. " + "If the temporary buffer size is too large, please use a smaller `--prefill-chunk-size` in " + "`mlc_llm gen_config`."); + } + if (!init_config.max_history_size.has_value()) { + inferred_config.max_history_size = model_max_history_size; + } else { + inferred_config.max_history_size = + std::min(inferred_config.max_history_size.value(), model_max_history_size); + } + if (verbose) { + LOG(INFO) << "The actual engine mode is \"" << EngineModeToString(mode) + << "\". So max batch size is " << inferred_config.max_num_sequence.value() + << ", max RNN state token capacity is " + << inferred_config.max_total_sequence_length.value() << ", prefill chunk size is " + << inferred_config.prefill_chunk_size.value() << "."; + LOG(INFO) << "Estimated total single GPU memory usage: " + << BytesToMegabytesString(params_bytes + temp_buffer_bytes + + inferred_config.max_history_size.value() * + rnn_state_base_bytes) + << " MB (Parameters: " << BytesToMegabytesString(params_bytes) << " MB. RNN state: " + << BytesToMegabytesString(inferred_config.max_history_size.value() * + rnn_state_base_bytes) + << " MB. Temporary buffer: " + << BytesToMegabytesString(model_workspace_bytes + logit_processor_workspace_bytes + + temp_buffer_bytes) + << " MB). The actual usage might be slightly larger than the estimated number."; + } + + inferred_config.kv_state_kind = KVStateKind::kRNNState; + return Result::Ok(inferred_config); +} + +/****************** Config utils ******************/ + +Result ModelsUseKVCache(const std::vector& model_configs) { + ICHECK_GE(model_configs.size(), 1); + std::string model_type = json::Lookup(model_configs[0], "model_type"); + bool use_kv_cache = model_type.find("rwkv") == std::string::npos; + for (int i = 1; i < static_cast(model_configs.size()); ++i) { + if ((json::Lookup(model_configs[i], "model_type").find("rwkv") == + std::string::npos) != use_kv_cache) { + return Result::Error( + "Invalid models in EngineConfig. Models must be all RNN model or none model is RNN " + "model."); + } + } + return Result::Ok(use_kv_cache); +} } // namespace serve } // namespace llm diff --git a/cpp/serve/config.h b/cpp/serve/config.h index fd76dd49f0..8437232d37 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -5,13 +5,15 @@ #ifndef MLC_LLM_SERVE_CONFIG_H_ #define MLC_LLM_SERVE_CONFIG_H_ +#include #include #include #include #include -#include "../json_ffi/config.h" +#include "../metadata/model.h" +#include "../support/result.h" namespace mlc { namespace llm { @@ -19,7 +21,6 @@ namespace serve { using namespace tvm; using namespace tvm::runtime; -using namespace mlc::llm::json_ffi; /****************** GenerationConfig ******************/ @@ -60,22 +61,51 @@ class GenerationConfigNode : public Object { class GenerationConfig : public ObjectRef { public: - explicit GenerationConfig(String config_json_str); + TVM_DLL explicit GenerationConfig( + std::optional n, std::optional temperature, std::optional top_p, + std::optional frequency_penalty, std::optional presense_penalty, + std::optional repetition_penalty, std::optional logprobs, + std::optional top_logprobs, std::optional>> logit_bias, + std::optional seed, std::optional ignore_eos, std::optional max_tokens, + std::optional> stop_strs, std::optional> stop_token_ids, + std::optional response_format, Optional default_config_json_str); - /*! - * \brief Create a generation config from a ChatCompletionRequest. - * If the request does not contain a generation config, the model-defined - * generation config will be used. - */ - static Optional Create( - const std::string& json_str, std::string* err, const Conversation& conv_template, - const ModelDefinedGenerationConfig& model_defined_gen_config); + TVM_DLL explicit GenerationConfig(String config_json_str, + Optional default_config_json_str); + + /*! \brief Get the default generation config from the model config. */ + TVM_DLL static GenerationConfig GetDefaultFromModelConfig(const picojson::object& json); TVM_DEFINE_OBJECT_REF_METHODS(GenerationConfig, ObjectRef, GenerationConfigNode); }; /****************** Engine config ******************/ +/*! + * \brief The engine mode in MLC LLM. + * We provide three preset modes: "local", "interactive" and "server". + * The default mode is "local". + * The choice of mode decides the values of "max_batch_size", "max_total_sequence_length" + * and "prefill_chunk_size" when they are not explicitly specified. + * 1. Mode "local" refers to the local server deployment which has low + * request concurrency. So the max batch size will be set to 4, and max + * total sequence length and prefill chunk size are set to the context + * window size (or sliding window size) of the model. + * 2. Mode "interactive" refers to the interactive use of server, which + * has at most 1 concurrent request. So the max batch size will be set to 1, + * and max total sequence length and prefill chunk size are set to the context + * window size (or sliding window size) of the model. + * 3. Mode "server" refers to the large server use case which may handle + * many concurrent request and want to use GPU memory as much as possible. + * In this mode, we will automatically infer the largest possible max batch + * size and max total sequence length. + */ +enum class EngineMode : int { + kLocal = 0, + kInteractive = 1, + kServer = 2, +}; + /*! \brief The speculative mode. */ enum class SpeculativeMode : int { /*! \brief Disable speculative decoding. */ @@ -87,11 +117,13 @@ enum class SpeculativeMode : int { }; /*! \brief The kind of cache. */ -enum KVStateKind { - kAttention = 0, +enum class KVStateKind : int { + kKVCache = 0, kRNNState = 1, }; +class InferrableEngineConfig; + /*! \brief The configuration of engine execution config. */ class EngineConfigNode : public Object { public: @@ -99,44 +131,61 @@ class EngineConfigNode : public Object { /*! \brief The path to the model directory. */ String model; - /*! \brief The path to the model library. */ - String model_lib_path; + /*! \brief The path or identifier to the model library. */ + String model_lib; /*! \brief The path to the additional models' directories. */ Array additional_models; /*! \brief The path to the additional models' libraries. */ - Array additional_model_lib_paths; + Array additional_model_libs; /*************** KV cache config and engine capacities ***************/ + /*! + * \brief The engine mode in MLC LLM. + * \sa EngineMode + */ + EngineMode mode = EngineMode::kLocal; + /*! + * \brief A number in (0, 1) denoting the fraction of GPU memory used by the server in total. + * It is used to infer to maximum possible KV cache capacity. + * When it is unspecified, it defaults to 0.85. + * Under mode "local" or "interactive", the actual memory usage may be + * significantly smaller than this number. Under mode "server", the actual + * memory usage may be slightly larger than this number. + */ + float gpu_memory_utilization = 0.85; /*! \brief The number of consecutive tokens handled in each page in paged KV cache. */ - int kv_cache_page_size; + int kv_cache_page_size = 16; /*! * \brief The maximum number of sequences that are allowed to be * processed by the KV cache at any time. */ - int max_num_sequence; + int max_num_sequence = 4; /*! \brief The maximum length allowed for a single sequence in the engine. */ - int max_total_sequence_length; + int max_total_sequence_length = 4096; /*! * \brief The maximum total number of tokens whose KV data are allowed * to exist in the KV cache at any time. */ - int max_single_sequence_length; + int max_single_sequence_length = 4096; /*! \brief The maximum total sequence length in a prefill. */ - int prefill_chunk_size; + int prefill_chunk_size = 1024; /*! \brief The maximum history size for RNN state. KV cache does not need this. */ - int max_history_size; + int max_history_size = 0; /*! \brief The kind of cache. Whether it's KV cache or RNN state. */ - KVStateKind kv_state_kind; + KVStateKind kv_state_kind = KVStateKind::kKVCache; /*************** Speculative decoding ***************/ /*! \brief The speculative mode. */ - SpeculativeMode speculative_mode; + SpeculativeMode speculative_mode = SpeculativeMode::kDisable; /*! \brief The number of tokens to generate in speculative proposal (draft). */ int spec_draft_length = 4; - String AsJSONString() const; + /*************** Debug ***************/ + bool verbose = false; + + TVM_DLL String AsJSONString() const; static constexpr const char* _type_key = "mlc.serve.EngineConfig"; static constexpr const bool _type_has_method_sequal_reduce = false; @@ -146,19 +195,98 @@ class EngineConfigNode : public Object { class EngineConfig : public ObjectRef { public: - explicit EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, int kv_cache_page_size, - int max_num_sequence, int max_total_sequence_length, - int max_single_sequence_length, int prefill_chunk_size, - int max_history_size, KVStateKind kv_state_kind, - SpeculativeMode speculative_mode, int spec_draft_length); + /*! \brief Create EngineConfig from JSON object and inferred config. */ + TVM_DLL static EngineConfig FromJSONAndInferredConfig( + const picojson::object& json, const InferrableEngineConfig& inferred_config); - /*! \brief Create EngineConfig from JSON string. */ - static EngineConfig FromJSONString(const std::string& json_str); + /*! + * \brief Get all the models and model libs from the JSON string for engine initialization. + * \return The parsed models/model libs from config or error message. + */ + TVM_DLL static Result>> + GetModelsAndModelLibsFromJSONString(const std::string& json_str); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); }; +/*! \brief A subset of engine config that is inferrable. */ +struct InferrableEngineConfig { + std::optional max_num_sequence; + std::optional max_total_sequence_length; + std::optional max_single_sequence_length; + std::optional prefill_chunk_size; + std::optional max_history_size; + std::optional kv_state_kind; + + /*! \brief Infer the config for KV cache from a given initial config. */ + TVM_DLL static Result InferForKVCache( + EngineMode mode, Device device, double gpu_memory_utilization, + const std::vector& model_configs, + const std::vector& model_metadata, InferrableEngineConfig init_config, + bool verbose); + /*! \brief Infer the config for RNN state from a given initial config. */ + TVM_DLL static Result InferForRNNState( + EngineMode mode, Device device, double gpu_memory_utilization, + const std::vector& model_configs, + const std::vector& model_metadata, InferrableEngineConfig init_config, + bool verbose); +}; + +/****************** Config utils ******************/ + +/*! \brief Check if the models use KV cache or RNN state. */ +Result ModelsUseKVCache(const std::vector& model_configs); + +inline std::string EngineModeToString(EngineMode mode) { + return mode == EngineMode::kLocal ? "local" + : mode == EngineMode::kInteractive ? "interactive" + : "server"; +} + +inline EngineMode EngineModeFromString(const std::string& mode) { + if (mode == "local") { + return EngineMode::kLocal; + } else if (mode == "interactive") { + return EngineMode::kInteractive; + } else if (mode == "server") { + return EngineMode::kServer; + } else { + LOG(FATAL) << "Invalid engine mode string: " << mode; + } +} + +inline std::string SpeculativeModeToString(SpeculativeMode speculative_mode) { + return speculative_mode == SpeculativeMode::kDisable ? "disable" + : speculative_mode == SpeculativeMode::kSmallDraft ? "small_draft" + : "eagle"; +} + +inline SpeculativeMode SpeculativeModeFromString(const std::string& speculative_mode) { + if (speculative_mode == "disable") { + return SpeculativeMode::kDisable; + } else if (speculative_mode == "small_draft") { + return SpeculativeMode::kSmallDraft; + } else if (speculative_mode == "eagle") { + return SpeculativeMode::kEagle; + } else { + LOG(FATAL) << "Invalid speculative mode string: " << speculative_mode; + } +} + +inline std::string KVStateKindToString(KVStateKind kv_state_kind) { + return kv_state_kind == KVStateKind::kKVCache ? "kv_cache" : "rnn_State"; +} + +inline KVStateKind KVStateKindFromString(const std::string& kv_state_kind) { + if (kv_state_kind == "kv_cache") { + return KVStateKind::kKVCache; + } else if (kv_state_kind == "rnn_state") { + return KVStateKind::kRNNState; + } else { + LOG(FATAL) << "Invalid kv state kind string: " << kv_state_kind; + } +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 297eba8b10..6fd6188562 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -17,6 +17,8 @@ #include #include +#include "../support/json_parser.h" +#include "../support/result.h" #include "../tokenizers.h" #include "engine_actions/action.h" #include "engine_actions/action_commons.h" @@ -45,61 +47,71 @@ class EngineImpl : public Engine { public: /********************** Engine Management **********************/ - explicit EngineImpl(EngineConfig engine_config, DLDevice device, - Optional request_stream_callback, - Optional trace_recorder) { - // Step 1. Initialize metadata and singleton states inside the engine - this->estate_->Reset(); - // Being "-1" means there is no limit on single sequence length. - if (engine_config->max_single_sequence_length == -1) { - engine_config->max_single_sequence_length = std::numeric_limits::max(); + static Result Create(const std::string& engine_config_json_str, + DLDevice device, + Optional request_stream_callback, + Optional trace_recorder) { + using TResult = Result; + std::unique_ptr n = std::make_unique(); + + // - Read the models and model libs from the EngineConfig JSON string. + Result>> models_and_model_libs_res = + EngineConfig::GetModelsAndModelLibsFromJSONString(engine_config_json_str); + if (models_and_model_libs_res.IsErr()) { + return TResult::Error(models_and_model_libs_res.UnwrapErr()); } - this->request_stream_callback_ = std::move(request_stream_callback); - this->trace_recorder_ = trace_recorder; - - // Step 2. Initialize each model independently. - // Create the logit processor and sampler. - this->models_.clear(); - this->model_workspaces_.clear(); - + std::vector> models_and_model_libs = + models_and_model_libs_res.Unwrap(); + ICHECK_GE(models_and_model_libs.size(), 1); + // - Initialize singleton states inside the engine. + n->estate_->Reset(); + n->request_stream_callback_ = std::move(request_stream_callback); + n->trace_recorder_ = trace_recorder; + n->device_ = device; + // - Load model config, create a shared disco session when tensor + // parallelism is enabled. std::vector model_configs; - model_configs.push_back(Model::LoadModelConfig(engine_config->model)); - for (const auto& model_path : engine_config->additional_models) { - model_configs.push_back(Model::LoadModelConfig(model_path)); + for (int i = 0; i < static_cast(models_and_model_libs.size()); ++i) { + const auto& [model_str, model_lib] = models_and_model_libs[i]; + Result model_config_res = Model::LoadModelConfig(model_str); + if (model_config_res.IsErr()) { + return TResult::Error("Model " + std::to_string(i) + + " has invalid mlc-chat-config.json: " + model_config_res.UnwrapErr()); + } + model_configs.push_back(model_config_res.Unwrap()); } - - Optional session = CreateDiscoSession(model_configs, device); - - auto f_create_model = [this, &engine_config, &device, &trace_recorder, &model_configs, - &session](const String& model_path, const String& model_lib_path, - int model_index) { - Model model = Model::Create(model_lib_path, std::move(model_path), model_configs[model_index], - device, engine_config->max_num_sequence, session, + Optional session = n->CreateDiscoSession(model_configs, device); + // - Initialize each model independently. + n->models_.clear(); + for (int i = 0; i < static_cast(models_and_model_libs.size()); ++i) { + const auto& [model_str, model_lib] = models_and_model_libs[i]; + Model model = Model::Create(model_lib, model_str, model_configs[i], device, session, /*trace_enabled=*/trace_recorder.defined()); + n->models_.push_back(model); + } + // - Automatically infer the missing fields in EngineConfig JSON strings + // and get the final EngineConfig. + Result engine_config_res = + n->AutoDecideEngineConfig(engine_config_json_str, model_configs); + if (engine_config_res.IsErr()) { + return TResult::Error(engine_config_res.UnwrapErr()); + } + EngineConfig engine_config = engine_config_res.Unwrap(); + // - Load model weights, create KV cache and workspace. + n->model_workspaces_.clear(); + for (const Model& model : n->models_) { + model->LoadParams(); + model->SetMaxNumSequence(engine_config->max_num_sequence); + model->SetPrefillChunkSize(engine_config->prefill_chunk_size); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, engine_config->max_total_sequence_length, engine_config->prefill_chunk_size, engine_config->max_history_size, engine_config->kv_state_kind); - CHECK_GE(model->GetMaxWindowSize(), engine_config->max_single_sequence_length) - << "The window size of the model, " << model->GetMaxWindowSize() - << ", is smaller than the pre-defined max single sequence length, " - << engine_config->max_single_sequence_length; - this->models_.push_back(model); - this->model_workspaces_.push_back( + n->model_workspaces_.push_back( ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); - }; - - f_create_model(engine_config->model, engine_config->model_lib_path, /*model_index=*/0); - CHECK_EQ(engine_config->additional_models.size(), - engine_config->additional_model_lib_paths.size()) - << "The additional model and lib path list has mismatched size."; - for (int i = 0; i < static_cast(engine_config->additional_models.size()); ++i) { - f_create_model(engine_config->additional_models[i], - engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1); } - - // Step 3. Initialize tokenizer and grammar - this->tokenizer_ = Tokenizer::FromPath(engine_config->model); + // - Initialize tokenizer and grammar + n->tokenizer_ = Tokenizer::FromPath(engine_config->model); std::string token_table_postproc_method; if (model_configs[0].count("token_table_postproc_method") == 0) { // Backward compatibility: use "byte_fallback" by default @@ -108,73 +120,77 @@ class EngineImpl : public Engine { token_table_postproc_method = model_configs[0].at("token_table_postproc_method").get(); } - this->token_table_ = - Tokenizer::PostProcessTokenTable(tokenizer_->TokenTable(), token_table_postproc_method); - this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); - - // Step 4. Initialize engine actions that represent state transitions. + n->token_table_ = + Tokenizer::PostProcessTokenTable(n->tokenizer_->TokenTable(), token_table_postproc_method); + n->grammar_init_context_storage_ = GrammarInitContextStorage(n->token_table_); + // - Create the logit processor and sampler, and + // the DraftTokenWorkspaceManager for speculative decoding. int max_num_tokens = engine_config->max_num_sequence; DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr}; if (engine_config->speculative_mode != SpeculativeMode::kDisable) { max_num_tokens *= engine_config->spec_draft_length + 1; - draft_token_workspace_manager = models_[0]->CreateDraftTokenWorkspaceManager(max_num_tokens); + draft_token_workspace_manager = + n->models_[0]->CreateDraftTokenWorkspaceManager(max_num_tokens); draft_token_workspace_manager->AllocWorkspace( - &model_workspaces_[0], + &n->model_workspaces_[0], /*require_hidden_states=*/engine_config->speculative_mode == SpeculativeMode::kEagle); } LogitProcessor logit_processor = - this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); - Sampler sampler = this->models_[0]->CreateSampler( - max_num_tokens, static_cast(this->models_.size()), trace_recorder); + n->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); + Sampler sampler = n->models_[0]->CreateSampler( + max_num_tokens, static_cast(n->models_.size()), trace_recorder); + // - Initialize engine actions that represent state transitions. if (engine_config->speculative_mode != SpeculativeMode::kDisable) { // Speculative decoding is only possible for more than one model. - ICHECK_GT(this->models_.size(), 1U); + ICHECK_GT(n->models_.size(), 1U); switch (engine_config->speculative_mode) { case SpeculativeMode::kEagle: - this->actions_ = { - EngineAction::EagleNewRequestPrefill(this->models_, // + n->actions_ = { + EngineAction::EagleNewRequestPrefill(n->models_, // logit_processor, // sampler, // - this->model_workspaces_, // + n->model_workspaces_, // draft_token_workspace_manager, // engine_config, // - this->trace_recorder_), - EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler, - this->model_workspaces_, draft_token_workspace_manager, - this->trace_recorder_, - engine_config->spec_draft_length), - EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, - this->model_workspaces_, draft_token_workspace_manager, - engine_config, this->trace_recorder_)}; + n->trace_recorder_), + EngineAction::EagleBatchDraft(n->models_, logit_processor, sampler, + n->model_workspaces_, draft_token_workspace_manager, + n->trace_recorder_, engine_config->spec_draft_length), + EngineAction::EagleBatchVerify(n->models_, logit_processor, sampler, + n->model_workspaces_, draft_token_workspace_manager, + engine_config, n->trace_recorder_)}; break; default: - this->actions_ = { - EngineAction::NewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::BatchDraft(this->models_, logit_processor, sampler, - this->model_workspaces_, draft_token_workspace_manager, - this->trace_recorder_), - EngineAction::BatchVerify(this->models_, logit_processor, sampler, - this->model_workspaces_, draft_token_workspace_manager, - engine_config, this->trace_recorder_)}; + n->actions_ = { + EngineAction::NewRequestPrefill(n->models_, // + logit_processor, // + sampler, // + n->model_workspaces_, // + engine_config, // + n->trace_recorder_), + EngineAction::BatchDraft(n->models_, logit_processor, sampler, n->model_workspaces_, + draft_token_workspace_manager, n->trace_recorder_), + EngineAction::BatchVerify(n->models_, logit_processor, sampler, n->model_workspaces_, + draft_token_workspace_manager, engine_config, + n->trace_recorder_)}; } } else { - this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::BatchDecode(this->models_, logit_processor, sampler, - this->trace_recorder_)}; + n->actions_ = { + EngineAction::NewRequestPrefill(n->models_, // + logit_processor, // + sampler, // + n->model_workspaces_, // + engine_config, // + n->trace_recorder_), + EngineAction::BatchDecode(n->models_, logit_processor, sampler, n->trace_recorder_)}; } - // Step 4. Automatically set the threading backend max concurrency. - this->engine_config_ = engine_config; - SetThreadMaxConcurrency(); + // - Automatically set the threading backend max concurrency. + n->engine_config_ = engine_config; + n->SetThreadMaxConcurrency(); + // - Get the default generation config from the first model. + GenerationConfig default_generation_cfg = + GenerationConfig::GetDefaultFromModelConfig(model_configs[0]); + return TResult::Ok({std::move(n), std::move(engine_config), std::move(default_generation_cfg)}); } void Reset() final { @@ -321,7 +337,8 @@ class EngineImpl : public Engine { } /************** Utility Functions **************/ - Optional CreateDiscoSession(std::vector model_configs, Device device) { + Optional CreateDiscoSession(const std::vector& model_configs, + Device device) { const auto& base_model_config = model_configs[0]; auto f_get_num_shards = [](const picojson::object& model_config) -> int { @@ -373,6 +390,95 @@ class EngineImpl : public Engine { } private: + Result AutoDecideEngineConfig(const std::string& engine_config_json_str, + const std::vector& model_configs) { + using TResult = Result; + picojson::value config_json; + std::string err = picojson::parse(config_json, engine_config_json_str); + if (!err.empty()) { + return TResult::Error(err); + } + picojson::object config = config_json.get(); + ObjectPtr n = make_object(); + + // - Get the engine mode and maximum GPU utilization for inference. + EngineMode mode = EngineModeFromString(json::Lookup(config, "mode")); + double gpu_memory_utilization = + json::LookupOrDefault(config, "gpu_memory_utilization", n->gpu_memory_utilization); + bool verbose = json::LookupOrDefault(config, "verbose", n->verbose); + + // - Get the config fields that can be automatically inferred. + std::optional max_num_sequence = + json::LookupOptional(config, "max_num_sequence"); + std::optional max_total_sequence_length = + json::LookupOptional(config, "max_total_sequence_length"); + std::optional max_single_sequence_length = + json::LookupOptional(config, "max_single_sequence_length"); + std::optional prefill_chunk_size = + json::LookupOptional(config, "prefill_chunk_size"); + std::optional max_history_size = + json::LookupOptional(config, "max_history_size"); + std::optional kv_state_kind_str = + json::LookupOptional(config, "kv_state_kind"); + std::optional kv_state_kind; + if (kv_state_kind_str.has_value()) { + kv_state_kind = KVStateKindFromString(kv_state_kind_str.value()); + } + InferrableEngineConfig inferrable_cfg{max_num_sequence, max_total_sequence_length, + max_single_sequence_length, prefill_chunk_size, + max_history_size, kv_state_kind}; + + // - Get the model metadata. + std::vector model_metadata; + for (const Model& model : models_) { + model_metadata.push_back(model->GetMetadata()); + } + // - Select from kv cache or RNN state. + Result use_kv_cache = ModelsUseKVCache(model_configs); + if (use_kv_cache.IsErr()) { + return TResult::Error(use_kv_cache.UnwrapErr()); + } + KVStateKind inferred_kv_state_kind; + Result inferrable_cfg_res; + if (use_kv_cache.Unwrap()) { + inferred_kv_state_kind = KVStateKind::kKVCache; + // - Check if the kv state kind from config is valid. + if (kv_state_kind.has_value() && kv_state_kind.value() != inferred_kv_state_kind) { + return TResult::Error( + "Invalid kv state kind in EngineConfig. The models use KV cache, but RNN state is " + "specified in EngineConfig."); + } + // - Infer configuration. + inferrable_cfg_res = InferrableEngineConfig::InferForKVCache( + mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg, + verbose); + } else { + inferred_kv_state_kind = KVStateKind::kRNNState; + // - Check if the kv state kind from config is valid. + if (kv_state_kind.has_value() && kv_state_kind.value() != inferred_kv_state_kind) { + return TResult::Error( + "Invalid kv state kind in EngineConfig. The models use RNN state, but KV cache is " + "specified in EngineConfig."); + } + // - Infer configuration. + inferrable_cfg_res = InferrableEngineConfig::InferForRNNState( + mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg, + verbose); + } + + if (inferrable_cfg_res.IsErr()) { + return TResult::Error(inferrable_cfg_res.UnwrapErr()); + } + inferrable_cfg = inferrable_cfg_res.Unwrap(); + ICHECK(inferrable_cfg.max_num_sequence.has_value()); + ICHECK(inferrable_cfg.max_total_sequence_length.has_value()); + ICHECK(inferrable_cfg.max_single_sequence_length.has_value()); + ICHECK(inferrable_cfg.prefill_chunk_size.has_value()); + ICHECK(inferrable_cfg.max_history_size.has_value()); + ICHECK(inferrable_cfg.kv_state_kind.has_value()); + return TResult::Ok(EngineConfig::FromJSONAndInferredConfig(config, inferrable_cfg)); + } + /*! \brief Set the maximum threading backend concurrency. */ void SetThreadMaxConcurrency() { int host_cpu_usage = 1; @@ -408,6 +514,8 @@ class EngineImpl : public Engine { GrammarInitContextStorage grammar_init_context_storage_; // Models Array models_; + // Device that the models run on. + Device device_; // Workspace of each model. std::vector model_workspaces_; // Request stream callback function @@ -418,12 +526,12 @@ class EngineImpl : public Engine { Optional trace_recorder_; }; -std::unique_ptr Engine::Create(EngineConfig engine_config, Device device, - Optional request_stream_callback, - Optional trace_recorder) { - return std::make_unique(std::move(engine_config), device, - std::move(request_stream_callback), - std::move(trace_recorder)); +Result Engine::Create(const std::string& engine_config_json_str, + Device device, + Optional request_stream_callback, + Optional trace_recorder) { + return EngineImpl::Create(engine_config_json_str, device, std::move(request_stream_callback), + std::move(trace_recorder)); } /*! \brief Clear global memory manager */ @@ -445,13 +553,21 @@ class EngineModule : public ModuleNode { TVM_MODULE_VTABLE_ENTRY("reset", &EngineModule::Reset); TVM_MODULE_VTABLE_ENTRY("get_request_stream_callback", &EngineModule::GetRequestStreamCallback); TVM_MODULE_VTABLE_ENTRY("set_request_stream_callback", &EngineModule::SetRequestStreamCallback); + TVM_MODULE_VTABLE_ENTRY("get_default_generation_config", + &EngineModule::GetDefaultGenerationConfigJSONString); TVM_MODULE_VTABLE_END(); /*! \brief Initialize the engine with config and other fields. */ - void Init(EngineConfig engine_config, Device device, Optional request_stream_callback, + void Init(const std::string& engine_config_json_str, Device device, + Optional request_stream_callback, Optional trace_recorder) { - this->engine_ = Engine::Create(std::move(engine_config), device, - std::move(request_stream_callback), std::move(trace_recorder)); + Result output_res = + Engine::Create(engine_config_json_str, device, std::move(request_stream_callback), + std::move(trace_recorder)); + CHECK(output_res.IsOk()) << output_res.UnwrapErr(); + EngineCreationOutput output = output_res.Unwrap(); + this->engine_ = std::move(output.reloaded_engine); + this->default_generation_cfg_json_str_ = output.default_generation_cfg->AsJSONString(); } /*! \brief Construct an EngineModule. */ static tvm::runtime::Module Create() { return Module(make_object()); } @@ -473,6 +589,12 @@ class EngineModule : public ModuleNode { void Reset() { return GetEngine()->Reset(); } /*! \brief Redirection to `Engine::Stats` */ String Stats() { return GetEngine()->Stats(); } + /*! \brief Return the default generation config string. */ + String GetDefaultGenerationConfigJSONString() { + CHECK(!default_generation_cfg_json_str_.empty()) + << "The default generation config has not been set."; + return default_generation_cfg_json_str_; + } private: Engine* GetEngine() { @@ -481,6 +603,7 @@ class EngineModule : public ModuleNode { } std::unique_ptr engine_ = nullptr; + String default_generation_cfg_json_str_; }; TVM_REGISTER_GLOBAL("mlc.serve.create_engine").set_body_typed(EngineModule::Create); diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index 2fc0a4d730..7bbe942227 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -21,6 +21,18 @@ using namespace tvm::runtime; typedef TypedPackedFunc)> FRequestStreamCallback; +class Engine; + +/*! + * \brief The output of engine creation, including the created engine and + * the default generation config for requests. + */ +struct EngineCreationOutput { + std::unique_ptr reloaded_engine; + EngineConfig completed_engine_config; + GenerationConfig default_generation_cfg; +}; + /*! * \brief The engine interface for request serving in MLC LLM. * The engine can run one or multiple LLM models internally for @@ -50,15 +62,16 @@ class Engine { /*! * \brief Create an engine in unique pointer. - * \param engine_config The engine config. + * \param engine_config_json_str The serialized JSON string of the engine config. * \param device The device where the run models. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. - * \return The created Engine in pointer. + * \return The created Engine in pointer, and the default generation config. */ - static std::unique_ptr Create(EngineConfig engine_config, Device device, - Optional request_stream_callback, - Optional trace_recorder); + static Result Create(const std::string& engine_config_json_str, + Device device, + Optional request_stream_callback, + Optional trace_recorder); /*! \brief Reset the engine, clean up all running data and statistics. */ virtual void Reset() = 0; diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index 55ab0a1dff..a0ae4d98f3 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -5,8 +5,8 @@ #include "grammar_parser.h" -#include "../../metadata/json_parser.h" #include "../../support/encoding.h" +#include "../../support/json_parser.h" #include "grammar_builder.h" namespace mlc { diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index be76b40e2e..0bd4126b40 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -12,6 +12,7 @@ #include +#include "../support/json_parser.h" #include "config.h" #include "logit_processor.h" @@ -26,13 +27,13 @@ class ModelImpl; TVM_REGISTER_OBJECT_TYPE(ModelObj); Model Model::Create(String reload_lib_path, String model_path, const picojson::object& model_config, - DLDevice device, int max_num_sequence, const Optional& session, - bool trace_enabled) { - return Model(make_object(reload_lib_path, model_path, model_config, device, - max_num_sequence, session, trace_enabled)); + DLDevice device, const Optional& session, bool trace_enabled) { + return Model(make_object(reload_lib_path, model_path, model_config, device, session, + trace_enabled)); } -picojson::object Model::LoadModelConfig(const String& model_path) { +Result Model::LoadModelConfig(const String& model_path) { + using TResult = Result; picojson::object model_config; std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); std::ostringstream config_ostream; @@ -42,10 +43,10 @@ picojson::object Model::LoadModelConfig(const String& model_path) { picojson::value config_json; std::string err = picojson::parse(config_json, config_str); if (!err.empty()) { - LOG(FATAL) << err; + return TResult::Error(err); } picojson::object config = config_json.get(); - return config; + return TResult::Ok(config); } class ModelImpl : public ModelObj { @@ -55,34 +56,21 @@ class ModelImpl : public ModelObj { * \sa Model::Create */ explicit ModelImpl(String reload_lib_path, String model_path, picojson::object model_config, - DLDevice device, int max_num_sequence, const Optional& session, - bool trace_enabled) - : device_(device) { + DLDevice device, const Optional& session, bool trace_enabled) + : model_(model_path), device_(device) { // Step 1. Process model config json string. LoadModelConfigJSON(model_config); // Step 2. Initialize vm, we use the packed function mechanism // so there is no explicit abi dependency on these extra // classes other than basic tvm runtime. this->ft_.Init(reload_lib_path, device_, model_config, session); - // Step 3. Load params in nd-array cache. - this->params_ = ft_.LoadParams(model_path, device_); - // Step 4. Set max_num_sequence - this->max_num_sequence_ = max_num_sequence; - // Step 5. Reset + // Step 3. Reset this->Reset(); - // Step 6. Initialize the shared NDArray. - Device device_host{DLDeviceType::kDLCPU, 0}; - memory::Allocator* allocator = - memory::MemoryManager::GetOrCreateAllocator(device_host, memory::AllocatorType::kNaive); - ICHECK_NOTNULL(allocator); - token_ids_storage_ = memory::Storage( - allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)), allocator); - this->logit_pos_arr_ = NDArray::Empty({max_num_sequence}, DataType::Int(32), device_host); - // Step 7. Set model type - if (model_config["model_type"].get().find("rwkv") != std::string::npos) { + // Step 4. Set model type + if (json::Lookup(model_config, "model_type").find("rwkv") != std::string::npos) { this->kind = KVStateKind::kRNNState; } else { - this->kind = KVStateKind::kAttention; + this->kind = KVStateKind::kKVCache; } } @@ -104,6 +92,7 @@ class ModelImpl : public ModelObj { } ICHECK_EQ(token_ids_nd->ndim, 1); ICHECK_EQ(token_ids_nd->shape[0], num_tokens); + ICHECK_NE(prefill_chunk_size_, -1); auto token_ids_dref_or_nd = ft_.CopyToWorker0(token_ids_nd, "token_ids", {prefill_chunk_size_}); ObjectRef embeddings = ft_.embed_func_(token_ids_dref_or_nd, params_); @@ -249,6 +238,7 @@ class ModelImpl : public ModelObj { ShapeTuple embedding_shape{1, total_length, hidden_size_}; embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); } + ICHECK_NE(max_num_sequence_, -1); ObjectRef logit_pos_dref_or_nd = ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); // args: embeddings, logit_pos, kv_cache, params @@ -576,7 +566,7 @@ class ModelImpl : public ModelObj { void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, int prefill_chunk_size, int max_history_size, KVStateKind kv_state_kind) final { - if (kv_state_kind == KVStateKind::kAttention) { + if (kv_state_kind == KVStateKind::kKVCache) { IntTuple max_num_sequence_tuple{max_num_sequence}; IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; @@ -619,6 +609,8 @@ class ModelImpl : public ModelObj { /************** Raw Info Query **************/ + ModelMetadata GetMetadata() const final { return ft_.model_metadata_; } + int GetNumAvailablePages() const final { if (this->kind == KVStateKind::kRNNState) { // RNNState does not introduce new page at runtime @@ -639,14 +631,32 @@ class ModelImpl : public ModelObj { /*********************** Utilities ***********************/ + void LoadParams() final { this->params_ = ft_.LoadParams(model_, device_); } + + void SetMaxNumSequence(int max_num_sequence) final { + this->max_num_sequence_ = max_num_sequence; + this->logit_pos_arr_ = + NDArray::Empty({max_num_sequence}, DataType::Int(32), Device{DLDeviceType::kDLCPU, 0}); + } + + void SetPrefillChunkSize(int prefill_chunk_size) final { + this->prefill_chunk_size_ = prefill_chunk_size; + Device device_host{DLDeviceType::kDLCPU, 0}; + memory::Allocator* allocator = + memory::MemoryManager::GetOrCreateAllocator(device_host, memory::AllocatorType::kNaive); + ICHECK_NOTNULL(allocator); + token_ids_storage_ = memory::Storage( + allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)), allocator); + } + LogitProcessor CreateLogitProcessor(int max_num_token, - Optional trace_recorder) { + Optional trace_recorder) final { return LogitProcessor(max_num_token, vocab_size_, &this->ft_, device_, std::move(trace_recorder)); } Sampler CreateSampler(int max_num_sample, int num_models, - Optional trace_recorder) { + Optional trace_recorder) final { if (Sampler::SupportGPUSampler(device_)) { return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_, std::move(trace_recorder)); @@ -660,11 +670,6 @@ class ModelImpl : public ModelObj { return num_shards_ > 1 ? num_shards_ : 0; } - int GetMaxWindowSize() const final { - // Being "-1" means there is no limit on the window size. - return max_window_size_ != -1 ? max_window_size_ : std::numeric_limits::max(); - } - ObjectRef AllocEmbeddingTensor() final { // Allocate the embedding tensor. ObjectRef embedding = ft_.alloc_embedding_tensor_func_(); @@ -678,6 +683,7 @@ class ModelImpl : public ModelObj { NDArray embedding_nd = Downcast(embedding); embedding_shape = embedding_nd.Shape(); } + ICHECK_NE(prefill_chunk_size_, -1); ICHECK_EQ(embedding_shape.size(), 2); ICHECK_GE(embedding_shape[0], prefill_chunk_size_); this->hidden_size_ = embedding_shape[1]; @@ -697,8 +703,9 @@ class ModelImpl : public ModelObj { hidden_states_nd = Downcast(hidden_states); } ShapeTuple hidden_states_shape = hidden_states_nd.Shape(); + ICHECK_NE(prefill_chunk_size_, -1); ICHECK_EQ(hidden_states_shape.size(), 2); - ICHECK_EQ(hidden_states_shape[0], prefill_chunk_size_); + ICHECK_GE(hidden_states_shape[0], prefill_chunk_size_); this->hidden_size_ = hidden_states_shape[1]; this->hidden_states_dtype_ = hidden_states_nd->dtype; return hidden_states; @@ -731,6 +738,7 @@ class ModelImpl : public ModelObj { NDArray indices_nd = logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ICHECK_NE(max_num_sequence_, -1); ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos", {max_num_sequence_}); ft_.gather_hidden_states_func_(input, indices_device, dst_view); return dst_view; @@ -741,6 +749,7 @@ class ModelImpl : public ModelObj { NDArray indices_nd = logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ICHECK_NE(max_num_sequence_, -1); ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos", {max_num_sequence_}); ft_.scatter_hidden_states_func_(input, indices_device, *dst); } @@ -752,6 +761,7 @@ class ModelImpl : public ModelObj { NDArray indices_nd = logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ICHECK_NE(max_num_sequence_, -1); ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); ft_.gather_probs_func_(input, indices_device, dst_view); @@ -763,6 +773,7 @@ class ModelImpl : public ModelObj { NDArray indices_nd = logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ICHECK_NE(max_num_sequence_, -1); ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); ft_.scatter_probs_func_(input, indices_device, *dst); @@ -776,50 +787,22 @@ class ModelImpl : public ModelObj { private: /*! \brief Load model configuration from JSON. */ - picojson::object LoadModelConfigJSON(picojson::object config) { - if (config.count("context_window_size")) { - CHECK(config["context_window_size"].is()); - this->max_window_size_ = config["context_window_size"].get(); - } else { - LOG(FATAL) << "Key \"context_window_size\" not found."; - } - if (config.count("sliding_window_size")) { - CHECK(config["sliding_window_size"].is()); - this->sliding_window_size_ = config["sliding_window_size"].get(); - CHECK(sliding_window_size_ == -1 || sliding_window_size_ > 0) - << "Sliding window should be either -1 (which means disabled) of positive"; - } - if (config.count("attention_sink_size")) { - CHECK(config["attention_sink_size"].is()); - this->attention_sink_size_ = config["attention_sink_size"].get(); - this->attention_sink_size_ = std::max(this->attention_sink_size_, 0); - } - if (config.count("tensor_parallel_shards")) { - CHECK(config["tensor_parallel_shards"].is()); - this->num_shards_ = config["tensor_parallel_shards"].get(); - } else { - LOG(FATAL) << "Key \"tensor_parallel_shards\" not found."; - } - if (config.count("prefill_chunk_size")) { - CHECK(config["prefill_chunk_size"].is()); - this->prefill_chunk_size_ = config["prefill_chunk_size"].get(); - } else { - LOG(FATAL) << "Key \"prefill_chunk_size\" not found."; - } - if (config.count("vocab_size")) { - CHECK(config["vocab_size"].is()); - this->vocab_size_ = config["vocab_size"].get(); - } else { - LOG(FATAL) << "Key \"vocab_size\" not found."; - } - - return config; + void LoadModelConfigJSON(const picojson::object& config) { + this->sliding_window_size_ = + json::LookupOrDefault(config, "sliding_window_size", this->sliding_window_size_); + CHECK(sliding_window_size_ == -1 || sliding_window_size_ > 0) + << "Sliding window should be either -1 (which means disabled) of positive"; + this->attention_sink_size_ = + json::LookupOrDefault(config, "attention_sink_size", this->attention_sink_size_); + this->attention_sink_size_ = std::max(this->attention_sink_size_, 0); + this->num_shards_ = json::Lookup(config, "tensor_parallel_shards"); + this->vocab_size_ = json::Lookup(config, "vocab_size"); } //---------------------------- // Model configurations //---------------------------- - int max_window_size_ = -1; + std::string model_; int sliding_window_size_ = -1; int attention_sink_size_ = 0; int num_shards_ = -1; diff --git a/cpp/serve/model.h b/cpp/serve/model.h index f587969bfb..1ac4e4001c 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -12,6 +12,7 @@ #include #include "../base.h" +#include "../support/result.h" #include "config.h" #include "draft_token_workspace_manager.h" #include "event_trace_recorder.h" @@ -254,6 +255,9 @@ class ModelObj : public Object { /************** Raw Info Query **************/ + /*! \brief Return the metadata JSON object of the model. */ + virtual ModelMetadata GetMetadata() const = 0; + /*! \brief Get the number of available pages in KV cache. */ virtual int GetNumAvailablePages() const = 0; @@ -262,6 +266,21 @@ class ModelObj : public Object { /*********************** Utilities ***********************/ + /*! \brief Load the model's weight parameters, which is not loaded at construction time. */ + virtual void LoadParams() = 0; + + /*! + * \brief Set the maximum number of sequences to be processed for the model, + * which is not initialized at construction time. + */ + virtual void SetMaxNumSequence(int max_num_sequence) = 0; + + /*! + * \brief Set the prefill chunk size for the model, + * which is not initialized at construction time. + */ + virtual void SetPrefillChunkSize(int prefill_chunk_size) = 0; + /*! \brief Create a logit processor from this model. */ virtual LogitProcessor CreateLogitProcessor(int max_num_token, Optional trace_recorder) = 0; @@ -279,9 +298,6 @@ class ModelObj : public Object { */ virtual int EstimateHostCPURequirement() const = 0; - /*! \brief Get the max window size of the model. "-1" means infinite length. */ - virtual int GetMaxWindowSize() const = 0; - /*! \brief Allocate an embedding tensor with the prefill chunk size. */ virtual ObjectRef AllocEmbeddingTensor() = 0; @@ -331,22 +347,20 @@ class Model : public ObjectRef { * \param model_path The path to the model weight parameters. * \param model_config The model config json object. * \param device The device to run the model on. - * \param max_num_sequence The maximum number of sequences to be processed * \param session The session to run the model on. * \param trace_enabled A boolean indicating whether tracing is enabled. * \return The created runtime module. */ TVM_DLL static Model Create(String reload_lib_path, String model_path, const picojson::object& model_config, DLDevice device, - int max_num_sequence, const Optional& session, - bool trace_enabled); + const Optional& session, bool trace_enabled); /*! * Load the model config from the given model path. * \param model_path The path to the model weight parameters. * \return The model config json object. */ - static picojson::object LoadModelConfig(const String& model_path); + TVM_DLL static Result LoadModelConfig(const String& model_path); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj); }; diff --git a/cpp/serve/request.cc b/cpp/serve/request.cc index 8ecd20b18e..bd955ec846 100644 --- a/cpp/serve/request.cc +++ b/cpp/serve/request.cc @@ -67,9 +67,11 @@ Request Request::FromUntokenized(const Request& request, const Tokenizer& tokeni } TVM_REGISTER_GLOBAL("mlc.serve.Request") - .set_body_typed([](String id, Array inputs, String generation_cfg_json) { + .set_body_typed([](String id, Array inputs, String generation_cfg_json_str, + Optional default_generation_cfg_json_str) { return Request(std::move(id), std::move(inputs), - GenerationConfig(std::move(generation_cfg_json))); + GenerationConfig(std::move(generation_cfg_json_str), + std::move(default_generation_cfg_json_str))); }); TVM_REGISTER_GLOBAL("mlc.serve.RequestGetInputs").set_body_typed([](Request request) { diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index 080853d465..8c3cadd358 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -13,6 +13,7 @@ #include #include +#include "../support/result.h" #include "engine.h" #include "request.h" @@ -36,8 +37,8 @@ enum class InstructionKind : int { /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineImpl : public ThreadedEngine { public: - void InitBackgroundEngine(Device device, Optional request_stream_callback, - Optional trace_recorder) final { + void InitThreadedEngine(Device device, Optional request_stream_callback, + Optional trace_recorder) final { device_ = device; CHECK(request_stream_callback.defined()) << "ThreadedEngine requires request stream callback function, but it is not given."; @@ -45,17 +46,23 @@ class ThreadedEngineImpl : public ThreadedEngine { trace_recorder_ = trace_recorder; } - void Reload(EngineConfig engine_config) final { + void Reload(String engine_config_json_str) final { bool need_notify = false; { std::lock_guard lock(background_loop_mutex_); - instruction_queue_.emplace_back(InstructionKind::kReloadEngine, std::move(engine_config)); + instruction_queue_.emplace_back(InstructionKind::kReloadEngine, + std::move(engine_config_json_str)); ++pending_request_operation_cnt_; need_notify = engine_waiting_; } if (need_notify) { background_loop_cv_.notify_one(); } + { + std::unique_lock lock(reload_unload_mutex_); + reload_finished_ = false; + reload_unload_cv_.wait(lock, [this] { return reload_finished_; }); + } } void Unload() final { @@ -69,6 +76,11 @@ class ThreadedEngineImpl : public ThreadedEngine { if (need_notify) { background_loop_cv_.notify_one(); } + { + std::unique_lock lock(reload_unload_mutex_); + unload_finished_ = false; + reload_unload_cv_.wait(lock, [this] { return unload_finished_; }); + } } void Reset() final { @@ -140,7 +152,7 @@ class ThreadedEngineImpl : public ThreadedEngine { EngineUnloadImpl(); } else if (kind == InstructionKind::kReloadEngine) { EngineUnloadImpl(); - EngineReloadImpl(Downcast(arg)); + EngineReloadImpl(Downcast(arg)); } else if (kind == InstructionKind::kResetEngine) { if (background_engine_ != nullptr) { background_engine_->Reset(); @@ -199,7 +211,23 @@ class ThreadedEngineImpl : public ThreadedEngine { request_stream_callback_cv_.notify_one(); } - /************** Debug/Profile **************/ + /************** Query/Profile/Debug **************/ + + String GetDefaultGenerationConfigJSONString() const final { + CHECK(!default_generation_cfg_json_str_.empty()) + << "The default generation config has not been set."; + return default_generation_cfg_json_str_; + }; + + String GetCompleteEngineConfigJSONString() const final { + CHECK(!complete_engine_config_json_str_.empty()) << "The engine config has not been set."; + return complete_engine_config_json_str_; + }; + + String Stats() final { + std::lock_guard lock(background_loop_mutex_); + return background_engine_->Stats(); + } void DebugCallFuncOnAllAllWorker(const String& func_name) final { bool need_notify = false; @@ -214,13 +242,8 @@ class ThreadedEngineImpl : public ThreadedEngine { } } - String Stats() final { - std::lock_guard lock(background_loop_mutex_); - return background_engine_->Stats(); - } - private: - void EngineReloadImpl(EngineConfig engine_config) { + void EngineReloadImpl(const std::string& engine_config_json_str) { auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { ICHECK_EQ(args.size(), 1); Array delta_outputs = args[0]; @@ -237,8 +260,19 @@ class ThreadedEngineImpl : public ThreadedEngine { }; Optional request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - background_engine_ = Engine::Create(std::move(engine_config), device_, - std::move(request_stream_callback), trace_recorder_); + Result output_res = Engine::Create( + engine_config_json_str, device_, std::move(request_stream_callback), trace_recorder_); + CHECK(output_res.IsOk()) << output_res.UnwrapErr(); + EngineCreationOutput output = output_res.Unwrap(); + background_engine_ = std::move(output.reloaded_engine); + default_generation_cfg_json_str_ = output.default_generation_cfg->AsJSONString(); + complete_engine_config_json_str_ = output.completed_engine_config->AsJSONString(); + { + // Wake up the thread waiting for reload finish. + std::lock_guard lock(reload_unload_mutex_); + reload_finished_ = true; + reload_unload_cv_.notify_one(); + } } void EngineUnloadImpl() { @@ -250,6 +284,14 @@ class ThreadedEngineImpl : public ThreadedEngine { tvm::runtime::Registry::Get("vm.builtin.memory_manager.clear"); ICHECK(fclear_memory_manager) << "Cannot find env function vm.builtin.memory_manager.clear"; (*fclear_memory_manager)(); + default_generation_cfg_json_str_ = ""; + complete_engine_config_json_str_ = ""; + } + { + // Wake up the thread waiting for unload finish. + std::lock_guard lock(reload_unload_mutex_); + unload_finished_ = true; + reload_unload_cv_.notify_one(); } } @@ -261,13 +303,19 @@ class ThreadedEngineImpl : public ThreadedEngine { PackedFunc request_stream_callback_; /*! \brief Event trace recorder. */ Optional trace_recorder_; + /*! \brief The complete engine config JSON string. */ + String complete_engine_config_json_str_; + /*! \brief The default generation config JSON string. */ + String default_generation_cfg_json_str_; /*! \brief The mutex ensuring only one thread can access critical regions. */ std::mutex background_loop_mutex_; std::mutex request_stream_callback_mutex_; + std::mutex reload_unload_mutex_; /*! \brief The condition variable preventing threaded engine from spinning. */ std::condition_variable background_loop_cv_; std::condition_variable request_stream_callback_cv_; + std::condition_variable reload_unload_cv_; /*! \brief A boolean flag denoting if the engine needs to exit background loop. */ std::atomic exit_now_ = false; @@ -303,13 +351,17 @@ class ThreadedEngineImpl : public ThreadedEngine { bool engine_waiting_ = false; /*! \brief A boolean flag indicating if the stream callback loop is waiting. */ bool stream_callback_waiting_ = false; + /*! \brief A boolean indicating if the engine reload has finished. */ + bool reload_finished_ = false; + /*! \brief A boolean indicating if the engine unload has finished. */ + bool unload_finished_ = false; }; /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { public: TVM_MODULE_VTABLE_BEGIN("mlc.serve.async_threaded_engine"); - TVM_MODULE_VTABLE_ENTRY("init_background_engine", &ThreadedEngineImpl::InitBackgroundEngine); + TVM_MODULE_VTABLE_ENTRY("init_threaded_engine", &ThreadedEngineImpl::InitThreadedEngine); TVM_MODULE_VTABLE_ENTRY("reload", &ThreadedEngineImpl::Reload); TVM_MODULE_VTABLE_ENTRY("add_request", &ThreadedEngineImpl::AddRequest); TVM_MODULE_VTABLE_ENTRY("abort_request", &ThreadedEngineImpl::AbortRequest); @@ -317,9 +369,13 @@ class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", &ThreadedEngineImpl::RunBackgroundStreamBackLoop); TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &ThreadedEngineImpl::ExitBackgroundLoop); + TVM_MODULE_VTABLE_ENTRY("get_default_generation_config", + &ThreadedEngineImpl::GetDefaultGenerationConfigJSONString); + TVM_MODULE_VTABLE_ENTRY("get_complete_engine_config", + &ThreadedEngineImpl::GetCompleteEngineConfigJSONString); + TVM_MODULE_VTABLE_ENTRY("stats", &ThreadedEngineImpl::Stats); TVM_MODULE_VTABLE_ENTRY("debug_call_func_on_all_worker", &ThreadedEngineImpl::DebugCallFuncOnAllAllWorker); - TVM_MODULE_VTABLE_ENTRY("stats", &ThreadedEngineImpl::Stats); TVM_MODULE_VTABLE_END(); }; diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index d0f2ebe2d7..b6afdcbb7c 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -39,14 +39,14 @@ class ThreadedEngine { * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. */ - virtual void InitBackgroundEngine(Device device, Optional request_stream_callback, - Optional trace_recorder) = 0; + virtual void InitThreadedEngine(Device device, Optional request_stream_callback, + Optional trace_recorder) = 0; /*! * \brief Reload the engine with the new engine config. - * \param engine_config The engine config. + * \param engine_config_json_str The engine config JSON string. */ - virtual void Reload(EngineConfig engine_config) = 0; + virtual void Reload(String engine_config_json_str) = 0; /*! \brief Unload the background engine. */ virtual void Unload() = 0; @@ -73,13 +73,19 @@ class ThreadedEngine { /*! \brief Abort the input request (specified by id string) from engine. */ virtual void AbortRequest(const String& request_id) = 0; - /************** Debug/Profile **************/ + /************** Query/Profile/Debug **************/ - /*! \brief Call the given global function on all workers. Only for debug purpose. */ - virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; + /*! \brief Return the default generation config JSON string. */ + virtual String GetDefaultGenerationConfigJSONString() const = 0; + + /*! \brief Return the complete engine config JSON string. */ + virtual String GetCompleteEngineConfigJSONString() const = 0; /*! \brief Print the statistics of the engine. */ virtual String Stats() = 0; + + /*! \brief Call the given global function on all workers. Only for debug purpose. */ + virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; }; } // namespace serve diff --git a/cpp/metadata/json_parser.h b/cpp/support/json_parser.h similarity index 92% rename from cpp/metadata/json_parser.h rename to cpp/support/json_parser.h index 99a284fc42..f71757435a 100644 --- a/cpp/metadata/json_parser.h +++ b/cpp/support/json_parser.h @@ -2,8 +2,8 @@ * \file json_parser.h * \brief Helps to parse JSON strings and objects. */ -#ifndef MLC_LLM_CPP_JSON_PARSER_H_ -#define MLC_LLM_CPP_JSON_PARSER_H_ +#ifndef MLC_LLM_SUPPORT_JSON_PARSER_H_ +#define MLC_LLM_SUPPORT_JSON_PARSER_H_ #include #include @@ -165,6 +165,17 @@ inline ValueType LookupOrDefault(const picojson::object& json, const std::string return it->second.get(); } +template +inline std::optional LookupOptional(const picojson::object& json, + const std::string& key) { + auto it = json.find(key); + if (it == json.end() || it->second.is()) { + return std::nullopt; + } + CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; + return it->second.get(); +} + template inline ValueType Lookup(const picojson::array& json, int index) { CHECK(index < json.size()) << "IndexError: json::array index out of range"; @@ -209,4 +220,4 @@ inline picojson::object ParseToJsonObject(const std::string& json_str) { } // namespace llm } // namespace mlc -#endif // MLC_LLM_CPP_JSON_PARSER_H_ +#endif // MLC_LLM_SUPPORT_JSON_PARSER_H_ diff --git a/cpp/support/result.h b/cpp/support/result.h new file mode 100644 index 0000000000..c6def39525 --- /dev/null +++ b/cpp/support/result.h @@ -0,0 +1,77 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file result.h + * \brief The header for the Result class in MLC LLM. + */ +#ifndef MLC_LLM_SUPPORT_RESULT_H_ +#define MLC_LLM_SUPPORT_RESULT_H_ + +#include + +#include +#include + +namespace mlc { +namespace llm { + +/*! + * \brief The result class in MLC LLM. + * Each instance is either an okay value or an error. + * \tparam T The okay value type of the result. + * \tparam E The error type of the result. + */ +template +class Result { + public: + /*! \brief Create a result with an okay value. */ + static Result Ok(T value) { + Result result; + result.ok_value_ = std::move(value); + return result; + } + /*! \brief Create a result with an error value. */ + static Result Error(E error) { + Result result; + result.err_value_ = std::move(error); + return result; + } + /*! \brief Check if the result is okay or not. */ + bool IsOk() const { return ok_value_.has_value(); } + /*! \brief Check if the result is an error or not. */ + bool IsErr() const { return err_value_.has_value(); } + /*! + * \brief Unwrap the result and return the okay value. + * Throwing exception if it is an error. + * \note This function returns the ok value by moving, so a Result can be unwrapped only once. + */ + T Unwrap() { + ICHECK(ok_value_.has_value()) << "Cannot unwrap result on an error value."; + ICHECK(!unwrapped_) << "Cannot unwrap a Result instance twice."; + unwrapped_ = true; + return std::move(ok_value_.value()); + } + /*! + * \brief Unwrap the result and return the error value. + * Throwing exception if it is an okay value. + * \note This function returns the error value by moving, so a Result can be unwrapped only once. + */ + E UnwrapErr() { + ICHECK(err_value_.has_value()) << "Cannot unwrap result on an okay value."; + ICHECK(!unwrapped_) << "Cannot unwrap a Result instance twice."; + unwrapped_ = true; + return std::move(err_value_.value()); + } + + private: + /*! \brief A boolean flag indicating if the result is okay or error. */ + bool unwrapped_ = false; + /*! \brief The internal optional okay value. */ + std::optional ok_value_; + /*! \brief The internal optional error value. */ + std::optional err_value_; +}; + +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SUPPORT_RESULT_H_ diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index 4706e09811..560ca17255 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -285,7 +285,7 @@ We can check the output with the commands below: python >>> from mlc_llm import ChatModule >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ - model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") + model_lib="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") >>> cm.generate("hi") 'Hi! How can I assist you today?' @@ -312,7 +312,7 @@ We can check the output with the commands below: python >>> from mlc_llm import ChatModule >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ - model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so") + model_lib="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so") >>> cm.generate("hi") 'Hi! How can I assist you today?' @@ -340,7 +340,7 @@ We can check the output with the commands below: python >>> from mlc_llm import ChatModule >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ - model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so", device="vulkan") + model_lib="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so", device="vulkan") >>> cm.generate("hi") 'Hi! How can I assist you today?' diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst index aa65256fd6..1518f5145a 100644 --- a/docs/compilation/convert_weights.rst +++ b/docs/compilation/convert_weights.rst @@ -177,6 +177,6 @@ Running the distributed models are similar to running prebuilt model weights and python >>> from mlc_llm import ChatModule >>> cm = ChatModule(model="dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC", \ - model_lib_path="dist/prebuilt_libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") # Adjust based on backend + model_lib="dist/prebuilt_libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") # Adjust based on backend >>> cm.generate("hi") 'Hi! How can I assist you today?' diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst index a7ebe28d6d..f978581707 100644 --- a/docs/deploy/cli.rst +++ b/docs/deploy/cli.rst @@ -92,13 +92,13 @@ For models other than the prebuilt ones we provided: Once you have the model locally compiled with a model library and model weights, to run ``mlc_llm``, simply - Specify the path to ``mlc-chat-config.json`` and the converted model weights to ``--model`` -- Specify the path to the compiled model library (e.g. a .so file) to ``--model-lib-path`` +- Specify the path to the compiled model library (e.g. a .so file) to ``--model-lib`` .. code:: shell mlc_llm chat dist/Llama-2-7b-chat-hf-q4f16_1-MLC \ --device "cuda:0" --overrides context_window_size=1024 \ - --model-lib-path dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-vulkan.so + --model-lib dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-vulkan.so # CUDA on Linux: dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so # Metal on macOS: dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-metal.so # Same rule applies for other platforms diff --git a/docs/deploy/ide_integration.rst b/docs/deploy/ide_integration.rst index 866dfa3cbe..7e0735d8e0 100644 --- a/docs/deploy/ide_integration.rst +++ b/docs/deploy/ide_integration.rst @@ -112,7 +112,7 @@ You can now locally deploy your compiled model with the MLC serve module. To fin python -m mlc_llm.serve.server \ --model dist/CodeLlama-7b-hf-q4f16_1-MLC \ - --model-lib-path ./dist/libs/CodeLlama-7b-hf-q4f16_1-cuda.so + --model-lib ./dist/libs/CodeLlama-7b-hf-q4f16_1-cuda.so Configure the IDE Extension --------------------------- diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index 75a5cdbdc7..2bcf7997d3 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -273,7 +273,7 @@ We simply specify the Huggingface link as ``model_url``, while reusing the ``mod "model_url": "https://huggingface.co/mlc-ai/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC", "model_id": "Mistral-7B-Instruct-v0.2-q3f16_1", "model_lib": "mistral_q3f16_1", - "model_lib_path": "lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q3f16_1-iphone.tar", + "model_lib": "lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q3f16_1-iphone.tar", "estimated_vram_bytes": 3316000000 } ] @@ -421,7 +421,6 @@ rounded up to MB. "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC", "model_id": "phi-2-q4f16_1", "model_lib": "phi_msft_q4f16_1", - "model_lib_path": "lib/phi-2/phi-2-q4f16_1-iphone.tar", "estimated_vram_bytes": 3043000000 } ] diff --git a/docs/deploy/python_chat_module.rst b/docs/deploy/python_chat_module.rst index 5776e29138..14e9f3ed03 100644 --- a/docs/deploy/python_chat_module.rst +++ b/docs/deploy/python_chat_module.rst @@ -95,7 +95,7 @@ file ``sample_mlc_llm.py`` and paste the following lines: # Create a ChatModule instance cm = ChatModule( model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} @@ -106,7 +106,7 @@ file ``sample_mlc_llm.py`` and paste the following lines: # Here WizardMath reuses Mistral's model library # cm = ChatModule( # model="dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", # or "dist/WizardMath-7B-V1.1-q4f16_1-MLC" - # model_lib_path="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" + # model_lib="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" # ) # Generate a response for a given prompt @@ -200,7 +200,7 @@ We provide an example below. cm = ChatModule( chat_config=chat_config, model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} @@ -275,7 +275,7 @@ We provide an example below. cm = ChatModule( chat_config=chat_config, model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} @@ -320,7 +320,7 @@ We provide an example below. # Create a ChatModule instance cm = ChatModule( model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} diff --git a/docs/deploy/python_engine.rst b/docs/deploy/python_engine.rst index 89c60ac422..2ef4d5bd23 100644 --- a/docs/deploy/python_engine.rst +++ b/docs/deploy/python_engine.rst @@ -219,15 +219,15 @@ you can construct a :class:`mlc_llm.MLCEngine` as follows: **Specify Model Library Path.** Further, if you build the model library on your own, -you can use it in :class:`mlc_llm.MLCEngine` by passing the library path through argument ``model_lib_path``. +you can use it in :class:`mlc_llm.MLCEngine` by passing the library path through argument ``model_lib``. .. code:: python from mlc_llm import MLCEngine model = "models/phi-2" - model_lib_path = "models/phi-2/lib.so" # Assuming the phi-2 model library is built at "models/phi-2/lib.so" - engine = MLCEngine(model, model_lib_path=model_lib_path) + model_lib = "models/phi-2/lib.so" # Assuming the phi-2 model library is built at "models/phi-2/lib.so" + engine = MLCEngine(model, model_lib=model_lib) The same applies to :class:`mlc_llm.AsyncMLCEngine`. diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index 07d39dbfad..a82c914004 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -28,7 +28,7 @@ This section provides a quick start guide to work with MLC-LLM REST API. To laun .. code:: bash - mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] + mlc_llm serve MODEL [--model-lib PATH-TO-MODEL-LIB] where ``MODEL`` is the model folder after compiling with :ref:`MLC-LLM build process `. Information about other arguments can be found under :ref:`Launch the server ` section. @@ -66,14 +66,14 @@ To launch the MLC Server for MLC-LLM, run the following command in your terminal .. code:: bash - mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] + mlc_llm serve MODEL [--model-lib PATH-TO-MODEL-LIB] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] MODEL The model folder after compiling with MLC-LLM build process. The parameter can either be the model name with its quantization scheme (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model folder. In the former case, we will use the provided name to search for the model folder over possible paths. ---model-lib-path A field to specify the full path to the model library file to use (e.g. a ``.so`` file). +--model-lib A field to specify the full path to the model library file to use (e.g. a ``.so`` file). --device The description of the device to run on. User should provide a string in the form of 'device_name:device_id' or 'device_name', where 'device_name' is one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the diff --git a/docs/get_started/introduction.rst b/docs/get_started/introduction.rst index 29060d5a60..bcba8f631e 100644 --- a/docs/get_started/introduction.rst +++ b/docs/get_started/introduction.rst @@ -240,20 +240,20 @@ Below is an example command of compiling model libraries in MLC LLM: .. code:: bash - export $MODEL_LIB_PATH=$MLC_MODEL_PATH/lib.so # ".dylib" for Intel Macs. - # ".dll" for Windows. - # ".wasm" for web. - # ".tar" for iPhone/Android. - mlc_llm compile $MLC_MODEL_PATH -o $MODEL_LIB_PATH + export $MODEL_LIB=$MLC_MODEL_PATH/lib.so # ".dylib" for Intel Macs. + # ".dll" for Windows. + # ".wasm" for web. + # ".tar" for iPhone/Android. + mlc_llm compile $MLC_MODEL_PATH -o $MODEL_LIB At runtime, we need to specify this model library path to use it. For example, .. code:: bash # For chat CLI - mlc_llm chat $MLC_MODEL_PATH --model-lib-path $MODEL_LIB_PATH + mlc_llm chat $MLC_MODEL_PATH --model-lib $MODEL_LIB # For REST server - mlc_llm serve $MLC_MODEL_PATH --model-lib-path $MODEL_LIB_PATH + mlc_llm serve $MLC_MODEL_PATH --model-lib $MODEL_LIB .. code:: python @@ -261,8 +261,8 @@ At runtime, we need to specify this model library path to use it. For example, # For Python API model = "models/phi-2" - model_lib_path = "models/phi-2/lib.so" - engine = MLCEngine(model, model_lib_path=model_lib_path) + model_lib = "models/phi-2/lib.so" + engine = MLCEngine(model, model_lib=model_lib) :ref:`compile-model-libraries` introduces the model compilation command in detail, where you can find instructions and example commands to compile model to different diff --git a/examples/python/sample_mlc_chat.py b/examples/python/sample_mlc_chat.py index de00e84ff6..f4e49bb2bd 100644 --- a/examples/python/sample_mlc_chat.py +++ b/examples/python/sample_mlc_chat.py @@ -7,7 +7,7 @@ # Create a ChatModule instance cm = ChatModule( model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so", # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} @@ -18,7 +18,7 @@ # Here WizardMath reuses Mistral's model library # cm = ChatModule( # model="dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", # or "dist/WizardMath-7B-V1.1-q4f16_1-MLC" -# model_lib_path="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" +# model_lib="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" # ) # Generate a response for a given prompt diff --git a/python/mlc_llm/chat_module.py b/python/mlc_llm/chat_module.py index 24ad8faecf..2efc3ec9b9 100644 --- a/python/mlc_llm/chat_module.py +++ b/python/mlc_llm/chat_module.py @@ -442,7 +442,7 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi if field_name == "model_lib": warn_msg = ( 'WARNING: Do not override "model_lib" in ChatConfig. ' - "This override will be ignored. Please use ChatModule.model_lib_path to " + "This override will be ignored. Please use ChatModule.model_lib to " "override the full model library path instead." ) warnings.warn(warn_msg) @@ -493,7 +493,7 @@ def _get_lib_module_path( # pylint: disable=too-many-arguments model: str, model_path: str, chat_config: ChatConfig, - model_lib_path: Optional[str], + model_lib: Optional[str], device_name: str, config_file_path: str, ) -> str: @@ -507,7 +507,7 @@ def _get_lib_module_path( # pylint: disable=too-many-arguments Model path found by `_get_model_path`. chat_config : ChatConfig Chat config after potential overrides. Returned by ``_get_chat_config``. - model_lib_path : Optional[str] + model_lib : Optional[str] User's input. Supposedly a full path to model library. Prioritized to use. device_name : str User's input. Used to construct the library model file name. @@ -516,20 +516,20 @@ def _get_lib_module_path( # pylint: disable=too-many-arguments Returns ------- - model_lib_path : str + model_lib : str The path pointing to the model library we find. Raises ------ FileNotFoundError: if we cannot find a valid model library file. """ - # 1. Use user's model_lib_path if provided - if model_lib_path is not None: - if os.path.isfile(model_lib_path): - logger.info("Using library model: %s", model_lib_path) - return model_lib_path + # 1. Use user's model_lib if provided + if model_lib is not None: + if os.path.isfile(model_lib): + logger.info("Using library model: %s", model_lib) + return model_lib raise FileNotFoundError( - f"The `model_lib_path` you passed in is not a file: {model_lib_path}.\n" + f"The `model_lib` you passed in is not a file: {model_lib}.\n" f"Please refer to {_PYTHON_GET_STARTED_TUTORIAL_URL} as tutorial on model loading." ) @@ -584,7 +584,7 @@ def _get_lib_module_path( # pylint: disable=too-many-arguments err_msg += f"- {candidate}\n" err_msg += ( "If you would like to directly specify the model library path, you may " - "consider passing in the `ChatModule.model_lib_path` parameter.\n" + "consider passing in the `ChatModule.model_lib` parameter.\n" f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an example " "on how to load a model." ) @@ -654,12 +654,12 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio return json.dumps(asdict(generation_config)) -def _inspect_model_lib_metadata_memory_usage(model_lib_path, config_file_path): +def _inspect_model_lib_metadata_memory_usage(model_lib, config_file_path): cmd = [ sys.executable, "-m", "mlc_llm.cli.model_metadata", - model_lib_path, + model_lib, "--memory-only", "--mlc-chat-config", config_file_path, @@ -716,7 +716,7 @@ class ChatModule: # pylint: disable=too-many-instance-attributes A ``ChatConfig`` instance partially filled. Will be used to override the ``mlc-chat-config.json``. - model_lib_path : Optional[str] + model_lib : Optional[str] The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use the provided ``model`` to search over possible paths. @@ -727,7 +727,7 @@ def __init__( # pylint: disable=too-many-arguments model: str, device: str = "auto", chat_config: Optional[ChatConfig] = None, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, ): # 0. Get device: # Retrieve device_name and device_id (if any, default 0) from device arg @@ -768,12 +768,12 @@ def __init__( # pylint: disable=too-many-arguments self.chat_config = _get_chat_config(self.config_file_path, chat_config) # 4. Look up model library - if model_lib_path is not None: - self.model_lib_path = _get_lib_module_path( + if model_lib is not None: + self.model_lib = _get_lib_module_path( model, self.model_path, self.chat_config, - model_lib_path, + model_lib, self.device.MASK2STR[self.device.device_type], self.config_file_path, ) @@ -781,20 +781,20 @@ def __init__( # pylint: disable=too-many-arguments logger.info("Now compiling model lib on device...") from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel - self.model_lib_path = str( + self.model_lib = str( jit.jit( model_path=Path(self.model_path), chat_config=asdict(self.chat_config), device=self.device, ) ) - _inspect_model_lib_metadata_memory_usage(self.model_lib_path, self.config_file_path) + _inspect_model_lib_metadata_memory_usage(self.model_lib, self.config_file_path) # 5. Call reload user_chat_config_json_str = _convert_chat_config_to_json_str( self.chat_config, self.chat_config.conv_template ) - self._reload(self.model_lib_path, self.model_path, user_chat_config_json_str) + self._reload(self.model_lib, self.model_path, user_chat_config_json_str) def generate( self, diff --git a/python/mlc_llm/cli/bench.py b/python/mlc_llm/cli/bench.py index 26b74b1f10..0e42048ff2 100644 --- a/python/mlc_llm/cli/bench.py +++ b/python/mlc_llm/cli/bench.py @@ -1,4 +1,5 @@ """Command line entrypoint of benchmark.""" + from mlc_llm.help import HELP from mlc_llm.interface.bench import bench from mlc_llm.interface.chat import ChatConfigOverride @@ -45,10 +46,10 @@ def main(argv): help=HELP["generate_length"] + ' (default: "%(default)s")', ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, default=None, - help=HELP["model_lib_path"] + ' (default: "%(default)s")', + help=HELP["model_lib"] + ' (default: "%(default)s")', ) parsed = parser.parse_args(argv) bench( @@ -58,5 +59,5 @@ def main(argv): opt=parsed.opt, overrides=parsed.overrides, generate_length=parsed.generate_length, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, ) diff --git a/python/mlc_llm/cli/benchmark.py b/python/mlc_llm/cli/benchmark.py index 72c86fab03..aa22bae68c 100644 --- a/python/mlc_llm/cli/benchmark.py +++ b/python/mlc_llm/cli/benchmark.py @@ -1,4 +1,5 @@ """A command line tool for benchmarking a chat model.""" + import argparse from pathlib import Path @@ -74,7 +75,7 @@ def main(): model=args.model, device=args.device, chat_config=ChatConfig(tensor_parallel_shards=args.tensor_parallel_shards), - model_lib_path=args.model_lib, + model_lib=args.model_lib, ) prompt = _load_prompt(args.prompt) output = chat_module.benchmark_generate(prompt, generate_length=args.generate_length) diff --git a/python/mlc_llm/cli/chat.py b/python/mlc_llm/cli/chat.py index 13c83a64ec..34fb5daa09 100644 --- a/python/mlc_llm/cli/chat.py +++ b/python/mlc_llm/cli/chat.py @@ -1,4 +1,5 @@ """Command line entrypoint of chat.""" + from mlc_llm.help import HELP from mlc_llm.interface.chat import ChatConfigOverride, chat from mlc_llm.support.argparse import ArgumentParser @@ -32,10 +33,10 @@ def main(argv): help=HELP["chatconfig_overrides"] + ' (default: "%(default)s")', ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, default=None, - help=HELP["model_lib_path"] + ' (default: "%(default)s")', + help=HELP["model_lib"] + ' (default: "%(default)s")', ) parsed = parser.parse_args(argv) chat( @@ -43,5 +44,5 @@ def main(argv): device=parsed.device, opt=parsed.opt, overrides=parsed.overrides, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, ) diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index 6663a0c230..9ba0e01e3d 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -4,7 +4,6 @@ from mlc_llm.help import HELP from mlc_llm.interface.serve import serve -from mlc_llm.serve.config import SpeculativeMode from mlc_llm.support.argparse import ArgumentParser @@ -24,10 +23,10 @@ def main(argv): help=HELP["device_deploy"] + ' (default: "%(default)s")', ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, default=None, - help=HELP["model_lib_path"] + ' (default: "%(default)s")', + help=HELP["model_lib"] + ' (default: "%(default)s")', ) parser.add_argument( "--mode", @@ -44,18 +43,16 @@ def main(argv): "--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"] ) parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"]) - parser.add_argument( - "--max-history-size", type=int, default=1, help=HELP["max_history_size_serve"] - ) + parser.add_argument("--max-history-size", type=int, help=HELP["max_history_size_serve"]) parser.add_argument( "--gpu-memory-utilization", type=float, help=HELP["gpu_memory_utilization_serve"] ) parser.add_argument( "--speculative-mode", type=str, - choices=["DISABLE", "SMALL_DRAFT", "EAGLE"], - default="DISABLE", - help=HELP["speculative_mode_serve"], + choices=["disable", "small_draft", "eable"], + default="disable", + help=HELP["speculative_mode_serve"] + ' (default: "%(default)s")', ) parser.add_argument( "--spec-draft-length", type=int, default=4, help=HELP["spec_draft_length_serve"] @@ -97,7 +94,7 @@ def main(argv): serve( model=parsed.model, device=parsed.device, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, mode=parsed.mode, additional_models=parsed.additional_models, max_batch_size=parsed.max_batch_size, @@ -105,7 +102,7 @@ def main(argv): prefill_chunk_size=parsed.prefill_chunk_size, max_history_size=parsed.max_history_size, gpu_memory_utilization=parsed.gpu_memory_utilization, - speculative_mode=SpeculativeMode[parsed.speculative_mode], + speculative_mode=parsed.speculative_mode, spec_draft_length=parsed.spec_draft_length, enable_tracing=parsed.enable_tracing, host=parsed.host, diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index 86930fa5ea..f6ef6c38af 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -25,9 +25,9 @@ A path to ``mlc-chat-config.json``, or an MLC model directory that contains `mlc-chat-config.json`. It can also be a link to a HF repository pointing to an MLC compiled model. """.strip(), - "model_lib_path": """ + "model_lib": """ The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use -the provided ``model`` to search over possible paths. It the model lib path is not found, it will be +the provided ``model`` to search over possible paths. It the model lib is not found, it will be compiled in a JIT manner. """.strip(), "model_type": """ @@ -186,8 +186,8 @@ When engine is enabled with speculative decoding, additional models are needed. The way of specifying additional models is: "--additional-models model_path_1 model_path_2 ..." or -"--additional-models model_path_1:model_lib_path_1 model_path_2 ...". -When the model lib path of a model is not given, JIT model compilation will be activated +"--additional-models model_path_1:model_lib_1 model_path_2 ...". +When the model lib of a model is not given, JIT model compilation will be activated to compile the model automatically. """, "gpu_memory_utilization_serve": """ @@ -199,10 +199,10 @@ """, "speculative_mode_serve": """ The speculative decoding mode. Right now three options are supported: - - DISABLE, where speculative decoding is not enabled, - - SMALL_DRAFT, denoting the normal speculative decoding (small draft) style, - - EAGLE, denoting the eagle-style speculative decoding. -The default mode is "DISABLE". + - "disable", where speculative decoding is not enabled, + - "small_draft", denoting the normal speculative decoding (small draft) style, + - "eagle", denoting the eagle-style speculative decoding. +The default mode is "disable". """, "spec_draft_length_serve": """ The number of draft tokens to generate in speculative proposal. The default values is 4. diff --git a/python/mlc_llm/interface/bench.py b/python/mlc_llm/interface/bench.py index 6a7d833447..baa350df05 100644 --- a/python/mlc_llm/interface/bench.py +++ b/python/mlc_llm/interface/bench.py @@ -1,4 +1,5 @@ """Python entrypoint of benchmark.""" + from typing import Optional from mlc_llm.chat_module import ChatConfig, ChatModule @@ -13,7 +14,7 @@ def bench( # pylint: disable=too-many-arguments opt: str, overrides: ChatConfigOverride, generate_length: int, - model_lib_path: Optional[str], + model_lib: Optional[str], ): """run the benchmarking""" # Set up chat config @@ -21,7 +22,7 @@ def bench( # pylint: disable=too-many-arguments # Apply overrides config = overrides.apply(config) # Set up ChatModule - cm = ChatModule(model, device, chat_config=config, model_lib_path=model_lib_path) + cm = ChatModule(model, device, chat_config=config, model_lib=model_lib) output = cm.benchmark_generate(prompt, generate_length=generate_length) print(f"Generated text:\n{output}\n") diff --git a/python/mlc_llm/interface/chat.py b/python/mlc_llm/interface/chat.py index 9c0763a6ef..75985ec27a 100644 --- a/python/mlc_llm/interface/chat.py +++ b/python/mlc_llm/interface/chat.py @@ -1,4 +1,5 @@ """Python entrypoint of chat.""" + import dataclasses from typing import List, Optional, Union @@ -121,7 +122,7 @@ def chat( device: str, opt: str, overrides: ChatConfigOverride, - model_lib_path: Optional[str], + model_lib: Optional[str], ): """chat with a model.""" # Set up chat config and generate config @@ -130,7 +131,7 @@ def chat( # Apply overrides config = overrides.apply(config) # Set up ChatModule - cm = ChatModule(model, device, chat_config=config, model_lib_path=model_lib_path) + cm = ChatModule(model, device, chat_config=config, model_lib=model_lib) _print_help_str() cm._process_system_prompts() # pylint: disable=protected-access diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index 40fa9fdda8..d1cde12678 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -8,7 +8,6 @@ from mlc_llm.protocol import error_protocol from mlc_llm.serve import engine -from mlc_llm.serve.config import SpeculativeMode from mlc_llm.serve.entrypoints import debug_entrypoints, openai_entrypoints from mlc_llm.serve.server import ServerContext @@ -16,7 +15,7 @@ def serve( model: str, device: str, - model_lib_path: Optional[str], + model_lib: Optional[str], mode: Literal["local", "interactive", "server"], additional_models: List[str], max_batch_size: Optional[int], @@ -24,7 +23,7 @@ def serve( prefill_chunk_size: Optional[int], max_history_size: Optional[int], gpu_memory_utilization: Optional[float], - speculative_mode: SpeculativeMode, + speculative_mode: Literal["disable", "small_draft", "eagle"], spec_draft_length: int, enable_tracing: bool, host: str, @@ -39,7 +38,7 @@ def serve( async_engine = engine.AsyncMLCEngine( model=model, device=device, - model_lib_path=model_lib_path, + model_lib=model_lib, mode=mode, additional_models=additional_models, max_batch_size=max_batch_size, diff --git a/python/mlc_llm/json_ffi/engine.py b/python/mlc_llm/json_ffi/engine.py index 0c604a2ef3..237319a926 100644 --- a/python/mlc_llm/json_ffi/engine.py +++ b/python/mlc_llm/json_ffi/engine.py @@ -1,6 +1,5 @@ # pylint: disable=chained-comparison,missing-docstring,too-few-public-methods,too-many-instance-attributes # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -import json import queue import threading from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union @@ -11,8 +10,6 @@ from mlc_llm.serve import engine_utils from mlc_llm.serve.engine_base import ( EngineConfig, - SpeculativeMode, - _infer_kv_cache_config, _parse_models, _process_model_args, detect_device, @@ -20,32 +17,6 @@ from mlc_llm.tokenizer import Tokenizer -# TODO(mlc-team): further minimize the JSONFFIEngine -# construction to not depend on any config and directly pass in JSON -# model defined generation config should be read from the JSONFFIEngine via Reload -def create_model_defined_generation_config( - temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float -) -> tvm.runtime.Object: - return tvm.get_global_func("mlc.json_ffi.ModelDefinedGenerationConfig")( - temperature, - top_p, - frequency_penalty, - presence_penalty, - ) - - -# TODO(mlc-team): further minimize the JSONFFIEngine -# Engine config should be passed as json str -# and backend should have good default -# only model and model_lib should be mandatory -def create_json_ffi_engine_config( - conv_template: str, model_generation_cfgs: Dict[str, tvm.runtime.Object] -) -> tvm.runtime.Object: - return tvm.get_global_func("mlc.json_ffi.JSONFFIEngineConfig")( - conv_template, model_generation_cfgs - ) - - class EngineState: sync_queue: queue.Queue @@ -70,27 +41,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model: str, device: Union[str, tvm.runtime.Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, max_history_size: Optional[int] = None, prefill_chunk_size: Optional[int] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable", spec_draft_length: int = 4, gpu_memory_utilization: Optional[float] = None, ) -> None: # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) + models = _parse_models(model, model_lib, additional_models) if isinstance(device, str): device = detect_device(device) assert isinstance(device, tvm.runtime.Device) - ( - model_args, - model_config_paths, - self.conv_template, - ) = _process_model_args(models, device) + model_args = _process_model_args(models, device)[0] # TODO(mlc-team) Remove the model config parsing, estimation below # in favor of a simple direct passing of parameters into backend. @@ -103,33 +70,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals # since we won't have similar logics in android/iOS # # - Load the raw model config into dict - self.model_config_dicts = [] for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] - with open(model_config_paths[i], "r", encoding="utf-8") as file: - self.model_config_dicts.append(json.load(file)) - - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + model_info.model_lib = model_args[i][1] # - Initialize engine state and engine. self.state = EngineState() @@ -151,43 +93,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals } self.tokenizer = Tokenizer(model_args[0][0]) - self.engine_config = EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - max_history_size=max_history_size, - kv_state_kind=kv_state_kind, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ) - - self.json_ffi_engine_config = create_json_ffi_engine_config( - conv_template=self.conv_template.model_dump_json(), - model_generation_cfgs={ - model.model: create_model_defined_generation_config( - temperature=model_config["temperature"], - top_p=model_config["top_p"], - frequency_penalty=model_config["frequency_penalty"], - presence_penalty=model_config["presence_penalty"], - ) - for model, model_config in zip(models, self.model_config_dicts) - }, - ) - - self._ffi["init_background_engine"]( - self.json_ffi_engine_config, - self.engine_config, - device, - self.state.get_request_stream_callback(), - None, - ) - def _background_loop(): self._ffi["run_background_loop"]() @@ -203,6 +108,26 @@ def _background_stream_back_loop(): self._background_stream_back_loop_thread.start() self._terminated = False + self.engine_config = EngineConfig( + model=model_args[0][0], + model_lib=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_libs=[model_arg[1] for model_arg in model_args[1:]], + mode=mode, + gpu_memory_utilization=gpu_memory_utilization, + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + verbose=False, + ) + + self._ffi["init_background_engine"](device, self.state.get_request_stream_callback(), None) + self._ffi["reload"](self.engine_config.asjson()) + def terminate(self): self._terminated = True self._ffi["exit_background_loop"]() @@ -301,7 +226,7 @@ def _handle_chat_completion( raise exception def _test_reload(self): - self._ffi["reload"](self.engine_config) + self._ffi["reload"](self.engine_config.asjson()) def _test_reset(self): self._ffi["reset"]() diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 4a5168f971..9a0a724ea1 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -107,9 +107,9 @@ class CompletionRequest(BaseModel): @field_validator("frequency_penalty", "presence_penalty") @classmethod - def check_penalty_range(cls, penalty_value: float) -> float: + def check_penalty_range(cls, penalty_value: Optional[float]) -> Optional[float]: """Check if the penalty value is in range [-2, 2].""" - if penalty_value < -2 or penalty_value > 2: + if penalty_value and (penalty_value < -2 or penalty_value > 2): raise ValueError("Penalty value should be in range [-2, 2].") return penalty_value @@ -221,7 +221,7 @@ class ChatCompletionRequest(BaseModel): @field_validator("frequency_penalty", "presence_penalty") @classmethod - def check_penalty_range(cls, penalty_value: float) -> float: + def check_penalty_range(cls, penalty_value: Optional[float]) -> Optional[float]: """Check if the penalty value is in range [-2, 2].""" if penalty_value and (penalty_value < -2 or penalty_value > 2): raise ValueError("Penalty value should be in range [-2, 2].") @@ -386,7 +386,7 @@ def openai_api_get_unsupported_fields( def openai_api_get_generation_config( - request: Union[CompletionRequest, ChatCompletionRequest], model_config: Dict[str, Any] + request: Union[CompletionRequest, ChatCompletionRequest] ) -> Dict[str, Any]: """Create the generation config from the given request.""" from ..serve.config import ResponseFormat # pylint: disable=import-outside-toplevel @@ -407,17 +407,6 @@ def openai_api_get_generation_config( ] for arg_name in arg_names: kwargs[arg_name] = getattr(request, arg_name) - - # If per-request generation config values are missing, try loading from model config. - # If still not found, then use the default OpenAI API value - if kwargs["temperature"] is None: - kwargs["temperature"] = model_config.get("temperature", 1.0) - if kwargs["top_p"] is None: - kwargs["top_p"] = model_config.get("top_p", 1.0) - if kwargs["frequency_penalty"] is None: - kwargs["frequency_penalty"] = model_config.get("frequency_penalty", 0.0) - if kwargs["presence_penalty"] is None: - kwargs["presence_penalty"] = model_config.get("presence_penalty", 0.0) if kwargs["max_tokens"] is None: # Setting to -1 means the generation will not stop until # exceeding model capability or hit any stop criteria. diff --git a/python/mlc_llm/protocol/protocol_utils.py b/python/mlc_llm/protocol/protocol_utils.py index 3005909bbd..f4273d0302 100644 --- a/python/mlc_llm/protocol/protocol_utils.py +++ b/python/mlc_llm/protocol/protocol_utils.py @@ -23,14 +23,13 @@ def get_unsupported_fields(request: RequestProtocol) -> List[str]: def get_generation_config( request: RequestProtocol, - model_config: Dict[str, Any], extra_stop_token_ids: Optional[List[int]] = None, extra_stop_str: Optional[List[str]] = None, ) -> GenerationConfig: """Create the generation config in MLC LLM out from the input request protocol.""" kwargs: Dict[str, Any] if isinstance(request, (OpenAICompletionRequest, OpenAIChatCompletionRequest)): - kwargs = openai_api_get_generation_config(request, model_config) + kwargs = openai_api_get_generation_config(request) else: raise RuntimeError("Cannot reach here") diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 59358c1646..ec6899ea26 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -2,7 +2,7 @@ # Load MLC LLM library by importing base from .. import base -from .config import EngineConfig, GenerationConfig, SpeculativeMode +from .config import EngineConfig, GenerationConfig from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData from .engine import AsyncMLCEngine, MLCEngine from .grammar import BNFGrammar, GrammarStateMatcher diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 6b808ac37b..916403839a 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -1,14 +1,9 @@ """Configuration dataclasses used in MLC LLM serving""" -import enum import json from dataclasses import asdict, dataclass, field from typing import Dict, List, Literal, Optional -import tvm - -from . import _ffi_api - @dataclass class ResponseFormat: @@ -43,19 +38,19 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes n : int How many chat completion choices to generate for each input message. - temperature : float + temperature : Optional[float] The value that applies to logits and modulates the next token probabilities. - top_p : float + top_p : Optional[float] In sampling, only the most probable tokens with probabilities summed up to `top_p` are kept for sampling. - frequency_penalty : float + frequency_penalty : Optional[float] Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. - presence_penalty : float + presence_penalty : Optional[float] Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. @@ -101,10 +96,10 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes """ n: int = 1 - temperature: float = 0.8 - top_p: float = 0.95 - frequency_penalty: float = 0.0 - presence_penalty: float = 0.0 + temperature: Optional[float] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None repetition_penalty: float = 1.0 logprobs: bool = False top_logprobs: int = 0 @@ -128,26 +123,8 @@ def from_json(json_str: str) -> "GenerationConfig": return GenerationConfig(**json.loads(json_str)) -class KVStateKind(enum.IntEnum): # pylint: disable=too-few-public-methods - """Possible kinds of KV state.""" - - ATTENTION = 0 - RNNSTATE = 1 - - -class SpeculativeMode(enum.IntEnum): - """The speculative mode.""" - - # Disable speculative decoding. - DISABLE = 0 - # The normal speculative decoding (small draft) mode. - SMALL_DRAFT = 1 - # The eagle-style speculative decoding. - EAGLE = 2 - - -@tvm._ffi.register_object("mlc.serve.EngineConfig") # pylint: disable=protected-access -class EngineConfig(tvm.runtime.Object): +@dataclass +class EngineConfig: # pylint: disable=too-many-instance-attributes """The class of MLCEngine execution configuration. Parameters @@ -155,74 +132,103 @@ class EngineConfig(tvm.runtime.Object): model : str The path to the model directory. - model_lib_path : str + model_lib : str The path to the model library. additional_models : List[str] The path to the additional models' directories. - additional_model_lib_paths : List[str] + additional_model_libs : List[str] The path to the additional models' libraries. + mode : Literal["local", "interactive", "server"] + The engine mode in MLC LLM. + We provide three preset modes: "local", "interactive" and "server". + The default mode is "local". + The choice of mode decides the values of "max_batch_size", "max_total_sequence_length" + and "prefill_chunk_size" when they are not explicitly specified. + 1. Mode "local" refers to the local server deployment which has low + request concurrency. So the max batch size will be set to 4, and max + total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 2. Mode "interactive" refers to the interactive use of server, which + has at most 1 concurrent request. So the max batch size will be set to 1, + and max total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 3. Mode "server" refers to the large server use case which may handle + many concurrent request and want to use GPU memory as much as possible. + In this mode, we will automatically infer the largest possible max batch + size and max total sequence length. + + You can manually specify arguments "max_batch_size", "max_total_sequence_length" and + "prefill_chunk_size" to override the automatic inferred values. + + gpu_memory_utilization : float + A number in (0, 1) denoting the fraction of GPU memory used by the server in total. + It is used to infer to maximum possible KV cache capacity. + When it is unspecified, it defaults to 0.85. + Under mode "local" or "interactive", the actual memory usage may be + significantly smaller than this number. Under mode "server", the actual + memory usage may be slightly larger than this number. + kv_cache_page_size : int The number of consecutive tokens handled in each page in paged KV cache. - max_num_sequence : int + max_num_sequence : Optional[int] The maximum number of sequences that are allowed to be processed by the KV cache at any time. - max_total_sequence_length : int + max_total_sequence_length : Optional[int] The maximum length allowed for a single sequence in the engine. - max_single_sequence_length : int + max_single_sequence_length : Optional[int] The maximum total number of tokens whose KV data are allowed to exist in the KV cache at any time. - prefill_chunk_size : int + prefill_chunk_size : Optional[int] The maximum total sequence length in a prefill. - max_history_size: int + max_history_size: Optional[int] The maximum history size for RNN state to rool back. - kv_state_kind: KVStateKind + kv_state_kind: Optional[Literal["kv_cache", "rnn_state"]] The kind of cache. - speculative_mode : SpeculativeMode + speculative_mode : Literal["disable", "small_draft", "eagle"] The speculative mode. + "disable" means speculative decoding is disabled. + "small_draft" means the normal speculative decoding (small draft) mode. + "eagle" means the eagle-style speculative decoding. spec_draft_length : int The number of tokens to generate in speculative proposal (draft). + + verbose : bool + A boolean indicating whether to print logging info in engine. """ - def __init__( # pylint: disable=too-many-arguments - self, - model: str, - model_lib_path: str, - additional_models: List[str], - additional_model_lib_paths: List[str], - kv_cache_page_size: int, - max_num_sequence: int, - max_total_sequence_length: int, - max_single_sequence_length: int, - prefill_chunk_size: int, - max_history_size: int, - kv_state_kind: KVStateKind, - speculative_mode: SpeculativeMode, - spec_draft_length: int, - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.EngineConfig, # type: ignore # pylint: disable=no-member - model, - model_lib_path, - additional_models, - additional_model_lib_paths, - kv_cache_page_size, - max_num_sequence, - max_total_sequence_length, - max_single_sequence_length, - prefill_chunk_size, - max_history_size, - kv_state_kind, - speculative_mode, - spec_draft_length, - ) + model: str + model_lib: str + additional_models: List[str] = field(default_factory=list) + additional_model_libs: List[str] = field(default_factory=list) + mode: Literal["local", "interactive", "server"] = "local" + gpu_memory_utilization: Optional[float] = None + kv_cache_page_size: int = 16 + max_num_sequence: Optional[int] = None + max_total_sequence_length: Optional[int] = None + max_single_sequence_length: Optional[int] = None + prefill_chunk_size: Optional[int] = None + max_history_size: Optional[int] = None + kv_state_kind: Optional[Literal["kv_cache", "rnn_state"]] = None + speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable" + spec_draft_length: int = 4 + verbose: bool = True + + def asjson(self) -> str: + """Return the config in string of JSON format.""" + return json.dumps(asdict(self)) + + @staticmethod + def from_json(json_str: str) -> "EngineConfig": + """Construct a config from JSON string.""" + return EngineConfig(**json.loads(json_str)) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 413c856db1..8b63a65130 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -22,7 +22,7 @@ from mlc_llm.protocol import openai_api_protocol from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import GenerationConfig, SpeculativeMode +from mlc_llm.serve.config import GenerationConfig from mlc_llm.serve.request import Request from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -63,8 +63,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals messages: List[Dict[str, Any]], stream: Literal[True], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -72,8 +72,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -112,8 +112,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -122,8 +122,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Literal[False] = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -161,8 +161,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -171,8 +171,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -240,8 +240,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals messages: List[Dict[str, Any]], stream: Literal[True], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -249,8 +249,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -289,8 +289,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -299,8 +299,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Literal[False] = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -336,8 +336,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -346,8 +346,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -417,8 +417,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -427,8 +427,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -467,8 +467,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -478,8 +478,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals stop: Optional[Union[str, List[str]]] = None, stream: Literal[False] = False, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -515,8 +515,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -526,8 +526,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals stop: Optional[Union[str, List[str]]] = None, stream: bool = False, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -596,8 +596,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -606,8 +606,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -646,8 +646,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -657,8 +657,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals stop: Optional[Union[str, List[str]]] = None, stream: Literal[False] = False, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -694,8 +694,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -705,8 +705,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals stop: Optional[Union[str, List[str]]] = None, stream: bool = False, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -758,7 +758,7 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): Parameters ---------- - models : str + model : str A path to ``mlc-chat-config.json``, or an MLC model directory that contains `mlc-chat-config.json`. It can also be a link to a HF repository pointing to an MLC compiled model. @@ -767,10 +767,10 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): The device used to deploy the model such as "cuda" or "cuda:0". Will default to "auto" and detect from local available GPUs if not specified. - model_lib_path : Optional[str] + model_lib : Optional[str] The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use the provided ``model`` to search over possible paths. - It the model lib path is not found, it will be compiled in a JIT manner. + It the model lib is not found, it will be compiled in a JIT manner. mode : Literal["local", "interactive", "server"] The engine mode in MLC LLM. @@ -798,8 +798,8 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): The model paths and (optional) model library paths of additional models (other than the main model). When engine is enabled with speculative decoding, additional models are needed. - Each string in the list is either in form "model_path" or "model_path:model_lib_path". - When the model lib path of a model is not given, JIT model compilation will + Each string in the list is either in form "model_path" or "model_path:model_lib". + When the model lib of a model is not given, JIT model compilation will be activated to compile the model automatically. max_batch_size : Optional[int] @@ -827,15 +827,20 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. - engine_config : Optional[EngineConfig] - The MLCEngine execution configuration. - Currently speculative decoding mode is specified via engine config. - For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" - to specify the eagle-style speculative decoding. - Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. + speculative_mode : Literal["disable", "small_draft", "eagle"] + The speculative mode. + "disable" means speculative decoding is disabled. + "small_draft" means the normal speculative decoding (small draft) mode. + "eagle" means the eagle-style speculative decoding. + + spec_draft_length : int + The number of tokens to generate in speculative proposal (draft). enable_tracing : bool A boolean indicating if to enable event logging for requests. + + verbose : bool + A boolean indicating whether to print logging info in engine. """ def __init__( # pylint: disable=too-many-arguments @@ -843,7 +848,7 @@ def __init__( # pylint: disable=too-many-arguments model: str, device: Union[str, Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, @@ -851,15 +856,16 @@ def __init__( # pylint: disable=too-many-arguments prefill_chunk_size: Optional[int] = None, max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable", spec_draft_length: int = 4, enable_tracing: bool = False, + verbose: bool = True, ) -> None: super().__init__( "async", model=model, device=device, - model_lib_path=model_lib_path, + model_lib=model_lib, mode=mode, additional_models=additional_models, max_batch_size=max_batch_size, @@ -870,6 +876,7 @@ def __init__( # pylint: disable=too-many-arguments speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, enable_tracing=enable_tracing, + verbose=verbose, ) self.chat = Chat(weakref.ref(self)) self.completions = AsyncCompletion(weakref.ref(self)) @@ -889,8 +896,8 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -899,8 +906,8 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -1012,8 +1019,8 @@ async def _completion( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -1023,8 +1030,8 @@ async def _completion( # pylint: disable=too-many-arguments,too-many-locals stop: Optional[Union[str, List[str]]] = None, stream: bool = False, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -1194,7 +1201,6 @@ async def _handle_completion( request, request_id, self.state, - self.model_config_dicts[0], self.tokenizer, self.max_input_sequence_length, ) @@ -1264,7 +1270,9 @@ async def _generate( # Create the request with the given id, input data, generation # config and the created callback. input_data = engine_utils.convert_prompts_to_data(prompt) - request = Request(request_id, input_data, generation_config) + request = Request( + request_id, input_data, generation_config, self.default_generation_cfg_json_str + ) # Create the unique async request stream of the request. stream = engine_base.AsyncRequestStream() @@ -1309,7 +1317,7 @@ class MLCEngine(engine_base.MLCEngineBase): Parameters ---------- - models : str + model : str A path to ``mlc-chat-config.json``, or an MLC model directory that contains `mlc-chat-config.json`. It can also be a link to a HF repository pointing to an MLC compiled model. @@ -1318,10 +1326,10 @@ class MLCEngine(engine_base.MLCEngineBase): The device used to deploy the model such as "cuda" or "cuda:0". Will default to "auto" and detect from local available GPUs if not specified. - model_lib_path : Optional[str] + model_lib : Optional[str] The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use the provided ``model`` to search over possible paths. - It the model lib path is not found, it will be compiled in a JIT manner. + It the model lib is not found, it will be compiled in a JIT manner. mode : Literal["local", "interactive", "server"] The engine mode in MLC LLM. @@ -1349,8 +1357,8 @@ class MLCEngine(engine_base.MLCEngineBase): The model paths and (optional) model library paths of additional models (other than the main model). When engine is enabled with speculative decoding, additional models are needed. - Each string in the list is either in form "model_path" or "model_path:model_lib_path". - When the model lib path of a model is not given, JIT model compilation will + Each string in the list is either in form "model_path" or "model_path:model_lib". + When the model lib of a model is not given, JIT model compilation will be activated to compile the model automatically. max_batch_size : Optional[int] @@ -1375,15 +1383,20 @@ class MLCEngine(engine_base.MLCEngineBase): significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. - engine_config : Optional[EngineConfig] - The MLCEngine execution configuration. - Currently speculative decoding mode is specified via engine config. - For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" - to specify the eagle-style speculative decoding. - Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. + speculative_mode : Literal["disable", "small_draft", "eagle"] + The speculative mode. + "disable" means speculative decoding is disabled. + "small_draft" means the normal speculative decoding (small draft) mode. + "eagle" means the eagle-style speculative decoding. + + spec_draft_length : int + The number of tokens to generate in speculative proposal (draft). enable_tracing : bool A boolean indicating if to enable event logging for requests. + + verbose : bool + A boolean indicating whether to print logging info in engine. """ def __init__( # pylint: disable=too-many-arguments @@ -1391,7 +1404,7 @@ def __init__( # pylint: disable=too-many-arguments model: str, device: Union[str, Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, @@ -1399,15 +1412,16 @@ def __init__( # pylint: disable=too-many-arguments prefill_chunk_size: Optional[int] = None, max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable", spec_draft_length: int = 4, enable_tracing: bool = False, + verbose: bool = True, ) -> None: super().__init__( "sync", model=model, device=device, - model_lib_path=model_lib_path, + model_lib=model_lib, mode=mode, additional_models=additional_models, max_batch_size=max_batch_size, @@ -1418,6 +1432,7 @@ def __init__( # pylint: disable=too-many-arguments speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, enable_tracing=enable_tracing, + verbose=verbose, ) self.chat = Chat(weakref.ref(self)) self.completions = Completion(weakref.ref(self)) @@ -1437,8 +1452,8 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -1447,8 +1462,8 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -1560,8 +1575,8 @@ def _completion( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -1571,8 +1586,8 @@ def _completion( # pylint: disable=too-many-arguments,too-many-locals stop: Optional[Union[str, List[str]]] = None, stream: bool = False, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -1737,7 +1752,6 @@ def _handle_completion( request, request_id, self.state, - self.model_config_dicts[0], self.tokenizer, self.max_input_sequence_length, ) @@ -1804,7 +1818,9 @@ def _generate( # pylint: disable=too-many-locals # Create the request with the given id, input data, generation # config and the created callback. input_data = engine_utils.convert_prompts_to_data(prompt) - request = Request(request_id, input_data, generation_config) + request = Request( + request_id, input_data, generation_config, self.default_generation_cfg_json_str + ) # Record the stream in the tracker self.state.sync_output_queue = queue.Queue() diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 7f3f7e1331..e0d7160ece 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -6,7 +6,6 @@ import asyncio import json import queue -import subprocess import sys import threading from dataclasses import asdict, dataclass @@ -20,12 +19,7 @@ from mlc_llm.protocol import openai_api_protocol, protocol_utils from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import ( - EngineConfig, - GenerationConfig, - KVStateKind, - SpeculativeMode, -) +from mlc_llm.serve.config import EngineConfig, GenerationConfig from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -49,25 +43,25 @@ class ModelInfo: or a full path to a model directory (e.g., "dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1") - model_lib_path : Optional[str] + model_lib : Optional[str] The path to the compiled library of the model. E.g., "dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so" """ model: str - model_lib_path: Optional[str] = None + model_lib: Optional[str] = None def _parse_models( - model: str, model_lib_path: Optional[str], additional_models: Optional[List[str]] + model: str, model_lib: Optional[str], additional_models: Optional[List[str]] ) -> List[ModelInfo]: - """Parse the specified model paths and model lib paths. + """Parse the specified model paths and model libs. Return a list of ModelInfo, which is a wrapper class of the model path + lib path. Each additional model is expected to follow the format of either - "{MODEL_PATH}" or "{MODEL_PATH}:{MODEL_LIB_PATH}". + "{MODEL_PATH}" or "{MODEL_PATH}:{MODEL_LIB}". """ - models = [ModelInfo(model, model_lib_path)] + models = [ModelInfo(model, model_lib)] if additional_models is not None: for additional_model in additional_models: splits = additional_model.split(":", maxsplit=1) @@ -95,30 +89,30 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: assert isinstance(chat_config.conv_template, Conversation) conversation = chat_config.conv_template - if model.model_lib_path is not None: - # do model lib search if the model lib path is provided + if model.model_lib is not None: + # do model lib search if the model lib is provided # error out if file not found - model_lib_path = _get_lib_module_path( + model_lib = _get_lib_module_path( model=model.model, model_path=model_path, chat_config=chat_config, - model_lib_path=model.model_lib_path, + model_lib=model.model_lib, device_name=device.MASK2STR[device.device_type], config_file_path=config_file_path, ) else: # TODO(mlc-team) add logging information - # Run jit if model_lib_path is not provided + # Run jit if model_lib is not provided from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel - model_lib_path = str( + model_lib = str( jit.jit( model_path=Path(model_path), chat_config=asdict(chat_config), device=device, ) ) - return model_path, model_lib_path + return model_path, model_lib model_args: List[Tuple[str, str]] = [_convert_model_info(model) for model in models] @@ -126,618 +120,43 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: return model_args, config_file_paths, conversation -def _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( # pylint: disable=too-many-locals,too-many-arguments - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_paths: List[str], - model_config_dicts: List[Dict[str, Any]], - max_num_sequence: int, - gpu_memory_utilization: Optional[float], -) -> Tuple[float, float, float, float, float, int]: - """Estimate the memory usage and the max total sequence length (capacity) - that the KV cache can support. - """ - assert len(models) != 0 - - kv_bytes_per_token = 0 - kv_aux_workspace_bytes = 0 - model_workspace_bytes = 0 - logit_processor_workspace_bytes = 0 - params_bytes = 0 - temp_func_bytes = 0 - - for model, model_config_path, model_config_dict in zip( - models, model_config_paths, model_config_dicts - ): - # Read metadata for the parameter size and the temporary memory size. - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-memory-usage-in-json", - "--mlc-chat-config", - model_config_path, - ] - usage_str = subprocess.check_output(cmd, universal_newlines=True) - usage_json = json.loads(usage_str) - params_bytes += usage_json["params_bytes"] - temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) - - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-kv-cache-metadata-in-json", - ] - kv_cache_metadata_str = subprocess.check_output(cmd, universal_newlines=True) - kv_cache_metadata = json.loads(kv_cache_metadata_str) - - # Read model config and compute the kv size per token. - model_config = model_config_dict["model_config"] - vocab_size = model_config["vocab_size"] - prefill_chunk_size = model_config["prefill_chunk_size"] - num_layers = kv_cache_metadata["num_hidden_layers"] - head_dim = kv_cache_metadata["head_dim"] - num_qo_heads = kv_cache_metadata["num_attention_heads"] - num_kv_heads = kv_cache_metadata["num_key_value_heads"] - hidden_size = head_dim * num_qo_heads - kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25 - kv_aux_workspace_bytes += ( - (max_num_sequence + 1) * 88 - + prefill_chunk_size * (num_qo_heads + 1) * 8 - + prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 - + 48 * 1024 * 1024 - ) - model_workspace_bytes += ( - prefill_chunk_size * 4 - + max_num_sequence * 4 - + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 - ) - logit_processor_workspace_bytes += ( - max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 - ) - - # Get single-card GPU size. - gpu_size_bytes = device.total_global_memory - if gpu_size_bytes is None: - raise ValueError("Cannot read total GPU global memory from device.") - if gpu_memory_utilization is None: - gpu_memory_utilization = 0.85 - - model_max_total_sequence_length = int( - ( - int(gpu_size_bytes) * gpu_memory_utilization - - params_bytes - - temp_func_bytes - - kv_aux_workspace_bytes - - model_workspace_bytes - - logit_processor_workspace_bytes - ) - / kv_bytes_per_token - ) - if model_max_total_sequence_length <= 0: - raise ValueError( - f"The model weight size {params_bytes} may be larger than available GPU memory " - f"size {gpu_size_bytes * gpu_memory_utilization} bytes." - ) - - if device.device_type == Device.kDLMetal: - # NOTE: Metal runtime has severe performance issues with large buffers. - # To work around the issue, we limit the KV cache capacity to 32768. - model_max_total_sequence_length = min(model_max_total_sequence_length, 32768) - - total_mem_usage_except_kv_cache = ( - params_bytes - + temp_func_bytes - + kv_aux_workspace_bytes - + model_workspace_bytes - + logit_processor_workspace_bytes - ) - return ( - total_mem_usage_except_kv_cache, - params_bytes, - kv_bytes_per_token, - kv_aux_workspace_bytes, - model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes, - int(model_max_total_sequence_length), - ) - - -def _estimate_mem_usage_and_max_history_size_for_rnn_state( # pylint: disable=too-many-arguments, too-many-locals, unused-argument - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_paths: List[str], - model_config_dicts: List[Dict[str, Any]], - max_num_sequence: int, - gpu_memory_utilization: Optional[float], -) -> Tuple[float, float, float, int]: - # Get single-card GPU size. - gpu_size_bytes = device.total_global_memory - if gpu_size_bytes is None: - raise ValueError("Cannot read total GPU global memory from device.") - if gpu_memory_utilization is None: - gpu_memory_utilization = 0.90 - - rnn_state_base_bytes = 0.0 # the memory usage for rnn state when history = 1 - param_bytes = 0.0 - temp_func_bytes = 0.0 - model_workspace_bytes = 0.0 - logit_processor_workspace_bytes = 0.0 - for model, model_config_path, model_config_dict in zip( - models, model_config_paths, model_config_dicts - ): - # Read metadata for the parameter size and the temporary memory size. - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-memory-usage-in-json", - "--mlc-chat-config", - model_config_path, - ] - usage_str = subprocess.check_output(cmd, universal_newlines=True) - usage_json = json.loads(usage_str) - param_bytes += usage_json["params_bytes"] - temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) - - model_config = model_config_dict["model_config"] - vocab_size = model_config_dict["vocab_size"] - head_size = model_config["head_size"] - num_heads = model_config["num_heads"] - num_layers = model_config["num_hidden_layers"] - hidden_size = model_config["hidden_size"] - prefill_chunk_size = model_config["prefill_chunk_size"] - logit_processor_workspace_bytes += ( - max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 - ) - - model_workspace_bytes += ( - prefill_chunk_size * 4 - + max_num_sequence * 4 - + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 - ) - - rnn_state_base_bytes += ( - max_num_sequence * hidden_size * num_layers * 2 * 2 - + max_num_sequence * num_heads * head_size * head_size * num_layers * 2 - ) - - max_history_size = int( - ( - gpu_size_bytes * gpu_memory_utilization - - logit_processor_workspace_bytes - - model_workspace_bytes - - param_bytes - - temp_func_bytes - ) - / rnn_state_base_bytes - ) - if max_history_size < 1: - raise ValueError( - f"Memory required by models may be larger than available GPU memory " - f"size {gpu_size_bytes * gpu_memory_utilization} bytes." - ) - - return ( - param_bytes, - model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes, - rnn_state_base_bytes, - max_history_size, - ) - - -def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[int, int, int]: - """Read the model config dictionaries, and return the maximum single - sequence length the models can support, the maximum prefill chunk - size the models can support, and the max batch size the models can support. - - Returns - ------- - model_max_single_sequence_length : int - The maximum single sequence length the models can support. - model_max_prefill_chunk_size : int - The maximum prefill chunk size the models can support. - model_max_batch_size : int - The max batch size the models can support. - """ - model_max_single_sequence_length = int(1e9) - model_max_prefill_chunk_size = int(1e9) - model_max_batch_size = int(1e9) - for i, config in enumerate(model_config_dicts): - runtime_context_window_size = config["context_window_size"] - compile_time_context_window_size = config["model_config"]["context_window_size"] - if runtime_context_window_size > compile_time_context_window_size: - raise ValueError( - f"Model {i}'s runtime context window size ({runtime_context_window_size}) is " - "larger than the context window size used at compile time " - f"({compile_time_context_window_size})" - ) - if runtime_context_window_size == -1 and compile_time_context_window_size != -1: - raise ValueError( - f"Model {i}'s runtime context window size (infinite) is " - "larger than the context window size used at compile time " - f"({compile_time_context_window_size})" - ) - if runtime_context_window_size != -1: - model_max_single_sequence_length = min( - model_max_single_sequence_length, runtime_context_window_size - ) - - runtime_prefill_chunk_size = config["prefill_chunk_size"] - compile_time_prefill_chunk_size = config["model_config"]["prefill_chunk_size"] - if runtime_prefill_chunk_size > compile_time_prefill_chunk_size: - raise ValueError( - f"Model {i}'s runtime prefill chunk size ({runtime_prefill_chunk_size}) is " - "larger than the prefill chunk size used at compile time " - f"({compile_time_prefill_chunk_size})" - ) - model_max_prefill_chunk_size = min(model_max_prefill_chunk_size, runtime_prefill_chunk_size) - - model_max_batch_size = min(model_max_batch_size, config["model_config"]["max_batch_size"]) - - assert model_max_prefill_chunk_size != int(1e9) - assert model_max_batch_size != int(1e9) - return model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size - - -def _infer_kv_cache_config_for_kv_cache( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - mode: Literal["local", "interactive", "server"], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - gpu_memory_utilization: Optional[float], - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_dicts: List[Dict[str, Any]], - model_config_paths: List[str], -) -> Tuple[int, int, int, KVStateKind, int]: - """Initialize the KV cache config with user input and GPU memory usage estimation. - The returned four integers are: - - max_batch_size - - max_total_sequence_length - - prefill_chunk_size - - kv_state_kind - - model_max_single_sequence_length - """ - ( - model_max_single_sequence_length, - model_max_prefill_chunk_size, - model_max_batch_size, - ) = _get_model_config_limit(model_config_dicts) - - def infer_args_under_mode( - mode: Literal["local", "interactive", "server"], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - ) -> Tuple[Tuple[int, int, int, KVStateKind], List[float]]: - logging_msg = "" - # - max_batch_size - if max_batch_size is None: - max_batch_size = ( - min(4, model_max_batch_size) - if mode == "local" - else (1 if mode == "interactive" else model_max_batch_size) - ) - logging_msg += f"max batch size is set to {max_batch_size}, " - else: - logging_msg += f"max batch size {max_batch_size} is specified by user, " - # - infer the maximum total sequence length that can fit GPU memory. - ( - total_mem_usage_except_kv_cache, - model_params_bytes, - kv_bytes_per_token, - kv_aux_workspace_bytes, - temp_workspace_bytes, - model_max_total_sequence_length, - ) = _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( - models, - device, - model_config_paths, - model_config_dicts, - max_batch_size, - gpu_memory_utilization, - ) - # - max_total_sequence_length - if max_total_sequence_length is None: - if mode == "local": - max_total_sequence_length = min( - model_max_total_sequence_length, model_max_single_sequence_length, 8192 - ) - elif mode == "interactive": - max_total_sequence_length = min( - model_max_total_sequence_length, model_max_single_sequence_length - ) - else: - max_total_sequence_length = min( - model_max_total_sequence_length, - max_batch_size * model_max_single_sequence_length, - ) - logging_msg += f"max KV cache token capacity is set to {max_total_sequence_length}, " - else: - logging_msg += ( - f"max KV cache token capacity {max_total_sequence_length} is specified by user. " - ) - # - prefill_chunk_size - if prefill_chunk_size is None: - if mode in ["local", "interactive"]: - prefill_chunk_size = min( - model_max_prefill_chunk_size, - model_max_total_sequence_length, - model_max_single_sequence_length, - ) - else: - prefill_chunk_size = model_max_prefill_chunk_size - logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " - else: - logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " - - if mode == "local": - logging_msg += ( - "We choose small max batch size and KV cache capacity to use less GPU memory." - ) - elif mode == "interactive": - logging_msg += "We fix max batch size to 1 for interactive single sequence use." - else: - logging_msg += ( - "We use as much GPU memory as possible (within the" - " limit of gpu_memory_utilization)." - ) - logger.info('Under mode "%s", %s', mode, logging_msg) - - # - Construct the KV cache config - # - Estimate total GPU memory usage on single GPU. - return ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - KVStateKind.ATTENTION, - ), [ - total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token, - model_params_bytes, - kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes, - temp_workspace_bytes, - ] - - # - Infer KV cache config and estimate memory usage for each mode. - local_kv_cache_config, local_mem_usage_list = infer_args_under_mode( - "local", max_batch_size, max_total_sequence_length, prefill_chunk_size - ) - interactive_kv_cache_config, interactive_mem_usage_list = infer_args_under_mode( - "interactive", max_batch_size, max_total_sequence_length, prefill_chunk_size - ) - server_kv_cache_config, server_mem_usage_list = infer_args_under_mode( - "server", max_batch_size, max_total_sequence_length, prefill_chunk_size - ) - - # - Select the config based on the actual mode. +def _print_engine_mode_logging_msg(mode: Literal["local", "interactive", "server"]) -> None: + """Print the logging info for engine mode selection.""" if mode == "local": - kv_cache_config = local_kv_cache_config - mem_usage_list = local_mem_usage_list + logger.info( + "The selected engine mode is %s. " + "We choose small max batch size and KV cache capacity to use less GPU memory.", + green(mode), + ) elif mode == "interactive": - kv_cache_config = interactive_kv_cache_config - mem_usage_list = interactive_mem_usage_list - else: - kv_cache_config = server_kv_cache_config - mem_usage_list = server_mem_usage_list - - logger.info( - 'The actual engine mode is "%s". So max batch size is %s, ' - "max KV cache token capacity is %s, prefill chunk size is %s.", - green(mode), - green(str(kv_cache_config[0])), - green(str(kv_cache_config[1])), - green(str(kv_cache_config[2])), - ) - - logger.info( - "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB). " - "The actual usage might be slightly larger than the estimated number.", - green("Estimated total single GPU memory usage"), - *list(mem_usage / 1024 / 1024 for mem_usage in mem_usage_list), - ) - # - Final messages - override_msg = "Please override the arguments if you have particular values to set." - if mode in ["local", "interactive"]: logger.info( - 'Please switch to mode "server" if you want to use more GPU memory ' - "and support more concurrent requests. %s", - override_msg, + "The selected engine mode is %s. " + "We fix max batch size to 1 for interactive single sequence use.", + green(mode), ) else: logger.info( - 'Please switch to mode "local" or "interactive" if you want to use less GPU memory ' - "or do not have many concurrent requests to process. %s", - override_msg, + "The selected engine mode is %s. " + "We use as much GPU memory as possible (within the limit " + "of gpu_memory_utilization).", + green(mode), ) - return *kv_cache_config, model_max_single_sequence_length - - -def _infer_kv_cache_config_for_rnn_state( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - mode: Literal["local", "interactive", "server"], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - max_history_size: Optional[int], - gpu_memory_utilization: Optional[float], - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_dicts: List[Dict[str, Any]], - model_config_paths: List[str], -) -> Tuple[int, int, int, KVStateKind, int]: - """Initialize the RNN state config with user input and GPU memory usage estimation. - The returned four integers are: - - max_batch_size - - max_total_sequence_length - - prefill_chunk_size - - kv_state_kind - - max_history_size - """ - logging_msg = "" - prefill_chunk_size = 0 - - if prefill_chunk_size is None: - prefill_chunk_size = min( - config["prefill_chunk_size"] if "prefill_chunk_size" in config else 4096 - for config in model_config_dicts - ) - logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " - else: - logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " - if max_batch_size is None: - max_batch_size = 1 if mode == "interactive" else 4 - logging_msg += f"max batch size is set to {max_batch_size}, " - else: - logging_msg += f"max batch size {max_batch_size} is specified by user, " - - if mode == "local": - logging_msg += ( - "We choose small max batch size and RNN state capacity to use less GPU memory." - ) - elif mode == "interactive": - logging_msg += "We fix max batch size to 1 for interactive single sequence use." - else: - logging_msg += ( - "We use as much GPU memory as possible (within the" " limit of gpu_memory_utilization)." + if mode != "local": + logger.info( + "If you have low concurrent requests and want to use less GPU memory, " + 'please select mode "local".' ) - logger.info('Under mode "%s", %s', mode, logging_msg) - - ( - model_param_bytes, - model_temp_bytes, - model_rnn_state_base_bytes, - model_max_history_size, - ) = _estimate_mem_usage_and_max_history_size_for_rnn_state( - models, - device, - model_config_paths, - model_config_dicts, - max_batch_size, - gpu_memory_utilization, - ) - if max_history_size is None: - max_history_size = model_max_history_size - else: - max_history_size = min(max_history_size, model_max_history_size) - max_total_sequence_length = 32768 - prefill_chunk_size = 0 - kind = KVStateKind.RNNSTATE - - logger.info( - "%s: %.2f MB (Parameters: %.2f MB. RNNState: %.2f MB. Temporary buffer: %.2f MB). " - "The actual usage might be slightly larger than the estimated number.", - green("Estimated total single GPU memory usage"), - (model_param_bytes + model_temp_bytes + model_rnn_state_base_bytes) / 1024 / 1024, - model_param_bytes / 1024 / 1024, - max_history_size * model_rnn_state_base_bytes / 1024 / 1024, - model_temp_bytes / 1024 / 1024, - ) - - return ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - kind, - max_history_size, - ) - - -def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - mode: Literal["local", "interactive", "server"], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - max_history_size: Optional[int], - gpu_memory_utilization: Optional[float], - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_dicts: List[Dict[str, Any]], - model_config_paths: List[str], -) -> Tuple[int, int, int, int, int, KVStateKind]: - """Initialize the cache config with user input and GPU memory usage estimation. - The returned four integers are: - - max_batch_size - - max_total_sequence_length - - prefill_chunk_size - - max_single_sequence_length - - max_history_size - - kv_state_kind - """ - if all("rwkv" not in model.model for model in models): - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - kv_state_kind, - max_single_sequence_length, - ) = _infer_kv_cache_config_for_kv_cache( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - gpu_memory_utilization, - models, - device, - model_config_dicts, - model_config_paths, + if mode != "interactive": + logger.info( + "If you don't have concurrent requests and only use the engine interactively, " + 'please select mode "interactive".' ) - max_history_size = 0 # KV cache doesn't need this - elif all("rwkv" in model.model for model in models): - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - kv_state_kind, - max_history_size, - ) = _infer_kv_cache_config_for_rnn_state( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - model_config_dicts, - model_config_paths, + if mode != "server": + logger.info( + "If you have high concurrent requests and want to maximize the GPU memory utilization, " + 'please select mode "server".' ) - max_single_sequence_length = max_total_sequence_length # RNN state doesn't need this - else: - raise ValueError("The models should be either all KV cache models or all RNN state models.") - return ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) - - -def _infer_generation_config( - model_config_dicts: List[Dict[str, Any]] -) -> List[Tuple[float, float, float, float]]: - """Infer the generation config from the model config dictionaries. - The returned four floats are: - - temperature - - top_p - - frequency_penalty - - presence_penalty - """ - generation_configs = [] - - for model_config in model_config_dicts: - temperature = model_config.get("temperature", 1.0) - top_p = model_config.get("top_p", 1.0) - frequency_penalty = model_config.get("frequency_penalty", 0.0) - presence_penalty = model_config.get("presence_penalty", 0.0) - generation_configs.append((temperature, top_p, frequency_penalty, presence_penalty)) - - return generation_configs @dataclass @@ -1000,7 +419,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals kind: Literal["async", "sync"], model: str, device: Union[str, tvm.runtime.Device], - model_lib_path: Optional[str], + model_lib: Optional[str], mode: Literal["local", "interactive", "server"], additional_models: Optional[List[str]], max_batch_size: Optional[int], @@ -1008,12 +427,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals prefill_chunk_size: Optional[int], max_history_size: Optional[int], gpu_memory_utilization: Optional[float], - speculative_mode: SpeculativeMode, + speculative_mode: Literal["disable", "small_draft", "eagle"], spec_draft_length: int, enable_tracing: bool, + verbose: bool, ) -> None: # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) + models = _parse_models(model, model_lib, additional_models) if isinstance(device, str): device = detect_device(device) assert isinstance(device, Device) @@ -1026,31 +446,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals # - Load the raw model config into dict self.model_config_dicts = [] for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] + model_info.model_lib = model_args[i][1] with open(model_config_paths[i], "r", encoding="utf-8") as file: self.model_config_dicts.append(json.load(file)) - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + # - Print logging info for regarding the mode selection. + if verbose: + _print_engine_mode_logging_msg(mode) # - Initialize engine state and engine. self.state = EngineState(enable_tracing) @@ -1063,35 +465,20 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "run_background_loop", "run_background_stream_back_loop", "reload", - "init_background_engine", + "init_threaded_engine", "exit_background_loop", - "debug_call_func_on_all_worker", + "get_default_generation_config", + "get_complete_engine_config", "stats", + "debug_call_func_on_all_worker", ] } self.tokenizer = Tokenizer(model_args[0][0]) - self._ffi["init_background_engine"]( + self._ffi["init_threaded_engine"]( device, self.state.get_request_stream_callback(kind), self.state.trace_recorder, ) - self._ffi["reload"]( - EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - max_history_size=max_history_size, - kv_state_kind=kv_state_kind, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ) - ) def _background_loop(): self._ffi["run_background_loop"]() @@ -1108,6 +495,31 @@ def _background_stream_back_loop(): self._background_stream_back_loop_thread.start() self._terminated = False + self._ffi["reload"]( + EngineConfig( + model=model_args[0][0], + model_lib=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_libs=[model_arg[1] for model_arg in model_args[1:]], + mode=mode, + gpu_memory_utilization=gpu_memory_utilization, + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + verbose=verbose, + ).asjson() + ) + self.default_generation_cfg_json_str: str = self._ffi["get_default_generation_config"]() + self.engine_config = EngineConfig.from_json(self._ffi["get_complete_engine_config"]()) + self.max_input_sequence_length = min( + self.engine_config.max_single_sequence_length, + self.engine_config.max_total_sequence_length, + ) + def terminate(self): """Terminate the engine.""" self._terminated = True @@ -1215,7 +627,6 @@ def process_chat_completion_request( # pylint: disable=too-many-arguments # Process generation config. Create request id. generation_cfg = protocol_utils.get_generation_config( request, - model_config, extra_stop_token_ids=conv_template.stop_token_ids, extra_stop_str=conv_template.stop_str, ) @@ -1336,11 +747,10 @@ def process_chat_completion_stream_output( # pylint: disable=too-many-arguments return response, num_completion_tokens -def process_completion_request( # pylint: disable=too-many-arguments +def process_completion_request( request: openai_api_protocol.CompletionRequest, request_id: str, engine_state: EngineState, - model_config: Dict[str, Any], tokenizer: Tokenizer, max_input_sequence_length: int, ) -> Tuple[List[int], GenerationConfig, int, Optional[openai_api_protocol.CompletionResponse]]: @@ -1392,7 +802,7 @@ def process_completion_request( # pylint: disable=too-many-arguments assert isinstance(prompt, list) # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config(request, model_config) + generation_cfg = protocol_utils.get_generation_config(request) # - Echo back the prompt. echo_response = None diff --git a/python/mlc_llm/serve/request.py b/python/mlc_llm/serve/request.py index 5c2d8ad196..44cdcd292c 100644 --- a/python/mlc_llm/serve/request.py +++ b/python/mlc_llm/serve/request.py @@ -1,6 +1,6 @@ """The request class in MLC LLM serving""" -from typing import List, Union +from typing import List, Optional, Union import tvm._ffi from tvm.runtime import Object @@ -28,6 +28,11 @@ class Request(Object): generation_config : GenerationConfig The sampling configuration which may contain temperature, top_p, repetition_penalty, max_gen_len, etc. + + default_generation_config_json_str : Optional[str] + The JSON string of the default generation config. + When a field in the input generation_config is not defined, + we use the value in the default generation config. """ def __init__( @@ -35,6 +40,7 @@ def __init__( request_id: str, inputs: Union[Data, List[Data]], generation_config: GenerationConfig, + default_generation_config_json_str: Optional[str] = None, ): if not isinstance(inputs, list): inputs = [inputs] @@ -43,6 +49,7 @@ def __init__( request_id, inputs, generation_config.asjson(), + default_generation_config_json_str, ) @property diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index 1d17f8e66a..dcecd25795 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -11,8 +11,6 @@ import requests from tvm.runtime import Device -from mlc_llm.serve.config import SpeculativeMode - class PopenServer: # pylint: disable=too-many-instance-attributes """The wrapper of MLC LLM server, which runs the server in @@ -23,14 +21,14 @@ def __init__( # pylint: disable=too-many-arguments model: str, device: Union[str, Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable", spec_draft_length: int = 4, enable_tracing: bool = False, host: str = "127.0.0.1", @@ -38,7 +36,7 @@ def __init__( # pylint: disable=too-many-arguments ) -> None: """Please check out `python/mlc_llm/cli/serve.py` for the server arguments.""" self.model = model - self.model_lib_path = model_lib_path + self.model_lib = model_lib self.device = device self.mode = mode self.additional_models = additional_models @@ -59,8 +57,8 @@ def start(self) -> None: # pylint: disable=too-many-branches """ cmd = [sys.executable] cmd += ["-m", "mlc_llm", "serve", self.model] - if self.model_lib_path is not None: - cmd += ["--model-lib-path", self.model_lib_path] + if self.model_lib is not None: + cmd += ["--model-lib", self.model_lib] cmd += ["--device", self.device] if self.mode is not None: cmd += ["--mode", self.mode] @@ -72,10 +70,10 @@ def start(self) -> None: # pylint: disable=too-many-branches cmd += ["--max-total-seq-length", str(self.max_total_sequence_length)] if self.prefill_chunk_size is not None: cmd += ["--prefill-chunk-size", str(self.prefill_chunk_size)] - if self.speculative_mode != SpeculativeMode.DISABLE: + if self.speculative_mode != "disable": cmd += [ "--speculative-mode", - self.speculative_mode.name, + self.speculative_mode, "--spec-draft-length", str(self.spec_draft_length), ] diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 1be841cb08..39b09b36ce 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -14,10 +14,10 @@ import tvm from mlc_llm.serve import data -from mlc_llm.serve.config import EngineConfig, GenerationConfig, SpeculativeMode +from mlc_llm.serve.config import EngineConfig, GenerationConfig from mlc_llm.serve.engine_base import ( - _infer_kv_cache_config, _parse_models, + _print_engine_mode_logging_msg, _process_model_args, detect_device, ) @@ -58,13 +58,6 @@ class SyncMLCEngine: Parameters ---------- - models : Union[ModelInfo, List[ModelInfo]] - One or a list of model info (specifying which models to load and - which device to load to) to launch the engine. - - kv_cache_config : KVCacheConfig - The configuration of the paged KV cache. - request_stream_callback : Optional[Callable[[str, data.TokenData, Optional[str]], None]] The provided callback function to handle the generation output. It has the signature of `(str, data.TokenData, bool) -> None`, @@ -80,11 +73,11 @@ class SyncMLCEngine: the `set_request_stream_callback` method. Otherwise, the engine will raise exception. - engine_config : Optional[EngineConfig] - The Engine execution configuration. - enable_tracing : bool A boolean indicating if to enable event logging for requests. + + verbose : bool + A boolean indicating whether to print logging info in engine. """ def __init__( # pylint: disable=too-many-arguments,too-many-locals @@ -92,7 +85,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model: str, device: Union[str, tvm.runtime.Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, @@ -101,12 +94,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, enable_tracing: bool = False, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable", spec_draft_length: int = 4, + verbose: bool = True, request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, ): # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) + models = _parse_models(model, model_lib, additional_models) if isinstance(device, str): device = detect_device(device) assert isinstance(device, tvm.runtime.Device) @@ -119,31 +113,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals # - Load the raw model config into dict self.model_config_dicts = [] for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] + model_info.model_lib = model_args[i][1] with open(model_config_paths[i], "r", encoding="utf-8") as file: self.model_config_dicts.append(json.load(file)) - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + # - Print logging info for regarding the mode selection. + if verbose: + _print_engine_mode_logging_msg(mode) self._ffi = _create_tvm_module( "mlc.serve.create_engine", @@ -156,6 +132,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "reset", "get_request_stream_callback", "set_request_stream_callback", + "get_default_generation_config", ], ) self.trace_recorder = EventTraceRecorder() if enable_tracing else None @@ -163,23 +140,25 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self._ffi["init"]( EngineConfig( model=model_args[0][0], - model_lib_path=model_args[0][1], + model_lib=model_args[0][1], additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + additional_model_libs=[model_arg[1] for model_arg in model_args[1:]], + mode=mode, + gpu_memory_utilization=gpu_memory_utilization, kv_cache_page_size=16, max_num_sequence=max_batch_size, max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, prefill_chunk_size=prefill_chunk_size, max_history_size=max_history_size, - kv_state_kind=kv_state_kind, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, - ), + verbose=verbose, + ).asjson(), device, request_stream_callback, self.trace_recorder, ) + self.default_generation_cfg_json_str: str = self._ffi["get_default_generation_config"]() self.tokenizer = Tokenizer(model_args[0][0]) def generate( # pylint: disable=too-many-locals @@ -304,6 +283,7 @@ def convert_to_data(prompt: Union[str, List[int], List[data.Data]]) -> List[data request_id=str(req_id), inputs=input_data, generation_config=generation_cfg, + default_generation_config_json_str=self.default_generation_cfg_json_str, ) ) diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index 4f1cfe103d..8ff370e9d9 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -144,7 +144,7 @@ class DebugChat: # pylint: disable=too-many-instance-attributes, too-few-public dc = DebugChat( model="./dist/Llama-2-7b-chat-hf-q4f16_1-MLC", debug_dir=Path("./debug-llama-2"), - model_lib_path="./dist/llama-2-7b-chat-q4f16_1-metal.so", + model_lib="./dist/llama-2-7b-chat-q4f16_1-metal.so", ) dc.generate("hello world", 3) """ @@ -152,7 +152,7 @@ class DebugChat: # pylint: disable=too-many-instance-attributes, too-few-public def __init__( # pylint: disable=too-many-arguments self, model: str, - model_lib_path: str, + model_lib: str, debug_dir: Path, device: Optional[str] = "auto", chat_config: Optional[ChatConfig] = None, @@ -169,7 +169,7 @@ def __init__( # pylint: disable=too-many-arguments folder. In the former case, we will use the provided name to search for the model folder over possible paths. - model_lib_path : str + model_lib : str The full path to the model library file to use (e.g. a ``.so`` file). debug_dir: Path @@ -213,7 +213,7 @@ def instrument( debug_instrument if debug_instrument else DefaultDebugInstrument(debug_dir / "prefill") ) self.mod, self.params, self.metadata = _get_tvm_module( - model, model_lib_path, self.device, self.instrument + model, model_lib, self.device, self.instrument ) self.model_path, self.config_file_path = _get_model_path(model) self.chat_config = _get_chat_config(self.config_file_path, chat_config) @@ -427,7 +427,7 @@ def main(): required=True, ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, help="The full path to the model library file to use (e.g. a ``.so`` file).", required=True, @@ -447,7 +447,7 @@ def main(): parsed = parser.parse_args() dc = DebugChat( model=parsed.model, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, debug_dir=Path(parsed.debug_dir), device=parsed.device, ) diff --git a/python/mlc_llm/testing/debug_compare.py b/python/mlc_llm/testing/debug_compare.py index b3487e3e48..d257d0f3b0 100644 --- a/python/mlc_llm/testing/debug_compare.py +++ b/python/mlc_llm/testing/debug_compare.py @@ -139,7 +139,7 @@ def get_instrument(args): if args.cmp_device is None: assert args.cmp_lib_path is None, "cmp_lib_path must be None if cmp_device is None" args.cmp_device = args.device - args.cmp_lib_path = args.model_lib_path + args.cmp_lib_path = args.model_lib if args.cmp_device == "iphone": assert args.cmp_lib_path.endswith(".dylib"), "Require a dylib file for iPhone" @@ -194,7 +194,7 @@ def main(): required=True, ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, help="The full path to the model library file to use (e.g. a ``.so`` file).", required=True, @@ -230,7 +230,7 @@ def main(): instrument = get_instrument(parsed) debug_chat = DebugChat( model=parsed.model, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, debug_dir=Path(parsed.debug_dir), device=parsed.device, debug_instrument=instrument, diff --git a/rust/src/chat_module.rs b/rust/src/chat_module.rs index b90549d06c..e8c1893a98 100644 --- a/rust/src/chat_module.rs +++ b/rust/src/chat_module.rs @@ -213,24 +213,24 @@ fn get_chat_config(config_file_path: &Path) -> result::Result, device_name: &str, + model: &str, model_path: &Path, chat_config: &ChatConfig, model_lib: Option<&str>, device_name: &str, config_file_path: &Path, ) -> PathBuf { - // 1. Use user's model_lib_path if provided - if let Some(lib_path) = model_lib_path { + // 1. Use user's model_lib if provided + if let Some(lib_path) = model_lib { let path = Path::new(lib_path); if path.is_file() { info!("Using library model: {:?}", path); return path.to_path_buf(); } else { - panic!("The `model_lib_path` you passed in is not a file: {:?}.", lib_path); + panic!("The `model_lib` you passed in is not a file: {:?}.", lib_path); } } @@ -290,7 +290,7 @@ fn get_lib_module_path( } err_msg += &format!( "If you would like to directly specify the model library path, you may \ - consider passing in the `ChatModule.model_lib_path` parameter." + consider passing in the `ChatModule.model_lib` parameter." ); panic!("{}", err_msg); @@ -323,7 +323,7 @@ pub struct ChatModule { } impl ChatModule { - pub fn new(model: &str, device: &str, model_lib_path: Option<&str>) -> Result { + pub fn new(model: &str, device: &str, model_lib: Option<&str>) -> Result { let device_err_msg = format!( "Invalid device name: {}. Please enter the device in the form \ 'device_name:device_id' or 'device_name', where 'device_name' needs to be \ @@ -362,11 +362,11 @@ impl ChatModule { let chat_config = get_chat_config(&config_file_path).unwrap(); // 4. Look up the model library - let model_lib_path = get_lib_module_path( + let model_lib = get_lib_module_path( model, &model_path, &chat_config, - model_lib_path, + model_lib, device_name, &config_file_path, ); @@ -375,7 +375,7 @@ impl ChatModule { chat_module: m, chat_config, }; - let model_lib_str = model_lib_path.as_path().display().to_string(); + let model_lib_str = model_lib.as_path().display().to_string(); let model_path_str = model_path.as_path().display().to_string(); chat_mod.reload(&model_lib_str, &model_path_str, "").unwrap(); Ok(chat_mod) diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index c52571b522..b438c2a352 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union +from typing import Dict, List, Optional from mlc_llm.json_ffi import JSONFFIEngine @@ -120,12 +120,10 @@ def test_reload_reset_unload(): def test_function_calling(): model = "dist/gorilla-openfunctions-v1-q4f16_1-MLC" - model_lib_path = ( - "dist/gorilla-openfunctions-v1-q4f16_1-MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so" - ) + model_lib = "dist/gorilla-openfunctions-v1-q4f16_1-MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so" engine = JSONFFIEngine( model, - model_lib_path=model_lib_path, + model_lib=model_lib, max_total_sequence_length=1024, ) diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index c89a9e2c38..da9b486476 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -10,14 +10,14 @@ def _parse_args(): args = argparse.ArgumentParser() - args.add_argument("--model-lib-path", type=str) + args.add_argument("--model-lib", type=str) args.add_argument("--device", type=str, default="auto") args.add_argument("--batch-size", type=int, default=80) args.add_argument("--max-total-seq-length", type=int) args.add_argument("--seed", type=int, default=0) parsed = args.parse_args() - parsed.model = os.path.dirname(parsed.model_lib_path) + parsed.model = os.path.dirname(parsed.model_lib) assert parsed.batch_size % 16 == 0 return parsed @@ -44,7 +44,7 @@ def benchmark(args: argparse.Namespace): engine = SyncMLCEngine( model=args.model, device=args.device, - model_lib_path=args.model_lib_path, + model_lib=args.model_lib, mode="server", max_batch_size=args.batch_size, max_total_sequence_length=args.max_total_seq_length, diff --git a/tests/python/serve/server/conftest.py b/tests/python/serve/server/conftest.py index e425494231..1ba0d096e8 100644 --- a/tests/python/serve/server/conftest.py +++ b/tests/python/serve/server/conftest.py @@ -9,15 +9,15 @@ @pytest.fixture(scope="session") def served_model() -> Tuple[str, str]: - model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") - if model_lib_path is None: + model_lib = os.environ.get("MLC_SERVE_MODEL_LIB") + if model_lib is None: raise ValueError( 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' "Please set it to model lib compiled by MLC LLM " "(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`)." ) - model = os.path.dirname(model_lib_path) - return model, model_lib_path + model = os.path.dirname(model_lib) + return model, model_lib @pytest.fixture(scope="session") @@ -25,7 +25,7 @@ def launch_server(served_model): # pylint: disable=redefined-outer-name """A pytest session-level fixture which launches the server in a subprocess.""" server = PopenServer( model=served_model[0], - model_lib_path=served_model[1], + model_lib=served_model[1], enable_tracing=True, ) diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index e4f64d2ce4..db2d601f11 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -1287,14 +1287,14 @@ def test_debug_dump_event_trace( if __name__ == "__main__": - model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") - if model_lib_path is None: + model_lib = os.environ.get("MLC_SERVE_MODEL_LIB") + if model_lib is None: raise ValueError( 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' "Please set it to model lib compiled by MLC LLM " "(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`)." ) - MODEL = (os.path.dirname(model_lib_path), model_lib_path) + MODEL = (os.path.dirname(model_lib), model_lib) test_openai_v1_models(MODEL, None) diff --git a/tests/python/serve/server/test_server_function_call.py b/tests/python/serve/server/test_server_function_call.py index 3fff27b938..b55fe10455 100644 --- a/tests/python/serve/server/test_server_function_call.py +++ b/tests/python/serve/server/test_server_function_call.py @@ -195,15 +195,15 @@ def test_openai_v1_chat_completion_function_call( if __name__ == "__main__": - model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") - if model_lib_path is None: + model_lib = os.environ.get("MLC_SERVE_MODEL_LIB") + if model_lib is None: raise ValueError( 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' "Please set it to model lib compiled by MLC LLM " "(e.g., `./dist/gorilla-openfunctions-v1-q4f16_1_MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so`) " "which supports function calls." ) - MODEL = (os.path.dirname(model_lib_path), model_lib_path) + MODEL = (os.path.dirname(model_lib), model_lib) for msg in CHAT_COMPLETION_MESSAGES: test_openai_v1_chat_completion_function_call(MODEL, None, stream=False, messages=msg) diff --git a/tests/python/serve/server/test_server_image.py b/tests/python/serve/server/test_server_image.py index 9b016224e4..d1a79c5445 100644 --- a/tests/python/serve/server/test_server_image.py +++ b/tests/python/serve/server/test_server_image.py @@ -239,8 +239,8 @@ def test_openai_v1_chat_completions( if __name__ == "__main__": - model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") - if model_lib_path is None: + model_lib = os.environ.get("MLC_SERVE_MODEL_LIB") + if model_lib is None: raise ValueError( 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' "Please set it to model lib compiled by MLC LLM " @@ -249,9 +249,9 @@ def test_openai_v1_chat_completions( model = os.environ.get("MLC_SERVE_MODEL") if model is None: - MODEL = (os.path.dirname(model_lib_path), model_lib_path) + MODEL = (os.path.dirname(model_lib), model_lib) else: - MODEL = (model, model_lib_path) + MODEL = (model, model_lib) for msg in CHAT_COMPLETION_MESSAGES: test_openai_v1_chat_completions(MODEL, None, stream=False, messages=msg) diff --git a/tests/python/serve/test_radix_tree.py b/tests/python/serve/test_radix_tree.py index cea421cd95..06d2196d67 100644 --- a/tests/python/serve/test_radix_tree.py +++ b/tests/python/serve/test_radix_tree.py @@ -1,6 +1,3 @@ -from tvm import TVMError -from tvm.runtime import ShapeTuple - from mlc_llm.serve import PagedRadixTree diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index 6e3835238a..2c431ebcf5 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -22,10 +22,10 @@ async def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" async_engine = AsyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) @@ -79,10 +79,10 @@ async def generate_task( async def test_chat_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" async_engine = AsyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) @@ -131,10 +131,10 @@ async def generate_task(prompt: str, request_id: str): async def test_chat_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" async_engine = AsyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) @@ -182,10 +182,10 @@ async def generate_task(prompt: str, request_id: str): async def test_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" async_engine = AsyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) @@ -234,10 +234,10 @@ async def generate_task(prompt: str, request_id: str): async def test_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" async_engine = AsyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index c3963af613..926aa87f60 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -3,7 +3,7 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncMLCEngine, GenerationConfig, SpeculativeMode +from mlc_llm.serve import AsyncMLCEngine, GenerationConfig prompts = [ "What is the meaning of life?", @@ -22,17 +22,15 @@ async def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) + small_model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" async_engine = AsyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.SMALL_DRAFT, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="small_draft", ) num_requests = 10 diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 37d1833b14..dc67f3c91e 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -31,11 +31,11 @@ ] -def create_engine(model: str, model_lib_path: str): +def create_engine(model: str, model_lib: str): if "rwkv" in model: return MLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_batch_size=8, max_history_size=1, @@ -43,15 +43,15 @@ def create_engine(model: str, model_lib_path: str): else: return MLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_engine_generate(model: str, model_lib_path: str): - engine = create_engine(model, model_lib_path) +@pytest.mark.parametrize("model,model_lib", test_models) +def test_engine_generate(model: str, model_lib: str): + engine = create_engine(model, model_lib) num_requests = 10 max_tokens = 256 @@ -81,10 +81,10 @@ def test_engine_generate(model: str, model_lib_path: str): del engine -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_chat_completion(model: str, model_lib_path: str): +@pytest.mark.parametrize("model,model_lib", test_models) +def test_chat_completion(model: str, model_lib: str): # Create engine - engine = create_engine(model, model_lib_path) + engine = create_engine(model, model_lib) num_requests = 2 max_tokens = 64 @@ -119,9 +119,9 @@ def test_chat_completion(model: str, model_lib_path: str): del engine -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_chat_completion_non_stream(model: str, model_lib_path: str): - engine = create_engine(model, model_lib_path) +@pytest.mark.parametrize("model,model_lib", test_models) +def test_chat_completion_non_stream(model: str, model_lib: str): + engine = create_engine(model, model_lib) num_requests = 2 max_tokens = 64 @@ -155,9 +155,9 @@ def test_chat_completion_non_stream(model: str, model_lib_path: str): del engine -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_completion(model: str, model_lib_path: str): - engine = create_engine(model, model_lib_path) +@pytest.mark.parametrize("model,model_lib", test_models) +def test_completion(model: str, model_lib: str): + engine = create_engine(model, model_lib) num_requests = 2 max_tokens = 128 @@ -192,9 +192,9 @@ def test_completion(model: str, model_lib_path: str): del engine -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_completion_non_stream(model: str, model_lib_path: str): - engine = create_engine(model, model_lib_path) +@pytest.mark.parametrize("model,model_lib", test_models) +def test_completion_non_stream(model: str, model_lib: str): + engine = create_engine(model, model_lib) num_requests = 2 max_tokens = 128 @@ -229,9 +229,9 @@ def test_completion_non_stream(model: str, model_lib_path: str): if __name__ == "__main__": - for model, model_lib_path in test_models: - test_engine_generate(model, model_lib_path) - test_chat_completion(model, model_lib_path) - test_chat_completion_non_stream(model, model_lib_path) - test_completion(model, model_lib_path) - test_completion_non_stream(model, model_lib_path) + for model, model_lib in test_models: + test_engine_generate(model, model_lib) + test_chat_completion(model, model_lib) + test_chat_completion_non_stream(model, model_lib) + test_completion(model, model_lib) + test_completion_non_stream(model, model_lib) diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index b764c62cd2..2b3ce29c7f 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -17,12 +17,12 @@ "Generate a JSON with 5 elements:", ] model_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" -model_lib_path = "dist/libs/Llama-2-7b-chat-hf-q4f16_1-cuda.so" +model_lib = "dist/libs/Llama-2-7b-chat-hf-q4f16_1-cuda.so" def test_batch_generation_with_grammar(): # Create engine - engine = SyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + engine = SyncMLCEngine(model=model_path, model_lib=model_lib, mode="server") prompt_len = len(prompts_list) prompts = prompts_list * 3 @@ -69,7 +69,7 @@ def test_batch_generation_with_grammar(): def test_batch_generation_with_schema(): # Create engine - engine = SyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + engine = SyncMLCEngine(model=model_path, model_lib=model_lib, mode="server") prompt = ( "Generate a json containing three fields: an integer field named size, a " @@ -121,7 +121,7 @@ class Schema(BaseModel): async def run_async_engine(): # Create engine - async_engine = AsyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + async_engine = AsyncMLCEngine(model=model_path, model_lib=model_lib, mode="server") prompts = prompts_list * 20 diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py index 59e8c97196..01bb1967e0 100644 --- a/tests/python/serve/test_serve_engine_image.py +++ b/tests/python/serve/test_serve_engine_image.py @@ -12,10 +12,10 @@ def get_test_image(config) -> data.ImageData: def test_engine_generate(): # Create engine model = "dist/llava-1.5-7b-hf-q4f16_1-MLC/params" - model_lib_path = "dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so" + model_lib = "dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 33c06b1c5e..3f1fa5107c 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -4,13 +4,7 @@ import numpy as np -from mlc_llm.serve import ( - GenerationConfig, - Request, - RequestStreamOutput, - SpeculativeMode, - data, -) +from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data from mlc_llm.serve.sync_engine import SyncMLCEngine prompts = [ @@ -85,18 +79,16 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) + small_model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.SMALL_DRAFT, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="small_draft", request_stream_callback=fcallback, ) @@ -153,18 +145,16 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" small_model = "dist/Eagle-llama2-7b-chat-q0f16-MLC" - small_model_lib_path = ( - "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" - ) + small_model_lib = "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.EAGLE, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="eagle", spec_draft_length=2, request_stream_callback=fcallback, ) @@ -236,19 +226,17 @@ def step(self) -> None: # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) + small_model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" timer = CallbackTimer() engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.SMALL_DRAFT, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="small_draft", request_stream_callback=timer.callback_getter(), ) @@ -322,19 +310,19 @@ def step(self) -> None: # Create engine model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" small_model = "dist/Eagle-llama2-7b-chat-q4f16_1-MLC" - small_model_lib_path = ( + small_model_lib = ( "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) timer = CallbackTimer() engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.EAGLE, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="eagle", request_stream_callback=timer.callback_getter(), ) @@ -379,19 +367,17 @@ def compare_output_text(output_text1, output_text2): def test_engine_generate(compare_precision=False): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) + small_model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.SMALL_DRAFT, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="small_draft", ) num_requests = 10 @@ -405,7 +391,7 @@ def test_engine_generate(compare_precision=False): ) engine_single_model = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) @@ -441,18 +427,18 @@ def test_engine_generate(compare_precision=False): def test_engine_eagle_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" small_model = "dist/Eagle-llama2-7b-chat-q4f16_1-MLC" - small_model_lib_path = ( + small_model_lib = ( "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.EAGLE, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="eagle", ) num_requests = 10 @@ -493,10 +479,10 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" + model_lib = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, request_stream_callback=fcallback, @@ -556,24 +542,22 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" + model_lib = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) + small_model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" # If Flashinfer allows head_dim < 128, we can test this model # small_model = "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC" - # small_model_lib_path = ( + # small_model_lib = ( # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so" # ) spec_engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], + additional_models=[small_model + ":" + small_model_lib], spec_draft_length=6, - speculative_mode=SpeculativeMode.SMALL_DRAFT, + speculative_mode="small_draft", request_stream_callback=fcallback, ) @@ -631,19 +615,17 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" small_model = "dist/Eagle-llama2-7b-chat-q0f16-MLC" - small_model_lib_path = ( - "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" - ) + small_model_lib = "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" spec_engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], + additional_models=[small_model + ":" + small_model_lib], spec_draft_length=6, - speculative_mode=SpeculativeMode.EAGLE, + speculative_mode="eagle", request_stream_callback=fcallback, ) diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py index f68f48b7c5..8c574f875f 100644 --- a/tests/python/serve/test_serve_sync_engine.py +++ b/tests/python/serve/test_serve_sync_engine.py @@ -79,10 +79,10 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", request_stream_callback=fcallback, ) @@ -155,10 +155,10 @@ def step(self) -> None: # Create engine timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", request_stream_callback=timer.callback_getter(), ) @@ -236,10 +236,10 @@ def step(self) -> None: # Create engine timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", request_stream_callback=timer.callback_getter(), ) @@ -322,10 +322,10 @@ def all_finished(self) -> bool: # Create engine timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", request_stream_callback=timer.callback_getter(), ) @@ -364,10 +364,10 @@ def all_finished(self) -> bool: def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) From 17fb1c4dc5cd4b3424be4761006c91ce6eeeb914 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 3 May 2024 05:36:28 -0700 Subject: [PATCH 264/531] [Serving] Add some try-except captures in AsyncMLCEngine (#2265) * [Serving] Add some try-except captures in AsyncMLCEngine --- python/mlc_llm/serve/engine.py | 124 +++++++++++++++++++-------------- 1 file changed, 73 insertions(+), 51 deletions(-) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 8b63a65130..c99dbd4794 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -982,19 +982,26 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( [[] for _ in range(n)] if logprobs else None ) - async for response in chatcmpl_generator: - num_prompt_tokens = response.usage.prompt_tokens - num_completion_tokens = response.usage.completion_tokens - for choice in response.choices: - assert isinstance(choice.delta.content, str) - output_texts[choice.index] += choice.delta.content - if choice.finish_reason is not None and finish_reasons[choice.index] is None: - finish_reasons[choice.index] = choice.finish_reason - if choice.logprobs is not None: - assert logprob_results is not None - logprob_results[ # pylint: disable=unsupported-assignment-operation - choice.index - ] += choice.logprobs.content + try: + async for response in chatcmpl_generator: + num_prompt_tokens = response.usage.prompt_tokens + num_completion_tokens = response.usage.completion_tokens + for choice in response.choices: + assert isinstance(choice.delta.content, str) + output_texts[choice.index] += choice.delta.content + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + except ( + Exception, + asyncio.CancelledError, + ) as err: # pylint: disable=broad-exception-caught + logger.error("Error in chat completion with request ID %s: %s", request_id, err) + raise err assert all(finish_reason is not None for finish_reason in finish_reasons) use_function_calling, tool_calls_list = engine_base.process_function_call_output( @@ -1157,23 +1164,30 @@ async def _handle_chat_completion( finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] num_completion_tokens = 0 self.state.record_event(request_id, event="invoke generate") - async for delta_outputs in self._generate( - prompts, generation_cfg, request_id # type: ignore - ): - response, num_completion_tokens = engine_base.process_chat_completion_stream_output( - delta_outputs, - request_id, - self.state, - request.model, - generation_cfg, - use_function_calling, - prompt_length, - finish_reasons, - num_completion_tokens, - ) - if response is not None: - yield response - self.state.record_event(request_id, event="finish") + try: + async for delta_outputs in self._generate( + prompts, generation_cfg, request_id # type: ignore + ): + response, num_completion_tokens = engine_base.process_chat_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + use_function_calling, + prompt_length, + finish_reasons, + num_completion_tokens, + ) + if response is not None: + yield response + self.state.record_event(request_id, event="finish") + except ( + Exception, + asyncio.CancelledError, + ) as err: # pylint: disable=broad-exception-caught + logger.error("Error in _handle_chat_completion for request %s: %s", request_id, err) + raise err async def _handle_completion( self, request: openai_api_protocol.CompletionRequest, request_id: str @@ -1210,28 +1224,35 @@ async def _handle_completion( num_completion_tokens = 0 finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] self.state.record_event(request_id, event="invoke generate") - async for delta_outputs in self._generate( - prompt, generation_cfg, request_id # type: ignore - ): - response, num_completion_tokens = engine_base.process_completion_stream_output( - delta_outputs, - request_id, - self.state, - request.model, - generation_cfg, - prompt_length, - finish_reasons, - num_completion_tokens, - ) - if response is not None: - yield response + try: + async for delta_outputs in self._generate( + prompt, generation_cfg, request_id # type: ignore + ): + response, num_completion_tokens = engine_base.process_completion_stream_output( + delta_outputs, + request_id, + self.state, + request.model, + generation_cfg, + prompt_length, + finish_reasons, + num_completion_tokens, + ) + if response is not None: + yield response - suffix_response = engine_base.create_completion_suffix_response( - request, request_id, prompt_length, finish_reasons, num_completion_tokens - ) - if suffix_response is not None: - yield suffix_response - self.state.record_event(request_id, event="finish") + suffix_response = engine_base.create_completion_suffix_response( + request, request_id, prompt_length, finish_reasons, num_completion_tokens + ) + if suffix_response is not None: + yield suffix_response + self.state.record_event(request_id, event="finish") + except ( + Exception, + asyncio.CancelledError, + ) as err: # pylint: disable=broad-exception-caught + logger.error("Error in _handle_completion for request %s: %s", request_id, err) + raise err async def _generate( self, @@ -1301,6 +1322,7 @@ async def _generate( Exception, asyncio.CancelledError, ) as exception: # pylint: disable=broad-exception-caught + logger.error("Error in _generate for request %s: %s", request_id, exception) await self.abort(request_id) raise exception From b124b0b74f1b4d4b5e022279e44c12758aaea9fe Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 3 May 2024 05:36:34 -0700 Subject: [PATCH 265/531] [Eagle] Fix token shifting for prefill step (#2266) --- .../eagle_new_request_prefill.cc | 64 +++++++++---------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index 80de254ca8..2844f76c6b 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -123,45 +123,27 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { if (rsentry->child_indices.empty()) { models_[model_id]->EnableSlidingWindowForSeq(mstate->internal_id); } + // Shift the input tokens by 1 for eagle models. + if (model_id == 0) { + for (int j = 1; j < static_cast(models_.size()); ++j) { + ICHECK(rsentry->mstates[j]->inputs.size()); + TokenData token_data = Downcast(rsentry->mstates[j]->inputs[0]); + rsentry->mstates[j]->inputs.Set( + 0, TokenData( + IntTuple(token_data->token_ids.begin() + 1, token_data->token_ids.end()))); + } + } } request_internal_ids.push_back(mstate->internal_id); RECORD_EVENT(trace_recorder_, prefill_inputs[i].rsentry->request->id, "start embedding"); // Speculative models shift left the input tokens by 1 when base model has committed tokens. // Note: for n > 1 cases Eagle doesn't work because parent entry doesn't shift input tokens. - int embed_offset = - prefill_inputs[i].rsentry->mstates[model_id]->committed_tokens.empty() ? 0 : 1; for (int j = 0; j < static_cast(input_data.size()); ++j) { - if (j == static_cast(input_data.size()) - 1) { - std::vector tail_tokens; - TokenData tk_data = Downcast(input_data[j]); - CHECK(tk_data.defined()); - for (int k = embed_offset; k < static_cast(tk_data->token_ids.size()); ++k) { - tail_tokens.push_back(tk_data->token_ids[k]); - } - embeddings = models_[model_id]->TokenEmbed( - {tail_tokens.begin(), tail_tokens.end()}, - /*dst=*/!single_input ? &model_workspaces_[model_id].embeddings : nullptr, - /*offset=*/cum_prefill_length); - cum_prefill_length += input_data[j]->GetLength(); - cum_prefill_length -= embed_offset; - } else { - embeddings = input_data[i]->GetEmbedding( - models_[model_id], - /*dst=*/!single_input ? &model_workspaces_[model_id].embeddings : nullptr, - /*offset=*/cum_prefill_length); - cum_prefill_length += input_data[j]->GetLength(); - } - } - if (embed_offset > 0) { - std::vector new_tokens = {prefill_inputs[i] - .rsentry->mstates[model_id] - ->committed_tokens.back() - .sampled_token_id.first}; - embeddings = - models_[model_id]->TokenEmbed({new_tokens.begin(), new_tokens.end()}, - /*dst=*/&model_workspaces_[model_id].embeddings, - /*offset=*/cum_prefill_length); - cum_prefill_length += new_tokens.size(); + embeddings = input_data[j]->GetEmbedding( + models_[model_id], + /*dst=*/!single_input ? &model_workspaces_[model_id].embeddings : nullptr, + /*offset=*/cum_prefill_length); + cum_prefill_length += input_data[j]->GetLength(); } RECORD_EVENT(trace_recorder_, rsentry->request->id, "finish embedding"); } @@ -238,6 +220,11 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { generation_cfg.clear(); for (int i = 0; i < num_rsentries; ++i) { const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; + // No sample for rsentries with remaining inputs. + if (!rsentry->mstates[0]->inputs.empty()) { + continue; + } + int remaining_num_child_to_activate = prefill_inputs[i].num_child_to_activate; for (int child_idx : rsentry->child_indices) { // Only use base model to judge if we need to add child entries. @@ -310,6 +297,17 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { rsentries_for_sample[i]->mstates[mid]->inputs.push_back( TokenData(std::vector{sample_results[i].sampled_token_id.first})); } + if (mid > 0) { + // Add the sampled token as an input of the eagle models. + TokenData token_data = + Downcast(rsentries_for_sample[i]->mstates[mid]->inputs.back()); + std::vector token_ids = {token_data->token_ids.begin(), + token_data->token_ids.end()}; + token_ids.push_back(sample_results[i].sampled_token_id.first); + int ninputs = static_cast(rsentries_for_sample[i]->mstates[mid]->inputs.size()); + rsentries_for_sample[i]->mstates[mid]->inputs.Set( + ninputs - 1, TokenData(IntTuple(token_ids.begin(), token_ids.end()))); + } } // Only base model trigger timing records. if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { From c0306602492c8a6c89eb0dd679a0ef50f5313173 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 3 May 2024 13:05:49 -0400 Subject: [PATCH 266/531] [Fix] Fix the two-stage softmax func by removing log2e (#2269) * [Fix] Fix the two-stage softmax func by removing log2e When two-stage softmax was introduced, we use a log2e numeric transformation for some potentially better performance. However, under the case of low temperature, the log2e transformation is not numerically stable, which may cause the softmax result not summing up to 1. This PR fixes this by removing all the log2e related calculation. * Remove redundant import --- python/mlc_llm/compiler_pass/rewrite_softmax.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/python/mlc_llm/compiler_pass/rewrite_softmax.py b/python/mlc_llm/compiler_pass/rewrite_softmax.py index 82e6cf863b..df879b37ec 100644 --- a/python/mlc_llm/compiler_pass/rewrite_softmax.py +++ b/python/mlc_llm/compiler_pass/rewrite_softmax.py @@ -1,7 +1,5 @@ """A compiler pass that rewrites one-shot softmax into two-stage softmax.""" -import math - import tvm from tvm import relax from tvm.ir.module import IRModule @@ -81,8 +79,6 @@ def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-re def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements target: tvm.target.Target, chunk_size: int ): - log2e = math.log2(math.exp(1)) - # pylint: disable=invalid-name @T.prim_func def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=too-many-locals @@ -117,13 +113,13 @@ def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=to temp_sum[v0, v1] = T.float32(0) temp_sum[v0, v1] += T.if_then_else( v1 * T.int64(chunk_size) + v2 < vocab_size, - T.exp2((A_pad[v0, v1, v2] - temp_max[v0, v1]) * log2e), + T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), T.float32(0), ) for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)): with T.block("log"): v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) - chunked_lse[v0, v1] = T.log2(temp_sum[v0, v1]) + temp_max[v0, v1] * log2e + chunked_lse[v0, v1] = T.log(temp_sum[v0, v1]) + temp_max[v0, v1] @T.prim_func def softmax_with_chunked_lse(var_A: T.handle, var_chunked_lse: T.handle, var_softmax: T.handle): @@ -148,17 +144,17 @@ def softmax_with_chunked_lse(var_A: T.handle, var_chunked_lse: T.handle, var_sof v0, v1 = T.axis.remap("SR", [l0, l1]) with T.init(): temp_sum[v0] = T.float32(0) - temp_sum[v0] += T.exp2(chunked_lse[v0, v1] - temp_max[v0]) + temp_sum[v0] += T.exp(chunked_lse[v0, v1] - temp_max[v0]) for l0 in T.serial(0, batch_size): with T.block("log"): v0 = T.axis.remap("S", [l0]) - lse[v0] = T.log2(temp_sum[v0]) + temp_max[v0] + lse[v0] = T.log(temp_sum[v0]) + temp_max[v0] for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): with T.block("pad"): v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) if v1 * T.int64(chunk_size) + v2 < vocab_size: - softmax[v0, v1 * T.int64(chunk_size) + v2] = T.exp2( - A[v0, v1 * T.int64(chunk_size) + v2] * log2e - lse[v0] + softmax[v0, v1 * T.int64(chunk_size) + v2] = T.exp( + A[v0, v1 * T.int64(chunk_size) + v2] - lse[v0] ) sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_lse": softmax_with_chunked_lse})) From 8d58e52320e085ffcf28295fd70b146024f71764 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 3 May 2024 18:14:26 -0700 Subject: [PATCH 267/531] [Eagle] Fix missing broadcast in hidden states gather/scatter (#2271) * [Eagle] Fix missing broadcast in hidden states gather/scatter --- python/mlc_llm/cli/serve.py | 2 +- .../attach_spec_decode_aux_funcs.py | 76 ++++++++++++++++--- python/mlc_llm/compiler_pass/pipeline.py | 3 +- 3 files changed, 68 insertions(+), 13 deletions(-) diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index 9ba0e01e3d..d776ed146b 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -50,7 +50,7 @@ def main(argv): parser.add_argument( "--speculative-mode", type=str, - choices=["disable", "small_draft", "eable"], + choices=["disable", "small_draft", "eagle"], default="disable", help=HELP["speculative_mode_serve"] + ' (default: "%(default)s")', ) diff --git a/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py b/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py index b7cfd76fa3..f7bb3dbe14 100644 --- a/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py +++ b/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py @@ -1,7 +1,8 @@ """The pass that attaches logit processor functions to the IRModule.""" import tvm -from tvm import IRModule +from tvm import IRModule, relax, tir +from tvm.relax import BlockBuilder, TensorStructInfo from tvm.script import tir as T @@ -9,25 +10,29 @@ class AttachSpecDecodeAuxFuncs: # pylint: disable=too-few-public-methods """Attach logit processing TIR functions to IRModule.""" + tensor_parallel_shards: int + + def __init__(self, tensor_parallel_shards: int): + self.tensor_parallel_shards = tensor_parallel_shards + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: """Entrypoint""" mod = mod.clone() - mod["scatter_probs"] = _get_scatter_2d_inplace( - dtype="float32", global_symbol="scatter_probs" + bb = BlockBuilder(mod) + bb.add_func( + _get_scatter_2d_inplace(dtype="float32", global_symbol="scatter_probs"), "scatter_probs" + ) + bb.add_func( + _get_gather_2d_inplace(dtype="float32", global_symbol="gather_probs"), "gather_probs" ) - mod["gather_probs"] = _get_gather_2d_inplace(dtype="float32", global_symbol="gather_probs") if "prefill_to_last_hidden_states" in mod: hidden_states_struct_info = mod["prefill_to_last_hidden_states"].ret_struct_info.fields[ 0 ] # pylint: disable=no-member dtype = hidden_states_struct_info.dtype - mod["scatter_hidden_states"] = _get_scatter_2d_inplace( - dtype, global_symbol="scatter_hidden_states" - ) - mod["gather_hidden_states"] = _get_gather_2d_inplace( - dtype, global_symbol="gather_hidden_states" - ) - return mod + _add_gather_hidden_states(bb, self.tensor_parallel_shards, dtype) + _add_scatter_hidden_states(bb, self.tensor_parallel_shards, dtype) + return bb.finalize() def _get_scatter_2d_inplace(dtype: str, global_symbol: str): @@ -64,3 +69,52 @@ def _gather_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): dst[vb, vj] = src[indices[vb], vj] return _gather_2d + + +def _add_scatter_hidden_states(bb: BlockBuilder, tensor_parallel_shards: int, dtype: str): + batch_size = tir.Var("batch_size", "int64") + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + src = relax.Var("src", struct_info=TensorStructInfo([batch_size, n], dtype)) + indices = relax.Var("indices", struct_info=TensorStructInfo([batch_size], "int32")) + dst = relax.Var("dst", struct_info=TensorStructInfo([m, n], dtype)) + with bb.function("scatter_hidden_states", [src, indices, dst]): + with bb.dataflow(): + if tensor_parallel_shards > 1: + indices = relax.op.ccl.broadcast_from_worker0(indices) + output = relax.op.call_tir_inplace( + bb.add_func( + _get_scatter_2d_inplace(dtype, "_scatter_hidden_states"), + "_scatter_hidden_states", + ), + [src, indices, dst], + 2, + dst.struct_info, # pylint: disable=no-member + ) + bb.emit_output(output) + gv = bb.emit_func_output(output) + return gv + + +def _add_gather_hidden_states(bb: BlockBuilder, tensor_parallel_shards: int, dtype: str): + batch_size = tir.Var("batch_size", "int64") + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + src = relax.Var("src", struct_info=TensorStructInfo([m, n], dtype)) + indices = relax.Var("indices", struct_info=TensorStructInfo([batch_size], "int32")) + dst = relax.Var("dst", struct_info=TensorStructInfo([batch_size, n], dtype)) + with bb.function("gather_hidden_states", [src, indices, dst]): + with bb.dataflow(): + if tensor_parallel_shards > 1: + indices = relax.op.ccl.broadcast_from_worker0(indices) + output = relax.op.call_tir_inplace( + bb.add_func( + _get_gather_2d_inplace(dtype, "_gather_hidden_states"), "_gather_hidden_states" + ), + [src, indices, dst], + 2, + dst.struct_info, # pylint: disable=no-member + ) + bb.emit_output(output) + gv = bb.emit_func_output(output) + return gv diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index 3c80d2c4df..7bc89de21b 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -92,6 +92,7 @@ def _mlc_llm_pipeline( # pylint: disable=too-many-arguments additional_tirs = additional_tirs or {} metadata = metadata or {} ext_mods = ext_mods or [] + tensor_parallel_shards = metadata.get("tensor_parallel_shards", 1) @tvm.transform.module_pass(opt_level=0) def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: @@ -105,7 +106,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I AttachAdditionalPrimFuncs(additional_tirs), AttachAllocEmbeddingTensorFunc(metadata), AttachGPUSamplingFunc(target, variable_bounds), - AttachSpecDecodeAuxFuncs(), + AttachSpecDecodeAuxFuncs(tensor_parallel_shards), AttachMemoryPlanAttr(), tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)), _DebugDump("debug-phase0.py", debug_dump, show_meta=False), From c166a900a86b7c6ce2d7b3599030df5bddfd85fd Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 4 May 2024 09:20:45 -0400 Subject: [PATCH 268/531] [Sampler] Use pivot-based renormalization for top-p sampling (#2272) This PR integrates the pivot-based prob renormalization for top-p sampling, whose performance is a few times faster than the current sort-based top-p sampling on CUDA. --- cpp/serve/engine_actions/batch_decode.cc | 6 +- .../engine_actions/new_request_prefill.cc | 6 +- cpp/serve/sampler/gpu_sampler.cc | 40 ++++++++++--- .../mlc_llm/compiler_pass/attach_sampler.py | 58 +++++++++---------- .../mlc_llm/compiler_pass/rewrite_softmax.py | 9 +++ python/mlc_llm/op/top_p_pivot.py | 42 +++++++++++--- 6 files changed, 109 insertions(+), 52 deletions(-) diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index ecff914baa..3c5c8fdb5b 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -114,8 +114,10 @@ class BatchDecodeActionObj : public EngineActionObj { // Fill range [0, num_rsentries) into `sample_indices`. std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - std::vector sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); // - Update the committed tokens of states. diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index f801b1e282..5a5847aaa0 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -229,8 +229,10 @@ class NewRequestPrefillActionObj : public EngineActionObj { rsentry_activated.push_back(true); } } - std::vector sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); // - Update the committed tokens of states. diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index 36cb6e5c0a..1a013a9627 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -60,6 +60,8 @@ class GPUSampler : public SamplerObj { uniform_samples_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); sample_indices_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); top_p_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); + top_p_init_pivots_host_ = + NDArray::Empty({max_num_sample, num_top_p_cutoff_pivots_}, dtype_f32_, device_cpu); top_prob_offsets_host_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device_cpu); draft_tokens_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); token_tree_first_child_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); @@ -73,6 +75,8 @@ class GPUSampler : public SamplerObj { uniform_samples_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); sample_indices_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); top_p_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); + top_p_init_pivots_device_ = + NDArray::Empty({max_num_sample, num_top_p_cutoff_pivots_}, dtype_f32_, device); top_prob_offsets_device_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device); draft_tokens_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_first_child_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); @@ -118,21 +122,35 @@ class GPUSampler : public SamplerObj { return probs_on_device; } - // - Argsort the probability. - Array argsort_results = gpu_argsort_probs_func_(probs_on_device); - ICHECK_EQ(argsort_results.size(), 2); - NDArray sorted_probs_on_device = argsort_results[0]; - NDArray sorted_indices_on_device = argsort_results[1]; - - // - Copy auxiliary array for top-p. + // - Copy auxiliary array for top-p and initial pivots. NDArray top_p_host = top_p_host_.CreateView({num_probs}, dtype_f32_); NDArray top_p_device = top_p_device_.CreateView({num_probs}, dtype_f32_); CopyArray(/*src=*/top_p_host, /*dst=*/top_p_device, copy_stream_); + + NDArray top_p_init_pivots_host = + top_p_init_pivots_host_.CreateView({num_probs, num_top_p_cutoff_pivots_}, dtype_f32_); + NDArray top_p_init_pivots_device = + top_p_init_pivots_device_.CreateView({num_probs, num_top_p_cutoff_pivots_}, dtype_f32_); + const float* p_top_p = static_cast(top_p_host->data); + float* p_top_p_init_pivots = static_cast(top_p_init_pivots_host->data); + for (int i = 0; i < num_probs; ++i) { + if (1 - p_top_p[i] >= 0.02) { + p_top_p_init_pivots[i * num_top_p_cutoff_pivots_] = + std::min(1 - p_top_p[i], static_cast(0.5)); + p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 1] = 0.02; + p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 2] = 0.01; + } else { + p_top_p_init_pivots[i * num_top_p_cutoff_pivots_] = 1 - p_top_p[i]; + p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 1] = (1 - p_top_p[i]) / 2; + p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 2] = (1 - p_top_p[i]) / 4; + } + } + CopyArray(/*src=*/top_p_init_pivots_host, /*dst=*/top_p_init_pivots_device, copy_stream_); SyncCopyStream(device_, compute_stream_, copy_stream_); // - Renormalize the prob with top p. NDArray renormed_probs_on_device = - gpu_renormalize_by_top_p_func_(probs_on_device, sorted_probs_on_device, top_p_device); + gpu_renormalize_by_top_p_func_(probs_on_device, top_p_device, top_p_init_pivots_device); RECORD_EVENT(trace_recorder_, request_ids, "finish renormalization by top p"); return renormed_probs_on_device; @@ -500,6 +518,9 @@ class GPUSampler : public SamplerObj { << "GPU sampler requires the top_p values for each prob distribution are the same."; } } + for (int i = 0; i < num_probs; ++i) { + p_top_p[i] = std::max(p_top_p[i], eps_); + } return need_top_p; } @@ -665,6 +686,7 @@ class GPUSampler : public SamplerObj { NDArray uniform_samples_host_; NDArray sample_indices_host_; NDArray top_p_host_; + NDArray top_p_init_pivots_host_; NDArray top_prob_offsets_host_; NDArray draft_tokens_host_; NDArray token_tree_first_child_host_; @@ -678,6 +700,7 @@ class GPUSampler : public SamplerObj { NDArray uniform_samples_device_; NDArray sample_indices_device_; NDArray top_p_device_; + NDArray top_p_init_pivots_device_; NDArray top_prob_offsets_device_; NDArray draft_tokens_device_; NDArray token_tree_first_child_device_; @@ -691,6 +714,7 @@ class GPUSampler : public SamplerObj { // The device stream for copying auxiliary data structure to GPU. TVMStreamHandle copy_stream_ = nullptr; const float eps_ = 1e-5; + const int num_top_p_cutoff_pivots_ = 3; }; Sampler Sampler::CreateGPUSampler(int max_num_sample, int vocab_size, FunctionTable* ft, diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index 46dc40c106..5bf62257a1 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -7,7 +7,8 @@ from tvm.relax.frontend import nn from tvm.script import tir as T -from ..op.batch_spec_verify import batch_spec_verify +from mlc_llm.op.batch_spec_verify import batch_spec_verify +from mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm @tvm.transform.module_pass(opt_level=0, name="AttachGPUSamplingFunc") @@ -49,7 +50,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR _attach_sample_with_top_p(bb, vocab_size), _attach_take_probs_func(bb, vocab_size), _attach_batch_verifier(bb, vocab_size), - _attach_renormalize_by_top_p(bb, vocab_size), + _attach_renormalize_by_top_p(bb, vocab_size, self.target), ] ] @@ -227,41 +228,36 @@ def _attach_sample_with_top_p( # pylint: disable=too-many-locals return gv -def _attach_renormalize_by_top_p(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): +def _attach_renormalize_by_top_p( + bb: relax.BlockBuilder, vocab_size: tir.PrimExpr, target: tvm.target.Target +): batch_size = tir.Var("batch_size", "int64") + num_pivots = 3 probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) - sorted_probs = relax.Var( - "sorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32") - ) top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32")) - with bb.function("renormalize_by_top_p", [probs, sorted_probs, top_p]): + init_pivots = relax.Var( + "init_pivots", relax.TensorStructInfo((batch_size, num_pivots), "float32") + ) + with bb.function("renormalize_by_top_p", [probs, top_p, init_pivots]): with bb.dataflow(): - probs_tensor = nn.wrap_nested(probs, name="probs") - sorted_probs_tensor = nn.wrap_nested(sorted_probs, name="sorted_probs") - top_p_shape = relax.ShapeExpr([batch_size, 1]) - top_p_tensor = nn.wrap_nested( - relax.call_pure_packed( - "vm.builtin.reshape", - top_p, - top_p_shape, - sinfo_args=relax.TensorStructInfo(top_p_shape, "float32"), - ), - name="sample_indices", - ) - top_k_tensor = nn.tensor_ir_op( - full, - name_hint="full", - args=[vocab_size], - out=nn.Tensor.placeholder( - [batch_size, 1], - "int32", - ), + cutoff_output = bb.emit( + relax.call_tir( + bb.add_func(top_p_pivot(num_pivots, target), "top_p_pivot_cutoff"), + args=[probs, top_p, init_pivots], + out_sinfo=[top_p.struct_info, top_p.struct_info], # pylint: disable=no-member + ) ) - renormalized_probs = nn.renormalize_top_p_top_k_prob( - probs_tensor, sorted_probs_tensor, top_p_tensor, top_k_tensor + final_pivot = cutoff_output[0] + renorm_sum = cutoff_output[1] + renormalized_probs = bb.emit( + relax.call_tir( + bb.add_func(top_p_renorm(target), "top_p_renorm_after_cutoff"), + args=[probs, final_pivot, renorm_sum], + out_sinfo=probs.struct_info, # pylint: disable=no-member + ) ) - bb.emit_output(renormalized_probs._expr) # pylint: disable=protected-access - gv = bb.emit_func_output(renormalized_probs._expr) # pylint: disable=protected-access + bb.emit_output(renormalized_probs) + gv = bb.emit_func_output(renormalized_probs) return gv diff --git a/python/mlc_llm/compiler_pass/rewrite_softmax.py b/python/mlc_llm/compiler_pass/rewrite_softmax.py index df879b37ec..47a5a168d7 100644 --- a/python/mlc_llm/compiler_pass/rewrite_softmax.py +++ b/python/mlc_llm/compiler_pass/rewrite_softmax.py @@ -79,6 +79,15 @@ def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-re def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements target: tvm.target.Target, chunk_size: int ): + # NOTE: A quick note on the softmax implementation. + # We once tried to multiply every element by log2e which can be computed + # potentially more efficiently on hardware. + # However, when the input values are large, multiplying by the factor of log2e + # causes numerical issue in float32 dtype. + # This leads to the softmax output not summing up to 1. + # For numerical stability, we removed the log2e factor and switched back + # to the standard log/exp computation. + # pylint: disable=invalid-name @T.prim_func def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=too-many-locals diff --git a/python/mlc_llm/op/top_p_pivot.py b/python/mlc_llm/op/top_p_pivot.py index 9c97959bff..b9565a83c9 100644 --- a/python/mlc_llm/op/top_p_pivot.py +++ b/python/mlc_llm/op/top_p_pivot.py @@ -3,12 +3,14 @@ import tvm from tvm.script import tir as T +from mlc_llm.support.max_thread_check import get_max_num_threads_per_block + # mypy: disable-error-code="attr-defined,valid-type,name-defined" # pylint: disable=too-many-locals,invalid-name,too-many-arguments,unnecessary-lambda # pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches -def top_p_pivot(pN): +def top_p_pivot(pN, target: tvm.target.Target): """Top-p pivot function. This function finds the pivot to cut-off top-p percentile. A valide pivot should satisfy the following conditions: @@ -23,7 +25,7 @@ def top_p_pivot(pN): prob: The probability vector - top_p_global: + top_p_arr: The top-p threshold init_pivots: @@ -31,11 +33,18 @@ def top_p_pivot(pN): final_pivot: The final pivot to cut-off top-p percentile + + final_lsum: + The final sum of the values after top-p filtering. """ TX = 1024 K = 32 eps_LR = 1e-7 + max_num_threads_per_block = get_max_num_threads_per_block(target) + if max_num_threads_per_block < TX: + TX = max_num_threads_per_block + def _var(dtype="int32"): return T.alloc_buffer((1,), dtype, scope="local") @@ -46,7 +55,7 @@ def valid(lsum, lmin, cmin, top_p): @T.prim_func(private=True) def _func( var_prob: T.handle, - top_p_global: T.buffer([1], dtype="float32"), + var_top_p_arr: T.handle, var_init_pivots: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle, @@ -55,7 +64,8 @@ def _func( B = T.int32() N = T.int32() prob = T.match_buffer(var_prob, (B, N,), "float32") - init_pivots = T.match_buffer(var_init_pivots, (pN,), "float32") + top_p_arr = T.match_buffer(var_top_p_arr, (B,), dtype="float32") + init_pivots = T.match_buffer(var_init_pivots, (B, pN), "float32") final_pivot = T.match_buffer(var_final_pivot, (B,), "float32") final_lsum = T.match_buffer(var_final_lsum, (B,), "float32") @@ -92,7 +102,7 @@ def _func( with T.block("CTA"): b, tx = T.axis.remap("SS", [_bx, _tx]) - top_p[0] = top_p_global[0] + top_p[0] = top_p_arr[b] if tx == 0: # leader thread initializes L, R @@ -105,8 +115,14 @@ def _func( R_local[0] = R[0] for i in T.unroll(0, pN): # pivots are in descending order - pivot[i] = init_pivots[i] + pivot[i] = init_pivots[b, i] find_pivot_local[0] = False + if L_local[0] - R_local[0] <= eps_LR: + # When the initial value is too small, set the result directly. + if tx == 0: + final_lsum[b] = 1.0 + final_pivot[b] = 0.0 + find_pivot_local[0] = True while T.tvm_thread_invariant( L_local[0] - R_local[0] > eps_LR @@ -118,7 +134,7 @@ def _func( ### get lsum, lmin, total_sum for pidx in T.unroll(0, pN): lsum[pidx] = 0.0 - lmin[pidx] = 1.0 + lmin[pidx] = T.max_value("float32") cmin[pidx] = 0 total_sum[0] = 0.0 it[0] = 0 @@ -226,6 +242,7 @@ def _func( final_lsum[b] = lsum[pidx] elif lsum[pidx] - lmin[pidx] * cmin[pidx] >= top_p[0]: R[0] = pivot[pidx] + final_lsum[b] = lsum[pidx] elif lsum[pidx] < top_p[0]: L[0] = pivot[pidx] it[0] += 1 @@ -243,13 +260,15 @@ def _func( if tx == 0: # leader thread writes back the pivot if T.Not(find_pivot_local[0]): - final_pivot[b] = -1e5 + final_pivot[b] = R_local[0] + if R_local[0] == eps_LR: + final_lsum[b] = lsum[pN - 1] # fmt: on return _func -def top_p_renorm(): +def top_p_renorm(target: tvm.target.Target = None): """Top-p renormalization function. This function renormalizes the probability vector. Given the pivot, the probability vector is renormalized as follows: @@ -273,6 +292,11 @@ def top_p_renorm(): TX = 1024 CTA_COUNT = 512 + if target: + max_num_threads_per_block = get_max_num_threads_per_block(target) + if max_num_threads_per_block < TX: + TX = max_num_threads_per_block + def _var(dtype="int32"): return T.alloc_buffer((1,), dtype, scope="local") From 0ca6b33f7682673f68db6b906c75f9c69e304b32 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 5 May 2024 09:49:10 -0400 Subject: [PATCH 269/531] [JSONFFI] Update JSONFFI error checking with the Result class (#2275) This PR updates the error checking in JSONFFIEngine and related request parsing to use the Result class. --- cpp/json_ffi/conv_template.cc | 281 ++++++++++---------- cpp/json_ffi/conv_template.h | 34 +-- cpp/json_ffi/json_ffi_engine.cc | 38 +-- cpp/json_ffi/openai_api_protocol.cc | 386 +++++++++++++++------------- cpp/json_ffi/openai_api_protocol.h | 63 ++--- cpp/metadata/model.cc | 2 +- cpp/serve/config.cc | 2 +- cpp/serve/grammar/grammar_parser.cc | 2 +- cpp/support/json_parser.h | 208 +++++++++------ 9 files changed, 541 insertions(+), 475 deletions(-) diff --git a/cpp/json_ffi/conv_template.cc b/cpp/json_ffi/conv_template.cc index 9511bb5b64..4feee6f98e 100644 --- a/cpp/json_ffi/conv_template.cc +++ b/cpp/json_ffi/conv_template.cc @@ -34,14 +34,8 @@ Conversation::Conversation() {"assistant", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]}, {"tool", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {} -std::vector Conversation::CheckMessageSeps(std::vector& seps) { - if (seps.size() == 0 || seps.size() > 2) { - throw std::invalid_argument("seps should have size 1 or 2."); - } - return seps; -} - -std::optional> Conversation::AsPrompt(std::string* err) { +Result> Conversation::AsPrompt() { + using TResult = Result>; // Get the system message std::string system_msg = system_template; size_t pos = system_msg.find(PLACEHOLDERS[MessagePlaceholders::SYSTEM]); @@ -64,11 +58,11 @@ std::optional> Conversation::AsPrompt(std::string* err) { for (int i = 0; i < messages.size(); i++) { std::string role = messages[i].role; + // Todo(mlc-team): support content to be a single string. std::optional>> content = messages[i].content; if (roles.find(role) == roles.end()) { - *err += "\nRole " + role + " is not supported. "; - return std::nullopt; + return TResult::Error("Role \"" + role + "\" is not supported"); } std::string separator = separators[role == "assistant"]; // check assistant role @@ -90,29 +84,30 @@ std::optional> Conversation::AsPrompt(std::string* err) { message += role_prefix; - for (auto& item : content.value()) { - if (item.find("type") == item.end()) { - *err += "Content item should have a type field"; - return std::nullopt; + for (const auto& item : content.value()) { + auto it_type = item.find("type"); + if (it_type == item.end()) { + return TResult::Error("The content of a message does not have \"type\" field"); } - if (item["type"] == "text") { - if (item.find("text") == item.end()) { - *err += "Content item should have a text field"; - return std::nullopt; + if (it_type->second == "text") { + auto it_text = item.find("text"); + if (it_text == item.end()) { + return TResult::Error("The text type content of a message does not have \"text\" field"); } // replace placeholder[ROLE] with input message from role std::string role_text = role_templates[role]; std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)]; size_t pos = role_text.find(placeholder); if (pos != std::string::npos) { - role_text.replace(pos, placeholder.length(), item["text"]); + role_text.replace(pos, placeholder.length(), it_text->second); } - if (use_function_calling.has_value() && use_function_calling.value()) { + if (use_function_calling) { // replace placeholder[FUNCTION] with function_string // this assumes function calling is used for a single request scenario only if (!function_string.has_value()) { - *err += "Function string is required for function calling"; - return std::nullopt; + return TResult::Error( + "The function string in conversation template is not defined for function " + "calling."); } pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]); if (pos != std::string::npos) { @@ -122,8 +117,7 @@ std::optional> Conversation::AsPrompt(std::string* err) { } message += role_text; } else { - *err += "Unsupported content type: " + item["type"]; - return std::nullopt; + return TResult::Error("Unsupported content type: " + it_type->second); } } @@ -131,186 +125,201 @@ std::optional> Conversation::AsPrompt(std::string* err) { message_list.push_back(TextData(message)); } - return message_list; + return TResult::Ok(message_list); } -std::optional Conversation::FromJSON(const picojson::object& json, std::string* err) { +Result Conversation::FromJSON(const picojson::object& json_obj) { + using TResult = Result; Conversation conv; - // name - std::string name; - if (json::ParseJSONField(json, "name", name, err, false)) { - conv.name = name; + Result> name_res = + json::LookupOptionalWithResultReturn(json_obj, "name"); + if (name_res.IsErr()) { + return TResult::Error(name_res.UnwrapErr()); } + conv.name = name_res.Unwrap(); - std::string system_template; - if (!json::ParseJSONField(json, "system_template", system_template, err, true)) { - return std::nullopt; + Result system_template_res = + json::LookupWithResultReturn(json_obj, "system_template"); + if (system_template_res.IsErr()) { + return TResult::Error(system_template_res.UnwrapErr()); } - conv.system_template = system_template; + conv.system_template = system_template_res.Unwrap(); - std::string system_message; - if (!json::ParseJSONField(json, "system_message", system_message, err, true)) { - return std::nullopt; + Result system_message_res = + json::LookupWithResultReturn(json_obj, "system_message"); + if (system_message_res.IsErr()) { + return TResult::Error(system_message_res.UnwrapErr()); } - conv.system_message = system_message; + conv.system_message = system_message_res.Unwrap(); - picojson::array system_prefix_token_ids_arr; - if (json::ParseJSONField(json, "system_prefix_token_ids", system_prefix_token_ids_arr, err, - false)) { + Result> system_prefix_token_ids_arr_res = + json::LookupOptionalWithResultReturn(json_obj, "system_prefix_token_ids"); + if (system_prefix_token_ids_arr_res.IsErr()) { + return TResult::Error(system_prefix_token_ids_arr_res.UnwrapErr()); + } + std::optional system_prefix_token_ids_arr = + system_prefix_token_ids_arr_res.Unwrap(); + if (system_prefix_token_ids_arr.has_value()) { std::vector system_prefix_token_ids; - for (const auto& token_id : system_prefix_token_ids_arr) { + system_prefix_token_ids.reserve(system_prefix_token_ids_arr.value().size()); + for (const auto& token_id : system_prefix_token_ids_arr.value()) { if (!token_id.is()) { - *err += "system_prefix_token_ids should be an array of integers."; - return std::nullopt; + return TResult::Error("A system prefix token id is not integer."); } system_prefix_token_ids.push_back(token_id.get()); } - conv.system_prefix_token_ids = system_prefix_token_ids; + conv.system_prefix_token_ids = std::move(system_prefix_token_ids); } - bool add_role_after_system_message; - if (!json::ParseJSONField(json, "add_role_after_system_message", add_role_after_system_message, - err, true)) { - return std::nullopt; + Result add_role_after_system_message_res = + json::LookupWithResultReturn(json_obj, "add_role_after_system_message"); + if (add_role_after_system_message_res.IsErr()) { + return TResult::Error(add_role_after_system_message_res.UnwrapErr()); } - conv.add_role_after_system_message = add_role_after_system_message; + conv.add_role_after_system_message = add_role_after_system_message_res.Unwrap(); - picojson::object roles_object; - if (!json::ParseJSONField(json, "roles", roles_object, err, true)) { - return std::nullopt; + Result roles_object_res = + json::LookupWithResultReturn(json_obj, "roles"); + if (roles_object_res.IsErr()) { + return TResult::Error(roles_object_res.UnwrapErr()); } - std::unordered_map roles; - for (const auto& role : roles_object) { + for (const auto& role : roles_object_res.Unwrap()) { if (!role.second.is()) { - *err += "roles should be a map of string to string."; - return std::nullopt; + return TResult::Error("A role value in the conversation template is not a string."); } - roles[role.first] = role.second.get(); + conv.roles[role.first] = role.second.get(); } - conv.roles = roles; - - picojson::object role_templates_object; - if (json::ParseJSONField(json, "role_templates", role_templates_object, err, false)) { - for (const auto& role : role_templates_object) { - if (!role.second.is()) { - *err += "role_templates should be a map of string to string."; - return std::nullopt; + + Result> role_templates_object_res = + json::LookupOptionalWithResultReturn(json_obj, "role_templates"); + if (role_templates_object_res.IsErr()) { + return TResult::Error(role_templates_object_res.UnwrapErr()); + } + std::optional role_templates_object = role_templates_object_res.Unwrap(); + if (role_templates_object.has_value()) { + for (const auto& [role, msg] : role_templates_object.value()) { + if (!msg.is()) { + return TResult::Error("A value in \"role_templates\" is not a string."); } - conv.role_templates[role.first] = role.second.get(); + conv.role_templates[role] = msg.get(); } } - picojson::array messages_arr; - if (!json::ParseJSONField(json, "messages", messages_arr, err, true)) { - return std::nullopt; + Result messages_arr_res = + json::LookupWithResultReturn(json_obj, "messages"); + if (messages_arr_res.IsErr()) { + return TResult::Error(messages_arr_res.UnwrapErr()); } - std::vector messages; - for (const auto& message : messages_arr) { + for (const auto& message : messages_arr_res.Unwrap()) { if (!message.is()) { - *err += "messages should be an array of objects."; - return std::nullopt; + return TResult::Error("A message in the conversation template is not a JSON object."); } picojson::object message_obj = message.get(); - std::string role; - if (!json::ParseJSONField(message_obj, "role", role, err, true)) { - *err += "role field is required in messages."; - return std::nullopt; + Result role_res = json::LookupWithResultReturn(message_obj, "role"); + if (role_res.IsErr()) { + return TResult::Error(role_res.UnwrapErr()); + } + Result> content_arr_res = + json::LookupOptionalWithResultReturn(message_obj, "content"); + if (content_arr_res.IsErr()) { + return TResult::Error(content_arr_res.UnwrapErr()); } - picojson::array content_arr; + std::optional content_arr = content_arr_res.Unwrap(); std::vector> content; - if (json::ParseJSONField(message_obj, "content", content_arr, err, false)) { - for (const auto& item : content_arr) { + if (content_arr.has_value()) { + content.reserve(content_arr.value().size()); + for (const auto& item : content_arr.value()) { + // Todo(mlc-team): allow content item to be a single string. if (!item.is()) { - *err += "Content item is not an object"; - return std::nullopt; + return TResult::Error("The content of conversation template message is not an object"); } std::unordered_map item_map; - picojson::object item_obj = item.get(); - for (picojson::value::object::const_iterator i = item_obj.begin(); i != item_obj.end(); - ++i) { - item_map[i->first] = i->second.to_str(); + for (const auto& [key, value] : item.get()) { + item_map[key] = value.to_str(); } - content.push_back(item_map); + content.push_back(std::move(item_map)); } } - messages.push_back({role, content}); + conv.messages.push_back({role_res.Unwrap(), content}); } - conv.messages = messages; - picojson::array seps_arr; - if (!json::ParseJSONField(json, "seps", seps_arr, err, true)) { - return std::nullopt; + Result seps_arr_res = + json::LookupWithResultReturn(json_obj, "seps"); + if (seps_arr_res.IsErr()) { + return TResult::Error(seps_arr_res.UnwrapErr()); } std::vector seps; - for (const auto& sep : seps_arr) { + for (const auto& sep : seps_arr_res.Unwrap()) { if (!sep.is()) { - *err += "seps should be an array of strings."; - return std::nullopt; + return TResult::Error("A separator (\"seps\") of the conversation template is not a string"); } - seps.push_back(sep.get()); + conv.seps.push_back(sep.get()); } - conv.seps = seps; - std::string role_content_sep; - if (!json::ParseJSONField(json, "role_content_sep", role_content_sep, err, true)) { - return std::nullopt; + Result role_content_sep_res = + json::LookupWithResultReturn(json_obj, "role_content_sep"); + if (role_content_sep_res.IsErr()) { + return TResult::Error(role_content_sep_res.UnwrapErr()); } - conv.role_content_sep = role_content_sep; + conv.role_content_sep = role_content_sep_res.Unwrap(); - std::string role_empty_sep; - if (!json::ParseJSONField(json, "role_empty_sep", role_empty_sep, err, true)) { - return std::nullopt; + Result role_empty_sep_res = + json::LookupWithResultReturn(json_obj, "role_empty_sep"); + if (role_empty_sep_res.IsErr()) { + return TResult::Error(role_empty_sep_res.UnwrapErr()); } - conv.role_empty_sep = role_empty_sep; + conv.role_empty_sep = role_empty_sep_res.Unwrap(); - picojson::array stop_str_arr; - if (!json::ParseJSONField(json, "stop_str", stop_str_arr, err, true)) { - return std::nullopt; + Result stop_str_arr_res = + json::LookupWithResultReturn(json_obj, "stop_str"); + if (stop_str_arr_res.IsErr()) { + return TResult::Error(stop_str_arr_res.UnwrapErr()); } - std::vector stop_str; - for (const auto& stop : stop_str_arr) { + for (const auto& stop : stop_str_arr_res.Unwrap()) { if (!stop.is()) { - *err += "stop_str should be an array of strings."; - return std::nullopt; + return TResult::Error( + "A stop string (\"stop_str\") of the conversation template is not a string."); } - stop_str.push_back(stop.get()); + conv.stop_str.push_back(stop.get()); } - conv.stop_str = stop_str; - picojson::array stop_token_ids_arr; - if (!json::ParseJSONField(json, "stop_token_ids", stop_token_ids_arr, err, true)) { - return std::nullopt; + Result stop_token_ids_arr_res = + json::LookupWithResultReturn(json_obj, "stop_token_ids"); + if (stop_token_ids_arr_res.IsErr()) { + return TResult::Error(stop_token_ids_arr_res.UnwrapErr()); } - std::vector stop_token_ids; - for (const auto& stop : stop_token_ids_arr) { + for (const auto& stop : stop_token_ids_arr_res.Unwrap()) { if (!stop.is()) { - *err += "stop_token_ids should be an array of integers."; - return std::nullopt; + return TResult::Error( + "A stop token id (\"stop_token_ids\") of the conversation template is not an integer."); } - stop_token_ids.push_back(stop.get()); + conv.stop_token_ids.push_back(stop.get()); } - conv.stop_token_ids = stop_token_ids; - std::string function_string; - if (!json::ParseJSONField(json, "function_string", function_string, err, false)) { - conv.function_string = function_string; + Result> function_string_res = + json::LookupOptionalWithResultReturn(json_obj, "function_string"); + if (function_string_res.IsErr()) { + return TResult::Error(function_string_res.UnwrapErr()); } + conv.function_string = function_string_res.Unwrap(); - bool use_function_calling; - if (json::ParseJSONField(json, "use_function_calling", use_function_calling, err, false)) { - conv.use_function_calling = use_function_calling; + Result use_function_calling_res = json::LookupOrDefaultWithResultReturn( + json_obj, "use_function_calling", conv.use_function_calling); + if (use_function_calling_res.IsErr()) { + return TResult::Error(use_function_calling_res.UnwrapErr()); } + conv.use_function_calling = use_function_calling_res.Unwrap(); - return conv; + return TResult::Ok(conv); } -std::optional Conversation::FromJSON(const std::string& json_str, std::string* err) { - std::optional json_obj = json::LoadJSONFromString(json_str, err); - if (!json_obj.has_value()) { - return std::nullopt; +Result Conversation::FromJSON(const std::string& json_str) { + Result json_obj = json::ParseToJSONObjectWithResultReturn(json_str); + if (json_obj.IsErr()) { + return Result::Error(json_obj.UnwrapErr()); } - return Conversation::FromJSON(json_obj.value(), err); + return Conversation::FromJSON(json_obj.Unwrap()); } } // namespace json_ffi diff --git a/cpp/json_ffi/conv_template.h b/cpp/json_ffi/conv_template.h index eeb348831c..2d579a8d94 100644 --- a/cpp/json_ffi/conv_template.h +++ b/cpp/json_ffi/conv_template.h @@ -10,6 +10,7 @@ #include #include "../serve/data.h" +#include "../support/result.h" #include "picojson.h" using namespace mlc::llm::serve; @@ -86,34 +87,17 @@ struct Conversation { // Function call fields // whether using function calling or not, helps check for output message format in API call std::optional function_string = std::nullopt; - std::optional use_function_calling = false; + bool use_function_calling = false; Conversation(); - /** - * @brief Checks the size of the separators vector. - * This function checks if the size of the separators vector is either 1 or 2. - * If the size is not 1 or 2, it throws an invalid_argument exception. - */ - static std::vector CheckMessageSeps(std::vector& seps); - - /*! - * \brief Create the list of prompts from the messages based on the conversation template. - * When creation fails, errors are dumped to the input error string, and nullopt is returned. - */ - std::optional> AsPrompt(std::string* err); - - /*! - * \brief Create a Conversation instance from the given JSON object. - * When creation fails, errors are dumped to the input error string, and nullopt is returned. - */ - static std::optional FromJSON(const picojson::object& json, std::string* err); - - /*! - * \brief Parse and create a Conversation instance from the given JSON string. - * When creation fails, errors are dumped to the input error string, and nullopt is returned. - */ - static std::optional FromJSON(const std::string& json_str, std::string* err); + /*! \brief Create the list of prompts from the messages based on the conversation template. */ + Result> AsPrompt(); + + /*! \brief Create a Conversation instance from the given JSON object. */ + static Result FromJSON(const picojson::object& json); + /*! \brief Parse and create a Conversation instance from the given JSON string. */ + static Result FromJSON(const std::string& json_str); }; } // namespace json_ffi diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 6b2676ee3f..b4f9751719 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -41,16 +41,16 @@ void JSONFFIEngine::StreamBackError(std::string request_id) { response.model = "json_ffi"; // TODO: Return model name from engine (or from args) response.system_fingerprint = ""; - this->request_stream_callback_(Array{picojson::value(response.ToJSON()).serialize()}); + this->request_stream_callback_(Array{picojson::value(response.AsJSON()).serialize()}); } bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request_id) { - std::optional optional_request = - ChatCompletionRequest::FromJSON(request_json_str, &err_); - if (!optional_request.has_value()) { + Result request_res = ChatCompletionRequest::FromJSON(request_json_str); + if (request_res.IsErr()) { + err_ = request_res.UnwrapErr(); return false; } - ChatCompletionRequest request = optional_request.value(); + ChatCompletionRequest request = request_res.Unwrap(); // Create Request // TODO: Check if request_id is present already @@ -74,17 +74,20 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request conv_template.messages = messages; // check function calling - bool success_check = request.CheckFunctionCalling(conv_template, &err_); - if (!success_check) { + Result updated_conv_template = request.CheckFunctionCalling(conv_template); + if (updated_conv_template.IsErr()) { + err_ = updated_conv_template.UnwrapErr(); return false; } + conv_template = updated_conv_template.Unwrap(); // get prompt - std::optional> inputs_obj = conv_template.AsPrompt(&err_); - if (!inputs_obj.has_value()) { + Result> inputs_obj = conv_template.AsPrompt(); + if (inputs_obj.IsErr()) { + err_ = inputs_obj.UnwrapErr(); return false; } - Array inputs = inputs_obj.value(); + Array inputs = inputs_obj.Unwrap(); // generation_cfg Array stop_strs; @@ -162,18 +165,17 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { this->engine_->Reload(engine_config_json_str); this->default_generation_cfg_json_str_ = this->engine_->GetDefaultGenerationConfigJSONString(); picojson::object engine_config_json = - json::ParseToJsonObject(this->engine_->GetCompleteEngineConfigJSONString()); + json::ParseToJSONObject(this->engine_->GetCompleteEngineConfigJSONString()); // Load conversation template. Result model_config_json = serve::Model::LoadModelConfig(json::Lookup(engine_config_json, "model")); CHECK(model_config_json.IsOk()) << model_config_json.UnwrapErr(); - std::optional conv_template = Conversation::FromJSON( - json::Lookup(model_config_json.Unwrap(), "conv_template"), &err_); - if (!conv_template.has_value()) { - LOG(FATAL) << "Invalid conversation template JSON: " << err_; - } - this->conv_template_ = conv_template.value(); + Result conv_template = Conversation::FromJSON( + json::Lookup(model_config_json.Unwrap(), "conv_template")); + CHECK(!conv_template.IsErr()) << "Invalid conversation template JSON: " + << conv_template.UnwrapErr(); + this->conv_template_ = conv_template.Unwrap(); // Create streamer. // Todo(mlc-team): Create one streamer for each request, instead of a global one. this->streamer_ = @@ -240,7 +242,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { response.choices = choices; response.model = "json_ffi"; // TODO: Return model name from engine (or from args) response.system_fingerprint = ""; - response_arr.push_back(picojson::value(response.ToJSON()).serialize()); + response_arr.push_back(picojson::value(response.AsJSON()).serialize()); } return response_arr; } diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc index 4547108eb5..c07de8fef5 100644 --- a/cpp/json_ffi/openai_api_protocol.cc +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -11,53 +11,41 @@ namespace mlc { namespace llm { namespace json_ffi { -std::string generate_uuid_string(size_t length) { - auto randchar = []() -> char { - const char charset[] = - "0123456789" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz"; - const size_t max_index = (sizeof(charset) - 1); - return charset[rand() % max_index]; - }; - std::string str(length, 0); - std::generate_n(str.begin(), length, randchar); - return str; -} - -std::optional ChatFunction::FromJSON(const picojson::object& json_obj, - std::string* err) { - ChatFunction chatFunc; +Result ChatFunction::FromJSON(const picojson::object& json_obj) { + using TResult = Result; + ChatFunction chat_func; - // description (optional) - std::string description; - if (json::ParseJSONField(json_obj, "description", description, err, false)) { - chatFunc.description = description; + // description + Result> description_res = + json::LookupOptionalWithResultReturn(json_obj, "description"); + if (description_res.IsErr()) { + return TResult::Error(description_res.UnwrapErr()); } + chat_func.description = description_res.Unwrap(); // name - std::string name; - if (!json::ParseJSONField(json_obj, "name", name, err, true)) { - return std::nullopt; + Result name_res = json::LookupWithResultReturn(json_obj, "name"); + if (name_res.IsErr()) { + return TResult::Error(name_res.UnwrapErr()); } - chatFunc.name = name; + chat_func.name = name_res.Unwrap(); // parameters - picojson::object parameters_obj; - if (!json::ParseJSONField(json_obj, "parameters", parameters_obj, err, true)) { - return std::nullopt; + Result parameters_obj_res = + json::LookupWithResultReturn(json_obj, "parameters"); + if (parameters_obj_res.IsErr()) { + return TResult::Error(parameters_obj_res.UnwrapErr()); } - std::unordered_map parameters; - for (picojson::value::object::const_iterator i = parameters_obj.begin(); - i != parameters_obj.end(); ++i) { - parameters[i->first] = i->second.to_str(); + picojson::object parameters_obj = parameters_obj_res.Unwrap(); + chat_func.parameters.reserve(parameters_obj.size()); + for (const auto& [key, value] : parameters_obj) { + chat_func.parameters[key] = value.to_str(); } - chatFunc.parameters = parameters; - return chatFunc; + return TResult::Ok(chat_func); } -picojson::object ChatFunction::ToJSON() const { +picojson::object ChatFunction::AsJSON() const { picojson::object obj; if (this->description.has_value()) { obj["description"] = picojson::value(this->description.value()); @@ -71,57 +59,63 @@ picojson::object ChatFunction::ToJSON() const { return obj; } -std::optional ChatTool::FromJSON(const picojson::object& json_obj, std::string* err) { +Result ChatTool::FromJSON(const picojson::object& json_obj) { + using TResult = Result; ChatTool chatTool; // function - picojson::object function_obj; - if (!json::ParseJSONField(json_obj, "function", function_obj, err, true)) { - return std::nullopt; + Result function_obj_res = + json::LookupWithResultReturn(json_obj, "function"); + if (function_obj_res.IsErr()) { + return TResult::Error(function_obj_res.UnwrapErr()); } - - std::optional function = ChatFunction::FromJSON(function_obj, err); - if (!function.has_value()) { - return std::nullopt; + Result function = ChatFunction::FromJSON(function_obj_res.Unwrap()); + if (function.IsErr()) { + return TResult::Error(function.UnwrapErr()); } - chatTool.function = function.value(); + chatTool.function = function.Unwrap(); - return chatTool; + return TResult::Ok(chatTool); } -picojson::object ChatTool::ToJSON() const { +picojson::object ChatTool::AsJSON() const { picojson::object obj; obj["type"] = picojson::value("function"); - obj["function"] = picojson::value(this->function.ToJSON()); + obj["function"] = picojson::value(this->function.AsJSON()); return obj; } -std::optional ChatFunctionCall::FromJSON(const picojson::object& json_obj, - std::string* err) { - ChatFunctionCall chatFuncCall; +Result ChatFunctionCall::FromJSON(const picojson::object& json_obj) { + using TResult = Result; + ChatFunctionCall chat_func_call; // name - std::string name; - if (!json::ParseJSONField(json_obj, "name", name, err, true)) { - return std::nullopt; + Result name_res = json::LookupWithResultReturn(json_obj, "name"); + if (name_res.IsErr()) { + return TResult::Error(name_res.UnwrapErr()); } - chatFuncCall.name = name; + chat_func_call.name = name_res.Unwrap(); // arguments - picojson::object arguments_obj; - if (json::ParseJSONField(json_obj, "arguments", arguments_obj, err, false)) { + Result> arguments_obj_res = + json::LookupOptionalWithResultReturn(json_obj, "arguments"); + if (arguments_obj_res.IsErr()) { + return TResult::Error(arguments_obj_res.UnwrapErr()); + } + std::optional arguments_obj = arguments_obj_res.Unwrap(); + if (arguments_obj.has_value()) { std::unordered_map arguments; - for (picojson::value::object::const_iterator i = arguments_obj.begin(); - i != arguments_obj.end(); ++i) { - arguments[i->first] = i->second.to_str(); + arguments.reserve(arguments_obj.value().size()); + for (const auto& [key, value] : arguments_obj.value()) { + arguments[key] = value.to_str(); } - chatFuncCall.arguments = arguments; + chat_func_call.arguments = std::move(arguments); } - return chatFuncCall; + return TResult::Ok(chat_func_call); } -picojson::object ChatFunctionCall::ToJSON() const { +picojson::object ChatFunctionCall::AsJSON() const { picojson::object obj; picojson::object arguments_obj; if (this->arguments.has_value()) { @@ -135,69 +129,75 @@ picojson::object ChatFunctionCall::ToJSON() const { return obj; } -std::optional ChatToolCall::FromJSON(const picojson::object& json_obj, - std::string* err) { - ChatToolCall chatToolCall; +Result ChatToolCall::FromJSON(const picojson::object& json_obj) { + using TResult = Result; + ChatToolCall chat_tool_call; // function - picojson::object function_obj; - if (!json::ParseJSONField(json_obj, "function", function_obj, err, true)) { - return std::nullopt; + Result function_obj_res = + json::LookupWithResultReturn(json_obj, "function"); + if (function_obj_res.IsErr()) { + return TResult::Error(function_obj_res.UnwrapErr()); } - - std::optional function = ChatFunctionCall::FromJSON(function_obj, err); - if (!function.has_value()) { - return std::nullopt; - }; - chatToolCall.function = function.value(); + Result function_res = ChatFunctionCall::FromJSON(function_obj_res.Unwrap()); + if (function_res.IsErr()) { + return TResult::Error(function_res.UnwrapErr()); + } + chat_tool_call.function = function_res.Unwrap(); // overwrite default id - std::string id; - if (!json::ParseJSONField(json_obj, "id", id, err, false)) { - return std::nullopt; + Result> id_res = + json::LookupOptionalWithResultReturn(json_obj, "id"); + if (id_res.IsErr()) { + return TResult::Error(id_res.UnwrapErr()); + } + std::optional id = id_res.UnwrapErr(); + if (id.has_value()) { + chat_tool_call.id = id.value(); } - chatToolCall.id = id; - return chatToolCall; + return TResult::Ok(chat_tool_call); } -picojson::object ChatToolCall::ToJSON() const { +picojson::object ChatToolCall::AsJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); - obj["function"] = picojson::value(this->function.ToJSON()); + obj["function"] = picojson::value(this->function.AsJSON()); obj["type"] = picojson::value("function"); return obj; } -std::optional ChatCompletionMessage::FromJSON( - const picojson::object& json_obj, std::string* err) { +Result ChatCompletionMessage::FromJSON(const picojson::object& json_obj) { + using TResult = Result; ChatCompletionMessage message; // content - picojson::array content_arr; - if (!json::ParseJSONField(json_obj, "content", content_arr, err, true)) { - return std::nullopt; - } - std::vector > content; - for (const auto& item : content_arr) { + Result content_arr_res = + json::LookupWithResultReturn(json_obj, "content"); + if (content_arr_res.IsErr()) { + return TResult::Error(content_arr_res.UnwrapErr()); + } + std::vector> content; + for (const auto& item : content_arr_res.Unwrap()) { + // Todo(mlc-team): allow content item to be a single string. if (!item.is()) { - *err += "Content item is not an object"; - return std::nullopt; + return TResult::Error("The content of chat completion message is not an object"); } - std::unordered_map item_map; picojson::object item_obj = item.get(); - for (picojson::value::object::const_iterator i = item_obj.begin(); i != item_obj.end(); ++i) { - item_map[i->first] = i->second.to_str(); + std::unordered_map item_map; + for (const auto& [key, value] : item_obj) { + item_map[key] = value.to_str(); } - content.push_back(item_map); + content.push_back(std::move(item_map)); } message.content = content; // role - std::string role_str; - if (!json::ParseJSONField(json_obj, "role", role_str, err, true)) { - return std::nullopt; + Result role_str_res = json::LookupWithResultReturn(json_obj, "role"); + if (role_str_res.IsErr()) { + return TResult::Error(role_str_res.UnwrapErr()); } + std::string role_str = role_str_res.Unwrap(); if (role_str == "system") { message.role = Role::system; } else if (role_str == "user") { @@ -207,124 +207,148 @@ std::optional ChatCompletionMessage::FromJSON( } else if (role_str == "tool") { message.role = Role::tool; } else { - *err += "Invalid role"; - return std::nullopt; + return TResult::Error("Invalid role in chat completion message: " + role_str); } // name - std::string name; - if (json::ParseJSONField(json_obj, "name", name, err, false)) { - message.name = name; + Result> name_res = + json::LookupOptionalWithResultReturn(json_obj, "name"); + if (name_res.IsErr()) { + return TResult::Error(name_res.UnwrapErr()); } + message.name = name_res.Unwrap(); // tool calls - picojson::array tool_calls_arr; - if (json::ParseJSONField(json_obj, "tool_calls", tool_calls_arr, err, false)) { + Result> tool_calls_arr_res = + json::LookupOptionalWithResultReturn(json_obj, "tool_calls"); + if (tool_calls_arr_res.IsErr()) { + return TResult::Error(tool_calls_arr_res.UnwrapErr()); + } + std::optional tool_calls_arr = tool_calls_arr_res.Unwrap(); + if (tool_calls_arr.has_value()) { std::vector tool_calls; - for (const auto& item : tool_calls_arr) { + tool_calls.reserve(tool_calls_arr.value().size()); + for (const auto& item : tool_calls_arr.value()) { if (!item.is()) { - *err += "Chat Tool Call item is not an object"; - return std::nullopt; + return TResult::Error("A tool call item in the chat completion message is not an object"); + } + Result tool_call = ChatToolCall::FromJSON(item.get()); + if (tool_call.IsErr()) { + return TResult::Error(tool_call.UnwrapErr()); } - picojson::object item_obj = item.get(); - std::optional tool_call = ChatToolCall::FromJSON(item_obj, err); - if (!tool_call.has_value()) { - return std::nullopt; - }; - tool_calls.push_back(tool_call.value()); + tool_calls.push_back(tool_call.Unwrap()); } message.tool_calls = tool_calls; } // tool call id - std::string tool_call_id; - if (json::ParseJSONField(json_obj, "tool_call_id", tool_call_id, err, false)) { - message.tool_call_id = tool_call_id; + Result> tool_call_id_res = + json::LookupOptionalWithResultReturn(json_obj, "tool_call_id"); + if (tool_call_id_res.IsErr()) { + return TResult::Error(tool_call_id_res.UnwrapErr()); } + message.tool_call_id = tool_call_id_res.Unwrap(); - return message; + return TResult::Ok(message); } -std::optional ChatCompletionRequest::FromJSON( - const picojson::object& json_obj, std::string* err) { +Result ChatCompletionRequest::FromJSON(const std::string& json_str) { + using TResult = Result; + Result json_obj_res = json::ParseToJSONObjectWithResultReturn(json_str); + if (json_obj_res.IsErr()) { + return TResult::Error(json_obj_res.UnwrapErr()); + } + picojson::object json_obj = json_obj_res.Unwrap(); ChatCompletionRequest request; // messages - picojson::array messages_arr; - if (!json::ParseJSONField(json_obj, "messages", messages_arr, err, true)) { - return std::nullopt; + Result messages_arr_res = + json::LookupWithResultReturn(json_obj, "messages"); + if (messages_arr_res.IsErr()) { + return TResult::Error(messages_arr_res.UnwrapErr()); } std::vector messages; - for (const auto& item : messages_arr) { + for (const auto& item : messages_arr_res.Unwrap()) { + if (!item.is()) { + return TResult::Error("A message in chat completion request is not object"); + } picojson::object item_obj = item.get(); - std::optional message = ChatCompletionMessage::FromJSON(item_obj, err); - if (!message.has_value()) { - return std::nullopt; + Result message = ChatCompletionMessage::FromJSON(item_obj); + if (message.IsErr()) { + return TResult::Error(message.UnwrapErr()); } - messages.push_back(message.value()); + messages.push_back(message.Unwrap()); } request.messages = messages; // model - std::string model; - if (!json::ParseJSONField(json_obj, "model", model, err, true)) { - return std::nullopt; + Result model_res = json::LookupWithResultReturn(json_obj, "model"); + if (model_res.IsErr()) { + return TResult::Error(model_res.UnwrapErr()); + } + request.model = model_res.Unwrap(); + + // max_tokens + Result> max_tokens_res = + json::LookupOptionalWithResultReturn(json_obj, "max_tokens"); + if (max_tokens_res.IsErr()) { + return TResult::Error(max_tokens_res.UnwrapErr()); } - request.model = model; + request.max_tokens = max_tokens_res.Unwrap(); // frequency_penalty - double frequency_penalty; - if (json::ParseJSONField(json_obj, "frequency_penalty", frequency_penalty, err, false)) { - request.frequency_penalty = frequency_penalty; + Result> frequency_penalty_res = + json::LookupOptionalWithResultReturn(json_obj, "frequency_penalty"); + if (frequency_penalty_res.IsErr()) { + return TResult::Error(frequency_penalty_res.UnwrapErr()); } + request.frequency_penalty = frequency_penalty_res.Unwrap(); // presence_penalty - double presence_penalty; - if (json::ParseJSONField(json_obj, "presence_penalty", presence_penalty, err, false)) { - request.presence_penalty = presence_penalty; + Result> presence_penalty_res = + json::LookupOptionalWithResultReturn(json_obj, "presence_penalty"); + if (presence_penalty_res.IsErr()) { + return TResult::Error(presence_penalty_res.UnwrapErr()); } + request.presence_penalty = presence_penalty_res.Unwrap(); // tool_choice - std::string tool_choice = "auto"; - request.tool_choice = tool_choice; - if (json::ParseJSONField(json_obj, "tool_choice", tool_choice, err, false)) { - request.tool_choice = tool_choice; + Result tool_choice_res = + json::LookupOrDefaultWithResultReturn(json_obj, "tool_choice", "auto"); + if (tool_choice_res.IsErr()) { + return TResult::Error(tool_choice_res.UnwrapErr()); } + request.tool_choice = tool_choice_res.Unwrap(); // tools - picojson::array tools_arr; - if (json::ParseJSONField(json_obj, "tools", tools_arr, err, false)) { + Result> tools_arr_res = + json::LookupOptionalWithResultReturn(json_obj, "tools"); + if (tool_choice_res.IsErr()) { + return TResult::Error(tool_choice_res.UnwrapErr()); + } + std::optional tools_arr = tools_arr_res.Unwrap(); + if (tools_arr.has_value()) { std::vector tools; - for (const auto& item : tools_arr) { + tools.reserve(tools_arr.value().size()); + for (const auto& item : tools_arr.value()) { if (!item.is()) { - *err += "Chat Tool item is not an object"; - return std::nullopt; + return TResult::Error("A tool of the chat completion request is not an object"); + } + Result tool = ChatTool::FromJSON(item.get()); + if (tool.IsErr()) { + return TResult::Error(tool.UnwrapErr()); } - picojson::object item_obj = item.get(); - std::optional tool = ChatTool::FromJSON(item_obj, err); - if (!tool.has_value()) { - return std::nullopt; - }; - tools.push_back(tool.value()); + tools.push_back(tool.Unwrap()); } request.tools = tools; } // TODO: Other parameters - return request; -} - -std::optional ChatCompletionRequest::FromJSON(const std::string& json_str, - std::string* err) { - std::optional json_obj = json::LoadJSONFromString(json_str, err); - if (!json_obj.has_value()) { - return std::nullopt; - } - return ChatCompletionRequest::FromJSON(json_obj.value(), err); + return TResult::Ok(request); } -picojson::object ChatCompletionMessage::ToJSON() const { +picojson::object ChatCompletionMessage::AsJSON() const { picojson::object obj; picojson::array content_arr; for (const auto& item : this->content.value()) { @@ -353,17 +377,18 @@ picojson::object ChatCompletionMessage::ToJSON() const { if (this->tool_calls.has_value()) { picojson::array tool_calls_arr; for (const auto& tool_call : this->tool_calls.value()) { - tool_calls_arr.push_back(picojson::value(tool_call.ToJSON())); + tool_calls_arr.push_back(picojson::value(tool_call.AsJSON())); } obj["tool_calls"] = picojson::value(tool_calls_arr); } return obj; } -bool ChatCompletionRequest::CheckFunctionCalling(Conversation& conv_template, std::string* err) { +Result ChatCompletionRequest::CheckFunctionCalling(Conversation conv_template) { + using TResult = Result; if (!tools.has_value() || (tool_choice.has_value() && tool_choice.value() == "none")) { conv_template.use_function_calling = false; - return true; + return TResult::Ok(conv_template); } std::vector tools_ = tools.value(); std::string tool_choice_ = tool_choice.value(); @@ -372,29 +397,28 @@ bool ChatCompletionRequest::CheckFunctionCalling(Conversation& conv_template, st for (const auto& tool : tools_) { if (tool.function.name == tool_choice_) { conv_template.use_function_calling = true; - picojson::value function_str(tool.function.ToJSON()); + picojson::value function_str(tool.function.AsJSON()); conv_template.function_string = function_str.serialize(); - return true; + return TResult::Ok(conv_template); } } if (tool_choice_ != "auto") { - *err += "Invalid tool_choice value: " + tool_choice_; - return false; + return TResult::Error("Invalid tool_choice value in the request: " + tool_choice_); } picojson::array function_list; for (const auto& tool : tools_) { - function_list.push_back(picojson::value(tool.function.ToJSON())); + function_list.push_back(picojson::value(tool.function.AsJSON())); } conv_template.use_function_calling = true; picojson::value function_list_json(function_list); conv_template.function_string = function_list_json.serialize(); - return true; + return TResult::Ok(conv_template); }; -picojson::object ChatCompletionResponseChoice::ToJSON() const { +picojson::object ChatCompletionResponseChoice::AsJSON() const { picojson::object obj; if (!this->finish_reason.has_value()) { obj["finish_reason"] = picojson::value(); @@ -410,11 +434,11 @@ picojson::object ChatCompletionResponseChoice::ToJSON() const { } } obj["index"] = picojson::value((int64_t)this->index); - obj["message"] = picojson::value(this->message.ToJSON()); + obj["message"] = picojson::value(this->message.AsJSON()); return obj; } -picojson::object ChatCompletionStreamResponseChoice::ToJSON() const { +picojson::object ChatCompletionStreamResponseChoice::AsJSON() const { picojson::object obj; if (!this->finish_reason.has_value()) { obj["finish_reason"] = picojson::value(); @@ -431,16 +455,16 @@ picojson::object ChatCompletionStreamResponseChoice::ToJSON() const { } obj["index"] = picojson::value((int64_t)this->index); - obj["delta"] = picojson::value(this->delta.ToJSON()); + obj["delta"] = picojson::value(this->delta.AsJSON()); return obj; } -picojson::object ChatCompletionResponse::ToJSON() const { +picojson::object ChatCompletionResponse::AsJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); picojson::array choices_arr; for (const auto& choice : this->choices) { - choices_arr.push_back(picojson::value(choice.ToJSON())); + choices_arr.push_back(picojson::value(choice.AsJSON())); } obj["choices"] = picojson::value(choices_arr); obj["created"] = picojson::value((int64_t)this->created); @@ -450,12 +474,12 @@ picojson::object ChatCompletionResponse::ToJSON() const { return obj; } -picojson::object ChatCompletionStreamResponse::ToJSON() const { +picojson::object ChatCompletionStreamResponse::AsJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); picojson::array choices_arr; for (const auto& choice : this->choices) { - choices_arr.push_back(picojson::value(choice.ToJSON())); + choices_arr.push_back(picojson::value(choice.AsJSON())); } obj["choices"] = picojson::value(choices_arr); obj["created"] = picojson::value((int64_t)this->created); diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index 70ef2fb22f..914366c2f1 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -13,6 +13,7 @@ #include #include +#include "../support/result.h" #include "conv_template.h" #include "picojson.h" @@ -24,17 +25,30 @@ enum class Role { system, user, assistant, tool }; enum class Type { text, json_object, function }; enum class FinishReason { stop, length, tool_calls, error }; -std::string generate_uuid_string(size_t length); +inline std::string generate_uuid_string(size_t length) { + auto randchar = []() -> char { + const char charset[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[rand() % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; +} class ChatFunction { public: std::optional description = std::nullopt; std::string name; + // Todo: change to std::vector>? std::unordered_map parameters; // Assuming parameters are string key-value pairs - static std::optional FromJSON(const picojson::object& json, std::string* err); - picojson::object ToJSON() const; + static Result FromJSON(const picojson::object& json); + picojson::object AsJSON() const; }; class ChatTool { @@ -42,8 +56,8 @@ class ChatTool { Type type = Type::function; ChatFunction function; - static std::optional FromJSON(const picojson::object& json, std::string* err); - picojson::object ToJSON() const; + static Result FromJSON(const picojson::object& json); + picojson::object AsJSON() const; }; class ChatFunctionCall { @@ -52,8 +66,8 @@ class ChatFunctionCall { std::optional> arguments = std::nullopt; // Assuming arguments are string key-value pairs - static std::optional FromJSON(const picojson::object& json, std::string* err); - picojson::object ToJSON() const; + static Result FromJSON(const picojson::object& json); + picojson::object AsJSON() const; }; class ChatToolCall { @@ -62,8 +76,8 @@ class ChatToolCall { Type type = Type::function; ChatFunctionCall function; - static std::optional FromJSON(const picojson::object& json, std::string* err); - picojson::object ToJSON() const; + static Result FromJSON(const picojson::object& json); + picojson::object AsJSON() const; }; class ChatCompletionMessage { @@ -75,9 +89,8 @@ class ChatCompletionMessage { std::optional> tool_calls = std::nullopt; std::optional tool_call_id = std::nullopt; - static std::optional FromJSON(const picojson::object& json, - std::string* err); - picojson::object ToJSON() const; + static Result FromJSON(const picojson::object& json); + picojson::object AsJSON() const; }; class RequestResponseFormat { @@ -108,20 +121,10 @@ class ChatCompletionRequest { bool ignore_eos = false; // RequestResponseFormat response_format; //TODO: implement this - /*! - * \brief Create a ChatCompletionRequest instance from the given JSON object. - * When creation fails, errors are dumped to the input error string, and nullopt is returned. - */ - static std::optional FromJSON(const picojson::object& json_obj, - std::string* err); - /*! - * \brief Parse and create a ChatCompletionRequest instance from the given JSON string. - * When creation fails, errors are dumped to the input error string, and nullopt is returned. - */ - static std::optional FromJSON(const std::string& json_str, - std::string* err); - - bool CheckFunctionCalling(Conversation& conv_template, std::string* err); + /*! \brief Parse and create a ChatCompletionRequest instance from the given JSON string. */ + static Result FromJSON(const std::string& json_str); + + Result CheckFunctionCalling(Conversation conv_template); // TODO: check_penalty_range, check_logit_bias, check_logprobs }; @@ -132,7 +135,7 @@ class ChatCompletionResponseChoice { ChatCompletionMessage message; // TODO: logprobs - picojson::object ToJSON() const; + picojson::object AsJSON() const; }; class ChatCompletionStreamResponseChoice { @@ -142,7 +145,7 @@ class ChatCompletionStreamResponseChoice { ChatCompletionMessage delta; // TODO: logprobs - picojson::object ToJSON() const; + picojson::object AsJSON() const; }; class ChatCompletionResponse { @@ -155,7 +158,7 @@ class ChatCompletionResponse { std::string object = "chat.completion"; // TODO: usage_info - picojson::object ToJSON() const; + picojson::object AsJSON() const; }; class ChatCompletionStreamResponse { @@ -167,7 +170,7 @@ class ChatCompletionStreamResponse { std::string system_fingerprint; std::string object = "chat.completion.chunk"; - picojson::object ToJSON() const; + picojson::object AsJSON() const; }; } // namespace json_ffi diff --git a/cpp/metadata/model.cc b/cpp/metadata/model.cc index 2daf1d0338..62ba2787b9 100644 --- a/cpp/metadata/model.cc +++ b/cpp/metadata/model.cc @@ -90,7 +90,7 @@ ModelMetadata ModelMetadata::FromModule(tvm::runtime::Module module, std::string json_str = ""; TypedPackedFunc pf = module.GetFunction("_metadata"); json_str = pf(); - picojson::object json = json::ParseToJsonObject(json_str); + picojson::object json = json::ParseToJSONObject(json_str); try { return ModelMetadata::FromJSON(json, model_config); } catch (const std::exception& e) { diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 30a3617a8d..9b9d5ba65a 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -79,7 +79,7 @@ GenerationConfig::GenerationConfig( GenerationConfig::GenerationConfig(String config_json_str, Optional default_config_json_str) { - picojson::object config = json::ParseToJsonObject(config_json_str); + picojson::object config = json::ParseToJSONObject(config_json_str); ObjectPtr n = make_object(); GenerationConfig default_config; if (default_config_json_str.defined()) { diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index a0ae4d98f3..a4eda4e395 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -420,7 +420,7 @@ BNFGrammar EBNFParser::Parse(std::string ebnf_string, std::string main_rule) { BNFGrammar BNFJSONParser::Parse(std::string json_string) { auto node = make_object(); - auto grammar_json = json::ParseToJsonObject(json_string); + auto grammar_json = json::ParseToJSONObject(json_string); auto rules_json = json::Lookup(grammar_json, "rules"); for (const auto& rule_json : rules_json) { auto rule_json_obj = rule_json.get(); diff --git a/cpp/support/json_parser.h b/cpp/support/json_parser.h index f71757435a..ef1225081d 100644 --- a/cpp/support/json_parser.h +++ b/cpp/support/json_parser.h @@ -12,6 +12,8 @@ #include +#include "result.h" + namespace mlc { namespace llm { namespace json { @@ -21,52 +23,31 @@ namespace json { * \param json_str The JSON string to parse. * \return The parsed JSON object. */ -picojson::object ParseToJsonObject(const std::string& json_str); - -// Todo(mlc-team): implement "Result" class for JSON parsing with error collection. -/*! - * \brief Parse input JSON string into JSON dict. - * Any error will be dumped to the input error string. - */ -inline std::optional LoadJSONFromString(const std::string& json_str, - std::string* err) { - ICHECK_NOTNULL(err); - picojson::value json; - *err = picojson::parse(json, json_str); - if (!json.is()) { - *err += "The input JSON string does not correspond to a JSON dict."; - return std::nullopt; - } - return json.get(); +inline picojson::object ParseToJSONObject(const std::string& json_str) { + picojson::value result; + std::string err = picojson::parse(result, json_str); + CHECK(err.empty()) << "Failed to parse JSON: err. The JSON string is:" << json_str; + CHECK(result.is()) + << "ValueError: The given string is not a JSON object: " << json_str; + return result.get(); } - /*! - * \brief // Todo(mlc-team): document this function. - * \tparam T - * \param json_obj - * \param field - * \param value - * \param err - * \param required - * \return + * \brief Parse a JSON string to a JSON object. + * \param json_str The JSON string to parse. + * \return The parsed JSON object, or the error message. */ -template -inline bool ParseJSONField(const picojson::object& json_obj, const std::string& field, T& value, - std::string* err, bool required) { - // T can be int, double, bool, string, picojson::array - if (json_obj.count(field)) { - if (!json_obj.at(field).is()) { - *err += "Field " + field + " is not of type " + typeid(T).name() + "\n"; - return false; - } - value = json_obj.at(field).get(); - } else { - if (required) { - *err += "Field " + field + " is required\n"; - return false; - } +inline Result ParseToJSONObjectWithResultReturn(const std::string& json_str) { + using TResult = Result; + picojson::value result; + std::string err = picojson::parse(result, json_str); + if (!err.empty()) { + return TResult::Error("Failed to parse JSON: err. The JSON string is: " + json_str + + ". The error is " + err); + } + if (!result.is()) { + return TResult::Error("ValueError: The given string is not a JSON object: " + json_str); } - return true; + return TResult::Ok(result.get()); } /*! @@ -87,6 +68,109 @@ ValueType Lookup(const picojson::object& json, const std::string& key); */ template ValueType Lookup(const picojson::array& json, int index); +/*! + * \brief Lookup a JSON object by a key, and convert it to a given type. + * If the key doesn't exist or has null value, the default value is returned. + * \param json The JSON object to look up. + * \param key The key to look up. + * \tparam ValueType The type to be converted to. + * \return The converted value, or the default value if the key doesn't exist or has null value. + */ +template +inline ValueType LookupOrDefault(const picojson::object& json, const std::string& key, + const ValueType& default_value) { + auto it = json.find(key); + if (it == json.end() || it->second.is()) { + return default_value; + } + CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; + return it->second.get(); +} +/*! + * \brief Lookup a JSON object by a key, and convert it to a given type. + * If the key doesn't exist or has null value, return std::nullopt. + * \param json The JSON object to look up. + * \param key The key to look up. + * \tparam ValueType The type to be converted to. + * \return The converted value, or std::nullopt if the value doesn't exist or has null value. + */ +template +inline std::optional LookupOptional(const picojson::object& json, + const std::string& key) { + auto it = json.find(key); + if (it == json.end() || it->second.is()) { + return std::nullopt; + } + CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; + return it->second.get(); +} +/*! + * \brief Lookup a JSON object by a key, and convert it to a given type. + * \param json The JSON object to look up. + * \param key The key to look up. + * \tparam ValueType The type to be converted to. + * \return The converted value, or the error message. + */ +template +inline Result LookupWithResultReturn(const picojson::object& json, + const std::string& key) { + using TResult = Result; + auto it = json.find(key); + if (it == json.end()) { + return TResult::Error("ValueError: key \"" + key + "\" not found in the JSON object"); + } + if (!it->second.is()) { + return TResult::Error("ValueError: key \"" + key + "\" has unexpected value type."); + } + return TResult::Ok(it->second.get()); +} +/*! + * \brief Lookup a JSON object by a key, and convert it to a given type. + * If the key doesn't exist or has null value, the default value is returned. + * \param json The JSON object to look up. + * \param key The key to look up. + * \tparam ValueType The type to be converted to. + * \return The converted value, or the default value if the key doesn't exist or has null value + * , or the error message. + */ +template +inline Result LookupOrDefaultWithResultReturn(const picojson::object& json, + const std::string& key, + const ValueType& default_value) { + using TResult = Result; + auto it = json.find(key); + if (it == json.end() || it->second.is()) { + return TResult::Ok(default_value); + } + if (!it->second.is()) { + return TResult::Error("ValueError: key \"" + key + "\" has unexpected value type."); + } + return TResult::Ok(it->second.get()); +} +/*! + * \brief Lookup a JSON object by a key, and convert it to a given type. + * If the key doesn't exist or has null value, return std::nullopt. + * \param json The JSON object to look up. + * \param key The key to look up. + * \tparam ValueType The type to be converted to. + * \return The converted value, or std::nullopt if the value doesn't exist or has null value, + * , or the error message. + */ +template +inline Result> LookupOptionalWithResultReturn(const picojson::object& json, + const std::string& key) { + using TResult = Result>; + auto it = json.find(key); + if (it == json.end() || it->second.is()) { + return TResult::Ok(std::nullopt); + } + if (!it->second.is()) { + return TResult::Error("ValueError: key \"" + key + "\" has unexpected value type."); + } + return TResult::Ok(it->second.get()); +} + +// Implementation details /*! \brief ShapeTuple extension to incorporate symbolic shapes. */ struct SymShapeTuple { @@ -112,8 +196,6 @@ struct SymShapeTuple { } }; -// Implementation details - namespace details { inline tvm::runtime::DataType DTypeFromString(const std::string& s) { @@ -149,33 +231,6 @@ inline ValueType Lookup(const picojson::object& json, const std::string& key) { return it->second.get(); } -template -inline ValueType LookupOrDefault(const picojson::object& json, const std::string& key, - const ValueType& default_value) { - auto it = json.find(key); - if (it == json.end()) { - return default_value; - } - - if (it->second.is()) { - return default_value; - } - - CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; - return it->second.get(); -} - -template -inline std::optional LookupOptional(const picojson::object& json, - const std::string& key) { - auto it = json.find(key); - if (it == json.end() || it->second.is()) { - return std::nullopt; - } - CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; - return it->second.get(); -} - template inline ValueType Lookup(const picojson::array& json, int index) { CHECK(index < json.size()) << "IndexError: json::array index out of range"; @@ -205,17 +260,6 @@ inline SymShapeTuple Lookup(const picojson::array& json, int index) { return details::SymShapeTupleFromArray(Lookup(json, index)); } -inline picojson::object ParseToJsonObject(const std::string& json_str) { - picojson::value result; - std::string err = picojson::parse(result, json_str); - if (!err.empty()) { - LOG(FATAL) << "Failed to parse JSON: err. The JSON string is:" << json_str; - } - CHECK(result.is()) - << "ValueError: The given string is not a JSON object: " << json_str; - return result.get(); -} - } // namespace json } // namespace llm } // namespace mlc From f181ce2e9e5ec6c445fe123d7e4e5fd89a7764c5 Mon Sep 17 00:00:00 2001 From: Wei Tao <1136862851@qq.com> Date: Sun, 5 May 2024 21:58:27 +0800 Subject: [PATCH 270/531] [Bugfix] fix _kv_cache_transpose_append buffer read region error (#2277) * improve Install via environment variable * [HotFix] fix kv_cache_transpose_append buffer region --- python/mlc_llm/nn/kv_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index e4cbf1c047..e5cae1e5cd 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -399,7 +399,7 @@ def tir_kv_cache_transpose_append( pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf] = k_data[vgpos, vh, vf] with T.block("v_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.reads(position_map[vgpos], v_data[vgpos, vh, vf]) T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) position: T.int32 = position_map[vgpos] # type: ignore[name-defined,no-redef] pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf] = v_data[vgpos, vh, vf] From 23636e5c0f4ede72e143ed1168a22860b814a59b Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 5 May 2024 18:08:34 -0400 Subject: [PATCH 271/531] [GenConfig] Set upper bound for prefill chunk size (#2278) By default the prefill chunk size is set to the context window size or the sliding window size. When the number is large, our memory planning during model compilation will allocate a lot memory. Given we have support for input chunking, we can reduce the prefill chunk size to a small value to save runtime memory. This PR sets the prefill chunk size to be at most 2048. --- python/mlc_llm/model/baichuan/baichuan_model.py | 14 ++++++-------- python/mlc_llm/model/chatglm3/chatglm3_model.py | 14 ++++++-------- python/mlc_llm/model/gemma/gemma_model.py | 14 ++++++-------- python/mlc_llm/model/gpt2/gpt2_model.py | 14 ++++++-------- .../model/gpt_bigcode/gpt_bigcode_model.py | 14 ++++++-------- python/mlc_llm/model/gpt_neox/gpt_neox_model.py | 14 ++++++-------- python/mlc_llm/model/internlm/internlm_model.py | 14 ++++++-------- python/mlc_llm/model/llama/llama_model.py | 14 ++++++-------- python/mlc_llm/model/mistral/mistral_model.py | 7 +++---- python/mlc_llm/model/orion/orion_model.py | 14 ++++++-------- python/mlc_llm/model/phi/phi_model.py | 17 ++++++++++++++--- python/mlc_llm/model/qwen/qwen_model.py | 14 ++++++-------- python/mlc_llm/model/qwen2/qwen2_model.py | 14 ++++++-------- .../mlc_llm/model/stable_lm/stablelm_model.py | 14 ++++++-------- 14 files changed, 89 insertions(+), 103 deletions(-) diff --git a/python/mlc_llm/model/baichuan/baichuan_model.py b/python/mlc_llm/model/baichuan/baichuan_model.py index 1d8f88c676..0b6dfb1477 100644 --- a/python/mlc_llm/model/baichuan/baichuan_model.py +++ b/python/mlc_llm/model/baichuan/baichuan_model.py @@ -66,21 +66,19 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring diff --git a/python/mlc_llm/model/chatglm3/chatglm3_model.py b/python/mlc_llm/model/chatglm3/chatglm3_model.py index f7e81019e0..df86353540 100644 --- a/python/mlc_llm/model/chatglm3/chatglm3_model.py +++ b/python/mlc_llm/model/chatglm3/chatglm3_model.py @@ -72,21 +72,19 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring diff --git a/python/mlc_llm/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py index 118f3ce856..c08c6d9ad4 100644 --- a/python/mlc_llm/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -68,21 +68,19 @@ def __post_init__(self): assert self.num_attention_heads % self.num_key_value_heads == 0 if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index ede9dc350f..0922a7a1bf 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -63,21 +63,19 @@ def __post_init__(self): assert self.head_dim * self.n_head == self.n_embd if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring,too-many-locals diff --git a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py index c13d169be1..dd721ad444 100644 --- a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py @@ -55,21 +55,19 @@ def __post_init__(self): ) if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring diff --git a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py index 5e940a15b3..0ce1858c89 100644 --- a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py @@ -70,21 +70,19 @@ def __post_init__(self): if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring diff --git a/python/mlc_llm/model/internlm/internlm_model.py b/python/mlc_llm/model/internlm/internlm_model.py index f8e95ab4ec..00683add3b 100644 --- a/python/mlc_llm/model/internlm/internlm_model.py +++ b/python/mlc_llm/model/internlm/internlm_model.py @@ -65,21 +65,19 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 60c8f138d1..69f01ee13b 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -70,21 +70,19 @@ def __post_init__(self): assert self.num_attention_heads % self.num_key_value_heads == 0 if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring diff --git a/python/mlc_llm/model/mistral/mistral_model.py b/python/mlc_llm/model/mistral/mistral_model.py index 3439f7b41f..966dc6e35e 100644 --- a/python/mlc_llm/model/mistral/mistral_model.py +++ b/python/mlc_llm/model/mistral/mistral_model.py @@ -54,12 +54,11 @@ def __post_init__(self): assert self.attention_sink_size >= 0 if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("sliding_window_size"), - self.sliding_window_size, + min(self.sliding_window_size, 2048), ) - self.prefill_chunk_size = self.sliding_window_size + self.prefill_chunk_size = min(self.sliding_window_size, 2048) # pylint: disable=invalid-name,missing-docstring diff --git a/python/mlc_llm/model/orion/orion_model.py b/python/mlc_llm/model/orion/orion_model.py index c6a2293cd2..d9c55e1f6c 100644 --- a/python/mlc_llm/model/orion/orion_model.py +++ b/python/mlc_llm/model/orion/orion_model.py @@ -70,21 +70,19 @@ def __post_init__(self): assert self.num_attention_heads % self.num_key_value_heads == 0 if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring diff --git a/python/mlc_llm/model/phi/phi_model.py b/python/mlc_llm/model/phi/phi_model.py index 2c9c596ed7..7ecb5e211f 100644 --- a/python/mlc_llm/model/phi/phi_model.py +++ b/python/mlc_llm/model/phi/phi_model.py @@ -64,9 +64,20 @@ def __post_init__(self): "provided in `config.json`." ) if self.prefill_chunk_size == 0: - self.prefill_chunk_size = self.context_window_size - if self.prefill_chunk_size > self.context_window_size: - self.prefill_chunk_size = self.context_window_size + logger.info( + "%s defaults to %d", + bold("prefill_chunk_size"), + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) if self.num_key_value_heads == 0 or self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads if self.intermediate_size == 0 or self.intermediate_size is None: diff --git a/python/mlc_llm/model/qwen/qwen_model.py b/python/mlc_llm/model/qwen/qwen_model.py index 09bb8e854f..cbca790246 100644 --- a/python/mlc_llm/model/qwen/qwen_model.py +++ b/python/mlc_llm/model/qwen/qwen_model.py @@ -63,21 +63,19 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring diff --git a/python/mlc_llm/model/qwen2/qwen2_model.py b/python/mlc_llm/model/qwen2/qwen2_model.py index 6eae4c2bb0..88e49af635 100644 --- a/python/mlc_llm/model/qwen2/qwen2_model.py +++ b/python/mlc_llm/model/qwen2/qwen2_model.py @@ -63,21 +63,19 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring,too-many-locals diff --git a/python/mlc_llm/model/stable_lm/stablelm_model.py b/python/mlc_llm/model/stable_lm/stablelm_model.py index 10e16cded6..ea87e64fc7 100644 --- a/python/mlc_llm/model/stable_lm/stablelm_model.py +++ b/python/mlc_llm/model/stable_lm/stablelm_model.py @@ -64,21 +64,19 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring From 6bcd70ca696b4527242e9a679cd9b30f802c73b3 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 6 May 2024 07:57:50 -0400 Subject: [PATCH 272/531] [iOS] Initial scaffolding of MLCEngine in Swift (#2279) [iOS] Initial scaffolding of LLMEngine in Swift This PR adds initial scaffolding of LLMEngine in swift. We wraps callback to AsyncStream so it can be accessed using for await API. We also added an minimal example app to showcase the new MLCEngine, the old ChatModule is still used in the MLCChat App. The return value is structified already. We will still need to structurify the chat completion interface. --- ios/MLCChat/States/AppState.swift | 8 +- ios/MLCChat/States/ChatState.swift | 16 +- .../project.pbxproj | 415 ++++++++++++++++++ .../contents.xcworkspacedata | 7 + .../xcshareddata/IDEWorkspaceChecks.plist | 8 + .../AccentColor.colorset/Contents.json | 11 + .../AppIcon.appiconset/Contents.json | 13 + .../Assets.xcassets/Contents.json | 6 + .../MLCEngineExample/ContentView.swift | 21 + .../MLCEngineExample.entitlements | 10 + .../MLCEngineExampleApp.swift | 92 ++++ .../Preview Assets.xcassets/Contents.json | 6 + ios/MLCEngineExample/READMD.md | 6 + ios/MLCSwift/Sources/ObjC/LLMEngine.mm | 112 +++++ ios/MLCSwift/Sources/ObjC/include/LLMEngine.h | 32 ++ ios/MLCSwift/Sources/Swift/LLMEngine.swift | 111 +++++ .../Sources/Swift/OpenAIProtocol.swift | 70 +++ ios/MLCSwift/Sources/Swift/ThreadWorker.swift | 4 +- ios/prepare_libs.sh | 3 +- 19 files changed, 936 insertions(+), 15 deletions(-) create mode 100644 ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj create mode 100644 ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.xcworkspace/contents.xcworkspacedata create mode 100644 ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist create mode 100644 ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/AccentColor.colorset/Contents.json create mode 100644 ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/AppIcon.appiconset/Contents.json create mode 100644 ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/Contents.json create mode 100644 ios/MLCEngineExample/MLCEngineExample/ContentView.swift create mode 100644 ios/MLCEngineExample/MLCEngineExample/MLCEngineExample.entitlements create mode 100644 ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift create mode 100644 ios/MLCEngineExample/MLCEngineExample/Preview Content/Preview Assets.xcassets/Contents.json create mode 100644 ios/MLCEngineExample/READMD.md create mode 100644 ios/MLCSwift/Sources/ObjC/LLMEngine.mm create mode 100644 ios/MLCSwift/Sources/ObjC/include/LLMEngine.h create mode 100644 ios/MLCSwift/Sources/Swift/LLMEngine.swift create mode 100644 ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift diff --git a/ios/MLCChat/States/AppState.swift b/ios/MLCChat/States/AppState.swift index 4b8af5086f..4dc8d9f315 100644 --- a/ios/MLCChat/States/AppState.swift +++ b/ios/MLCChat/States/AppState.swift @@ -13,7 +13,7 @@ final class AppState: ObservableObject { @Published var alertMessage = "" // TODO: Should move out @Published var alertDisplayed = false // TODO: Should move out - + private var appConfig: AppConfig? private var modelIDs = Set() @@ -33,7 +33,7 @@ final class AppState: ObservableObject { } loadModelsConfig(modelList: appConfig.modelList) } - + func requestDeleteModel(modelID: String) { // model dir should have been deleted in ModelState assert(!fileManager.fileExists(atPath: cacheDirectoryURL.appending(path: modelID).path())) @@ -65,7 +65,7 @@ private extension AppState { return nil } } - + func loadModelsConfig(modelList: [AppConfig.ModelRecord]) { for model in modelList { if model.modelPath != nil { @@ -131,7 +131,7 @@ private extension AppState { let fileHandle = try FileHandle(forReadingFrom: modelConfigURL) let data = fileHandle.readDataToEndOfFile() var modelConfig = try jsonDecoder.decode(ModelConfig.self, from: data) - modelConfig.modelLib = modelLib + modelConfig.modelLib = modelLib modelConfig.modelID = modelID modelConfig.estimatedVRAMReq = estimatedVRAMReq return modelConfig diff --git a/ios/MLCChat/States/ChatState.swift b/ios/MLCChat/States/ChatState.swift index 7a5a60f66f..cb1903c1d7 100644 --- a/ios/MLCChat/States/ChatState.swift +++ b/ios/MLCChat/States/ChatState.swift @@ -37,7 +37,7 @@ final class ChatState: ObservableObject { @Published var infoText = "" @Published var displayName = "" @Published var useVision = false - + private let modelChatStateLock = NSLock() private var modelChatState: ModelChatState = .ready @@ -46,12 +46,12 @@ final class ChatState: ObservableObject { private var modelLib = "" private var modelPath = "" var modelID = "" - + init() { threadWorker.qualityOfService = QualityOfService.userInteractive threadWorker.start() } - + var isInterruptible: Bool { return getModelChatState() == .ready || getModelChatState() == .generating @@ -71,7 +71,7 @@ final class ChatState: ObservableObject { return getModelChatState() == .ready || getModelChatState() == .generating } - + func requestResetChat() { assert(isResettable) interruptChat(prologue: { @@ -80,7 +80,7 @@ final class ChatState: ObservableObject { self?.mainResetChat() }) } - + func requestTerminateChat(callback: @escaping () -> Void) { assert(isInterruptible) interruptChat(prologue: { @@ -89,7 +89,7 @@ final class ChatState: ObservableObject { self?.mainTerminateChat(callback: callback) }) } - + func requestReloadChat(modelID: String, modelLib: String, modelPath: String, estimatedVRAMReq: Int, displayName: String) { if (isCurrentModel(modelID: modelID)) { return @@ -105,7 +105,7 @@ final class ChatState: ObservableObject { displayName: displayName) }) } - + func requestGenerate(prompt: String) { assert(isChattable) switchToGenerating() @@ -222,7 +222,7 @@ private extension ChatState { func interruptChat(prologue: () -> Void, epilogue: @escaping () -> Void) { assert(isInterruptible) - if getModelChatState() == .ready + if getModelChatState() == .ready || getModelChatState() == .failed || getModelChatState() == .pendingImageUpload { prologue() diff --git a/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj b/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj new file mode 100644 index 0000000000..f24f333d83 --- /dev/null +++ b/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj @@ -0,0 +1,415 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 60; + objects = { + +/* Begin PBXBuildFile section */ + C0B37B892BE8226A00B2F80B /* MLCEngineExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0B37B882BE8226A00B2F80B /* MLCEngineExampleApp.swift */; }; + C0B37B8B2BE8226A00B2F80B /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0B37B8A2BE8226A00B2F80B /* ContentView.swift */; }; + C0B37B8D2BE8226B00B2F80B /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0B37B8C2BE8226B00B2F80B /* Assets.xcassets */; }; + C0B37B902BE8226B00B2F80B /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0B37B8F2BE8226B00B2F80B /* Preview Assets.xcassets */; }; + C0B37B982BE8234D00B2F80B /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = C0B37B972BE8234D00B2F80B /* MLCSwift */; }; + C0B37C0A2BE82D5900B2F80B /* dist in Copy Files */ = {isa = PBXBuildFile; fileRef = C0B37C062BE825DC00B2F80B /* dist */; }; +/* End PBXBuildFile section */ + +/* Begin PBXCopyFilesBuildPhase section */ + C0B37B992BE8255600B2F80B /* Copy Files */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 12; + dstPath = ""; + dstSubfolderSpec = 7; + files = ( + C0B37C0A2BE82D5900B2F80B /* dist in Copy Files */, + ); + name = "Copy Files"; + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXCopyFilesBuildPhase section */ + +/* Begin PBXFileReference section */ + C0B37B852BE8226A00B2F80B /* MLCEngineExample.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLCEngineExample.app; sourceTree = BUILT_PRODUCTS_DIR; }; + C0B37B882BE8226A00B2F80B /* MLCEngineExampleApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLCEngineExampleApp.swift; sourceTree = ""; }; + C0B37B8A2BE8226A00B2F80B /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; + C0B37B8C2BE8226B00B2F80B /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + C0B37B8F2BE8226B00B2F80B /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; + C0B37C062BE825DC00B2F80B /* dist */ = {isa = PBXFileReference; lastKnownFileType = folder; name = dist; path = ../dist; sourceTree = ""; }; + C0B37C0C2BE8349300B2F80B /* MLCEngineExample.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLCEngineExample.entitlements; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + C0B37B822BE8226A00B2F80B /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + C0B37B982BE8234D00B2F80B /* MLCSwift in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + C0B37B7C2BE8226A00B2F80B = { + isa = PBXGroup; + children = ( + C0B37C062BE825DC00B2F80B /* dist */, + C0B37B872BE8226A00B2F80B /* MLCEngineExample */, + C0B37B862BE8226A00B2F80B /* Products */, + ); + sourceTree = ""; + }; + C0B37B862BE8226A00B2F80B /* Products */ = { + isa = PBXGroup; + children = ( + C0B37B852BE8226A00B2F80B /* MLCEngineExample.app */, + ); + name = Products; + sourceTree = ""; + }; + C0B37B872BE8226A00B2F80B /* MLCEngineExample */ = { + isa = PBXGroup; + children = ( + C0B37C0C2BE8349300B2F80B /* MLCEngineExample.entitlements */, + C0B37B882BE8226A00B2F80B /* MLCEngineExampleApp.swift */, + C0B37B8A2BE8226A00B2F80B /* ContentView.swift */, + C0B37B8C2BE8226B00B2F80B /* Assets.xcassets */, + C0B37B8E2BE8226B00B2F80B /* Preview Content */, + ); + path = MLCEngineExample; + sourceTree = ""; + }; + C0B37B8E2BE8226B00B2F80B /* Preview Content */ = { + isa = PBXGroup; + children = ( + C0B37B8F2BE8226B00B2F80B /* Preview Assets.xcassets */, + ); + path = "Preview Content"; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + C0B37B842BE8226A00B2F80B /* MLCEngineExample */ = { + isa = PBXNativeTarget; + buildConfigurationList = C0B37B932BE8226B00B2F80B /* Build configuration list for PBXNativeTarget "MLCEngineExample" */; + buildPhases = ( + C0B37B812BE8226A00B2F80B /* Sources */, + C0B37B822BE8226A00B2F80B /* Frameworks */, + C0B37B832BE8226A00B2F80B /* Resources */, + C0B37B992BE8255600B2F80B /* Copy Files */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = MLCEngineExample; + packageProductDependencies = ( + C0B37B972BE8234D00B2F80B /* MLCSwift */, + ); + productName = MLCEngineExample; + productReference = C0B37B852BE8226A00B2F80B /* MLCEngineExample.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + C0B37B7D2BE8226A00B2F80B /* Project object */ = { + isa = PBXProject; + attributes = { + BuildIndependentTargetsInParallel = 1; + LastSwiftUpdateCheck = 1530; + LastUpgradeCheck = 1530; + TargetAttributes = { + C0B37B842BE8226A00B2F80B = { + CreatedOnToolsVersion = 15.3; + }; + }; + }; + buildConfigurationList = C0B37B802BE8226A00B2F80B /* Build configuration list for PBXProject "MLCEngineExample" */; + compatibilityVersion = "Xcode 14.0"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = C0B37B7C2BE8226A00B2F80B; + packageReferences = ( + C0B37B962BE8234D00B2F80B /* XCLocalSwiftPackageReference "../MLCSwift" */, + ); + productRefGroup = C0B37B862BE8226A00B2F80B /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + C0B37B842BE8226A00B2F80B /* MLCEngineExample */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + C0B37B832BE8226A00B2F80B /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + C0B37B902BE8226B00B2F80B /* Preview Assets.xcassets in Resources */, + C0B37B8D2BE8226B00B2F80B /* Assets.xcassets in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + C0B37B812BE8226A00B2F80B /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + C0B37B8B2BE8226A00B2F80B /* ContentView.swift in Sources */, + C0B37B892BE8226A00B2F80B /* MLCEngineExampleApp.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + C0B37B912BE8226B00B2F80B /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 17.4; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + }; + name = Debug; + }; + C0B37B922BE8226B00B2F80B /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 17.4; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + SDKROOT = iphoneos; + SWIFT_COMPILATION_MODE = wholemodule; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + C0B37B942BE8226B00B2F80B /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; + CODE_SIGN_ENTITLEMENTS = MLCEngineExample/MLCEngineExample.entitlements; + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 1; + DEVELOPMENT_ASSET_PATHS = "\"MLCEngineExample/Preview Content\""; + DEVELOPMENT_TEAM = 3FR42MXLK9; + ENABLE_PREVIEWS = YES; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; + INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; + INFOPLIST_KEY_UILaunchScreen_Generation = YES; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + IPHONEOS_DEPLOYMENT_TARGET = 16.0; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + LIBRARY_SEARCH_PATHS = "$(PROJECT_DIR)/../build/lib"; + MARKETING_VERSION = 1.0; + OTHER_LDFLAGS = ( + "-Wl,-all_load", + "-lmodel_iphone", + "-lmlc_llm", + "-ltvm_runtime", + "-ltokenizers_cpp", + "-lsentencepiece", + "-ltokenizers_c", + ); + PRODUCT_BUNDLE_IDENTIFIER = mlc.MLCEngineExample; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + C0B37B952BE8226B00B2F80B /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; + CODE_SIGN_ENTITLEMENTS = MLCEngineExample/MLCEngineExample.entitlements; + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 1; + DEVELOPMENT_ASSET_PATHS = "\"MLCEngineExample/Preview Content\""; + DEVELOPMENT_TEAM = 3FR42MXLK9; + ENABLE_PREVIEWS = YES; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; + INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; + INFOPLIST_KEY_UILaunchScreen_Generation = YES; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + IPHONEOS_DEPLOYMENT_TARGET = 16.0; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + LIBRARY_SEARCH_PATHS = "$(PROJECT_DIR)/../build/lib"; + MARKETING_VERSION = 1.0; + OTHER_LDFLAGS = ( + "-Wl,-all_load", + "-lmodel_iphone", + "-lmlc_llm", + "-ltvm_runtime", + "-ltokenizers_cpp", + "-lsentencepiece", + "-ltokenizers_c", + ); + PRODUCT_BUNDLE_IDENTIFIER = mlc.MLCEngineExample; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + C0B37B802BE8226A00B2F80B /* Build configuration list for PBXProject "MLCEngineExample" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + C0B37B912BE8226B00B2F80B /* Debug */, + C0B37B922BE8226B00B2F80B /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + C0B37B932BE8226B00B2F80B /* Build configuration list for PBXNativeTarget "MLCEngineExample" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + C0B37B942BE8226B00B2F80B /* Debug */, + C0B37B952BE8226B00B2F80B /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + +/* Begin XCLocalSwiftPackageReference section */ + C0B37B962BE8234D00B2F80B /* XCLocalSwiftPackageReference "../MLCSwift" */ = { + isa = XCLocalSwiftPackageReference; + relativePath = ../MLCSwift; + }; +/* End XCLocalSwiftPackageReference section */ + +/* Begin XCSwiftPackageProductDependency section */ + C0B37B972BE8234D00B2F80B /* MLCSwift */ = { + isa = XCSwiftPackageProductDependency; + productName = MLCSwift; + }; +/* End XCSwiftPackageProductDependency section */ + }; + rootObject = C0B37B7D2BE8226A00B2F80B /* Project object */; +} diff --git a/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.xcworkspace/contents.xcworkspacedata new file mode 100644 index 0000000000..919434a625 --- /dev/null +++ b/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.xcworkspace/contents.xcworkspacedata @@ -0,0 +1,7 @@ + + + + + diff --git a/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist new file mode 100644 index 0000000000..18d981003d --- /dev/null +++ b/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist @@ -0,0 +1,8 @@ + + + + + IDEDidComputeMac32BitWarning + + + diff --git a/ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/AccentColor.colorset/Contents.json b/ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/AccentColor.colorset/Contents.json new file mode 100644 index 0000000000..eb87897008 --- /dev/null +++ b/ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/AccentColor.colorset/Contents.json @@ -0,0 +1,11 @@ +{ + "colors" : [ + { + "idiom" : "universal" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/AppIcon.appiconset/Contents.json b/ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000000..13613e3ee1 --- /dev/null +++ b/ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,13 @@ +{ + "images" : [ + { + "idiom" : "universal", + "platform" : "ios", + "size" : "1024x1024" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/Contents.json b/ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/Contents.json new file mode 100644 index 0000000000..73c00596a7 --- /dev/null +++ b/ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/ios/MLCEngineExample/MLCEngineExample/ContentView.swift b/ios/MLCEngineExample/MLCEngineExample/ContentView.swift new file mode 100644 index 0000000000..650cd38cb5 --- /dev/null +++ b/ios/MLCEngineExample/MLCEngineExample/ContentView.swift @@ -0,0 +1,21 @@ +// This is a minimum example App to interact with MLC Engine +// +// for a complete example, take a look at the MLCChat + +import SwiftUI + +struct ContentView: View { + @EnvironmentObject private var appState: AppState + // simply display text on the app + var body: some View { + HStack { + Text(appState.displayText) + Spacer() + } + .padding() + } +} + +#Preview { + ContentView() +} diff --git a/ios/MLCEngineExample/MLCEngineExample/MLCEngineExample.entitlements b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExample.entitlements new file mode 100644 index 0000000000..caa3d58396 --- /dev/null +++ b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExample.entitlements @@ -0,0 +1,10 @@ + + + + + com.apple.developer.kernel.extended-virtual-addressing + + com.apple.developer.kernel.increased-memory-limit + + + diff --git a/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift new file mode 100644 index 0000000000..19b6ab45de --- /dev/null +++ b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift @@ -0,0 +1,92 @@ +// NOTE: This example is still work in progress +// +// This is a minimum example App to interact with MLC Engine +// This app is mainly created with minimalism in mind for +// example and quick testing purposes. +// +// To build this app, select target My Mac(Designed for iPad) and run +// Make sure you run prepare_libs.sh and prepare_params.sh first +// to ensure the dist folder populates with the right model file +// and we have the model lib packaged correctly +import Foundation +import SwiftUI + +// Import MLCSwift +import MLCSwift + +class AppState: ObservableObject { + // the MLC engine instance + private let engine = MLCEngine() + // obtain the local path to store models + // this that stores the model files in the dist folder + private let distURL = Bundle.main.bundleURL.appending(path: "dist") + // NOTE: this does not yet work out of box + // need to supply the Llama-3-8B-Instruct-q3f16_1-MLC and llama_q3f16_1 + // via manual local compile + // TODO(mlc-team): update prebuild so it can be used out of box + // + // model path, this must match a builtin + // file name in prepare_params.sh + private let modelPath = "Llama-3-8B-Instruct-q3f16_1-MLC" + // model lib identifier of within the packaged library + // this must match a config in MLCChat/app-config.json + // make sure we run prepare_libs.sh + private let modelLib = "llama_q3f16_1" + + // this is a message to be displayed in app + @Published var displayText = "" + + public func runExample() { + // MLCEngine is a actor that can be called in an async context + Task { + let modelLocalPath = distURL.appending(path: modelPath).path() + // Step 0: load the engine + await engine.reload(modelPath: modelLocalPath, modelLib: modelLib) + + // TODO(mlc-team) update request so it is also structure based + // as in open ai api + // sent a request + let jsonRequest = """ + { + "model": "llama3", + "messages": [ + { + "role": "user", + "content": [ + { "type": "text", "text": "What is the meaning of life?" } + ] + } + ] + } + """ + // run chat completion as in OpenAI API style + for await res in await engine.chatCompletion(jsonRequest: jsonRequest) { + // publish at main event loop + DispatchQueue.main.async { + // parse the result content in structured form + // and stream back to the display + self.displayText += res.choices[0].delta.content![0]["text"]! + } + } + } + } +} + + +@main +struct MLCEngineExampleApp: App { + private let appState = AppState() + + init() { + // we simply run test + // please checkout output in console + appState.runExample() + } + + var body: some Scene { + WindowGroup { + ContentView() + .environmentObject(appState) + } + } +} diff --git a/ios/MLCEngineExample/MLCEngineExample/Preview Content/Preview Assets.xcassets/Contents.json b/ios/MLCEngineExample/MLCEngineExample/Preview Content/Preview Assets.xcassets/Contents.json new file mode 100644 index 0000000000..73c00596a7 --- /dev/null +++ b/ios/MLCEngineExample/MLCEngineExample/Preview Content/Preview Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/ios/MLCEngineExample/READMD.md b/ios/MLCEngineExample/READMD.md new file mode 100644 index 0000000000..e08265f4b2 --- /dev/null +++ b/ios/MLCEngineExample/READMD.md @@ -0,0 +1,6 @@ +# MLCEngine Example + +Minimal example of the latest MLCEngine Swift API. + +NOTE: this project is still work in progress, +things may not yet be fully functioning and are subject to change diff --git a/ios/MLCSwift/Sources/ObjC/LLMEngine.mm b/ios/MLCSwift/Sources/ObjC/LLMEngine.mm new file mode 100644 index 0000000000..bafc7a29db --- /dev/null +++ b/ios/MLCSwift/Sources/ObjC/LLMEngine.mm @@ -0,0 +1,112 @@ +// +// LLMEngine.mm +// LLMEngine +// +#import +#import +#include + +#include "LLMEngine.h" + +#define TVM_USE_LIBBACKTRACE 0 +#define DMLC_USE_LOGGING_LIBRARY + +#include +#include + +using namespace tvm::runtime; + +@implementation JSONFFIEngine { + // Internal c++ classes + // internal module backed by JSON FFI + Module json_ffi_engine_; + // member functions + PackedFunc init_background_engine_func_; + PackedFunc unload_func_; + PackedFunc reload_func_; + PackedFunc reset_func_; + PackedFunc chat_completion_func_; + PackedFunc abort_func_; + PackedFunc run_background_loop_func_; + PackedFunc run_background_stream_back_loop_func_; + PackedFunc exit_background_loop_func_; +} + +- (instancetype)init { + if (self = [super init]) { + // load chat module + const PackedFunc* f_json_ffi_create = Registry::Get("mlc.json_ffi.CreateJSONFFIEngine"); + ICHECK(f_json_ffi_create) << "Cannot find mlc.json_ffi.CreateJSONFFIEngine"; + json_ffi_engine_ = (*f_json_ffi_create)(); + init_background_engine_func_ = json_ffi_engine_->GetFunction("init_background_engine"); + reload_func_ = json_ffi_engine_->GetFunction("reload"); + unload_func_ = json_ffi_engine_->GetFunction("unload"); + reset_func_ = json_ffi_engine_->GetFunction("reset"); + chat_completion_func_ = json_ffi_engine_->GetFunction("chat_completion"); + abort_func_ = json_ffi_engine_->GetFunction("abort"); + run_background_loop_func_ = json_ffi_engine_->GetFunction("run_background_loop"); + run_background_stream_back_loop_func_ = + json_ffi_engine_->GetFunction("run_background_stream_back_loop"); + exit_background_loop_func_ = json_ffi_engine_->GetFunction("exit_background_loop"); + + ICHECK(init_background_engine_func_ != nullptr); + ICHECK(reload_func_ != nullptr); + ICHECK(unload_func_ != nullptr); + ICHECK(reset_func_ != nullptr); + ICHECK(chat_completion_func_ != nullptr); + ICHECK(abort_func_ != nullptr); + ICHECK(run_background_loop_func_ != nullptr); + ICHECK(run_background_stream_back_loop_func_ != nullptr); + ICHECK(exit_background_loop_func_ != nullptr); + } + return self; +} + +- (void)initBackgroundEngine:(void (^)(NSString*))streamCallback { + TypedPackedFunc)> internal_stream_callback( + [streamCallback](Array res) { + for (String value : res) { + streamCallback([NSString stringWithUTF8String:value.c_str()]); + } + }); + DLDevice metal_device{kDLMetal, 0}; + init_background_engine_func_(metal_device, internal_stream_callback, nullptr); +} + +- (void)reload:(NSString*)engineConfigJson { + std::string engine_config = engineConfigJson.UTF8String; + reload_func_(engine_config); +} + +- (void)unload { + unload_func_(); +} + +- (void)reset { + reset_func_(); +} + +- (void)chatCompletion:(NSString*)requestJSON requestID:(NSString*)requestID { + std::string request_json = requestJSON.UTF8String; + std::string request_id = requestID.UTF8String; + chat_completion_func_(request_json, request_id); +} + +- (void)abort:(NSString*)requestID { + std::string request_id = requestID.UTF8String; + abort_func_(request_id); +} + +- (void)runBackgroundLoop { + run_background_loop_func_(); +} + +- (void)runBackgroundStreamBackLoop { + run_background_stream_back_loop_func_(); +} + +- (void)exitBackgroundLoop { + exit_background_loop_func_(); +} + +@end diff --git a/ios/MLCSwift/Sources/ObjC/include/LLMEngine.h b/ios/MLCSwift/Sources/ObjC/include/LLMEngine.h new file mode 100644 index 0000000000..22fc4ef653 --- /dev/null +++ b/ios/MLCSwift/Sources/ObjC/include/LLMEngine.h @@ -0,0 +1,32 @@ +// +// Use this file to import your target's public headers that you would like to expose to Swift. +// LLM Chat Module +// +// Exposed interface of Object-C, enables swift binding. +#import +#import + +/** + * This is an internal Raw JSON FFI Engine that redirects request to internal JSON FFI Engine in C++ + */ +@interface JSONFFIEngine : NSObject + +- (void)initBackgroundEngine:(void (^)(NSString*))streamCallback; + +- (void)reload:(NSString*)engineConfig; + +- (void)unload; + +- (void)reset; + +- (void)chatCompletion:(NSString*)requestJSON requestID:(NSString*)requestID; + +- (void)abort:(NSString*)requestID; + +- (void)runBackgroundLoop; + +- (void)runBackgroundStreamBackLoop; + +- (void)exitBackgroundLoop; + +@end diff --git a/ios/MLCSwift/Sources/Swift/LLMEngine.swift b/ios/MLCSwift/Sources/Swift/LLMEngine.swift new file mode 100644 index 0000000000..91a4d20b81 --- /dev/null +++ b/ios/MLCSwift/Sources/Swift/LLMEngine.swift @@ -0,0 +1,111 @@ +import Foundation +import LLMChatObjC +import os + +class BackgroundWorker : Thread { + private var task: ()->Void; + + public init(task: @escaping () -> Void) { + self.task = task + } + + public override func main() { + self.task(); + } +} + +@available(iOS 14.0.0, *) +public actor MLCEngine { + private let jsonFFIEngine = JSONFFIEngine() + private var threads = Array(); + private var continuationMap = Dictionary.Continuation>() + private let logger = Logger() + + + public init() { + jsonFFIEngine.initBackgroundEngine { (result : String?) -> Void in + self.streamCallback(result: result) + } + // startup background threads with + let backgroundWorker = BackgroundWorker { + Thread.setThreadPriority(1) + self.jsonFFIEngine.runBackgroundLoop() + } + let backgroundStreamBackWorker = BackgroundWorker { + self.jsonFFIEngine.runBackgroundStreamBackLoop() + } + // set background worker to be high QoS so it gets higher p for gpu + backgroundWorker.qualityOfService = QualityOfService.userInteractive + threads.append(backgroundWorker) + threads.append(backgroundStreamBackWorker) + backgroundWorker.start() + backgroundStreamBackWorker.start() + } + + deinit { + jsonFFIEngine.exitBackgroundLoop() + } + + public func reload(modelPath: String, modelLib: String) { + let engineConfig = """ + { + "model": "\(modelPath)", + "model_lib": "system://\(modelLib)", + "mode": "interactive" + } + """ + jsonFFIEngine.reload(engineConfig) + } + + public func unload() { + jsonFFIEngine.unload() + } + + // TODO(mlc-team) turn into a structured interface + public func chatCompletion(jsonRequest: String) -> AsyncStream { + // generate a UUID for the request + let requestID = UUID().uuidString + let stream = AsyncStream(ChatCompletionStreamResponse.self) { continuation in + continuation.onTermination = { termination in + if termination == .cancelled { + self.jsonFFIEngine.abort(requestID); + } + } + // store continuation map for further callbacks + self.continuationMap[requestID] = continuation + // start invoking engine for completion + self.jsonFFIEngine.chatCompletion(jsonRequest, requestID: requestID) + } + return stream + } + + private func streamCallback(result: String?) { + var responses: [ChatCompletionStreamResponse] = [] + + let decoder = JSONDecoder() + do { + let msg = try decoder.decode(ChatCompletionStreamResponse.self, from: result!.data(using: .utf8)!) + responses.append(msg) + } catch let lastError { + logger.error("Swift json parsing error: error=\(lastError), jsonsrc=\(result!)") + } + + // dispatch to right request ID + for res in responses { + if let continuation = self.continuationMap[res.id] { + continuation.yield(res) + // detect finished from result + var finished = false + for choice in res.choices { + if choice.finish_reason != "" && choice.finish_reason != nil { + finished = true; + } + } + if finished { + continuation.finish() + self.continuationMap.removeValue(forKey: res.id) + } + } + } + } +} diff --git a/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift b/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift new file mode 100644 index 0000000000..1aa652af5e --- /dev/null +++ b/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift @@ -0,0 +1,70 @@ +// Protocol definition of OpenAI API +import Foundation + +// Protocols for v1/chat/completions +// API reference: https://platform.openai.com/docs/api-reference/chat/create + +public struct TopLogProbs : Codable { + public let token: String + public let logprob: Float + public let bytes: Optional<[Int]> +} + +public struct LogProbsContent : Codable { + public let token: String + public let logprob: Float + public var bytes: Optional<[Int]> = nil + public var top_logprobs: [TopLogProbs] = [] +} + +public struct LogProbs : Codable { + public var content: [LogProbsContent] = [] +} + +public struct ChatFunction : Codable { + public let name: String + public var description: Optional = nil + public let parameters: [String: String] +} + +public struct ChatTool : Codable { + public let type: String = "function" + public let function: ChatFunction +} + +public struct ChatFunctionCall : Codable { + public let name: String + // NOTE: arguments shold be dict str to any codable + // for now only allow string output due to typing issues + public var arguments: Optional<[String: String]> = nil +} + +public struct ChatToolCall : Codable { + public let id: String = UUID().uuidString + public let type: String = "function" + public let function: ChatFunctionCall +} + +public struct ChatCompletionMessage : Codable { + public let role: String + public var content: Optional<[[String: String]]> = nil + public var name: Optional = nil + public var tool_calls: Optional<[ChatToolCall]> = nil + public var tool_call_id: Optional = nil +} + +public struct ChatCompletionStreamResponseChoice: Codable { + public var finish_reason: Optional = nil + public let index: Int + public let delta: ChatCompletionMessage + public var lobprobs: Optional = nil +} + +public struct ChatCompletionStreamResponse: Codable { + public let id : String + public var choices: [ChatCompletionStreamResponseChoice] = [] + public var created: Optional = nil + public var model: Optional = nil + public let system_fingerprint: String + public var object: Optional = nil +} diff --git a/ios/MLCSwift/Sources/Swift/ThreadWorker.swift b/ios/MLCSwift/Sources/Swift/ThreadWorker.swift index 79f1eb2004..6f992f681d 100644 --- a/ios/MLCSwift/Sources/Swift/ThreadWorker.swift +++ b/ios/MLCSwift/Sources/Swift/ThreadWorker.swift @@ -7,7 +7,7 @@ import Foundation public class ThreadWorker : Thread { private var cond = NSCondition(); private var queue = Array<()->Void>(); - + public override func main() { Thread.setThreadPriority(1) while (true) { @@ -20,7 +20,7 @@ public class ThreadWorker : Thread { task() } } - + public func push(task: @escaping ()->Void) { self.cond.lock() self.queue.append(task) diff --git a/ios/prepare_libs.sh b/ios/prepare_libs.sh index d87423890d..3885024b51 100755 --- a/ios/prepare_libs.sh +++ b/ios/prepare_libs.sh @@ -64,7 +64,8 @@ cmake ../..\ -DCMAKE_CXX_FLAGS="-O3"\ -DMLC_LLM_INSTALL_STATIC_LIB=ON\ -DUSE_METAL=ON -make mlc_llm_static + +cmake --build . --config release --target mlc_llm_static -j cmake --build . --target install --config release -j cd .. From d31941fc6fd41bb899d35b6b1c391a2c2f43a35e Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 6 May 2024 08:05:22 -0400 Subject: [PATCH 273/531] Rename READMD.md to README.md --- ios/MLCEngineExample/{READMD.md => README.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename ios/MLCEngineExample/{READMD.md => README.md} (100%) diff --git a/ios/MLCEngineExample/READMD.md b/ios/MLCEngineExample/README.md similarity index 100% rename from ios/MLCEngineExample/READMD.md rename to ios/MLCEngineExample/README.md From 5ae393abd6f2157ef136f4ccd47cfcce618a2420 Mon Sep 17 00:00:00 2001 From: Animesh Bohara Date: Mon, 6 May 2024 20:26:34 +0530 Subject: [PATCH 274/531] [Serving] Image support in JSONFFIEngine (#2208) Using new Result interface Co-authored-by: Animesh Bohara --- .gitmodules | 3 + 3rdparty/stb | 1 + CMakeLists.txt | 1 + cpp/json_ffi/conv_template.cc | 144 +++++++++++++++- cpp/json_ffi/conv_template.h | 39 ++++- cpp/json_ffi/image_utils.cc | 156 ++++++++++++++++++ cpp/json_ffi/image_utils.h | 31 ++++ cpp/json_ffi/json_ffi_engine.cc | 12 +- cpp/json_ffi/json_ffi_engine.h | 2 + python/mlc_llm/model/llava/llava_model.py | 97 +---------- python/mlc_llm/serve/data.py | 4 +- .../json_ffi/test_json_ffi_engine_image.py | 91 ++++++++++ 12 files changed, 481 insertions(+), 100 deletions(-) create mode 160000 3rdparty/stb create mode 100644 cpp/json_ffi/image_utils.cc create mode 100644 cpp/json_ffi/image_utils.h create mode 100644 tests/python/json_ffi/test_json_ffi_engine_image.py diff --git a/.gitmodules b/.gitmodules index 10ef4b2682..ac9bafe076 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "3rdparty/tvm"] path = 3rdparty/tvm url = https://github.com/mlc-ai/relax.git +[submodule "3rdparty/stb"] + path = 3rdparty/stb + url = https://github.com/nothings/stb.git diff --git a/3rdparty/stb b/3rdparty/stb new file mode 160000 index 0000000000..ae721c50ea --- /dev/null +++ b/3rdparty/stb @@ -0,0 +1 @@ +Subproject commit ae721c50eaf761660b4f90cc590453cdb0c2acd0 diff --git a/CMakeLists.txt b/CMakeLists.txt index 7f0dd7ef24..24504c8bee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,6 +88,7 @@ target_include_directories(mlc_llm_objs PRIVATE ${MLC_LLM_INCLUDES}) target_compile_definitions(mlc_llm_objs PRIVATE ${MLC_LLM_COMPILE_DEFS}) target_include_directories(mlc_llm_objs PRIVATE ${TOKENZIER_CPP_PATH}/include) target_compile_definitions(mlc_llm_objs PRIVATE -DMLC_LLM_EXPORTS) +target_include_directories(mlc_llm_objs PRIVATE 3rdparty/stb) add_library(mlc_llm SHARED $) add_library(mlc_llm_static STATIC $) diff --git a/cpp/json_ffi/conv_template.cc b/cpp/json_ffi/conv_template.cc index 4feee6f98e..e23258f0b8 100644 --- a/cpp/json_ffi/conv_template.cc +++ b/cpp/json_ffi/conv_template.cc @@ -3,6 +3,7 @@ #include #include "../support/json_parser.h" +#include "image_utils.h" namespace mlc { namespace llm { @@ -10,6 +11,124 @@ namespace json_ffi { using namespace mlc::llm; +/****************** Model vision config ******************/ + +ModelVisionConfig ModelVisionConfig::FromJSON(const picojson::object& json_obj) { + ModelVisionConfig config; + + Result hidden_size_res = json::LookupWithResultReturn(json_obj, "hidden_size"); + if (hidden_size_res.IsOk()) { + config.hidden_size = hidden_size_res.Unwrap(); + } + + Result image_size_res = json::LookupWithResultReturn(json_obj, "image_size"); + if (image_size_res.IsOk()) { + config.image_size = image_size_res.Unwrap(); + } + + Result intermediate_size_res = + json::LookupWithResultReturn(json_obj, "intermediate_size"); + if (intermediate_size_res.IsOk()) { + config.intermediate_size = intermediate_size_res.Unwrap(); + } + + Result num_attention_heads_res = + json::LookupWithResultReturn(json_obj, "num_attention_heads"); + if (num_attention_heads_res.IsOk()) { + config.num_attention_heads = num_attention_heads_res.Unwrap(); + } + + Result num_hidden_layers_res = + json::LookupWithResultReturn(json_obj, "num_hidden_layers"); + if (num_hidden_layers_res.IsOk()) { + config.num_hidden_layers = num_hidden_layers_res.Unwrap(); + } + + Result patch_size_res = json::LookupWithResultReturn(json_obj, "patch_size"); + if (patch_size_res.IsOk()) { + config.patch_size = patch_size_res.Unwrap(); + } + + Result projection_dim_res = + json::LookupWithResultReturn(json_obj, "projection_dim"); + if (projection_dim_res.IsOk()) { + config.projection_dim = projection_dim_res.Unwrap(); + } + + Result vocab_size_res = json::LookupWithResultReturn(json_obj, "vocab_size"); + if (vocab_size_res.IsOk()) { + config.vocab_size = vocab_size_res.Unwrap(); + } + + Result dtype_res = json::LookupWithResultReturn(json_obj, "dtype"); + if (dtype_res.IsOk()) { + config.dtype = dtype_res.Unwrap(); + } + + Result num_channels_res = + json::LookupWithResultReturn(json_obj, "num_channels"); + if (num_channels_res.IsOk()) { + config.num_channels = num_channels_res.Unwrap(); + } + + Result layer_norm_eps_res = + json::LookupWithResultReturn(json_obj, "layer_norm_eps"); + if (layer_norm_eps_res.IsOk()) { + config.layer_norm_eps = layer_norm_eps_res.Unwrap(); + } + + return config; +} + +/****************** Model config ******************/ + +ModelConfig ModelConfig::FromJSON(const picojson::object& json_obj) { + ModelConfig config; + + Result vocab_size_res = json::LookupWithResultReturn(json_obj, "vocab_size"); + if (vocab_size_res.IsOk()) { + config.vocab_size = vocab_size_res.Unwrap(); + } + + Result context_window_size_res = + json::LookupWithResultReturn(json_obj, "context_window_size"); + if (context_window_size_res.IsOk()) { + config.context_window_size = context_window_size_res.Unwrap(); + } + + Result sliding_window_size_res = + json::LookupWithResultReturn(json_obj, "sliding_window_size"); + if (sliding_window_size_res.IsOk()) { + config.sliding_window_size = sliding_window_size_res.Unwrap(); + } + + Result prefill_chunk_size_res = + json::LookupWithResultReturn(json_obj, "prefill_chunk_size"); + if (prefill_chunk_size_res.IsOk()) { + config.prefill_chunk_size = prefill_chunk_size_res.Unwrap(); + } + + Result tensor_parallel_shards_res = + json::LookupWithResultReturn(json_obj, "tensor_parallel_shards"); + if (tensor_parallel_shards_res.IsOk()) { + config.tensor_parallel_shards = tensor_parallel_shards_res.Unwrap(); + } + + Result max_batch_size_res = + json::LookupWithResultReturn(json_obj, "max_batch_size"); + if (max_batch_size_res.IsOk()) { + config.max_batch_size = max_batch_size_res.Unwrap(); + } + + if (json_obj.count("vision_config")) { + const picojson::object& vision_config_obj = + json_obj.at("vision_config").get(); + config.vision_config = ModelVisionConfig::FromJSON(vision_config_obj); + } + + return config; +} + /****************** Conversation template ******************/ std::map PLACEHOLDERS = { @@ -34,7 +153,7 @@ Conversation::Conversation() {"assistant", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]}, {"tool", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {} -Result> Conversation::AsPrompt() { +Result> Conversation::AsPrompt(ModelConfig config, DLDevice device) { using TResult = Result>; // Get the system message std::string system_msg = system_template; @@ -116,6 +235,29 @@ Result> Conversation::AsPrompt() { } } message += role_text; + } else if (it_type->second == "image_url") { + if (item.find("image_url") == item.end()) { + return TResult::Error("Content should have an image_url field"); + } + std::string image_url = + item.at("image_url"); // TODO(mlc-team): According to OpenAI API reference this + // should be a map, with a "url" key containing the URL, but + // we are just assuming this as the URL for now + std::string base64_image = image_url.substr(image_url.find(",") + 1); + Result image_data_res = LoadImageFromBase64(base64_image); + if (image_data_res.IsErr()) { + return TResult::Error(image_data_res.UnwrapErr()); + } + if (!config.vision_config.has_value()) { + return TResult::Error("Vision config is required for image input"); + } + int image_size = config.vision_config.value().image_size; + int patch_size = config.vision_config.value().patch_size; + + int embed_size = (image_size * image_size) / (patch_size * patch_size); + + auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device); + message_list.push_back(ImageData(image_ndarray, embed_size)); } else { return TResult::Error("Unsupported content type: " + it_type->second); } diff --git a/cpp/json_ffi/conv_template.h b/cpp/json_ffi/conv_template.h index 2d579a8d94..8217c5d6e5 100644 --- a/cpp/json_ffi/conv_template.h +++ b/cpp/json_ffi/conv_template.h @@ -19,6 +19,43 @@ namespace mlc { namespace llm { namespace json_ffi { +/****************** Model vision config ******************/ + +/*! \brief Defines the Vision config of the model (if present) */ +class ModelVisionConfig { + public: + int hidden_size; + int image_size; + int intermediate_size; + int num_attention_heads; + int num_hidden_layers; + int patch_size; + int projection_dim; + int vocab_size; + std::string dtype; + int num_channels; + double layer_norm_eps; + + static ModelVisionConfig FromJSON(const picojson::object& json_obj); +}; + +/****************** Model config ******************/ + +/*! \brief Defines the config of the model. +Populated from "model_config" field in mlc-chat-config.json */ +class ModelConfig { + public: + int vocab_size; + int context_window_size; + int sliding_window_size; + int prefill_chunk_size; + int tensor_parallel_shards; + int max_batch_size; + std::optional vision_config = std::nullopt; + + static ModelConfig FromJSON(const picojson::object& json_obj); +}; + /****************** Conversation template ******************/ enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION }; @@ -92,7 +129,7 @@ struct Conversation { Conversation(); /*! \brief Create the list of prompts from the messages based on the conversation template. */ - Result> AsPrompt(); + Result> AsPrompt(ModelConfig config, DLDevice device); /*! \brief Create a Conversation instance from the given JSON object. */ static Result FromJSON(const picojson::object& json); diff --git a/cpp/json_ffi/image_utils.cc b/cpp/json_ffi/image_utils.cc new file mode 100644 index 0000000000..24c785fbd5 --- /dev/null +++ b/cpp/json_ffi/image_utils.cc @@ -0,0 +1,156 @@ +#include "image_utils.h" + +#include + +#include "../../3rdparty/tvm/src/support/base64.h" +#define STB_IMAGE_IMPLEMENTATION +#include "stb_image.h" + +namespace mlc { +namespace llm { +namespace json_ffi { + +using namespace tvm::runtime; + +class MemoryBufferStream : public dmlc::Stream { + public: + MemoryBufferStream(const char* data, size_t size) : data_(data), size_(size), pos_(0) {} + + size_t Read(void* ptr, size_t size) override { + size_t remaining = size_ - pos_; + if (size > remaining) { + size = remaining; + } + if (size == 0) { + return 0; + } + std::memcpy(ptr, data_ + pos_, size); + pos_ += size; + return size; + } + + void Write(const void* ptr, size_t size) override { + LOG(FATAL) << "MemoryBufferStream does not support write"; + } + + private: + const char* data_; + size_t size_; + size_t pos_; +}; + +size_t Base64DecodedSize(const std::string& base64_str) { + size_t len = base64_str.size(); + size_t padding = 0; + if (base64_str[len - 1] == '=') { + padding++; + } + if (base64_str[len - 2] == '=') { + padding++; + } + return 3 * len / 4 - padding; +} + +Result LoadImageFromBase64(const std::string& base64_str) { + using TResult = Result; + MemoryBufferStream stream(base64_str.c_str(), base64_str.size()); + tvm::support::Base64InStream base64_stream(&stream); + size_t decoded_size = Base64DecodedSize(base64_str); + std::vector decoded(decoded_size); + base64_stream.InitPosition(); + base64_stream.Read((void*)decoded.data(), decoded_size); + int width, height, num_channels; + unsigned char* image_data = + stbi_load_from_memory(decoded.data(), decoded_size, &width, &height, &num_channels, 3); + if (!image_data) { + return TResult::Error(stbi_failure_reason()); + } + auto image_ndarray = NDArray::Empty({height, width, 3}, {kDLUInt, 8, 1}, {kDLCPU, 0}); + image_ndarray.CopyFromBytes((void*)image_data, width * height * 3); + stbi_image_free(image_data); + return TResult::Ok(image_ndarray); +} + +NDArray ClipPreprocessor(NDArray image_data, int target_size, DLDevice device) { + int height = image_data->shape[0]; + int width = image_data->shape[1]; + // Resize + const int short_side = width < height ? width : height; + const int long_side = width > height ? width : height; + const int new_short_side = target_size; + const int new_long_side = (int)(new_short_side * (long_side / (float)short_side)); + const int new_width = width < height ? new_short_side : new_long_side; + const int new_height = width > height ? new_short_side : new_long_side; + + std::vector processed_image_data(new_width * new_height * 3); + + // Bilinear Interpolation + for (int y = 0; y < new_height; y++) { + for (int x = 0; x < new_width; x++) { + const float x_ratio = float(width - 1) / new_width; + const float y_ratio = float(height - 1) / new_height; + const int x1 = int(x_ratio * x); + const int y1 = int(y_ratio * y); + const int x2 = x1 + 1; + const int y2 = y1 + 1; + const float x_diff = x_ratio * x - x1; + const float y_diff = y_ratio * y - y1; + for (int c = 0; c < 3; c++) { + const uint8_t top_left = ((uint8_t*)image_data->data)[(y1 * width + x1) * 3 + c]; + const uint8_t top_right = ((uint8_t*)image_data->data)[(y1 * width + x2) * 3 + c]; + const uint8_t bottom_left = ((uint8_t*)image_data->data)[(y2 * width + x1) * 3 + c]; + const uint8_t bottom_right = ((uint8_t*)image_data->data)[(y2 * width + x2) * 3 + c]; + processed_image_data[(y * new_width + x) * 3 + c] = + (float)(int(top_left * (1 - x_diff) * (1 - y_diff) + top_right * x_diff * (1 - y_diff) + + bottom_left * y_diff * (1 - x_diff) + bottom_right * x_diff * y_diff)); + } + } + } + + // Center crop + const int crop_x = (new_width - target_size) / 2; + const int crop_y = (new_height - target_size) / 2; + std::vector cropped_image_data(target_size * target_size * 3); + for (int y = 0; y < target_size; y++) { + for (int x = 0; x < target_size; x++) { + for (int c = 0; c < 3; c++) { + cropped_image_data[(y * target_size + x) * 3 + c] = + processed_image_data[((y + crop_y) * new_width + x + crop_x) * 3 + c]; + } + } + } + + // Rescale + for (int i = 0; i < target_size * target_size * 3; i++) { + cropped_image_data[i] = cropped_image_data[i] / 255.0f; + } + + // Normalize + const float IMAGE_MEAN[] = {0.48145466f, 0.4578275f, 0.40821073f}; + const float IMAGE_STD[] = {0.26862954f, 0.26130258f, 0.27577711f}; + for (int i = 0; i < target_size * target_size * 3; i++) { + const int c = i % 3; + cropped_image_data[i] = (cropped_image_data[i] - IMAGE_MEAN[c]) / IMAGE_STD[c]; + } + + std::vector image_data_channel_first(target_size * target_size * 3); + for (int y = 0; y < target_size; y++) { + for (int x = 0; x < target_size; x++) { + for (int c = 0; c < 3; c++) { + image_data_channel_first[c * target_size * target_size + y * target_size + x] = + cropped_image_data[(y * target_size + x) * 3 + c]; + } + } + } + + // Create NDArray + auto image_ndarray = NDArray::Empty({1, 3, target_size, target_size}, {kDLFloat, 32, 1}, device); + image_ndarray.CopyFromBytes((void*)image_data_channel_first.data(), + target_size * target_size * 3 * sizeof(float)); + + return image_ndarray; +} + +} // namespace json_ffi +} // namespace llm +} // namespace mlc diff --git a/cpp/json_ffi/image_utils.h b/cpp/json_ffi/image_utils.h new file mode 100644 index 0000000000..1a89b7bc13 --- /dev/null +++ b/cpp/json_ffi/image_utils.h @@ -0,0 +1,31 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file json_ffi/image_utils.h + * \brief The header of Image utils for JSON FFI Engine in MLC LLM. + */ +#ifndef MLC_LLM_JSON_FFI_IMAGE_UTILS_H_ +#define MLC_LLM_JSON_FFI_IMAGE_UTILS_H_ + +#include + +#include +#include + +#include "../support/result.h" + +namespace mlc { +namespace llm { +namespace json_ffi { + +/*! \brief Load a base64 encoded image string into a CPU NDArray of shape {height, width, 3} */ +Result LoadImageFromBase64(const std::string& base64_str); + +/*! \brief Preprocess the CPU image for CLIP encoder and return an NDArray on the given device */ +tvm::runtime::NDArray ClipPreprocessor(tvm::runtime::NDArray image_data, int target_size, + DLDevice device); + +} // namespace json_ffi +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_JSON_FFI_IMAGE_UTILS_H_ diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index b4f9751719..65f3183424 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -4,6 +4,9 @@ #include #include +#include +#include + #include "../serve/model.h" #include "../support/json_parser.h" #include "../support/result.h" @@ -82,7 +85,7 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request conv_template = updated_conv_template.Unwrap(); // get prompt - Result> inputs_obj = conv_template.AsPrompt(); + Result> inputs_obj = conv_template.AsPrompt(this->model_config_, this->device_); if (inputs_obj.IsErr()) { err_ = inputs_obj.UnwrapErr(); return false; @@ -145,6 +148,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) { + this->device_ = device; CHECK(request_stream_callback.defined()) << "JSONFFIEngine requires request stream callback function, but it is not given."; this->request_stream_callback_ = request_stream_callback.value(); @@ -171,11 +175,15 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { Result model_config_json = serve::Model::LoadModelConfig(json::Lookup(engine_config_json, "model")); CHECK(model_config_json.IsOk()) << model_config_json.UnwrapErr(); + const picojson::object& model_config_json_unwrapped = model_config_json.Unwrap(); Result conv_template = Conversation::FromJSON( - json::Lookup(model_config_json.Unwrap(), "conv_template")); + json::Lookup(model_config_json_unwrapped, "conv_template")); CHECK(!conv_template.IsErr()) << "Invalid conversation template JSON: " << conv_template.UnwrapErr(); this->conv_template_ = conv_template.Unwrap(); + this->model_config_ = ModelConfig::FromJSON( + json::Lookup(model_config_json_unwrapped, "model_config")); + // Create streamer. // Todo(mlc-team): Create one streamer for each request, instead of a global one. this->streamer_ = diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h index e805cb6e8a..13dc5809bd 100644 --- a/cpp/json_ffi/json_ffi_engine.h +++ b/cpp/json_ffi/json_ffi_engine.h @@ -50,6 +50,8 @@ class JSONFFIEngine { TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request Conversation conv_template_; String default_generation_cfg_json_str_; + ModelConfig model_config_; + DLDevice device_; }; } // namespace json_ffi diff --git a/python/mlc_llm/model/llava/llava_model.py b/python/mlc_llm/model/llava/llava_model.py index 1498c13fdb..d3c409e92d 100644 --- a/python/mlc_llm/model/llava/llava_model.py +++ b/python/mlc_llm/model/llava/llava_model.py @@ -7,7 +7,7 @@ import logging from typing import Any, Dict, Optional, Tuple -from tvm import relax, te, tir +from tvm import relax, tir from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Module, Tensor, op from tvm.relax.frontend.nn.modules import Conv2D @@ -375,84 +375,11 @@ def to(self, dtype: Optional[str] = None): if dtype is not None: self.dtype = dtype - def _embed_input_ids(self, input_ids: Tensor) -> Tensor: - return self.language_model.embed(input_ids) - - def _embed_pixel_values_and_input_ids(self, pixel_values: Tensor, input_ids: Tensor) -> Tensor: - def _index(x, value, batch_size, seq_len): - return te.compute( - (batch_size, seq_len), - lambda i, j: tir.if_then_else( - x[i, j] == value, - j, - tir.IntImm("int32", 0), - ), - name="index", - ) - - def _concat(x: Tensor, y: Tensor, new_shape: tuple, insert_index: Tensor): - return te.compute( - (new_shape), - lambda b, i, j: tir.if_then_else( - i < insert_index[0], - x[b, i, j], - tir.if_then_else( - i < insert_index[0] + y.shape[1], - y[b, i - insert_index[0], j], - x[b, i - y.shape[1] + 1, j], - ), - ), - ) - - input_embeddings = self._embed_input_ids(input_ids) - - image_features_all = self.vision_tower.forward(pixel_values) - image_features = wrap_nested( - strided_slice( - image_features_all._expr, # pylint: disable=protected-access - axes=[1], - begin=[1], - end=[image_features_all.shape[1]], - ), - name="slice", - ) - image_features = self.multi_modal_projector(image_features) - batch_size, seq_len = input_ids.shape - image_index_tensor = op.tensor_expr_op( - _index, - name_hint="index", - args=[ - input_ids, - tir.IntImm("int32", self.config.image_token_index), - batch_size, - seq_len, - ], - ).astype("int32") - ##! Assume only one token in input - ##! Also assume batch_size = 1 for now - # TODO: Support image_count > 1 and batch_size > 1 # pylint: disable=fixme - insert_index = op.sum(image_index_tensor, axis=1) - - new_shape = ( - batch_size, - seq_len + tir.IntImm("int32", image_features.shape[1] - 1), - self.config.text_config.hidden_size, - ) - - combined_embeddings = op.tensor_expr_op( - _concat, - name_hint="combined_embeddings", - args=[input_embeddings, image_features, new_shape, insert_index], - ) - return combined_embeddings - def embed(self, input_ids: Tensor) -> Tensor: - return self._embed_input_ids(input_ids) - - def embed_with_pixel_values(self, pixel_values: Tensor, input_ids: Tensor) -> Tensor: - return self._embed_pixel_values_and_input_ids(pixel_values, input_ids) + return self.language_model.embed(input_ids) def image_embed(self, pixel_values: Tensor) -> Tensor: + pixel_values = pixel_values.astype(self.dtype) image_features_all = self.vision_tower.forward(pixel_values) image_features = wrap_nested( strided_slice( @@ -536,22 +463,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "embed_with_pixel_values": { - "pixel_values": nn.spec.Tensor( - [ - 1, - 3, - self.config.vision_config.image_size, - self.config.vision_config.image_size, - ], - self.dtype, - ), - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), - "$": { - "param_mode": "packed", - "effect_mode": "none", - }, - }, "image_embed": { "pixel_values": nn.spec.Tensor( [ @@ -560,7 +471,7 @@ def get_default_spec(self): self.config.vision_config.image_size, self.config.vision_config.image_size, ], - self.dtype, + "float32", ), "$": { "param_mode": "packed", diff --git a/python/mlc_llm/serve/data.py b/python/mlc_llm/serve/data.py index 1c56178ad1..7b946836ea 100644 --- a/python/mlc_llm/serve/data.py +++ b/python/mlc_llm/serve/data.py @@ -112,11 +112,9 @@ def from_url(url: str, config: Dict) -> "ImageData": # pylint: disable=too-many size={"shortest_edge": image_input_size}, crop_size={"height": image_input_size, "width": image_input_size}, ) - quantization = config["quantization"] - out_dtype = "float16" if "f16" in quantization else "float32" image_features = tvm.nd.array( image_processor.preprocess(image_tensor, return_tensors="np")["pixel_values"].astype( - out_dtype + "float32" ) ) image_data = ImageData(image_features, image_embed_size) diff --git a/tests/python/json_ffi/test_json_ffi_engine_image.py b/tests/python/json_ffi/test_json_ffi_engine_image.py new file mode 100644 index 0000000000..cfafb2bb9c --- /dev/null +++ b/tests/python/json_ffi/test_json_ffi_engine_image.py @@ -0,0 +1,91 @@ +import base64 +from typing import Dict, List, Optional + +import requests + +from mlc_llm.json_ffi import JSONFFIEngine + + +def base64_encode_image(url: str) -> str: + response = requests.get(url) + response.raise_for_status() # Ensure we got a successful response + image_data = base64.b64encode(response.content) + image_data_str = image_data.decode("utf-8") + data_url = f"data:image/jpeg;base64,{image_data_str}" + return data_url + + +image_prompts = [ + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": f"{base64_encode_image('https://llava-vl.github.io/static/images/view.jpg')}", + }, + {"type": "text", "text": "What does the image represent?"}, + ], + } + ] +] + + +def run_chat_completion( + engine: JSONFFIEngine, + model: str, + prompts: List[List[Dict]] = image_prompts, + tools: Optional[List[Dict]] = None, +): + num_requests = 1 + max_tokens = 64 + n = 1 + output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] + + for rid in range(num_requests): + print(f"chat completion for request {rid}") + for response in engine.chat_completion( + messages=prompts[rid], + model=model, + max_tokens=max_tokens, + n=n, + request_id=str(rid), + tools=tools, + ): + for choice in response.choices: + assert choice.delta.role == "assistant" + assert isinstance(choice.delta.content[0], Dict) + assert choice.delta.content[0]["type"] == "text" + output_texts[rid][choice.index] += choice.delta.content[0]["text"] + + # Print output. + print("Chat completion all finished") + for req_id, outputs in enumerate(output_texts): + print(f"Prompt {req_id}: {prompts[req_id]}") + if len(outputs) == 1: + print(f"Output {req_id}:{outputs[0]}\n") + else: + for i, output in enumerate(outputs): + print(f"Output {req_id}({i}):{output}\n") + + +def test_chat_completion(): + # Create engine. + model = "dist/llava-1.5-7b-hf-q4f16_1-MLC" + engine = JSONFFIEngine( + model, + max_total_sequence_length=1024, + ) + + run_chat_completion(engine, model) + + # Test malformed requests. + for response in engine._handle_chat_completion("malformed_string", n=1, request_id="123"): + assert len(response.choices) == 1 + assert response.choices[0].finish_reason == "error" + + engine.terminate() + + +if __name__ == "__main__": + test_chat_completion() From cd0993390e964523f7a69d07c1b92796ce0c8a8a Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 6 May 2024 12:52:23 -0400 Subject: [PATCH 275/531] [Pass] Attach manual softmax-with-temperature (#2280) This PR updates all the models to use the new softmax-with-temperature function, which inlines the temperature division (or argmax if temperature is 0) process into the two-stage softmax. Unit benchmark shows that the inline of division does no harm to the softmax. When batch size is large, the inlined softmax can have better performance than a standalone divide kernel, which takes much time when batch size is large. --- cpp/serve/logit_processor.cc | 2 +- .../mlc_llm/compiler_pass/attach_sampler.py | 15 +- .../attach_softmax_with_temperature.py | 243 ++++++++++++++++++ .../attach_spec_decode_aux_funcs.py | 37 +-- python/mlc_llm/compiler_pass/pipeline.py | 4 +- .../mlc_llm/compiler_pass/rewrite_softmax.py | 198 -------------- .../mlc_llm/model/baichuan/baichuan_model.py | 11 - .../mlc_llm/model/chatglm3/chatglm3_model.py | 11 - python/mlc_llm/model/gemma/gemma_model.py | 11 - python/mlc_llm/model/gpt2/gpt2_model.py | 11 - .../model/gpt_bigcode/gpt_bigcode_model.py | 11 - .../mlc_llm/model/gpt_neox/gpt_neox_model.py | 11 - .../mlc_llm/model/internlm/internlm_model.py | 11 - python/mlc_llm/model/llama/llama_model.py | 11 - python/mlc_llm/model/llava/llava_model.py | 11 - python/mlc_llm/model/mistral/mistral_model.py | 11 - python/mlc_llm/model/orion/orion_model.py | 11 - python/mlc_llm/model/phi/phi_model.py | 11 - python/mlc_llm/model/qwen/qwen_model.py | 11 - python/mlc_llm/model/qwen2/qwen2_model.py | 11 - python/mlc_llm/model/rwkv5/rwkv5_model.py | 12 - python/mlc_llm/model/rwkv6/rwkv6_model.py | 12 - .../mlc_llm/model/stable_lm/stablelm_model.py | 11 - 23 files changed, 271 insertions(+), 417 deletions(-) create mode 100644 python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py delete mode 100644 python/mlc_llm/compiler_pass/rewrite_softmax.py diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index 7ce70a0d26..628a4ec1c5 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -166,7 +166,7 @@ class LogitProcessorImpl : public LogitProcessorObj { cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i)); int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i); for (int j = 0; j < num_token_to_process; ++j) { - p_temperature[token_offset + j] = std::max(generation_cfg[i]->temperature, eps_); + p_temperature[token_offset + j] = std::max(generation_cfg[i]->temperature, 0.0); } } diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index 5bf62257a1..4761914e2f 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -125,8 +125,7 @@ def _attach_argsort_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): sorted_indices, primfunc_name_hint="take_sorted_probs", ) - output = (sorted_values, sorted_indices) - bb.emit_output(output) + output = bb.emit_output((sorted_values, sorted_indices)) gv = bb.emit_func_output(output) return gv @@ -215,7 +214,7 @@ def _attach_sample_with_top_p( # pylint: disable=too-many-locals sample_indices_tensor, ) ) - result = bb.emit( + result = bb.emit_output( relax.call_pure_packed( "vm.builtin.reshape", result_tensor._expr, # pylint: disable=protected-access @@ -223,7 +222,6 @@ def _attach_sample_with_top_p( # pylint: disable=too-many-locals sinfo_args=sample_indices.struct_info, # pylint: disable=no-member ) ) - bb.emit_output(result) gv = bb.emit_func_output(result) return gv @@ -249,14 +247,13 @@ def _attach_renormalize_by_top_p( ) final_pivot = cutoff_output[0] renorm_sum = cutoff_output[1] - renormalized_probs = bb.emit( + renormalized_probs = bb.emit_output( relax.call_tir( bb.add_func(top_p_renorm(target), "top_p_renorm_after_cutoff"), args=[probs, final_pivot, renorm_sum], out_sinfo=probs.struct_info, # pylint: disable=no-member ) ) - bb.emit_output(renormalized_probs) gv = bb.emit_func_output(renormalized_probs) return gv @@ -315,7 +312,7 @@ def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-argument args = [unsorted_probs, sorted_indices, sample_indices, sampling_results, top_prob_offsets] with bb.function("sampler_take_probs", args): with bb.dataflow(): - taken_probs_indices = bb.emit( + taken_probs_indices = bb.emit_output( relax.call_tir( bb.add_func(sampler_take_probs_tir, "sampler_take_probs_tir"), args, @@ -326,7 +323,6 @@ def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-argument ], ) ) - bb.emit_output(taken_probs_indices) gv = bb.emit_func_output(taken_probs_indices) return gv @@ -362,7 +358,7 @@ def _attach_batch_verifier(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): ] with bb.function("sampler_verify_draft_tokens", args): with bb.dataflow(): - res = bb.emit( + res = bb.emit_output( relax.call_tir_inplace( bb.add_func(batch_spec_verify(vocab_size), "batch_verify_on_gpu_single_kernel"), args, @@ -373,6 +369,5 @@ def _attach_batch_verifier(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): ], ) ) - bb.emit_output(res) gv = bb.emit_func_output(res) return gv diff --git a/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py b/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py new file mode 100644 index 0000000000..f454ab1b85 --- /dev/null +++ b/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py @@ -0,0 +1,243 @@ +"""A compiler pass that attaches two-stage softmax with temperature.""" + +import tvm +from tvm import relax, tir +from tvm.ir.module import IRModule +from tvm.relax.expr_functor import PyExprMutator, mutator +from tvm.script import tir as T + +from ..support.max_thread_check import get_max_num_threads_per_block + + +@tvm.transform.module_pass(opt_level=0, name="AttachSoftmaxWithTemperature") +class AttachSoftmaxWithTemperature: # pylint: disable=too-few-public-methods + """Rewrites one-shot softmax into two-stage softmax.""" + + def __init__(self, target: tvm.target.Target) -> None: + self.target = target + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + return _Rewriter(mod, self.target).transform() + + +@mutator +class _Rewriter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: IRModule, target: tvm.target.Target) -> None: + super().__init__(mod) + self.mod = mod + self.target = target + self.chunk_size = 4096 + + def transform(self) -> IRModule: + """Entry point""" + batch_size = tir.Var("batch_size", "int64") + vocab_size = tir.Var("vocab_size", "int64") + dtype = "float32" + logits = relax.Var("logits", relax.TensorStructInfo([batch_size, 1, vocab_size], dtype)) + temperature = relax.Var("temperature", relax.TensorStructInfo([batch_size], dtype)) + with self.builder_.function("softmax_with_temperature", params=[logits, temperature]): + with self.builder_.dataflow(): + output_struct_info = logits.struct_info # pylint: disable=no-member + new_shape = relax.ShapeExpr([batch_size, vocab_size]) + logits = relax.call_pure_packed( + "vm.builtin.reshape", + logits, + new_shape, + sinfo_args=relax.TensorStructInfo(new_shape, dtype), + ) + f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func( + self.target, self.chunk_size + ) + chunked_result_struct_info = relax.TensorStructInfo( + (batch_size, (vocab_size + self.chunk_size - 1) // self.chunk_size), + "float32", + ) + chunked_results = self.builder_.emit( + relax.call_tir( + self.builder_.add_func(f_chunk_lse, "chunk_lse"), + args=[logits, temperature], + out_sinfo=[chunked_result_struct_info, chunked_result_struct_info], + ) + ) + chunked_sum = chunked_results[0] + chunked_max = chunked_results[1] + softmax = self.builder_.emit( + relax.call_tir( + self.builder_.add_func(f_softmax_with_lse, "softmax_with_chunked_sum"), + args=[logits, temperature, chunked_sum, chunked_max], + out_sinfo=logits.struct_info, + ) + ) + softmax = self.builder_.emit_output( + relax.call_pure_packed( + "vm.builtin.reshape", + softmax, + output_struct_info.shape, + sinfo_args=output_struct_info, + ) + ) + self.builder_.emit_func_output(softmax) + return self.builder_.get() + + +def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements + target: tvm.target.Target, chunk_size: int +): + # NOTE: A quick note on the softmax implementation. + # We once tried to multiply every element by log2e which can be computed + # potentially more efficiently on hardware. + # However, when the input values are large, multiplying by the factor of log2e + # causes numerical issue in float32 dtype. + # This leads to the softmax output not summing up to 1. + # For numerical stability, we removed the log2e factor and switched back + # to the standard log/exp computation. + + # The kernels below handle both the cases of temperature=0 and temperature != 0. + # - When temperature is not 0, the first kernel computes the log-sum-exp of + # chunks (subtracted by the max value in chunk), and the max values of chunks. + # The second kernel merges the log-sum-exp with the maximum values. + # - When temperature is 0, the first kernel computes the max value and the counts + # of the max value. The second kernel merges the max and counts, and set the + # softmax of the maximum values to "max_value / max_count". + + # pylint: disable=invalid-name + @T.prim_func + def chunk_lse( # pylint: disable=too-many-locals + var_A: T.handle, + var_temperature: T.handle, + var_chunked_sum: T.handle, + var_chunked_max: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + temperature = T.match_buffer(var_temperature, (batch_size,), dtype="float32") + chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks), dtype="float32") + chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks), dtype="float32") + A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(chunk_size)), dtype="float32") + temp_max = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + temp_sum = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + A_pad[v0, v1, v2] = T.if_then_else( + v1 * T.int64(chunk_size) + v2 < vocab_size, + T.if_then_else( + temperature[v0] > T.float32(1e-5), + A[v0, v1 * T.int64(chunk_size) + v2] / temperature[v0], + A[v0, v1 * T.int64(chunk_size) + v2], + ), + T.min_value("float32"), + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("max"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_max[v0, v1] = T.min_value("float32") + temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2]) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("sum_exp"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_sum[v0, v1] = T.float32(0) + temp_sum[v0, v1] += T.if_then_else( + v1 * T.int64(chunk_size) + v2 < vocab_size, + T.Select( + temperature[v0] > T.float32(1e-5), + T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), + T.cast(A_pad[v0, v1, v2] == temp_max[v0, v1], "float32"), + ), + T.float32(0), + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)): + with T.block("log"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + chunked_sum[v0, v1] = T.Select( + temperature[v0] > T.float32(1e-5), + T.log(temp_sum[v0, v1]), + temp_sum[v0, v1], + ) + chunked_max[v0, v1] = temp_max[v0, v1] + + @T.prim_func + def softmax_with_chunked_sum( + var_A: T.handle, + var_temperature: T.handle, + var_chunked_sum: T.handle, + var_chunked_max: T.handle, + var_softmax: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + temperature = T.match_buffer(var_temperature, (batch_size,), dtype="float32") + chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks), dtype="float32") + chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks), dtype="float32") + softmax = T.match_buffer(var_softmax, (batch_size, vocab_size), dtype="float32") + temp_max = T.alloc_buffer((batch_size,), dtype="float32") + temp_sum = T.alloc_buffer((batch_size,), dtype="float32") + for l0, l1 in T.grid(batch_size, num_chunks): + with T.block("max"): + v0, v1 = T.axis.remap("SR", [l0, l1]) + with T.init(): + temp_max[v0] = T.min_value("float32") + temp_max[v0] = T.max(temp_max[v0], chunked_max[v0, v1]) + for l0, l1 in T.grid(batch_size, num_chunks): + with T.block("sum_exp"): + v0, v1 = T.axis.remap("SR", [l0, l1]) + with T.init(): + temp_sum[v0] = T.float32(0) + temp_sum[v0] += T.Select( + temperature[v0] > T.float32(1e-5), + T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max[v0]), + T.cast(chunked_max[v0, v1] == temp_max[v0], "float32") * chunked_sum[v0, v1], + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("log_pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + if v1 * T.int64(chunk_size) + v2 < vocab_size: + softmax[v0, v1 * T.int64(chunk_size) + v2] = T.if_then_else( + temperature[v0] > T.float32(1e-5), + T.exp( + A[v0, v1 * T.int64(chunk_size) + v2] / temperature[v0] + - (T.log(temp_sum[v0]) + temp_max[v0]) + ), + T.cast(A[v0, v1 * T.int64(chunk_size) + v2] == temp_max[v0], "float32") + / temp_sum[v0], + ) + + sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_sum": softmax_with_chunked_sum})) + max_threads = get_max_num_threads_per_block(target) + TX = 32 + TY = max_threads // TX + unroll_depth = 64 + # pylint: enable=invalid-name + + sch.work_on("softmax_with_chunked_sum") + l0, l1, l2 = sch.get_loops("log_pad") + bx = sch.fuse(l0, l1) + sch.bind(bx, "blockIdx.x") + unroll, ty, tx = sch.split(l2, [None, TY, TX]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1) + + for block_name in ["sum_exp", "max"]: + block = sch.get_block(block_name) + sch.set_scope(block, buffer_index=0, storage_scope="shared") + sch.compute_at(block, bx) + r_loop = sch.get_loops(block)[-1] + r_loop, tx = sch.split(r_loop, [None, TX]) + sch.reorder(tx, r_loop) + sch.bind(tx, "threadIdx.x") + sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) + + return chunk_lse, sch.mod["softmax_with_chunked_sum"] diff --git a/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py b/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py index f7bb3dbe14..ef3d6af722 100644 --- a/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py +++ b/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py @@ -82,16 +82,17 @@ def _add_scatter_hidden_states(bb: BlockBuilder, tensor_parallel_shards: int, dt with bb.dataflow(): if tensor_parallel_shards > 1: indices = relax.op.ccl.broadcast_from_worker0(indices) - output = relax.op.call_tir_inplace( - bb.add_func( - _get_scatter_2d_inplace(dtype, "_scatter_hidden_states"), - "_scatter_hidden_states", - ), - [src, indices, dst], - 2, - dst.struct_info, # pylint: disable=no-member + output = bb.emit_output( + relax.op.call_tir_inplace( + bb.add_func( + _get_scatter_2d_inplace(dtype, "_scatter_hidden_states"), + "_scatter_hidden_states", + ), + [src, indices, dst], + 2, + dst.struct_info, # pylint: disable=no-member + ) ) - bb.emit_output(output) gv = bb.emit_func_output(output) return gv @@ -107,14 +108,16 @@ def _add_gather_hidden_states(bb: BlockBuilder, tensor_parallel_shards: int, dty with bb.dataflow(): if tensor_parallel_shards > 1: indices = relax.op.ccl.broadcast_from_worker0(indices) - output = relax.op.call_tir_inplace( - bb.add_func( - _get_gather_2d_inplace(dtype, "_gather_hidden_states"), "_gather_hidden_states" - ), - [src, indices, dst], - 2, - dst.struct_info, # pylint: disable=no-member + output = bb.emit_output( + relax.op.call_tir_inplace( + bb.add_func( + _get_gather_2d_inplace(dtype, "_gather_hidden_states"), + "_gather_hidden_states", + ), + [src, indices, dst], + 2, + dst.struct_info, # pylint: disable=no-member + ) ) - bb.emit_output(output) gv = bb.emit_func_output(output) return gv diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index 7bc89de21b..a80bbaf8d7 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -15,6 +15,7 @@ from .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc from .attach_logit_processor import AttachLogitProcessFunc from .attach_sampler import AttachGPUSamplingFunc +from .attach_softmax_with_temperature import AttachSoftmaxWithTemperature from .attach_spec_decode_aux_funcs import AttachSpecDecodeAuxFuncs from .attach_support_info import ( AttachAdditionalPrimFuncs, @@ -34,7 +35,6 @@ from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc from .low_batch_specialization import LowBatchGemvSpecialize -from .rewrite_softmax import RewriteTwoStageSoftmax from .scatter_tuple_get_item import ScatterTupleGetItem logger = logging.getLogger(__name__) @@ -100,6 +100,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I [ # Phase 0. Add additional information for compilation and remove unused Relax func DispatchKVCacheCreation(target, flashinfer, metadata), + AttachSoftmaxWithTemperature(target), AttachVariableBounds(variable_bounds), AttachCUDAGraphSymbolicCaptureHints(cuda_graph_symbolic_capture_hints), AttachLogitProcessFunc(target), @@ -121,7 +122,6 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline _LogProgress("Lowering to TVM TIR kernels"), tvm.relax.backend.DispatchSortScan(), - RewriteTwoStageSoftmax(target=target), tvm.relax.transform.LegalizeOps(), tvm.relax.transform.AnnotateTIROpPattern(), tvm.relax.transform.FoldConstant(), diff --git a/python/mlc_llm/compiler_pass/rewrite_softmax.py b/python/mlc_llm/compiler_pass/rewrite_softmax.py deleted file mode 100644 index 47a5a168d7..0000000000 --- a/python/mlc_llm/compiler_pass/rewrite_softmax.py +++ /dev/null @@ -1,198 +0,0 @@ -"""A compiler pass that rewrites one-shot softmax into two-stage softmax.""" - -import tvm -from tvm import relax -from tvm.ir.module import IRModule -from tvm.relax.expr import Expr -from tvm.relax.expr_functor import PyExprMutator, mutator -from tvm.script import tir as T - -from ..support.max_thread_check import get_max_num_threads_per_block - - -@tvm.transform.module_pass(opt_level=0, name="RewriteTwoStageSoftmax") -class RewriteTwoStageSoftmax: # pylint: disable=too-few-public-methods - """Rewrites one-shot softmax into two-stage softmax.""" - - def __init__(self, target: tvm.target.Target) -> None: - self.target = target - - def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: - """IRModule-level transformation""" - return _Rewriter(mod, self.target).transform() - - -@mutator -class _Rewriter(PyExprMutator): # pylint: disable=abstract-method - def __init__(self, mod: IRModule, target: tvm.target.Target) -> None: - super().__init__(mod) - self.mod = mod - self.target = target - self.chunk_size = 4096 - - def transform(self) -> IRModule: - """Entry point""" - func_name = "softmax_with_temperature" - if func_name not in self.mod: - return self.mod - gv = self.mod.get_global_var(func_name) - updated_func = self.visit_expr(self.mod[gv]) - self.builder_.update_func(gv, updated_func) - return self.builder_.get() - - def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-renamed - if call.op != tvm.ir.Op.get("relax.nn.softmax"): - return call - x = call.args[0] - if call.attrs.axis not in [-1, x.struct_info.ndim - 1]: - return call - # Currently the softmax input is 3-dim, and dtype is float32. - assert x.struct_info.ndim == 3 - assert x.struct_info.dtype == "float32" - x_shape = x.struct_info.shape - new_shape = relax.ShapeExpr([x_shape[0] * x_shape[1], x_shape[2]]) - x_reshaped = relax.call_pure_packed( - "vm.builtin.reshape", - x, - new_shape, - sinfo_args=relax.TensorStructInfo(new_shape, x.struct_info.dtype), - ) - f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(self.target, self.chunk_size) - chunked_lse = relax.call_tir( - self.builder_.add_func(f_chunk_lse, "chunk_lse"), - args=[x_reshaped], - out_sinfo=relax.TensorStructInfo( - (new_shape[0], (new_shape[1] + self.chunk_size - 1) // self.chunk_size), - x.struct_info.dtype, - ), - ) - softmax = relax.call_tir( - self.builder_.add_func(f_softmax_with_lse, "softmax_with_chunked_lse"), - args=[x_reshaped, chunked_lse], - out_sinfo=relax.TensorStructInfo(new_shape, x.struct_info.dtype), - ) - return relax.call_pure_packed( - "vm.builtin.reshape", softmax, x_shape, sinfo_args=x.struct_info - ) - - -def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements - target: tvm.target.Target, chunk_size: int -): - # NOTE: A quick note on the softmax implementation. - # We once tried to multiply every element by log2e which can be computed - # potentially more efficiently on hardware. - # However, when the input values are large, multiplying by the factor of log2e - # causes numerical issue in float32 dtype. - # This leads to the softmax output not summing up to 1. - # For numerical stability, we removed the log2e factor and switched back - # to the standard log/exp computation. - - # pylint: disable=invalid-name - @T.prim_func - def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=too-many-locals - T.func_attr({"tir.noalias": T.bool(True)}) - batch_size = T.int64(is_size_var=True) - vocab_size = T.int64(is_size_var=True) - num_chunks = T.int64(is_size_var=True) - A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") - chunked_lse = T.match_buffer(var_chunked_lse, (batch_size, num_chunks), dtype="float32") - A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(chunk_size)), dtype="float32") - temp_max = T.alloc_buffer((batch_size, num_chunks), dtype="float32") - temp_sum = T.alloc_buffer((batch_size, num_chunks), dtype="float32") - - for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): - with T.block("pad"): - v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) - A_pad[v0, v1, v2] = T.if_then_else( - v1 * T.int64(chunk_size) + v2 < vocab_size, - A[v0, v1 * T.int64(chunk_size) + v2], - T.min_value("float32"), - ) - for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): - with T.block("max"): - v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) - with T.init(): - temp_max[v0, v1] = T.min_value("float32") - temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2]) - for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): - with T.block("sum_exp"): - v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) - with T.init(): - temp_sum[v0, v1] = T.float32(0) - temp_sum[v0, v1] += T.if_then_else( - v1 * T.int64(chunk_size) + v2 < vocab_size, - T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), - T.float32(0), - ) - for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)): - with T.block("log"): - v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) - chunked_lse[v0, v1] = T.log(temp_sum[v0, v1]) + temp_max[v0, v1] - - @T.prim_func - def softmax_with_chunked_lse(var_A: T.handle, var_chunked_lse: T.handle, var_softmax: T.handle): - T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) - batch_size = T.int64(is_size_var=True) - vocab_size = T.int64(is_size_var=True) - num_chunks = T.int64(is_size_var=True) - A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") - chunked_lse = T.match_buffer(var_chunked_lse, (batch_size, num_chunks), dtype="float32") - softmax = T.match_buffer(var_softmax, (batch_size, vocab_size), dtype="float32") - temp_max = T.alloc_buffer((batch_size,), dtype="float32") - temp_sum = T.alloc_buffer((batch_size,), dtype="float32") - lse = T.alloc_buffer((batch_size,), dtype="float32") - for l0, l1 in T.grid(batch_size, num_chunks): - with T.block("max"): - v0, v1 = T.axis.remap("SR", [l0, l1]) - with T.init(): - temp_max[v0] = T.min_value("float32") - temp_max[v0] = T.max(temp_max[v0], chunked_lse[v0, v1]) - for l0, l1 in T.grid(batch_size, num_chunks): - with T.block("sum_exp"): - v0, v1 = T.axis.remap("SR", [l0, l1]) - with T.init(): - temp_sum[v0] = T.float32(0) - temp_sum[v0] += T.exp(chunked_lse[v0, v1] - temp_max[v0]) - for l0 in T.serial(0, batch_size): - with T.block("log"): - v0 = T.axis.remap("S", [l0]) - lse[v0] = T.log(temp_sum[v0]) + temp_max[v0] - for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): - with T.block("pad"): - v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) - if v1 * T.int64(chunk_size) + v2 < vocab_size: - softmax[v0, v1 * T.int64(chunk_size) + v2] = T.exp( - A[v0, v1 * T.int64(chunk_size) + v2] - lse[v0] - ) - - sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_lse": softmax_with_chunked_lse})) - max_threads = get_max_num_threads_per_block(target) - TX = 32 - TY = max_threads // TX - unroll_depth = 64 - # pylint: enable=invalid-name - - sch.work_on("softmax_with_chunked_lse") - sch.compute_inline("log") - l0, l1, l2 = sch.get_loops("pad") - bx = sch.fuse(l0, l1) - sch.bind(bx, "blockIdx.x") - unroll, ty, tx = sch.split(l2, [None, TY, TX]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) - sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1) - - for block_name in ["sum_exp", "max"]: - block = sch.get_block(block_name) - sch.set_scope(block, buffer_index=0, storage_scope="shared") - sch.compute_at(block, bx) - r_loop = sch.get_loops(block)[-1] - r_loop, tx = sch.split(r_loop, [None, TX]) - sch.reorder(tx, r_loop) - sch.bind(tx, "threadIdx.x") - sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) - sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) - - return chunk_lse, sch.mod["softmax_with_chunked_lse"] diff --git a/python/mlc_llm/model/baichuan/baichuan_model.py b/python/mlc_llm/model/baichuan/baichuan_model.py index 0b6dfb1477..9981b06449 100644 --- a/python/mlc_llm/model/baichuan/baichuan_model.py +++ b/python/mlc_llm/model/baichuan/baichuan_model.py @@ -260,9 +260,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -337,14 +334,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/chatglm3/chatglm3_model.py b/python/mlc_llm/model/chatglm3/chatglm3_model.py index df86353540..88849214b7 100644 --- a/python/mlc_llm/model/chatglm3/chatglm3_model.py +++ b/python/mlc_llm/model/chatglm3/chatglm3_model.py @@ -336,9 +336,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -413,14 +410,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py index c08c6d9ad4..2f88642893 100644 --- a/python/mlc_llm/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -288,9 +288,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -365,14 +362,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index 0922a7a1bf..43d7df1d3b 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -280,9 +280,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -357,14 +354,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py index dd721ad444..fd84601112 100644 --- a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py @@ -257,9 +257,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -334,14 +331,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py index 0ce1858c89..022a05602e 100644 --- a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py @@ -311,9 +311,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -389,14 +386,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/internlm/internlm_model.py b/python/mlc_llm/model/internlm/internlm_model.py index 00683add3b..8bd59de7d6 100644 --- a/python/mlc_llm/model/internlm/internlm_model.py +++ b/python/mlc_llm/model/internlm/internlm_model.py @@ -271,9 +271,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -348,14 +345,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 69f01ee13b..cd99301132 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -318,9 +318,6 @@ def batch_verify_to_last_hidden_states( hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -450,14 +447,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/llava/llava_model.py b/python/mlc_llm/model/llava/llava_model.py index d3c409e92d..e4facaf1cb 100644 --- a/python/mlc_llm/model/llava/llava_model.py +++ b/python/mlc_llm/model/llava/llava_model.py @@ -425,9 +425,6 @@ def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): return self.language_model.batch_verify(input_embeds, paged_kv_cache) - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -529,14 +526,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/mistral/mistral_model.py b/python/mlc_llm/model/mistral/mistral_model.py index 966dc6e35e..4522c4877d 100644 --- a/python/mlc_llm/model/mistral/mistral_model.py +++ b/python/mlc_llm/model/mistral/mistral_model.py @@ -253,9 +253,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -330,14 +327,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/orion/orion_model.py b/python/mlc_llm/model/orion/orion_model.py index d9c55e1f6c..9f2f6173db 100644 --- a/python/mlc_llm/model/orion/orion_model.py +++ b/python/mlc_llm/model/orion/orion_model.py @@ -272,9 +272,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -349,14 +346,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/phi/phi_model.py b/python/mlc_llm/model/phi/phi_model.py index 7ecb5e211f..b30aad8c20 100644 --- a/python/mlc_llm/model/phi/phi_model.py +++ b/python/mlc_llm/model/phi/phi_model.py @@ -388,9 +388,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def embed(self, input_ids: Tensor): if self.tensor_parallel_shards > 1: input_ids = op.ccl_broadcast_from_worker0(input_ids) @@ -472,14 +469,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/qwen/qwen_model.py b/python/mlc_llm/model/qwen/qwen_model.py index cbca790246..6ce101441c 100644 --- a/python/mlc_llm/model/qwen/qwen_model.py +++ b/python/mlc_llm/model/qwen/qwen_model.py @@ -266,9 +266,6 @@ def batch_verify(self, inputs: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(inputs, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -343,14 +340,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/qwen2/qwen2_model.py b/python/mlc_llm/model/qwen2/qwen2_model.py index 88e49af635..52c0742e17 100644 --- a/python/mlc_llm/model/qwen2/qwen2_model.py +++ b/python/mlc_llm/model/qwen2/qwen2_model.py @@ -279,9 +279,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -356,14 +353,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/rwkv5/rwkv5_model.py b/python/mlc_llm/model/rwkv5/rwkv5_model.py index 81c9e9aa7f..987d9f8b6b 100644 --- a/python/mlc_llm/model/rwkv5/rwkv5_model.py +++ b/python/mlc_llm/model/rwkv5/rwkv5_model.py @@ -379,10 +379,6 @@ def batch_verify(self, input_embeds: Tensor, state: RNNState): """Verify step.""" return self.forward(input_embeds, state) - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - """Softmax.""" - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state( self, max_batch_size: tir.Var, @@ -451,14 +447,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_rnn_state": { "max_batch_size": int, "max_history": int, diff --git a/python/mlc_llm/model/rwkv6/rwkv6_model.py b/python/mlc_llm/model/rwkv6/rwkv6_model.py index a8faf48a6b..7c090206c5 100644 --- a/python/mlc_llm/model/rwkv6/rwkv6_model.py +++ b/python/mlc_llm/model/rwkv6/rwkv6_model.py @@ -421,10 +421,6 @@ def batch_verify(self, input_embeds: Tensor, state: RNNState): """Verify step.""" return self.forward(input_embeds, state) - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - """Softmax.""" - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state( self, max_batch_size: tir.Var, @@ -493,14 +489,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_rnn_state": { "max_batch_size": int, "max_history": int, diff --git a/python/mlc_llm/model/stable_lm/stablelm_model.py b/python/mlc_llm/model/stable_lm/stablelm_model.py index ea87e64fc7..8958495da2 100644 --- a/python/mlc_llm/model/stable_lm/stablelm_model.py +++ b/python/mlc_llm/model/stable_lm/stablelm_model.py @@ -275,9 +275,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -353,14 +350,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, From eb1454f8ae42b14130f193faba13500b843939d1 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 6 May 2024 13:54:37 -0400 Subject: [PATCH 276/531] [Model] Remove unused import to fix lint (#2284) This PR removes the unused import in llava model to fix lint. --- python/mlc_llm/model/llava/llava_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/model/llava/llava_model.py b/python/mlc_llm/model/llava/llava_model.py index e4facaf1cb..a6ccfe8edc 100644 --- a/python/mlc_llm/model/llava/llava_model.py +++ b/python/mlc_llm/model/llava/llava_model.py @@ -9,7 +9,7 @@ from tvm import relax, tir from tvm.relax.frontend import nn -from tvm.relax.frontend.nn import Module, Tensor, op +from tvm.relax.frontend.nn import Module, Tensor from tvm.relax.frontend.nn.modules import Conv2D from tvm.relax.frontend.nn.op import ( broadcast_to, From 44b56753a602df40195f67063951e37959c0ff6e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 6 May 2024 23:46:54 -0400 Subject: [PATCH 277/531] [Serving] Fix BatchVerify to feed the extra token when fully accepted (#2285) This PR fixes a bug in the BatchVerify action. When a draft model's proposal is fully accepted by the main model, there is an extra token which is already in the main model's KV cache but not in the draft model's KV cache. Prior to this PR, BatchVerify action does not feed this extra token into the draft model's KV cache, which causes size mismatch between the main model's KV cache and draft model's KV cache. This PR fixes this issue by adding an additional BatchDecode step for the requests whose draft proposals are fully accepted by the main model. --- cpp/serve/engine_actions/batch_verify.cc | 47 +++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 42524d46b2..80c5a5e125 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -133,6 +133,13 @@ class BatchVerifyActionObj : public EngineActionObj { draft_output_tokens, draft_probs_on_device); ICHECK_EQ(sample_results_arr.size(), num_rsentries); + // We collect the requests whose drafts are fully accepted. + // When a request's draft is fully accepted, there is an extra token proposed + // by the draft model but not added into the draft model's KV cache. + // In this case, an additional batch decode step is needed for these requests. + std::vector fully_accepted_rsentries; + fully_accepted_rsentries.reserve(num_rsentries); + for (int i = 0; i < num_rsentries; ++i) { const std::vector& sample_results = sample_results_arr[i]; int accept_length = sample_results.size(); @@ -151,9 +158,47 @@ class BatchVerifyActionObj : public EngineActionObj { if (rollback_length > 0) { models_[verify_model_id_]->PopNFromKVCache( rsentries[i]->mstates[verify_model_id_]->internal_id, rollback_length); + // The last accepted token is not yet added into the draft model. + // Therefore, the rollback length for the draft model is one less. models_[draft_model_id_]->PopNFromKVCache( - rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length); + rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1); + } else { + fully_accepted_rsentries.push_back(i); + } + } + + if (!fully_accepted_rsentries.empty()) { + // - Run a step of batch decode for requests whose drafts are fully accepted. + // When a request's draft is fully accepted, there is an extra token proposed + // by the draft model but not added into the draft model's KV cache. + // In this case, an additional batch decode step is needed for these requests. + std::vector input_tokens; + std::vector fully_accepted_request_internal_ids; + input_tokens.reserve(fully_accepted_rsentries.size()); + fully_accepted_request_internal_ids.reserve(fully_accepted_rsentries.size()); + for (int rsentry_id : fully_accepted_rsentries) { + int num_committed_tokens = + rsentries[rsentry_id]->mstates[verify_model_id_]->committed_tokens.size(); + // When a request's draft is fully accepted, an additional new token is sampled. + // So the token needed to fill in the draft model is the committed_token[-2]. + ICHECK_GE(num_committed_tokens, 2); + input_tokens.push_back(rsentries[rsentry_id] + ->mstates[verify_model_id_] + ->committed_tokens[num_committed_tokens - 2] + .sampled_token_id.first); + fully_accepted_request_internal_ids.push_back( + rsentries[rsentry_id]->mstates[draft_model_id_]->internal_id); } + // - Compute embeddings. + ObjectRef embeddings = models_[draft_model_id_]->TokenEmbed( + {IntTuple{input_tokens.begin(), input_tokens.end()}}); + // - Invoke model decode. + NDArray logits = + models_[draft_model_id_]->BatchDecode(embeddings, fully_accepted_request_internal_ids); + // - We explicitly synchronize to avoid the input tokens getting overriden in the + // next runs of BatchDecode. + // This is because we do not do sample for this round of batch decode. + TVMSynchronize(logits->device.device_type, logits->device.device_id, nullptr); } // clear the draft model state entries From ec6cc300636e78b93f4cce01c1b6cd49440a0bd2 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 6 May 2024 23:54:10 -0400 Subject: [PATCH 278/531] Update engine.cc --- cpp/serve/engine.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 6fd6188562..8c26b55778 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -129,8 +129,9 @@ class EngineImpl : public Engine { DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr}; if (engine_config->speculative_mode != SpeculativeMode::kDisable) { max_num_tokens *= engine_config->spec_draft_length + 1; + // multiply max num_tokens by two so we can do ping-pong swaping during draft/verify process draft_token_workspace_manager = - n->models_[0]->CreateDraftTokenWorkspaceManager(max_num_tokens); + n->models_[0]->CreateDraftTokenWorkspaceManager(max_num_tokens * 2); draft_token_workspace_manager->AllocWorkspace( &n->model_workspaces_[0], /*require_hidden_states=*/engine_config->speculative_mode == SpeculativeMode::kEagle); From d01e1fcaa9eae0d95fac7847c45ba3d962626c12 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Tue, 7 May 2024 17:23:55 +0530 Subject: [PATCH 279/531] [CMAKE][BUILD] Add config option to enable OpenCL Host ptr (#2287) [CMAKE][BUILD] Add user option to enable OpenCL Host ptr --- cmake/gen_cmake_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cmake/gen_cmake_config.py b/cmake/gen_cmake_config.py index f12983c441..13d56af783 100644 --- a/cmake/gen_cmake_config.py +++ b/cmake/gen_cmake_config.py @@ -29,6 +29,7 @@ "USE_OPENCL", "Use OpenCL? (y/n) ", ), + Backend("OpenCLHostPtr", "USE_OPENCL_ENABLE_HOST_PTR", "Use OpenCLHostPtr? (y/n): "), ] enabled_backends = set() From 0829bcf7728650b9b8c9c244b534faa99f785476 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 7 May 2024 10:59:06 -0400 Subject: [PATCH 280/531] [Serving][Fix] Pass draft length when constructing draft action (#2291) This PR fixes a bug which does not pass the speculative decoding draft length to the draft generation stage. --- cpp/serve/engine.cc | 3 ++- cpp/serve/engine_actions/action.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 8c26b55778..616c463d9c 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -170,7 +170,8 @@ class EngineImpl : public Engine { engine_config, // n->trace_recorder_), EngineAction::BatchDraft(n->models_, logit_processor, sampler, n->model_workspaces_, - draft_token_workspace_manager, n->trace_recorder_), + draft_token_workspace_manager, n->trace_recorder_, + engine_config->spec_draft_length), EngineAction::BatchVerify(n->models_, logit_processor, sampler, n->model_workspaces_, draft_token_workspace_manager, engine_config, n->trace_recorder_)}; diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index c69c508810..067ef11dac 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -115,7 +115,7 @@ class EngineAction : public ObjectRef { static EngineAction BatchDraft(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, DraftTokenWorkspaceManager draft_token_workspace_manager, - Optional trace_recorder, int draft_length = 4); + Optional trace_recorder, int draft_length); /*! * \brief Create the action that runs one-step speculative draft proposal for From 2306086c9432d59aed2454335f66492688bd679f Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 7 May 2024 16:45:19 -0400 Subject: [PATCH 281/531] [Pass] Fix sampling func attachment to not read existing vocab size (#2292) This PR updates the AttachGPUSamplingFunc pass to make each sampling func have independent dynamic vocab size var. So we do not have to read the vocab size from the prefill function. --- .../mlc_llm/compiler_pass/attach_sampler.py | 43 ++++++++----------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index 4761914e2f..0a92f88cd8 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -33,24 +33,15 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR return mod bb = relax.BlockBuilder(mod) - # Prefill method exists in base models. - # Prefill_to_last_hidden method exists in base model and speculative small models - if "prefill" in mod: - vocab_size = mod["prefill"].ret_struct_info.fields[0].shape[-1] - else: - assert ( - "prefill_to_last_hidden_states" in mod - ), "Everay model should either has 'prefill' or 'prefill_to_last_hidden_states' method" - vocab_size = mod["prefill_to_last_hidden_states"].ret_struct_info.fields[0].shape[-1] gv_names = [ gv.name_hint for gv in [ - _attach_multinomial_sampling_func(bb, vocab_size), - _attach_argsort_func(bb, vocab_size), - _attach_sample_with_top_p(bb, vocab_size), - _attach_take_probs_func(bb, vocab_size), - _attach_batch_verifier(bb, vocab_size), - _attach_renormalize_by_top_p(bb, vocab_size, self.target), + _attach_multinomial_sampling_func(bb), + _attach_argsort_func(bb), + _attach_sample_with_top_p(bb), + _attach_take_probs_func(bb), + _attach_batch_verifier(bb), + _attach_renormalize_by_top_p(bb, self.target), ] ] @@ -64,9 +55,10 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR return mod -def _attach_multinomial_sampling_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): +def _attach_multinomial_sampling_func(bb: relax.BlockBuilder): batch_size = tir.Var("batch_size", "int64") num_samples = tir.Var("num_samples", "int64") + vocab_size = tir.Var("vocab_size", "int64") probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) uniform_samples = relax.Var( "uniform_samples", relax.TensorStructInfo((num_samples,), "float32") @@ -109,8 +101,9 @@ def _attach_multinomial_sampling_func(bb: relax.BlockBuilder, vocab_size: tir.Pr return gv -def _attach_argsort_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): +def _attach_argsort_func(bb: relax.BlockBuilder): batch_size = tir.Var("batch_size", "int64") + vocab_size = tir.Var("vocab_size", "int64") probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) with bb.function("argsort_probs", [probs]): with bb.dataflow(): @@ -141,11 +134,10 @@ def full(var_result: T.handle, value: T.int32): result[vi, 0] = value -def _attach_sample_with_top_p( # pylint: disable=too-many-locals - bb: relax.BlockBuilder, vocab_size: tir.PrimExpr -): +def _attach_sample_with_top_p(bb: relax.BlockBuilder): # pylint: disable=too-many-locals batch_size = tir.Var("batch_size", "int64") num_samples = tir.Var("num_samples", "int64") + vocab_size = tir.Var("vocab_size", "int64") sorted_probs = relax.Var( "sorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32") ) @@ -226,10 +218,9 @@ def _attach_sample_with_top_p( # pylint: disable=too-many-locals return gv -def _attach_renormalize_by_top_p( - bb: relax.BlockBuilder, vocab_size: tir.PrimExpr, target: tvm.target.Target -): +def _attach_renormalize_by_top_p(bb: relax.BlockBuilder, target: tvm.target.Target): batch_size = tir.Var("batch_size", "int64") + vocab_size = tir.Var("vocab_size", "int64") num_pivots = 3 probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32")) @@ -258,10 +249,11 @@ def _attach_renormalize_by_top_p( return gv -def _attach_take_probs_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): +def _attach_take_probs_func(bb: relax.BlockBuilder): batch_size = tir.Var("batch_size", "int64") num_samples = tir.Var("num_samples", "int64") num_positions = tir.Var("num_positions", "int64") + vocab_size = tir.Var("vocab_size", "int64") unsorted_probs = relax.Var( "unsorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32") ) @@ -327,9 +319,10 @@ def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-argument return gv -def _attach_batch_verifier(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): +def _attach_batch_verifier(bb: relax.BlockBuilder): num_nodes = tir.Var("num_nodes", "int64") nbatch = tir.Var("nbatch", "int64") + vocab_size = tir.Var("vocab_size", "int64") draft_probs = relax.Var( "draft_probs", relax.TensorStructInfo((num_nodes, vocab_size), "float32") ) From b499d2b3ea91c5260ccc47fd6b07d5792f8dd8e0 Mon Sep 17 00:00:00 2001 From: Mengshiun Yu Date: Tue, 7 May 2024 19:31:30 -0400 Subject: [PATCH 282/531] [SLM] Introduce microsoft/Phi-3 (#2222) Introduce microsoft/Phi-3 from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct --- python/mlc_llm/conversation_template.py | 18 + python/mlc_llm/interface/gen_config.py | 1 + python/mlc_llm/model/model.py | 15 + python/mlc_llm/model/model_preset.py | 33 ++ python/mlc_llm/model/phi3/__init__.py | 0 python/mlc_llm/model/phi3/phi3_loader.py | 79 ++++ python/mlc_llm/model/phi3/phi3_model.py | 371 ++++++++++++++++++ .../mlc_llm/model/phi3/phi3_quantization.py | 54 +++ 8 files changed, 571 insertions(+) create mode 100644 python/mlc_llm/model/phi3/__init__.py create mode 100644 python/mlc_llm/model/phi3/phi3_loader.py create mode 100644 python/mlc_llm/model/phi3/phi3_model.py create mode 100644 python/mlc_llm/model/phi3/phi3_quantization.py diff --git a/python/mlc_llm/conversation_template.py b/python/mlc_llm/conversation_template.py index 56547ec1c3..22cd49c8dd 100644 --- a/python/mlc_llm/conversation_template.py +++ b/python/mlc_llm/conversation_template.py @@ -246,6 +246,24 @@ def get_conv_template(name: str) -> Optional[Conversation]: ) ) +# Phi-3 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="phi-3", + system_template=f"<|system|>\n{MessagePlaceholders.SYSTEM.value}", + system_message="You are a helpful digital assistant. Please provide safe, " + "ethical and accurate information to the user.", + roles={"user": "<|user|>", "assistant": "<|assistant|>"}, + seps=["<|end|>\n"], + role_content_sep="\n", + role_empty_sep="\n", + system_prefix_token_ids=[1], + stop_str=["<|endoftext|>"], + stop_token_ids=[32000, 32001, 32007], + ) +) + + # StableLM Tuned Alpha ConvTemplateRegistry.register_conv_template( Conversation( diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index 13f0e1215f..e7ae49df2a 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -379,6 +379,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b "glm", "custom", # for web-llm only "phi-2", + "phi-3", "stablelm-2", "gemma_instruction", "orion", diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 595d7ba9a3..84d47ffd68 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -22,6 +22,7 @@ from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization from .orion import orion_loader, orion_model, orion_quantization from .phi import phi_loader, phi_model, phi_quantization +from .phi3 import phi3_loader, phi3_model, phi3_quantization from .qwen import qwen_loader, qwen_model, qwen_quantization from .qwen2 import qwen2_loader, qwen2_model, qwen2_quantization from .rwkv5 import rwkv5_loader, rwkv5_model, rwkv5_quantization @@ -201,6 +202,20 @@ class Model: "ft-quant": phi_quantization.ft_quant, }, ), + "phi3": Model( + name="phi3", + model=phi3_model.Phi3ForCausalLM, + config=phi3_model.Phi3Config, + source={ + "huggingface-torch": phi3_loader.phi3_huggingface, + "huggingface-safetensor": phi3_loader.phi3_huggingface, + }, + quantize={ + "no-quant": phi3_quantization.no_quant, + "group-quant": phi3_quantization.group_quant, + "ft-quant": phi3_quantization.ft_quant, + }, + ), "qwen": Model( name="qwen", model=qwen_model.QWenLMHeadModel, diff --git a/python/mlc_llm/model/model_preset.py b/python/mlc_llm/model/model_preset.py index 41abf0292c..a7276308b7 100644 --- a/python/mlc_llm/model/model_preset.py +++ b/python/mlc_llm/model/model_preset.py @@ -358,6 +358,39 @@ "transformers_version": "4.35.2", "vocab_size": 51200, }, + "phi-3": { + "_name_or_path": "Phi-3-mini-4k-instruct", + "architectures": ["Phi3ForCausalLM"], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_phi3.Phi3Config", + "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM", + }, + "bos_token_id": 1, + "embd_pdrop": 0.0, + "eos_token_id": 32000, + "hidden_act": "silu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 4096, + "model_type": "phi3", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "original_max_position_embeddings": 4096, + "pad_token_id": 32000, + "resid_pdrop": 0.0, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 10000.0, + "sliding_window": 2047, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.39.3", + "use_cache": True, + "vocab_size": 32064, + }, "qwen": { "architectures": ["QWenLMHeadModel"], "auto_map": { diff --git a/python/mlc_llm/model/phi3/__init__.py b/python/mlc_llm/model/phi3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/phi3/phi3_loader.py b/python/mlc_llm/model/phi3/phi3_loader.py new file mode 100644 index 0000000000..ab694457d7 --- /dev/null +++ b/python/mlc_llm/model/phi3/phi3_loader.py @@ -0,0 +1,79 @@ +""" +This file specifies how MLC's Phi parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .phi3_model import Phi3Config, Phi3ForCausalLM + + +def phi3_huggingface(model_config: Phi3Config, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of Phi-1/Phi-1.5 HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : PhiConfig + The configuration of the Phi model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = Phi3ForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params = model.export_tvm( # pylint: disable=W0632:unbalanced-tuple-unpacking + spec=model.get_default_spec() + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + def _add(mlc_name, hf_name): + mapping.add_mapping( + mlc_name, + [hf_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + def _concat_add(mlc_name, hf_names): + mapping.add_mapping( + mlc_name, + hf_names, + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + _add("lm_head.weight", "lm_head.weight") + _add("transformer.norm.weight", "model.norm.weight") + _add("transformer.embd.weight", "model.embed_tokens.weight") + + prefix = "transformer.h" + hf_prefix = "model.layers" + for i in range(model_config.num_hidden_layers): + _add(f"{prefix}.{i}.ln.weight", f"{hf_prefix}.{i}.input_layernorm.weight") + _add(f"{prefix}.{i}.mlp.down_proj.weight", f"{hf_prefix}.{i}.mlp.down_proj.weight") + _add(f"{prefix}.{i}.mlp.gate_up_proj.weight", f"{hf_prefix}.{i}.mlp.gate_up_proj.weight") + _add( + f"{prefix}.{i}.post_attention_layernorm.weight", + f"{hf_prefix}.{i}.post_attention_layernorm.weight", + ) + _add(f"{prefix}.{i}.mixer.out_proj.weight", f"{hf_prefix}.{i}.self_attn.o_proj.weight") + _add(f"{prefix}.{i}.mixer.qkv_proj.weight", f"{hf_prefix}.{i}.self_attn.qkv_proj.weight") + return mapping diff --git a/python/mlc_llm/model/phi3/phi3_model.py b/python/mlc_llm/model/phi3/phi3_model.py new file mode 100644 index 0000000000..7169ba2668 --- /dev/null +++ b/python/mlc_llm/model/phi3/phi3_model.py @@ -0,0 +1,371 @@ +""" +Implementation for Phi architecture. +TODO: add docstring +""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class Phi3Config(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Phi-3 model.""" + + model_type: str # "phi", "phi-msft", "mixformer-sequential" + hidden_size: int + vocab_size: int + num_hidden_layers: int + num_attention_heads: int + intermediate_size: int + rms_norm_eps: float + num_key_value_heads: int + position_embedding_base: int = 0 + context_window_size: int = 0 + prefill_chunk_size: int = 0 + head_dim: int = 0 + tensor_parallel_shards: int = 1 + max_batch_size: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %d", + bold("prefill_chunk_size"), + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) + + if self.num_key_value_heads == 0 or self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size + assert self.num_attention_heads % self.num_key_value_heads == 0 + + +# pylint: disable=invalid-name,missing-docstring + + +class Phi3MLP(nn.Module): + def __init__(self, config: Phi3Config): + super().__init__() + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) + + def forward(self, hidden_states: Tensor): + up_states = self.gate_up_proj(hidden_states) + gate, up_states = nn.op.split(up_states, 2, axis=-1) + up_states = up_states * op.silu(gate) + return self.down_proj(up_states) + + +class PhiMHA(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: Phi3Config): + self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards + assert config.num_attention_heads % config.tensor_parallel_shards == 0, ( + f"num_attention_heads({config.num_attention_heads}) " + "must be divisible by tensor_parallel_shards" + ) + self.num_key_value_heads = config.num_key_value_heads // config.tensor_parallel_shards + assert config.num_key_value_heads % config.tensor_parallel_shards == 0, ( + f"num_attention_heads({config.num_key_value_heads}) " + "must be divisible by tensor_parallel_shards" + ) + self.head_dim = config.head_dim + + self.qkv_proj = nn.Linear( + in_features=config.hidden_size, + out_features=(self.num_q_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=False, + ) + self.out_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_key_value_heads + b, s, _ = hidden_states.shape + # QKV Projection + qkv = self.qkv_proj(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + (b, s, h_q * d), + ) + return self.out_proj(output) + + +class Phi3ParallelBlock(nn.Module): + def __init__(self, config: Phi3Config): + super().__init__() + + self.ln = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + self.mixer = PhiMHA(config) + self.mlp = Phi3MLP(config) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, -1, config.rms_norm_eps, bias=False + ) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.mixer.num_q_heads * hd + k = self.mixer.num_key_value_heads * hd + v = self.mixer.num_key_value_heads * hd + i = self.mlp.intermediate_size + + _set(self.mixer.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0)) + _set(self.mixer.out_proj, tp.ShardSingleDim("_shard_o", dim=1)) + _set(self.mlp.gate_up_proj, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0)) + _set(self.mlp.down_proj, tp.ShardSingleDim("_shard_mlp_down", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + attn_outputs = self.mixer(self.ln(hidden_states), paged_kv_cache, layer_id) + hidden_states = self._apply_parallel_residual(attn_outputs, hidden_states) + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = self._apply_parallel_residual(out, hidden_states) + return hidden_states + + def _apply_parallel_residual(self, mlp_out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(mlp_out + residual / self.tensor_parallel_shards, "sum") + return mlp_out + residual + + +class Phi3Model(nn.Module): + def __init__(self, config: Phi3Config) -> None: + super().__init__() + self.embd = nn.Embedding(config.vocab_size, config.hidden_size) + self.h = nn.ModuleList([Phi3ParallelBlock(config) for _ in range(config.num_hidden_layers)]) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + + def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = input_embed + for layer_id, layer in enumerate(self.h): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class Phi3ForCausalLM(nn.Module): + # pylint: disable=too-many-instance-attributes + def __init__(self, config: Phi3Config) -> None: + super().__init__() + + self.transformer = Phi3Model(config) + self.lm_head = nn.Linear(config.hidden_size, "vocab_size", bias=False) + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.vocab_size = config.vocab_size + self.rope_theta = config.position_embedding_base + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.transformer(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + lm_logits = self.lm_head(hidden_states) + if lm_logits.dtype != "float32": + lm_logits = lm_logits.astype("float32") + return lm_logits + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.transformer(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + + if logits.dtype != "float32": + logits = logits.astype("float32") + + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.transformer(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + embeds = self.transformer.embd(input_ids) + return embeds + + def create_paged_kv_cache( # pylint: disable=too-many-arguments + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/phi3/phi3_quantization.py b/python/mlc_llm/model/phi3/phi3_quantization.py new file mode 100644 index 0000000000..008b3e22c9 --- /dev/null +++ b/python/mlc_llm/model/phi3/phi3_quantization.py @@ -0,0 +1,54 @@ +"""This file specifies how MLC's Llama parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .phi3_model import Phi3Config, Phi3ForCausalLM + + +def group_quant( + model_config: Phi3Config, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Phi-architecture model using group quantization.""" + model: nn.Module = Phi3ForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: Phi3Config, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Phi-architecture model using FasterTransformer quantization.""" + model: nn.Module = Phi3ForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: Phi3Config, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Phi model without quantization.""" + model: nn.Module = Phi3ForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map From 3621bf63b53f494888a0c2ce1fae9136315c1e15 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 7 May 2024 16:43:15 -0700 Subject: [PATCH 283/531] [Eagle] Run additional decode for draft model when all proposals are accepted (#2294) --- .../engine_actions/eagle_batch_verify.cc | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index 6b23035f78..0f5fba4a5a 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -140,6 +140,13 @@ class EagleBatchVerifyActionObj : public EngineActionObj { draft_output_tokens, draft_probs_on_device); ICHECK_EQ(sample_results_arr.size(), num_rsentries); + // We collect the requests whose drafts are fully accepted. + // When a request's draft is fully accepted, there is an extra token proposed + // by the draft model but not added into the draft model's KV cache. + // In this case, an additional batch decode step is needed for these requests. + std::vector fully_accepted_rsentries; + fully_accepted_rsentries.reserve(num_rsentries); + std::vector last_accepted_hidden_positions; last_accepted_hidden_positions.reserve(num_rsentries); for (int i = 0; i < num_rsentries; ++i) { @@ -157,6 +164,7 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // - Take max with 0 in case of all accepted. int rollback_length = std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0); + // rollback kv cache // NOTE: when number of small models is more than 1 (in the future), // it is possible to re-compute prefill for the small models. @@ -166,6 +174,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // Draft model rollback minus one because verify uses one more token. models_[draft_model_id_]->PopNFromKVCache( rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1); + } else { + fully_accepted_rsentries.push_back(i); } // clear the draft model state entries rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(&draft_token_slots_); @@ -173,7 +183,62 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // - Slice and save hidden_states_for_sample last_accepted_hidden_positions.push_back(cum_verify_lengths[i] + accept_length - 1); } + if (!fully_accepted_rsentries.empty()) { + // - Run a step of batch decode for requests whose drafts are fully accepted. + // When a request's draft is fully accepted, there is an extra token proposed + // by the draft model but not added into the draft model's KV cache. + // In this case, an additional batch decode step is needed for these requests. + std::vector input_tokens; + std::vector fully_accepted_request_internal_ids; + input_tokens.reserve(fully_accepted_rsentries.size()); + fully_accepted_request_internal_ids.reserve(fully_accepted_rsentries.size()); + + std::vector hidden_states_positions_for_fully_accepted; + hidden_states_positions_for_fully_accepted.reserve(fully_accepted_rsentries.size()); + + for (int rsentry_id : fully_accepted_rsentries) { + int num_committed_tokens = + rsentries[rsentry_id]->mstates[verify_model_id_]->committed_tokens.size(); + // When a request's draft is fully accepted, an additional new token is sampled. + // So the token needed to fill in the draft model is the committed_token[-2]. + ICHECK_GE(num_committed_tokens, 2); + input_tokens.push_back(rsentries[rsentry_id] + ->mstates[verify_model_id_] + ->committed_tokens[num_committed_tokens - 2] + .sampled_token_id.first); + + // Taking the hidden states of the token before the last token + hidden_states_positions_for_fully_accepted.push_back( + last_accepted_hidden_positions[rsentry_id] - 1); + fully_accepted_request_internal_ids.push_back( + rsentries[rsentry_id]->mstates[draft_model_id_]->internal_id); + } + // - Compute embeddings. + ObjectRef embeddings = models_[draft_model_id_]->TokenEmbed( + {IntTuple{input_tokens.begin(), input_tokens.end()}}); + // - Gather hidden states + ObjectRef hidden_states_for_fully_accepted = models_[draft_model_id_]->GatherHiddenStates( + hidden_states, hidden_states_positions_for_fully_accepted, + &model_workspaces_[draft_model_id_].hidden_states); + // - Invoke model decode. + ObjectRef fused_embedding_hidden_states = + models_[draft_model_id_]->FuseEmbedHidden(embeddings, hidden_states_for_fully_accepted, + /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states_for_fully_accepted = models_[draft_model_id_]->BatchDecodeToLastHidden( + fused_embedding_hidden_states, request_internal_ids); + // - We explicitly synchronize to avoid the input tokens getting overriden in the + // next runs of BatchDecode. + // This is because we do not do sample for this round of batch decode. + if (hidden_states_for_fully_accepted->IsInstance()) { + Downcast(Downcast(hidden_states_for_fully_accepted)->session)->SyncWorker(0); + } else { + NDArray hidden_states_for_fully_accepted_nd = + Downcast(hidden_states_for_fully_accepted); + TVMSynchronize(hidden_states_for_fully_accepted_nd->device.device_type, + hidden_states_for_fully_accepted_nd->device.device_id, nullptr); + } + } { // One step draft for the following steps From df4e2f37bbaeace797f278ed8c8b1dba33a0370c Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 8 May 2024 06:57:27 -0400 Subject: [PATCH 284/531] [iOS] Introducing package CLI for iOS app packaging (#2297) This PR introduces the packaging CLI `mlc_llm package` which reads from a `mlc-package-config.json` and compiles model and prepares model/runtime libraries automatically. With this PR, we get rid of prebuilt model library dependency for iOS app build. Validated that the iOS build can work. iOS documentation is updated according to this latest change. The same flow is supposed to work for Android as well, while it still needs verification for Android app build. --- docs/compilation/compile_models.rst | 4 +- docs/compilation/convert_weights.rst | 4 +- docs/deploy/ios.rst | 421 +++++++----------- docs/deploy/javascript.rst | 4 +- ios/MLCChat.xcodeproj/project.pbxproj | 14 +- ios/MLCChat/Common/Constants.swift | 4 +- ios/MLCChat/States/AppState.swift | 2 +- ios/MLCChat/mlc-package-config.json | 33 ++ .../project.pbxproj | 12 +- .../MLCEngineExampleApp.swift | 16 +- ios/MLCEngineExample/mlc-package-config.json | 11 + ios/prepare_model_lib.py | 88 ---- ios/{prepare_libs.sh => prepare_package.sh} | 3 +- ios/prepare_params.sh | 32 -- python/mlc_llm/__main__.py | 6 +- python/mlc_llm/chat_module.py | 12 +- python/mlc_llm/cli/package.py | 55 +++ python/mlc_llm/help.py | 11 + python/mlc_llm/interface/jit.py | 46 +- python/mlc_llm/interface/package.py | 274 ++++++++++++ python/mlc_llm/serve/engine_base.py | 12 +- 21 files changed, 625 insertions(+), 439 deletions(-) create mode 100644 ios/MLCChat/mlc-package-config.json create mode 100644 ios/MLCEngineExample/mlc-package-config.json delete mode 100644 ios/prepare_model_lib.py rename ios/{prepare_libs.sh => prepare_package.sh} (94%) delete mode 100755 ios/prepare_params.sh create mode 100644 python/mlc_llm/cli/package.py create mode 100644 python/mlc_llm/interface/package.py diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index 560ca17255..a98de7d97a 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -245,10 +245,10 @@ All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_con .. note:: - For the ``conv-template``, `conv_template.cc `__ + For the ``conv-template``, `conversation_template.py `__ contains a full list of conversation templates that MLC provides. If the model you are adding requires a new conversation template, you would need to add your own. - Follow `this PR `__ as an example. + Follow `this PR `__ as an example. However, adding your own template would require you :ref:`build mlc_llm from source ` in order for it to be recognized by the runtime. diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst index 1518f5145a..e350ba4ac5 100644 --- a/docs/compilation/convert_weights.rst +++ b/docs/compilation/convert_weights.rst @@ -107,10 +107,10 @@ See :ref:`compile-command-specification` for specification of ``gen_config``. ``dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/mlc-chat-config.json`` (checkout :ref:`configure-mlc-chat-json` for more detailed instructions). You can also simply use the default configuration. - `conv_template.cc `__ + `conversation_template.py `__ contains a full list of conversation templates that MLC provides. If the model you are adding requires a new conversation template, you would need to add your own. - Follow `this PR `__ as an example. However, + Follow `this PR `__ as an example. However, adding your own template would require you :ref:`build mlc_llm from source ` in order for it to be recognized by the runtime. diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index 2bcf7997d3..d326a53fbb 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -29,10 +29,17 @@ Step 1. Install Build Dependencies ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ First and foremost, please clone the `MLC LLM GitHub repository `_. +After cloning, go to the ``ios/`` directory. + +.. code:: bash + + git clone https://github.com/mlc-ai/mlc-llm.git + cd mlc-llm + git submodule update --init --recursive + cd ./ios + Please follow :doc:`/install/tvm` to install TVM Unity. -Note that we **do not** have to run `build.py` since we can use prebuilt weights. -We only need TVM Unity's utility to combine the libraries (`local-id-iphone.tar`) into a single library. We also need to have the following build dependencies: @@ -40,88 +47,84 @@ We also need to have the following build dependencies: * Git and Git-LFS, * `Rust and Cargo `_, which are required by Hugging Face's tokenizer. +.. _ios-build-runtime-and-model-libraries: -Step 2. Download Prebuilt Weights and Library -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Step 2. Build Runtime and Model Libraries +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -You also need to obtain a copy of the MLC-LLM source code -by cloning the `MLC LLM GitHub repository `_. -To simplify the build, we will use prebuilt model -weights and libraries here. Run the following command -in the root directory of the MLC-LLM. +The models to be built for the iOS app are specified in ``MLCChat/mlc-package-config.json``: +in the ``model_list`` field of this file, ``model`` points to the Hugging Face model repository, +where model weights are downloaded from. ``model_id`` is a unique model identifier. +``estimated_vram_bytes`` is an estimation of the vRAM the model takes at runtime. -.. code:: bash - - mkdir -p dist/prebuilt - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib - - cd dist/prebuilt - git lfs install - git clone https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - cd ../.. - -Validate that the files and directories exist: +We have a one-line command to build and prepare all the model libraries: .. code:: bash - >>> ls -l ./dist/prebuilt/lib/*/*-iphone.tar - ./dist/prebuilt/lib/RedPajama-INCITE-Chat-3B-v1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar - ./dist/prebuilt/lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q3f16_1-iphone.tar - ... + ./prepare_package.sh - >>> ls -l ./dist/prebuilt/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - # chat config: - mlc-chat-config.json - # model weights: - ndarray-cache.json - params_shard_*.bin - ... - - -Step 3. Build Auxiliary Components -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +This command mainly executes the following two steps: -**Tokenizer and runtime** +1. **Build runtime and tokenizer.** In addition to the model itself, a lightweight runtime and tokenizer are required to actually run the LLM. +2. **Compile models.** We compile each model in ``model_list`` of ``MLCChat/mlc-package-config.json`` into a binary model library. -In addition to the model itself, a lightweight runtime and tokenizer are -required to actually run the LLM. You can build and organize these -components by following these steps: +The command creates a ``./dist/`` directory that contains the runtime and model build output. +Please make sure all the following files exist in ``./dist/``. .. code:: bash - git submodule update --init --recursive - cd ./ios - ./prepare_libs.sh + >>> ls ./dist + bundle # The directory for mlc-app-config.json (and optionally model weights) + # that will be bundled into the iOS app. + lib # The directory for runtime and model libraries. -This will create a ``./build`` folder that contains the following files. -Please make sure all the following files exist in ``./build/``. - -.. code:: bash + >>> ls ./dist/bundle + mlc-app-config.json # The app config JSON file. - >>> ls ./build/lib/ + >>> ls ./dist/lib libmlc_llm.a # A lightweight interface to interact with LLM, tokenizer, and TVM Unity runtime libmodel_iphone.a # The compiled model lib libsentencepiece.a # SentencePiece tokenizer libtokenizers_cpp.a # Huggingface tokenizer libtvm_runtime.a # TVM Unity runtime -**Add prepackage model** -We can also *optionally* add prepackage weights into the app, -run the following command under the ``./ios`` directory: +.. _ios-bundle-model-weights: -.. code:: bash +Step 3. (Optional) Bundle model weights into the app +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - cd ./ios - open ./prepare_params.sh # make sure builtin_list only contains "RedPajama-INCITE-Chat-3B-v1-q4f16_1" - ./prepare_params.sh +By default, we download the model weights from Hugging Face when running the app. +**As an option,**, we bundle model weights into the app: +set the field ``"bundle_weight": true`` for any model you want to bundle weights +in ``MLCChat/mlc-package-config.json``, and run ``prepare_package.sh`` again. +Below is an example: -The outcome should be as follows: +.. code:: json + + { + "model_list": [ + { + "model": "HF://mlc-ai/gemma-2b-it-q4f16_1-MLC", + "model_id": "gemma-2b-q4f16_1", + "estimated_vram_bytes": 3000000000, + "overrides": { + "prefill_chunk_size": 128 + }, + "bundle_weight": true + } + ] + } + +The outcome of running ``prepare_package.sh`` should be as follows: .. code:: bash - >>> ls ./dist/ - RedPajama-INCITE-Chat-3B-v1-q4f16_1 + >>> ls ./dist/bundle + mlc-app-config.json + gemma-2b-it-q4f16_1-MLC # The model weights that will be bundled into the app. + +.. _ios-build-app: Step 4. Build iOS App ^^^^^^^^^^^^^^^^^^^^^ @@ -146,51 +149,99 @@ to run on your Mac. You can also directly run it on your iPad or iPhone. Customize the App ----------------- -We can customize the iOS app in several ways. -`MLCChat/app-config.json `_ -controls the list of local and remote models to be packaged into the app, given a local path or a URL respectively. Only models in ``model_list`` will have their libraries brought into the app when running `./prepare_libs` to package them into ``libmodel_iphone.a``. Each model defined in `app-config.json` contain the following fields: +We can customize the models built in the iOS app by customizing `MLCChat/mlc-package-config.json `_. +We introduce each field of the JSON file here. -``model_path`` - (Required if local model) Name of the local folder containing the weights. +Each entry in ``"model_list"`` of the JSON file has the following fields: -``model_url`` - (Required if remote model) URL to the repo containing the weights. +``model`` + (Required) The path to the MLC-converted model to be built into the app. + + It can be either a Hugging Face URL (e.g., ``"model": "HF://mlc-ai/phi-2-q4f16_1-MLC"```), or a path to a local model directory which contains converted model weights (e.g., ``"model": "../dist/gemma-2b-q4f16_1"``). Please check out :ref:`convert-weights-via-MLC` if you want to build local model into the app. + + *Note: the local path (if relative) is relative to the* ``ios/`` *directory.* ``model_id`` - (Required) Unique local identifier to identify the model. + (Required) A unique local identifier to identify the model. + It can be an arbitrary one. -``model_lib`` - (Required) Matches the system-lib-prefix, generally set during ``mlc_llm compile`` which can be specified using - ``--system-lib-prefix`` argument. By default, it is set to ``"${model_type}_${quantization}"`` e.g. ``gpt_neox_q4f16_1`` - for the RedPajama-INCITE-Chat-3B-v1 model. If the ``--system-lib-prefix`` argument is manually specified during - ``mlc_llm compile``, the ``model_lib`` field should be updated accordingly. +``estimated_vram_bytes`` + (Required) Estimated requirements of vRAM to run the model. -``required_vram_bytes`` - (Required) Estimated requirements of VRAM to run the model. +``bundle_weight`` + (Optional) A boolean flag indicating whether to bundle model weights into the app. See :ref:`ios-bundle-model-weights`. -``model_lib_path_for_prepare_libs`` - (Required) List of paths to the model libraries in the app (respective ``.tar`` file in the ``binary-mlc-llm-libs`` - repo, relative path in the ``dist`` artifact folder or full path to the library). Only used while running - ``prepare_libs.sh`` to determine which model library to use during runtime. Useful when selecting a library with - different settings (e.g. ``prefill_chunk_size``, ``context_window_size``, and ``sliding_window_size``). +``overrides`` + (Optional) A dictionary to override the default model context window size (to limit the KV cache size) and prefill chunk size (to limit the model temporary execution memory). + Example: -Additionally, the app prepackages the models under ``./ios/dist``. -This built-in list can be controlled by editing ``prepare_params.sh``. -You can package new prebuilt models or compiled models by changing the above fields and then repeating the steps above. + .. code:: json + { + "model_list": [ + { + "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", + "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", + "estimated_vram_bytes": 2960000000, + "overrides": { + "context_window_size": 512, + "prefill_chunk_size": 128 + } + } + ] + } -Bring Your Own Model Variant ----------------------------- +``model_lib`` + (Optional) A string specifying the system library prefix to use for the model. + Usually this is used when you want to build multiple model variants with the same architecture into the app. + **This field does not affect any app functionality.** + The ``"model_lib_path_for_prepare_libs"`` introduced below is also related. + Example: -In cases where the model you are adding is simply a variant of an existing -model, we only need to convert weights and reuse existing model library. For instance: + .. code:: json -- Adding ``NeuralHermes`` when MLC already supports the ``Mistral`` architecture + { + "model_list": [ + { + "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", + "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", + "estimated_vram_bytes": 2960000000, + "model_lib": "gpt_neox_q4f16_1" + } + ] + } -In this section, we walk you through adding ``NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC`` to the MLC iOS app. -According to the model's ``config.json`` on `its Huggingface repo `_, -it reuses the Mistral model architecture. +Besides ``model_list`` in ``MLCChat/mlc-package-config.json``, +you can also **optionally** specify a dictionary of ``"model_lib_path_for_prepare_libs"``, +**if you want to use model libraries that are manually compiled**. +The keys of this dictionary should be the ``model_lib`` that specified in model list, +and the values of this dictionary are the paths (absolute, or relative) to the manually compiled model libraries. +The model libraries specified in ``"model_lib_path_for_prepare_libs"`` will be built into the app when running ``prepare_package.sh``. +Example: + +.. code:: json + + { + "model_list": [ + { + "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", + "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", + "estimated_vram_bytes": 2960000000, + "model_lib": "gpt_neox_q4f16_1" + } + ], + "model_lib_path_for_prepare_libs": { + "gpt_neox_q4f16_1": "../dist/lib/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar" + } + } + + +Bring Your Own Model +-------------------- + +This section introduces how to build your own model into the iOS app. +We use the example of `NeuralHermes `_ model, which a variant of Mistral model. .. note:: @@ -198,7 +249,7 @@ it reuses the Mistral model architecture. See that page for more details. Note that the weights are shared across all platforms in MLC. -**Step 1 Clone from HF and convert_weight** +**Step 1. Clone from HF and convert_weight** You can be under the mlc-llm repo, or your own working directory. Note that all platforms can share the same compiled/quantized weights. See :ref:`compile-command-specification` @@ -217,7 +268,7 @@ for specification of ``convert_weight``. --quantization q4f16_1 \ -o dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC -**Step 2 Generate MLC Chat Config** +**Step 2. Generate MLC Chat Config** Use ``mlc_llm gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. See :ref:`compile-command-specification` for specification of ``gen_config``. @@ -228,16 +279,16 @@ See :ref:`compile-command-specification` for specification of ``gen_config``. --quantization q3f16_1 --conv-template neural_hermes_mistral \ -o dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC -For the ``conv-template``, `conv_template.cc `__ +For the ``conv-template``, `conversation_template.py `__ contains a full list of conversation templates that MLC provides. If the model you are adding requires a new conversation template, you would need to add your own. -Follow `this PR `__ as an example. +Follow `this PR `__ as an example. We look up the template to use with the ``conv_template`` field in ``mlc-chat-config.json``. For more details, please see :ref:`configure-mlc-chat-json`. -**Step 3 Upload weights to HF** +**Step 3. Upload weights to HF** .. code:: shell @@ -255,185 +306,33 @@ After successfully following all steps, you should end up with a Huggingface rep which includes the converted/quantized weights, the ``mlc-chat-config.json``, and tokenizer files. -**Step 4 Register as a ModelRecord** - -Finally, we modify the code snippet for -`app-config.json `__ -pasted above. - -We simply specify the Huggingface link as ``model_url``, while reusing the ``model_lib`` for -``Mistral-7B``. - -.. code:: javascript - - "model_list": [ - // Other records here omitted... - { - // Substitute model_url with the one you created `my-huggingface-account/my-mistral-weight-huggingface-repo` - "model_url": "https://huggingface.co/mlc-ai/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC", - "model_id": "Mistral-7B-Instruct-v0.2-q3f16_1", - "model_lib": "mistral_q3f16_1", - "model_lib": "lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q3f16_1-iphone.tar", - "estimated_vram_bytes": 3316000000 - } - ] - - -Now, the app will use the ``NeuralHermes-Mistral`` model you just added. - - -Bring Your Own Model Library ----------------------------- - -A model library is specified by: - - - The model architecture (e.g. ``mistral``, ``phi-msft``) - - Quantization Scheme (e.g. ``q3f16_1``, ``q0f32``) - - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill_chunk_size``), which affects memory planning - - Platform (e.g. ``cuda``, ``webgpu``, ``iphone``, ``android``) - -In cases where the model you want to run is not compatible with the provided MLC -prebuilt model libraries (e.g. having a different quantization, a different -metadata spec, or even a different model architecture), you need to build your -own model library. - -In this section, we walk you through adding ``phi-2`` to the iOS app. +**Step 4. Register in Model List** -This section largely replicates :ref:`compile-model-libraries`. See that page for -more details, specifically the ``iOS`` option. +Finally, we add the model into the ``model_list`` of +`MLCChat/mlc-package-config.json `_ by specifying the Hugging Face link as ``model``: -**Step 0. Install dependencies** +.. code:: json -To compile model libraries for iOS, you need to :ref:`build mlc_llm from source `. - -**Step 1. Clone from HF and convert_weight** - -You can be under the mlc-llm repo, or your own working directory. Note that all platforms -can share the same compiled/quantized weights. - -.. code:: shell - - # Create directory - mkdir -p dist/models && cd dist/models - # Clone HF weights - git lfs install - git clone https://huggingface.co/microsoft/phi-2 - cd ../.. - # Convert weight - mlc_llm convert_weight ./dist/models/phi-2/ \ - --quantization q4f16_1 \ - -o dist/phi-2-q4f16_1-MLC - -**Step 2. Generate mlc-chat-config and compile** - -A model library is specified by: - - - The model architecture (e.g. ``mistral``, ``phi-msft``) - - Quantization Scheme (e.g. ``q3f16_1``, ``q0f32``) - - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill_chunk_size``), which affects memory planning - - Platform (e.g. ``cuda``, ``webgpu``, ``iphone``, ``android``) - -All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_config``. - -.. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/phi-2/ \ - --quantization q4f16_1 --conv-template phi-2 \ - -o dist/phi-2-q4f16_1-MLC/ - # 2. mkdir: create a directory to store the compiled model library - mkdir -p dist/libs - # 3. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json \ - --device iphone -o dist/libs/phi-2-q4f16_1-iphone.tar - -Given the compiled library, it is possible to calculate an upper bound for the VRAM -usage during runtime. This useful to better understand if a model is able to fit particular -hardware. -That information will be displayed at the end of the console log when the ``compile`` is executed. -It might look something like this: - -.. code:: shell - - [2024-04-25 03:19:56] INFO model_metadata.py:96: Total memory usage: 1625.73 MB (Parameters: 1492.45 MB. KVCache: 0.00 MB. Temporary buffer: 133.28 MB) - [2024-04-25 03:19:56] INFO model_metadata.py:105: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size` - [2024-04-25 03:19:56] INFO compile.py:198: Generated: dist/libs/phi-2-q4f16_1-iphone.tar - -.. note:: - When compiling larger models like ``Llama-2-7B``, you may want to add a lower chunk size - while prefilling prompts ``--prefill_chunk_size 128`` or even lower ``context_window_size``\ - to decrease memory usage. Otherwise, during runtime, you may run out of memory. - - -**Step 3. Distribute model library and model weights** - -After following the steps above, you should end up with: - -.. code:: shell - - ~/mlc-llm > ls dist/libs - phi-2-q4f16_1-iphone.tar # ===> the model library - - ~/mlc-llm > ls dist/phi-2-q4f16_1-MLC - mlc-chat-config.json # ===> the chat config - ndarray-cache.json # ===> the model weight info - params_shard_0.bin # ===> the model weights - params_shard_1.bin - ... - tokenizer.json # ===> the tokenizer files - tokenizer_config.json - -Upload the ``phi-2-q4f16_1-iphone.tar`` to a github repository (for us, -it is in `binary-mlc-llm-libs `__). Then -upload the weights ``phi-2-q4f16_1-MLC`` to a Huggingface repo: - -.. code:: shell - - # First, please create a repository on Hugging Face. - # With the repository created, run - git lfs install - git clone https://huggingface.co/my-huggingface-account/my-phi-weight-huggingface-repo - cd my-phi-weight-huggingface-repo - cp path/to/mlc-llm/dist/phi-2-q4f16_1-MLC/* . - git add . && git commit -m "Add phi-2 model weights" - git push origin main - -This would result in something like `phi-2-q4f16_1-MLC -`_. - - -**Step 4. Register as a ModelRecord** - -Finally, we update the code snippet for -`app-config.json `__ -pasted above. - -We simply specify the Huggingface link as ``model_url``, while using the new ``model_lib`` for -``phi-2``. Regarding the field ``estimated_vram_bytes``, we can use the output of the last step -rounded up to MB. - -.. code:: javascript - - "model_list": [ - // Other records here omitted... - { - // Substitute model_url with the one you created `my-huggingface-account/my-phi-weight-huggingface-repo` - "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC", - "model_id": "phi-2-q4f16_1", - "model_lib": "phi_msft_q4f16_1", - "estimated_vram_bytes": 3043000000 - } - ] + { + "model_list": [ + { + "model": "HF://mlc-ai/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC", + "model_id": "Mistral-7B-Instruct-v0.2-q3f16_1", + "estimated_vram_bytes": 3316000000, + } + ] + } -Now, the app will use the ``phi-2`` model library you just added. +Now, go through :ref:`ios-build-runtime-and-model-libraries` and :ref:`ios-build-app` again. +The app will use the ``NeuralHermes-Mistral`` model you just added. Build Apps with MLC Swift API ----------------------------- We also provide a Swift package that you can use to build -your own app. The package is located under `ios/MLCSwift`. +your own app. The package is located under ``ios/MLCSwift``. - First make sure you have run the same steps listed in the previous section. This will give us the necessary libraries diff --git a/docs/deploy/javascript.rst b/docs/deploy/javascript.rst index bd92908cff..92e5b87ce1 100644 --- a/docs/deploy/javascript.rst +++ b/docs/deploy/javascript.rst @@ -150,11 +150,11 @@ See :ref:`compile-command-specification` for specification of ``gen_config``. --quantization q4f16_1 --conv-template wizard_coder_or_math \ -o dist/WizardMath-7B-V1.1-q4f16_1-MLC/ -For the ``conv-template``, `conv_template.cc `__ +For the ``conv-template``, `conversation_template.py `__ contains a full list of conversation templates that MLC provides. If the model you are adding requires a new conversation template, you would need to add your own. -Follow `this PR `__ as an example. Besides, you also need to add the new template to ``/path/to/web-llm/src/conversation.ts``. +Follow `this PR `__ as an example. Besides, you also need to add the new template to ``/path/to/web-llm/src/conversation.ts``. We look up the template to use with the ``conv_template`` field in ``mlc-chat-config.json``. For more details, please see :ref:`configure-mlc-chat-json`. diff --git a/ios/MLCChat.xcodeproj/project.pbxproj b/ios/MLCChat.xcodeproj/project.pbxproj index 4c5173fa3c..8b390e1401 100644 --- a/ios/MLCChat.xcodeproj/project.pbxproj +++ b/ios/MLCChat.xcodeproj/project.pbxproj @@ -16,8 +16,6 @@ AEC27EFA2A85C2AC00254E67 /* ParamsConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */; }; AEC27EFC2A85C3B000254E67 /* AppConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27EFB2A85C3B000254E67 /* AppConfig.swift */; }; AEC27F022A86337E00254E67 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27F012A86337E00254E67 /* Constants.swift */; }; - C06A74F229F9A78800BC4BE6 /* dist in CopyFiles */ = {isa = PBXBuildFile; fileRef = C06A74E029F99C9F00BC4BE6 /* dist */; }; - C09834192A16F4E000A05B51 /* app-config.json in CopyFiles */ = {isa = PBXBuildFile; fileRef = C09834182A16F4CB00A05B51 /* app-config.json */; }; C0D643B329F99A7F004DDAA4 /* MLCChatApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */; }; C0D643B729F99A80004DDAA4 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B629F99A80004DDAA4 /* Assets.xcassets */; }; C0D643BA29F99A80004DDAA4 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */; }; @@ -25,6 +23,7 @@ C0D643C829F99B34004DDAA4 /* MessageView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C729F99B34004DDAA4 /* MessageView.swift */; }; C0DDBDF62A39103F00E9D060 /* ChatState.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C029F99B07004DDAA4 /* ChatState.swift */; }; C0DDBE0D2A3BCD8000E9D060 /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = C0DDBE0C2A3BCD8000E9D060 /* MLCSwift */; }; + F3C280002BEB16ED00F1E016 /* bundle in CopyFiles */ = {isa = PBXBuildFile; fileRef = F3C27FFF2BEB16ED00F1E016 /* bundle */; }; /* End PBXBuildFile section */ /* Begin PBXCopyFilesBuildPhase section */ @@ -34,8 +33,7 @@ dstPath = ""; dstSubfolderSpec = 7; files = ( - C09834192A16F4E000A05B51 /* app-config.json in CopyFiles */, - C06A74F229F9A78800BC4BE6 /* dist in CopyFiles */, + F3C280002BEB16ED00F1E016 /* bundle in CopyFiles */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -61,7 +59,6 @@ AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ParamsConfig.swift; sourceTree = ""; }; AEC27EFB2A85C3B000254E67 /* AppConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppConfig.swift; sourceTree = ""; }; AEC27F012A86337E00254E67 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = ""; }; - C06A74E029F99C9F00BC4BE6 /* dist */ = {isa = PBXFileReference; lastKnownFileType = folder; path = dist; sourceTree = ""; }; C06A74E629F9A1DF00BC4BE6 /* MLCChat.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLCChat.entitlements; sourceTree = ""; }; C09834182A16F4CB00A05B51 /* app-config.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = "app-config.json"; sourceTree = ""; }; C0D643AF29F99A7F004DDAA4 /* MLCChat.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLCChat.app; sourceTree = BUILT_PRODUCTS_DIR; }; @@ -72,6 +69,7 @@ C0D643C229F99B07004DDAA4 /* ChatView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ChatView.swift; sourceTree = ""; }; C0D643C729F99B34004DDAA4 /* MessageView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MessageView.swift; sourceTree = ""; }; C0DDBE0B2A3BA6F800E9D060 /* MLCSwift */ = {isa = PBXFileReference; lastKnownFileType = wrapper; path = MLCSwift; sourceTree = ""; }; + F3C27FFF2BEB16ED00F1E016 /* bundle */ = {isa = PBXFileReference; lastKnownFileType = folder; name = bundle; path = dist/bundle; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -129,8 +127,8 @@ C0D643A629F99A7F004DDAA4 = { isa = PBXGroup; children = ( + F3C27FFF2BEB16ED00F1E016 /* bundle */, C0DDBDF02A39068900E9D060 /* Packages */, - C06A74E029F99C9F00BC4BE6 /* dist */, C0D643B129F99A7F004DDAA4 /* MLCChat */, C0D643B029F99A7F004DDAA4 /* Products */, C0D643C929F99BDA004DDAA4 /* Frameworks */, @@ -422,7 +420,7 @@ ); LIBRARY_SEARCH_PATHS = ( "$(inherited)", - "$(PROJECT_DIR)/build/lib", + "$(PROJECT_DIR)/dist/lib", ); MARKETING_VERSION = 1.3; OTHER_LDFLAGS = ( @@ -474,7 +472,7 @@ ); LIBRARY_SEARCH_PATHS = ( "$(inherited)", - "$(PROJECT_DIR)/build/lib", + "$(PROJECT_DIR)/dist/lib", ); MARKETING_VERSION = 1.3; OTHER_LDFLAGS = ( diff --git a/ios/MLCChat/Common/Constants.swift b/ios/MLCChat/Common/Constants.swift index cf3a240fcf..aa3d9654de 100644 --- a/ios/MLCChat/Common/Constants.swift +++ b/ios/MLCChat/Common/Constants.swift @@ -4,8 +4,8 @@ // struct Constants { - static let prebuiltModelDir = "dist" - static let appConfigFileName = "app-config.json" + static let prebuiltModelDir = "bundle" + static let appConfigFileName = "bundle/mlc-app-config.json" static let modelConfigFileName = "mlc-chat-config.json" static let paramsConfigFileName = "ndarray-cache.json" } diff --git a/ios/MLCChat/States/AppState.swift b/ios/MLCChat/States/AppState.swift index 4dc8d9f315..bd2f252b68 100644 --- a/ios/MLCChat/States/AppState.swift +++ b/ios/MLCChat/States/AppState.swift @@ -225,7 +225,7 @@ private extension AppState { // model_id dir should exist if modelURL == nil { - // prebuilt model in dist + // prebuilt model in bundle modelBaseURL = Bundle.main.bundleURL.appending(path: Constants.prebuiltModelDir).appending(path: modelPath!) } else { // download model in cache diff --git a/ios/MLCChat/mlc-package-config.json b/ios/MLCChat/mlc-package-config.json new file mode 100644 index 0000000000..db5b29206f --- /dev/null +++ b/ios/MLCChat/mlc-package-config.json @@ -0,0 +1,33 @@ +{ + "model_list": [ + { + "model": "HF://mlc-ai/Mistral-7B-Instruct-v0.2-q3f16_1-MLC", + "model_id": "Mistral-7B-Instruct-v0.2-q3f16_1", + "estimated_vram_bytes": 3316000000, + "overrides": { + "context_window_size": 512 + } + }, + { + "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", + "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", + "estimated_vram_bytes": 2960000000, + "overrides": { + "prefill_chunk_size": 128 + } + }, + { + "model": "HF://mlc-ai/phi-2-q4f16_1-MLC", + "model_id": "phi-2-q4f16_1", + "estimated_vram_bytes": 3043000000 + }, + { + "model": "HF://mlc-ai/gemma-2b-it-q4f16_1-MLC", + "model_id": "gemma-2b-q4f16_1", + "estimated_vram_bytes": 3000000000, + "overrides": { + "prefill_chunk_size": 128 + } + } + ] +} \ No newline at end of file diff --git a/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj b/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj index f24f333d83..52c9ac0108 100644 --- a/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj +++ b/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj @@ -12,7 +12,7 @@ C0B37B8D2BE8226B00B2F80B /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0B37B8C2BE8226B00B2F80B /* Assets.xcassets */; }; C0B37B902BE8226B00B2F80B /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0B37B8F2BE8226B00B2F80B /* Preview Assets.xcassets */; }; C0B37B982BE8234D00B2F80B /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = C0B37B972BE8234D00B2F80B /* MLCSwift */; }; - C0B37C0A2BE82D5900B2F80B /* dist in Copy Files */ = {isa = PBXBuildFile; fileRef = C0B37C062BE825DC00B2F80B /* dist */; }; + F31E1EEE2BEAD4870061D498 /* bundle in Copy Files */ = {isa = PBXBuildFile; fileRef = F31E1EED2BEAD4870061D498 /* bundle */; }; /* End PBXBuildFile section */ /* Begin PBXCopyFilesBuildPhase section */ @@ -22,7 +22,7 @@ dstPath = ""; dstSubfolderSpec = 7; files = ( - C0B37C0A2BE82D5900B2F80B /* dist in Copy Files */, + F31E1EEE2BEAD4870061D498 /* bundle in Copy Files */, ); name = "Copy Files"; runOnlyForDeploymentPostprocessing = 0; @@ -35,8 +35,8 @@ C0B37B8A2BE8226A00B2F80B /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; C0B37B8C2BE8226B00B2F80B /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; C0B37B8F2BE8226B00B2F80B /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; - C0B37C062BE825DC00B2F80B /* dist */ = {isa = PBXFileReference; lastKnownFileType = folder; name = dist; path = ../dist; sourceTree = ""; }; C0B37C0C2BE8349300B2F80B /* MLCEngineExample.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLCEngineExample.entitlements; sourceTree = ""; }; + F31E1EED2BEAD4870061D498 /* bundle */ = {isa = PBXFileReference; lastKnownFileType = folder; name = bundle; path = ../dist/bundle; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -54,7 +54,7 @@ C0B37B7C2BE8226A00B2F80B = { isa = PBXGroup; children = ( - C0B37C062BE825DC00B2F80B /* dist */, + F31E1EED2BEAD4870061D498 /* bundle */, C0B37B872BE8226A00B2F80B /* MLCEngineExample */, C0B37B862BE8226A00B2F80B /* Products */, ); @@ -314,7 +314,7 @@ "$(inherited)", "@executable_path/Frameworks", ); - LIBRARY_SEARCH_PATHS = "$(PROJECT_DIR)/../build/lib"; + LIBRARY_SEARCH_PATHS = "$(PROJECT_DIR)/../dist/lib"; MARKETING_VERSION = 1.0; OTHER_LDFLAGS = ( "-Wl,-all_load", @@ -355,7 +355,7 @@ "$(inherited)", "@executable_path/Frameworks", ); - LIBRARY_SEARCH_PATHS = "$(PROJECT_DIR)/../build/lib"; + LIBRARY_SEARCH_PATHS = "$(PROJECT_DIR)/../dist/lib"; MARKETING_VERSION = 1.0; OTHER_LDFLAGS = ( "-Wl,-all_load", diff --git a/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift index 19b6ab45de..cf4d3dae53 100644 --- a/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift +++ b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift @@ -5,8 +5,8 @@ // example and quick testing purposes. // // To build this app, select target My Mac(Designed for iPad) and run -// Make sure you run prepare_libs.sh and prepare_params.sh first -// to ensure the dist folder populates with the right model file +// Make sure you run prepare_package.sh first with "MLCChat" replaced by "MLCEngineExample" +// to ensure the "dist/bundle" folder populates with the right model file // and we have the model lib packaged correctly import Foundation import SwiftUI @@ -19,18 +19,12 @@ class AppState: ObservableObject { private let engine = MLCEngine() // obtain the local path to store models // this that stores the model files in the dist folder - private let distURL = Bundle.main.bundleURL.appending(path: "dist") - // NOTE: this does not yet work out of box - // need to supply the Llama-3-8B-Instruct-q3f16_1-MLC and llama_q3f16_1 - // via manual local compile - // TODO(mlc-team): update prebuild so it can be used out of box - // + private let bundleURL = Bundle.main.bundleURL.appending(path: "bundle") // model path, this must match a builtin // file name in prepare_params.sh private let modelPath = "Llama-3-8B-Instruct-q3f16_1-MLC" // model lib identifier of within the packaged library - // this must match a config in MLCChat/app-config.json - // make sure we run prepare_libs.sh + // make sure we run prepare_package.sh private let modelLib = "llama_q3f16_1" // this is a message to be displayed in app @@ -39,7 +33,7 @@ class AppState: ObservableObject { public func runExample() { // MLCEngine is a actor that can be called in an async context Task { - let modelLocalPath = distURL.appending(path: modelPath).path() + let modelLocalPath = bundleURL.appending(path: modelPath).path() // Step 0: load the engine await engine.reload(modelPath: modelLocalPath, modelLib: modelLib) diff --git a/ios/MLCEngineExample/mlc-package-config.json b/ios/MLCEngineExample/mlc-package-config.json new file mode 100644 index 0000000000..066fe7fa10 --- /dev/null +++ b/ios/MLCEngineExample/mlc-package-config.json @@ -0,0 +1,11 @@ +{ + "model_list": [ + { + "model": "HF://mlc-ai/Llama-3-8B-Instruct-q3f16_1-MLC", + "model_id": "llama3", + "estimated_vram_bytes": 3316000000, + "bundle_weight": true, + "model_lib": "llama_q3f16_1" + } + ] +} \ No newline at end of file diff --git a/ios/prepare_model_lib.py b/ios/prepare_model_lib.py deleted file mode 100644 index ff56236321..0000000000 --- a/ios/prepare_model_lib.py +++ /dev/null @@ -1,88 +0,0 @@ -import json -import os -import sys -from tvm.contrib import cc - - -def get_model_libs(lib_path): - global_symbol_map = cc.get_global_symbol_section_map(lib_path) - libs = [] - suffix = "___tvm_dev_mblob" - for name in global_symbol_map.keys(): - if name.endswith(suffix): - model_lib = name[: -len(suffix)] - if model_lib.startswith("_"): - model_lib = model_lib[1:] - libs.append(model_lib) - return libs - - -def main(): - app_config_path = "MLCChat/app-config.json" - app_config = json.load(open(app_config_path, "r")) - artifact_path = os.path.abspath(os.path.join("..", "dist")) - - tar_list = [] - model_set = set() - - for model, model_lib_path in app_config["model_lib_path_for_prepare_libs"].items(): - paths = [ - os.path.join(artifact_path, model_lib_path), - os.path.join(artifact_path, "prebuilt", model_lib_path), - os.path.join(model_lib_path), - ] - valid_paths = [p for p in paths if os.path.isfile(p)] - if not valid_paths: - raise RuntimeError( - f"Cannot find iOS lib for {model} from the following candidate paths: {paths}" - ) - tar_list.append(valid_paths[0]) - model_set.add(model) - - lib_path = os.path.join("build", "lib", "libmodel_iphone.a") - - cc.create_staticlib(lib_path, tar_list) - available_model_libs = get_model_libs(lib_path) - print(f"Creating lib from {tar_list}..") - print(f"Validating the library {lib_path}...") - print( - f"List of available model libs packaged: {available_model_libs}," - " if we have '-' in the model_lib string, it will be turned into '_'" - ) - global_symbol_map = cc.get_global_symbol_section_map(lib_path) - error_happened = False - for item in app_config["model_list"]: - model_lib = item["model_lib"] - model_id = item["model_id"] - if model_lib not in model_set: - print( - f"ValidationError: model_lib={model_lib} specified for model_id={model_id} " - "is not included in model_lib_path_for_prepare_libs field, " - "This will cause the specific model not being able to load, " - f"please check {app_config_path}." - ) - error_happened = True - - model_prefix_pattern = model_lib.replace("-", "_") + "___tvm_dev_mblob" - if ( - model_prefix_pattern not in global_symbol_map - and "_" + model_prefix_pattern not in global_symbol_map - ): - model_lib_path = app_config["model_lib_path_for_prepare_libs"][model_lib] - print( - "ValidationError:\n" - f"\tmodel_lib {model_lib} requested in {app_config_path} is not found in {lib_path}\n" - f"\tspecifically the model_lib for {model_lib_path} in model_lib_path_for_prepare_libs.\n" - f"\tcurrent available model_libs in {lib_path}: {available_model_libs}" - ) - error_happened = True - - if not error_happened: - print("Validation pass") - else: - print("Validation failed") - exit(255) - - -if __name__ == "__main__": - main() diff --git a/ios/prepare_libs.sh b/ios/prepare_package.sh similarity index 94% rename from ios/prepare_libs.sh rename to ios/prepare_package.sh index 3885024b51..695c113760 100755 --- a/ios/prepare_libs.sh +++ b/ios/prepare_package.sh @@ -72,4 +72,5 @@ cd .. rm -rf MLCSwift/tvm_home ln -s ../../3rdparty/tvm MLCSwift/tvm_home -python prepare_model_lib.py +python -m mlc_llm package MLCChat/mlc-package-config.json --device iphone -o dist +cp build/lib/* dist/lib/ diff --git a/ios/prepare_params.sh b/ios/prepare_params.sh deleted file mode 100755 index 0ac293228c..0000000000 --- a/ios/prepare_params.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/bin/bash -set -euxo pipefail - -# NOTE: this is optional, prepackage weight into app -rm -rf dist -mkdir -p dist - -declare -a builtin_list=( - "Mistral-7B-Instruct-v0.2-q3f16_1" - # "OpenHermes-2.5-Mistral-7B-q3f16_1" - # "Llama-2-7b-chat-hf-q3f16_1" - # "RedPajama-INCITE-Chat-3B-v1-q4f16_1" - # "vicuna-v1-7b-q3f16_0" - # "rwkv-raven-1b5-q8f16_0" - # "rwkv-raven-3b-q8f16_0" - # "rwkv-raven-7b-q8f16_0" -) - -for model in "${builtin_list[@]}"; do - if [ -d ../dist/$model/params ]; then - cp -r ../dist/$model/params dist/$model - elif [ -d ../dist/prebuilt/$model ]; then - cp -r ../dist/prebuilt/$model dist/$model - elif [ -d ../dist/prebuilt/mlc-chat-$model ]; then - cp -r ../dist/prebuilt/mlc-chat-$model dist/$model - elif [ -d ../dist/prebuilt/$model-MLC ]; then - cp -r ../dist/prebuilt/$model-MLC dist/$model - else - echo "Cannot find prebuilt weights for " $model - exit 1 - fi -done diff --git a/python/mlc_llm/__main__.py b/python/mlc_llm/__main__.py index 857cfc479a..ef34f5a40e 100644 --- a/python/mlc_llm/__main__.py +++ b/python/mlc_llm/__main__.py @@ -14,7 +14,7 @@ def main(): parser.add_argument( "subcommand", type=str, - choices=["compile", "convert_weight", "gen_config", "chat", "serve", "bench"], + choices=["compile", "convert_weight", "gen_config", "chat", "serve", "bench", "package"], help="Subcommand to to run. (choices: %(choices)s)", ) parsed = parser.parse_args(sys.argv[1:2]) @@ -42,6 +42,10 @@ def main(): elif parsed.subcommand == "bench": from mlc_llm.cli import bench as cli + cli.main(sys.argv[2:]) + elif parsed.subcommand == "package": + from mlc_llm.cli import package as cli + cli.main(sys.argv[2:]) else: raise ValueError(f"Unknown subcommand {parsed.subcommand}") diff --git a/python/mlc_llm/chat_module.py b/python/mlc_llm/chat_module.py index 2efc3ec9b9..72d1e5315e 100644 --- a/python/mlc_llm/chat_module.py +++ b/python/mlc_llm/chat_module.py @@ -781,13 +781,11 @@ def __init__( # pylint: disable=too-many-arguments logger.info("Now compiling model lib on device...") from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel - self.model_lib = str( - jit.jit( - model_path=Path(self.model_path), - chat_config=asdict(self.chat_config), - device=self.device, - ) - ) + self.model_lib = jit.jit( + model_path=Path(self.model_path), + chat_config=asdict(self.chat_config), + device=self.device, + ).model_lib_path _inspect_model_lib_metadata_memory_usage(self.model_lib, self.config_file_path) # 5. Call reload diff --git a/python/mlc_llm/cli/package.py b/python/mlc_llm/cli/package.py new file mode 100644 index 0000000000..f605858d67 --- /dev/null +++ b/python/mlc_llm/cli/package.py @@ -0,0 +1,55 @@ +"""Command line entrypoint of package.""" + +from pathlib import Path +from typing import Union + +from mlc_llm.help import HELP +from mlc_llm.interface.package import package +from mlc_llm.support.argparse import ArgumentParser + + +def main(argv): + """Parse command line arguments and call `mlc_llm.interface.package`.""" + parser = ArgumentParser("MLC LLM Package CLI") + + def _parse_package_config(path: Union[str, Path]) -> Path: + path = Path(path) + if not path.exists(): + raise ValueError( + f"Path {str(path)} is expected to be a JSON file, but the file does not exist." + ) + if not path.is_file(): + raise ValueError(f"Path {str(path)} is expected to be a JSON file.") + return path + + def _parse_output(path: Union[str, Path]) -> Path: + path = Path(path) + if not path.is_dir(): + path.mkdir(parents=True, exist_ok=True) + return path + + parser.add_argument( + "package_config", + type=_parse_package_config, + help=HELP["config_package"] + " (required)", + ) + parser.add_argument( + "--device", + type=str, + choices=["iphone", "android"], + required=True, + help=HELP["device_package"] + " (required)", + ) + parser.add_argument( + "--output", + "-o", + type=_parse_output, + required=True, + help=HELP["output_package"] + " (required)", + ) + parsed = parser.parse_args(argv) + package( + package_config_path=parsed.package_config, + device=parsed.device, + output=parsed.output, + ) diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index f6ef6c38af..6af5495a77 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -213,5 +213,16 @@ For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" to specify the eagle-style speculative decoding. Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. +""", + "config_package": """ +The path to "mlc-package-config.json" which is used for package build. +See "ios/MLCChat/mlc-package-config.json" as an example. +""", + "device_package": """ +The device to build package for. +Options are ["iphone", "android"]. +""", + "output_package": """ +The path of output directory for the package build outputs. """, } diff --git a/python/mlc_llm/interface/jit.py b/python/mlc_llm/interface/jit.py index e999a36468..dd0179b811 100644 --- a/python/mlc_llm/interface/jit.py +++ b/python/mlc_llm/interface/jit.py @@ -10,7 +10,7 @@ import sys import tempfile from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Optional, Union from tvm.runtime import Device @@ -30,7 +30,20 @@ logger = logging.getLogger(__name__) -def jit(model_path: Path, chat_config: Dict[str, Any], device: Device) -> Path: +@dataclasses.dataclass +class JITResult: + """The jit compilation result class.""" + + model_lib_path: str + system_lib_prefix: Optional[str] = None + + +def jit( # pylint: disable=too-many-locals,too-many-statements + model_path: Path, + chat_config: Dict[str, Any], + device: Union[Device, str], + system_lib_prefix: Optional[str] = None, +) -> JITResult: """Just-in-time compile a MLC-Chat model.""" logger.info( "%s = %s. Can be one of: ON, OFF, REDO, READONLY", @@ -44,6 +57,7 @@ def jit(model_path: Path, chat_config: Dict[str, Any], device: Device) -> Path: mlc_chat_config = json.load(in_file) model_type = mlc_chat_config.pop("model_type") quantization = mlc_chat_config.pop("quantization") + lib_suffix = MLC_DSO_SUFFIX if device not in ["iphone", "android"] else "tar" def _get_optimization_flags() -> str: opt = chat_config.pop("opt", None) @@ -73,9 +87,9 @@ def _get_model_config() -> Dict[str, Any]: model_config[field.name] = value return MODELS[model_type].config.from_dict(model_config).asdict() - def _run_jit(opt: str, overrides: str, device: str, dst: str): + def _run_jit(opt: str, overrides: str, device: str, system_lib_prefix: Optional[str], dst: str): with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir: - dso_path = os.path.join(tmp_dir, f"lib.{MLC_DSO_SUFFIX}") + dso_path = os.path.join(tmp_dir, f"lib.{lib_suffix}") cmd = [ sys.executable, "-m", @@ -91,6 +105,8 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): "--output", dso_path, ] + if system_lib_prefix: + cmd += ["--system-lib-prefix", system_lib_prefix + "_"] logger.info("Compiling using commands below:") logger.info("%s", blue(shlex.join(cmd))) subprocess.run(cmd, check=False, env=os.environ) @@ -105,10 +121,23 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): "model_config": _get_model_config(), "overrides": _get_overrides(), "opt": _get_optimization_flags(), - "device": device2str(device), + "device": device2str(device) if isinstance(device, Device) else device, "model_type": model_type, "quantization": quantization, } + if device in ["iphone", "android"]: + if system_lib_prefix is None: + system_lib_hash_value = hashlib.md5( + json.dumps( + hash_key, + sort_keys=True, + indent=2, + ).encode("utf-8") + ).hexdigest() + system_lib_prefix = f"{model_type}_{quantization}_{system_lib_hash_value}".replace( + "-", "_" + ) + hash_key["system_lib_prefix"] = system_lib_prefix hash_value = hashlib.md5( json.dumps( hash_key, @@ -116,10 +145,10 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): indent=2, ).encode("utf-8") ).hexdigest() - dst = MLC_CACHE_DIR / "model_lib" / f"{hash_value}.so" + dst = MLC_CACHE_DIR / "model_lib" / f"{hash_value}.{lib_suffix}" if dst.is_file() and MLC_JIT_POLICY in ["ON", "READONLY"]: logger.info("Using cached model lib: %s", bold(str(dst))) - return dst + return JITResult(str(dst), system_lib_prefix) if MLC_JIT_POLICY == "READONLY": raise RuntimeError( "No cached model lib found, and JIT is disabled by MLC_JIT_POLICY=READONLY" @@ -128,6 +157,7 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): opt=hash_key["opt"], overrides=hash_key["overrides"], device=hash_key["device"], + system_lib_prefix=system_lib_prefix, dst=str(dst), ) - return dst + return JITResult(str(dst), system_lib_prefix) diff --git a/python/mlc_llm/interface/package.py b/python/mlc_llm/interface/package.py new file mode 100644 index 0000000000..335c57d1db --- /dev/null +++ b/python/mlc_llm/interface/package.py @@ -0,0 +1,274 @@ +"""Python entrypoint of package.""" + +import dataclasses +import json +import os +import shutil +import sys +from dataclasses import asdict +from pathlib import Path +from typing import List, Literal + +from tvm.contrib import cc + +from mlc_llm.chat_module import ChatConfig, _get_chat_config, _get_model_path +from mlc_llm.interface import jit +from mlc_llm.support import logging, style + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +def _get_model_libs(lib_path: Path) -> List[str]: + """Get the model lib prefixes in the given static lib path.""" + global_symbol_map = cc.get_global_symbol_section_map(lib_path) + libs = [] + suffix = "___tvm_dev_mblob" + for name, _ in global_symbol_map.items(): + if name.endswith(suffix): + model_lib = name[: -len(suffix)] + if model_lib.startswith("_"): + model_lib = model_lib[1:] + libs.append(model_lib) + return libs + + +def validate_model_lib( # pylint: disable=too-many-locals + app_config_path: Path, device: Literal["iphone", "android"], output: Path +) -> None: + """Validate the model lib prefixes of model libraries.""" + # pylint: disable=import-outside-toplevel,redefined-outer-name,shadowed-import,reimported + if device == "android": + from tvm.contrib import ndk as cc + else: + from tvm.contrib import cc + # pylint: enable=import-outside-toplevel,redefined-outer-name,shadowed-import,reimported + + with open(app_config_path, "r", encoding="utf-8") as file: + app_config = json.load(file) + + tar_list = [] + model_set = set() + + for model, model_lib_path in app_config["model_lib_path_for_prepare_libs"].items(): + model_lib_path = os.path.join(model_lib_path) + lib_path_valid = os.path.isfile(model_lib_path) + if not lib_path_valid: + raise RuntimeError(f"Cannot find file {model_lib_path} as an {device} model library") + tar_list.append(model_lib_path) + model_set.add(model) + + os.makedirs(output / "lib", exist_ok=True) + lib_path = ( + output / "lib" / ("libmodel_iphone.a" if device == "iphone" else "libmodel_android.a") + ) + + cc.create_staticlib(lib_path, tar_list) + available_model_libs = _get_model_libs(lib_path) + logger.info("Creating lib from %s", str(tar_list)) + logger.info("Validating the library %s", str(lib_path)) + logger.info( + "List of available model libs packaged: %s," + " if we have '-' in the model_lib string, it will be turned into '_'", + str(available_model_libs), + ) + global_symbol_map = cc.get_global_symbol_section_map(lib_path) + error_happened = False + for item in app_config["model_list"]: + model_lib = item["model_lib"] + model_id = item["model_id"] + if model_lib not in model_set: + logger.info( + "ValidationError: model_lib=%s specified for model_id=%s " + "is not included in model_lib_path_for_prepare_libs field, " + "This will cause the specific model not being able to load, " + "please check %s.", + model_lib, + model_id, + str(app_config_path), + ) + error_happened = True + + model_prefix_pattern = model_lib.replace("-", "_") + "___tvm_dev_mblob" + if ( + model_prefix_pattern not in global_symbol_map + and "_" + model_prefix_pattern not in global_symbol_map + ): + model_lib_path = app_config["model_lib_path_for_prepare_libs"][model_lib] + logger.info( + "ValidationError:\n" + "\tmodel_lib %s requested in %s is not found in %s\n" + "\tspecifically the model_lib for %s in model_lib_path_for_prepare_libs.\n" + "\tcurrent available model_libs in %s: %s", + model_lib, + str(app_config_path), + str(lib_path), + model_lib_path, + str(lib_path), + str(available_model_libs), + ) + error_happened = True + + if not error_happened: + logger.info(style.green("Validation pass")) + else: + logger.info(style.red("Validation failed")) + sys.exit(255) + + +def package( # pylint: disable=too-many-locals,too-many-statements,too-many-branches + package_config_path: Path, + device: Literal["iphone", "android"], + output: Path, +) -> None: + """Python entrypoint of package.""" + # - Read package config. + with open(package_config_path, "r", encoding="utf-8") as file: + package_config = json.load(file) + if not isinstance(package_config, dict): + raise ValueError( + "The content of MLC package config is expected to be a dict with " + f'field "model_list". However, the content of "{package_config_path}" is not a dict.' + ) + + # - Create the bundle directory. + bundle_dir = output / "bundle" + os.makedirs(bundle_dir, exist_ok=True) + # Clean up all the directories in `output/bundle`. + logger.info('Clean up all directories under "%s"', str(bundle_dir)) + for content_path in bundle_dir.iterdir(): + if content_path.is_dir(): + shutil.rmtree(content_path) + + # - Process each model, and prepare the app config. + app_config_model_list = [] + + model_entries = package_config.get("model_list", []) + if not isinstance(model_entries, list): + raise ValueError('The "model_list" in "mlc-package-config.json" is expected to be a list.') + model_lib_path_for_prepare_libs = package_config.get("model_lib_path_for_prepare_libs", {}) + if not isinstance(model_lib_path_for_prepare_libs, dict): + raise ValueError( + 'The "model_lib_path_for_prepare_libs" in "mlc-package-config.json" is expected to be ' + "a dict." + ) + + for model_entry in package_config.get("model_list", []): + # - Parse model entry. + if not isinstance(model_entry, dict): + raise ValueError('The element of "model_list" is expected to be a dict.') + model = model_entry["model"] + model_id = model_entry["model_id"] + bundle_weight = model_entry.get("bundle_weight", False) + overrides = model_entry.get("overrides", {}) + model_lib = model_entry.get("model_lib", None) + estimated_vram_bytes = model_entry["estimated_vram_bytes"] + if not isinstance(model, str): + raise ValueError('The value of "model" in "model_list" is expected to be a string.') + if not isinstance(model_id, str): + raise ValueError('The value of "model_id" in "model_list" is expected to be a string.') + if not isinstance(bundle_weight, bool): + raise ValueError( + 'The value of "bundle_weight" in "model_list" is expected to be a boolean.' + ) + if not isinstance(overrides, dict): + raise ValueError('The value of "overrides" in "model_list" is expected to be a dict.') + if model_lib is not None and not isinstance(model_lib, str): + raise ValueError('The value of "model_lib" in "model_list" is expected to be string.') + + # - Load model config. Download happens when needed. + model_path_and_config_file_path = _get_model_path(model) + model_path = Path(model_path_and_config_file_path[0]) + config_file_path = model_path_and_config_file_path[1] + chat_config = _get_chat_config( + config_file_path, user_chat_config=ChatConfig.from_dict(overrides) + ) + # - Jit compile if the model lib path is not specified. + model_lib_path = ( + model_lib_path_for_prepare_libs.get(model_lib, None) if model_lib is not None else None + ) + if model_lib_path is None: + if model_lib is None: + logger.info( + 'Model lib is not specified for model "%s". Now jit compile the model library.', + model_id, + ) + else: + logger.info( + 'Model lib path for "%s" is not specified in "model_lib_path_for_prepare_libs".' + "Now jit compile the model library.", + model_lib, + ) + model_lib_path, model_lib = dataclasses.astuple( + jit.jit( + model_path=model_path, + chat_config=asdict(chat_config), + device=device, + system_lib_prefix=model_lib, + ) + ) + assert model_lib is not None + model_lib_path_for_prepare_libs[model_lib] = model_lib_path + + # - Set "model_url"/"model_path" and "model_id" + app_config_model_entry = {} + is_local_model = not model.startswith("HF://") and not model.startswith("https://") + app_config_model_entry["model_id"] = model_id + app_config_model_entry["model_lib"] = model_lib + + # - Bundle weight + if is_local_model and not bundle_weight: + raise ValueError( + f'Model "{model}" in "model_list" is a local path.' + f'Please set \'"bundle_weight": true\' in the entry of model "{model}".' + ) + if bundle_weight: + if not os.path.isfile(model_path / "ndarray-cache.json"): + raise ValueError( + f'Bundle weight is set for model "{model}". However, model weights are not' + f'found under the directory "{model}". ' + + ( + "Please follow https://llm.mlc.ai/docs/compilation/convert_weights.html to " + "convert model weights." + if is_local_model + else "Please report this issue to https://github.com/mlc-ai/mlc-llm/issues." + ) + ) + # Overwrite the model weight directory in bundle. + bundle_model_weight_path = bundle_dir / model_path.name + logger.info( + 'Bundle weight for model "%s". Copying weights from "%s" to "%s".', + model_id, + model_path, + bundle_model_weight_path, + ) + if bundle_model_weight_path.exists(): + shutil.rmtree(bundle_model_weight_path) + shutil.copytree(model_path, bundle_model_weight_path) + app_config_model_entry["model_path"] = model_path.name + else: + app_config_model_entry["model_url"] = model.replace("HF://", "https://huggingface.co/") + + # - estimated_vram_bytes + app_config_model_entry["estimated_vram_bytes"] = estimated_vram_bytes + + app_config_model_list.append(app_config_model_entry) + + # - Dump "mlc-app-config.json". + app_config_json_str = json.dumps( + { + "model_list": app_config_model_list, + "model_lib_path_for_prepare_libs": model_lib_path_for_prepare_libs, + }, + indent=2, + ) + app_config_path = bundle_dir / "mlc-app-config.json" + with open(app_config_path, "w", encoding="utf-8") as file: + print(app_config_json_str, file=file) + logger.info( + 'Dump the app config below to "dist/bundle/mlc-app-config.json":\n%s', + style.green(app_config_json_str), + ) + + # - Validate model libraries. + validate_model_lib(app_config_path, device, output) diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index e0d7160ece..641c8f6ed5 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -105,13 +105,11 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: # Run jit if model_lib is not provided from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel - model_lib = str( - jit.jit( - model_path=Path(model_path), - chat_config=asdict(chat_config), - device=device, - ) - ) + model_lib = jit.jit( + model_path=Path(model_path), + chat_config=asdict(chat_config), + device=device, + ).model_lib_path return model_path, model_lib model_args: List[Tuple[str, str]] = [_convert_model_info(model) for model in models] From 8a3198600ab8fae781884892b14d03b51c743032 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 8 May 2024 04:55:12 -0700 Subject: [PATCH 285/531] Increase the timeout in PopenServer (#2298) --- python/mlc_llm/help.py | 2 +- python/mlc_llm/serve/server/popen_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index 6af5495a77..a9b8917990 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -210,7 +210,7 @@ "engine_config_serve": """ The MLCEngine execution configuration. Currently speculative decoding mode is specified via engine config. -For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" to +For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=eagle'" to specify the eagle-style speculative decoding. Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. """, diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index dcecd25795..e9e1c8e9a9 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -95,7 +95,7 @@ def start(self) -> None: # pylint: disable=too-many-branches # Try to query the server until it is ready. openai_v1_models_url = f"http://{self.host}:{str(self.port)}/v1/models" query_result = None - timeout = 60 + timeout = 120 attempts = 0.0 while query_result is None and attempts < timeout: try: From 65f97160133c1264ca85bea5e940199ca778d811 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Wed, 8 May 2024 17:25:39 +0530 Subject: [PATCH 286/531] [LLM-CHAT] Enable gpu softmax for penality softmax (#2288) 1. Avoid the cpu softmax for different penality config by having copy sync to gpu and use gpu softmax. 2. Disable decode token time counter for first token. --- cpp/llm_chat.cc | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 9485ccad02..93de185eb2 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -710,7 +710,7 @@ class LLMChat { /*! \brief reset the runtime stats. */ void ResetRuntimeStats() { this->prefill_total_tokens = 0; - this->decode_total_tokens = 0; + this->decode_total_tokens = -1; this->embed_total_time = 0; this->prefill_total_time = 0; this->decode_total_time = 0; @@ -1031,8 +1031,8 @@ class LLMChat { int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config); auto tend = std::chrono::high_resolution_clock::now(); - - this->decode_total_time += static_cast((tend - tstart).count()) / 1e9; + if (this->decode_total_tokens >= 0) + this->decode_total_time += static_cast((tend - tstart).count()) / 1e9; this->decode_total_tokens += 1; this->ProcessNextToken(next_token, generation_config); } @@ -1223,14 +1223,16 @@ class LLMChat { if (gen_presence_penalty != 0.0f || gen_frequency_penalty != 0.0f) { this->UpdateLogitsOrProbOnCPUSync(logits_on_device); this->ApplyPresenceAndFrequencyPenaltyOnCPU(gen_presence_penalty, gen_frequency_penalty); + this->UpdateLogitsOrProbOnGPUSync(logits_on_device); if (gen_temperature >= 1e-6f) { - this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature); + this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_)); } } else if (gen_repetition_penalty != 1.0f) { this->UpdateLogitsOrProbOnCPUSync(logits_on_device); this->ApplyRepetitionPenaltyOnCPU(gen_repetition_penalty); + this->UpdateLogitsOrProbOnGPUSync(logits_on_device); if (gen_temperature >= 1e-6f) { - this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature); + this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_)); } } else { if (gen_temperature < 1e-6f) { @@ -1505,6 +1507,12 @@ class LLMChat { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } + void UpdateLogitsOrProbOnGPUSync(NDArray logits_or_prob) { + logits_or_prob.CopyFrom(logits_on_cpu_); + + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + } + // Clear kv cache void ResetKVCache() { ft_.reset_kv_cache_func_(kv_cache_); @@ -1547,7 +1555,7 @@ class LLMChat { double decode_total_time = 0; double sample_total_time = 0; double prefill_total_time = 0; - int64_t decode_total_tokens = 0; + int64_t decode_total_tokens = -1; int64_t prefill_total_tokens = 0; //---------------------------- // Conversation From 1bd1ab08863d29264559c80d69c77f02bcc28ee1 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 8 May 2024 12:36:17 -0400 Subject: [PATCH 287/531] [iOS][REFACTOR] Restructure the iOS folders (#2299) Move MLCChat to its own sub folder minor improvements to package. --- docs/deploy/ios.rst | 29 ++++++-- .../MLCChat.xcodeproj/project.pbxproj | 22 ++++-- .../contents.xcworkspacedata | 0 .../xcshareddata/IDEWorkspaceChecks.plist | 0 .../xcshareddata/WorkspaceSettings.xcsettings | 0 .../xcshareddata/xcschemes/MLCChat.xcscheme | 0 .../AccentColor.colorset/Contents.json | 0 .../AppIcon.appiconset/Contents.json | 0 .../AppIcon.appiconset/mlc-logo.png | Bin .../Assets.xcassets/Contents.json | 0 .../{ => MLCChat}/Common/Constants.swift | 0 ios/MLCChat/{ => MLCChat}/Info.plist | 0 .../{ => MLCChat}/MLCChat.entitlements | 0 ios/MLCChat/{ => MLCChat}/MLCChatApp.swift | 0 .../{ => MLCChat}/Models/AppConfig.swift | 0 .../{ => MLCChat}/Models/ModelConfig.swift | 0 .../{ => MLCChat}/Models/ParamsConfig.swift | 0 .../Preview Assets.xcassets/Contents.json | 0 .../{ => MLCChat}/States/AppState.swift | 0 .../{ => MLCChat}/States/ChatState.swift | 0 .../{ => MLCChat}/States/ModelState.swift | 0 .../{ => MLCChat}/Views/ChatView.swift | 0 .../{ => MLCChat}/Views/ImageProcessing.swift | 0 .../{ => MLCChat}/Views/MessageView.swift | 0 .../{ => MLCChat}/Views/ModelView.swift | 0 .../{ => MLCChat}/Views/StartView.swift | 0 ios/MLCChat/README.md | 6 ++ ios/MLCChat/app-config.json | 34 --------- ios/MLCChat/mlc-package-config.json | 3 +- ios/MLCChat/prepare_package.sh | 10 +++ .../project.pbxproj | 19 +++-- ios/MLCEngineExample/README.md | 6 ++ ios/MLCEngineExample/prepare_package.sh | 10 +++ .../Sources/Swift/OpenAIProtocol.swift | 6 +- ios/{prepare_package.sh => prepare_libs.sh} | 5 +- python/mlc_llm/interface/jit.py | 20 ++++-- python/mlc_llm/interface/package.py | 65 ++++++++++-------- 37 files changed, 139 insertions(+), 96 deletions(-) rename ios/{ => MLCChat}/MLCChat.xcodeproj/project.pbxproj (97%) rename ios/{ => MLCChat}/MLCChat.xcodeproj/project.xcworkspace/contents.xcworkspacedata (100%) rename ios/{ => MLCChat}/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist (100%) rename ios/{ => MLCChat}/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings (100%) rename ios/{ => MLCChat}/MLCChat.xcodeproj/xcshareddata/xcschemes/MLCChat.xcscheme (100%) rename ios/MLCChat/{ => MLCChat}/Assets.xcassets/AccentColor.colorset/Contents.json (100%) rename ios/MLCChat/{ => MLCChat}/Assets.xcassets/AppIcon.appiconset/Contents.json (100%) rename ios/MLCChat/{ => MLCChat}/Assets.xcassets/AppIcon.appiconset/mlc-logo.png (100%) rename ios/MLCChat/{ => MLCChat}/Assets.xcassets/Contents.json (100%) rename ios/MLCChat/{ => MLCChat}/Common/Constants.swift (100%) rename ios/MLCChat/{ => MLCChat}/Info.plist (100%) rename ios/MLCChat/{ => MLCChat}/MLCChat.entitlements (100%) rename ios/MLCChat/{ => MLCChat}/MLCChatApp.swift (100%) rename ios/MLCChat/{ => MLCChat}/Models/AppConfig.swift (100%) rename ios/MLCChat/{ => MLCChat}/Models/ModelConfig.swift (100%) rename ios/MLCChat/{ => MLCChat}/Models/ParamsConfig.swift (100%) rename ios/MLCChat/{ => MLCChat}/Preview Content/Preview Assets.xcassets/Contents.json (100%) rename ios/MLCChat/{ => MLCChat}/States/AppState.swift (100%) rename ios/MLCChat/{ => MLCChat}/States/ChatState.swift (100%) rename ios/MLCChat/{ => MLCChat}/States/ModelState.swift (100%) rename ios/MLCChat/{ => MLCChat}/Views/ChatView.swift (100%) rename ios/MLCChat/{ => MLCChat}/Views/ImageProcessing.swift (100%) rename ios/MLCChat/{ => MLCChat}/Views/MessageView.swift (100%) rename ios/MLCChat/{ => MLCChat}/Views/ModelView.swift (100%) rename ios/MLCChat/{ => MLCChat}/Views/StartView.swift (100%) create mode 100644 ios/MLCChat/README.md delete mode 100644 ios/MLCChat/app-config.json create mode 100755 ios/MLCChat/prepare_package.sh create mode 100755 ios/MLCEngineExample/prepare_package.sh rename ios/{prepare_package.sh => prepare_libs.sh} (93%) diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index d326a53fbb..b90c48a84d 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -61,6 +61,7 @@ We have a one-line command to build and prepare all the model libraries: .. code:: bash + cd /path/to/MLCChat ./prepare_package.sh This command mainly executes the following two steps: @@ -89,6 +90,17 @@ Please make sure all the following files exist in ``./dist/``. libtvm_runtime.a # TVM Unity runtime +.. note:: + + We leverage a local JIT cache to avoid repetitive compilation of the same input. + However, sometimes it is helpful to force rebuild when we have a new compiler update + or when something goes wrong with the ached library. + You can do so by setting the environment variable ``MLC_JIT_POLICY=REDO`` + + .. code:: bash + + MLC_JIT_POLICY=REDO ./prepare_package.sh + .. _ios-bundle-model-weights: Step 3. (Optional) Bundle model weights into the app @@ -129,7 +141,7 @@ The outcome of running ``prepare_package.sh`` should be as follows: Step 4. Build iOS App ^^^^^^^^^^^^^^^^^^^^^ -Open ``./ios/MLCChat.xcodeproj`` using Xcode. Note that you will need an +Open ``./ios/MLCChat/MLCChat.xcodeproj`` using Xcode. Note that you will need an Apple Developer Account to use Xcode, and you may be prompted to use your own developer team credential and product bundle identifier. @@ -232,7 +244,7 @@ Example: } ], "model_lib_path_for_prepare_libs": { - "gpt_neox_q4f16_1": "../dist/lib/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar" + "gpt_neox_q4f16_1": "../../dist/lib/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar" } } @@ -334,15 +346,18 @@ Build Apps with MLC Swift API We also provide a Swift package that you can use to build your own app. The package is located under ``ios/MLCSwift``. -- First make sure you have run the same steps listed - in the previous section. This will give us the necessary libraries - under ``/path/to/ios/build/lib``. -- Then you can add ``ios/MLCSwift`` package to your app in Xcode. +- First, create `mlc-package-config.json` and `prepare_package.sh` in your project folder. + You do so by copying the files in MLCChat folder. + Run `prepare_package.sh` + This will give us the necessary libraries under ``/path/to/project/dist``. +- Under "Build phases", add ``/path/to/project/dist/bundle`` this will copying + this folder into your app to include bundled weights and configs. +- Add ``ios/MLCSwift`` package to your app in Xcode. Under "Frameworks, Libraries, and Embedded Content", click add package dependencies and add local package that points to ``ios/MLCSwift``. - Finally, we need to add the libraries dependencies. Under build settings: - - Add library search path ``/path/to/ios/build/lib``. + - Add library search path ``/path/to/project/dist/lib``. - Add the following items to "other linker flags". .. code:: diff --git a/ios/MLCChat.xcodeproj/project.pbxproj b/ios/MLCChat/MLCChat.xcodeproj/project.pbxproj similarity index 97% rename from ios/MLCChat.xcodeproj/project.pbxproj rename to ios/MLCChat/MLCChat.xcodeproj/project.pbxproj index 8b390e1401..3580a5d200 100644 --- a/ios/MLCChat.xcodeproj/project.pbxproj +++ b/ios/MLCChat/MLCChat.xcodeproj/project.pbxproj @@ -3,7 +3,7 @@ archiveVersion = 1; classes = { }; - objectVersion = 56; + objectVersion = 60; objects = { /* Begin PBXBuildFile section */ @@ -16,13 +16,13 @@ AEC27EFA2A85C2AC00254E67 /* ParamsConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */; }; AEC27EFC2A85C3B000254E67 /* AppConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27EFB2A85C3B000254E67 /* AppConfig.swift */; }; AEC27F022A86337E00254E67 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27F012A86337E00254E67 /* Constants.swift */; }; + C04105DD2BEBBEA6005A434D /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = C04105DC2BEBBEA6005A434D /* MLCSwift */; }; C0D643B329F99A7F004DDAA4 /* MLCChatApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */; }; C0D643B729F99A80004DDAA4 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B629F99A80004DDAA4 /* Assets.xcassets */; }; C0D643BA29F99A80004DDAA4 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */; }; C0D643C429F99B07004DDAA4 /* ChatView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C229F99B07004DDAA4 /* ChatView.swift */; }; C0D643C829F99B34004DDAA4 /* MessageView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C729F99B34004DDAA4 /* MessageView.swift */; }; C0DDBDF62A39103F00E9D060 /* ChatState.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C029F99B07004DDAA4 /* ChatState.swift */; }; - C0DDBE0D2A3BCD8000E9D060 /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = C0DDBE0C2A3BCD8000E9D060 /* MLCSwift */; }; F3C280002BEB16ED00F1E016 /* bundle in CopyFiles */ = {isa = PBXBuildFile; fileRef = F3C27FFF2BEB16ED00F1E016 /* bundle */; }; /* End PBXBuildFile section */ @@ -60,7 +60,6 @@ AEC27EFB2A85C3B000254E67 /* AppConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppConfig.swift; sourceTree = ""; }; AEC27F012A86337E00254E67 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = ""; }; C06A74E629F9A1DF00BC4BE6 /* MLCChat.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLCChat.entitlements; sourceTree = ""; }; - C09834182A16F4CB00A05B51 /* app-config.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = "app-config.json"; sourceTree = ""; }; C0D643AF29F99A7F004DDAA4 /* MLCChat.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLCChat.app; sourceTree = BUILT_PRODUCTS_DIR; }; C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLCChatApp.swift; sourceTree = ""; }; C0D643B629F99A80004DDAA4 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; @@ -77,7 +76,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - C0DDBE0D2A3BCD8000E9D060 /* MLCSwift in Frameworks */, + C04105DD2BEBBEA6005A434D /* MLCSwift in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -146,7 +145,6 @@ C0D643B129F99A7F004DDAA4 /* MLCChat */ = { isa = PBXGroup; children = ( - C09834182A16F4CB00A05B51 /* app-config.json */, AEC27F032A86338800254E67 /* Common */, AEC27EF82A85C29000254E67 /* Models */, AEC27EFF2A85EE2800254E67 /* States */, @@ -201,7 +199,7 @@ ); name = MLCChat; packageProductDependencies = ( - C0DDBE0C2A3BCD8000E9D060 /* MLCSwift */, + C04105DC2BEBBEA6005A434D /* MLCSwift */, ); productName = MLCChat; productReference = C0D643AF29F99A7F004DDAA4 /* MLCChat.app */; @@ -232,6 +230,9 @@ Base, ); mainGroup = C0D643A629F99A7F004DDAA4; + packageReferences = ( + C04105DB2BEBBEA6005A434D /* XCLocalSwiftPackageReference "../MLCSwift" */, + ); productRefGroup = C0D643B029F99A7F004DDAA4 /* Products */; projectDirPath = ""; projectRoot = ""; @@ -517,8 +518,15 @@ }; /* End XCConfigurationList section */ +/* Begin XCLocalSwiftPackageReference section */ + C04105DB2BEBBEA6005A434D /* XCLocalSwiftPackageReference "../MLCSwift" */ = { + isa = XCLocalSwiftPackageReference; + relativePath = ../MLCSwift; + }; +/* End XCLocalSwiftPackageReference section */ + /* Begin XCSwiftPackageProductDependency section */ - C0DDBE0C2A3BCD8000E9D060 /* MLCSwift */ = { + C04105DC2BEBBEA6005A434D /* MLCSwift */ = { isa = XCSwiftPackageProductDependency; productName = MLCSwift; }; diff --git a/ios/MLCChat.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/ios/MLCChat/MLCChat.xcodeproj/project.xcworkspace/contents.xcworkspacedata similarity index 100% rename from ios/MLCChat.xcodeproj/project.xcworkspace/contents.xcworkspacedata rename to ios/MLCChat/MLCChat.xcodeproj/project.xcworkspace/contents.xcworkspacedata diff --git a/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/ios/MLCChat/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist similarity index 100% rename from ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist rename to ios/MLCChat/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist diff --git a/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings b/ios/MLCChat/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings similarity index 100% rename from ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings rename to ios/MLCChat/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings diff --git a/ios/MLCChat.xcodeproj/xcshareddata/xcschemes/MLCChat.xcscheme b/ios/MLCChat/MLCChat.xcodeproj/xcshareddata/xcschemes/MLCChat.xcscheme similarity index 100% rename from ios/MLCChat.xcodeproj/xcshareddata/xcschemes/MLCChat.xcscheme rename to ios/MLCChat/MLCChat.xcodeproj/xcshareddata/xcschemes/MLCChat.xcscheme diff --git a/ios/MLCChat/Assets.xcassets/AccentColor.colorset/Contents.json b/ios/MLCChat/MLCChat/Assets.xcassets/AccentColor.colorset/Contents.json similarity index 100% rename from ios/MLCChat/Assets.xcassets/AccentColor.colorset/Contents.json rename to ios/MLCChat/MLCChat/Assets.xcassets/AccentColor.colorset/Contents.json diff --git a/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/Contents.json b/ios/MLCChat/MLCChat/Assets.xcassets/AppIcon.appiconset/Contents.json similarity index 100% rename from ios/MLCChat/Assets.xcassets/AppIcon.appiconset/Contents.json rename to ios/MLCChat/MLCChat/Assets.xcassets/AppIcon.appiconset/Contents.json diff --git a/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/mlc-logo.png b/ios/MLCChat/MLCChat/Assets.xcassets/AppIcon.appiconset/mlc-logo.png similarity index 100% rename from ios/MLCChat/Assets.xcassets/AppIcon.appiconset/mlc-logo.png rename to ios/MLCChat/MLCChat/Assets.xcassets/AppIcon.appiconset/mlc-logo.png diff --git a/ios/MLCChat/Assets.xcassets/Contents.json b/ios/MLCChat/MLCChat/Assets.xcassets/Contents.json similarity index 100% rename from ios/MLCChat/Assets.xcassets/Contents.json rename to ios/MLCChat/MLCChat/Assets.xcassets/Contents.json diff --git a/ios/MLCChat/Common/Constants.swift b/ios/MLCChat/MLCChat/Common/Constants.swift similarity index 100% rename from ios/MLCChat/Common/Constants.swift rename to ios/MLCChat/MLCChat/Common/Constants.swift diff --git a/ios/MLCChat/Info.plist b/ios/MLCChat/MLCChat/Info.plist similarity index 100% rename from ios/MLCChat/Info.plist rename to ios/MLCChat/MLCChat/Info.plist diff --git a/ios/MLCChat/MLCChat.entitlements b/ios/MLCChat/MLCChat/MLCChat.entitlements similarity index 100% rename from ios/MLCChat/MLCChat.entitlements rename to ios/MLCChat/MLCChat/MLCChat.entitlements diff --git a/ios/MLCChat/MLCChatApp.swift b/ios/MLCChat/MLCChat/MLCChatApp.swift similarity index 100% rename from ios/MLCChat/MLCChatApp.swift rename to ios/MLCChat/MLCChat/MLCChatApp.swift diff --git a/ios/MLCChat/Models/AppConfig.swift b/ios/MLCChat/MLCChat/Models/AppConfig.swift similarity index 100% rename from ios/MLCChat/Models/AppConfig.swift rename to ios/MLCChat/MLCChat/Models/AppConfig.swift diff --git a/ios/MLCChat/Models/ModelConfig.swift b/ios/MLCChat/MLCChat/Models/ModelConfig.swift similarity index 100% rename from ios/MLCChat/Models/ModelConfig.swift rename to ios/MLCChat/MLCChat/Models/ModelConfig.swift diff --git a/ios/MLCChat/Models/ParamsConfig.swift b/ios/MLCChat/MLCChat/Models/ParamsConfig.swift similarity index 100% rename from ios/MLCChat/Models/ParamsConfig.swift rename to ios/MLCChat/MLCChat/Models/ParamsConfig.swift diff --git a/ios/MLCChat/Preview Content/Preview Assets.xcassets/Contents.json b/ios/MLCChat/MLCChat/Preview Content/Preview Assets.xcassets/Contents.json similarity index 100% rename from ios/MLCChat/Preview Content/Preview Assets.xcassets/Contents.json rename to ios/MLCChat/MLCChat/Preview Content/Preview Assets.xcassets/Contents.json diff --git a/ios/MLCChat/States/AppState.swift b/ios/MLCChat/MLCChat/States/AppState.swift similarity index 100% rename from ios/MLCChat/States/AppState.swift rename to ios/MLCChat/MLCChat/States/AppState.swift diff --git a/ios/MLCChat/States/ChatState.swift b/ios/MLCChat/MLCChat/States/ChatState.swift similarity index 100% rename from ios/MLCChat/States/ChatState.swift rename to ios/MLCChat/MLCChat/States/ChatState.swift diff --git a/ios/MLCChat/States/ModelState.swift b/ios/MLCChat/MLCChat/States/ModelState.swift similarity index 100% rename from ios/MLCChat/States/ModelState.swift rename to ios/MLCChat/MLCChat/States/ModelState.swift diff --git a/ios/MLCChat/Views/ChatView.swift b/ios/MLCChat/MLCChat/Views/ChatView.swift similarity index 100% rename from ios/MLCChat/Views/ChatView.swift rename to ios/MLCChat/MLCChat/Views/ChatView.swift diff --git a/ios/MLCChat/Views/ImageProcessing.swift b/ios/MLCChat/MLCChat/Views/ImageProcessing.swift similarity index 100% rename from ios/MLCChat/Views/ImageProcessing.swift rename to ios/MLCChat/MLCChat/Views/ImageProcessing.swift diff --git a/ios/MLCChat/Views/MessageView.swift b/ios/MLCChat/MLCChat/Views/MessageView.swift similarity index 100% rename from ios/MLCChat/Views/MessageView.swift rename to ios/MLCChat/MLCChat/Views/MessageView.swift diff --git a/ios/MLCChat/Views/ModelView.swift b/ios/MLCChat/MLCChat/Views/ModelView.swift similarity index 100% rename from ios/MLCChat/Views/ModelView.swift rename to ios/MLCChat/MLCChat/Views/ModelView.swift diff --git a/ios/MLCChat/Views/StartView.swift b/ios/MLCChat/MLCChat/Views/StartView.swift similarity index 100% rename from ios/MLCChat/Views/StartView.swift rename to ios/MLCChat/MLCChat/Views/StartView.swift diff --git a/ios/MLCChat/README.md b/ios/MLCChat/README.md new file mode 100644 index 0000000000..831d7eee73 --- /dev/null +++ b/ios/MLCChat/README.md @@ -0,0 +1,6 @@ +# MLC Chat App + +Checkout [Documentation page](https://llm.mlc.ai/docs/deploy/ios.html) for more information. + +- run `./prepare_package.sh` +- open the xcode project diff --git a/ios/MLCChat/app-config.json b/ios/MLCChat/app-config.json deleted file mode 100644 index 1379fc6647..0000000000 --- a/ios/MLCChat/app-config.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "model_list": [ - { - "model_path": "Mistral-7B-Instruct-v0.2-q3f16_1", - "model_id": "Mistral-7B-Instruct-v0.2-q3f16_1", - "model_lib": "mistral_q3f16_1", - "estimated_vram_bytes": 3316000000 - }, - { - "model_url": "https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", - "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", - "model_lib": "gpt_neox_q4f16_1", - "estimated_vram_bytes": 2960000000 - }, - { - "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC", - "model_id": "phi-2-q4f16_1", - "model_lib": "phi_msft_q4f16_1", - "estimated_vram_bytes": 3043000000 - }, - { - "model_url": "https://huggingface.co/mlc-ai/gemma-2b-it-q4f16_1-MLC", - "model_id": "gemma-2b-q4f16_1", - "model_lib": "gemma_q4f16_1", - "estimated_vram_bytes": 3000000000 - } - ], - "model_lib_path_for_prepare_libs": { - "mistral_q3f16_1": "lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q3f16_1-iphone.tar", - "gpt_neox_q4f16_1": "lib/RedPajama-INCITE-Chat-3B-v1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar", - "phi_msft_q4f16_1": "lib/phi-2/phi-2-q4f16_1-iphone.tar", - "gemma_q4f16_1": "lib/gemma-2b-it/gemma-2b-it-q4f16_1-iphone.tar" - } -} diff --git a/ios/MLCChat/mlc-package-config.json b/ios/MLCChat/mlc-package-config.json index db5b29206f..66ca1379f7 100644 --- a/ios/MLCChat/mlc-package-config.json +++ b/ios/MLCChat/mlc-package-config.json @@ -4,6 +4,7 @@ "model": "HF://mlc-ai/Mistral-7B-Instruct-v0.2-q3f16_1-MLC", "model_id": "Mistral-7B-Instruct-v0.2-q3f16_1", "estimated_vram_bytes": 3316000000, + "bundle_weight": true, "overrides": { "context_window_size": 512 } @@ -30,4 +31,4 @@ } } ] -} \ No newline at end of file +} diff --git a/ios/MLCChat/prepare_package.sh b/ios/MLCChat/prepare_package.sh new file mode 100755 index 0000000000..6dedca46ae --- /dev/null +++ b/ios/MLCChat/prepare_package.sh @@ -0,0 +1,10 @@ +# This script does two things +# It calls prepare_libs.sh in $MLC_LLM_HOME/ios/ to setup the iOS package and build binaries +# It then calls mlc_llm package to setup the weight and library bundle +# Feel free to copy this file and mlc-package-config.json to your project + +MLC_LLM_HOME="${MLC_LLM_HOME:-../..}" +cd ${MLC_LLM_HOME}/ios && ./prepare_libs.sh $@ && cd - +mkdir -p dist/lib +cp ${MLC_LLM_HOME}/ios/build/lib/* dist/lib/ +python -m mlc_llm package mlc-package-config.json --device iphone -o dist diff --git a/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj b/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj index 52c9ac0108..2791b78391 100644 --- a/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj +++ b/ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj @@ -7,12 +7,13 @@ objects = { /* Begin PBXBuildFile section */ + C04105DF2BEBC61B005A434D /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = C04105DE2BEBC61B005A434D /* MLCSwift */; }; + C07094522BEBC6C4005C29FC /* bundle in Copy Files */ = {isa = PBXBuildFile; fileRef = C07094512BEBC6C4005C29FC /* bundle */; }; C0B37B892BE8226A00B2F80B /* MLCEngineExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0B37B882BE8226A00B2F80B /* MLCEngineExampleApp.swift */; }; C0B37B8B2BE8226A00B2F80B /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0B37B8A2BE8226A00B2F80B /* ContentView.swift */; }; C0B37B8D2BE8226B00B2F80B /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0B37B8C2BE8226B00B2F80B /* Assets.xcassets */; }; C0B37B902BE8226B00B2F80B /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0B37B8F2BE8226B00B2F80B /* Preview Assets.xcassets */; }; C0B37B982BE8234D00B2F80B /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = C0B37B972BE8234D00B2F80B /* MLCSwift */; }; - F31E1EEE2BEAD4870061D498 /* bundle in Copy Files */ = {isa = PBXBuildFile; fileRef = F31E1EED2BEAD4870061D498 /* bundle */; }; /* End PBXBuildFile section */ /* Begin PBXCopyFilesBuildPhase section */ @@ -22,7 +23,7 @@ dstPath = ""; dstSubfolderSpec = 7; files = ( - F31E1EEE2BEAD4870061D498 /* bundle in Copy Files */, + C07094522BEBC6C4005C29FC /* bundle in Copy Files */, ); name = "Copy Files"; runOnlyForDeploymentPostprocessing = 0; @@ -30,13 +31,13 @@ /* End PBXCopyFilesBuildPhase section */ /* Begin PBXFileReference section */ + C07094512BEBC6C4005C29FC /* bundle */ = {isa = PBXFileReference; lastKnownFileType = folder; name = bundle; path = dist/bundle; sourceTree = ""; }; C0B37B852BE8226A00B2F80B /* MLCEngineExample.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLCEngineExample.app; sourceTree = BUILT_PRODUCTS_DIR; }; C0B37B882BE8226A00B2F80B /* MLCEngineExampleApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLCEngineExampleApp.swift; sourceTree = ""; }; C0B37B8A2BE8226A00B2F80B /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; C0B37B8C2BE8226B00B2F80B /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; C0B37B8F2BE8226B00B2F80B /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; C0B37C0C2BE8349300B2F80B /* MLCEngineExample.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLCEngineExample.entitlements; sourceTree = ""; }; - F31E1EED2BEAD4870061D498 /* bundle */ = {isa = PBXFileReference; lastKnownFileType = folder; name = bundle; path = ../dist/bundle; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -45,6 +46,7 @@ buildActionMask = 2147483647; files = ( C0B37B982BE8234D00B2F80B /* MLCSwift in Frameworks */, + C04105DF2BEBC61B005A434D /* MLCSwift in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -54,7 +56,7 @@ C0B37B7C2BE8226A00B2F80B = { isa = PBXGroup; children = ( - F31E1EED2BEAD4870061D498 /* bundle */, + C07094512BEBC6C4005C29FC /* bundle */, C0B37B872BE8226A00B2F80B /* MLCEngineExample */, C0B37B862BE8226A00B2F80B /* Products */, ); @@ -107,6 +109,7 @@ name = MLCEngineExample; packageProductDependencies = ( C0B37B972BE8234D00B2F80B /* MLCSwift */, + C04105DE2BEBC61B005A434D /* MLCSwift */, ); productName = MLCEngineExample; productReference = C0B37B852BE8226A00B2F80B /* MLCEngineExample.app */; @@ -314,7 +317,7 @@ "$(inherited)", "@executable_path/Frameworks", ); - LIBRARY_SEARCH_PATHS = "$(PROJECT_DIR)/../dist/lib"; + LIBRARY_SEARCH_PATHS = "${PROJECT_DIR}/dist/lib"; MARKETING_VERSION = 1.0; OTHER_LDFLAGS = ( "-Wl,-all_load", @@ -355,7 +358,7 @@ "$(inherited)", "@executable_path/Frameworks", ); - LIBRARY_SEARCH_PATHS = "$(PROJECT_DIR)/../dist/lib"; + LIBRARY_SEARCH_PATHS = "${PROJECT_DIR}/dist/lib"; MARKETING_VERSION = 1.0; OTHER_LDFLAGS = ( "-Wl,-all_load", @@ -405,6 +408,10 @@ /* End XCLocalSwiftPackageReference section */ /* Begin XCSwiftPackageProductDependency section */ + C04105DE2BEBC61B005A434D /* MLCSwift */ = { + isa = XCSwiftPackageProductDependency; + productName = MLCSwift; + }; C0B37B972BE8234D00B2F80B /* MLCSwift */ = { isa = XCSwiftPackageProductDependency; productName = MLCSwift; diff --git a/ios/MLCEngineExample/README.md b/ios/MLCEngineExample/README.md index e08265f4b2..67bf06089b 100644 --- a/ios/MLCEngineExample/README.md +++ b/ios/MLCEngineExample/README.md @@ -1,6 +1,12 @@ # MLCEngine Example + Minimal example of the latest MLCEngine Swift API. NOTE: this project is still work in progress, things may not yet be fully functioning and are subject to change + +Checkout [Documentation page](https://llm.mlc.ai/docs/deploy/ios.html) for more information. + +- run `./prepare_package.sh` +- open the xcode project diff --git a/ios/MLCEngineExample/prepare_package.sh b/ios/MLCEngineExample/prepare_package.sh new file mode 100755 index 0000000000..d1f022166d --- /dev/null +++ b/ios/MLCEngineExample/prepare_package.sh @@ -0,0 +1,10 @@ +# This script does two things +# It calls prepare_libs.sh in $MLC_LLM_HOME/ios/ to setup the iOS package and build binaries +# It then calls mlc_llm package to setup the weight and library bundle +# Feel free to copy this file and mlc-package-config.json to your project + +MLC_LLM_HOME="${MLC_LLM_HOME:-../..}" +cd ${MLC_LLM_HOME}/ios && ./prepare_libs.sh $@ && cd - +rm -rf dist/lib && mkdir -p dist/lib +cp ${MLC_LLM_HOME}/ios/build/lib/* dist/lib/ +python -m mlc_llm package mlc-package-config.json --device iphone -o dist diff --git a/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift b/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift index 1aa652af5e..1f36933a15 100644 --- a/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift +++ b/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift @@ -28,7 +28,7 @@ public struct ChatFunction : Codable { } public struct ChatTool : Codable { - public let type: String = "function" + public var type: String = "function" public let function: ChatFunction } @@ -40,8 +40,8 @@ public struct ChatFunctionCall : Codable { } public struct ChatToolCall : Codable { - public let id: String = UUID().uuidString - public let type: String = "function" + public var id: String = UUID().uuidString + public var type: String = "function" public let function: ChatFunctionCall } diff --git a/ios/prepare_package.sh b/ios/prepare_libs.sh similarity index 93% rename from ios/prepare_package.sh rename to ios/prepare_libs.sh index 695c113760..58e6468637 100755 --- a/ios/prepare_package.sh +++ b/ios/prepare_libs.sh @@ -1,3 +1,5 @@ +# Command to prepare the mlc llm static libraries +# This command will be invoked by prepare_package.sh in the subfolder function help { echo -e "OPTION:" echo -e " -s, --simulator Build for Simulator" @@ -71,6 +73,3 @@ cd .. rm -rf MLCSwift/tvm_home ln -s ../../3rdparty/tvm MLCSwift/tvm_home - -python -m mlc_llm package MLCChat/mlc-package-config.json --device iphone -o dist -cp build/lib/* dist/lib/ diff --git a/python/mlc_llm/interface/jit.py b/python/mlc_llm/interface/jit.py index dd0179b811..7744ffe894 100644 --- a/python/mlc_llm/interface/jit.py +++ b/python/mlc_llm/interface/jit.py @@ -38,18 +38,28 @@ class JITResult: system_lib_prefix: Optional[str] = None +def log_jit_policy(): + """log current jit policy""" + logger.info( + "%s = %s. Can be one of: ON, OFF, REDO, READONLY", + bold("MLC_JIT_POLICY"), + MLC_JIT_POLICY, + ) + + def jit( # pylint: disable=too-many-locals,too-many-statements model_path: Path, chat_config: Dict[str, Any], device: Union[Device, str], system_lib_prefix: Optional[str] = None, + *, + skip_log_jit_policy=False, ) -> JITResult: """Just-in-time compile a MLC-Chat model.""" - logger.info( - "%s = %s. Can be one of: ON, OFF, REDO, READONLY", - bold("MLC_JIT_POLICY"), - MLC_JIT_POLICY, - ) + # skip logging jit policy since when outside can hint once + if not skip_log_jit_policy: + log_jit_policy() + if MLC_JIT_POLICY == "OFF": raise RuntimeError("JIT is disabled by MLC_JIT_POLICY=OFF") diff --git a/python/mlc_llm/interface/package.py b/python/mlc_llm/interface/package.py index 335c57d1db..d342ff589d 100644 --- a/python/mlc_llm/interface/package.py +++ b/python/mlc_llm/interface/package.py @@ -34,7 +34,11 @@ def _get_model_libs(lib_path: Path) -> List[str]: def validate_model_lib( # pylint: disable=too-many-locals - app_config_path: Path, device: Literal["iphone", "android"], output: Path + app_config_path: Path, + package_config_path: Path, + model_lib_path_for_prepare_libs: dict, + device: Literal["iphone", "android"], + output: Path, ) -> None: """Validate the model lib prefixes of model libraries.""" # pylint: disable=import-outside-toplevel,redefined-outer-name,shadowed-import,reimported @@ -50,7 +54,7 @@ def validate_model_lib( # pylint: disable=too-many-locals tar_list = [] model_set = set() - for model, model_lib_path in app_config["model_lib_path_for_prepare_libs"].items(): + for model, model_lib_path in model_lib_path_for_prepare_libs.items(): model_lib_path = os.path.join(model_lib_path) lib_path_valid = os.path.isfile(model_lib_path) if not lib_path_valid: @@ -74,39 +78,39 @@ def validate_model_lib( # pylint: disable=too-many-locals ) global_symbol_map = cc.get_global_symbol_section_map(lib_path) error_happened = False + for item in app_config["model_list"]: model_lib = item["model_lib"] model_id = item["model_id"] if model_lib not in model_set: - logger.info( - "ValidationError: model_lib=%s specified for model_id=%s " - "is not included in model_lib_path_for_prepare_libs field, " + # NOTE: this cannot happen under new setting + # since if model_lib is not included, it will be jitted + raise RuntimeError( + f"ValidationError: model_lib={model_lib} specified for model_id={model_id} " + "is not included in model_lib_path_for_prepare_libs argument, " "This will cause the specific model not being able to load, " - "please check %s.", - model_lib, - model_id, - str(app_config_path), + f"model_lib_path_for_prepare_libs={model_lib_path_for_prepare_libs}" ) - error_happened = True model_prefix_pattern = model_lib.replace("-", "_") + "___tvm_dev_mblob" if ( model_prefix_pattern not in global_symbol_map and "_" + model_prefix_pattern not in global_symbol_map ): - model_lib_path = app_config["model_lib_path_for_prepare_libs"][model_lib] - logger.info( + # NOTE: no lazy format is ok since this is a slow pass + model_lib_path = model_lib_path_for_prepare_libs[model_lib] + log_msg = ( "ValidationError:\n" - "\tmodel_lib %s requested in %s is not found in %s\n" - "\tspecifically the model_lib for %s in model_lib_path_for_prepare_libs.\n" - "\tcurrent available model_libs in %s: %s", - model_lib, - str(app_config_path), - str(lib_path), - model_lib_path, - str(lib_path), - str(available_model_libs), + f"\tmodel_lib {model_lib} requested in {str(app_config_path)}" + f" is not found in {str(lib_path)}\n" + f"\tspecifically the model_lib for {model_lib_path}.\n" + f"\tcurrent available model_libs in {str(lib_path)}: {available_model_libs}\n" + f"\tThis can happen when we manually specified model_lib_path_for_prepare_libs" + f" in {str(package_config_path)}\n" + f"\tConsider remove model_lib_path_for_prepare_libs (so library can be jitted)" + "or check the compile command" ) + logger.info(log_msg) error_happened = True if not error_happened: @@ -153,6 +157,8 @@ def package( # pylint: disable=too-many-locals,too-many-statements,too-many-bra "a dict." ) + jit.log_jit_policy() + for model_entry in package_config.get("model_list", []): # - Parse model entry. if not isinstance(model_entry, dict): @@ -205,6 +211,7 @@ def package( # pylint: disable=too-many-locals,too-many-statements,too-many-bra chat_config=asdict(chat_config), device=device, system_lib_prefix=model_lib, + skip_log_jit_policy=True, ) ) assert model_lib is not None @@ -237,10 +244,9 @@ def package( # pylint: disable=too-many-locals,too-many-statements,too-many-bra # Overwrite the model weight directory in bundle. bundle_model_weight_path = bundle_dir / model_path.name logger.info( - 'Bundle weight for model "%s". Copying weights from "%s" to "%s".', - model_id, - model_path, - bundle_model_weight_path, + "Bundle weight for %s, copy into %s", + style.bold(model_id), + style.bold(str(bundle_model_weight_path)), ) if bundle_model_weight_path.exists(): shutil.rmtree(bundle_model_weight_path) @@ -256,10 +262,7 @@ def package( # pylint: disable=too-many-locals,too-many-statements,too-many-bra # - Dump "mlc-app-config.json". app_config_json_str = json.dumps( - { - "model_list": app_config_model_list, - "model_lib_path_for_prepare_libs": model_lib_path_for_prepare_libs, - }, + {"model_list": app_config_model_list}, indent=2, ) app_config_path = bundle_dir / "mlc-app-config.json" @@ -271,4 +274,6 @@ def package( # pylint: disable=too-many-locals,too-many-statements,too-many-bra ) # - Validate model libraries. - validate_model_lib(app_config_path, device, output) + validate_model_lib( + app_config_path, package_config_path, model_lib_path_for_prepare_libs, device, output + ) From c5801409cf555c925dfbbae42abdf7d1c9a2f8bc Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Thu, 9 May 2024 04:50:33 +0530 Subject: [PATCH 288/531] [KVCACHE][TIR] Improved tir schedule for decode tir page attention (#2289) * [KVCACHE][TIR] Improved tir schedule for decode tir page attention 1. Improved tir schedule of page attention (It improved 30% to this function). 2. Enable missing dequant+matmul fusion in ph-2 model * Updated K_local to QK_local * Update kv_cache.py * Increase max thread for android:adreno --- .../fuse_dequantize_matmul_ewise.py | 2 +- python/mlc_llm/nn/kv_cache.py | 43 +++++++------------ python/mlc_llm/support/auto_target.py | 2 + 3 files changed, 19 insertions(+), 28 deletions(-) diff --git a/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py b/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py index 0943828933..36d133fb9a 100644 --- a/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py +++ b/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py @@ -16,7 +16,7 @@ def transform_module( """IRModule-level transformation""" seq = [] for n_aux_tensor in [0, 1, 2, 3, 4]: - for match_ewise in [0, 1, 2, 6]: + for match_ewise in [0, 1, 2, 3, 6]: if match_ewise == 6 and n_aux_tensor != 4: continue seq.append( diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index e5cae1e5cd..092278d0de 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -887,7 +887,7 @@ def _attention_decode( THREAD_LIMIT = 512 TILE_SIZE_PER_BDX = 2 if target.kind.name == "opencl" and "android" in str(target.host): - THREAD_LIMIT = 256 + THREAD_LIMIT = 256 if H_kv < 8 else 512 TILE_SIZE_PER_BDX = 1 max_num_threads_per_block = get_max_num_threads_per_block(target) thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) @@ -976,12 +976,14 @@ def batch_decode_paged_kv( t0 = T.alloc_buffer((1,), "float32", scope="local") S_local = T.alloc_buffer((bdy * tile_size_per_bdx), "float32", scope="local") - K_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") + QK_local = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") V_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") m_prev = T.alloc_buffer((1,), "float32", scope="local") d_prev = T.alloc_buffer((1,), "float32", scope="local") other_m = T.alloc_buffer((1,), "float32", scope="local") other_d = T.alloc_buffer((1,), "float32", scope="local") + exp_mprev = T.alloc_buffer((1,), "float32", scope="local") + exp_otherm = T.alloc_buffer((1,), "float32", scope="local") other_o = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") st_m = T.alloc_buffer((1,), "float32", scope="local") st_d = T.alloc_buffer((1,), "float32", scope="local") @@ -1015,9 +1017,9 @@ def batch_decode_paged_kv( for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_size_per_bdx * bdy * bdz)): tile_start_s: T.int32(is_size_var=True) = (tz * bdy + ty) * tile_size_per_bdx # type: ignore tile_start_g: T.int32(is_size_var=True) = ((iterator * bdz + tz) * bdy + ty) * tile_size_per_bdx # type: ignore - # load K from global memory to shared memory + # load KV from global memory to shared memory for j in T.serial(tile_size_per_bdx): - with T.block("K_load"): + with T.block("KV_load"): T.reads() T.writes() row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore @@ -1031,36 +1033,21 @@ def batch_decode_paged_kv( _rope(pages, k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype), pages[page_no, 0, by, page_offset, tx * VEC_SIZE + vec] ) - else: - for vec in T.vectorized(VEC_SIZE): - K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 - T.tvm_storage_sync("shared") - # load V from global memory to shared memory - for j in T.serial(tile_size_per_bdx): - with T.block("V_load"): - T.reads() - T.writes() - row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore - if row_g < kv_chunk_len[0]: - seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_g, batch_idx, length_info, sliding_window) # type: ignore - page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore - for vec in T.vectorized(VEC_SIZE): V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec] else: for vec in T.vectorized(VEC_SIZE): + K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 T.tvm_storage_sync("shared") # compute QK m_prev[0] = st_m[0] for j in T.serial(bdy * tile_size_per_bdx): - # load K from shared memory to local memory - for vec in T.vectorized(VEC_SIZE): - K_local[vec] = K_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec] # compute S = Q * K * sm_scale + for vec in T.vectorized(VEC_SIZE): + QK_local[vec] = T.cast(Q_local[vec], "float32") * T.cast(K_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec], "float32") * attn_score_scaling_factor * sm_scale S_reduce_local[0] = 0 - for vec in T.serial(VEC_SIZE): - S_reduce_local[0] += T.cast(Q_local[vec], "float32") * T.cast(K_local[vec], "float32") * attn_score_scaling_factor * sm_scale + for vec in T.unroll(VEC_SIZE): + S_reduce_local[0] += QK_local[vec] with T.block("block_cross_thread"): T.reads(S_reduce_local[0]) @@ -1117,11 +1104,13 @@ def batch_decode_paged_kv( other_o[vec] = O_allreduce[j, ty, tx * VEC_SIZE + vec] st_m[0] = T.max(st_m[0], other_m[0]) st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) - for vec in T.serial(VEC_SIZE): - O_local[vec] = O_local[vec] * T.exp2(m_prev[0] - st_m[0]) + other_o[vec] * T.exp2(other_m[0] - st_m[0]) + exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) + exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) + for vec in T.vectorized(VEC_SIZE): + O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] # normalize O - for vec in T.serial(VEC_SIZE): + for vec in T.vectorized(VEC_SIZE): O_local[vec] /= st_d[0] # store O to global memory diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 5239756d9d..001f3116cb 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -400,6 +400,7 @@ def detect_system_lib_prefix( "target": { "kind": "opencl", "device": "adreno", + "max_threads_per_block": 512, "host": { "kind": "llvm", "mtriple": "aarch64-linux-android", @@ -411,6 +412,7 @@ def detect_system_lib_prefix( "target": { "kind": "opencl", "device": "adreno", + "max_threads_per_block": 512, "host": { "kind": "llvm", "mtriple": "aarch64-linux-android", From 10f3e4df2b02a01fb6d0436210bb5c7d47a6607e Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 8 May 2024 18:22:55 -0700 Subject: [PATCH 289/531] [Sampler] Remove unneeded output_prob_dist param (#2300) --- cpp/serve/engine_actions/batch_draft.cc | 2 +- cpp/serve/sampler/cpu_sampler.cc | 61 ++++++------------------- cpp/serve/sampler/gpu_sampler.cc | 29 ++++-------- cpp/serve/sampler/sampler.h | 4 +- 4 files changed, 24 insertions(+), 72 deletions(-) diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index 513a0fe447..2e9d4dd536 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -123,7 +123,7 @@ class BatchDraftActionObj : public EngineActionObj { NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( probs_on_device, sample_indices, request_ids, generation_cfg); std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs, &prob_dist); + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); // - Add draft token to the state. diff --git a/cpp/serve/sampler/cpu_sampler.cc b/cpp/serve/sampler/cpu_sampler.cc index 196a6dd695..6c71169872 100644 --- a/cpp/serve/sampler/cpu_sampler.cc +++ b/cpp/serve/sampler/cpu_sampler.cc @@ -27,15 +27,12 @@ namespace serve { * \param input_prob_offset The offset specifying which distribution to sample from. * \param top_p The top-p value of sampling. * \param uniform_sample The random number in [0, 1] for sampling. - * \param output_prob_dist Optional pointer to store the corresponding probability distribution of - * each token, offset by unit_offset. If nullptr provided, nothing will be stored out. * \return The sampled value and probability. * \note This function is an enhancement of SampleTopPFromProb in TVM Unity. * We will upstream the enhancement after it gets stable. */ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_offset, double top_p, - double uniform_sample, - std::vector* output_prob_dist = nullptr) { + double uniform_sample) { // prob: (*, v) // The prob array may have arbitrary ndim and shape. // The last dimension corresponds to the prob distribution size. @@ -51,13 +48,6 @@ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_o static_cast(__builtin_assume_aligned(prob->data, 4)) + (input_prob_offset * ndata); constexpr double one = 1.0f - 1e-5f; - if (output_prob_dist) { - ICHECK_LT(unit_offset, static_cast(output_prob_dist->size())); - if (!(*output_prob_dist)[unit_offset].defined()) { - (*output_prob_dist)[unit_offset] = NDArray::Empty({ndata}, prob->dtype, DLDevice{kDLCPU, 0}); - } - } - if (top_p == 0) { // Specially handle case where top_p == 0. // This case is equivalent to doing argmax. @@ -75,20 +65,9 @@ TokenProbPair SampleTopPFromProb(NDArray prob, int unit_offset, int input_prob_o break; } } - if (output_prob_dist) { - float* __restrict p_output_prob = - static_cast(__builtin_assume_aligned((*output_prob_dist)[unit_offset]->data, 4)); - for (int i = 0; i < ndata; ++i) { - p_output_prob[i] = i == argmax_pos ? 1.0 : 0.0; - } - } return {argmax_pos, 1.0}; } - if (output_prob_dist) { - (*output_prob_dist)[unit_offset].CopyFromBytes(p_prob, ndata * sizeof(float)); - } - if (top_p >= one) { // Specially handle case where top_p == 1. double prob_sum = 0.0f; @@ -419,10 +398,9 @@ class CPUSampler : public SamplerObj { const std::vector& sample_indices, // const Array& request_ids, // const Array& generation_cfg, // - const std::vector& rngs, // - std::vector* output_prob_dist) final { + const std::vector& rngs) final { return BatchSampleTokensImpl(probs_on_host, sample_indices, request_ids, generation_cfg, rngs, - /*top_p_applied=*/true, output_prob_dist); + /*top_p_applied=*/true); } std::vector> BatchVerifyDraftTokensWithProbAfterTopP( @@ -520,14 +498,12 @@ class CPUSampler : public SamplerObj { } private: - std::vector BatchSampleTokensImpl( - NDArray probs_on_host, // - const std::vector& sample_indices, // - const Array& request_ids, // - const Array& generation_cfg, // - const std::vector& rngs, // - bool top_p_applied, // - std::vector* output_prob_dist = nullptr) { + std::vector BatchSampleTokensImpl(NDArray probs_on_host, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + bool top_p_applied) { // probs_on_host: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); ICHECK_EQ(probs_on_host->ndim, 2); @@ -540,29 +516,20 @@ class CPUSampler : public SamplerObj { std::vector sample_results; sample_results.resize(n); - if (output_prob_dist) { - output_prob_dist->resize(n); - } tvm::runtime::parallel_for_with_threading_backend( [this, &sample_results, &probs_on_host, &generation_cfg, &rngs, &request_ids, top_p_applied, - sample_indices, output_prob_dist](int i) { + sample_indices](int i) { RECORD_EVENT(this->trace_recorder_, request_ids[i], "start sample token"); // Sample top p from probability. double top_p = top_p_applied ? 1.0f : (generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p); - sample_results[i].sampled_token_id = - SampleTopPFromProb(probs_on_host, i, sample_indices[i], top_p, - rngs[i]->GetRandomNumber(), output_prob_dist); - if (output_prob_dist == nullptr) { - // When `output_prob_dist` is not nullptr, it means right now - // we are sampling for a small model in speculation, in which - // case we do not need to get the top probs. - sample_results[i].top_prob_tokens = - ComputeTopProbs(probs_on_host, i, generation_cfg[i]->top_logprobs); - } + sample_results[i].sampled_token_id = SampleTopPFromProb( + probs_on_host, i, sample_indices[i], top_p, rngs[i]->GetRandomNumber()); + sample_results[i].top_prob_tokens = + ComputeTopProbs(probs_on_host, i, generation_cfg[i]->top_logprobs); RECORD_EVENT(this->trace_recorder_, request_ids[i], "finish sample token"); }, 0, n); diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index 1a013a9627..7f09da7e1c 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -172,11 +172,10 @@ class GPUSampler : public SamplerObj { const std::vector& sample_indices, // const Array& request_ids, // const Array& generation_cfg, // - const std::vector& rngs, // - std::vector* output_prob_dist = nullptr) final { + const std::vector& rngs) final { NVTXScopedRange nvtx_scope("BatchSampleTokensWithProbAfterTopP"); return BatchSampleTokensImpl(std::move(probs_on_device), sample_indices, request_ids, - generation_cfg, rngs, /*top_p_applied=*/true, output_prob_dist); + generation_cfg, rngs, /*top_p_applied=*/true); } std::vector> BatchVerifyDraftTokensWithProbAfterTopP( @@ -326,14 +325,12 @@ class GPUSampler : public SamplerObj { } private: - std::vector BatchSampleTokensImpl( - NDArray probs_on_device, // - const std::vector& sample_indices, // - const Array& request_ids, // - const Array& generation_cfg, // - const std::vector& rngs, // - bool top_p_applied, // - std::vector* output_prob_dist = nullptr) { + std::vector BatchSampleTokensImpl(NDArray probs_on_device, // + const std::vector& sample_indices, // + const Array& request_ids, // + const Array& generation_cfg, // + const std::vector& rngs, // + bool top_p_applied) { // probs_on_device: (n, v) RECORD_EVENT(trace_recorder_, request_ids, "start sampling"); CHECK_EQ(probs_on_device->ndim, 2); @@ -342,16 +339,6 @@ class GPUSampler : public SamplerObj { int num_samples = sample_indices.size(); int num_probs = probs_on_device->shape[0]; int vocab_size = probs_on_device->shape[1]; - if (output_prob_dist != nullptr) { - ICHECK(output_prob_dist->empty()); - output_prob_dist->reserve(num_samples); - for (int i = 0; i < num_samples; ++i) { - NDArray prob_dist = NDArray::Empty({vocab_size}, dtype_f32_, device_); - float* p_prob = static_cast(probs_on_device->data) + sample_indices[i] * vocab_size; - prob_dist.CopyFromBytes(p_prob, vocab_size * sizeof(float)); - output_prob_dist->push_back(std::move(prob_dist)); - } - } if (num_samples == 0) { // This synchronization is necessary for making sure that this round // of model forward is finished. diff --git a/cpp/serve/sampler/sampler.h b/cpp/serve/sampler/sampler.h index 59e433ac47..d9f6dbcb4f 100644 --- a/cpp/serve/sampler/sampler.h +++ b/cpp/serve/sampler/sampler.h @@ -83,7 +83,6 @@ class SamplerObj : public Object { * \param generation_cfg The generation config of each request * in the input batch. * \param rngs The random number generator of each sequence. - * \param output_prob_dist The output probability distribution * \return The batch of sampling results, which contain the sampled token id * and other probability info. */ @@ -92,8 +91,7 @@ class SamplerObj : public Object { const std::vector& sample_indices, // const Array& request_ids, // const Array& generation_cfg, // - const std::vector& rngs, // - std::vector* output_prob_dist = nullptr) = 0; + const std::vector& rngs) = 0; /*! * \brief Verify draft tokens generated by small models in the large model From 33c15e72a3567292cba577ea7f89652ec9f2bd6e Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 9 May 2024 05:42:57 -0700 Subject: [PATCH 290/531] Enable cuda graph for batch_verify (#2304) --- python/mlc_llm/interface/compile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index 7be9dadd39..7aafc64738 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -166,6 +166,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: cuda_graph_symbolic_capture_hints = { "batch_decode": ["batch_size"], "batch_decode_to_last_hidden_states": ["batch_size"], + "batch_verify": ["batch_size", "seq_len"], "batch_verify_to_last_hidden_states": ["batch_size", "seq_len"], } metadata = { From dbd13f414acf453b957e2448207bce2a72b488b1 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 9 May 2024 21:15:34 -0400 Subject: [PATCH 291/531] [Android] Introducing mlc4j and app packaging (#2305) This PR lifts the existing `library` of android app into a standalone `mlc4j` directory, which can be referenced by android app at any location. On the app side, this PR moves the android app into a subfolder `MLCChat` which itself is a well-formed Android app. This folder contains two core files for app build: * `MLCChat/mlc-package-config.json` the config file that specifies the models to build into the app. * `MLCChat/prepare_package.py` the Python script that helps automatically prepare/build mlc4j and model libraries. This PR also updates the android app documentation to reflect this latest change. --- android/MLCChat/README.md | 6 + android/{ => MLCChat}/app/.gitignore | 0 android/{ => MLCChat}/app/build.gradle | 2 +- android/{ => MLCChat}/app/proguard-rules.pro | 0 .../app/src/main/AndroidManifest.xml | 0 .../app/src/main/ic_launcher-playstore.png | Bin .../main/java/ai/mlc/mlcchat/AppViewModel.kt | 2 +- .../src/main/java/ai/mlc/mlcchat/ChatView.kt | 0 .../main/java/ai/mlc/mlcchat/MainActivity.kt | 0 .../src/main/java/ai/mlc/mlcchat/NavView.kt | 0 .../src/main/java/ai/mlc/mlcchat/StartView.kt | 0 .../java/ai/mlc/mlcchat/ui/theme/Color.kt | 0 .../java/ai/mlc/mlcchat/ui/theme/Theme.kt | 0 .../main/java/ai/mlc/mlcchat/ui/theme/Type.kt | 0 .../res/drawable/ic_android_black_24dp.xml | 0 .../src/main/res/drawable/mlc_logo_108.xml | 0 .../app/src/main/res/values/colors.xml | 0 .../app/src/main/res/values/strings.xml | 0 .../app/src/main/res/values/themes.xml | 0 .../app/src/main/res/xml/backup_rules.xml | 0 .../main/res/xml/data_extraction_rules.xml | 0 android/{ => MLCChat}/build.gradle | 0 android/MLCChat/bundle_weight.py | 65 ++++ android/{ => MLCChat}/gradle.properties | 0 .../gradle/wrapper/gradle-wrapper.jar | Bin .../gradle/wrapper/gradle-wrapper.properties | 0 android/{ => MLCChat}/gradlew | 0 android/{ => MLCChat}/gradlew.bat | 0 android/MLCChat/mlc-package-config.json | 38 ++ android/{ => MLCChat}/settings.gradle | 3 +- android/library/prepare_libs.sh | 34 -- android/library/prepare_model_lib.py | 79 ---- .../library/src/main/assets/app-config.json | 41 --- android/{library => mlc4j}/.gitignore | 0 android/{library => mlc4j}/CMakeLists.txt | 2 +- android/{library => mlc4j}/build.gradle | 4 +- android/mlc4j/prepare_libs.py | 90 +++++ .../{library => mlc4j}/src/cpp/tvm_runtime.h | 0 .../src/main/AndroidManifest.xml | 0 .../main/java/ai/mlc/mlcllm/ChatModule.java | 0 .../package_model_libraries_weights.rst | 208 +++++++++++ docs/deploy/android.rst | 297 ++++++++++----- docs/deploy/ios.rst | 80 +++-- docs/index.rst | 1 + ios/.gitignore | 1 + ios/MLCChat/README.md | 4 +- ios/MLCChat/mlc-package-config.json | 1 + ios/MLCChat/prepare_package.sh | 10 - .../MLCEngineExampleApp.swift | 7 +- ios/MLCEngineExample/README.md | 4 +- ios/MLCEngineExample/mlc-package-config.json | 1 + ios/MLCEngineExample/prepare_package.sh | 10 - ios/MLCSwift/tvm_home | 1 - ios/README.md | 2 +- ios/prepare_libs.sh | 9 +- python/mlc_llm/cli/package.py | 32 +- python/mlc_llm/help.py | 23 +- python/mlc_llm/interface/package.py | 337 +++++++++++------- 58 files changed, 939 insertions(+), 455 deletions(-) create mode 100644 android/MLCChat/README.md rename android/{ => MLCChat}/app/.gitignore (100%) rename android/{ => MLCChat}/app/build.gradle (98%) rename android/{ => MLCChat}/app/proguard-rules.pro (100%) rename android/{ => MLCChat}/app/src/main/AndroidManifest.xml (100%) rename android/{ => MLCChat}/app/src/main/ic_launcher-playstore.png (100%) rename android/{ => MLCChat}/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt (99%) rename android/{ => MLCChat}/app/src/main/java/ai/mlc/mlcchat/ChatView.kt (100%) rename android/{ => MLCChat}/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt (100%) rename android/{ => MLCChat}/app/src/main/java/ai/mlc/mlcchat/NavView.kt (100%) rename android/{ => MLCChat}/app/src/main/java/ai/mlc/mlcchat/StartView.kt (100%) rename android/{ => MLCChat}/app/src/main/java/ai/mlc/mlcchat/ui/theme/Color.kt (100%) rename android/{ => MLCChat}/app/src/main/java/ai/mlc/mlcchat/ui/theme/Theme.kt (100%) rename android/{ => MLCChat}/app/src/main/java/ai/mlc/mlcchat/ui/theme/Type.kt (100%) rename android/{ => MLCChat}/app/src/main/res/drawable/ic_android_black_24dp.xml (100%) rename android/{ => MLCChat}/app/src/main/res/drawable/mlc_logo_108.xml (100%) rename android/{ => MLCChat}/app/src/main/res/values/colors.xml (100%) rename android/{ => MLCChat}/app/src/main/res/values/strings.xml (100%) rename android/{ => MLCChat}/app/src/main/res/values/themes.xml (100%) rename android/{ => MLCChat}/app/src/main/res/xml/backup_rules.xml (100%) rename android/{ => MLCChat}/app/src/main/res/xml/data_extraction_rules.xml (100%) rename android/{ => MLCChat}/build.gradle (100%) create mode 100644 android/MLCChat/bundle_weight.py rename android/{ => MLCChat}/gradle.properties (100%) rename android/{ => MLCChat}/gradle/wrapper/gradle-wrapper.jar (100%) rename android/{ => MLCChat}/gradle/wrapper/gradle-wrapper.properties (100%) rename android/{ => MLCChat}/gradlew (100%) rename android/{ => MLCChat}/gradlew.bat (100%) create mode 100644 android/MLCChat/mlc-package-config.json rename android/{ => MLCChat}/settings.gradle (82%) delete mode 100755 android/library/prepare_libs.sh delete mode 100644 android/library/prepare_model_lib.py delete mode 100644 android/library/src/main/assets/app-config.json rename android/{library => mlc4j}/.gitignore (100%) rename android/{library => mlc4j}/CMakeLists.txt (97%) rename android/{library => mlc4j}/build.gradle (84%) create mode 100644 android/mlc4j/prepare_libs.py rename android/{library => mlc4j}/src/cpp/tvm_runtime.h (100%) rename android/{library => mlc4j}/src/main/AndroidManifest.xml (100%) rename android/{library => mlc4j}/src/main/java/ai/mlc/mlcllm/ChatModule.java (100%) create mode 100644 docs/compilation/package_model_libraries_weights.rst delete mode 100755 ios/MLCChat/prepare_package.sh delete mode 100755 ios/MLCEngineExample/prepare_package.sh delete mode 120000 ios/MLCSwift/tvm_home diff --git a/android/MLCChat/README.md b/android/MLCChat/README.md new file mode 100644 index 0000000000..445d09a659 --- /dev/null +++ b/android/MLCChat/README.md @@ -0,0 +1,6 @@ +# MLC-LLM Android + +Checkout [Documentation page](https://llm.mlc.ai/docs/deploy/android.html) for more information. + +- run `mlc_llm package` +- open this `MLCChat/` folder as a project in Android Studio diff --git a/android/app/.gitignore b/android/MLCChat/app/.gitignore similarity index 100% rename from android/app/.gitignore rename to android/MLCChat/app/.gitignore diff --git a/android/app/build.gradle b/android/MLCChat/app/build.gradle similarity index 98% rename from android/app/build.gradle rename to android/MLCChat/app/build.gradle index 1fd30e3985..47b2915460 100644 --- a/android/app/build.gradle +++ b/android/MLCChat/app/build.gradle @@ -47,7 +47,7 @@ android { } dependencies { - implementation project(":library") + implementation project(":mlc4j") implementation 'androidx.core:core-ktx:1.10.1' implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.6.1' implementation 'androidx.activity:activity-compose:1.7.1' diff --git a/android/app/proguard-rules.pro b/android/MLCChat/app/proguard-rules.pro similarity index 100% rename from android/app/proguard-rules.pro rename to android/MLCChat/app/proguard-rules.pro diff --git a/android/app/src/main/AndroidManifest.xml b/android/MLCChat/app/src/main/AndroidManifest.xml similarity index 100% rename from android/app/src/main/AndroidManifest.xml rename to android/MLCChat/app/src/main/AndroidManifest.xml diff --git a/android/app/src/main/ic_launcher-playstore.png b/android/MLCChat/app/src/main/ic_launcher-playstore.png similarity index 100% rename from android/app/src/main/ic_launcher-playstore.png rename to android/MLCChat/app/src/main/ic_launcher-playstore.png diff --git a/android/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt similarity index 99% rename from android/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt rename to android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt index 6a3bf4a211..cd8b23ce08 100644 --- a/android/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt @@ -38,7 +38,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { private val modelIdSet = emptySet().toMutableSet() companion object { - const val AppConfigFilename = "app-config.json" + const val AppConfigFilename = "mlc-app-config.json" const val ModelConfigFilename = "mlc-chat-config.json" const val ParamsConfigFilename = "ndarray-cache.json" const val ModelUrlSuffix = "resolve/main/" diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ChatView.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt similarity index 100% rename from android/app/src/main/java/ai/mlc/mlcchat/ChatView.kt rename to android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt diff --git a/android/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt similarity index 100% rename from android/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt rename to android/MLCChat/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt diff --git a/android/app/src/main/java/ai/mlc/mlcchat/NavView.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/NavView.kt similarity index 100% rename from android/app/src/main/java/ai/mlc/mlcchat/NavView.kt rename to android/MLCChat/app/src/main/java/ai/mlc/mlcchat/NavView.kt diff --git a/android/app/src/main/java/ai/mlc/mlcchat/StartView.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt similarity index 100% rename from android/app/src/main/java/ai/mlc/mlcchat/StartView.kt rename to android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Color.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ui/theme/Color.kt similarity index 100% rename from android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Color.kt rename to android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ui/theme/Color.kt diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Theme.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ui/theme/Theme.kt similarity index 100% rename from android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Theme.kt rename to android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ui/theme/Theme.kt diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Type.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ui/theme/Type.kt similarity index 100% rename from android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Type.kt rename to android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ui/theme/Type.kt diff --git a/android/app/src/main/res/drawable/ic_android_black_24dp.xml b/android/MLCChat/app/src/main/res/drawable/ic_android_black_24dp.xml similarity index 100% rename from android/app/src/main/res/drawable/ic_android_black_24dp.xml rename to android/MLCChat/app/src/main/res/drawable/ic_android_black_24dp.xml diff --git a/android/app/src/main/res/drawable/mlc_logo_108.xml b/android/MLCChat/app/src/main/res/drawable/mlc_logo_108.xml similarity index 100% rename from android/app/src/main/res/drawable/mlc_logo_108.xml rename to android/MLCChat/app/src/main/res/drawable/mlc_logo_108.xml diff --git a/android/app/src/main/res/values/colors.xml b/android/MLCChat/app/src/main/res/values/colors.xml similarity index 100% rename from android/app/src/main/res/values/colors.xml rename to android/MLCChat/app/src/main/res/values/colors.xml diff --git a/android/app/src/main/res/values/strings.xml b/android/MLCChat/app/src/main/res/values/strings.xml similarity index 100% rename from android/app/src/main/res/values/strings.xml rename to android/MLCChat/app/src/main/res/values/strings.xml diff --git a/android/app/src/main/res/values/themes.xml b/android/MLCChat/app/src/main/res/values/themes.xml similarity index 100% rename from android/app/src/main/res/values/themes.xml rename to android/MLCChat/app/src/main/res/values/themes.xml diff --git a/android/app/src/main/res/xml/backup_rules.xml b/android/MLCChat/app/src/main/res/xml/backup_rules.xml similarity index 100% rename from android/app/src/main/res/xml/backup_rules.xml rename to android/MLCChat/app/src/main/res/xml/backup_rules.xml diff --git a/android/app/src/main/res/xml/data_extraction_rules.xml b/android/MLCChat/app/src/main/res/xml/data_extraction_rules.xml similarity index 100% rename from android/app/src/main/res/xml/data_extraction_rules.xml rename to android/MLCChat/app/src/main/res/xml/data_extraction_rules.xml diff --git a/android/build.gradle b/android/MLCChat/build.gradle similarity index 100% rename from android/build.gradle rename to android/MLCChat/build.gradle diff --git a/android/MLCChat/bundle_weight.py b/android/MLCChat/bundle_weight.py new file mode 100644 index 0000000000..adade13071 --- /dev/null +++ b/android/MLCChat/bundle_weight.py @@ -0,0 +1,65 @@ +import argparse +import os +import subprocess +from pathlib import Path + +from mlc_llm.support import logging + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +def main(apk_path: Path, package_output_path: Path): + """Push weights to the android device with adb""" + # - Install the apk on device. + logger.info('Install apk "%s" to device', str(apk_path.absolute())) + subprocess.run(["adb", "install", str(apk_path)], check=True, env=os.environ) + # - Create the weight directory for the app. + device_weihgt_dir = "/storage/emulated/0/Android/data/ai.mlc.mlcchat/files/" + logger.info('Creating directory "%s" on device', device_weihgt_dir) + subprocess.run( + ["adb", "shell", "mkdir", "-p", device_weihgt_dir], + check=True, + env=os.environ, + ) + for model_weight_dir in (package_output_path / "bundle").iterdir(): + if model_weight_dir.is_dir(): + src_path = str(model_weight_dir.absolute()) + dst_path = "/data/local/tmp/" + model_weight_dir.name + logger.info('Pushing local weights "%s" to device location "%s"', src_path, dst_path) + subprocess.run(["adb", "push", src_path, dst_path], check=True, env=os.environ) + + src_path = dst_path + dst_path = "/storage/emulated/0/Android/data/ai.mlc.mlcchat/files/" + logger.info('Move weights from "%s" to "%s"', src_path, dst_path) + subprocess.run(["adb", "shell", "mv", src_path, dst_path], check=True, env=os.environ) + logger.info("All finished.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("MLC LLM Android Weight Bundle") + + def _parse_apk_path(path: str) -> Path: + path = Path(path) + if not path.exists(): + raise ValueError( + f"Path {str(path)} is expected to be an apk file, but the file does not exist." + ) + if not path.is_file(): + raise ValueError(f"Path {str(path)} is expected to be an apk file.") + return path + + parser.add_argument( + "--apk-path", + type=_parse_apk_path, + default="app/release/app-release.apk", + help="The path to generated MLCChat apk file.", + ) + parser.add_argument( + "--package-output-path", + type=Path, + default="dist", + help='The path to the output directory of "mlc_llm package".', + ) + args = parser.parse_args() + main(args.apk_path, args.package_output_path) diff --git a/android/gradle.properties b/android/MLCChat/gradle.properties similarity index 100% rename from android/gradle.properties rename to android/MLCChat/gradle.properties diff --git a/android/gradle/wrapper/gradle-wrapper.jar b/android/MLCChat/gradle/wrapper/gradle-wrapper.jar similarity index 100% rename from android/gradle/wrapper/gradle-wrapper.jar rename to android/MLCChat/gradle/wrapper/gradle-wrapper.jar diff --git a/android/gradle/wrapper/gradle-wrapper.properties b/android/MLCChat/gradle/wrapper/gradle-wrapper.properties similarity index 100% rename from android/gradle/wrapper/gradle-wrapper.properties rename to android/MLCChat/gradle/wrapper/gradle-wrapper.properties diff --git a/android/gradlew b/android/MLCChat/gradlew similarity index 100% rename from android/gradlew rename to android/MLCChat/gradlew diff --git a/android/gradlew.bat b/android/MLCChat/gradlew.bat similarity index 100% rename from android/gradlew.bat rename to android/MLCChat/gradlew.bat diff --git a/android/MLCChat/mlc-package-config.json b/android/MLCChat/mlc-package-config.json new file mode 100644 index 0000000000..766d6d2a80 --- /dev/null +++ b/android/MLCChat/mlc-package-config.json @@ -0,0 +1,38 @@ +{ + "device": "android", + "model_list": [ + { + "model": "HF://mlc-ai/gemma-2b-it-q4f16_1-MLC", + "model_id": "gemma-2b-q4f16_1", + "estimated_vram_bytes": 3000000000 + }, + { + "model": "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC", + "estimated_vram_bytes": 4348727787, + "model_id": "Llama-2-7b-chat-hf-q4f16_1", + "overrides": { + "context_window_size": 768, + "prefill_chunk_size": 256 + } + }, + { + "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", + "estimated_vram_bytes": 1948348579, + "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1" + }, + { + "model": "HF://mlc-ai/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", + "estimated_vram_bytes": 4275453296, + "model_id": "Mistral-7B-Instruct-v0.2-q4f16_1", + "overrides": { + "sliding_window_size": 768, + "prefill_chunk_size": 256 + } + }, + { + "model": "HF://mlc-ai/phi-2-q4f16_1-MLC", + "estimated_vram_bytes": 2036816936, + "model_id": "phi-2-q4f16_1" + } + ] +} diff --git a/android/settings.gradle b/android/MLCChat/settings.gradle similarity index 82% rename from android/settings.gradle rename to android/MLCChat/settings.gradle index 31e8cf1d87..6866480997 100644 --- a/android/settings.gradle +++ b/android/MLCChat/settings.gradle @@ -14,4 +14,5 @@ dependencyResolutionManagement { } rootProject.name = "MLCChat" include ':app' -include ':library' +include ':mlc4j' +project(':mlc4j').projectDir = file('dist/lib/mlc4j') diff --git a/android/library/prepare_libs.sh b/android/library/prepare_libs.sh deleted file mode 100755 index c089927d09..0000000000 --- a/android/library/prepare_libs.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -set -euxo pipefail - -rustup target add aarch64-linux-android - -mkdir -p build/model_lib - -python3 prepare_model_lib.py - -cd build -touch config.cmake -if [ ${TVM_HOME-0} -ne 0 ]; then - echo "set(TVM_HOME ${TVM_HOME})" >> config.cmake -fi - -cmake .. \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ - -DCMAKE_INSTALL_PREFIX=. \ - -DCMAKE_CXX_FLAGS="-O3" \ - -DANDROID_ABI=arm64-v8a \ - -DANDROID_NATIVE_API_LEVEL=android-24 \ - -DANDROID_PLATFORM=android-24 \ - -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=ON \ - -DANDROID_STL=c++_static \ - -DUSE_HEXAGON_SDK=OFF \ - -DMLC_LLM_INSTALL_STATIC_LIB=ON \ - -DCMAKE_SKIP_INSTALL_ALL_DEPENDENCY=ON \ - -DUSE_OPENCL=ON \ - -DUSE_OPENCL_ENABLE_HOST_PTR=ON \ - -DUSE_CUSTOM_LOGGING=ON \ - -cmake --build . --target tvm4j_runtime_packed --config release -cmake --build . --target install --config release -j diff --git a/android/library/prepare_model_lib.py b/android/library/prepare_model_lib.py deleted file mode 100644 index 9f143d7357..0000000000 --- a/android/library/prepare_model_lib.py +++ /dev/null @@ -1,79 +0,0 @@ -import json -import os - -from tvm.contrib import ndk - - -def get_model_libs(lib_path): - global_symbol_map = ndk.get_global_symbol_section_map(lib_path) - libs = [] - suffix = "___tvm_dev_mblob" - for name in global_symbol_map.keys(): - if name.endswith(suffix): - model_lib = name[: -len(suffix)] - if model_lib.startswith("_"): - model_lib = model_lib[1:] - libs.append(model_lib) - return libs - - -def main(): - app_config_path = "src/main/assets/app-config.json" - app_config = json.load(open(app_config_path, "r")) - artifact_path = os.path.abspath(os.path.join("../..", "dist")) - tar_list = [] - model_set = set() - - for model, model_lib in app_config["model_lib_path_for_prepare_libs"].items(): - path = os.path.join(artifact_path, model_lib) - if not os.path.isfile(path): - raise RuntimeError(f"Cannot find android library {path}") - tar_list.append(path) - model_set.add(model) - - lib_path = os.path.join("build", "model_lib", "libmodel_android.a") - ndk.create_staticlib(lib_path, tar_list) - print(f"Creating lib from {tar_list}..") - - available_model_libs = get_model_libs(lib_path) - print(f"Validating the library {lib_path}...") - print( - f"List of available model libs packaged: {available_model_libs}," - " if we have '-' in the model_lib string, it will be turned into '_'" - ) - global_symbol_map = ndk.get_global_symbol_section_map(lib_path) - error_happened = False - for item in app_config["model_list"]: - model_lib = item["model_lib"] - model_id = item["model_id"] - if model_lib not in model_set: - print( - f"ValidationError: model_lib={model_lib} specified for model_id={model_id} " - "is not included in model_lib_path_for_prepare_libs field, " - "This will cause the specific model not being able to load, " - f"please check {app_config_path}." - ) - error_happened = True - model_prefix_pattern = model_lib.replace("-", "_") + "___tvm_dev_mblob" - if ( - model_prefix_pattern not in global_symbol_map - and "_" + model_prefix_pattern not in global_symbol_map - ): - model_lib = app_config["model_lib_path_for_prepare_libs"][model_lib] - print( - "ValidationError:\n" - f"\tmodel_lib {model_lib} requested in {app_config_path} is not found in {lib_path}\n" - f"\tspecifically the model_lib for {model_lib} in model_lib_path_for_prepare_libs.\n" - f"\tcurrent available model_libs in {lib_path}: {available_model_libs}" - ) - error_happened = True - - if not error_happened: - print("Validation pass") - else: - print("Validation failed") - exit(255) - - -if __name__ == "__main__": - main() diff --git a/android/library/src/main/assets/app-config.json b/android/library/src/main/assets/app-config.json deleted file mode 100644 index 68442c234e..0000000000 --- a/android/library/src/main/assets/app-config.json +++ /dev/null @@ -1,41 +0,0 @@ -{ - "model_list": [ - { - "model_url": "https://huggingface.co/mlc-ai/gemma-2b-it-q4f16_1-MLC", - "model_id": "gemma-2b-q4f16_1", - "model_lib": "gemma_q4f16_1", - "estimated_vram_bytes": 3000000000 - }, - { - "model_url": "https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC/", - "model_lib": "llama_q4f16_1", - "estimated_vram_bytes": 4348727787, - "model_id": "Llama-2-7b-chat-hf-q4f16_1" - }, - { - "model_url": "https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/", - "model_lib": "gpt_neox_q4f16_1", - "estimated_vram_bytes": 1948348579, - "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1" - }, - { - "model_url": "https://huggingface.co/mlc-ai/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", - "model_lib": "mistral_q4f16_1", - "estimated_vram_bytes": 4275453296, - "model_id": "Mistral-7B-Instruct-v0.2-q4f16_1" - }, - { - "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC", - "model_lib": "phi_msft_q4f16_1", - "estimated_vram_bytes": 2036816936, - "model_id": "phi-2-q4f16_1" - } - ], - "model_lib_path_for_prepare_libs": { - "gemma_q4f16_1": "prebuilt/lib/gemma-2b-it/gemma-2b-it-q4f16_1-android.tar", - "llama_q4f16_1": "prebuilt/lib/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-android.tar", - "gpt_neox_q4f16_1": "prebuilt/lib/RedPajama-INCITE-Chat-3B-v1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar", - "phi_msft_q4f16_1": "prebuilt/lib/phi-2/phi-2-q4f16_1-android.tar", - "mistral_q4f16_1": "prebuilt/lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-android.tar" - } -} \ No newline at end of file diff --git a/android/library/.gitignore b/android/mlc4j/.gitignore similarity index 100% rename from android/library/.gitignore rename to android/mlc4j/.gitignore diff --git a/android/library/CMakeLists.txt b/android/mlc4j/CMakeLists.txt similarity index 97% rename from android/library/CMakeLists.txt rename to android/mlc4j/CMakeLists.txt index a7d5a1caf0..f4ce6f218d 100644 --- a/android/library/CMakeLists.txt +++ b/android/mlc4j/CMakeLists.txt @@ -37,7 +37,7 @@ add_custom_command( ) add_library(model_android STATIC IMPORTED) -set_target_properties(model_android PROPERTIES IMPORTED_LOCATION ${ANDROID_BIN_DIR}/model_lib/libmodel_android.a) +set_target_properties(model_android PROPERTIES IMPORTED_LOCATION ${ANDROID_BIN_DIR}/lib/libmodel_android.a) add_library(tvm4j_runtime_packed SHARED ${TVM_HOME}/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc) diff --git a/android/library/build.gradle b/android/mlc4j/build.gradle similarity index 84% rename from android/library/build.gradle rename to android/mlc4j/build.gradle index 8e4a1b8408..a9058fd827 100644 --- a/android/library/build.gradle +++ b/android/mlc4j/build.gradle @@ -19,13 +19,13 @@ android { } sourceSets { main { - jniLibs.srcDirs = ['build/output'] + jniLibs.srcDirs = ['output'] } } } dependencies { - implementation fileTree(dir: 'build/output', include: ['*.jar']) + implementation fileTree(dir: 'output', include: ['*.jar']) implementation 'androidx.core:core-ktx:1.9.0' implementation 'androidx.appcompat:appcompat:1.6.1' implementation 'com.google.android.material:material:1.10.0' diff --git a/android/mlc4j/prepare_libs.py b/android/mlc4j/prepare_libs.py new file mode 100644 index 0000000000..19f80718f0 --- /dev/null +++ b/android/mlc4j/prepare_libs.py @@ -0,0 +1,90 @@ +"""The build script for mlc4j (MLC LLM and tvm4j)""" + +import argparse +import os +import subprocess +from pathlib import Path + +from mlc_llm.support import logging + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +def run_cmake(mlc4j_path: Path): + if "ANDROID_NDK" not in os.environ: + raise ValueError( + f'Environment variable "ANDROID_NDK" is required but not found.' + "Please follow https://llm.mlc.ai/docs/deploy/android.html to properly " + 'specify "ANDROID_NDK".' + ) + logger.info("Running cmake") + cmd = [ + "cmake", + str(mlc4j_path), + "-DCMAKE_BUILD_TYPE=Release", + f"-DCMAKE_TOOLCHAIN_FILE={os.environ['ANDROID_NDK']}/build/cmake/android.toolchain.cmake", + "-DCMAKE_INSTALL_PREFIX=.", + '-DCMAKE_CXX_FLAGS="-O3"', + "-DANDROID_ABI=arm64-v8a", + "-DANDROID_NATIVE_API_LEVEL=android-24", + "-DANDROID_PLATFORM=android-24", + "-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=ON", + "-DANDROID_STL=c++_static", + "-DUSE_HEXAGON_SDK=OFF", + "-DMLC_LLM_INSTALL_STATIC_LIB=ON", + "-DCMAKE_SKIP_INSTALL_ALL_DEPENDENCY=ON", + "-DUSE_OPENCL=ON", + "-DUSE_OPENCL_ENABLE_HOST_PTR=ON", + "-DUSE_CUSTOM_LOGGING=ON", + ] + subprocess.run(cmd, check=True, env=os.environ) + + +def run_cmake_build(): + logger.info("Running cmake build") + cmd = ["cmake", "--build", ".", "--target", "tvm4j_runtime_packed", "--config", "release"] + subprocess.run(cmd, check=True, env=os.environ) + + +def run_cmake_install(): + logger.info("Running cmake install") + cmd = ["cmake", "--build", ".", "--target", "install", "--config", "release", "-j"] + subprocess.run(cmd, check=True, env=os.environ) + + +def main(mlc_llm_home: Path): + # - Setup rust. + subprocess.run(["rustup", "target", "add", "aarch64-linux-android"], check=True, env=os.environ) + + # - Build MLC LLM and tvm4j. + build_path = Path("build") + os.makedirs(build_path / "lib", exist_ok=True) + logger.info('Entering "%s" for MLC LLM and tvm4j build.', os.path.abspath(build_path)) + os.chdir(build_path) + # Generate config.cmake if TVM Home is set. + if "TVM_HOME" in os.environ: + logger.info('Set TVM_HOME to "%s"', os.environ["TVM_HOME"]) + with open("config.cmake", "w", encoding="utf-8") as file: + print("set(TVM_HOME ${%s})" % os.environ["TVM_HOME"], file=file) + + # - Run cmake, build and install + run_cmake(mlc_llm_home / "android" / "mlc4j") + run_cmake_build() + run_cmake_install() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("MLC LLM Android Lib Preparation") + + parser.add_argument( + "--mlc-llm-home", + type=Path, + default=os.environ.get("MLC_LLM_HOME", None), + help="The path to MLC LLM source", + ) + parsed = parser.parse_args() + if parsed.mlc_llm_home is None: + parsed.mlc_llm_home = Path(os.path.abspath(os.path.curdir)).parent.parent + os.environ["MLC_LLM_HOME"] = str(parsed.mlc_llm_home) + main(parsed.mlc_llm_home) diff --git a/android/library/src/cpp/tvm_runtime.h b/android/mlc4j/src/cpp/tvm_runtime.h similarity index 100% rename from android/library/src/cpp/tvm_runtime.h rename to android/mlc4j/src/cpp/tvm_runtime.h diff --git a/android/library/src/main/AndroidManifest.xml b/android/mlc4j/src/main/AndroidManifest.xml similarity index 100% rename from android/library/src/main/AndroidManifest.xml rename to android/mlc4j/src/main/AndroidManifest.xml diff --git a/android/library/src/main/java/ai/mlc/mlcllm/ChatModule.java b/android/mlc4j/src/main/java/ai/mlc/mlcllm/ChatModule.java similarity index 100% rename from android/library/src/main/java/ai/mlc/mlcllm/ChatModule.java rename to android/mlc4j/src/main/java/ai/mlc/mlcllm/ChatModule.java diff --git a/docs/compilation/package_model_libraries_weights.rst b/docs/compilation/package_model_libraries_weights.rst new file mode 100644 index 0000000000..0bab235eb4 --- /dev/null +++ b/docs/compilation/package_model_libraries_weights.rst @@ -0,0 +1,208 @@ +.. _package-model-libraries-weights: + +Package Model Libraries & Weights +================================= + +When we want to build LLM applications with MLC LLM (e.g., iOS/Android apps), +usually we need to build static model libraries and app binding libraries, +and sometimes bundle model weights into the app. +MLC LLM provides a tool for fast model library and weight packaging: ``mlc_llm package``. + +This page briefly introduces how to use ``mlc_llm package`` for packaging. +Tutorials :ref:`deploy-ios` and :ref:`deploy-android` contain detailed examples and instructions +on using this packaging tool for iOS and Android deployment. + +----- + +Introduction +------------ + +To use ``mlc_llm package``, we must clone the source code of `MLC LLM `_ +and `install the MLC LLM and TVM Unity package `_. +Depending on the app we build, there might be some other dependencies, which are described in +corresponding :ref:`iOS ` and :ref:`Android ` tutorials. + +After cloning, the basic usage of ``mlc_llm package`` is as the following. + +.. code:: bash + + export MLC_LLM_HOME=/path/to/mlc-llm + cd /path/to/app # The app root directory which contains "mlc-package-config.json". + # E.g., "ios/MLCChat" or "android/MLCChat" + mlc_llm package + +**The package command reads from the JSON file** ``mlc-package-config.json`` **under the current directory.** +The output of this command is a directory ``dist/``, +which contains the packaged model libraries (under ``dist/lib/``) and weights (under ``dist/bundle/``). +This directory contains all necessary data for the app build. +Depending on the app we build, the internal structure of ``dist/lib/`` may be different. + +.. code:: + + dist + ├── lib + │ └── ... + └── bundle + └── ... + +The input ``mlc-package-config.json`` file specifies + +* the device (e.g., iPhone or Android) to package model libraries and weights for, +* the list of models to package. + +Below is an example ``mlc-package-config.json`` file: + +.. code:: json + + { + "device": "iphone", + "model_list": [ + { + "model": "HF://mlc-ai/Mistral-7B-Instruct-v0.2-q3f16_1-MLC", + "model_id": "Mistral-7B-Instruct-v0.2-q3f16_1", + "estimated_vram_bytes": 3316000000, + "bundle_weight": true, + "overrides": { + "context_window_size": 512 + } + }, + { + "model": "HF://mlc-ai/gemma-2b-it-q4f16_1-MLC", + "model_id": "gemma-2b-q4f16_1", + "estimated_vram_bytes": 3000000000, + "overrides": { + "prefill_chunk_size": 128 + } + } + ] + } + +This example ``mlc-package-config.json`` specifies "iphone" as the target device. +In the ``model_list``, + +* ``model`` points to the Hugging Face repository which contains the pre-converted model weights. Apps will download model weights from the Hugging Face URL. +* ``model_id`` is a unique model identifier. +* ``estimated_vram_bytes`` is an estimation of the vRAM the model takes at runtime. +* ``"bundle_weight": true`` means the model weights of the model will be bundled into the app when building. +* ``overrides`` specifies some model config parameter overrides. + + +Below is a more detailed specification of the ``mlc-package-config.json`` file. +Each entry in ``"model_list"`` of the JSON file has the following fields: + +``model`` + (Required) The path to the MLC-converted model to be built into the app. + + Usually it is a Hugging Face URL (e.g., ``"model": "HF://mlc-ai/phi-2-q4f16_1-MLC"```) that contains the pre-converted model weights. + For iOS, it can also be a path to a local model directory which contains converted model weights (e.g., ``"model": "../dist/gemma-2b-q4f16_1"``). + Please check out :ref:`convert-weights-via-MLC` if you want to build local model into the app. + +``model_id`` + (Required) A unique local identifier to identify the model. + It can be an arbitrary one. + +``estimated_vram_bytes`` + (Required) Estimated requirements of vRAM to run the model. + +``bundle_weight`` + (Optional) A boolean flag indicating whether to bundle model weights into the app. + If this field is set to true, the ``mlc_llm package`` command will copy the model weights + to ``dist/bundle/$model_id``. + +``overrides`` + (Optional) A dictionary to override the default model context window size (to limit the KV cache size) and prefill chunk size (to limit the model temporary execution memory). + Example: + + .. code:: json + + { + "device": "iphone", + "model_list": [ + { + "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", + "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", + "estimated_vram_bytes": 2960000000, + "overrides": { + "context_window_size": 512, + "prefill_chunk_size": 128 + } + } + ] + } + +``model_lib`` + (Optional) A string specifying the system library prefix to use for the model. + Usually this is used when you want to build multiple model variants with the same architecture into the app. + **This field does not affect any app functionality.** + The ``"model_lib_path_for_prepare_libs"`` introduced below is also related. + Example: + + .. code:: json + + { + "device": "iphone", + "model_list": [ + { + "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", + "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", + "estimated_vram_bytes": 2960000000, + "model_lib": "gpt_neox_q4f16_1" + } + ] + } + + +Besides ``model_list`` in ``MLCChat/mlc-package-config.json``, +you can also **optionally** specify a dictionary of ``"model_lib_path_for_prepare_libs"``, +**if you want to use model libraries that are manually compiled**. +The keys of this dictionary should be the ``model_lib`` that specified in model list, +and the values of this dictionary are the paths (absolute, or relative) to the manually compiled model libraries. +The model libraries specified in ``"model_lib_path_for_prepare_libs"`` will be built into the app when running ``mlc_llm package``. +Example: + +.. code:: json + + { + "device": "iphone", + "model_list": [ + { + "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", + "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", + "estimated_vram_bytes": 2960000000, + "model_lib": "gpt_neox_q4f16_1" + } + ], + "model_lib_path_for_prepare_libs": { + "gpt_neox_q4f16_1": "../../dist/lib/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar" + } + } + + +Arguments of ``mlc_llm package`` +-------------------------------- + +Command ``mlc_llm package`` can optionally take the arguments below: + +``--package-config`` + A path to ``mlc-package-config.json`` which contains the device and model specification. + By default, it is the ``mlc-package-config.json`` under the current directory. + +``--mlc-llm-home`` + The path to MLC LLM source code (cloned from https://github.com/mlc-ai/mlc-llm). + By default, it is the ``$MLC_LLM_HOME`` environment variable. + If neither ``$MLC_LLM_HOME`` or ``--mlc-llm-home`` is specified, error will be reported. + +``--output`` / ``-o`` + The output directory of ``mlc_llm package`` command. + By default, it is ``dist/`` under the current directory. + + +Summary and What to Do Next +--------------------------- + +In this page, we introduced the ``mlc_llm package`` command for fast model library and weight packaging. + +* It takes input file ``mlc-package-config.json`` which contains the device and model specification for packaging. +* It outputs directory ``dist/``, which contains packaged libraries under ``dist/lib/`` and model weights under ``dist/bundle/``. + +Next, please feel free to check out the :ref:`iOS ` and :ref:`Android ` tutorials for detailed examples of using ``mlc_llm package``. diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst index a9b2fcb18f..0a0d66b704 100644 --- a/docs/deploy/android.rst +++ b/docs/deploy/android.rst @@ -1,6 +1,6 @@ .. _deploy-android: -Android App +Android SDK =========== .. contents:: Table of Contents @@ -35,11 +35,14 @@ Prerequisite ANDROID_NDK: $HOME/Library/Android/sdk/ndk/25.2.9519653 TVM_NDK_CC: $ANDROID_NDK/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android24-clang -**JDK**, such as OpenJDK >= 17, to compile Java bindings of TVM Unity runtime. It could be installed via Homebrew on macOS, apt on Ubuntu or other package managers. Set up the following environment variable: +**JDK**, such as OpenJDK >= 17, to compile Java bindings of TVM Unity runtime. +We recommended setting the ``JAVA_HOME`` to the JDK bundled with Android Studio. e.g. ``export JAVA_HOME=/Applications/Android\ Studio.app/Contents/jbr/Contents/Home`` for macOS. +In other ways, it could be installed via Homebrew on macOS, apt on Ubuntu or other package managers. +Set up the following environment variable: - ``JAVA_HOME`` so that Java is available in ``$JAVA_HOME/bin/java``. -Please ensure that the JDK versions for Android Studio and JAVA_HOME are the same. We recommended setting the `JAVA_HOME` to the JDK bundled with Android Studio. e.g. `export JAVA_HOME=/Applications/Android\ Studio.app/Contents/jbr/Contents/Home` for macOS. +Please ensure that the JDK versions for Android Studio and JAVA_HOME are the same. **TVM Unity runtime** is placed under `3rdparty/tvm `__ in MLC LLM, so there is no need to install anything extra. Set up the following environment variable: @@ -60,128 +63,258 @@ Check if **environment variable** are properly set as the last check. One way to export JAVA_HOME=... # Java export TVM_HOME=... # TVM Unity runtime -Compile PyTorch Models from HuggingFace ---------------------------------------- -To deploy models on Android with reasonable performance, one has to cross-compile to and fully utilize mobile GPUs using TVM Unity. MLC provides a few pre-compiled models, or one could compile the models on their own. +Build Android App from Source +----------------------------- -**Cloning MLC LLM from GitHub**. Download MLC LLM via the following command: +This section shows how we can build the app from the source. -.. code-block:: bash +Step 1. Install Build Dependencies +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - git clone --recursive https://github.com/mlc-ai/mlc-llm/ - ^^^^^^^^^^^ - cd ./mlc-llm/ +First and foremost, please clone the `MLC LLM GitHub repository `_. +After cloning, go to the ``android/`` directory. -.. note:: - ❗ The ``--recursive`` flag is necessary to download submodules like `3rdparty/tvm `__. If you see any file missing during compilation, please double check if git submodules are properly cloned. +.. code:: bash -**Download the PyTorch model** using Git Large File Storage (LFS), and by default, under ``./dist/models/``: + git clone https://github.com/mlc-ai/mlc-llm.git + cd mlc-llm + git submodule update --init --recursive + cd android -.. code-block:: bash - MODEL_NAME=Llama-2-7b-chat-hf - QUANTIZATION=q4f16_1 +.. _android-build-runtime-and-model-libraries: - git lfs install - git clone https://huggingface.co/meta-llama/$MODEL_NAME \ - ./dist/models/ +Step 2. Build Runtime and Model Libraries +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -**Compile Android-capable models**. Install TVM Unity compiler as a Python package, and then compile the model for android using the following commands: +The models to be built for the Android app are specified in ``MLCChat/mlc-package-config.json``: +in the ``model_list``, ``model`` points to the Hugging Face repository which -.. code-block:: bash +* ``model`` points to the Hugging Face repository which contains the pre-converted model weights. The Android app will download model weights from the Hugging Face URL. +* ``model_id`` is a unique model identifier. +* ``estimated_vram_bytes`` is an estimation of the vRAM the model takes at runtime. +* ``"bundle_weight": true`` means the model weights of the model will be bundled into the app when building. +* ``overrides`` specifies some model config parameter overrides. - # convert weights - mlc_llm convert_weight ./dist/models/$MODEL_NAME/ --quantization $QUANTIZATION -o dist/$MODEL_NAME-$QUANTIZATION-MLC/ - # create mlc-chat-config.json - mlc_llm gen_config ./dist/models/$MODEL_NAME/ --quantization $QUANTIZATION \ - --conv-template llama-2 --context-window-size 768 -o dist/${MODEL_NAME}-${QUANTIZATION}-MLC/ +We have a one-line command to build and prepare all the model libraries: - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/${MODEL_NAME}-${QUANTIZATION}-MLC/mlc-chat-config.json \ - --device android -o ./dist/${MODEL_NAME}-${QUANTIZATION}-MLC/${MODEL_NAME}-${QUANTIZATION}-android.tar +.. code:: bash -This generates the directory ``./dist/$MODEL_NAME-$QUANTIZATION-MLC`` which contains the necessary components to run the model, as explained below. + cd /path/to/MLCChat # e.g., "android/MLCChat" + export MLC_LLM_HOME=/path/to/mlc-llm # e.g., "../.." + mlc_llm package -.. note:: - ❗ To run 7B models like llama-2-7B, Mistral-7B, it is recommended to use smaller values of parameter ``--context-window-size`` (``--sliding-window-size`` and ``--prefill-chunk-size`` for sliding window attention) to reduce the memory footprint of the model. Default configurations for certains models can be found under the Android tab in the `Compile Models `_ section. - -**Expected output format**. By default models are placed under ``./dist/${MODEL_NAME}-${QUANTIZATION}-MLC``, and the result consists of 3 major components: +This command mainly executes the following two steps: -- Runtime configuration: It configures conversation templates including system prompts, repetition penalty, sampling including temperature and top-p probability, maximum sequence length, etc. It is usually named as ``mlc-chat-config.json`` alongside with tokenizer configurations. -- Model lib: The compiled library that uses mobile GPU. It is usually named as ``${MODEL_NAME}-${QUANTIZATION}-android.tar``, for example, ``Llama-2-7b-chat-hf-q4f16_1-android.tar``. -- Model weights: the model weights are sharded as ``params_shard_*.bin`` and the metadata is stored in ``ndarray-cache.json`` +1. **Compile models.** We compile each model in ``model_list`` of ``MLCChat/mlc-package-config.json`` into a binary model library. +2. **Build runtime and tokenizer.** In addition to the model itself, a lightweight runtime and tokenizer are required to actually run the LLM. -Create Android Project using Compiled Models --------------------------------------------- +The command creates a ``./dist/`` directory that contains the runtime and model build output. +Please make sure all the following files exist in ``./dist/``. -The source code for MLC LLM is available under ``android/``, including scripts to build dependencies. Enter the directory first: +.. code:: -.. code-block:: bash + dist + └── lib + └── mlc4j + ├── build.gradle + ├── output + │ ├── arm64-v8a + │ │ └── libtvm4j_runtime_packed.so + │ └── tvm4j_core.jar + └── src + ├── cpp + │ └── tvm_runtime.h + └── main + ├── AndroidManifest.xml + ├── assets + │ └── mlc-app-config.json + └── java + └── ai + └── mlc + └── mlcllm + └── ChatModule.java - cd ./android/library +The model execution logic in mobile GPUs is incorporated into ``libtvm4j_runtime_packed.so``, +while ``tvm4j_core.jar`` is a lightweight (~60 kb) `Java binding `_ to it. -**Build necessary dependencies.** Configure the list of models the app comes with using the JSON file ``app-config.json`` which contains two properties `model_list` and `model_lib_path_for_prepare_libs` ``model_lib_path_for_prepare_libs`` contains list of model library paths under `./dist/` that will be bundled with the apk. The ``model_list`` property contains data for models that are not bundled with the apk, but downloaded from the internet at run-time. Each model defined in `model_list` contain the following fields: -``model_url`` - (Required) URL to the repo containing the weights. - -``model_id`` - (Required) Unique local identifier to identify the model. - -``model_lib`` - (Required) Matches the system-lib-prefix, generally set during ``mlc_llm compile`` which can be specified using - ``--system-lib-prefix`` argument. By default, it is set to ``"${model_type}_${quantization}"`` e.g. ``gpt_neox_q4f16_1`` for the RedPajama-INCITE-Chat-3B-v1 model. If the ``--system-lib-prefix`` argument is manually specified during ``mlc_llm compile``, the ``model_lib`` field should be updated accordingly. - -``estimated_vram_bytes`` - (Optional) Estimated requirements of VRAM to run the model. - -To change the configuration, edit ``app-config.json``: +.. note:: -.. code-block:: bash + We leverage a local JIT cache to avoid repetitive compilation of the same input. + However, sometimes it is helpful to force rebuild when we have a new compiler update + or when something goes wrong with the ached library. + You can do so by setting the environment variable ``MLC_JIT_POLICY=REDO`` - vim ./src/main/assets/app-config.json + .. code:: bash -Then bundle the android library ``${MODEL_NAME}-${QUANTIZATION}-android.tar`` compiled from ``mlc_llm compile`` in the previous steps, with TVM Unity's Java runtime by running the commands below: + MLC_JIT_POLICY=REDO mlc_llm package -.. code-block:: bash - ./prepare_libs.sh +Step 3. Build Android App +^^^^^^^^^^^^^^^^^^^^^^^^^ -which generates the two files below: +Open folder ``./android`` as an Android Studio Project. +Connect your Android device to your machine. +In the menu bar of Android Studio, click **"Build → Make Project"**. +Once the build is finished, click **"Run → Run 'app'"** and you will see the app launched on your phone. -.. code-block:: bash +.. note:: + ❗ This app cannot be run in an emulator and thus a physical phone is required, because MLC LLM needs an actual mobile GPU to meaningfully run at an accelerated speed. - >>> find ./build/output -type f - ./build/output/arm64-v8a/libtvm4j_runtime_packed.so - ./build/output/tvm4j_core.jar -The model execution logic in mobile GPUs is incorporated into ``libtvm4j_runtime_packed.so``, while ``tvm4j_core.jar`` is a lightweight (~60 kb) `Java binding `_ to it. +Customize the App +----------------- -**Build the Android app**. Open folder ``./android`` as an Android Studio Project. Connect your Android device to your machine. In the menu bar of Android Studio, click "Build → Make Project". Once the build is finished, click "Run → Run 'app'" and you will see the app launched on your phone. +We can customize the models built in the Android app by customizing `MLCChat/mlc-package-config.json `_. +We introduce each field of the JSON file here. -.. note:: - ❗ This app cannot be run in an emulator and thus a physical phone is required, because MLC LLM needs an actual mobile GPU to meaningfully run at an accelerated speed. +Each entry in ``"model_list"`` of the JSON file has the following fields: -Incorporate Model Weights -------------------------- +``model`` + (Required) The path to the MLC-converted model to be built into the app. + It is a Hugging Face URL (e.g., ``"model": "HF://mlc-ai/phi-2-q4f16_1-MLC"```) that contains + the pre-converted model weights. -Instructions have been provided to build an Android App with MLC LLM in previous sections, but it requires run-time weight downloading from HuggingFace, as configured in `app-config.json` in previous steps under `model_url`. However, it could be desirable to bundle weights together into the app to avoid downloading over the network. In this section, we provide a simple ADB-based walkthrough that hopefully helps with further development. +``model_id`` + (Required) A unique local identifier to identify the model. + It can be an arbitrary one. -**Generating APK**. Enter Android Studio, and click "Build → Generate Signed Bundle/APK" to build an APK for release. If it is the first time you generate an APK, you will need to create a key according to `the official guide from Android `_. This APK will be placed under ``android/app/release/app-release.apk``. +``estimated_vram_bytes`` + (Required) Estimated requirements of vRAM to run the model. + +``bundle_weight`` + (Optional) A boolean flag indicating whether to bundle model weights into the app. See :ref:`android-bundle-model-weights` below. + +``overrides`` + (Optional) A dictionary to override the default model context window size (to limit the KV cache size) and prefill chunk size (to limit the model temporary execution memory). + Example: + + .. code:: json + + { + "device": "android", + "model_list": [ + { + "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", + "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", + "estimated_vram_bytes": 1948348579, + "overrides": { + "context_window_size": 512, + "prefill_chunk_size": 128 + } + } + ] + } -**Install ADB and USB debugging**. Enable "USB debugging" in the developer mode in your phone settings. In SDK manager, install `Android SDK Platform-Tools `_. Add the path to platform-tool path to the environment variable ``PATH``. Run the following commands, and if ADB is installed correctly, your phone will appear as a device: +``model_lib`` + (Optional) A string specifying the system library prefix to use for the model. + Usually this is used when you want to build multiple model variants with the same architecture into the app. + **This field does not affect any app functionality.** + The ``"model_lib_path_for_prepare_libs"`` introduced below is also related. + Example: + + .. code:: json + + { + "device": "android", + "model_list": [ + { + "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", + "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", + "estimated_vram_bytes": 1948348579, + "model_lib": "gpt_neox_q4f16_1" + } + ] + } + + +Besides ``model_list`` in ``MLCChat/mlc-package-config.json``, +you can also **optionally** specify a dictionary of ``"model_lib_path_for_prepare_libs"``, +**if you want to use model libraries that are manually compiled**. +The keys of this dictionary should be the ``model_lib`` that specified in model list, +and the values of this dictionary are the paths (absolute, or relative) to the manually compiled model libraries. +The model libraries specified in ``"model_lib_path_for_prepare_libs"`` will be built into the app when running ``mlc_llm package``. +Example: + +.. code:: json + + { + "device": "android", + "model_list": [ + { + "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", + "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", + "estimated_vram_bytes": 1948348579, + "model_lib": "gpt_neox_q4f16_1" + } + ], + "model_lib_path_for_prepare_libs": { + "gpt_neox_q4f16_1": "../../dist/lib/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar" + } + } + +.. _android-bundle-model-weights: + +Bundle Model Weights +-------------------- + +Instructions have been provided to build an Android App with MLC LLM in previous sections, +but it requires run-time weight downloading from HuggingFace, +as configured in ``MLCChat/mlc-package-config.json``. +However, it could be desirable to bundle weights together into the app to avoid downloading over the network. +In this section, we provide a simple ADB-based walkthrough that hopefully helps with further development. + +**Enable weight bundle**. +Set the field ``"bundle_weight": true`` for any model you want to bundle weights +in ``MLCChat/mlc-package-config.json``, and run ``mlc_llm package`` again. +Below is an example: + +.. code:: json + + { + "device": "android", + "model_list": [ + { + "model": "HF://mlc-ai/gemma-2b-it-q4f16_1-MLC", + "model_id": "gemma-2b-q4f16_1", + "estimated_vram_bytes": 3000000000, + "bundle_weight": true + } + ] + } + +The outcome of running ``mlc_llm package`` should be as follows: + +.. code:: + + dist + ├── bundle + │ ├── gemma-2b-q4f16_1 # The model weights that will be bundled into the app. + │ └── mlc-app-config.json + └── ... + + +**Generating APK**. Enter Android Studio, and click **"Build → Generate Signed Bundle/APK"** to build an APK for release. If it is the first time you generate an APK, you will need to create a key according to `the official guide from Android `_. +This APK will be placed under ``android/MLCChat/app/release/app-release.apk``. + +**Install ADB and USB debugging**. Enable "USB debugging" in the developer mode in your phone settings. +In "SDK manager - SDK Tools", install `Android SDK Platform-Tools `_. +Add the path to platform-tool path to the environment variable ``PATH`` (on macOS, it is ``$HOME/Library/Android/sdk/platform-tools``). +Run the following commands, and if ADB is installed correctly, your phone will appear as a device: .. code-block:: bash adb devices -**Install the APK and weights to your phone**. Run the commands below replacing ``${MODEL_NAME}`` and ``${QUANTIZATION}`` with the actual model name (e.g. Llama-2-7b-chat-hf) and quantization format (e.g. q4f16_1). +**Install the APK and weights to your phone**. +Run the commands below to install the app, and push the local weights to the app data directory on your device. +Once it finishes, you can start the MLCChat app on your device. +The models with ``bundle_weight`` set to true will have their weights already on device. .. code-block:: bash - adb install android/app/release/app-release.apk - adb push dist/${MODEL_NAME}-${QUANTIZATION}-MLC /data/local/tmp/${MODEL_NAME}-${QUANTIZATION}/ - adb shell "mkdir -p /storage/emulated/0/Android/data/ai.mlc.mlcchat/files/" - adb shell "mv /data/local/tmp/${MODEL_NAME}-${QUANTIZATION} /storage/emulated/0/Android/data/ai.mlc.mlcchat/files/" + cd /path/to/MLCChat # e.g., "android/MLCChat" + python bundle_weight.py --apk-path app/release/app-release.apk diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index b90c48a84d..02aaa55952 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -1,7 +1,7 @@ .. _deploy-ios: -iOS App and Swift API -===================== +iOS and Swift SDK +================= .. contents:: Table of Contents :local: @@ -53,41 +53,44 @@ Step 2. Build Runtime and Model Libraries ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The models to be built for the iOS app are specified in ``MLCChat/mlc-package-config.json``: -in the ``model_list`` field of this file, ``model`` points to the Hugging Face model repository, -where model weights are downloaded from. ``model_id`` is a unique model identifier. -``estimated_vram_bytes`` is an estimation of the vRAM the model takes at runtime. +in the ``model_list``, + +* ``model`` points to the Hugging Face repository which contains the pre-converted model weights. The iOS app will download model weights from the Hugging Face URL. +* ``model_id`` is a unique model identifier. +* ``estimated_vram_bytes`` is an estimation of the vRAM the model takes at runtime. +* ``"bundle_weight": true`` means the model weights of the model will be bundled into the app when building. +* ``overrides`` specifies some model config parameter overrides. + We have a one-line command to build and prepare all the model libraries: .. code:: bash - cd /path/to/MLCChat - ./prepare_package.sh + cd /path/to/MLCChat # e.g., "ios/MLCChat" + export MLC_LLM_HOME=/path/to/mlc-llm # e.g., "../.." + mlc_llm package This command mainly executes the following two steps: -1. **Build runtime and tokenizer.** In addition to the model itself, a lightweight runtime and tokenizer are required to actually run the LLM. -2. **Compile models.** We compile each model in ``model_list`` of ``MLCChat/mlc-package-config.json`` into a binary model library. +1. **Compile models.** We compile each model in ``model_list`` of ``MLCChat/mlc-package-config.json`` into a binary model library. +2. **Build runtime and tokenizer.** In addition to the model itself, a lightweight runtime and tokenizer are required to actually run the LLM. The command creates a ``./dist/`` directory that contains the runtime and model build output. -Please make sure all the following files exist in ``./dist/``. - -.. code:: bash - - >>> ls ./dist - bundle # The directory for mlc-app-config.json (and optionally model weights) - # that will be bundled into the iOS app. - lib # The directory for runtime and model libraries. +Please make sure ``dist/`` follows the structure below, except the optional model weights. - >>> ls ./dist/bundle - mlc-app-config.json # The app config JSON file. +.. code:: - >>> ls ./dist/lib - libmlc_llm.a # A lightweight interface to interact with LLM, tokenizer, and TVM Unity runtime - libmodel_iphone.a # The compiled model lib - libsentencepiece.a # SentencePiece tokenizer - libtokenizers_cpp.a # Huggingface tokenizer - libtvm_runtime.a # TVM Unity runtime + dist + ├── bundle # The directory for mlc-app-config.json (and optionally model weights) + │ │ # that will be bundled into the iOS app. + │ ├── mlc-app-config.json # The app config JSON file. + │ └── [optional model weights] + └── lib + ├── libmlc_llm.a # A lightweight interface to interact with LLM, tokenizer, and TVM Unity runtime. + ├── libmodel_iphone.a # The compiled model lib. + ├── libsentencepiece.a # SentencePiece tokenizer + ├── libtokenizers_cpp.a # Huggingface tokenizer. + └── libtvm_runtime.a # TVM Unity runtime. .. note:: @@ -99,7 +102,7 @@ Please make sure all the following files exist in ``./dist/``. .. code:: bash - MLC_JIT_POLICY=REDO ./prepare_package.sh + MLC_JIT_POLICY=REDO mlc_llm package .. _ios-bundle-model-weights: @@ -109,12 +112,13 @@ Step 3. (Optional) Bundle model weights into the app By default, we download the model weights from Hugging Face when running the app. **As an option,**, we bundle model weights into the app: set the field ``"bundle_weight": true`` for any model you want to bundle weights -in ``MLCChat/mlc-package-config.json``, and run ``prepare_package.sh`` again. +in ``MLCChat/mlc-package-config.json``, and run ``mlc_llm package`` again. Below is an example: .. code:: json { + "device": "iphone", "model_list": [ { "model": "HF://mlc-ai/gemma-2b-it-q4f16_1-MLC", @@ -128,13 +132,15 @@ Below is an example: ] } -The outcome of running ``prepare_package.sh`` should be as follows: +The outcome of running ``mlc_llm package`` should be as follows: -.. code:: bash +.. code:: - >>> ls ./dist/bundle - mlc-app-config.json - gemma-2b-it-q4f16_1-MLC # The model weights that will be bundled into the app. + dist + ├── bundle + │ ├── gemma-2b-q4f16_1 # The model weights that will be bundled into the app. + │ └── mlc-app-config.json + └── ... .. _ios-build-app: @@ -190,6 +196,7 @@ Each entry in ``"model_list"`` of the JSON file has the following fields: .. code:: json { + "device": "iphone", "model_list": [ { "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", @@ -213,6 +220,7 @@ Each entry in ``"model_list"`` of the JSON file has the following fields: .. code:: json { + "device": "iphone", "model_list": [ { "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", @@ -229,12 +237,13 @@ you can also **optionally** specify a dictionary of ``"model_lib_path_for_prepar **if you want to use model libraries that are manually compiled**. The keys of this dictionary should be the ``model_lib`` that specified in model list, and the values of this dictionary are the paths (absolute, or relative) to the manually compiled model libraries. -The model libraries specified in ``"model_lib_path_for_prepare_libs"`` will be built into the app when running ``prepare_package.sh``. +The model libraries specified in ``"model_lib_path_for_prepare_libs"`` will be built into the app when running ``mlc_llm package``. Example: .. code:: json { + "device": "iphone", "model_list": [ { "model": "HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", @@ -326,6 +335,7 @@ Finally, we add the model into the ``model_list`` of .. code:: json { + "device": "iphone", "model_list": [ { "model": "HF://mlc-ai/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC", @@ -346,9 +356,9 @@ Build Apps with MLC Swift API We also provide a Swift package that you can use to build your own app. The package is located under ``ios/MLCSwift``. -- First, create `mlc-package-config.json` and `prepare_package.sh` in your project folder. +- First, create ``mlc-package-config.json`` in your project folder. You do so by copying the files in MLCChat folder. - Run `prepare_package.sh` + Run ``mlc_llm package``. This will give us the necessary libraries under ``/path/to/project/dist``. - Under "Build phases", add ``/path/to/project/dist/bundle`` this will copying this folder into your app to include bundled weights and configs. diff --git a/docs/index.rst b/docs/index.rst index 2d5597d18e..f406908219 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -45,6 +45,7 @@ Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a compilation/convert_weights.rst compilation/compile_models.rst + compilation/package_model_libraries_weights.rst compilation/define_new_models.rst .. toctree:: diff --git a/ios/.gitignore b/ios/.gitignore index 31d064cacb..f75e36783f 100644 --- a/ios/.gitignore +++ b/ios/.gitignore @@ -1,2 +1,3 @@ xuserdata +MLCSwift/tvm_home *~ diff --git a/ios/MLCChat/README.md b/ios/MLCChat/README.md index 831d7eee73..f4f4820e24 100644 --- a/ios/MLCChat/README.md +++ b/ios/MLCChat/README.md @@ -2,5 +2,5 @@ Checkout [Documentation page](https://llm.mlc.ai/docs/deploy/ios.html) for more information. -- run `./prepare_package.sh` -- open the xcode project +- run `mlc_llm package` +- open the Xcode project diff --git a/ios/MLCChat/mlc-package-config.json b/ios/MLCChat/mlc-package-config.json index 66ca1379f7..094e6e0ddb 100644 --- a/ios/MLCChat/mlc-package-config.json +++ b/ios/MLCChat/mlc-package-config.json @@ -1,4 +1,5 @@ { + "device": "iphone", "model_list": [ { "model": "HF://mlc-ai/Mistral-7B-Instruct-v0.2-q3f16_1-MLC", diff --git a/ios/MLCChat/prepare_package.sh b/ios/MLCChat/prepare_package.sh deleted file mode 100755 index 6dedca46ae..0000000000 --- a/ios/MLCChat/prepare_package.sh +++ /dev/null @@ -1,10 +0,0 @@ -# This script does two things -# It calls prepare_libs.sh in $MLC_LLM_HOME/ios/ to setup the iOS package and build binaries -# It then calls mlc_llm package to setup the weight and library bundle -# Feel free to copy this file and mlc-package-config.json to your project - -MLC_LLM_HOME="${MLC_LLM_HOME:-../..}" -cd ${MLC_LLM_HOME}/ios && ./prepare_libs.sh $@ && cd - -mkdir -p dist/lib -cp ${MLC_LLM_HOME}/ios/build/lib/* dist/lib/ -python -m mlc_llm package mlc-package-config.json --device iphone -o dist diff --git a/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift index cf4d3dae53..26361977ce 100644 --- a/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift +++ b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift @@ -5,7 +5,8 @@ // example and quick testing purposes. // // To build this app, select target My Mac(Designed for iPad) and run -// Make sure you run prepare_package.sh first with "MLCChat" replaced by "MLCEngineExample" +// Make sure you run "mlc_llm package" first with "MLCChat" +// replaced by "MLCEngineExample" // to ensure the "dist/bundle" folder populates with the right model file // and we have the model lib packaged correctly import Foundation @@ -22,9 +23,9 @@ class AppState: ObservableObject { private let bundleURL = Bundle.main.bundleURL.appending(path: "bundle") // model path, this must match a builtin // file name in prepare_params.sh - private let modelPath = "Llama-3-8B-Instruct-q3f16_1-MLC" + private let modelPath = "llama3" // model lib identifier of within the packaged library - // make sure we run prepare_package.sh + // make sure we run "mlc_llm package" private let modelLib = "llama_q3f16_1" // this is a message to be displayed in app diff --git a/ios/MLCEngineExample/README.md b/ios/MLCEngineExample/README.md index 67bf06089b..2e930e497b 100644 --- a/ios/MLCEngineExample/README.md +++ b/ios/MLCEngineExample/README.md @@ -8,5 +8,5 @@ things may not yet be fully functioning and are subject to change Checkout [Documentation page](https://llm.mlc.ai/docs/deploy/ios.html) for more information. -- run `./prepare_package.sh` -- open the xcode project +- run `mlc_llm package` +- open the Xcode project diff --git a/ios/MLCEngineExample/mlc-package-config.json b/ios/MLCEngineExample/mlc-package-config.json index 066fe7fa10..6a3bcaaa5a 100644 --- a/ios/MLCEngineExample/mlc-package-config.json +++ b/ios/MLCEngineExample/mlc-package-config.json @@ -1,4 +1,5 @@ { + "device": "iphone", "model_list": [ { "model": "HF://mlc-ai/Llama-3-8B-Instruct-q3f16_1-MLC", diff --git a/ios/MLCEngineExample/prepare_package.sh b/ios/MLCEngineExample/prepare_package.sh deleted file mode 100755 index d1f022166d..0000000000 --- a/ios/MLCEngineExample/prepare_package.sh +++ /dev/null @@ -1,10 +0,0 @@ -# This script does two things -# It calls prepare_libs.sh in $MLC_LLM_HOME/ios/ to setup the iOS package and build binaries -# It then calls mlc_llm package to setup the weight and library bundle -# Feel free to copy this file and mlc-package-config.json to your project - -MLC_LLM_HOME="${MLC_LLM_HOME:-../..}" -cd ${MLC_LLM_HOME}/ios && ./prepare_libs.sh $@ && cd - -rm -rf dist/lib && mkdir -p dist/lib -cp ${MLC_LLM_HOME}/ios/build/lib/* dist/lib/ -python -m mlc_llm package mlc-package-config.json --device iphone -o dist diff --git a/ios/MLCSwift/tvm_home b/ios/MLCSwift/tvm_home deleted file mode 120000 index e15bf649f5..0000000000 --- a/ios/MLCSwift/tvm_home +++ /dev/null @@ -1 +0,0 @@ -../../3rdparty/tvm \ No newline at end of file diff --git a/ios/README.md b/ios/README.md index de94ee75a0..39f0e0b4b6 100644 --- a/ios/README.md +++ b/ios/README.md @@ -1,3 +1,3 @@ -# MLC-LLM IOS +# MLC-LLM iOS [Documentation page](https://llm.mlc.ai/docs/deploy/ios.html) diff --git a/ios/prepare_libs.sh b/ios/prepare_libs.sh index 58e6468637..ede58c32e0 100755 --- a/ios/prepare_libs.sh +++ b/ios/prepare_libs.sh @@ -1,5 +1,5 @@ # Command to prepare the mlc llm static libraries -# This command will be invoked by prepare_package.sh in the subfolder +# This command will be invoked by the "mlc_llm package" command function help { echo -e "OPTION:" echo -e " -s, --simulator Build for Simulator" @@ -7,6 +7,7 @@ function help { echo -e " -h, --help Prints this help\n" } +MLC_LLM_HOME="${MLC_LLM_HOME:-..}" is_simulator="false" arch="arm64" @@ -53,7 +54,7 @@ fi mkdir -p build/ && cd build/ -cmake ../..\ +cmake $MLC_LLM_HOME\ -DCMAKE_BUILD_TYPE=$type\ -DCMAKE_SYSTEM_NAME=iOS\ -DCMAKE_SYSTEM_VERSION=14.0\ @@ -71,5 +72,5 @@ cmake --build . --config release --target mlc_llm_static -j cmake --build . --target install --config release -j cd .. -rm -rf MLCSwift/tvm_home -ln -s ../../3rdparty/tvm MLCSwift/tvm_home +rm -rf $MLC_LLM_HOME/ios/MLCSwift/tvm_home +ln -s $MLC_LLM_HOME/3rdparty/tvm $MLC_LLM_HOME/ios/MLCSwift/tvm_home diff --git a/python/mlc_llm/cli/package.py b/python/mlc_llm/cli/package.py index f605858d67..b8c6b994c2 100644 --- a/python/mlc_llm/cli/package.py +++ b/python/mlc_llm/cli/package.py @@ -1,5 +1,6 @@ """Command line entrypoint of package.""" +import os from pathlib import Path from typing import Union @@ -22,6 +23,10 @@ def _parse_package_config(path: Union[str, Path]) -> Path: raise ValueError(f"Path {str(path)} is expected to be a JSON file.") return path + def _parse_mlc_llm_home(path: str) -> Path: + os.environ["MLC_LLM_HOME"] = path + return Path(path) + def _parse_output(path: Union[str, Path]) -> Path: path = Path(path) if not path.is_dir(): @@ -29,27 +34,34 @@ def _parse_output(path: Union[str, Path]) -> Path: return path parser.add_argument( - "package_config", + "--package-config", type=_parse_package_config, - help=HELP["config_package"] + " (required)", + default="mlc-package-config.json", + help=HELP["config_package"] + ' (default: "%(default)s")', ) parser.add_argument( - "--device", - type=str, - choices=["iphone", "android"], - required=True, - help=HELP["device_package"] + " (required)", + "--mlc-llm-home", + type=_parse_mlc_llm_home, + default=os.environ.get("MLC_LLM_HOME", None), + help=HELP["mlc_llm_home"] + " (default: the $MLC_LLM_HOME environment variable)", ) parser.add_argument( "--output", "-o", type=_parse_output, - required=True, - help=HELP["output_package"] + " (required)", + default="dist", + help=HELP["output_package"] + ' (default: "%(default)s")', ) parsed = parser.parse_args(argv) + if parsed.mlc_llm_home is None: + raise ValueError( + "MLC LLM home is not specified. " + "Please obtain a copy of MLC LLM source code by " + "cloning https://github.com/mlc-ai/mlc-llm, and set environment variable " + '"MLC_LLM_HOME=path/to/mlc-llm"' + ) package( package_config_path=parsed.package_config, - device=parsed.device, + mlc_llm_home=parsed.mlc_llm_home, output=parsed.output, ) diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index a9b8917990..50e5a3a69a 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -189,40 +189,39 @@ "--additional-models model_path_1:model_lib_1 model_path_2 ...". When the model lib of a model is not given, JIT model compilation will be activated to compile the model automatically. -""", +""".strip(), "gpu_memory_utilization_serve": """ A number in (0, 1) denoting the fraction of GPU memory used by the server in total. It is used to infer to maximum possible KV cache capacity. When it is unspecified, it defaults to 0.85. Under mode "local" or "interactive", the actual memory usage may be significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. -""", +""".strip(), "speculative_mode_serve": """ The speculative decoding mode. Right now three options are supported: - "disable", where speculative decoding is not enabled, - "small_draft", denoting the normal speculative decoding (small draft) style, - "eagle", denoting the eagle-style speculative decoding. The default mode is "disable". -""", +""".strip(), "spec_draft_length_serve": """ The number of draft tokens to generate in speculative proposal. The default values is 4. -""", +""".strip(), "engine_config_serve": """ The MLCEngine execution configuration. Currently speculative decoding mode is specified via engine config. For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=eagle'" to specify the eagle-style speculative decoding. Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. -""", +""".strip(), "config_package": """ The path to "mlc-package-config.json" which is used for package build. -See "ios/MLCChat/mlc-package-config.json" as an example. -""", - "device_package": """ -The device to build package for. -Options are ["iphone", "android"]. -""", +See "https://github.com/mlc-ai/mlc-llm/blob/main/ios/MLCChat/mlc-package-config.json" as an example. +""".strip(), + "mlc_llm_home": """ +The source code path to MLC LLM. +""".strip(), "output_package": """ The path of output directory for the package build outputs. -""", +""".strip(), } diff --git a/python/mlc_llm/interface/package.py b/python/mlc_llm/interface/package.py index d342ff589d..58ff119cc0 100644 --- a/python/mlc_llm/interface/package.py +++ b/python/mlc_llm/interface/package.py @@ -4,12 +4,11 @@ import json import os import shutil +import subprocess import sys from dataclasses import asdict from pathlib import Path -from typing import List, Literal - -from tvm.contrib import cc +from typing import Any, Dict, List, Literal from mlc_llm.chat_module import ChatConfig, _get_chat_config, _get_model_path from mlc_llm.interface import jit @@ -18,125 +17,14 @@ logging.enable_logging() logger = logging.getLogger(__name__) +SUPPORTED_DEVICES = ["iphone", "android"] -def _get_model_libs(lib_path: Path) -> List[str]: - """Get the model lib prefixes in the given static lib path.""" - global_symbol_map = cc.get_global_symbol_section_map(lib_path) - libs = [] - suffix = "___tvm_dev_mblob" - for name, _ in global_symbol_map.items(): - if name.endswith(suffix): - model_lib = name[: -len(suffix)] - if model_lib.startswith("_"): - model_lib = model_lib[1:] - libs.append(model_lib) - return libs - - -def validate_model_lib( # pylint: disable=too-many-locals - app_config_path: Path, - package_config_path: Path, - model_lib_path_for_prepare_libs: dict, - device: Literal["iphone", "android"], - output: Path, -) -> None: - """Validate the model lib prefixes of model libraries.""" - # pylint: disable=import-outside-toplevel,redefined-outer-name,shadowed-import,reimported - if device == "android": - from tvm.contrib import ndk as cc - else: - from tvm.contrib import cc - # pylint: enable=import-outside-toplevel,redefined-outer-name,shadowed-import,reimported - - with open(app_config_path, "r", encoding="utf-8") as file: - app_config = json.load(file) - - tar_list = [] - model_set = set() - - for model, model_lib_path in model_lib_path_for_prepare_libs.items(): - model_lib_path = os.path.join(model_lib_path) - lib_path_valid = os.path.isfile(model_lib_path) - if not lib_path_valid: - raise RuntimeError(f"Cannot find file {model_lib_path} as an {device} model library") - tar_list.append(model_lib_path) - model_set.add(model) - - os.makedirs(output / "lib", exist_ok=True) - lib_path = ( - output / "lib" / ("libmodel_iphone.a" if device == "iphone" else "libmodel_android.a") - ) - - cc.create_staticlib(lib_path, tar_list) - available_model_libs = _get_model_libs(lib_path) - logger.info("Creating lib from %s", str(tar_list)) - logger.info("Validating the library %s", str(lib_path)) - logger.info( - "List of available model libs packaged: %s," - " if we have '-' in the model_lib string, it will be turned into '_'", - str(available_model_libs), - ) - global_symbol_map = cc.get_global_symbol_section_map(lib_path) - error_happened = False - - for item in app_config["model_list"]: - model_lib = item["model_lib"] - model_id = item["model_id"] - if model_lib not in model_set: - # NOTE: this cannot happen under new setting - # since if model_lib is not included, it will be jitted - raise RuntimeError( - f"ValidationError: model_lib={model_lib} specified for model_id={model_id} " - "is not included in model_lib_path_for_prepare_libs argument, " - "This will cause the specific model not being able to load, " - f"model_lib_path_for_prepare_libs={model_lib_path_for_prepare_libs}" - ) - - model_prefix_pattern = model_lib.replace("-", "_") + "___tvm_dev_mblob" - if ( - model_prefix_pattern not in global_symbol_map - and "_" + model_prefix_pattern not in global_symbol_map - ): - # NOTE: no lazy format is ok since this is a slow pass - model_lib_path = model_lib_path_for_prepare_libs[model_lib] - log_msg = ( - "ValidationError:\n" - f"\tmodel_lib {model_lib} requested in {str(app_config_path)}" - f" is not found in {str(lib_path)}\n" - f"\tspecifically the model_lib for {model_lib_path}.\n" - f"\tcurrent available model_libs in {str(lib_path)}: {available_model_libs}\n" - f"\tThis can happen when we manually specified model_lib_path_for_prepare_libs" - f" in {str(package_config_path)}\n" - f"\tConsider remove model_lib_path_for_prepare_libs (so library can be jitted)" - "or check the compile command" - ) - logger.info(log_msg) - error_happened = True - - if not error_happened: - logger.info(style.green("Validation pass")) - else: - logger.info(style.red("Validation failed")) - sys.exit(255) - - -def package( # pylint: disable=too-many-locals,too-many-statements,too-many-branches - package_config_path: Path, - device: Literal["iphone", "android"], - output: Path, -) -> None: - """Python entrypoint of package.""" - # - Read package config. - with open(package_config_path, "r", encoding="utf-8") as file: - package_config = json.load(file) - if not isinstance(package_config, dict): - raise ValueError( - "The content of MLC package config is expected to be a dict with " - f'field "model_list". However, the content of "{package_config_path}" is not a dict.' - ) +def build_model_library( # pylint: disable=too-many-branches,too-many-locals,too-many-statements + package_config: Dict[str, Any], device: str, bundle_dir: Path, app_config_path: Path +) -> Dict[str, str]: + """Build model libraries. Return the dictionary of "library prefix to lib path".""" # - Create the bundle directory. - bundle_dir = output / "bundle" os.makedirs(bundle_dir, exist_ok=True) # Clean up all the directories in `output/bundle`. logger.info('Clean up all directories under "%s"', str(bundle_dir)) @@ -242,7 +130,7 @@ def package( # pylint: disable=too-many-locals,too-many-statements,too-many-bra ) ) # Overwrite the model weight directory in bundle. - bundle_model_weight_path = bundle_dir / model_path.name + bundle_model_weight_path = bundle_dir / model_id logger.info( "Bundle weight for %s, copy into %s", style.bold(model_id), @@ -251,7 +139,8 @@ def package( # pylint: disable=too-many-locals,too-many-statements,too-many-bra if bundle_model_weight_path.exists(): shutil.rmtree(bundle_model_weight_path) shutil.copytree(model_path, bundle_model_weight_path) - app_config_model_entry["model_path"] = model_path.name + if bundle_weight and device == "iphone": + app_config_model_entry["model_path"] = model_id else: app_config_model_entry["model_url"] = model.replace("HF://", "https://huggingface.co/") @@ -265,15 +154,217 @@ def package( # pylint: disable=too-many-locals,too-many-statements,too-many-bra {"model_list": app_config_model_list}, indent=2, ) - app_config_path = bundle_dir / "mlc-app-config.json" with open(app_config_path, "w", encoding="utf-8") as file: print(app_config_json_str, file=file) logger.info( - 'Dump the app config below to "dist/bundle/mlc-app-config.json":\n%s', + 'Dump the app config below to "%s":\n%s', + str(app_config_path), style.green(app_config_json_str), ) + return model_lib_path_for_prepare_libs + + +def validate_model_lib( # pylint: disable=too-many-locals + app_config_path: Path, + package_config_path: Path, + model_lib_path_for_prepare_libs: dict, + device: Literal["iphone", "android"], + output: Path, +) -> None: + """Validate the model lib prefixes of model libraries.""" + # pylint: disable=import-outside-toplevel,redefined-outer-name,shadowed-import,reimported + if device == "android": + from tvm.contrib import ndk as cc + else: + from tvm.contrib import cc + # pylint: enable=import-outside-toplevel,redefined-outer-name,shadowed-import,reimported + + with open(app_config_path, "r", encoding="utf-8") as file: + app_config = json.load(file) + + tar_list = [] + model_set = set() + + for model, model_lib_path in model_lib_path_for_prepare_libs.items(): + model_lib_path = os.path.join(model_lib_path) + lib_path_valid = os.path.isfile(model_lib_path) + if not lib_path_valid: + raise RuntimeError(f"Cannot find file {model_lib_path} as an {device} model library") + tar_list.append(model_lib_path) + model_set.add(model) + + os.makedirs(output / "lib", exist_ok=True) + lib_path = ( + output / "lib" / ("libmodel_iphone.a" if device == "iphone" else "libmodel_android.a") + ) + + def _get_model_libs(lib_path: Path) -> List[str]: + """Get the model lib prefixes in the given static lib path.""" + global_symbol_map = cc.get_global_symbol_section_map(lib_path) + libs = [] + suffix = "___tvm_dev_mblob" + for name, _ in global_symbol_map.items(): + if name.endswith(suffix): + model_lib = name[: -len(suffix)] + if model_lib.startswith("_"): + model_lib = model_lib[1:] + libs.append(model_lib) + return libs + + cc.create_staticlib(lib_path, tar_list) + available_model_libs = _get_model_libs(lib_path) + logger.info("Creating lib from %s", str(tar_list)) + logger.info("Validating the library %s", str(lib_path)) + logger.info( + "List of available model libs packaged: %s," + " if we have '-' in the model_lib string, it will be turned into '_'", + str(available_model_libs), + ) + global_symbol_map = cc.get_global_symbol_section_map(lib_path) + error_happened = False + + for item in app_config["model_list"]: + model_lib = item["model_lib"] + model_id = item["model_id"] + if model_lib not in model_set: + # NOTE: this cannot happen under new setting + # since if model_lib is not included, it will be jitted + raise RuntimeError( + f"ValidationError: model_lib={model_lib} specified for model_id={model_id} " + "is not included in model_lib_path_for_prepare_libs argument, " + "This will cause the specific model not being able to load, " + f"model_lib_path_for_prepare_libs={model_lib_path_for_prepare_libs}" + ) + + model_prefix_pattern = model_lib.replace("-", "_") + "___tvm_dev_mblob" + if ( + model_prefix_pattern not in global_symbol_map + and "_" + model_prefix_pattern not in global_symbol_map + ): + # NOTE: no lazy format is ok since this is a slow pass + model_lib_path = model_lib_path_for_prepare_libs[model_lib] + log_msg = ( + "ValidationError:\n" + f"\tmodel_lib {model_lib} requested in {str(app_config_path)}" + f" is not found in {str(lib_path)}\n" + f"\tspecifically the model_lib for {model_lib_path}.\n" + f"\tcurrent available model_libs in {str(lib_path)}: {available_model_libs}\n" + f"\tThis can happen when we manually specified model_lib_path_for_prepare_libs" + f" in {str(package_config_path)}\n" + f"\tConsider remove model_lib_path_for_prepare_libs (so library can be jitted)" + "or check the compile command" + ) + logger.info(log_msg) + error_happened = True + + if not error_happened: + logger.info(style.green("Validation pass")) + else: + logger.info(style.red("Validation failed")) + sys.exit(255) + + +def build_android_binding(mlc_llm_home: Path, output: Path) -> None: + """Build android binding in MLC LLM""" + mlc4j_path = mlc_llm_home / "android" / "mlc4j" + + # Move the model libraries to "build/lib/" for linking + os.makedirs(Path("build") / "lib", exist_ok=True) + src_path = str(output / "lib" / "libmodel_android.a") + dst_path = str(Path("build") / "lib" / "libmodel_android.a") + logger.info('Moving "%s" to "%s"', src_path, dst_path) + shutil.move(src_path, dst_path) + + # Build mlc4j + logger.info("Building mlc4j") + subprocess.run([sys.executable, mlc4j_path / "prepare_libs.py"], check=True, env=os.environ) + # Copy built files back to output directory. + lib_path = output / "lib" / "mlc4j" + os.makedirs(lib_path, exist_ok=True) + logger.info('Clean up all directories under "%s"', str(lib_path)) + for content_path in lib_path.iterdir(): + if content_path.is_dir(): + shutil.rmtree(content_path) + + src_path = str(mlc4j_path / "src") + dst_path = str(lib_path / "src") + logger.info('Copying "%s" to "%s"', src_path, dst_path) + shutil.copytree(src_path, dst_path) + + src_path = str(mlc4j_path / "build.gradle") + dst_path = str(lib_path / "build.gradle") + logger.info('Copying "%s" to "%s"', src_path, dst_path) + shutil.copy(src_path, dst_path) + + src_path = str(Path("build") / "output") + dst_path = str(lib_path / "output") + logger.info('Copying "%s" to "%s"', src_path, dst_path) + shutil.copytree(src_path, dst_path) + + os.makedirs(lib_path / "src" / "main" / "assets") + src_path = str(output / "bundle" / "mlc-app-config.json") + dst_path = str(lib_path / "src" / "main" / "assets" / "mlc-app-config.json") + logger.info('Moving "%s" to "%s"', src_path, dst_path) + shutil.move(src_path, dst_path) + + +def build_iphone_binding(mlc_llm_home: Path, output: Path) -> None: + """Build iOS binding in MLC LLM""" + # Build iphone binding + logger.info("Build iphone binding") + subprocess.run(["bash", mlc_llm_home / "ios" / "prepare_libs.sh"], check=True, env=os.environ) + + # Copy built libraries back to output directory. + for static_library in (Path("build") / "lib").iterdir(): + dst_path = str(output / "lib" / static_library.name) + logger.info('Copying "%s" to "%s"', static_library, dst_path) + shutil.copy(static_library, dst_path) + + +def package( + package_config_path: Path, + mlc_llm_home: Path, + output: Path, +) -> None: + """Python entrypoint of package.""" + logger.info('MLC LLM HOME: "%s"', mlc_llm_home) + + # - Read package config. + with open(package_config_path, "r", encoding="utf-8") as file: + package_config = json.load(file) + if not isinstance(package_config, dict): + raise ValueError( + "The content of MLC package config is expected to be a dict with " + f'field "model_list". However, the content of "{package_config_path}" is not a dict.' + ) + + # - Read device. + if "device" not in package_config: + raise ValueError(f'JSON file "{package_config_path}" is required to have field "device".') + device = package_config["device"] + if device not in SUPPORTED_DEVICES: + raise ValueError( + f'The "device" field of JSON file {package_config_path} is expected to be one of ' + f'{SUPPORTED_DEVICES}, while "{device}" is given in the JSON.' + ) + bundle_dir = output / "bundle" + app_config_path = bundle_dir / "mlc-app-config.json" + # - Build model libraries. + model_lib_path_for_prepare_libs = build_model_library( + package_config, device, bundle_dir, app_config_path + ) # - Validate model libraries. validate_model_lib( app_config_path, package_config_path, model_lib_path_for_prepare_libs, device, output ) + + # - Copy model libraries + if device == "android": + build_android_binding(mlc_llm_home, output) + elif device == "iphone": + build_iphone_binding(mlc_llm_home, output) + else: + assert False, "Cannot reach here" + + logger.info("All finished.") From b62dd91ddea3be4c6548d5c4836eb100ba119f33 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 9 May 2024 21:50:47 -0400 Subject: [PATCH 292/531] [DOCS] Minor cleanup (#2308) Shorten titles so they fit into one line of navbar, add mention of jit cache. Remote old project overview --- docs/compilation/compile_models.rst | 22 +++-- docs/compilation/convert_weights.rst | 3 +- ....rst => package_libraries_and_weights.rst} | 17 +++- docs/deploy/ide_integration.rst | 4 +- docs/deploy/ios.rst | 4 +- docs/deploy/mlc_chat_config.rst | 6 +- docs/deploy/rest.rst | 39 ++++---- docs/get_started/project_overview.rst | 88 ------------------- docs/index.rst | 2 +- 9 files changed, 59 insertions(+), 126 deletions(-) rename docs/compilation/{package_model_libraries_weights.rst => package_libraries_and_weights.rst} (93%) delete mode 100644 docs/get_started/project_overview.rst diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index a98de7d97a..1e18b8d441 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -3,18 +3,29 @@ Compile Model Libraries ======================= -To run a model with MLC LLM in any platform, you need: +To run a model with MLC LLM in any platform, we need: 1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC `__.) -2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). - -If you are simply adding a model variant, follow :ref:`convert-weights-via-MLC` suffices. +2. **Model library** that comprises the inference logic This page describes how to compile a model library with MLC LLM. Model compilation optimizes the model inference for a given platform, allowing users bring their own new model architecture, use different quantization modes, and customize the overall model optimization flow. + + +Notably, in many cases you do not need to explicit call compile. + +- If you are using the Python API, you can skip specifying ``model_lib`` and + the system will JIT compile the library. + +- If you are building iOS/android package, checkout :ref:`package-libraries-and-weights`, + which provides a simpler high-level command that leverages the compile behind the scheme. + + +This page is still helpful to understand the compilation flow behind the scheme, +or be used to explicit create model libraries. We compile ``RedPajama-INCITE-Chat-3B-v1`` with ``q4f16_1`` as an example for all platforms. .. note:: @@ -23,8 +34,7 @@ We compile ``RedPajama-INCITE-Chat-3B-v1`` with ``q4f16_1`` as an example for al Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-chat-module` to obtain the CLI app / Python API that can be used to chat with the compiled model. - Finally, we strongly recommend you to read :ref:`project-overview` first to get - familiarized with the high-level terminologies. + .. contents:: Table of Contents :depth: 1 diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst index e350ba4ac5..e9e57e14b1 100644 --- a/docs/compilation/convert_weights.rst +++ b/docs/compilation/convert_weights.rst @@ -26,8 +26,7 @@ This can be extended to, e.g.: Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-chat-module` to obtain the CLI app / Python API that can be used to chat with the compiled model. - Finally, we strongly recommend you to read :ref:`project-overview` first to get - familiarized with the high-level terminologies. + .. contents:: Table of Contents :depth: 1 diff --git a/docs/compilation/package_model_libraries_weights.rst b/docs/compilation/package_libraries_and_weights.rst similarity index 93% rename from docs/compilation/package_model_libraries_weights.rst rename to docs/compilation/package_libraries_and_weights.rst index 0bab235eb4..5e9679bb26 100644 --- a/docs/compilation/package_model_libraries_weights.rst +++ b/docs/compilation/package_libraries_and_weights.rst @@ -1,7 +1,7 @@ -.. _package-model-libraries-weights: +.. _package-libraries-and-weights: -Package Model Libraries & Weights -================================= +Package Libraries and Weights +============================= When we want to build LLM applications with MLC LLM (e.g., iOS/Android apps), usually we need to build static model libraries and app binding libraries, @@ -177,6 +177,17 @@ Example: } } +Compilation Cache +----------------- +``mlc_llm package`` leverage a local JIT cache to avoid repetitive compilation of the same input. +It also leverages a local cache to download weights from remote. These caches +are shared across the entire project. Sometimes it is helpful to force rebuild when +we have a new compiler update or when something goes wrong with the ached library. +You can do so by setting the environment variable ``MLC_JIT_POLICY=REDO`` + +.. code:: bash + + MLC_JIT_POLICY=REDO mlc_llm package Arguments of ``mlc_llm package`` -------------------------------- diff --git a/docs/deploy/ide_integration.rst b/docs/deploy/ide_integration.rst index 7e0735d8e0..89a9edb530 100644 --- a/docs/deploy/ide_integration.rst +++ b/docs/deploy/ide_integration.rst @@ -1,7 +1,7 @@ .. _deploy-ide-integration: -Code Completion IDE Integration -=============================== +IDE Integration +=============== .. contents:: Table of Contents :local: diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index 02aaa55952..8e481b5b3d 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -1,7 +1,7 @@ .. _deploy-ios: -iOS and Swift SDK -================= +iOS Swift SDK +============= .. contents:: Table of Contents :local: diff --git a/docs/deploy/mlc_chat_config.rst b/docs/deploy/mlc_chat_config.rst index 948d50bddd..3132323d8c 100644 --- a/docs/deploy/mlc_chat_config.rst +++ b/docs/deploy/mlc_chat_config.rst @@ -1,7 +1,7 @@ .. _configure-mlc-chat-json: -Customize MLC Config File in JSON -================================= +Customize MLC Chat Config +========================= ``mlc-chat-config.json`` is required for both compile-time and runtime, hence serving two purposes: @@ -112,7 +112,7 @@ Conversation Structure ^^^^^^^^^^^^^^^^^^^^^^ MLC-LLM provided a set of pre-defined conversation templates, which you can directly use by -specifying ``--conv-template [name]`` when generating config. Below is a list (not complete) of +specifying ``--conv-template [name]`` when generating config. Below is a list (not complete) of supported conversation templates: - ``llama-2`` diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index a82c914004..7351791bf1 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -73,6 +73,7 @@ MODEL The model folder after compiling with MLC-LLM build proce (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model folder. In the former case, we will use the provided name to search for the model folder over possible paths. + --model-lib A field to specify the full path to the model library file to use (e.g. a ``.so`` file). --device The description of the device to run on. User should provide a string in the form of 'device_name:device_id' or 'device_name', where 'device_name' is one of @@ -137,7 +138,7 @@ The REST API provides the following endpoints: - **name** (*Optional[str]*): An optional name for the sender of the message. - **tool_calls** (*Optional[List[ChatToolCall]]*): A list of calls to external tools or functions made within this message, applicable when the role is `tool`. - **tool_call_id** (*Optional[str]*): A unique identifier for the tool call, relevant when integrating external tools or services. - + - **model** (*str*, required): The model to be used for generating responses. - **frequency_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model’s likelihood to repeat tokens. @@ -183,51 +184,51 @@ The REST API provides the following endpoints: **ChatCompletionResponseChoice** - **finish_reason** (*Optional[Literal["stop", "length", "tool_calls", "error"]]*, optional): The reason the completion process was terminated. It can be due to reaching a stop condition, the maximum length, output of tool calls, or an error. - + - **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices. - + - **message** (*ChatCompletionMessage*, required): The message part of the chat completion, containing the content of the chat response. - + - **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token **ChatCompletionStreamResponseChoice** - **finish_reason** (*Optional[Literal["stop", "length", "tool_calls"]]*, optional): Specifies why the streaming completion process ended. Valid reasons are "stop", "length", and "tool_calls". - + - **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices. - + - **delta** (*ChatCompletionMessage*, required): Represents the incremental update or addition to the chat completion message in the stream. - + - **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token **ChatCompletionResponse** - **id** (*str*, required): A unique identifier for the chat completion session. - + - **choices** (*List[ChatCompletionResponseChoice]*, required): A collection of `ChatCompletionResponseChoice` objects, representing the potential responses generated by the model. - + - **created** (*int*, required, default=current time): The UNIX timestamp representing when the response was generated. - + - **model** (*str*, required): The name of the model used to generate the chat completions. - + - **system_fingerprint** (*str*, required): A system-generated fingerprint that uniquely identifies the computational environment. - + - **object** (*Literal["chat.completion"]*, required, default="chat.completion"): A string literal indicating the type of object, here always "chat.completion". - + - **usage** (*UsageInfo*, required, default=empty `UsageInfo` object): Contains information about the API usage for this specific request. **ChatCompletionStreamResponse** - **id** (*str*, required): A unique identifier for the streaming chat completion session. - + - **choices** (*List[ChatCompletionStreamResponseChoice]*, required): A list of `ChatCompletionStreamResponseChoice` objects, each representing a part of the streaming chat response. - + - **created** (*int*, required, default=current time): The creation time of the streaming response, represented as a UNIX timestamp. - + - **model** (*str*, required): Specifies the model that was used for generating the streaming chat completions. - + - **system_fingerprint** (*str*, required): A unique identifier for the system generating the streaming completions. - + - **object** (*Literal["chat.completion.chunk"]*, required, default="chat.completion.chunk"): A literal indicating that this object represents a chunk of a streaming chat completion. ------------------------------------------------ @@ -238,7 +239,7 @@ The REST API provides the following endpoints: Below is an example of using the API to interact with MLC-LLM in Python with Streaming. .. code:: bash - + import requests import json diff --git a/docs/get_started/project_overview.rst b/docs/get_started/project_overview.rst deleted file mode 100644 index ef631e40c8..0000000000 --- a/docs/get_started/project_overview.rst +++ /dev/null @@ -1,88 +0,0 @@ -.. _project-overview: - -Project Overview -================ - -This page introduces high-level project concepts to help us use and customize MLC LLM. -The MLC-LLM project consists of three distinct submodules: model definition, model compilation, and runtimes. - -.. figure:: /_static/img/project-structure.svg - :width: 600 - :align: center - :alt: Project Structure - - Three independent submodules in MLC LLM - -**➀ Model definition in Python.** MLC offers a variety of pre-defined architectures, such as Llama (e.g., Llama2, Vicuna, OpenLlama, Wizard), GPT-NeoX (e.g., RedPajama, Dolly), RNNs (e.g., RWKV), and GPT-J (e.g., MOSS). Model developers could solely define the model in pure Python, without having to touch code generation and runtime. - -**➁ Model compilation in Python.** Models are compiled by :doc:`TVM Unity ` compiler, where the compilation is configured in pure Python. MLC LLM quantizes and exports the Python-based model to a model library and quantized model weights. Quantization and optimization algorithms can be developed in pure Python to compress and accelerate LLMs for specific usecases. - -**➂ Platform-native runtimes.** Variants of MLCChat are provided on each platform: **C++** for command line, **Javascript** for web, **Swift** for iOS, and **Java** for Android, configurable with a JSON chat config. App developers only need to familiarize with the platform-naive runtimes to integrate MLC-compiled LLMs into their projects. - -.. _terminologies: - -Terminologies -------------- - -It is helpful for us to familiarize the basic terminologies used in the MLC chat applications. Below are the -three things you need to run a model with MLC. - -- **model lib**: The model library refers to the executable libraries that enable - the execution of a specific model architecture. On Linux and M-chip macOS, these libraries have the suffix - ``.so``; on intel macOS, the suffix is ``.dylib``; on Windows, the library file ends with ``.dll``; - on web browser, the library suffix is ``.wasm``. (see `binary-mlc-llm-libs `__). - -- **model weights**: The model weight is a folder that contains the quantized neural network weights - of the language models as well as the tokenizer configurations. (e.g. `Llama-2-7b-chat-hf-q4f16_1-MLC `__) - -- **chat config**: The chat configuration includes settings that allow customization of parameters such as temperature and system prompt. - The default chat config usually resides in the same directory as model weights. (e.g. see ``Llama-2-7b-chat-hf-q4f16_1``'s - `mlc-chat-config.json `__) - -Model Preparation ------------------ - - -There are several ways to prepare the model weights and model lib. - -- :ref:`Model Prebuilts` contains models that can be directly used. -- You can also :doc:`run model compilation ` for model weight variants for given supported architectures. -- Finally, you can incorporate a new model architecture/inference logic following :doc:`Define New Models `. - -A default chat config usually comes with the model weight directory. You can further customize -the system prompt, temperature, and other options by modifying the JSON file. -MLC chat runtimes also provide API to override these options during model reload. -Please refer to :ref:`configure-mlc-chat-json` for more details. - - -Runtime Flow Overview ---------------------- - -Once the model weights, model library, and chat configuration are prepared, an MLC chat runtime can consume them as an engine to drive a chat application. -The diagram below shows a typical workflow for a MLC chat application. - -.. image:: https://raw.githubusercontent.com/mlc-ai/web-data/a05d4598bae6eb5a3133652d5cc0323ced3b0e17/images/mlc-llm/tutorials/mlc-llm-flow-slm.svg - :width: 90% - :align: center - -On the right side of the figure, you can see pseudo-code illustrating the structure of an MLC chat API during the execution of a chat app. -Typically, there is a ``ChatModule`` that manages the model. We instantiate the chat app with two files: the model weights (which include an ``mlc-chat-config.json``) -and the model library. We also have an optional chat configuration, which allows for overriding settings such as the system prompt and temperature. - -All MLC runtimes, including iOS, Web, CLI, and others, use these three elements. -All the runtime can read the same model weight folder. The packaging of the model libraries may vary depending on the runtime. -For the CLI, the model libraries are stored in a DLL directory. -iOS and Android include pre-packaged model libraries within the app due to dynamic loading restrictions. -WebLLM utilizes URLs of local or Internet-hosted WebAssembly (Wasm) files. - -What to Do Next ---------------- - -Thank you for reading and learning the high-level concepts. -Moving next, feel free to check out documents on the left navigation panel and -learn about topics you are interested in. - -- :ref:`configure-mlc-chat-json` shows how to configure specific chat behavior. -- Build and Deploy App section contains guides to build apps - and platform-specific MLC chat runtimes. -- Compile models section provides guidelines to convert model weights and produce model libs. diff --git a/docs/index.rst b/docs/index.rst index f406908219..1180d00be9 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -45,7 +45,7 @@ Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a compilation/convert_weights.rst compilation/compile_models.rst - compilation/package_model_libraries_weights.rst + compilation/package_libraries_and_weights.rst compilation/define_new_models.rst .. toctree:: From 37230db673bfe658b3726a864b7a3c49cffc20c5 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 9 May 2024 22:36:14 -0400 Subject: [PATCH 293/531] [DOCS] Update android doc (#2309) Avoid showing full tree and mention what the dist/lib/mlc4j stands for --- docs/deploy/android.rst | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst index 0a0d66b704..2a729349f1 100644 --- a/docs/deploy/android.rst +++ b/docs/deploy/android.rst @@ -132,13 +132,18 @@ Please make sure all the following files exist in ``./dist/``. ├── assets │ └── mlc-app-config.json └── java - └── ai - └── mlc - └── mlcllm - └── ChatModule.java + └── ... The model execution logic in mobile GPUs is incorporated into ``libtvm4j_runtime_packed.so``, -while ``tvm4j_core.jar`` is a lightweight (~60 kb) `Java binding `_ to it. +while ``tvm4j_core.jar`` is a lightweight (~60 kb) `Java binding `_ +to it. ``dist/lib/mlc4j`` is a gradle subproject that you should include in your app +so the Android project can reference the mlc4j (MLC LLM java library). +This library packages the dependent model libraries and necessary runtime to execute the model. + +.. code:: + + include ':mlc4j' + project(':mlc4j').projectDir = file('dist/lib/mlc4j') .. note:: From 8bb1d6e26443dfe721245b66bd2815e972ab4d32 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 9 May 2024 22:46:21 -0400 Subject: [PATCH 294/531] [DOCS] Update android doc (#2310) Avoid showing full tree and mention what the dist/lib/mlc4j stands for Avoid python3 instead directly use python, since python3 sometimes will points to system python. --- docs/install/mlc_llm.rst | 18 ++++++++++-------- docs/install/tvm.rst | 14 +++++++------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index ce15616957..398a23c54a 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -17,6 +17,7 @@ Select your operating system/compute platform and run the command in your termin .. note:: ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts. + Please make sure your conda environment has Python and pip installed. .. tabs:: @@ -29,35 +30,35 @@ Select your operating system/compute platform and run the command in your termin .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly .. tab:: CUDA 12.1 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu121 mlc-ai-nightly-cu121 + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu121 mlc-ai-nightly-cu121 .. tab:: CUDA 12.2 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu122 mlc-ai-nightly-cu122 + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu122 mlc-ai-nightly-cu122 .. tab:: ROCm 5.6 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-rocm56 mlc-ai-nightly-rocm56 + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-rocm56 mlc-ai-nightly-rocm56 .. tab:: ROCm 5.7 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-rocm57 mlc-ai-nightly-rocm57 + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-rocm57 mlc-ai-nightly-rocm57 .. tab:: Vulkan @@ -94,7 +95,7 @@ Select your operating system/compute platform and run the command in your termin .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly .. note:: @@ -115,9 +116,10 @@ Select your operating system/compute platform and run the command in your termin .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly .. note:: + Please make sure your conda environment comes with python and pip. Make sure you also install vulkan loader and clang to avoid vulkan not found error or clang not found(needed for jit compile) @@ -195,7 +197,7 @@ This step is useful when you want to make modification or obtain a specific vers # create build directory mkdir -p build && cd build # generate build configuration - python3 ../cmake/gen_cmake_config.py + python ../cmake/gen_cmake_config.py # build mlc_llm libraries cmake .. && cmake --build . --parallel $(nproc) && cd .. diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst index ed4977e5e3..591b5e89a3 100644 --- a/docs/install/tvm.rst +++ b/docs/install/tvm.rst @@ -37,35 +37,35 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly .. tab:: CUDA 12.1 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu121 + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu121 .. tab:: CUDA 12.2 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu122 + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu122 .. tab:: ROCm 5.6 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm56 + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm56 .. tab:: ROCm 5.7 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm57 + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm57 .. tab:: Vulkan @@ -88,7 +88,7 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly .. note:: @@ -109,7 +109,7 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. .. code-block:: bash conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly + python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly .. note:: Make sure you also install vulkan loader and clang to avoid vulkan From 459ffe3907353d3b1c4de3ca3b3d74e28efb31ab Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Thu, 9 May 2024 23:55:40 -0400 Subject: [PATCH 295/531] [SLM] Support BERT architecture. Implement a text embedding module (#2249) --- python/mlc_llm/embeddings/embeddings.py | 181 ++++++++++++ python/mlc_llm/model/bert/__init__.py | 0 python/mlc_llm/model/bert/bert_loader.py | 86 ++++++ python/mlc_llm/model/bert/bert_model.py | 262 ++++++++++++++++++ .../mlc_llm/model/bert/bert_quantization.py | 53 ++++ python/mlc_llm/model/model.py | 15 + python/mlc_llm/model/model_preset.py | 20 ++ python/mlc_llm/op/attention.py | 1 - 8 files changed, 617 insertions(+), 1 deletion(-) create mode 100644 python/mlc_llm/embeddings/embeddings.py create mode 100644 python/mlc_llm/model/bert/__init__.py create mode 100644 python/mlc_llm/model/bert/bert_loader.py create mode 100644 python/mlc_llm/model/bert/bert_model.py create mode 100644 python/mlc_llm/model/bert/bert_quantization.py diff --git a/python/mlc_llm/embeddings/embeddings.py b/python/mlc_llm/embeddings/embeddings.py new file mode 100644 index 0000000000..c43b24df9c --- /dev/null +++ b/python/mlc_llm/embeddings/embeddings.py @@ -0,0 +1,181 @@ +"""The Python API for MLC Embeddings.""" + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import tvm +from tvm import relax +from tvm.contrib import tvmjs +from tvm.runtime import Device, Module +from tvm.runtime.relax_vm import VirtualMachine + +from mlc_llm.chat_module import _get_model_path +from mlc_llm.serve import engine_utils +from mlc_llm.support.auto_device import detect_device +from mlc_llm.tokenizer import Tokenizer + + +def _extract_metadata(mod: Module): + return json.loads(VirtualMachine(mod, tvm.runtime.device("cpu"))["_metadata"]()) + + +def _load_params( + model_weight_path: str, device: Device, model_metadata: Dict[str, Any] +) -> List[tvm.nd.NDArray]: + params, meta = tvmjs.load_ndarray_cache(model_weight_path, device) + param_names = [param["name"] for param in model_metadata["params"]] + assert len(param_names) == meta["ParamSize"] + + plist = [] + for param_name in param_names: + plist.append(params[param_name]) + return plist + + +def _get_tvm_module( + model_weight_path: str, lib_path: str, device: Device, instrument: tvm.runtime.PackedFunc = None +): + ex = tvm.runtime.load_module(lib_path) + vm = relax.VirtualMachine(ex, device) + if instrument: + vm.set_instrument(instrument) + metadata = _extract_metadata(ex) + params = _load_params(model_weight_path, device, metadata) + return vm.module, params, metadata + + +class DefaultDebugInstrument: + """The default debug instrument to use if users don't specify + a customized one. + + This debug instrument will dump the arguments and output of each + VM Call instruction into a .npz file. It will also alert the user + if any function outputs are NaN or INF. + """ + + def __init__(self, debug_out: Path): + """Constructor + + Parameters + ---------- + debug_out : Path + the directory to dump the .npz files + """ + self.counter = 0 + self.first_nan_occurred = False + self.first_inf_occurred = False + self.debug_out = debug_out + debug_out.mkdir(exist_ok=True, parents=True) + + def reset(self, debug_out: Path): + """Reset the state of the Instrument class + + Parameters + ---------- + debug_out : Path + the directory to dump the .npz files + """ + self.counter = 0 + self.first_nan_occurred = False + self.first_inf_occurred = False + self.debug_out = debug_out + debug_out.mkdir(exist_ok=True, parents=True) + + def __call__(self, func, name, before_run, ret_val, *args): + # Determine what functions to look at + if before_run: # Whether before the function is called or after + return + if name.startswith("vm.builtin.") and "attention_with_fused_qkv" not in name: + return + + # Decide what to print or save about the function's arguments (where args[-1] is the + # buffer we write the result to) + func_name = f"f{self.counter}_{name}" + + # Save the arguments to npz + arg_dict = {} + for i, arg in enumerate(args): + if isinstance(arg, tvm.nd.NDArray): + arg_dict[f"arg_{i}"] = arg.numpy() + + np.savez(self.debug_out / f"{func_name}.npz", **arg_dict) + + self.counter += 1 + + +class MLCEmbeddings: # pylint: disable=too-few-public-methods + """A class to embed queries using MLC LLM encoder models. + + Parameters + ---------- + model: str + The model folder after compiling with MLC-LLM build process. The parameter + can either be the model name with its quantization scheme + (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model + folder. In the former case, we will use the provided name to search + for the model folder over possible paths. + + model_lib_path : str + The full path to the model library file to use (e.g. a ``.so`` file). + + device : Optional[str] + The description of the device to run on. User should provide a string in the + form of 'device_name:device_id' or 'device_name', where 'device_name' is one of + 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the + local device), and 'device_id' is the device id to run on. If no 'device_id' + is provided, it will be set to 0 by default. + + debug_dir: Path + The output folder to store the dumped debug files. If None, will not dump any debug files. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + model: str, + model_lib_path: str, + device: Optional[str] = "auto", + debug_dir: Optional[str] = None, + ): + self.device = detect_device(device) + instrument = DefaultDebugInstrument(Path(debug_dir)) if debug_dir else None + self.mod, self.params, self.metadata = _get_tvm_module( + model, model_lib_path, self.device, instrument + ) + self.model_path, _ = _get_model_path(model) + self.tokenizer = Tokenizer(self.model_path) + self.prefill_func = self.mod["prefill"] + + def embed(self, queries: List[str]) -> tvm.runtime.NDArray: + """ + Embeds a list of queries in a single batch. + + Parameters + ---------- + queries : List[str] + A list of queries to embed. + + Returns + ------- + List[float] + A list of embeddings for the queries. + """ + tokens, attention_mask = self._tokenize_queries(queries) + tokens_tvm = tvm.nd.array(tokens.astype("int32"), device=self.device) + attention_mask_tvm = tvm.nd.array(attention_mask.astype("int32"), device=self.device) + output = self.prefill_func(tokens_tvm, attention_mask_tvm, self.params) + return output + + def _tokenize_queries(self, queries: List[str]) -> Tuple[np.ndarray, np.ndarray]: + tokens = engine_utils.process_prompts(queries, self.tokenizer.encode) # type: ignore + max_query_length = max(len(token_seq) for token_seq in tokens) + + token_inputs = np.zeros((len(tokens), max_query_length), dtype=np.int32) + attention_mask = np.zeros((len(tokens), max_query_length), dtype=np.int32) + + for i, token_seq in enumerate(tokens): + token_inputs[i, : len(token_seq)] = token_seq + attention_mask[i, : len(token_seq)] = 1 + + return token_inputs, attention_mask diff --git a/python/mlc_llm/model/bert/__init__.py b/python/mlc_llm/model/bert/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/bert/bert_loader.py b/python/mlc_llm/model/bert/bert_loader.py new file mode 100644 index 0000000000..12bf9406fc --- /dev/null +++ b/python/mlc_llm/model/bert/bert_loader.py @@ -0,0 +1,86 @@ +""" +This file specifies how MLC's BERT parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +import numpy as np + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .bert_model import BertConfig, BertModel + + +def huggingface(model_config: BertConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : BertConfig + The configuration of the BERT model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = BertModel(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + attn = f"encoder.layer.{i}.attention.self" + mlc_name = f"{attn}.qkv.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.query.weight", + f"{attn}.key.weight", + f"{attn}.value.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + mlc_name = f"{attn}.qkv.bias" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.query.bias", + f"{attn}.key.bias", + f"{attn}.value.bias", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + return mapping diff --git a/python/mlc_llm/model/bert/bert_model.py b/python/mlc_llm/model/bert/bert_model.py new file mode 100644 index 0000000000..504e0f3a03 --- /dev/null +++ b/python/mlc_llm/model/bert/bert_model.py @@ -0,0 +1,262 @@ +""" +Implementation for BERT architecture. +""" + +import dataclasses +from functools import partial +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class BertConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the BERT model.""" + + vocab_size: int + hidden_size: int + num_hidden_layers: int + num_attention_heads: int + intermediate_size: int + hidden_act: str + layer_norm_eps: float + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + head_dim: int = 0 + max_batch_size: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.intermediate_size is None or self.intermediate_size == -1: + self.intermediate_size = 4 * self.hidden_size + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + + +# pylint: disable=invalid-name,missing-docstring,too-many-locals + + +class BertSelfAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: BertConfig): + self.num_heads = config.num_attention_heads // config.tensor_parallel_shards + self.head_dim = config.head_dim + + self.qkv = nn.Linear( + in_features=config.hidden_size, + out_features=3 * self.num_heads * self.head_dim, + bias=True, + ) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + d, h = self.head_dim, self.num_heads + b, s, _ = hidden_states.shape + + qkv = self.qkv(hidden_states) + qkv = op.reshape(qkv, (b, s, 3 * h, d)) + q, k, v = op.split(qkv, 3, axis=2) + + # Attention + output = op_ext.attention(q, k, v, attention_mask) + return output + + +class BertSelfOutput(nn.Module): + def __init__(self, config: BertConfig): + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: Tensor, input_tensor: Tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config: BertConfig): + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + self_output = self.self(hidden_states, attention_mask) + attention_output = self.output(self_output, hidden_states) + return attention_output + + +ACT2FN = { + "gelu": partial(nn.gelu, approximate=False), + "relu": nn.relu, + "silu": nn.silu, + "swish": nn.silu, + "gelu_new": partial(nn.gelu, approximate=True), +} + + +class BertIntermediate(nn.Module): + def __init__(self, config: BertConfig): + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: Tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config: BertConfig): + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: Tensor, input_tensor: Tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config: BertConfig): + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config: BertConfig): + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + for layer in self.layer: + hidden_states = layer(hidden_states, attention_mask) + return hidden_states + + +class BertEmbeddings(nn.Module): + def __init__(self, config: BertConfig): + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, dtype="float32") + self.position_embeddings = nn.Embedding( + config.context_window_size, config.hidden_size, dtype="float32" + ) + self.token_type_embeddings = nn.Embedding(2, config.hidden_size, dtype="float32") + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, input_ids: Tensor, token_type_ids: Tensor, position_ids: Tensor): + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertModel(nn.Module): + def __init__(self, config: BertConfig): + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def forward(self, inputs: Tensor, attention_mask: Tensor): + def _input_positions(inputs: te.Tensor): + b, s = inputs.shape + return te.compute((b, s), lambda _, j: j.astype("int32"), name="input_positions") + + input_positions = op.tensor_expr_op( + _input_positions, + name_hint="input_positions", + args=[inputs], + ) + + token_type_ids = op.zeros(inputs.shape, dtype="int32") + + embeddings = self.embeddings(inputs, token_type_ids, input_positions) + encoder_output = self.encoder(embeddings, attention_mask) + return encoder_output + + def prefill(self, inputs: Tensor, attention_mask: Tensor): + def _attention_mask(mask: te.Tensor, zero, batch_size, seq_len): + return te.compute( + (batch_size, 1, seq_len, seq_len), + lambda b, _, i, j: tir.if_then_else( + tir.any(mask[b, i] == zero, mask[b, j] == zero), + tir.min_value(self.dtype), + tir.max_value(self.dtype), + ), + name="attention_mask_prefill", + ) + + batch_size, seq_len = inputs.shape + attention_mask_2d = op.tensor_expr_op( + _attention_mask, + name_hint="attention_mask_prefill", + args=[attention_mask, tir.IntImm("int32", 0), batch_size, seq_len], + ) + return self.forward(inputs, attention_mask_2d) + + def get_default_spec(self): + mod_spec = { + "prefill": { + "inputs": nn.spec.Tensor(["batch_size", "seq_len"], "int32"), + "attention_mask": nn.spec.Tensor(["batch_size", "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/bert/bert_quantization.py b/python/mlc_llm/model/bert/bert_quantization.py new file mode 100644 index 0000000000..5f6d86f5ab --- /dev/null +++ b/python/mlc_llm/model/bert/bert_quantization.py @@ -0,0 +1,53 @@ +"""This file specifies how MLC's BERT parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .bert_model import BertConfig, BertModel + + +def group_quant( + model_config: BertConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a BERT-architecture model using group quantization.""" + model: nn.Module = BertModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: BertConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a BERT-architecture model using FasterTransformer quantization.""" + model: nn.Module = BertModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: BertConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a BERT model without quantization.""" + model: nn.Module = BertModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 84d47ffd68..08d272f409 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -9,6 +9,7 @@ from mlc_llm.quantization.quantization import Quantization from .baichuan import baichuan_loader, baichuan_model, baichuan_quantization +from .bert import bert_loader, bert_model, bert_quantization from .chatglm3 import chatglm3_loader, chatglm3_model, chatglm3_quantization from .eagle import eagle_loader, eagle_model, eagle_quantization from .gemma import gemma_loader, gemma_model, gemma_quantization @@ -370,4 +371,18 @@ class Model: "awq": eagle_quantization.awq_quant, }, ), + "bert": Model( + name="bert", + model=bert_model.BertModel, + config=bert_model.BertConfig, + source={ + "huggingface-torch": bert_loader.huggingface, + "huggingface-safetensor": bert_loader.huggingface, + }, + quantize={ + "no-quant": bert_quantization.no_quant, + "group-quant": bert_quantization.group_quant, + "ft-quant": bert_quantization.ft_quant, + }, + ), } diff --git a/python/mlc_llm/model/model_preset.py b/python/mlc_llm/model/model_preset.py index a7276308b7..7473443f45 100644 --- a/python/mlc_llm/model/model_preset.py +++ b/python/mlc_llm/model/model_preset.py @@ -743,4 +743,24 @@ "use_cache": True, "vocab_size": 128256, }, + "bert": { + "architectures": ["BertModel"], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "vocab_size": 30522, + }, } diff --git a/python/mlc_llm/op/attention.py b/python/mlc_llm/op/attention.py index dc41a5f5ef..734edda89e 100644 --- a/python/mlc_llm/op/attention.py +++ b/python/mlc_llm/op/attention.py @@ -62,7 +62,6 @@ def attention( # pylint: disable=invalid-name,too-many-locals,too-many-statemen b, s, h_q, d = q.shape t, h_kv, _ = k.shape[-3:] group_size = h_q // h_kv - assert b == 1, "batch size must be 1" def _fallback(): nonlocal q, k, v, qk_dtype From ea391de4d601dd91500818c69c3d831d963bf607 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 10 May 2024 06:18:48 -0700 Subject: [PATCH 296/531] [Serving] Log batch size in NVTX (#2312) --- cpp/serve/model.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 0bd4126b40..9f3aa799f2 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -334,7 +334,7 @@ class ModelImpl : public ModelObj { } NDArray BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) final { - NVTXScopedRange nvtx_scope("BatchDecode"); + NVTXScopedRange nvtx_scope("BatchDecode num_seqs=" + std::to_string(seq_ids.size())); int num_sequence = seq_ids.size(); CHECK(ft_.decode_func_.defined()) @@ -395,7 +395,8 @@ class ModelImpl : public ModelObj { ObjectRef BatchDecodeToLastHidden(const ObjectRef& hidden_states_dref_or_nd, const std::vector& seq_ids) final { - NVTXScopedRange nvtx_scope("BatchDecodeToLastHidden"); + NVTXScopedRange nvtx_scope("BatchDecodeToLastHidden num_seqs=" + + std::to_string(seq_ids.size())); int num_sequence = seq_ids.size(); CHECK(ft_.decode_to_last_hidden_func_.defined()) @@ -443,7 +444,6 @@ class ModelImpl : public ModelObj { NDArray BatchVerify(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) final { - NVTXScopedRange nvtx_scope("BatchVerify"); CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); int num_sequences = seq_ids.size(); @@ -452,6 +452,8 @@ class ModelImpl : public ModelObj { total_length += lengths[i]; } + NVTXScopedRange nvtx_scope("BatchVerify num_tokens=" + std::to_string(total_length)); + CHECK(ft_.verify_func_.defined()) << "`verify_with_embed` function is not found in the model. Please make sure the model is " "compiled with flag `--sep-embed` and `--enable-batching`"; @@ -504,7 +506,6 @@ class ModelImpl : public ModelObj { ObjectRef BatchVerifyToLastHidden(const ObjectRef& embeddings, const std::vector& seq_ids, const std::vector& lengths) final { - NVTXScopedRange nvtx_scope("BatchVerifyToLastHidden"); CHECK(!seq_ids.empty()); CHECK_EQ(seq_ids.size(), lengths.size()); int num_sequences = seq_ids.size(); @@ -512,6 +513,8 @@ class ModelImpl : public ModelObj { for (int i = 0; i < num_sequences; ++i) { total_length += lengths[i]; } + NVTXScopedRange nvtx_scope("BatchVerifyToLastHidden num_tokens=" + + std::to_string(total_length)); CHECK(ft_.verify_to_last_hidden_func_.defined()) << "`batch_verify_to_last_hidden_states` function is not found in the model."; From b01cfab812d88ea627a6bbb86e6064dfd346e9ae Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 10 May 2024 06:19:05 -0700 Subject: [PATCH 297/531] [Model] Removing unnecessary reshapes in get_logits (#2314) --- cpp/serve/engine_actions/eagle_batch_draft.cc | 10 +++------- .../engine_actions/eagle_batch_verify.cc | 19 ++++++------------- .../eagle_new_request_prefill.cc | 5 +---- cpp/serve/model.cc | 8 ++------ cpp/serve/model.h | 4 +--- python/mlc_llm/model/llama/llama_model.py | 2 +- 6 files changed, 14 insertions(+), 34 deletions(-) diff --git a/cpp/serve/engine_actions/eagle_batch_draft.cc b/cpp/serve/engine_actions/eagle_batch_draft.cc index b4e7ec4c39..dfff7fe7a3 100644 --- a/cpp/serve/engine_actions/eagle_batch_draft.cc +++ b/cpp/serve/engine_actions/eagle_batch_draft.cc @@ -116,20 +116,16 @@ class EagleBatchDraftActionObj : public EngineActionObj { request_internal_ids); NDArray logits; if (models_[model_id]->CanGetLogits()) { - logits = models_[model_id]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, - /*seq_len*/ 1); + logits = models_[model_id]->GetLogits(hidden_states); } else { // - Use base model's head. - logits = - models_[0]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + logits = models_[0]->GetLogits(hidden_states); } RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); - ICHECK_EQ(logits->ndim, 3); + ICHECK_EQ(logits->ndim, 2); ICHECK_EQ(logits->shape[0], num_rsentries); - ICHECK_EQ(logits->shape[1], 1); // - Update logits. - logits = logits.CreateView({num_rsentries, logits->shape[2]}, logits->dtype); logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); // - Compute probability distributions. diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index 0f5fba4a5a..71daaf1bf9 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -114,16 +114,12 @@ class EagleBatchVerifyActionObj : public EngineActionObj { RECORD_EVENT(trace_recorder_, request_ids, "start verify"); ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden( embeddings, request_internal_ids, verify_lengths); - NDArray logits = - models_[verify_model_id_]->GetLogits(hidden_states, 1, cum_verify_lengths[num_rsentries]); + NDArray logits = models_[verify_model_id_]->GetLogits(hidden_states); RECORD_EVENT(trace_recorder_, request_ids, "finish verify"); - ICHECK_EQ(logits->ndim, 3); - ICHECK_EQ(logits->shape[0], 1); - ICHECK_EQ(logits->shape[1], cum_verify_lengths[num_rsentries]); + ICHECK_EQ(logits->ndim, 2); + ICHECK_EQ(logits->shape[0], cum_verify_lengths.back()); // - Update logits. - logits = - logits.CreateView({cum_verify_lengths[num_rsentries], logits->shape[2]}, logits->dtype); logit_processor_->InplaceUpdateLogits(logits, generation_cfg, verify_request_mstates, request_ids, &cum_verify_lengths, &draft_output_tokens); @@ -273,19 +269,16 @@ class EagleBatchVerifyActionObj : public EngineActionObj { fused_embedding_hidden_states, request_internal_ids); if (models_[draft_model_id_]->CanGetLogits()) { - logits = models_[draft_model_id_]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, - /*seq_len*/ 1); + logits = models_[draft_model_id_]->GetLogits(hidden_states); } else { // - Use base model's head. - logits = models_[0]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + logits = models_[0]->GetLogits(hidden_states); } RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); - ICHECK_EQ(logits->ndim, 3); + ICHECK_EQ(logits->ndim, 2); ICHECK_EQ(logits->shape[0], num_rsentries); - ICHECK_EQ(logits->shape[1], 1); // - Update logits. - logits = logits.CreateView({num_rsentries, logits->shape[2]}, logits->dtype); logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); // - Compute probability distributions. diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index 2844f76c6b..e2d2d661f8 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -183,8 +183,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates( hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states); // logits_for_sample: (b * s, v) - logits_for_sample = - models_[sample_model_id]->GetLogits(hidden_states_for_sample, 1, num_rsentries); + logits_for_sample = models_[sample_model_id]->GetLogits(hidden_states_for_sample); // - Update logits. ICHECK(logits_for_sample.defined()); Array generation_cfg; @@ -195,8 +194,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { generation_cfg.push_back(prefill_inputs[i].rsentry->request->generation_cfg); mstates_for_logitproc.push_back(prefill_inputs[i].rsentry->mstates[sample_model_id]); } - logits_for_sample = logits_for_sample.CreateView({num_rsentries, logits_for_sample->shape[2]}, - logits_for_sample->dtype); logit_processor_->InplaceUpdateLogits(logits_for_sample, generation_cfg, mstates_for_logitproc, request_ids); diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 9f3aa799f2..e16432c222 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -125,7 +125,7 @@ class ModelImpl : public ModelObj { return ft_.get_logits_func_.defined() && ft_.batch_get_logits_func_.defined(); } - NDArray GetLogits(const ObjectRef& hidden_states, int batch_size, int seq_len) final { + NDArray GetLogits(const ObjectRef& hidden_states) final { NVTXScopedRange nvtx_scope("GetLogits"); CHECK(ft_.get_logits_func_.defined()) << "`get_logits` function is not found in the model."; @@ -139,18 +139,14 @@ class ModelImpl : public ModelObj { if (trace_enabled_) { TVMSynchronize(device_.device_type, device_.device_id, nullptr); } - NDArray logits{nullptr}; if (ft_.use_disco) { logits = Downcast(ret)->DebugGetFromRemote(0); } else { logits = Downcast(ret); } - CHECK(logits.defined()); // logits: (b * s, v) - ICHECK_EQ(logits->ndim, 2); - ICHECK_EQ(logits->shape[0], batch_size * seq_len); - return logits.CreateView({batch_size, seq_len, logits->shape[1]}, logits->dtype); + return logits; } ObjectRef FuseEmbedHidden(const ObjectRef& embeddings, const ObjectRef& previous_hidden_states, diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 1ac4e4001c..96d2ecb401 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -135,11 +135,9 @@ class ModelObj : public Object { /*! * \brief Compute logits for last hidden_states. * \param last_hidden_states The last hidden_states to compute logits for. - * \param batch_size The batch size of last_hidden_states - * \param seq_len The length of tokens in last_hidden_states * \return The computed logits. */ - virtual NDArray GetLogits(const ObjectRef& last_hidden_states, int batch_size, int seq_len) = 0; + virtual NDArray GetLogits(const ObjectRef& last_hidden_states) = 0; /*! * \brief Batch prefill function. Embedding in, logits out. diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index cd99301132..1b76a92453 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -352,7 +352,7 @@ def get_default_spec(self): }, }, "get_logits": { - "hidden_states": nn.spec.Tensor(["batch_size", self.hidden_size], self.dtype), + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), "$": { "param_mode": "packed", "effect_mode": "none", From 347222cfc158e0a2cf28ac26e5de4f3e75d3778d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 10 May 2024 06:19:12 -0700 Subject: [PATCH 298/531] Skip cublas dispatch for single batch (#2315) --- python/mlc_llm/compiler_pass/cublas_dispatch.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/mlc_llm/compiler_pass/cublas_dispatch.py b/python/mlc_llm/compiler_pass/cublas_dispatch.py index 231048628c..f5af94cc4b 100644 --- a/python/mlc_llm/compiler_pass/cublas_dispatch.py +++ b/python/mlc_llm/compiler_pass/cublas_dispatch.py @@ -20,10 +20,14 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR model_names = [ gv.name_hint for gv, func in mod.functions.items() if isinstance(func, relax.Function) ] + model_names = [name for name in model_names if "batch" not in name] mod = tvm.transform.Sequential( [ relax.transform.FuseOpsByPattern( - patterns, bind_constants=False, annotate_codegen=True + patterns, + bind_constants=False, + annotate_codegen=True, + entry_functions=model_names, ), relax.transform.RunCodegen({}, entry_functions=model_names), ] From 73b733da20c8b8fbabe572f68bf79f52dd87d985 Mon Sep 17 00:00:00 2001 From: Git bot Date: Fri, 10 May 2024 14:12:22 +0000 Subject: [PATCH 299/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index ced07e8878..c8f7ec8dc0 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit ced07e88781c0d6416e276d9cd084bb46aaf3da5 +Subproject commit c8f7ec8dc0377ad362e1c81b194c6e2322f27a75 From 3a0b42c986bf923af32ce7da8fc44489d9c2ddb6 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 10 May 2024 13:49:52 -0400 Subject: [PATCH 300/531] [DOCS] Remove mention of legacy modules (#2318) This PR removes mention of legacy modules and prebuilt in favor of JIT. --- docs/community/guideline.rst | 2 +- docs/compilation/compile_models.rst | 2 +- docs/compilation/convert_weights.rst | 44 +- docs/compilation/define_new_models.rst | 8 +- docs/deploy/python_chat_module.rst | 369 ------------ docs/deploy/python_engine.rst | 2 - docs/index.rst | 7 - docs/prebuilt_models.rst | 773 ------------------------- 8 files changed, 16 insertions(+), 1191 deletions(-) delete mode 100644 docs/deploy/python_chat_module.rst delete mode 100644 docs/prebuilt_models.rst diff --git a/docs/community/guideline.rst b/docs/community/guideline.rst index 33e8982543..467ffe65eb 100644 --- a/docs/community/guideline.rst +++ b/docs/community/guideline.rst @@ -53,7 +53,7 @@ on GitHub directly. Once your update is complete, you can click the ``contribute Contribute New Models to MLC-LLM ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -* If you have compiled a model using our :doc:`/compilation/compile_models` tutorial for an existing model architecture, please upload your models to the internet (e.g., Hugging Face) by following :ref:`distribute-compiled-models` tutorial. Once you have done that, you can create a pull request to add an entry in the :doc:`/prebuilt_models` page. Additionally, you have the option to `create a speed report issue `__ to track the speed and memory consumption of your model. You don't need to test it on all devices; let the community collaborate on building it together! +* If you have compiled a model using our :doc:`/compilation/compile_models` tutorial for an existing model architecture, please upload your models to the internet (e.g., Hugging Face) by following :ref:`distribute-compiled-models` tutorial. * If you add a new model variant to MLC-LLM by following our :doc:`/compilation/define_new_models` tutorial. Please create a pull request to add your model architecture (currently model architectures are placed under diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index 1e18b8d441..a22981b20c 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -32,7 +32,7 @@ We compile ``RedPajama-INCITE-Chat-3B-v1`` with ``q4f16_1`` as an example for al Before you proceed, make sure you followed :ref:`install-tvm-unity`, a required backend to compile models with MLC LLM. - Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-chat-module` to obtain + Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-engine` to obtain the CLI app / Python API that can be used to chat with the compiled model. diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst index e9e57e14b1..667f0c2e6a 100644 --- a/docs/compilation/convert_weights.rst +++ b/docs/compilation/convert_weights.rst @@ -1,15 +1,11 @@ .. _convert-weights-via-MLC: -Convert Weights via MLC -======================= +Convert Model Weights +===================== -To run a model with MLC LLM in any platform, you need: - -1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC `_.) -2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). - -In many cases, we only need to convert weights and reuse existing model library. -This page demonstrates adding a model variant with ``mlc_llm convert_weight``, which +To run a model with MLC LLM, +we need to convert model weights into MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC `_.) +This page walks us through the process of adding a model variant with ``mlc_llm convert_weight``, which takes a hugginface model as input and converts/quantizes into MLC-compatible weights. Specifically, we add RedPjama-INCITE-**Instruct**-3B-v1, while MLC already @@ -24,7 +20,7 @@ This can be extended to, e.g.: Before you proceed, make sure you followed :ref:`install-tvm-unity`, a required backend to compile models with MLC LLM. - Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-chat-module` to obtain + Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-engine` to obtain the CLI app / Python API that can be used to chat with the compiled model. @@ -151,31 +147,11 @@ for **Instruct** instead of **Chat**. Good job, you have successfully distributed the model you compiled. Next, we will talk about how we can consume the model weights in applications. -Download the Distributed Models and Run in Python -------------------------------------------------- +Download the Distributed Models +------------------------------- -Running the distributed models are similar to running prebuilt model weights and libraries in :ref:`Model Prebuilts`. +You can now use the existing mlc tools such as chat/serve/package with the converted weights. .. code:: shell - # Clone prebuilt libs so we can reuse them: - mkdir -p dist/ - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs - - # Or download the model library (only needed if we do not reuse the model lib): - cd dist/prebuilt_libs - wget url-to-my-model-lib - cd ../.. - - # Download the model weights - cd dist - git clone https://huggingface.co/my-huggingface-account/my-redpajama3b-weight-huggingface-repo RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC - cd .. - - # Run the model in Python; note that we reuse `-Chat` model library - python - >>> from mlc_llm import ChatModule - >>> cm = ChatModule(model="dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC", \ - model_lib="dist/prebuilt_libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") # Adjust based on backend - >>> cm.generate("hi") - 'Hi! How can I assist you today?' + mlc_llm chat HF://my-huggingface-account/my-redpajama3b-weight-huggingface-repo diff --git a/docs/compilation/define_new_models.rst b/docs/compilation/define_new_models.rst index 4c73864104..92b3af8dde 100644 --- a/docs/compilation/define_new_models.rst +++ b/docs/compilation/define_new_models.rst @@ -4,7 +4,7 @@ Define New Model Architectures This page guides you how to add a new model architecture in MLC. This notebook (runnable in Colab) should contain all necessary information to add a model in -MLC LLM: +MLC LLM: https://github.com/mlc-ai/notebooks/blob/main/mlc-llm/tutorial_add_new_model_architecture_in_tvm_nn_module.ipynb In the notebook, we leverage ``tvm.nn.module`` to define a model in MLC LLM. We also use ``JIT`` @@ -16,10 +16,10 @@ You can also refer to the PRs below on specific examples of adding a model archi - `GPT-2 PR `_ - `Mistral PR `_ -.. note:: +.. note:: - As mentioned in :ref:`Model Prebuilts`, when adding a model variant that has - its architecture already supported in mlc-llm , you **only need to convert weights** + When adding a model variant that has + its architecture already supported in mlc-llm , you **only need to convert weights** (e.g. adding ``CodeLlama`` when MLC supports ``llama-2``; adding ``OpenHermes Mistral`` when MLC supports ``mistral``). On the other hand, a new model architecture (or inference logic) requires more work (following the tutorial above). \ No newline at end of file diff --git a/docs/deploy/python_chat_module.rst b/docs/deploy/python_chat_module.rst deleted file mode 100644 index 14e9f3ed03..0000000000 --- a/docs/deploy/python_chat_module.rst +++ /dev/null @@ -1,369 +0,0 @@ -.. _deploy-python-chat-module: - -Python API (Chat Module) -======================== - -.. note:: - ❗ The Python API with :class:`mlc_llm.ChatModule` introduced in this page will be - deprecated in the near future. - Please go to :ref:`deploy-python-engine` for the latest Python API with complete - OpenAI API support. - -.. contents:: Table of Contents - :local: - :depth: 2 - -We expose ChatModule Python API for the MLC-LLM for easy integration into other Python projects. - -The Python API is a part of the MLC-LLM package, which we have prepared pre-built pip wheels via -the :doc:`installation page <../install/mlc_llm>`. - -Instead of following this page, you could also checkout the following tutorials in -Python notebook (all runnable in Colab): - -- `Getting Started with MLC-LLM `_: - how to quickly download prebuilt models and chat with it -- `Raw Text Generation with MLC-LLM `_: - how to perform raw text generation with MLC-LLM in Python - -.. These notebooks are not up-to-date with SLM yet -.. - `Compiling Llama-2 with MLC-LLM `_: -.. how to use Python APIs to compile models with the MLC-LLM workflow -.. - `Extensions to More Model Variants `_: -.. how to use Python APIs to compile and chat with any model variant you'd like - - -Verify Installation -------------------- - -.. code:: bash - - python -c "from mlc_llm import ChatModule; print(ChatModule)" - -You are expected to see the information about the :class:`mlc_llm.ChatModule` class. - -If the command above results in error, follow :ref:`install-mlc-packages` (either install the prebuilt pip wheels -or :ref:`mlcchat_build_from_source`). - -Run MLC Models w/ Python ------------------------- - -To run a model with MLC LLM in any platform/runtime, you need: - -1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-MLC - `_.) -2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). - -There are two ways to obtain the model weights and libraries: - -1. Compile your own model weights and libraries following :doc:`the model compilation page `. -2. Use off-the-shelf `prebuilt models weights `__ and - `prebuilt model libraries `__ (see :ref:`Model Prebuilts` for details). - -We use off-the-shelf prebuilt models in this page. However, same steps apply if you want to run -the models you compiled yourself. - -**Step 1: Download prebuilt model weights and libraries** - -Skip this step if you have already obtained the model weights and libraries. - -.. code:: shell - - # Activate your conda environment - conda install -c conda-forge git-lfs - - # Download pre-conveted weights - git lfs install && mkdir dist/ - git clone https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC \ - dist/Llama-2-7b-chat-hf-q4f16_1-MLC - - # Download pre-compiled model library - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs - - -**Step 2: Run the model in Python** - -Use the conda environment you used to install ``mlc_llm``. -From the ``mlc-llm`` directory, you can create a Python -file ``sample_mlc_llm.py`` and paste the following lines: - -.. code:: python - - from mlc_llm import ChatModule - from mlc_llm.callback import StreamToStdout - - # Create a ChatModule instance - cm = ChatModule( - model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so - # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so - # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} - ) - - # You can change to other models that you downloaded - # Model variants of the same architecture can reuse the same model library - # Here WizardMath reuses Mistral's model library - # cm = ChatModule( - # model="dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", # or "dist/WizardMath-7B-V1.1-q4f16_1-MLC" - # model_lib="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" - # ) - - # Generate a response for a given prompt - output = cm.generate( - prompt="What is the meaning of life?", - progress_callback=StreamToStdout(callback_interval=2), - ) - - # Print prefill and decode performance statistics - print(f"Statistics: {cm.stats()}\n") - - output = cm.generate( - prompt="How many points did you list out?", - progress_callback=StreamToStdout(callback_interval=2), - ) - - # Reset the chat module by - # cm.reset_chat() - - -Now run the Python file to start the chat - -.. code:: bash - - python sample_mlc_llm.py - - -.. collapse:: See output - - .. code:: - - Using model folder: ./dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1 - Using mlc chat config: ./dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1/mlc-chat-config.json - Using library model: ./dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so - - Thank you for your question! The meaning of life is a complex and subjective topic that has been debated by philosophers, theologians, scientists, and many others for centuries. There is no one definitive answer to this question, as it can vary depending on a person's beliefs, values, experiences, and perspectives. - - However, here are some possible ways to approach the question: - - 1. Religious or spiritual beliefs: Many people believe that the meaning of life is to fulfill a divine or spiritual purpose, whether that be to follow a set of moral guidelines, to achieve spiritual enlightenment, or to fulfill a particular destiny. - 2. Personal growth and development: Some people believe that the meaning of life is to learn, grow, and evolve as individuals, to develop one's talents and abilities, and to become the best version of oneself. - 3. Relationships and connections: Others believe that the meaning of life is to form meaningful connections and relationships with others, to love and be loved, and to build a supportive and fulfilling social network. - 4. Contribution and impact: Some people believe that the meaning of life is to make a positive impact on the world, to contribute to society in a meaningful way, and to leave a lasting legacy. - 5. Simple pleasures and enjoyment: Finally, some people believe that the meaning of life is to simply enjoy the present moment, to find pleasure and happiness in the simple things in life, and to appreciate the beauty and wonder of the world around us. - - Ultimately, the meaning of life is a deeply personal and subjective question, and each person must find their own answer based on their own beliefs, values, and experiences. - - Statistics: prefill: 3477.5 tok/s, decode: 153.6 tok/s - - I listed out 5 possible ways to approach the question of the meaning of life. - -| - -**Running other models** - -Checkout the :doc:`/prebuilt_models` page to run other pre-compiled models. - -For models other than the prebuilt ones we provided: - -1. If the model is a variant to an existing model library (e.g. ``WizardMathV1.1`` and ``OpenHermes`` are variants of ``Mistral`` as - shown in the code snippet), follow :ref:`convert-weights-via-MLC` to convert the weights and reuse existing model libraries. -2. Otherwise, follow :ref:`compile-model-libraries` to compile both the model library and weights. - - -Configure MLCChat in Python ---------------------------- -If you have checked out :ref:`Configure MLCChat in JSON`, you would know -that you could configure MLCChat through various fields such as ``temperature``. We provide the -option of overriding any field you'd like in Python, so that you do not need to manually edit -``mlc-chat-config.json``. - -Since there are two concepts -- `MLCChat Configuration` and `Conversation Configuration` -- we correspondingly -provide two dataclasses :class:`mlc_llm.ChatConfig` and :class:`mlc_llm.ConvConfig`. - -We provide an example below. - -.. code:: python - - from mlc_llm import ChatModule, ChatConfig, ConvConfig - from mlc_llm.callback import StreamToStdout - - # Using a `ConvConfig`, we modify `system`, a field in the conversation template - # `system` refers to the prompt encoded before starting the chat - conv_config = ConvConfig(system_message='Please show as much happiness as you can when talking to me.') - - # We then include the `ConvConfig` instance in `ChatConfig` while overriding `max_gen_len` - # Note that `conv_config` is an optional subfield of `chat_config` - chat_config = ChatConfig(max_gen_len=256, conv_config=conv_config) - - # Using the `chat_config` we created, instantiate a `ChatModule` - cm = ChatModule( - chat_config=chat_config, - model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so - # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so - # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} - ) - - output = cm.generate( - prompt="What is one plus one?", - progress_callback=StreamToStdout(callback_interval=2), - ) - - # You could also pass in a `ConvConfig` instance to `reset_chat()` - conv_config = ConvConfig(system='Please show as much sadness as you can when talking to me.') - chat_config = ChatConfig(max_gen_len=128, conv_config=conv_config) - cm.reset_chat(chat_config) - - output = cm.generate( - prompt="What is one plus one?", - progress_callback=StreamToStdout(callback_interval=2), - ) - - -.. collapse:: See output - - .. code:: - - Using model folder: ./dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1 - Using mlc chat config: ./dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1/mlc-chat-config.json - Using library model: ./dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so - - Oh, wow, *excitedly* one plus one? *grinning* Well, let me see... *counting on fingers* One plus one is... *eureka* Two! - ... - - *Sobs* Oh, the tragedy of it all... *sobs* One plus one... *chokes back tears* It's... *gulps* it's... *breaks down in tears* TWO! - ... - -| - -.. note:: - You do not need to specify the entire ``ChatConfig`` or ``ConvConfig``. Instead, we will first - load all the fields defined in ``mlc-chat-config.json``, a file required when instantiating - a :class:`mlc_llm.ChatModule`. Then, we will load in the optional ``ChatConfig`` you provide, overriding the - fields specified. - - It is also worth noting that ``ConvConfig`` itself is overriding the original conversation template - specified by the field ``conv_template`` in the chat configuration. Learn more about it in - :ref:`Configure MLCChat in JSON`. - -Raw Text Generation in Python ------------------------------ - -Raw text generation allows the user to have more flexibility over his prompts, -without being forced to create a new conversational template, making prompt customization easier. -This serves other demands for APIs to handle LLM generation without the usual system prompts and other items. - -We provide an example below. - -.. code:: python - - from mlc_llm import ChatModule, ChatConfig, ConvConfig - from mlc_llm.callback import StreamToStdout - - # Use a `ConvConfig` to define the generation settings - # Since the "LM" template only supports raw text generation, - # System prompts will not be executed even if provided - conv_config = ConvConfig(stop_tokens=[2,], add_bos=True, stop_str="[INST]") - - # Note that `conv_config` is an optional subfield of `chat_config` - # The "LM" template serves the basic purposes of raw text generation - chat_config = ChatConfig(conv_config=conv_config, conv_template="LM") - - # Using the `chat_config` we created, instantiate a `ChatModule` - cm = ChatModule( - chat_config=chat_config, - model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so - # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so - # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} - ) - # To make the model follow conversations a chat structure should be provided - # This allows users to build their own prompts without building a new template - system_prompt = "<>\nYou are a helpful, respectful and honest assistant.\n<>\n\n" - inst_prompt = "What is mother nature?" - - # Concatenate system and instruction prompts, and add instruction tags - output = cm.generate( - prompt=f"[INST] {system_prompt+inst_prompt} [/INST]", - progress_callback=StreamToStdout(callback_interval=2), - ) - - # The LM template has no memory, so it will be reset every single generation - # In this case the model will just follow normal text completion - # because there isn't a chat structure - output = cm.generate( - prompt="Life is a quality that distinguishes", - progress_callback=StreamToStdout(callback_interval=2), - ) - -.. note:: - The ``LM`` is a template without memory, which means that every execution will be cleared. - Additionally, system prompts will not be run when instantiating a `mlc_llm.ChatModule`, - unless explicitly given inside the prompt. - -Stream Iterator in Python -------------------------- - -Stream Iterator gives users an option to stream generated text to the function that the API is called from, -instead of streaming to stdout, which could be a necessity when building services on top of MLC Chat. - -We provide an example below. - -.. code:: python - - from mlc_llm import ChatModule - from mlc_llm.callback import StreamIterator - - # Create a ChatModule instance - cm = ChatModule( - model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so - # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so - # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} - ) - - # Stream to an Iterator - from threading import Thread - - stream = StreamIterator(callback_interval=2) - generation_thread = Thread( - target=cm.generate, - kwargs={"prompt": "What is the meaning of life?", "progress_callback": stream}, - ) - generation_thread.start() - - output = "" - for delta_message in stream: - output += delta_message - - generation_thread.join() - - -API Reference -------------- - -User can initiate a chat module by creating :class:`mlc_llm.ChatModule` class, which is a wrapper of the MLC-LLM model. -The :class:`mlc_llm.ChatModule` class provides the following methods: - -.. currentmodule:: mlc_llm - -.. autoclass:: ChatModule - :members: - :exclude-members: evaluate - :undoc-members: - :show-inheritance: - - .. automethod:: __init__ - -.. autoclass:: ChatConfig - :members: - -.. autoclass:: ConvConfig - :members: - -.. autoclass:: GenerationConfig - :members: diff --git a/docs/deploy/python_engine.rst b/docs/deploy/python_engine.rst index 2ef4d5bd23..86a9e7d4af 100644 --- a/docs/deploy/python_engine.rst +++ b/docs/deploy/python_engine.rst @@ -5,8 +5,6 @@ Python API .. note:: This page introduces the Python API with MLCEngine in MLC LLM. - If you want to check out the old Python API which uses :class:`mlc_llm.ChatModule`, - please go to :ref:`deploy-python-chat-module` .. contents:: Table of Contents :local: diff --git a/docs/index.rst b/docs/index.rst index 1180d00be9..7a6ab491db 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -48,13 +48,6 @@ Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a compilation/package_libraries_and_weights.rst compilation/define_new_models.rst -.. toctree:: - :maxdepth: 1 - :caption: Model Prebuilts - :hidden: - - prebuilt_models.rst - .. toctree:: :maxdepth: 1 :caption: Dependency Installation diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst deleted file mode 100644 index 2f772a5d7e..0000000000 --- a/docs/prebuilt_models.rst +++ /dev/null @@ -1,773 +0,0 @@ -.. _Model Prebuilts: - -Model Prebuilts -================== - -.. contents:: Table of Contents - :depth: 3 - :local: - -.. _model-prebuilts-overview: - -Overview --------- - -MLC-LLM is a universal solution for deploying different language models. Any models that can be described in `TVM Relax `__ -(a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the -help of :doc:`TVM Unity `. - -There are two ways to run a model on MLC-LLM (this page focuses on the second one): - -1. Compile your own models following :doc:`the model compilation page `. -2. Use off-the-shelf prebuilt models following this current page. - -In order to run a specific model on MLC-LLM, you need: - -**1. A model library:** a binary file containing the end-to-end functionality to inference a model (e.g. ``Llama-2-7b-chat-hf-q4f16_1-cuda.so``). -See the full list of all precompiled model libraries `here `__. - -**2. Compiled weights:** a folder containing multiple files that store the compiled and quantized weights of a model -(e.g. https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC). See the full list of all precompiled weights `here `__. - -In this page, we first quickly go over :ref:`how to use prebuilts ` for different platforms, -then track what current :ref:`prebuilt models we provide `. - - -.. _using-model-prebuilts: - -Using Prebuilt Models for Different Platforms -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -We quickly go over how to use prebuilt models for each platform. You can find detailed instruction on each platform's corresponding page. - -.. _using-prebuilt-models-cli: - -**Prebuilt Models on CLI / Python** - -For more, please see :ref:`the CLI page `, and the :ref:`the Python page `. - -.. collapse:: Click to show details - - First create the conda environment if you have not done so. - - .. code:: shell - - conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly - conda activate mlc-chat-venv - conda install git git-lfs - git lfs install - - Download the prebuilt model libraries from github. - - .. code:: shell - - mkdir dist/ - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs - - Run the model with CLI: - - .. code:: shell - - mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC - - - To run the model with Python API, see :ref:`the Python page ` (all other downloading steps are the same as CLI). - - -.. for a blank line - -| - -.. _using-prebuilt-models-ios: - -**Prebuilt Models on iOS** - -For more, please see :doc:`the iOS page `. - -.. collapse:: Click to show details - - The `iOS app `_ has builtin RedPajama-3B and Mistral-7B-Instruct-v0.2 support. - - All prebuilt models with an entry in ``iOS`` in the :ref:`model library table ` are supported by iOS. Namely, we have: - - .. list-table:: Prebuilt Models for iOS - :widths: 15 15 15 15 - :header-rows: 1 - - * - Model Code - - Model Series - - Quantization Mode - - MLC HuggingFace Weights Repo - * - `Mistral-7B-Instruct-v0.2-q3f16_1` - - `Mistral `__ - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - `RedPajama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `phi-2-q4f16_1` - - `Microsoft Phi-2 `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ -.. for a blank line - -| - -.. _prebuilt-models-android: - -**Prebuilt Models on Android** - -For more, please see :doc:`the Android page `. - -.. collapse:: Click to show details - - The apk for demo Android app includes the following models. To add more, check out the Android page. - - .. list-table:: Prebuilt Models for Android - :widths: 15 15 15 15 - :header-rows: 1 - - * - Model code - - Model Series - - Quantization Mode - - Hugging Face repo - * - `Llama-2-7b-q4f16_1` - - `Llama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - `RedPajama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ -.. for a blank line - -| - -.. _supported-model-architectures: - -Level 1: Supported Model Architectures (The All-In-One Table) -------------------------------------------------------------- - -For each model architecture (e.g. Llama), there are multiple variants (e.g. CodeLlama, WizardLM). The variants share the same code for inference and only differ in their weights. In other words, running CodeLlama and WizardLM can use the same model library file (specified in Level 2 tables), but different precompiled weights (specified in Level 3 tables). Note that we have not provided prebuilt weights for all model variants. - -Each entry below hyperlinks to the corresponding level 2 and level 3 tables. - -MLC-LLM supports the following model architectures: - -.. list-table:: Supported Model Architectures - :widths: 10 10 15 15 - :header-rows: 1 - - * - Model Architecture - - Support - - Available MLC Prebuilts - - Unavailable in MLC Prebuilts - * - `LLaMA `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`Llama-2-chat ` - - * `Code Llama `__ - * `Vicuna `__ - * `WizardLM `__ - * `WizardCoder (new) `__ - * `OpenOrca Platypus2 `__ - * `FlagAlpha Llama-2 Chinese `__ - * `georgesung Llama-2 Uncensored `__ - * `Alpaca `__ - * `Guanaco `__ - * `OpenLLaMA `__ - * `Gorilla `__ - * `YuLan-Chat `__ - * - `Mistral `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`Mistral-7B-Instruct-v0.2 ` - * :ref:`NeuralHermes-2.5-Mistral-7B ` - * :ref:`OpenHermes-2.5-Mistral-7B ` - * :ref:`WizardMath-7B-V1.1 ` - - - * - `GPT-NeoX `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`RedPajama ` - - * `Dolly `__ - * `Pythia `__ - * `StableCode `__ - * - `GPTBigCode `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - - - * `StarCoder `__ - * `SantaCoder `__ - * `WizardCoder (old) `__ - * - `Phi `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`Phi-1_5 ` - * :ref:`Phi-2 ` - - - * - `GPT2 `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`GPT2 ` - - - -If the model variant you are interested in uses one of these model architectures we support, -(but we have not provided the prebuilt weights yet), you can check out -:doc:`/compilation/convert_weights` on how to convert the weights. -Afterwards, you may follow :ref:`distribute-compiled-models` to upload your prebuilt -weights to hugging face, and submit a PR that adds an entry to this page, -contributing to the community. - -For models structured in an architecture we have not supported yet, you could: - -- Either `create a [Model Request] issue `__ which - automatically shows up on our `Model Request Tracking Board `__. - -- Or follow our tutorial :doc:`Define New Models `, which introduces how to bring a new model architecture to MLC-LLM. - - -.. _model-library-tables: - -Level 2: Model Library Tables (Precompiled Binary Files) --------------------------------------------------------- - -As mentioned earlier, each model architecture corresponds to a different model library file. That is, you cannot use the same model library file to run ``RedPajama`` and ``Llama-2``. However, you can use the same ``Llama`` model library file to run ``Llama-2``, ``WizardLM``, ``CodeLlama``, etc, but just with different weight files (from tables in Level 3). - -Each table below demonstrates the pre-compiled model library files for each model architecture. This is categorized by: - -- **Size**: each size of model has its own distinct model library file (e.g. 7B or 13B number of parameters) - -- **Platform**: the backend that the model library is intended to be run on (e.g. CUDA, ROCm, iphone, etc.) - -- **Quantization scheme**: the model library file also differs due to the quantization scheme used. For more on this, please see the :doc:`quantization page ` - (e.g. ``q3f16_1`` vs. ``q4f16_1``). - -Each entry links to the specific model library file found in `this github repo `__. - -If the model library you found is not available as a prebuilt, you can compile it yourself by following :doc:`the model compilation page `, -and submit a PR to the repo `binary-mlc-llm-libs `__ afterwards. - -.. _llama_library_table: - -Llama -^^^^^ -.. list-table:: Llama - :widths: 8 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M Chip) - - Metal - - (Intel) - - iOS - - Android - - webgpu - - mali - * - 7B - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - `q4f16_1 `__ - - `q4f32_1 `__ - - - * - 13B - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - - - - - - - `q4f16_1 `__ - - - * - 34B - - - - - - - - - - - - - - - - - - - - - * - 70B - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - - - - - - - `q4f16_1 `__ - - - -.. _mistral_library_table: - -Mistral -^^^^^^^ -.. list-table:: Mistral - :widths: 8 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M Chip) - - Metal - - (Intel) - - iOS - - Android - - webgpu - - mali - * - 7B - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - - - `q3f16_1 `__ - - `q4f16_1 `__ - - `q4f16_1 `__ - - - - -.. _gpt_neox_library_table: - -GPT-NeoX (RedPajama-INCITE) -^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. list-table:: GPT-NeoX (RedPajama-INCITE) - :widths: 8 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M Chip) - - Metal - - (Intel) - - iOS - - Android - - webgpu - - mali - * - 3B - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q4f32_1 `__ - - `q4f16_1 `__ - - `q4f32_1 `__ - - - -.. _gpt_big_code_library_table: - -GPTBigCode -^^^^^^^^^^ - -.. list-table:: GPTBigCode - :widths: 8 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M Chip) - - Metal - - (Intel) - - iOS - - Android - - webgpu - - mali - * - 15B - - - - - - - - - - - - - - - - - - - - - -.. _phi_library_table: - -Phi -^^^ -.. list-table:: Phi - :widths: 8 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M Chip) - - Metal - - (Intel) - - iOS - - Android - - webgpu - - mali - * - Phi-2 - - (2.7B) - - `q0f16 `__ - - `q4f16_1 `__ - - - - `q0f16 `__ - - `q4f16_1 `__ - - - - `q0f16 `__ - - `q4f16_1 `__ - - - - - - - - `q0f16 `__ - - `q4f16_1 `__ - - - * - Phi-1.5 - - (1.3B) - - `q0f16 `__ - - `q4f16_1 `__ - - - - `q0f16 `__ - - `q4f16_1 `__ - - - - `q0f16 `__ - - `q4f16_1 `__ - - - - - - - - `q0f16 `__ - - `q4f16_1 `__ - - - -.. _gpt2_library_table: - -GPT2 -^^^^ -.. list-table:: GPT2 - :widths: 8 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M Chip) - - Metal - - (Intel) - - iOS - - Android - - webgpu - - mali - * - GPT2 - - (124M) - - `q0f16 `__ - - - - `q0f16 `__ - - - - `q0f16 `__ - - - - - - - - `q0f16 `__ - - - * - GPT2-med - - (355M) - - `q0f16 `__ - - - - `q0f16 `__ - - - - `q0f16 `__ - - - - - - - - `q0f16 `__ - - - -.. _model-variant-tables: - -Level 3: Model Variant Tables (Precompiled Weights) ---------------------------------------------------- - -Finally, for each model variant, we provide the precompiled weights we uploaded to hugging face. - -Each precompiled weight is categorized by its model size (e.g. 7B vs. 13B) and the quantization scheme (e.g. ``q3f16_1`` vs. ``q4f16_1``). We note that the weights are **platform-agnostic**. - -Each model variant also loads its conversation configuration from a pre-defined :ref:`conversation template`. Note that multiple model variants can share a common conversation template. - -Some of these files are uploaded by our community contributors--thank you! - -.. _llama2_variant_table: - -`Llama-2 `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``llama-2`` - -.. list-table:: Llama-2 - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q4f16_1 (Chat) `__ - * `q4f32_1 (Chat) `__ - - * - 13B - - * `q4f16_1 `__ - - * - 70B - - * `q4f16_1 `__ - -.. _mistralinstruct_variant_table: - -`Mistral `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``mistral_default`` - -.. list-table:: Mistral - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q3f16_1 (Instruct) `__ - * `q4f16_1 (Instruct) `__ - -.. _neuralhermes_variant_table: - -`NeuralHermes-2.5-Mistral `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``neural_hermes_mistral`` - -.. list-table:: Neural Hermes - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q4f16_1 `__ - -.. _openhermes_variant_table: - -`OpenHermes-2-Mistral `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``open_hermes_mistral`` - -.. list-table:: Open Hermes - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q4f16_1 `__ - - - -.. _wizardmathv1.1_variant_table: - -`WizardMath V1.1 `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``wizard_coder_or_math`` - -.. list-table:: WizardMath - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q4f16_1 `__ - - -.. _red_pajama_variant_table: - -`RedPajama `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``redpajama_chat`` - -.. list-table:: Red Pajama - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 3B - - * `q4f16_1 (Chat) `__ - * `q4f32_1 (Chat) `__ - - -.. _phi_variant_table: - -`Phi `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``phi-2`` - -.. list-table:: Phi - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - Phi-2 (2.7B) - - * `q0f16 `__ - * `q4f16_1 `__ - * - Phi-1.5 (1.3B) - - * `q0f16 `__ - * `q4f16_1 `__ - - -.. _gpt2_variant_table: - -`GPT2 `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``gpt2`` - -.. list-table:: GPT2 - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - GPT2 (124M) - - * `q0f16 `__ - * - GPT2-medium (355M) - - * `q0f16 `__ - - ------------------- - - -.. _contribute-models-to-mlc-llm: - -Contribute Models to MLC-LLM ----------------------------- - -Ready to contribute your compiled models/new model architectures? Awesome! Please check :ref:`contribute-new-models` on how to contribute new models to MLC-LLM. From 2b8aadf57479ffffa24380846cfb9976d00d437c Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 10 May 2024 18:39:40 -0400 Subject: [PATCH 301/531] [Android] Add `-j` option to cmake build (#2321) This PR adds the `-j` option to cmake build to parallelize the build job over CPU cores. --- android/mlc4j/prepare_libs.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/android/mlc4j/prepare_libs.py b/android/mlc4j/prepare_libs.py index 19f80718f0..b1c490c354 100644 --- a/android/mlc4j/prepare_libs.py +++ b/android/mlc4j/prepare_libs.py @@ -43,13 +43,31 @@ def run_cmake(mlc4j_path: Path): def run_cmake_build(): logger.info("Running cmake build") - cmd = ["cmake", "--build", ".", "--target", "tvm4j_runtime_packed", "--config", "release"] + cmd = [ + "cmake", + "--build", + ".", + "--target", + "tvm4j_runtime_packed", + "--config", + "release", + f"-j{os.cpu_count()}", + ] subprocess.run(cmd, check=True, env=os.environ) def run_cmake_install(): logger.info("Running cmake install") - cmd = ["cmake", "--build", ".", "--target", "install", "--config", "release", "-j"] + cmd = [ + "cmake", + "--build", + ".", + "--target", + "install", + "--config", + "release", + f"-j{os.cpu_count()}", + ] subprocess.run(cmd, check=True, env=os.environ) From 98f042460f97a0953d5e8b6198531b78617ce9d5 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 11 May 2024 16:59:57 -0400 Subject: [PATCH 302/531] [DOCS] More clear android instruction (#2327) This PR sets a more clear instruction for android JDK setup --- docs/deploy/android.rst | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst index 2a729349f1..ed75befa02 100644 --- a/docs/deploy/android.rst +++ b/docs/deploy/android.rst @@ -36,17 +36,19 @@ Prerequisite TVM_NDK_CC: $ANDROID_NDK/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android24-clang **JDK**, such as OpenJDK >= 17, to compile Java bindings of TVM Unity runtime. -We recommended setting the ``JAVA_HOME`` to the JDK bundled with Android Studio. e.g. ``export JAVA_HOME=/Applications/Android\ Studio.app/Contents/jbr/Contents/Home`` for macOS. -In other ways, it could be installed via Homebrew on macOS, apt on Ubuntu or other package managers. +We strongly recommend setting the ``JAVA_HOME`` to the JDK bundled with Android Studio. e.g. +``export JAVA_HOME=/Applications/Android\ Studio.app/Contents/jbr/Contents/Home`` for macOS. +Using Android Studio's JBR bundle as recommended `here https://developer.android.com/build/jdks` +will reduce the chances of potential errors in JNI compilation. Set up the following environment variable: -- ``JAVA_HOME`` so that Java is available in ``$JAVA_HOME/bin/java``. +- ``export JAVA_HOME=/path/to/java_home`` you can then cross check and make sure ``$JAVA_HOME/bin/java`` exists. Please ensure that the JDK versions for Android Studio and JAVA_HOME are the same. **TVM Unity runtime** is placed under `3rdparty/tvm `__ in MLC LLM, so there is no need to install anything extra. Set up the following environment variable: -- ``TVM_HOME`` so that its headers are available under ``$TVM_HOME/include/tvm/runtime``. +- ``export TVM_HOME=/path/to/mlc-llm/3rdparty/tvm``. (Optional) **TVM Unity compiler** Python package (:ref:`install ` or :ref:`build from source `). It is *NOT* required if models are prebuilt, but to compile PyTorch models from HuggingFace in the following section, the compiler is a must-dependency. From 21feb7010db02e0c2149489f5972d6a8a796b5a0 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sun, 12 May 2024 03:24:55 -0700 Subject: [PATCH 303/531] [Serving] Refactor to consolidate new request prefill (#2329) --- .../engine_actions/batch_prefill_base.cc | 313 +++++++++++++++++ cpp/serve/engine_actions/batch_prefill_base.h | 107 ++++++ .../eagle_new_request_prefill.cc | 315 ++---------------- .../engine_actions/new_request_prefill.cc | 302 +---------------- 4 files changed, 452 insertions(+), 585 deletions(-) create mode 100644 cpp/serve/engine_actions/batch_prefill_base.cc create mode 100644 cpp/serve/engine_actions/batch_prefill_base.h diff --git a/cpp/serve/engine_actions/batch_prefill_base.cc b/cpp/serve/engine_actions/batch_prefill_base.cc new file mode 100644 index 0000000000..df6df2b3d9 --- /dev/null +++ b/cpp/serve/engine_actions/batch_prefill_base.cc @@ -0,0 +1,313 @@ +/*! + * Copyright (c) 2024 by Contributors + * \file serve/engine_actions/batch_prefill_base.h + */ + +#include "batch_prefill_base.h" + +namespace mlc { +namespace llm { +namespace serve { + +BatchPrefillBaseActionObj::BatchPrefillBaseActionObj(Array models, + EngineConfig engine_config, + Optional trace_recorder) + : models_(models), engine_config_(engine_config), trace_recorder_(trace_recorder) {} + +/*! + * \brief Find one or multiple request state entries to run prefill. + * \param estate The engine state. + * \return The request entries to prefill, together with their input lengths. + */ +std::vector +BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) { + if (estate->waiting_queue.empty()) { + // No request to prefill. + return {}; + } + + std::vector prefill_inputs; + + // - Try to prefill pending requests. + int total_input_length = 0; + int total_required_pages = 0; + int num_available_pages = models_[0]->GetNumAvailablePages(); + int num_running_rsentries = GetRunningRequestStateEntries(estate).size(); + int current_total_seq_len = models_[0]->GetCurrentTotalSequenceLength(); + + int num_prefill_rsentries = 0; + for (const Request& request : estate->waiting_queue) { + RequestState rstate = estate->GetRequestState(request); + bool prefill_stops = false; + for (const RequestStateEntry& rsentry : rstate->entries) { + // A request state entry can be prefilled only when: + // - it has inputs, and + // - it has no parent or its parent is alive and has no remaining input. + if (rsentry->mstates[0]->inputs.empty() || + (rsentry->parent_idx != -1 && + (rstate->entries[rsentry->parent_idx]->status == RequestStateStatus::kPending || + !rstate->entries[rsentry->parent_idx]->mstates[0]->inputs.empty()))) { + continue; + } + + int input_length = rsentry->mstates[0]->GetInputLength(); + int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + // - Attempt 1. Check if the entire request state entry can fit for prefill. + bool can_prefill = false; + for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0; + --num_child_to_activate) { + if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate, + total_input_length, total_required_pages, num_available_pages, + current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length, num_child_to_activate}); + num_prefill_rsentries += 1 + num_child_to_activate; + can_prefill = true; + break; + } + } + if (can_prefill) { + continue; + } + total_input_length -= input_length; + total_required_pages -= num_require_pages; + + // - Attempt 2. Check if the request state entry can partially fit by input chunking. + ICHECK_LE(total_input_length, engine_config_->prefill_chunk_size); + if (engine_config_->prefill_chunk_size - total_input_length >= input_length || + engine_config_->prefill_chunk_size == total_input_length) { + // 1. If the input length can fit the remaining prefill chunk size, + // it means the failure of attempt 1 is not because of the input + // length being too long, and thus chunking does not help. + // 2. If the total input length already reaches the prefill chunk size, + // the current request state entry will not be able to be processed. + // So we can safely return in either case. + prefill_stops = true; + break; + } + input_length = engine_config_->prefill_chunk_size - total_input_length; + num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / + engine_config_->kv_cache_page_size; + total_input_length += input_length; + total_required_pages += num_require_pages; + if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, + num_available_pages, current_total_seq_len, num_running_rsentries)) { + prefill_inputs.push_back({rsentry, input_length, 0}); + num_prefill_rsentries += 1; + } + + // - Prefill stops here. + prefill_stops = true; + break; + } + if (prefill_stops) { + break; + } + } + + return prefill_inputs; +} + +/*! \brief Check if the input requests can be prefilled under conditions. */ +bool BatchPrefillBaseActionObj::CanPrefill(EngineState estate, int num_prefill_rsentries, + int total_input_length, int num_required_pages, + int num_available_pages, int current_total_seq_len, + int num_running_rsentries) { + ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); + + // For RNN State, it can prefill as long as it can be instantiated. + if (engine_config_->kv_state_kind == KVStateKind::kRNNState) { + return true; + } + + // No exceeding of the maximum allowed requests that can + // run simultaneously. + int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable + ? (engine_config_->spec_draft_length + 1) + : 1; + if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > + std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { + return false; + } + + // NOTE: The conditions are heuristic and can be revised. + // Cond 1: total input length <= prefill chunk size. + // Cond 2: at least one decode can be performed after prefill. + // Cond 3: number of total tokens after 8 times of decode does not + // exceed the limit, where 8 is a watermark number can + // be configured and adjusted in the future. + int new_batch_size = num_running_rsentries + num_prefill_rsentries; + return total_input_length <= engine_config_->prefill_chunk_size && + num_required_pages + new_batch_size <= num_available_pages && + current_total_seq_len + total_input_length + 8 * new_batch_size <= + engine_config_->max_total_sequence_length; +} + +/*! + * \brief Chunk the input of the given RequestModelState for prefill + * with regard to the provided maximum allowed prefill length. + * Return the list of input for prefill and the total prefill length. + * The `inputs` field of the given `mstate` will be mutated to exclude + * the returned input. + * \param mstate The RequestModelState whose input data is to be chunked. + * \param max_prefill_length The maximum allowed prefill length for the mstate. + * \return The list of input for prefill and the total prefill length. + */ +std::pair, int> BatchPrefillBaseActionObj::ChunkPrefillInputData( + const RequestModelState& mstate, int max_prefill_length) { + if (mstate->inputs.empty()) { + } + ICHECK(!mstate->inputs.empty()); + std::vector inputs; + int cum_input_length = 0; + inputs.reserve(mstate->inputs.size()); + for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { + inputs.push_back(mstate->inputs[i]); + int input_length = mstate->inputs[i]->GetLength(); + cum_input_length += input_length; + // Case 0. the cumulative input length does not reach the maximum prefill length. + if (cum_input_length < max_prefill_length) { + continue; + } + + // Case 1. the cumulative input length equals the maximum prefill length. + if (cum_input_length == max_prefill_length) { + if (i == static_cast(mstate->inputs.size()) - 1) { + // - If `i` is the last input, we just copy and reset `mstate->inputs`. + mstate->inputs.clear(); + } else { + // - Otherwise, set the new input array. + mstate->inputs = Array{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; + } + return {inputs, cum_input_length}; + } + + // Case 2. cum_input_length > max_prefill_length + // The input `i` itself needs chunking if it is TokenData, + // or otherwise it cannot be chunked. + Data input = mstate->inputs[i]; + inputs.pop_back(); + cum_input_length -= input_length; + const auto* token_input = input.as(); + if (token_input == nullptr) { + // Cannot chunk the input. + if (i != 0) { + mstate->inputs = Array{mstate->inputs.begin() + i, mstate->inputs.end()}; + } + return {inputs, cum_input_length}; + } + + // Split the token data into two parts. + // Return the first part for prefill, and keep the second part. + int chunked_input_length = max_prefill_length - cum_input_length; + ICHECK_GT(input_length, chunked_input_length); + TokenData chunked_input(IntTuple{token_input->token_ids.begin(), + token_input->token_ids.begin() + chunked_input_length}); + TokenData remaining_input(IntTuple{token_input->token_ids.begin() + chunked_input_length, + token_input->token_ids.end()}); + inputs.push_back(chunked_input); + cum_input_length += chunked_input_length; + std::vector remaining_inputs{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; + remaining_inputs.insert(remaining_inputs.begin(), remaining_input); + mstate->inputs = remaining_inputs; + return {inputs, cum_input_length}; + } + + ICHECK(false) << "Cannot reach here"; +} + +void BatchPrefillBaseActionObj::UpdateRequestToAlive( + const std::vector& prefill_inputs, + const EngineState& estate, Array* request_ids, + std::vector* rstates_of_entries, + std::vector* status_before_prefill) { + int num_rsentries = prefill_inputs.size(); + request_ids->reserve(num_rsentries); + rstates_of_entries->reserve(num_rsentries); + status_before_prefill->reserve(num_rsentries); + for (const PrefillInput& prefill_input : prefill_inputs) { + const RequestStateEntry& rsentry = prefill_input.rsentry; + const Request& request = rsentry->request; + RequestState request_rstate = estate->GetRequestState(request); + request_ids->push_back(request->id); + status_before_prefill->push_back(rsentry->status); + rsentry->status = RequestStateStatus::kAlive; + + if (status_before_prefill->back() == RequestStateStatus::kPending) { + // - Add the request to running queue if the request state + // status was pending and all its request states were pending. + bool alive_state_existed = false; + for (const RequestStateEntry& rsentry_ : request_rstate->entries) { + if (rsentry_->status == RequestStateStatus::kAlive && !rsentry_.same_as(rsentry)) { + alive_state_existed = true; + } + } + if (!alive_state_existed) { + estate->running_queue.push_back(request); + } + } + rstates_of_entries->push_back(std::move(request_rstate)); + } +} + +std::vector BatchPrefillBaseActionObj::RemoveProcessedRequests( + const std::vector& prefill_inputs, + const EngineState& estate, const std::vector& rstates_of_entries) { + // - Remove the request from waiting queue if all its request states + // are now alive and have no remaining chunked inputs. + std::vector processed_requests; + int num_rsentries = prefill_inputs.size(); + processed_requests.reserve(num_rsentries); + std::unordered_set dedup_map; + for (int i = 0; i < num_rsentries; ++i) { + const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; + if (dedup_map.find(rsentry->request.get()) != dedup_map.end()) { + continue; + } + dedup_map.insert(rsentry->request.get()); + processed_requests.push_back(rsentry->request); + + bool pending_state_exists = false; + for (const RequestStateEntry& rsentry_ : rstates_of_entries[i]->entries) { + if (rsentry_->status == RequestStateStatus::kPending || + !rsentry_->mstates[0]->inputs.empty()) { + pending_state_exists = true; + break; + } + } + if (!pending_state_exists) { + auto it = + std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), rsentry->request); + ICHECK(it != estate->waiting_queue.end()); + estate->waiting_queue.erase(it); + } + } + return processed_requests; +} + +void BatchPrefillBaseActionObj::UpdateRequestStateEntriesWithSampleResults( + const std::vector& rsentries_for_sample, + const std::vector& rsentry_activated, const std::vector& sample_results) { + auto tnow = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + // Update all model states of the request state entry. + for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) { + mstate->CommitToken(sample_results[i]); + if (!rsentry_activated[i]) { + // When the child rsentry is not activated, + // add the sampled token as an input of the mstate for prefill. + mstate->inputs.push_back( + TokenData(std::vector{sample_results[i].sampled_token_id.first})); + } + } + if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { + rsentries_for_sample[i]->tprefill_finish = tnow; + } + } +} + +} // namespace serve +} // namespace llm +} // namespace mlc \ No newline at end of file diff --git a/cpp/serve/engine_actions/batch_prefill_base.h b/cpp/serve/engine_actions/batch_prefill_base.h new file mode 100644 index 0000000000..54b257dc21 --- /dev/null +++ b/cpp/serve/engine_actions/batch_prefill_base.h @@ -0,0 +1,107 @@ +/*! + * Copyright (c) 2024 by Contributors + * \file serve/engine_actions/batch_prefill_base.h + */ + +#include + +#include "../config.h" +#include "../model.h" +#include "action.h" +#include "action_commons.h" + +namespace mlc { +namespace llm { +namespace serve { + +/*! + * \brief The base action of that prefills requests in the `waiting_queue` of + * the engine state. + */ +class BatchPrefillBaseActionObj : public EngineActionObj { + protected: + /*! \brief The class of request state entry and its maximum allowed length for prefill. */ + struct PrefillInput { + RequestStateEntry rsentry; + int max_prefill_length = 0; + int num_child_to_activate = 0; + }; + + BatchPrefillBaseActionObj(Array models, EngineConfig engine_config, + Optional trace_recorder); + + /*! + * \brief Find one or multiple request state entries to run prefill. + * \param estate The engine state. + * \return The request entries to prefill, together with their input lengths. + */ + std::vector GetRequestStateEntriesToPrefill(EngineState estate); + + /*! \brief Check if the input requests can be prefilled under conditions. */ + bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, + int num_required_pages, int num_available_pages, int current_total_seq_len, + int num_running_rsentries); + + /*! + * \brief Chunk the input of the given RequestModelState for prefill + * with regard to the provided maximum allowed prefill length. + * Return the list of input for prefill and the total prefill length. + * The `inputs` field of the given `mstate` will be mutated to exclude + * the returned input. + * \param mstate The RequestModelState whose input data is to be chunked. + * \param max_prefill_length The maximum allowed prefill length for the mstate. + * \return The list of input for prefill and the total prefill length. + */ + std::pair, int> ChunkPrefillInputData(const RequestModelState& mstate, + int max_prefill_length); + + /*! + * \brief Update status of request states from pending to alive and collect request state entries + * from the prefill input. + * \param prefill_inputs The prefill input. + * \param estate The engine state. + * \param[out] request_ids The array to store the request ids of the request state entries. + * \param[out] rstates_of_entries The vector to store the request state entries. + * \param[out] status_before_prefill The vector to store the status of the request state entries + * before prefill. + */ + void UpdateRequestToAlive(const std::vector& prefill_inputs, + const EngineState& estate, Array* request_ids, + std::vector* rstates_of_entries, + std::vector* status_before_prefill); + + /*! + * \brief Remove the request from waiting queue if all its request states are now alive and have + * no remaining chunked inputs. + * \param prefill_inputs The prefill input. + * \param estate The engine state. + * \param rstates_of_entries The request state entries for each prefill input. + * \return The processed requests. + */ + std::vector RemoveProcessedRequests(const std::vector& prefill_inputs, + const EngineState& estate, + const std::vector& rstates_of_entries); + + /*! + * \brief Update the committed tokens of states. If a request is first-time prefilled, set the + * prefill finish time. + * \param rsentries_for_sample The request state entries for sample. + * \param + * rsentry_activated The activation status of the request state entries. + * \param sample_results The sample results. + */ + void UpdateRequestStateEntriesWithSampleResults( + const std::vector& rsentries_for_sample, + const std::vector& rsentry_activated, const std::vector& sample_results); + + /*! \brief The models to run prefill in. */ + Array models_; + /*! \brief The engine config. */ + EngineConfig engine_config_; + /*! \brief Event trace recorder. */ + Optional trace_recorder_; +}; + +} // namespace serve +} // namespace llm +} // namespace mlc \ No newline at end of file diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index e2d2d661f8..2190cf61ed 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -10,6 +10,7 @@ #include "../sampler/sampler.h" #include "action.h" #include "action_commons.h" +#include "batch_prefill_base.h" namespace mlc { namespace llm { @@ -19,7 +20,7 @@ namespace serve { * \brief The action that prefills requests in the `waiting_queue` of * the engine state. */ -class EagleNewRequestPrefillActionObj : public EngineActionObj { +class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj { public: explicit EagleNewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, @@ -27,13 +28,12 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config, Optional trace_recorder) - : models_(std::move(models)), + : BatchPrefillBaseActionObj(std::move(models), std::move(engine_config), + std::move(trace_recorder)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), model_workspaces_(std::move(model_workspaces)), - draft_token_workspace_manager_(std::move(draft_token_workspace_manager)), - engine_config_(std::move(engine_config)), - trace_recorder_(std::move(trace_recorder)) {} + draft_token_workspace_manager_(std::move(draft_token_workspace_manager)) {} Array Step(EngineState estate) final { // - Find the requests in `waiting_queue` that can prefill in this step. @@ -53,32 +53,8 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { Array request_ids; std::vector rstates_of_entries; std::vector status_before_prefill; - request_ids.reserve(num_rsentries); - rstates_of_entries.reserve(num_rsentries); - status_before_prefill.reserve(num_rsentries); - for (const PrefillInput& prefill_input : prefill_inputs) { - const RequestStateEntry& rsentry = prefill_input.rsentry; - const Request& request = rsentry->request; - RequestState request_rstate = estate->GetRequestState(request); - request_ids.push_back(request->id); - status_before_prefill.push_back(rsentry->status); - rsentry->status = RequestStateStatus::kAlive; - - if (status_before_prefill.back() == RequestStateStatus::kPending) { - // - Add the request to running queue if the request state - // status was pending and all its request states were pending. - bool alive_state_existed = false; - for (const RequestStateEntry& rsentry_ : request_rstate->entries) { - if (rsentry_->status == RequestStateStatus::kAlive && !rsentry_.same_as(rsentry)) { - alive_state_existed = true; - } - } - if (!alive_state_existed) { - estate->running_queue.push_back(request); - } - } - rstates_of_entries.push_back(std::move(request_rstate)); - } + UpdateRequestToAlive(prefill_inputs, estate, &request_ids, &rstates_of_entries, + &status_before_prefill); // - Get embedding and run prefill for each model. std::vector prefill_lengths; @@ -285,30 +261,19 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { // - If a request is first-time prefilled, set the prefill finish time. auto tnow = std::chrono::high_resolution_clock::now(); if (model_id == 0) { + UpdateRequestStateEntriesWithSampleResults(rsentries_for_sample, rsentry_activated, + sample_results); + // Add the sampled token as an input of the eagle models. for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { - for (int mid = 0; mid < static_cast(models_.size()); ++mid) { - rsentries_for_sample[i]->mstates[mid]->CommitToken(sample_results[i]); - if (!rsentry_activated[i]) { - // When the child rsentry is not activated, - // add the sampled token as an input of the mstate for prefill. - rsentries_for_sample[i]->mstates[mid]->inputs.push_back( - TokenData(std::vector{sample_results[i].sampled_token_id.first})); - } - if (mid > 0) { - // Add the sampled token as an input of the eagle models. - TokenData token_data = - Downcast(rsentries_for_sample[i]->mstates[mid]->inputs.back()); - std::vector token_ids = {token_data->token_ids.begin(), - token_data->token_ids.end()}; - token_ids.push_back(sample_results[i].sampled_token_id.first); - int ninputs = static_cast(rsentries_for_sample[i]->mstates[mid]->inputs.size()); - rsentries_for_sample[i]->mstates[mid]->inputs.Set( - ninputs - 1, TokenData(IntTuple(token_ids.begin(), token_ids.end()))); - } - } - // Only base model trigger timing records. - if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { - rsentries_for_sample[i]->tprefill_finish = tnow; + for (int mid = 1; mid < static_cast(models_.size()); ++mid) { + TokenData token_data = + Downcast(rsentries_for_sample[i]->mstates[mid]->inputs.back()); + std::vector token_ids = {token_data->token_ids.begin(), + token_data->token_ids.end()}; + token_ids.push_back(sample_results[i].sampled_token_id.first); + int ninputs = static_cast(rsentries_for_sample[i]->mstates[mid]->inputs.size()); + rsentries_for_sample[i]->mstates[mid]->inputs.Set( + ninputs - 1, TokenData(IntTuple(token_ids.begin(), token_ids.end()))); } } } else { @@ -332,246 +297,12 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { auto tend = std::chrono::high_resolution_clock::now(); estate->stats.engine_total_prefill_time += static_cast((tend - tstart).count()) / 1e9; - // - Remove the request from waiting queue if all its request states - // are now alive and have no remaining chunked inputs. - std::vector processed_requests; - { - processed_requests.reserve(num_rsentries); - std::unordered_set dedup_map; - for (int i = 0; i < num_rsentries; ++i) { - const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; - if (dedup_map.find(rsentry->request.get()) != dedup_map.end()) { - continue; - } - dedup_map.insert(rsentry->request.get()); - processed_requests.push_back(rsentry->request); - - bool pending_state_exists = false; - for (const RequestStateEntry& rsentry_ : rstates_of_entries[i]->entries) { - if (rsentry_->status == RequestStateStatus::kPending || - !rsentry_->mstates[0]->inputs.empty()) { - pending_state_exists = true; - break; - } - } - if (!pending_state_exists) { - auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), - rsentry->request); - ICHECK(it != estate->waiting_queue.end()); - estate->waiting_queue.erase(it); - } - } - } + std::vector processed_requests = + RemoveProcessedRequests(prefill_inputs, estate, rstates_of_entries); return processed_requests; } private: - /*! \brief The class of request state entry and its maximum allowed length for prefill. */ - struct PrefillInput { - RequestStateEntry rsentry; - int max_prefill_length = 0; - int num_child_to_activate = 0; - }; - - /*! - * \brief Find one or multiple request state entries to run prefill. - * \param estate The engine state. - * \return The request entries to prefill, together with their input lengths. - */ - std::vector GetRequestStateEntriesToPrefill(EngineState estate) { - if (estate->waiting_queue.empty()) { - // No request to prefill. - return {}; - } - - std::vector prefill_inputs; - - // - Try to prefill pending requests. - int total_input_length = 0; - int total_required_pages = 0; - int num_available_pages = models_[0]->GetNumAvailablePages(); - int num_running_rsentries = GetRunningRequestStateEntries(estate).size(); - int current_total_seq_len = models_[0]->GetCurrentTotalSequenceLength(); - - int num_prefill_rsentries = 0; - for (const Request& request : estate->waiting_queue) { - RequestState rstate = estate->GetRequestState(request); - bool prefill_stops = false; - for (const RequestStateEntry& rsentry : rstate->entries) { - // A request state entry can be prefilled only when: - // - it has inputs, and - // - it has no parent or its parent is alive and has no remaining input. - if (rsentry->mstates[0]->inputs.empty() || - (rsentry->parent_idx != -1 && - (rstate->entries[rsentry->parent_idx]->status == RequestStateStatus::kPending || - !rstate->entries[rsentry->parent_idx]->mstates[0]->inputs.empty()))) { - continue; - } - - int input_length = rsentry->mstates[0]->GetInputLength(); - int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / - engine_config_->kv_cache_page_size; - total_input_length += input_length; - total_required_pages += num_require_pages; - // - Attempt 1. Check if the entire request state entry can fit for prefill. - bool can_prefill = false; - for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0; - --num_child_to_activate) { - if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate, - total_input_length, total_required_pages, num_available_pages, - current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length, num_child_to_activate}); - num_prefill_rsentries += 1 + num_child_to_activate; - can_prefill = true; - break; - } - } - if (can_prefill) { - continue; - } - total_input_length -= input_length; - total_required_pages -= num_require_pages; - - // - Attempt 2. Check if the request state entry can partially fit by input chunking. - ICHECK_LE(total_input_length, engine_config_->prefill_chunk_size); - if (engine_config_->prefill_chunk_size - total_input_length >= input_length || - engine_config_->prefill_chunk_size == total_input_length) { - // 1. If the input length can fit the remaining prefill chunk size, - // it means the failure of attempt 1 is not because of the input - // length being too long, and thus chunking does not help. - // 2. If the total input length already reaches the prefill chunk size, - // the current request state entry will not be able to be processed. - // So we can safely return in either case. - prefill_stops = true; - break; - } - input_length = engine_config_->prefill_chunk_size - total_input_length; - num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / - engine_config_->kv_cache_page_size; - total_input_length += input_length; - total_required_pages += num_require_pages; - if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, - num_available_pages, current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length, 0}); - num_prefill_rsentries += 1; - } - - // - Prefill stops here. - prefill_stops = true; - break; - } - if (prefill_stops) { - break; - } - } - - return prefill_inputs; - } - - /*! \brief Check if the input requests can be prefilled under conditions. */ - bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, - int num_required_pages, int num_available_pages, int current_total_seq_len, - int num_running_rsentries) { - ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); - - // No exceeding of the maximum allowed requests that can - // run simultaneously. - int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable - ? (engine_config_->spec_draft_length + 1) - : 1; - if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > - std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { - return false; - } - - // NOTE: The conditions are heuristic and can be revised. - // Cond 1: total input length <= prefill chunk size. - // Cond 2: at least one decode can be performed after prefill. - // Cond 3: number of total tokens after 8 times of decode does not - // exceed the limit, where 8 is a watermark number can - // be configured and adjusted in the future. - int new_batch_size = num_running_rsentries + num_prefill_rsentries; - return total_input_length <= engine_config_->prefill_chunk_size && - num_required_pages + new_batch_size <= num_available_pages && - current_total_seq_len + total_input_length + 8 * new_batch_size <= - engine_config_->max_total_sequence_length; - } - - /*! - * \brief Chunk the input of the given RequestModelState for prefill - * with regard to the provided maximum allowed prefill length. - * Return the list of input for prefill and the total prefill length. - * The `inputs` field of the given `mstate` will be mutated to exclude - * the returned input. - * \param mstate The RequestModelState whose input data is to be chunked. - * \param max_prefill_length The maximum allowed prefill length for the mstate. - * \return The list of input for prefill and the total prefill length. - */ - std::pair, int> ChunkPrefillInputData(const RequestModelState& mstate, - int max_prefill_length) { - if (mstate->inputs.empty()) { - } - ICHECK(!mstate->inputs.empty()); - std::vector inputs; - int cum_input_length = 0; - inputs.reserve(mstate->inputs.size()); - for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { - inputs.push_back(mstate->inputs[i]); - int input_length = mstate->inputs[i]->GetLength(); - cum_input_length += input_length; - // Case 0. the cumulative input length does not reach the maximum prefill length. - if (cum_input_length < max_prefill_length) { - continue; - } - - // Case 1. the cumulative input length equals the maximum prefill length. - if (cum_input_length == max_prefill_length) { - if (i == static_cast(mstate->inputs.size()) - 1) { - // - If `i` is the last input, we just copy and reset `mstate->inputs`. - mstate->inputs.clear(); - } else { - // - Otherwise, set the new input array. - mstate->inputs = Array{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; - } - return {inputs, cum_input_length}; - } - - // Case 2. cum_input_length > max_prefill_length - // The input `i` itself needs chunking if it is TokenData, - // or otherwise it cannot be chunked. - Data input = mstate->inputs[i]; - inputs.pop_back(); - cum_input_length -= input_length; - const auto* token_input = input.as(); - if (token_input == nullptr) { - // Cannot chunk the input. - if (i != 0) { - mstate->inputs = Array{mstate->inputs.begin() + i, mstate->inputs.end()}; - } - return {inputs, cum_input_length}; - } - - // Split the token data into two parts. - // Return the first part for prefill, and keep the second part. - int chunked_input_length = max_prefill_length - cum_input_length; - ICHECK_GT(input_length, chunked_input_length); - TokenData chunked_input(IntTuple{token_input->token_ids.begin(), - token_input->token_ids.begin() + chunked_input_length}); - TokenData remaining_input(IntTuple{token_input->token_ids.begin() + chunked_input_length, - token_input->token_ids.end()}); - inputs.push_back(chunked_input); - cum_input_length += chunked_input_length; - std::vector remaining_inputs{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; - remaining_inputs.insert(remaining_inputs.begin(), remaining_input); - mstate->inputs = remaining_inputs; - return {inputs, cum_input_length}; - } - - ICHECK(false) << "Cannot reach here"; - } - - /*! \brief The models to run prefill in. */ - Array models_; /*! \brief The logit processor. */ LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ @@ -580,10 +311,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj { std::vector model_workspaces_; /*! \brief The draft token workspace manager. */ DraftTokenWorkspaceManager draft_token_workspace_manager_; - /*! \brief The engine config. */ - EngineConfig engine_config_; - /*! \brief Event trace recorder. */ - Optional trace_recorder_; /*! \brief Temporary buffer to store the slots of the current draft tokens */ std::vector draft_token_slots_; }; diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index 5a5847aaa0..038a6cc66c 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -10,6 +10,7 @@ #include "../sampler/sampler.h" #include "action.h" #include "action_commons.h" +#include "batch_prefill_base.h" namespace mlc { namespace llm { @@ -19,18 +20,17 @@ namespace serve { * \brief The action that prefills requests in the `waiting_queue` of * the engine state. */ -class NewRequestPrefillActionObj : public EngineActionObj { +class NewRequestPrefillActionObj : public BatchPrefillBaseActionObj { public: explicit NewRequestPrefillActionObj(Array models, LogitProcessor logit_processor, Sampler sampler, std::vector model_workspaces, EngineConfig engine_config, Optional trace_recorder) - : models_(std::move(models)), + : BatchPrefillBaseActionObj(std::move(models), std::move(engine_config), + std::move(trace_recorder)), logit_processor_(std::move(logit_processor)), sampler_(std::move(sampler)), - model_workspaces_(std::move(model_workspaces)), - engine_config_(std::move(engine_config)), - trace_recorder_(std::move(trace_recorder)) {} + model_workspaces_(std::move(model_workspaces)) {} Array Step(EngineState estate) final { // - Find the requests in `waiting_queue` that can prefill in this step. @@ -50,32 +50,8 @@ class NewRequestPrefillActionObj : public EngineActionObj { Array request_ids; std::vector rstates_of_entries; std::vector status_before_prefill; - request_ids.reserve(num_rsentries); - rstates_of_entries.reserve(num_rsentries); - status_before_prefill.reserve(num_rsentries); - for (const PrefillInput& prefill_input : prefill_inputs) { - const RequestStateEntry& rsentry = prefill_input.rsentry; - const Request& request = rsentry->request; - RequestState request_rstate = estate->GetRequestState(request); - request_ids.push_back(request->id); - status_before_prefill.push_back(rsentry->status); - rsentry->status = RequestStateStatus::kAlive; - - if (status_before_prefill.back() == RequestStateStatus::kPending) { - // - Add the request to running queue if the request state - // status was pending and all its request states were pending. - bool alive_state_existed = false; - for (const RequestStateEntry& rsentry_ : request_rstate->entries) { - if (rsentry_->status == RequestStateStatus::kAlive && !rsentry_.same_as(rsentry)) { - alive_state_existed = true; - } - } - if (!alive_state_existed) { - estate->running_queue.push_back(request); - } - } - rstates_of_entries.push_back(std::move(request_rstate)); - } + UpdateRequestToAlive(prefill_inputs, estate, &request_ids, &rstates_of_entries, + &status_before_prefill); // - Get embedding and run prefill for each model. std::vector prefill_lengths; @@ -237,280 +213,24 @@ class NewRequestPrefillActionObj : public EngineActionObj { // - Update the committed tokens of states. // - If a request is first-time prefilled, set the prefill finish time. - auto tnow = std::chrono::high_resolution_clock::now(); - for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { - for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) { - mstate->CommitToken(sample_results[i]); - if (!rsentry_activated[i]) { - // When the child rsentry is not activated, - // add the sampled token as an input of the mstate for prefill. - mstate->inputs.push_back( - TokenData(std::vector{sample_results[i].sampled_token_id.first})); - } - } - if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) { - rsentries_for_sample[i]->tprefill_finish = tnow; - } - } + UpdateRequestStateEntriesWithSampleResults(rsentries_for_sample, rsentry_activated, + sample_results); auto tend = std::chrono::high_resolution_clock::now(); estate->stats.engine_total_prefill_time += static_cast((tend - tstart).count()) / 1e9; - // - Remove the request from waiting queue if all its request states - // are now alive and have no remaining chunked inputs. - std::vector processed_requests; - { - processed_requests.reserve(num_rsentries); - std::unordered_set dedup_map; - for (int i = 0; i < num_rsentries; ++i) { - const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; - if (dedup_map.find(rsentry->request.get()) != dedup_map.end()) { - continue; - } - dedup_map.insert(rsentry->request.get()); - processed_requests.push_back(rsentry->request); - - bool pending_state_exists = false; - for (const RequestStateEntry& rsentry_ : rstates_of_entries[i]->entries) { - if (rsentry_->status == RequestStateStatus::kPending || - !rsentry_->mstates[0]->inputs.empty()) { - pending_state_exists = true; - break; - } - } - if (!pending_state_exists) { - auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), - rsentry->request); - ICHECK(it != estate->waiting_queue.end()); - estate->waiting_queue.erase(it); - } - } - } + std::vector processed_requests = + RemoveProcessedRequests(prefill_inputs, estate, rstates_of_entries); return processed_requests; } private: - /*! \brief The class of request state entry and its maximum allowed length for prefill. */ - struct PrefillInput { - RequestStateEntry rsentry; - int max_prefill_length = 0; - int num_child_to_activate = 0; - }; - - /*! - * \brief Find one or multiple request state entries to run prefill. - * \param estate The engine state. - * \return The request entries to prefill, together with their input lengths. - */ - std::vector GetRequestStateEntriesToPrefill(EngineState estate) { - if (estate->waiting_queue.empty()) { - // No request to prefill. - return {}; - } - - std::vector prefill_inputs; - - // - Try to prefill pending requests. - int total_input_length = 0; - int total_required_pages = 0; - int num_available_pages = models_[0]->GetNumAvailablePages(); - int num_running_rsentries = GetRunningRequestStateEntries(estate).size(); - int current_total_seq_len = models_[0]->GetCurrentTotalSequenceLength(); - - int num_prefill_rsentries = 0; - for (const Request& request : estate->waiting_queue) { - RequestState rstate = estate->GetRequestState(request); - bool prefill_stops = false; - for (const RequestStateEntry& rsentry : rstate->entries) { - // A request state entry can be prefilled only when: - // - it has inputs, and - // - it has no parent or its parent is alive and has no remaining input. - if (rsentry->mstates[0]->inputs.empty() || - (rsentry->parent_idx != -1 && - (rstate->entries[rsentry->parent_idx]->status == RequestStateStatus::kPending || - !rstate->entries[rsentry->parent_idx]->mstates[0]->inputs.empty()))) { - continue; - } - - int input_length = rsentry->mstates[0]->GetInputLength(); - int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / - engine_config_->kv_cache_page_size; - total_input_length += input_length; - total_required_pages += num_require_pages; - // - Attempt 1. Check if the entire request state entry can fit for prefill. - bool can_prefill = false; - for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0; - --num_child_to_activate) { - if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate, - total_input_length, total_required_pages, num_available_pages, - current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length, num_child_to_activate}); - num_prefill_rsentries += 1 + num_child_to_activate; - can_prefill = true; - break; - } - } - if (can_prefill) { - continue; - } - total_input_length -= input_length; - total_required_pages -= num_require_pages; - - // - Attempt 2. Check if the request state entry can partially fit by input chunking. - ICHECK_LE(total_input_length, engine_config_->prefill_chunk_size); - if (engine_config_->prefill_chunk_size - total_input_length >= input_length || - engine_config_->prefill_chunk_size == total_input_length) { - // 1. If the input length can fit the remaining prefill chunk size, - // it means the failure of attempt 1 is not because of the input - // length being too long, and thus chunking does not help. - // 2. If the total input length already reaches the prefill chunk size, - // the current request state entry will not be able to be processed. - // So we can safely return in either case. - prefill_stops = true; - break; - } - input_length = engine_config_->prefill_chunk_size - total_input_length; - num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) / - engine_config_->kv_cache_page_size; - total_input_length += input_length; - total_required_pages += num_require_pages; - if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, - num_available_pages, current_total_seq_len, num_running_rsentries)) { - prefill_inputs.push_back({rsentry, input_length, 0}); - num_prefill_rsentries += 1; - } - - // - Prefill stops here. - prefill_stops = true; - break; - } - if (prefill_stops) { - break; - } - } - - return prefill_inputs; - } - - /*! \brief Check if the input requests can be prefilled under conditions. */ - bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, - int num_required_pages, int num_available_pages, int current_total_seq_len, - int num_running_rsentries) { - ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); - - // For RNN State, it can prefill as long as it can be instantiated. - if (engine_config_->kv_state_kind == KVStateKind::kRNNState) { - return true; - } - - // No exceeding of the maximum allowed requests that can - // run simultaneously. - int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable - ? (engine_config_->spec_draft_length + 1) - : 1; - if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > - std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { - return false; - } - - // NOTE: The conditions are heuristic and can be revised. - // Cond 1: total input length <= prefill chunk size. - // Cond 2: at least one decode can be performed after prefill. - // Cond 3: number of total tokens after 8 times of decode does not - // exceed the limit, where 8 is a watermark number can - // be configured and adjusted in the future. - int new_batch_size = num_running_rsentries + num_prefill_rsentries; - return total_input_length <= engine_config_->prefill_chunk_size && - num_required_pages + new_batch_size <= num_available_pages && - current_total_seq_len + total_input_length + 8 * new_batch_size <= - engine_config_->max_total_sequence_length; - } - - /*! - * \brief Chunk the input of the given RequestModelState for prefill - * with regard to the provided maximum allowed prefill length. - * Return the list of input for prefill and the total prefill length. - * The `inputs` field of the given `mstate` will be mutated to exclude - * the returned input. - * \param mstate The RequestModelState whose input data is to be chunked. - * \param max_prefill_length The maximum allowed prefill length for the mstate. - * \return The list of input for prefill and the total prefill length. - */ - std::pair, int> ChunkPrefillInputData(const RequestModelState& mstate, - int max_prefill_length) { - if (mstate->inputs.empty()) { - } - ICHECK(!mstate->inputs.empty()); - std::vector inputs; - int cum_input_length = 0; - inputs.reserve(mstate->inputs.size()); - for (int i = 0; i < static_cast(mstate->inputs.size()); ++i) { - inputs.push_back(mstate->inputs[i]); - int input_length = mstate->inputs[i]->GetLength(); - cum_input_length += input_length; - // Case 0. the cumulative input length does not reach the maximum prefill length. - if (cum_input_length < max_prefill_length) { - continue; - } - - // Case 1. the cumulative input length equals the maximum prefill length. - if (cum_input_length == max_prefill_length) { - if (i == static_cast(mstate->inputs.size()) - 1) { - // - If `i` is the last input, we just copy and reset `mstate->inputs`. - mstate->inputs.clear(); - } else { - // - Otherwise, set the new input array. - mstate->inputs = Array{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; - } - return {inputs, cum_input_length}; - } - - // Case 2. cum_input_length > max_prefill_length - // The input `i` itself needs chunking if it is TokenData, - // or otherwise it cannot be chunked. - Data input = mstate->inputs[i]; - inputs.pop_back(); - cum_input_length -= input_length; - const auto* token_input = input.as(); - if (token_input == nullptr) { - // Cannot chunk the input. - if (i != 0) { - mstate->inputs = Array{mstate->inputs.begin() + i, mstate->inputs.end()}; - } - return {inputs, cum_input_length}; - } - - // Split the token data into two parts. - // Return the first part for prefill, and keep the second part. - int chunked_input_length = max_prefill_length - cum_input_length; - ICHECK_GT(input_length, chunked_input_length); - TokenData chunked_input(IntTuple{token_input->token_ids.begin(), - token_input->token_ids.begin() + chunked_input_length}); - TokenData remaining_input(IntTuple{token_input->token_ids.begin() + chunked_input_length, - token_input->token_ids.end()}); - inputs.push_back(chunked_input); - cum_input_length += chunked_input_length; - std::vector remaining_inputs{mstate->inputs.begin() + i + 1, mstate->inputs.end()}; - remaining_inputs.insert(remaining_inputs.begin(), remaining_input); - mstate->inputs = remaining_inputs; - return {inputs, cum_input_length}; - } - - ICHECK(false) << "Cannot reach here"; - } - - /*! \brief The models to run prefill in. */ - Array models_; /*! \brief The logit processor. */ LogitProcessor logit_processor_; /*! \brief The sampler to sample new tokens. */ Sampler sampler_; /*! \brief Workspace of each model. */ std::vector model_workspaces_; - /*! \brief The engine config. */ - EngineConfig engine_config_; - /*! \brief Event trace recorder. */ - Optional trace_recorder_; }; EngineAction EngineAction::NewRequestPrefill(Array models, LogitProcessor logit_processor, From 45a0487ac399bf4b5587d4fcdf406480f226051c Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 12 May 2024 16:45:24 -0400 Subject: [PATCH 304/531] [iOS] Make MLCEngine input to take in structured data (#2330) This PR modifies the MLCEngine chatCompletion to take in structured data. Co-authored-by: Vivian Zhai <98248913+YiyanZhai@users.noreply.github.com> --- cpp/json_ffi/openai_api_protocol.cc | 3 +- cpp/json_ffi/openai_api_protocol.h | 2 +- .../MLCEngineExampleApp.swift | 25 +-- ios/MLCSwift/Sources/Swift/LLMEngine.swift | 52 +++++- .../Sources/Swift/OpenAIProtocol.swift | 165 ++++++++++++++++-- 5 files changed, 211 insertions(+), 36 deletions(-) diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc index c07de8fef5..22d95c72c1 100644 --- a/cpp/json_ffi/openai_api_protocol.cc +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -282,7 +282,8 @@ Result ChatCompletionRequest::FromJSON(const std::string& request.messages = messages; // model - Result model_res = json::LookupWithResultReturn(json_obj, "model"); + Result> model_res = + json::LookupOptionalWithResultReturn(json_obj, "model"); if (model_res.IsErr()) { return TResult::Error(model_res.UnwrapErr()); } diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index 914366c2f1..da9002f994 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -102,7 +102,7 @@ class RequestResponseFormat { class ChatCompletionRequest { public: std::vector messages; - std::string model; + std::optional model = std::nullopt; std::optional frequency_penalty = std::nullopt; std::optional presence_penalty = std::nullopt; bool logprobs = false; diff --git a/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift index 26361977ce..0049cee7e7 100644 --- a/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift +++ b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift @@ -38,24 +38,15 @@ class AppState: ObservableObject { // Step 0: load the engine await engine.reload(modelPath: modelLocalPath, modelLib: modelLib) - // TODO(mlc-team) update request so it is also structure based - // as in open ai api - // sent a request - let jsonRequest = """ - { - "model": "llama3", - "messages": [ - { - "role": "user", - "content": [ - { "type": "text", "text": "What is the meaning of life?" } - ] - } - ] - } - """ // run chat completion as in OpenAI API style - for await res in await engine.chatCompletion(jsonRequest: jsonRequest) { + for await res in await engine.chatCompletion( + messages: [ + ChatCompletionMessage( + role: .user, + content: "What is the meaning of life?" + ) + ] + ) { // publish at main event loop DispatchQueue.main.async { // parse the result content in structured form diff --git a/ios/MLCSwift/Sources/Swift/LLMEngine.swift b/ios/MLCSwift/Sources/Swift/LLMEngine.swift index 91a4d20b81..a57da15cc5 100644 --- a/ios/MLCSwift/Sources/Swift/LLMEngine.swift +++ b/ios/MLCSwift/Sources/Swift/LLMEngine.swift @@ -61,8 +61,56 @@ public actor MLCEngine { jsonFFIEngine.unload() } - // TODO(mlc-team) turn into a structured interface - public func chatCompletion(jsonRequest: String) -> AsyncStream { + // offer a direct convenient method to pass in messages + public func chatCompletion( + messages: [ChatCompletionMessage], + model: Optional = nil, + frequency_penalty: Optional = nil, + presence_penalty: Optional = nil, + logprobs: Bool = false, + top_logprobs: Int = 0, + logit_bias: Optional<[Int : Float]> = nil, + max_tokens: Optional = nil, + n: Int = 1, + seed: Optional = nil, + stop: Optional<[String]> = nil, + stream: Bool = false, + temperature: Optional = nil, + top_p: Optional = nil, + tools: Optional<[ChatTool]> = nil, + user: Optional = nil, + response_format: Optional = nil + ) -> AsyncStream { + let request = ChatCompletionRequest( + messages: messages, + model: model, + frequency_penalty: frequency_penalty, + presence_penalty: presence_penalty, + logprobs: logprobs, + top_logprobs: top_logprobs, + logit_bias: logit_bias, + max_tokens: max_tokens, + n: n, + seed: seed, + stop: stop, + stream: stream, + temperature: temperature, + top_p: top_p, + tools: tools, + user: user, + response_format: response_format + ) + return self.chatCompletion(request: request) + } + + // completion function + public func chatCompletion( + request: ChatCompletionRequest + ) -> AsyncStream { + let encoder = JSONEncoder() + let data = try! encoder.encode(request) + let jsonRequest = String(data: data, encoding: .utf8)! + // generate a UUID for the request let requestID = UUID().uuidString let stream = AsyncStream(ChatCompletionStreamResponse.self) { continuation in diff --git a/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift b/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift index 1f36933a15..c364fad3a3 100644 --- a/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift +++ b/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift @@ -5,14 +5,14 @@ import Foundation // API reference: https://platform.openai.com/docs/api-reference/chat/create public struct TopLogProbs : Codable { - public let token: String - public let logprob: Float - public let bytes: Optional<[Int]> + public var token: String + public var logprob: Float + public var bytes: Optional<[Int]> } public struct LogProbsContent : Codable { - public let token: String - public let logprob: Float + public var token: String + public var logprob: Float public var bytes: Optional<[Int]> = nil public var top_logprobs: [TopLogProbs] = [] } @@ -22,49 +22,184 @@ public struct LogProbs : Codable { } public struct ChatFunction : Codable { - public let name: String + public var name: String public var description: Optional = nil - public let parameters: [String: String] + public var parameters: [String: String] + + public init( + name: String, + description: Optional = nil, + parameters: [String : String] + ) { + self.name = name + self.description = description + self.parameters = parameters + } } public struct ChatTool : Codable { public var type: String = "function" public let function: ChatFunction + + public init(type: String, function: ChatFunction) { + self.type = type + self.function = function + } } public struct ChatFunctionCall : Codable { - public let name: String + public var name: String // NOTE: arguments shold be dict str to any codable // for now only allow string output due to typing issues public var arguments: Optional<[String: String]> = nil + + public init(name: String, arguments: Optional<[String : String]> = nil) { + self.name = name + self.arguments = arguments + } } public struct ChatToolCall : Codable { public var id: String = UUID().uuidString public var type: String = "function" - public let function: ChatFunctionCall + public var function: ChatFunctionCall + + public init( + id: String = UUID().uuidString, + type: String = "function", + function: ChatFunctionCall + ) { + self.id = id + self.type = type + self.function = function + } } -public struct ChatCompletionMessage : Codable { - public let role: String +public enum ChatCompletionRole: String, Codable { + case system = "system" + case user = "user" + case assistant = "assistant" + case tool = "tool" +} + +public struct ChatCompletionMessage: Codable { + public var role: ChatCompletionRole public var content: Optional<[[String: String]]> = nil public var name: Optional = nil public var tool_calls: Optional<[ChatToolCall]> = nil public var tool_call_id: Optional = nil + + // more complicated content construction + public init( + role: ChatCompletionRole, + content: Optional<[[String : String]]> = nil, + name: Optional = nil, + tool_calls: Optional<[ChatToolCall]> = nil, + tool_call_id: Optional = nil + ) { + self.role = role + self.content = content + self.name = name + self.tool_calls = tool_calls + self.tool_call_id = tool_call_id + } + + // convenient method to construct content from string + public init( + role: ChatCompletionRole, + content: String, + name: Optional = nil, + tool_calls: Optional<[ChatToolCall]> = nil, + tool_call_id: Optional = nil + ) { + self.role = role + self.content = [["type": "text", "text": content]] + self.name = name + self.tool_calls = tool_calls + self.tool_call_id = tool_call_id + } } public struct ChatCompletionStreamResponseChoice: Codable { public var finish_reason: Optional = nil - public let index: Int - public let delta: ChatCompletionMessage + public var index: Int + public var delta: ChatCompletionMessage public var lobprobs: Optional = nil } public struct ChatCompletionStreamResponse: Codable { - public let id : String + public var id : String public var choices: [ChatCompletionStreamResponseChoice] = [] public var created: Optional = nil public var model: Optional = nil - public let system_fingerprint: String + public var system_fingerprint: String public var object: Optional = nil } + +public struct ResponseFormat: Codable { + public var type: String + public var schema: Optional = nil + + public init(type: String, schema: Optional = nil) { + self.type = type + self.schema = schema + } +} + +public struct ChatCompletionRequest: Codable { + public var messages: [ChatCompletionMessage] + public var model: Optional = nil + public var frequency_penalty: Optional = nil + public var presence_penalty: Optional = nil + public var logprobs: Bool = false + public var top_logprobs: Int = 0 + public var logit_bias: Optional<[Int: Float]> = nil + public var max_tokens: Optional = nil + public var n: Int = 1 + public var seed: Optional = nil + public var stop: Optional<[String]> = nil + public var stream: Bool = false + public var temperature: Optional = nil + public var top_p: Optional = nil + public var tools: Optional<[ChatTool]> = nil + public var user: Optional = nil + public var response_format: Optional = nil + + public init( + messages: [ChatCompletionMessage], + model: Optional = nil, + frequency_penalty: Optional = nil, + presence_penalty: Optional = nil, + logprobs: Bool = false, + top_logprobs: Int = 0, + logit_bias: Optional<[Int : Float]> = nil, + max_tokens: Optional = nil, + n: Int = 1, + seed: Optional = nil, + stop: Optional<[String]> = nil, + stream: Bool = false, + temperature: Optional = nil, + top_p: Optional = nil, + tools: Optional<[ChatTool]> = nil, + user: Optional = nil, + response_format: Optional = nil + ) { + self.messages = messages + self.model = model + self.frequency_penalty = frequency_penalty + self.presence_penalty = presence_penalty + self.logprobs = logprobs + self.top_logprobs = top_logprobs + self.logit_bias = logit_bias + self.max_tokens = max_tokens + self.n = n + self.seed = seed + self.stop = stop + self.stream = stream + self.temperature = temperature + self.top_p = top_p + self.tools = tools + self.user = user + self.response_format = response_format + } +} From 679d3a8eecc4abb5991e9eee2d6b40384d8a1abc Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 12 May 2024 20:27:54 -0700 Subject: [PATCH 305/531] [REFACTOR] Refactor JSONFFI Conv template (#2331) This PR refactors JSONFFI conv template to use immutable processing. This helps to prevent bugs from multiple requests and concurrent access to the conversation data structure. It also reduces the need to deep copy the struct. --- cpp/json_ffi/conv_template.cc | 299 +++++++++++------- cpp/json_ffi/conv_template.h | 37 ++- cpp/json_ffi/json_ffi_engine.cc | 50 +-- cpp/json_ffi/openai_api_protocol.cc | 118 +++---- cpp/json_ffi/openai_api_protocol.h | 37 ++- .../MLCEngineExampleApp.swift | 2 +- ios/MLCSwift/Sources/Swift/LLMEngine.swift | 1 - .../Sources/Swift/OpenAIProtocol.swift | 47 ++- tests/python/json_ffi/test_json_ffi_engine.py | 11 +- 9 files changed, 348 insertions(+), 254 deletions(-) diff --git a/cpp/json_ffi/conv_template.cc b/cpp/json_ffi/conv_template.cc index e23258f0b8..a386e09921 100644 --- a/cpp/json_ffi/conv_template.cc +++ b/cpp/json_ffi/conv_template.cc @@ -131,7 +131,7 @@ ModelConfig ModelConfig::FromJSON(const picojson::object& json_obj) { /****************** Conversation template ******************/ -std::map PLACEHOLDERS = { +std::unordered_map PLACEHOLDERS = { {MessagePlaceholders::SYSTEM, "{system_message}"}, {MessagePlaceholders::USER, "{user_message}"}, {MessagePlaceholders::ASSISTANT, "{assistant_message}"}, @@ -153,120 +153,213 @@ Conversation::Conversation() {"assistant", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]}, {"tool", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {} -Result> Conversation::AsPrompt(ModelConfig config, DLDevice device) { - using TResult = Result>; - // Get the system message - std::string system_msg = system_template; - size_t pos = system_msg.find(PLACEHOLDERS[MessagePlaceholders::SYSTEM]); +std::string Conversation::GetSystemText(const std::string& system_msg) const { + std::string system_text = this->system_template; + static std::string system_placeholder = PLACEHOLDERS[MessagePlaceholders::SYSTEM]; + size_t pos = system_text.find(system_placeholder); if (pos != std::string::npos) { - system_msg.replace(pos, PLACEHOLDERS[MessagePlaceholders::SYSTEM].length(), - this->system_message); + system_text.replace(pos, system_placeholder.length(), system_msg); } + return system_text; +} - // Get the message strings - std::vector message_list; - std::vector separators = seps; - if (separators.size() == 1) { - separators.push_back(separators[0]); +std::string Conversation::GetRoleText(const std::string& role, const std::string& content, + const std::optional& fn_call_string) const { + std::string role_text = this->role_templates.at(role); + std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)]; + size_t pos = role_text.find(placeholder); + if (pos != std::string::npos) { + role_text.replace(pos, placeholder.length(), content); + } + if (fn_call_string) { + // replace placeholder[FUNCTION] with function_string + // this assumes function calling is used for a single request scenario only + pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]); + if (pos != std::string::npos) { + role_text.replace(pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length(), + fn_call_string.value()); + } } + return role_text; +} - if (!system_msg.empty()) { - system_msg += separators[0]; - message_list.push_back(TextData(system_message)); +/// Try to detect if function calling is needed, if so, return the function calling string +Result> TryGetFunctionCallingString( + const Conversation& conv, const ChatCompletionRequest& request) { + using TResult = Result>; + if (!request.tools.has_value() || + (request.tool_choice.has_value() && request.tool_choice.value() == "none")) { + return TResult::Ok(std::nullopt); + } + std::vector tools_ = request.tools.value(); + std::string tool_choice_ = request.tool_choice.value(); + + // TODO: support with tool choice as dict + for (const auto& tool : tools_) { + if (tool.function.name == tool_choice_) { + picojson::value function_str(tool.function.AsJSON()); + return TResult::Ok(function_str.serialize()); + } } - for (int i = 0; i < messages.size(); i++) { - std::string role = messages[i].role; - // Todo(mlc-team): support content to be a single string. - std::optional>> content = - messages[i].content; - if (roles.find(role) == roles.end()) { - return TResult::Error("Role \"" + role + "\" is not supported"); - } + if (tool_choice_ != "auto") { + return TResult::Error("Invalid tool_choice value in the request: " + tool_choice_); + } + + picojson::array function_list; + for (const auto& tool : tools_) { + function_list.push_back(picojson::value(tool.function.AsJSON())); + } - std::string separator = separators[role == "assistant"]; // check assistant role + picojson::value function_list_json(function_list); + return TResult::Ok(function_list_json.serialize()); +}; - // If content is empty, add the role and separator - // assistant's turn to generate text - if (!content.has_value()) { - message_list.push_back(TextData(roles[role] + role_empty_sep)); - continue; - } +Result> CreatePrompt(const Conversation& conv, + const ChatCompletionRequest& request, + const ModelConfig& config, DLDevice device) { + using TResult = Result>; + + Result> fn_call_str_tmp = TryGetFunctionCallingString(conv, request); + if (fn_call_str_tmp.IsErr()) { + return TResult::Error(fn_call_str_tmp.UnwrapErr()); + } + std::optional fn_call_string = fn_call_str_tmp.Unwrap(); + + // Handle system message + // concz + bool has_custom_system = false; + std::string custom_system_inputs; - std::string message = ""; - std::string role_prefix = ""; - // Do not append role prefix if this is the first message and there - // is already a system message - if (add_role_after_system_message || system_msg.empty() || i != 0) { - role_prefix = roles[role] + role_content_sep; + auto f_populate_system_message = [&](const std::vector& msg_vec) { + for (ChatCompletionMessage msg : msg_vec) { + if (msg.role == "system") { + ICHECK(msg.content.IsText()) << "System message must be text"; + custom_system_inputs += msg.content.Text(); + has_custom_system = true; + } } + }; + // go through messages in template and passed in. + f_populate_system_message(conv.messages); + f_populate_system_message(request.messages); - message += role_prefix; + // pending text records the text to be put into data + // we lazily accumulate the pending text + // to reduce amount of segments in the Data vector + std::string pending_text = + conv.GetSystemText(has_custom_system ? custom_system_inputs : conv.system_message); - for (const auto& item : content.value()) { - auto it_type = item.find("type"); - if (it_type == item.end()) { - return TResult::Error("The content of a message does not have \"type\" field"); + // the seperator after system message. + if (!pending_text.empty()) { + pending_text += conv.seps[0]; + } + + // Get the message strings + std::vector message_list; + size_t non_system_msg_count = 0; + + // returns error if error happens + auto f_process_messages = + [&](const std::vector& msg_vec) -> std::optional { + for (size_t i = 0; i < msg_vec.size(); ++i) { + const ChatCompletionMessage& msg = msg_vec[i]; + auto role_it = conv.roles.find(msg.role); + if (role_it == conv.roles.end()) { + return TResult::Error("Role \"" + msg.role + "\" is not supported"); } - if (it_type->second == "text") { - auto it_text = item.find("text"); - if (it_text == item.end()) { - return TResult::Error("The text type content of a message does not have \"text\" field"); - } - // replace placeholder[ROLE] with input message from role - std::string role_text = role_templates[role]; - std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)]; - size_t pos = role_text.find(placeholder); - if (pos != std::string::npos) { - role_text.replace(pos, placeholder.length(), it_text->second); - } - if (use_function_calling) { - // replace placeholder[FUNCTION] with function_string - // this assumes function calling is used for a single request scenario only - if (!function_string.has_value()) { - return TResult::Error( - "The function string in conversation template is not defined for function " - "calling."); + const std::string& role_name = role_it->second; + // skip system message as it is already processed + if (msg.role == "system") continue; + // skip when content is empty + if (msg.content.IsNull()) { + pending_text += role_name + conv.role_empty_sep; + continue; + } + ++non_system_msg_count; + // assistant uses conv.seps[1] if there are two seps + int sep_offset = msg.role == "assistant" ? 1 : 0; + const std::string& seperator = conv.seps[sep_offset % conv.seps.size()]; + // setup role prefix + std::string role_prefix = ""; + // Do not append role prefix if this is the first message and there is already a system + // message + if (conv.add_role_after_system_message || pending_text.empty() || non_system_msg_count != 1) { + role_prefix = role_name + conv.role_content_sep; + } + pending_text += role_prefix; + + if (msg.content.IsParts()) { + for (const auto& item : msg.content.Parts()) { + auto it_type = item.find("type"); + if (it_type == item.end()) { + return TResult::Error("The content of a message does not have \"type\" field"); } - pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]); - if (pos != std::string::npos) { - role_text.replace(pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length(), - function_string.value()); + if (it_type->second == "text") { + auto it_text = item.find("text"); + if (it_text == item.end()) { + return TResult::Error( + "The text type content of a message does not have \"text\" field"); + } + // replace placeholder[ROLE] with input message from role + pending_text += conv.GetRoleText(msg.role, it_text->second, fn_call_string); + } else if (it_type->second == "image_url") { + if (item.find("image_url") == item.end()) { + return TResult::Error("Content should have an image_url field"); + } + std::string image_url = + item.at("image_url"); // TODO(mlc-team): According to OpenAI API reference this + // should be a map, with a "url" key containing the URL, but + // we are just assuming this as the URL for now + std::string base64_image = image_url.substr(image_url.find(",") + 1); + Result image_data_res = LoadImageFromBase64(base64_image); + if (image_data_res.IsErr()) { + return TResult::Error(image_data_res.UnwrapErr()); + } + if (!config.vision_config.has_value()) { + return TResult::Error("Vision config is required for image input"); + } + int image_size = config.vision_config.value().image_size; + int patch_size = config.vision_config.value().patch_size; + + int embed_size = (image_size * image_size) / (patch_size * patch_size); + + auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device); + // lazily commit text data + if (pending_text.length() != 0) { + message_list.push_back(TextData(pending_text)); + pending_text = ""; + } + message_list.push_back(ImageData(image_ndarray, embed_size)); + } else { + return TResult::Error("Unsupported content type: " + it_type->second); } } - message += role_text; - } else if (it_type->second == "image_url") { - if (item.find("image_url") == item.end()) { - return TResult::Error("Content should have an image_url field"); - } - std::string image_url = - item.at("image_url"); // TODO(mlc-team): According to OpenAI API reference this - // should be a map, with a "url" key containing the URL, but - // we are just assuming this as the URL for now - std::string base64_image = image_url.substr(image_url.find(",") + 1); - Result image_data_res = LoadImageFromBase64(base64_image); - if (image_data_res.IsErr()) { - return TResult::Error(image_data_res.UnwrapErr()); - } - if (!config.vision_config.has_value()) { - return TResult::Error("Vision config is required for image input"); - } - int image_size = config.vision_config.value().image_size; - int patch_size = config.vision_config.value().patch_size; - - int embed_size = (image_size * image_size) / (patch_size * patch_size); - - auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device); - message_list.push_back(ImageData(image_ndarray, embed_size)); } else { - return TResult::Error("Unsupported content type: " + it_type->second); + ICHECK(msg.content.IsText()); + pending_text += conv.GetRoleText(msg.role, msg.content.Text(), fn_call_string); } + pending_text += seperator; } + return std::nullopt; + }; - message += separator; - message_list.push_back(TextData(message)); + if (auto err = f_process_messages(conv.messages)) { + return err.value(); + } + if (auto err = f_process_messages(request.messages)) { + return err.value(); + } + // append last assistant begin message + ChatCompletionMessage last_assistant_begin; + last_assistant_begin.role = "assistant"; + last_assistant_begin.content = std::nullopt; + if (auto err = f_process_messages({last_assistant_begin})) { + return err.value(); + } + if (pending_text.length() != 0) { + message_list.push_back(TextData(pending_text)); } - return TResult::Ok(message_list); } @@ -383,7 +476,10 @@ Result Conversation::FromJSON(const picojson::object& json_obj) { content.push_back(std::move(item_map)); } } - conv.messages.push_back({role_res.Unwrap(), content}); + ChatCompletionMessage msg; + msg.role = role_res.Unwrap(); + msg.content = content; + conv.messages.push_back(msg); } Result seps_arr_res = @@ -438,21 +534,6 @@ Result Conversation::FromJSON(const picojson::object& json_obj) { } conv.stop_token_ids.push_back(stop.get()); } - - Result> function_string_res = - json::LookupOptionalWithResultReturn(json_obj, "function_string"); - if (function_string_res.IsErr()) { - return TResult::Error(function_string_res.UnwrapErr()); - } - conv.function_string = function_string_res.Unwrap(); - - Result use_function_calling_res = json::LookupOrDefaultWithResultReturn( - json_obj, "use_function_calling", conv.use_function_calling); - if (use_function_calling_res.IsErr()) { - return TResult::Error(use_function_calling_res.UnwrapErr()); - } - conv.use_function_calling = use_function_calling_res.Unwrap(); - return TResult::Ok(conv); } diff --git a/cpp/json_ffi/conv_template.h b/cpp/json_ffi/conv_template.h index 8217c5d6e5..e6c8e784f7 100644 --- a/cpp/json_ffi/conv_template.h +++ b/cpp/json_ffi/conv_template.h @@ -11,6 +11,7 @@ #include "../serve/data.h" #include "../support/result.h" +#include "openai_api_protocol.h" #include "picojson.h" using namespace mlc::llm::serve; @@ -62,12 +63,6 @@ enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION }; MessagePlaceholders MessagePlaceholderFromString(const std::string& role); -class Message { - public: - std::string role; - std::optional>> content = std::nullopt; -}; - /** * @brief A struct that specifies the convention template of conversation * and contains the conversation history. @@ -102,7 +97,7 @@ struct Conversation { // The conversation history messages. // Each message is a pair of strings, denoting "(role, content)". // The content can be None. - std::vector messages; + std::vector messages; // The separators between messages when concatenating into a single prompt. // List size should be either 1 or 2. @@ -121,15 +116,24 @@ struct Conversation { std::vector stop_str; std::vector stop_token_ids; - // Function call fields - // whether using function calling or not, helps check for output message format in API call - std::optional function_string = std::nullopt; - bool use_function_calling = false; - Conversation(); - /*! \brief Create the list of prompts from the messages based on the conversation template. */ - Result> AsPrompt(ModelConfig config, DLDevice device); + /*! + * \brief Get the system text(with the prompt template) given the system prompt message + * \param system_msg The system prompt message. + * \return The created system text. + */ + std::string GetSystemText(const std::string& system_msg) const; + + /*! + * \brief replace the content from role by the correct role text in template + * \param role The input role + * \param content The input content from the role + * \param fn_call_str The function calling string if any. + * \return The created text. + */ + std::string GetRoleText(const std::string& role, const std::string& content, + const std::optional& fn_call_str) const; /*! \brief Create a Conversation instance from the given JSON object. */ static Result FromJSON(const picojson::object& json); @@ -137,6 +141,11 @@ struct Conversation { static Result FromJSON(const std::string& json_str); }; +/*! \brief Create the list of prompts from the messages based on the conversation template. */ +Result> CreatePrompt(const Conversation& conv, + const ChatCompletionRequest& request, + const ModelConfig& config, DLDevice device); + } // namespace json_ffi } // namespace llm } // namespace mlc diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 65f3183424..343266135c 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -31,7 +31,7 @@ void JSONFFIEngine::StreamBackError(std::string request_id) { ChatCompletionMessage delta; delta.content = std::vector>{ {{"type", "text"}, {"text", this->err_}}}; - delta.role = Role::assistant; + delta.role = "assistant"; ChatCompletionStreamResponseChoice choice; choice.finish_reason = FinishReason::error; @@ -54,38 +54,9 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request return false; } ChatCompletionRequest request = request_res.Unwrap(); - // Create Request - // TODO: Check if request_id is present already - - // inputs - Conversation conv_template = this->conv_template_; - std::vector messages; - for (const auto& message : request.messages) { - std::string role; - if (message.role == Role::user) { - role = "user"; - } else if (message.role == Role::assistant) { - role = "assistant"; - } else if (message.role == Role::tool) { - role = "tool"; - } else { - role = "system"; - } - messages.push_back({role, message.content}); - } - messages.push_back({"assistant", std::nullopt}); - conv_template.messages = messages; - - // check function calling - Result updated_conv_template = request.CheckFunctionCalling(conv_template); - if (updated_conv_template.IsErr()) { - err_ = updated_conv_template.UnwrapErr(); - return false; - } - conv_template = updated_conv_template.Unwrap(); - - // get prompt - Result> inputs_obj = conv_template.AsPrompt(this->model_config_, this->device_); + // get prompt: note, assistant was appended in the end. + Result> inputs_obj = + CreatePrompt(this->conv_template_, request, this->model_config_, this->device_); if (inputs_obj.IsErr()) { err_ = inputs_obj.UnwrapErr(); return false; @@ -94,8 +65,8 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request // generation_cfg Array stop_strs; - stop_strs.reserve(conv_template.stop_str.size()); - for (const std::string& stop_str : conv_template.stop_str) { + stop_strs.reserve(this->conv_template_.stop_str.size()); + for (const std::string& stop_str : this->conv_template_.stop_str) { stop_strs.push_back(stop_str); } if (request.stop.has_value()) { @@ -110,7 +81,7 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request /*repetition_penalty=*/std::nullopt, request.logprobs, request.top_logprobs, request.logit_bias, request.seed, request.ignore_eos, request.max_tokens, std::move(stop_strs), - conv_template.stop_token_ids, /*response_format=*/std::nullopt, + conv_template_.stop_token_ids, /*response_format=*/std::nullopt, this->default_generation_cfg_json_str_); Request engine_request(request_id, inputs, generation_cfg); @@ -232,11 +203,8 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { // Size of delta_output->group_delta_token_ids Array should be 1 IntTuple delta_token_ids = delta_output->group_delta_token_ids[0]; std::vector delta_token_ids_vec(delta_token_ids.begin(), delta_token_ids.end()); - delta.content = std::vector>(); - delta.content.value().push_back(std::unordered_map{ - {"type", "text"}, {"text", this->streamer_->Put(delta_token_ids_vec)}}); - - delta.role = Role::assistant; + delta.content = this->streamer_->Put(delta_token_ids_vec); + delta.role = "assistant"; choice.delta = delta; diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc index 22d95c72c1..525366440a 100644 --- a/cpp/json_ffi/openai_api_protocol.cc +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -170,25 +170,37 @@ picojson::object ChatToolCall::AsJSON() const { Result ChatCompletionMessage::FromJSON(const picojson::object& json_obj) { using TResult = Result; ChatCompletionMessage message; + ChatCompletionMessageContent content; // content - Result content_arr_res = - json::LookupWithResultReturn(json_obj, "content"); - if (content_arr_res.IsErr()) { - return TResult::Error(content_arr_res.UnwrapErr()); - } - std::vector> content; - for (const auto& item : content_arr_res.Unwrap()) { - // Todo(mlc-team): allow content item to be a single string. - if (!item.is()) { - return TResult::Error("The content of chat completion message is not an object"); + auto it = json_obj.find("content"); + if (it == json_obj.end()) { + return TResult::Error("ValueError: key \"content\" not found in the chat completion."); + } + if (it->second.is()) { + content = it->second.get(); + } else if (it->second.is()) { + // skip + } else { + // most complicated case + std::vector> parts; + Result content_arr_res = + json::LookupWithResultReturn(json_obj, "content"); + if (content_arr_res.IsErr()) { + return TResult::Error(content_arr_res.UnwrapErr()); } - picojson::object item_obj = item.get(); - std::unordered_map item_map; - for (const auto& [key, value] : item_obj) { - item_map[key] = value.to_str(); + for (const auto& item : content_arr_res.Unwrap()) { + if (!item.is()) { + return TResult::Error("The content of chat completion message is not an object"); + } + picojson::object item_obj = item.get(); + std::unordered_map item_map; + for (const auto& [key, value] : item_obj) { + item_map[key] = value.to_str(); + } + parts.push_back(std::move(item_map)); } - content.push_back(std::move(item_map)); + content = parts; } message.content = content; @@ -198,14 +210,8 @@ Result ChatCompletionMessage::FromJSON(const picojson::ob return TResult::Error(role_str_res.UnwrapErr()); } std::string role_str = role_str_res.Unwrap(); - if (role_str == "system") { - message.role = Role::system; - } else if (role_str == "user") { - message.role = Role::user; - } else if (role_str == "assistant") { - message.role = Role::assistant; - } else if (role_str == "tool") { - message.role = Role::tool; + if (role_str == "system" || role_str == "user" || role_str == "assistant" || role_str == "tool") { + message.role = role_str; } else { return TResult::Error("Invalid role in chat completion message: " + role_str); } @@ -345,30 +351,28 @@ Result ChatCompletionRequest::FromJSON(const std::string& } // TODO: Other parameters - return TResult::Ok(request); } picojson::object ChatCompletionMessage::AsJSON() const { picojson::object obj; - picojson::array content_arr; - for (const auto& item : this->content.value()) { - picojson::object item_obj; - for (const auto& pair : item) { - item_obj[pair.first] = picojson::value(pair.second); + + if (this->content.IsText()) { + obj["content"] = picojson::value(this->content.Text()); + } else if (this->content.IsParts()) { + picojson::array content_arr; + for (const auto& item : this->content.Parts()) { + picojson::object item_obj; + for (const auto& pair : item) { + item_obj[pair.first] = picojson::value(pair.second); + } + content_arr.push_back(picojson::value(item_obj)); } - content_arr.push_back(picojson::value(item_obj)); - } - obj["content"] = picojson::value(content_arr); - if (this->role == Role::system) { - obj["role"] = picojson::value("system"); - } else if (this->role == Role::user) { - obj["role"] = picojson::value("user"); - } else if (this->role == Role::assistant) { - obj["role"] = picojson::value("assistant"); - } else if (this->role == Role::tool) { - obj["role"] = picojson::value("tool"); + obj["content"] = picojson::value(content_arr); } + + obj["role"] = picojson::value(this->role); + if (this->name.has_value()) { obj["name"] = picojson::value(this->name.value()); } @@ -385,40 +389,6 @@ picojson::object ChatCompletionMessage::AsJSON() const { return obj; } -Result ChatCompletionRequest::CheckFunctionCalling(Conversation conv_template) { - using TResult = Result; - if (!tools.has_value() || (tool_choice.has_value() && tool_choice.value() == "none")) { - conv_template.use_function_calling = false; - return TResult::Ok(conv_template); - } - std::vector tools_ = tools.value(); - std::string tool_choice_ = tool_choice.value(); - - // TODO: support with tool choice as dict - for (const auto& tool : tools_) { - if (tool.function.name == tool_choice_) { - conv_template.use_function_calling = true; - picojson::value function_str(tool.function.AsJSON()); - conv_template.function_string = function_str.serialize(); - return TResult::Ok(conv_template); - } - } - - if (tool_choice_ != "auto") { - return TResult::Error("Invalid tool_choice value in the request: " + tool_choice_); - } - - picojson::array function_list; - for (const auto& tool : tools_) { - function_list.push_back(picojson::value(tool.function.AsJSON())); - } - - conv_template.use_function_calling = true; - picojson::value function_list_json(function_list); - conv_template.function_string = function_list_json.serialize(); - return TResult::Ok(conv_template); -}; - picojson::object ChatCompletionResponseChoice::AsJSON() const { picojson::object obj; if (!this->finish_reason.has_value()) { diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index da9002f994..50f7315778 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -14,14 +14,12 @@ #include #include "../support/result.h" -#include "conv_template.h" #include "picojson.h" namespace mlc { namespace llm { namespace json_ffi { -enum class Role { system, user, assistant, tool }; enum class Type { text, json_object, function }; enum class FinishReason { stop, length, tool_calls, error }; @@ -80,11 +78,41 @@ class ChatToolCall { picojson::object AsJSON() const; }; +class ChatCompletionMessageContent { + public: + ChatCompletionMessageContent() = default; + + ChatCompletionMessageContent(std::nullopt_t) {} // NOLINT(*) + + ChatCompletionMessageContent(std::string text) : text_(text) {} // NOLINT(*) + + ChatCompletionMessageContent( + std::vector> parts) // NOLINT(*) + : parts_(parts) {} + + bool IsNull() const { return !IsText() && !IsParts(); } + + bool IsText() const { return text_.operator bool(); } + + bool IsParts() const { return parts_.operator bool(); } + + const std::string& Text() const { return text_.value(); } + + const std::vector>& Parts() const { + return parts_.value(); + } + + private: + /*! \brief used to store text content */ + std::optional text_; + std::optional>> parts_; +}; + class ChatCompletionMessage { public: - std::optional>> content = + ChatCompletionMessageContent content = std::nullopt; // Assuming content is a list of string key-value pairs - Role role; + std::string role; std::optional name = std::nullopt; std::optional> tool_calls = std::nullopt; std::optional tool_call_id = std::nullopt; @@ -124,7 +152,6 @@ class ChatCompletionRequest { /*! \brief Parse and create a ChatCompletionRequest instance from the given JSON string. */ static Result FromJSON(const std::string& json_str); - Result CheckFunctionCalling(Conversation conv_template); // TODO: check_penalty_range, check_logit_bias, check_logprobs }; diff --git a/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift index 0049cee7e7..991149be2b 100644 --- a/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift +++ b/ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift @@ -51,7 +51,7 @@ class AppState: ObservableObject { DispatchQueue.main.async { // parse the result content in structured form // and stream back to the display - self.displayText += res.choices[0].delta.content![0]["text"]! + self.displayText += res.choices[0].delta.content!.asText() } } } diff --git a/ios/MLCSwift/Sources/Swift/LLMEngine.swift b/ios/MLCSwift/Sources/Swift/LLMEngine.swift index a57da15cc5..ce167b7dd3 100644 --- a/ios/MLCSwift/Sources/Swift/LLMEngine.swift +++ b/ios/MLCSwift/Sources/Swift/LLMEngine.swift @@ -110,7 +110,6 @@ public actor MLCEngine { let encoder = JSONEncoder() let data = try! encoder.encode(request) let jsonRequest = String(data: data, encoding: .utf8)! - // generate a UUID for the request let requestID = UUID().uuidString let stream = AsyncStream(ChatCompletionStreamResponse.self) { continuation in diff --git a/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift b/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift index c364fad3a3..edb0fa5211 100644 --- a/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift +++ b/ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift @@ -82,9 +82,46 @@ public enum ChatCompletionRole: String, Codable { case tool = "tool" } +public enum ChatCompletionMessageContent: Codable { + case text(String) + case parts([[String: String]]) + + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if let text = try? container.decode(String.self) { + self = .text(text) + } else { + let parts = try container.decode([[String: String]].self) + self = .parts(parts) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .text(let text): try container.encode(text) + case .parts(let parts): try container.encode(parts) + } + } + + public func asText() -> String { + switch (self) { + case .text(let text): return text + case .parts(let parts): + var res = "" + for item in parts { + if item["type"]! == "text" { + res += item["text"]! + } + } + return res + } + } +} + public struct ChatCompletionMessage: Codable { public var role: ChatCompletionRole - public var content: Optional<[[String: String]]> = nil + public var content: Optional = nil public var name: Optional = nil public var tool_calls: Optional<[ChatToolCall]> = nil public var tool_call_id: Optional = nil @@ -98,7 +135,11 @@ public struct ChatCompletionMessage: Codable { tool_call_id: Optional = nil ) { self.role = role - self.content = content + if let cvalue = content { + self.content = .parts(cvalue) + } else { + self.content = nil + } self.name = name self.tool_calls = tool_calls self.tool_call_id = tool_call_id @@ -113,7 +154,7 @@ public struct ChatCompletionMessage: Codable { tool_call_id: Optional = nil ) { self.role = role - self.content = [["type": "text", "text": content]] + self.content = .text(content) self.name = name self.tool_calls = tool_calls self.tool_call_id = tool_call_id diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index b438c2a352..ca2e7deb98 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -66,9 +66,8 @@ def run_chat_completion( ): for choice in response.choices: assert choice.delta.role == "assistant" - assert isinstance(choice.delta.content[0], Dict) - assert choice.delta.content[0]["type"] == "text" - output_texts[rid][choice.index] += choice.delta.content[0]["text"] + assert isinstance(choice.delta.content, str) + output_texts[rid][choice.index] += choice.delta.content # Print output. print("Chat completion all finished") @@ -83,7 +82,7 @@ def run_chat_completion( def test_chat_completion(): # Create engine. - model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" engine = JSONFFIEngine( model, max_total_sequence_length=1024, @@ -101,7 +100,7 @@ def test_chat_completion(): def test_reload_reset_unload(): # Create engine. - model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" + model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" engine = JSONFFIEngine( model, max_total_sequence_length=1024, @@ -136,4 +135,4 @@ def test_function_calling(): if __name__ == "__main__": test_chat_completion() test_reload_reset_unload() - test_function_calling() + # test_function_calling() From 821ee5dbdb415f61459cc4f183f5af87f3707c43 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 13 May 2024 16:27:23 -0700 Subject: [PATCH 306/531] [Eagle] Fix the requests for additional decode in eagle verify (#2336) --- cpp/serve/engine_actions/eagle_batch_verify.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index 71daaf1bf9..9f31ed22d6 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -218,11 +218,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj { hidden_states, hidden_states_positions_for_fully_accepted, &model_workspaces_[draft_model_id_].hidden_states); // - Invoke model decode. - ObjectRef fused_embedding_hidden_states = - models_[draft_model_id_]->FuseEmbedHidden(embeddings, hidden_states_for_fully_accepted, - /*batch_size*/ num_rsentries, /*seq_len*/ 1); + ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( + embeddings, hidden_states_for_fully_accepted, + /*batch_size*/ fully_accepted_rsentries.size(), /*seq_len*/ 1); hidden_states_for_fully_accepted = models_[draft_model_id_]->BatchDecodeToLastHidden( - fused_embedding_hidden_states, request_internal_ids); + fused_embedding_hidden_states, fully_accepted_request_internal_ids); // - We explicitly synchronize to avoid the input tokens getting overriden in the // next runs of BatchDecode. // This is because we do not do sample for this round of batch decode. From bc6e3eddbd0979d365d8f8586c2c88d480bc1699 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Tue, 14 May 2024 05:48:23 -0700 Subject: [PATCH 307/531] [Serving][Grammar] Refactor GrammarStateMatcher and support LLaMA-3 (#2335) This PR refactors GrammarStateMatcher and support the LLaMA-3 tokenizer. Common tokenizers, including Phi-2, Gemma, LLaMA-2, etc. are also supported. The performance is optimized for LLaMA-3 tokenizer since its token table has size 128k, much larger than LLaMA-2 tokenizer. These changes are introduced to the grammar library: These changes are introduced to the grammar library: 1. Introduce ByteString rule expression and simplify CharacterClass and CharacterClassStar 2. Refactor BNFGrammarVisitor and BNFGrammarMutator for visiting and mutating grammar rules 3. Now GrammarStateMatcherBase, the internally impl of the GrammarStateMatcher, accepts char by char, instead of codepoint by codepoint. So it supports any valid UTF-8 string, even if the token is not a complete codepoint. 4. Support lookahead assertion for rules to specify the rule must be followed by a sequence. This can eliminate some uncertain tokens in preprocessing. Minor changes: 1. Introduce template hash function HashCombine 2. Update the UTF8 encoding handling functions Performance: 1. For JSON, finding mask requires <30us on 5900X with single thread. The uncertain tokens is <30 in most cases. 2. For JSON schema, finding mask requires <30us on 5900X with single thread. The uncertain tokens is <30 in most cases. --- cpp/serve/engine.cc | 8 +- cpp/serve/grammar/grammar.cc | 135 +++--- cpp/serve/grammar/grammar.h | 45 +- cpp/serve/grammar/grammar_builder.h | 79 ++- ...ammar_simplifier.cc => grammar_functor.cc} | 187 ++++--- ...grammar_simplifier.h => grammar_functor.h} | 145 +++--- cpp/serve/grammar/grammar_parser.cc | 70 ++- cpp/serve/grammar/grammar_parser.h | 2 +- cpp/serve/grammar/grammar_serializer.cc | 52 +- cpp/serve/grammar/grammar_serializer.h | 6 +- cpp/serve/grammar/grammar_state_matcher.cc | 399 ++++++++------- cpp/serve/grammar/grammar_state_matcher.h | 23 +- .../grammar/grammar_state_matcher_base.h | 356 ++++++++------ .../grammar/grammar_state_matcher_preproc.h | 459 ++++++++++-------- .../grammar/grammar_state_matcher_state.h | 80 +-- cpp/serve/grammar/json_schema_converter.cc | 8 +- cpp/serve/grammar/support.h | 84 +++- cpp/support/encoding.cc | 77 ++- cpp/support/encoding.h | 76 ++- cpp/support/utils.h | 18 + cpp/tokenizers.cc | 3 +- python/mlc_llm/serve/grammar.py | 86 ++-- tests/python/serve/test_grammar_parser.py | 173 ++++--- .../test_grammar_state_matcher_custom.py | 37 +- .../serve/test_grammar_state_matcher_json.py | 96 +++- .../python/serve/test_serve_engine_grammar.py | 2 +- web/emcc/mlc_wasm_runtime.cc | 2 +- 27 files changed, 1684 insertions(+), 1024 deletions(-) rename cpp/serve/grammar/{grammar_simplifier.cc => grammar_functor.cc} (54%) rename cpp/serve/grammar/{grammar_simplifier.h => grammar_functor.h} (58%) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 616c463d9c..9b9cf81fe7 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -122,7 +122,7 @@ class EngineImpl : public Engine { } n->token_table_ = Tokenizer::PostProcessTokenTable(n->tokenizer_->TokenTable(), token_table_postproc_method); - n->grammar_init_context_storage_ = GrammarInitContextStorage(n->token_table_); + n->grammar_init_context_cache_ = GrammarInitContextCache(n->token_table_); // - Create the logit processor and sampler, and // the DraftTokenWorkspaceManager for speculative decoding. int max_num_tokens = engine_config->max_num_sequence; @@ -499,9 +499,9 @@ class EngineImpl : public Engine { if (response_format.type != "json_object") { return std::nullopt; } else if (!response_format.schema) { - return grammar_init_context_storage_->GetInitContextForJSON(); + return grammar_init_context_cache_->GetInitContextForJSON(); } else { - return grammar_init_context_storage_->GetInitContextForJSONSchema( + return grammar_init_context_cache_->GetInitContextForJSONSchema( response_format.schema.value()); } } @@ -513,7 +513,7 @@ class EngineImpl : public Engine { Tokenizer tokenizer_; std::vector token_table_; // Helper to get the grammar init context for requests. - GrammarInitContextStorage grammar_init_context_storage_; + GrammarInitContextCache grammar_init_context_cache_; // Models Array models_; // Device that the models run on. diff --git a/cpp/serve/grammar/grammar.cc b/cpp/serve/grammar/grammar.cc index c8d760538c..2f0d7f565f 100644 --- a/cpp/serve/grammar/grammar.cc +++ b/cpp/serve/grammar/grammar.cc @@ -5,9 +5,9 @@ #include "grammar.h" +#include "grammar_functor.h" #include "grammar_parser.h" #include "grammar_serializer.h" -#include "grammar_simplifier.h" #include "json_schema_converter.h" namespace mlc { @@ -21,18 +21,28 @@ std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar) { return os; } -BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string, const std::string& main_rule, - bool normalize, bool simplify) { +BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string, + const std::string& main_rule) { auto grammar = EBNFParser::Parse(ebnf_string, main_rule); - if (normalize) { - grammar = NestedRuleUnwrapper(grammar).Apply(); - } + // Normalize the grammar by default + grammar = BNFGrammarNormalizer().Apply(grammar); return grammar; } TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromEBNFString") - .set_body_typed([](String ebnf_string, String main_rule, bool normalize, bool simplify) { - return BNFGrammar::FromEBNFString(ebnf_string, main_rule, normalize, simplify); + .set_body_typed([](String ebnf_string, String main_rule) { + return BNFGrammar::FromEBNFString(ebnf_string, main_rule); + }); + +// Parse the EBNF string but not normalize it +BNFGrammar DebugFromEBNFStringNoNormalize(const std::string& ebnf_string, + const std::string& main_rule) { + return EBNFParser::Parse(ebnf_string, main_rule); +} + +TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarDebugFromEBNFStringNoNormalize") + .set_body_typed([](String ebnf_string, String main_rule) { + return DebugFromEBNFStringNoNormalize(ebnf_string, main_rule); }); BNFGrammar BNFGrammar::FromJSON(const std::string& json_string) { @@ -69,79 +79,90 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromSchema").set_body([](TVMArgs args, *rv = BNFGrammar::FromSchema(args[0], indent, separators, args[3]); }); +// Optimized json grammar for the speed of the grammar state matcher const std::string kJSONGrammarString = R"( main ::= ( - "{" ws members_or_embrace | - "[" ws elements_or_embrace + "{" [ \n\t]* members_and_embrace | + "[" [ \n\t]* elements_or_embrace ) -value ::= ( - "{" ws members_or_embrace | - "[" ws elements_or_embrace | - "\"" characters "\"" | - [0-9] fraction exponent | - [1-9] digits fraction exponent | +value_non_str ::= ( + "{" [ \n\t]* members_and_embrace | + "[" [ \n\t]* elements_or_embrace | + "0" fraction exponent | + [1-9] [0-9]* fraction exponent | "-" [0-9] fraction exponent | - "-" [1-9] digits fraction exponent | + "-" [1-9] [0-9]* fraction exponent | "true" | "false" | "null" -) -members_or_embrace ::= ( - "\"" characters "\"" ws ":" ws value members_rest ws "}" | - "}" -) -members ::= "\"" characters "\"" ws ":" ws value members_rest -members_rest ::= ( - "" | - "," ws "\"" characters "\"" ws ":" ws value members_rest | - " " ws "," ws "\"" characters "\"" ws ":" ws value members_rest | - "\n" ws "," ws "\"" characters "\"" ws ":" ws value members_rest | - "\t" ws "," ws "\"" characters "\"" ws ":" ws value members_rest -) +) (= [ \n\t,}\]]) +members_and_embrace ::= ("\"" characters_and_colon [ \n\t]* members_suffix | "}") (= [ \n\t,}\]]) +members_suffix ::= ( + value_non_str [ \n\t]* member_suffix_suffix | + "\"" characters_and_embrace | + "\"" characters_and_comma [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix +) (= [ \n\t,}\]]) +member_suffix_suffix ::= ( + "}" | + "," [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix +) (= [ \n\t,}\]]) elements_or_embrace ::= ( - "{" ws members_or_embrace elements_rest ws "]" | - "[" ws elements_or_embrace elements_rest ws "]" | - "\"" characters "\"" elements_rest ws "]" | - [0-9] fraction exponent elements_rest ws "]" | - [1-9] digits fraction exponent elements_rest ws "]" | - "-" [0-9] fraction exponent elements_rest ws "]" | - "-" [1-9] digits fraction exponent elements_rest ws "]" | - "true" elements_rest ws "]" | - "false" elements_rest ws "]" | - "null" elements_rest ws "]" | + "{" [ \n\t]* members_and_embrace elements_rest [ \n\t]* "]" | + "[" [ \n\t]* elements_or_embrace elements_rest [ \n\t]* "]" | + "\"" characters_item elements_rest [ \n\t]* "]" | + "0" fraction exponent elements_rest [ \n\t]* "]" | + [1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" | + "-" "0" fraction exponent elements_rest [ \n\t]* "]" | + "-" [1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" | + "true" elements_rest [ \n\t]* "]" | + "false" elements_rest [ \n\t]* "]" | + "null" elements_rest [ \n\t]* "]" | "]" ) elements ::= ( - "{" ws members_or_embrace elements_rest | - "[" ws elements_or_embrace elements_rest | - "\"" characters "\"" elements_rest | - [0-9] fraction exponent elements_rest | - [1-9] digits fraction exponent elements_rest | + "{" [ \n\t]* members_and_embrace elements_rest | + "[" [ \n\t]* elements_or_embrace elements_rest | + "\"" characters_item elements_rest | + "0" fraction exponent elements_rest | + [1-9] [0-9]* fraction exponent elements_rest | "-" [0-9] fraction exponent elements_rest | - "-" [1-9] digits fraction exponent elements_rest | + "-" [1-9] [0-9]* fraction exponent elements_rest | "true" elements_rest | "false" elements_rest | "null" elements_rest ) elements_rest ::= ( "" | - "," ws elements | - " " ws "," ws elements | - "\n" ws "," ws elements | - "\t" ws "," ws elements + [ \n\t]* "," [ \n\t]* elements ) -characters ::= "" | [^"\\\r\n] characters | "\\" escape characters +characters_and_colon ::= ( + "\"" [ \n\t]* ":" | + [^"\\\x00-\x1F] characters_and_colon | + "\\" escape characters_and_colon +) (=[ \n\t]* [\"{[0-9tfn-]) +characters_and_comma ::= ( + "\"" [ \n\t]* "," | + [^"\\\x00-\x1F] characters_and_comma | + "\\" escape characters_and_comma +) (=[ \n\t]* "\"") +characters_and_embrace ::= ( + "\"" [ \n\t]* "}" | + [^"\\\x00-\x1F] characters_and_embrace | + "\\" escape characters_and_embrace +) (=[ \n\t]* [},]) +characters_item ::= ( + "\"" | + [^"\\\x00-\x1F] characters_item | + "\\" escape characters_item +) (= [ \n\t]* [,\]]) escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -digits ::= [0-9] | [0-9] digits -fraction ::= "" | "." digits -exponent ::= "" | "e" sign digits | "E" sign digits +fraction ::= "" | "." [0-9] [0-9]* +exponent ::= "" | "e" sign [0-9] [0-9]* | "E" sign [0-9] [0-9]* sign ::= "" | "+" | "-" -ws ::= [ \n\t]* )"; BNFGrammar BNFGrammar::GetGrammarOfJSON() { - static const BNFGrammar grammar = - BNFGrammar::FromEBNFString(kJSONGrammarString, "main", true, false); + static const BNFGrammar grammar = BNFGrammar::FromEBNFString(kJSONGrammarString, "main"); return grammar; } diff --git a/cpp/serve/grammar/grammar.h b/cpp/serve/grammar/grammar.h index ba15e58af3..b7922301cb 100644 --- a/cpp/serve/grammar/grammar.h +++ b/cpp/serve/grammar/grammar.h @@ -44,16 +44,15 @@ using namespace tvm::runtime; * #### Types of RuleExprs * Every RuleExpr is represented by a type as well as a variable-length array containing its data. * RuleExpr has several types: + * - Byte string: a string of bytes (0~255). Supports UTF-8 strings. * - Character class: a range of characters (each character is a unicode codepoint), e.g. [a-z], - * [ac-z]. - * A single character is represented by a character class with the same lower and upper bound. - * A string is represented by a sequence of character classes. - * - Negated character class: all characters that are not in the range, e.g. [^a-z], [^ac-z] + * [ac-z]. Can be negated: [^a-z], [^ac-z]. Now only ascii chars is allowed in [], but this + * expression can accept/reject unicode chars. + * - Character class star: a star quantifier of a character class. e.g. [a-z]*, [^a-z]*. * - EmptyStr: an empty string, i.e. "" * - Rule reference: a reference to another rule * - Sequence: a sequence of rule_exprs, e.g. ("a" "b"). These rule_exprs are concatenated together. * - Choices: a choice of rule_exprs, e.g. ("a" "b") | "c". Each rule_expr can be matched. - * - Character class star: special support for a repetition of a character class. e.g. [a-z]* * * #### Storage of RuleExprs * Each type of RuleExpr has a different data format. For the format of each type of RuleExpr, see @@ -76,6 +75,9 @@ class BNFGrammarNode : public Object { std::string name; /*! \brief The RuleExpr id of the body of the rule. */ int32_t body_expr_id; + /*! \brief The id of the associated lookahead assertion expr. For now it must be a id of a + * sequence RuleExpr. -1 if not exists. */ + int32_t lookahead_assertion_id = -1; }; /*! \brief Get the number of rules. */ @@ -86,6 +88,8 @@ class BNFGrammarNode : public Object { << "rule_id " << rule_id << " is out of bound"; return rules_[rule_id]; } + /*! \brief Get the main rule id of the grammar. */ + int32_t GetMainRuleId() const { return main_rule_id_; } /*! \brief Get the main rule of the grammar. */ const Rule& GetMainRule() const { DCHECK(main_rule_id_ >= 0 && main_rule_id_ < static_cast(rules_.size())) @@ -95,10 +99,11 @@ class BNFGrammarNode : public Object { /*! \brief The type of the rule expr. */ enum class RuleExprType : int32_t { - // data format: [lower0, upper0, lower1, upper1, ...] + // data format: [byte0, byte1, ...] + kByteString, + // data format: [is_negative, lower0, upper0, lower1, upper1, ...] kCharacterClass, - // data format: [lower0, upper0, lower1, upper1, ...] - kNegCharacterClass, + kCharacterClassStar, // data format: [] kEmptyStr, // data format: [rule_id] @@ -107,8 +112,6 @@ class BNFGrammarNode : public Object { kSequence, // data format: [rule_expr_id0, rule_expr_id1, ...] kChoices, - // data format: [rule_expr_id] - kCharacterClassStar, }; /*! \brief The object representing a rule expr. */ @@ -154,8 +157,8 @@ class BNFGrammarNode : public Object { std::vector rules_; /*! \brief The data of all rule_exprs. */ std::vector rule_expr_data_; - /*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id corresponds the - * index of this vector. */ + /*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id is the index + * to the elements in this vector. */ std::vector rule_expr_indptr_; /*! \brief The id of the main rule. */ int32_t main_rule_id_ = -1; @@ -168,25 +171,13 @@ class BNFGrammarNode : public Object { class BNFGrammar : public ObjectRef { public: /*! - * \brief Construct a BNF grammar with a EBNF-formatted string. Will parse the string and - * transform it into BNF AST. + * \brief Construct a BNF grammar with a EBNF-formatted string. The grammar will be normalized + * (simplified) by default. * \param ebnf_string The EBNF-formatted string. * \param main_rule The name of the main rule. - * \param normalize Whether to normalize the grammar. Default: true. Only set to false for the - * purpose of testing. - * - * \note In The normalized form of a BNF grammar, every rule is in the form: - * `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. - * - * I.e. a list of choices, each choice is a sequence of elements. Elements can be a character - * class or a rule reference. And if the rule can be empty, the first choice will be an empty - * string. - * \param simplify Whether to simplify the grammar to make matching more efficient. Default: true. - * Not implemented yet. */ static BNFGrammar FromEBNFString(const std::string& ebnf_string, - const std::string& main_rule = "main", bool normalize = true, - bool simplify = true); + const std::string& main_rule = "main"); /*! * \brief Construct a BNF grammar from the dumped JSON string. diff --git a/cpp/serve/grammar/grammar_builder.h b/cpp/serve/grammar/grammar_builder.h index 0854cc9789..7987a67f98 100644 --- a/cpp/serve/grammar/grammar_builder.h +++ b/cpp/serve/grammar/grammar_builder.h @@ -56,6 +56,16 @@ class BNFGrammarBuilder { return static_cast(grammar_->rule_expr_indptr_.size()) - 1; } + /*! + * \brief Add a RuleExpr for string stored in bytes. + * \param bytes A vector of int32_t, each representing a byte (0~255) in the string. + * The string is stored in int32 vector to match the storage format of the grammar. + */ + int32_t AddByteString(const std::vector& bytes) { + return AddRuleExpr( + {RuleExprType::kByteString, bytes.data(), static_cast(bytes.size())}); + } + /*! * \brief One element of a character class, containing a lower and a upper bound. Both bounds are * inclusive. @@ -66,19 +76,39 @@ class BNFGrammarBuilder { }; /*! - * \brief Add a RuleExpr for character class. + * \brief Add a RuleExpr for a character class. * \param elements A vector of CharacterClassElement, each containing a lower and a upper bound. - * \param is_neg_range Whether the character class is negated. + * \param is_negative Whether the character class is negated. */ int32_t AddCharacterClass(const std::vector& elements, - bool is_neg_range = false) { + bool is_negative = false) { std::vector data; + data.reserve(1 + elements.size() * 2); + data.push_back(static_cast(is_negative)); for (const auto& range : elements) { data.push_back(range.lower); data.push_back(range.upper); } - auto type = is_neg_range ? RuleExprType::kNegCharacterClass : RuleExprType::kCharacterClass; - return AddRuleExpr({type, data.data(), static_cast(data.size())}); + return AddRuleExpr( + {RuleExprType::kCharacterClass, data.data(), static_cast(data.size())}); + } + + /*! + * \brief Add a RuleExpr for a star quantifier of a character class. + * \param elements A vector of CharacterClassElement, each containing a lower and a upper bound. + * \param is_negative Whether the character class is negated. + */ + int32_t AddCharacterClassStar(const std::vector& elements, + bool is_negative = false) { + std::vector data; + data.reserve(1 + elements.size() * 2); + data.push_back(static_cast(is_negative)); + for (const auto& range : elements) { + data.push_back(range.lower); + data.push_back(range.upper); + } + return AddRuleExpr( + {RuleExprType::kCharacterClassStar, data.data(), static_cast(data.size())}); } /*! \brief Add a RuleExpr for empty string.*/ @@ -93,23 +123,14 @@ class BNFGrammarBuilder { /*! \brief Add a RuleExpr for RuleExpr sequence.*/ int32_t AddSequence(const std::vector& elements) { - std::vector data; - data.insert(data.end(), elements.begin(), elements.end()); - return AddRuleExpr({RuleExprType::kSequence, data.data(), static_cast(data.size())}); + return AddRuleExpr( + {RuleExprType::kSequence, elements.data(), static_cast(elements.size())}); } /*! \brief Add a RuleExpr for RuleExpr choices.*/ int32_t AddChoices(const std::vector& choices) { - std::vector data; - data.insert(data.end(), choices.begin(), choices.end()); - return AddRuleExpr({RuleExprType::kChoices, data.data(), static_cast(data.size())}); - } - - int32_t AddCharacterClassStar(int32_t element) { - std::vector data; - data.push_back(element); return AddRuleExpr( - {RuleExprType::kCharacterClassStar, data.data(), static_cast(data.size())}); + {RuleExprType::kChoices, choices.data(), static_cast(choices.size())}); } size_t NumRuleExprs() const { return grammar_->NumRuleExprs(); } @@ -154,7 +175,7 @@ class BNFGrammarBuilder { * rule body of a rule inserted by BNFGrammarBuilder::AddEmptyRule. */ void UpdateRuleBody(int32_t rule_id, int32_t body_expr_id) { - CHECK(rule_id < static_cast(grammar_->rules_.size())) + CHECK(rule_id >= 0 && rule_id < static_cast(grammar_->rules_.size())) << "Rule id " << rule_id << " is out of range."; grammar_->rules_[rule_id].body_expr_id = body_expr_id; } @@ -169,6 +190,28 @@ class BNFGrammarBuilder { UpdateRuleBody(rule_id, body_expr_id); } + /*! + * \brief Add a lookahead assertion to a rule referred by the given rule_id. The lookahead + * assertion should be a sequence RuleExpr id. An id of -1 means no lookahead assertion. + */ + void AddLookaheadAssertion(int32_t rule_id, int32_t lookahead_assertion_id) { + CHECK(rule_id < static_cast(grammar_->rules_.size())) + << "Rule id " << rule_id << " is out of range."; + CHECK(grammar_->rules_[rule_id].lookahead_assertion_id == -1) + << "Rule " << rule_id << " already has a lookahead assertion."; + grammar_->rules_[rule_id].lookahead_assertion_id = lookahead_assertion_id; + } + + /*! + * \brief Add a lookahead assertion to a rule referred by the given name. The lookahead + * assertion should be a sequence RuleExpr id. An id of -1 means no lookahead assertion. + */ + void AddLookaheadAssertion(std::string rule_name, int32_t lookahead_assertion_id) { + int32_t rule_id = GetRuleId(rule_name); + CHECK(rule_id != -1) << "Rule " << rule_name << " is not found."; + AddLookaheadAssertion(rule_id, lookahead_assertion_id); + } + /*! * \brief Find a name for a new rule starting with the given name hint. Some integer suffix (_1, * _2, ...) may be added to avoid name conflict. diff --git a/cpp/serve/grammar/grammar_simplifier.cc b/cpp/serve/grammar/grammar_functor.cc similarity index 54% rename from cpp/serve/grammar/grammar_simplifier.cc rename to cpp/serve/grammar/grammar_functor.cc index 109b5d85e1..ae4e108233 100644 --- a/cpp/serve/grammar/grammar_simplifier.cc +++ b/cpp/serve/grammar/grammar_functor.cc @@ -1,56 +1,101 @@ /*! * Copyright (c) 2023 by Contributors - * \file serve/grammar/grammar_simplifier.cc + * \file serve/grammar/grammar_functor.cc */ -#include "grammar_simplifier.h" +#include "grammar_functor.h" + +#include "../../support/encoding.h" namespace mlc { namespace llm { namespace serve { /*! - * \brief Eliminates single-element sequence or choice nodes in the grammar. - * \example The sequence `(a)` or the choice `(a)` will be replaced by `a` in a rule. - * \example The rule `A ::= ((b) (((d))))` will be replaced by `A ::= (b d)`. + * \brief Eliminates single-element sequence or choice or character class in the grammar. + * \example `A ::= choices("a")` --> `A ::= "a"` (the body is a string) + * \example `A ::= sequence("a")` --> `A ::= "a"` (the body is a string) + * \example `A ::= [a-a]` --> `A ::= "a"` (the body is a string) */ -class SingleElementSequenceOrChoiceEliminator : public BNFGrammarMutator { +class SingleElementExprEliminator : public BNFGrammarMutator { public: using BNFGrammarMutator::Apply; using BNFGrammarMutator::BNFGrammarMutator; private: - int32_t VisitSequence(const RuleExpr& rule_expr) { + // Keep the sequence expr in lookahead assertion + int32_t VisitLookaheadAssertion(int32_t lookahead_assertion_id) final { + if (lookahead_assertion_id == -1) { + return -1; + } + auto rule_expr = grammar_->GetRuleExpr(lookahead_assertion_id); + CHECK(rule_expr.type == RuleExprType::kSequence); + + std::vector sequence_ids; + for (int32_t i : rule_expr) { + sequence_ids.push_back(VisitExpr(i)); + } + return builder_.AddSequence(sequence_ids); + } + + int32_t VisitSequence(const RuleExpr& rule_expr) final { std::vector sequence_ids; for (int32_t i : rule_expr) { - sequence_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + sequence_ids.push_back(VisitExpr(i)); } if (sequence_ids.size() == 1) { return sequence_ids[0]; - } else { - return builder_.AddSequence(sequence_ids); } + return builder_.AddSequence(sequence_ids); } - int32_t VisitChoices(const RuleExpr& rule_expr) { + int32_t VisitChoices(const RuleExpr& rule_expr) final { std::vector choice_ids; for (int32_t i : rule_expr) { - choice_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + choice_ids.push_back(VisitExpr(i)); } if (choice_ids.size() == 1) { return choice_ids[0]; - } else { - return builder_.AddChoices(choice_ids); } + return builder_.AddChoices(choice_ids); + } + + int32_t VisitCharacterClass(const RuleExpr& rule_expr) final { + if (rule_expr.data_len == 3 && rule_expr[0] == 0 && rule_expr[1] == rule_expr[2]) { + std::string str = PrintAsUTF8(rule_expr[1]); + std::vector bytes; + bytes.reserve(str.size()); + for (char c : str) { + bytes.push_back(static_cast(c)); + } + return builder_.AddByteString(bytes); + } + return builder_.AddRuleExpr(rule_expr); } }; -class NestedRuleUnwrapperImpl : public BNFGrammarMutator { +/*! + * \brief Unwrap the rules containing nested expressions. After unwrapping, each rule will be in + * the form: `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. + * + * I.e. a list of choices, each choice is a sequence of elements. Elements can be a character class + * or a rule reference. And if the rule can be empty, the first choice will be an empty string. + * + * \example The rule `A ::= ((a) (((b)) (c)) "")` will be replaced by `A ::= ((a b c))`. One choice + * containing a sequence of three elements. The empty string is removed. + * \example The rule `A ::= (a | (b | (c | "")))` will be replaced by + * `A ::= ("" | (a) | (b) | (c))`. The first choice is an empty string, and each of the other three + * choices is a sequence containing a single element. + * \example The rule `A ::= (a | (b (c | d)))` will be replaced by + * `A ::= ((a) | (b B)), B ::= ((c) | (d))`. A new rule B is created to represent the nested + * choices. + */ +class NestedRuleUnwrapper : public BNFGrammarMutator { public: using BNFGrammarMutator::BNFGrammarMutator; - BNFGrammar Apply() final { - grammar_ = SingleElementSequenceOrChoiceEliminator(grammar_).Apply(); + BNFGrammar Apply(const BNFGrammar& grammar) final { + Init(grammar); for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { builder_.AddEmptyRule(grammar_->GetRule(i).name); } @@ -60,11 +105,20 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { cur_rule_name_ = rule.name; auto new_body_expr_id = VisitRuleBody(rule_expr); builder_.UpdateRuleBody(i, new_body_expr_id); + builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id)); } return builder_.Get(grammar_->GetMainRule().name); } private: + int32_t VisitLookaheadAssertion(int32_t lookahead_assertion_id) final { + if (lookahead_assertion_id == -1) { + return -1; + } + auto assertion_expr = grammar_->GetRuleExpr(lookahead_assertion_id); + return builder_.AddSequence(VisitSequence_(assertion_expr)); + } + /*! \brief Visit a RuleExpr as a rule body. */ int32_t VisitRuleBody(const RuleExpr& rule_expr) { switch (rule_expr.type) { @@ -74,12 +128,11 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { return builder_.AddChoices(VisitChoices_(rule_expr)); case RuleExprType::kEmptyStr: return builder_.AddChoices({builder_.AddEmptyStr()}); + case RuleExprType::kByteString: case RuleExprType::kCharacterClass: - case RuleExprType::kNegCharacterClass: + case RuleExprType::kCharacterClassStar: case RuleExprType::kRuleRef: return builder_.AddChoices({builder_.AddSequence({builder_.AddRuleExpr(rule_expr)})}); - case RuleExprType::kCharacterClassStar: - return builder_.AddCharacterClassStar(VisitExpr(grammar_->GetRuleExpr(rule_expr[0]))); default: LOG(FATAL) << "Unexpected sequence type: " << static_cast(rule_expr.type); } @@ -104,14 +157,12 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { case RuleExprType::kEmptyStr: found_empty = true; break; + case RuleExprType::kByteString: case RuleExprType::kCharacterClass: - case RuleExprType::kNegCharacterClass: + case RuleExprType::kCharacterClassStar: case RuleExprType::kRuleRef: VisitElementInChoices(choice_expr, &new_choice_ids); break; - case RuleExprType::kCharacterClassStar: - VisitCharacterClassStarInChoices(choice_expr, &new_choice_ids); - break; default: LOG(FATAL) << "Unexpected choice type: " << static_cast(choice_expr.type); } @@ -154,16 +205,6 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { new_choice_ids->push_back(builder_.AddSequence({sub_expr_id})); } - /*! \brief Visit a character class star RuleExpr that is one of a list of choices. */ - void VisitCharacterClassStarInChoices(const RuleExpr& rule_expr, - std::vector* new_choice_ids) { - auto sub_expr_id = builder_.AddRuleExpr(grammar_->GetRuleExpr(rule_expr[0])); - auto new_star_id = builder_.AddCharacterClassStar(sub_expr_id); - auto new_rule_id = builder_.AddRuleWithHint(cur_rule_name_ + "_star", new_star_id); - auto new_rule_ref_id = builder_.AddRuleRef(new_rule_id); - new_choice_ids->push_back(builder_.AddSequence({new_rule_ref_id})); - } - /*! * \brief Visit a RuleExpr containing a sequence. * \returns A list of new sequence RuleExpr ids. @@ -171,26 +212,24 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { std::vector VisitSequence_(const RuleExpr& rule_expr) { std::vector new_sequence_ids; for (auto i : rule_expr) { - auto seq_expr = grammar_->GetRuleExpr(i); - switch (seq_expr.type) { + auto element_expr = grammar_->GetRuleExpr(i); + switch (element_expr.type) { case RuleExprType::kSequence: - VisitSequenceInSequence(seq_expr, &new_sequence_ids); + VisitSequenceInSequence(element_expr, &new_sequence_ids); break; case RuleExprType::kChoices: - VisitChoiceInSequence(seq_expr, &new_sequence_ids); + VisitChoiceInSequence(element_expr, &new_sequence_ids); break; case RuleExprType::kEmptyStr: break; + case RuleExprType::kByteString: case RuleExprType::kCharacterClass: - case RuleExprType::kNegCharacterClass: - case RuleExprType::kRuleRef: - VisitElementInSequence(seq_expr, &new_sequence_ids); - break; case RuleExprType::kCharacterClassStar: - VisitCharacterClassStarInSequence(seq_expr, &new_sequence_ids); + case RuleExprType::kRuleRef: + VisitElementInSequence(element_expr, &new_sequence_ids); break; default: - LOG(FATAL) << "Unexpected sequence type: " << static_cast(seq_expr.type); + LOG(FATAL) << "Unexpected sequence type: " << static_cast(element_expr.type); } } return new_sequence_ids; @@ -223,22 +262,58 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { void VisitElementInSequence(const RuleExpr& rule_expr, std::vector* new_sequence_ids) { new_sequence_ids->push_back(builder_.AddRuleExpr(rule_expr)); } +}; - /*! \brief Visit a character class star RuleExpr that is in a sequence. */ - void VisitCharacterClassStarInSequence(const RuleExpr& rule_expr, - std::vector* new_sequence_ids) { - auto sub_expr_id = builder_.AddRuleExpr(grammar_->GetRuleExpr(rule_expr[0])); - auto new_star_id = builder_.AddCharacterClassStar(sub_expr_id); - auto new_rule_id = builder_.AddRuleWithHint(cur_rule_name_ + "_star", new_star_id); - auto new_rule_ref_id = builder_.AddRuleRef(new_rule_id); - new_sequence_ids->push_back(new_rule_ref_id); - } +class ByteStringFuser : public BNFGrammarMutator { + public: + using BNFGrammarMutator::Apply; + using BNFGrammarMutator::BNFGrammarMutator; - /*! \brief The name of the current rule being visited. */ - std::string cur_rule_name_; + private: + /*! + * \brief Visit a RuleExpr containing a sequence. + * \returns A list of new sequence RuleExpr ids. + */ + int32_t VisitSequence(const RuleExpr& rule_expr) final { + std::vector new_sequence_ids; + std::vector cur_byte_string; + for (auto i : rule_expr) { + auto element_expr = grammar_->GetRuleExpr(i); + if (element_expr.type == RuleExprType::kByteString) { + cur_byte_string.insert(cur_byte_string.end(), element_expr.begin(), element_expr.end()); + continue; + } else { + if (!cur_byte_string.empty()) { + new_sequence_ids.push_back(builder_.AddByteString(cur_byte_string)); + cur_byte_string.clear(); + } + new_sequence_ids.push_back(builder_.AddRuleExpr(element_expr)); + } + } + if (!cur_byte_string.empty()) { + new_sequence_ids.push_back(builder_.AddByteString(cur_byte_string)); + } + return builder_.AddSequence(new_sequence_ids); + } }; -BNFGrammar NestedRuleUnwrapper::Apply() { return NestedRuleUnwrapperImpl(grammar_).Apply(); } +// Return the list of all normalizers in the class. The normalizers are applied one by one. +std::vector> BNFGrammarNormalizer::GetNormalizerList() { + std::vector> normalizer_mutators; + normalizer_mutators.emplace_back(std::make_unique()); + normalizer_mutators.emplace_back(std::make_unique()); + normalizer_mutators.emplace_back(std::make_unique()); + return normalizer_mutators; +} + +BNFGrammar BNFGrammarNormalizer::Apply(const BNFGrammar& grammar) { + std::vector> normalizer_mutators = GetNormalizerList(); + grammar_ = grammar; + for (auto& mutator : normalizer_mutators) { + grammar_ = mutator->Apply(grammar_); + } + return grammar_; +} } // namespace serve } // namespace llm diff --git a/cpp/serve/grammar/grammar_simplifier.h b/cpp/serve/grammar/grammar_functor.h similarity index 58% rename from cpp/serve/grammar/grammar_simplifier.h rename to cpp/serve/grammar/grammar_functor.h index 50f3804387..123700778e 100644 --- a/cpp/serve/grammar/grammar_simplifier.h +++ b/cpp/serve/grammar/grammar_functor.h @@ -1,11 +1,11 @@ /*! * Copyright (c) 2023 by Contributors - * \file serve/grammar/grammar_simplifier.h + * \file serve/grammar/grammar_functor.h * \brief The header for the simplification of the BNF AST. */ -#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SIMPLIFIER_H_ -#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SIMPLIFIER_H_ +#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_FUNCTOR_H_ +#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_FUNCTOR_H_ #include #include @@ -27,29 +27,44 @@ namespace serve { * are void (for visitor) and BNFGrammar (for mutator). */ template -class BNFGrammarMutator { +class BNFGrammarFunctor { public: /*! * \brief Constructor. * \param grammar The grammar to visit or mutate. */ - explicit BNFGrammarMutator(const BNFGrammar& grammar) : grammar_(grammar) {} + explicit BNFGrammarFunctor() {} /*! * \brief Apply the transformation to the grammar, or visit the grammar. * \return The transformed grammar, or the visiting result, or void. - * \note Should be called only once after the mutator is constructed. */ - virtual ReturnType Apply() { - if constexpr (std::is_same::value && std::is_same::value) { + virtual ReturnType Apply(const BNFGrammar& grammar) { + Init(grammar); + if constexpr (std::is_same::value) { for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { auto rule = grammar_->GetRule(i); - auto rule_expr = grammar_->GetRuleExpr(rule.body_expr_id); - auto new_body_expr_id = VisitExpr(rule_expr); - builder_.AddRule(rule.name, new_body_expr_id); + cur_rule_name_ = rule.name; + VisitExpr(rule.body_expr_id); + VisitLookaheadAssertion(rule.lookahead_assertion_id); + } + } else if constexpr (std::is_same::value && + std::is_same::value) { + // First add empty rules to ensure the new rule ids the same as the old ones, then update + // the rule bodies + for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { + builder_.AddEmptyRule(grammar_->GetRule(i).name); + } + for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { + auto rule = grammar_->GetRule(i); + cur_rule_name_ = rule.name; + auto new_body_expr_id = VisitExpr(rule.body_expr_id); + builder_.UpdateRuleBody(i, new_body_expr_id); + // Handle lookahead assertion + builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id)); } return builder_.Get(grammar_->GetMainRule().name); - } else if constexpr (!std::is_same::value) { + } else { return ReturnType(); } } @@ -59,6 +74,25 @@ class BNFGrammarMutator { using RuleExpr = BNFGrammarNode::RuleExpr; using RuleExprType = BNFGrammarNode::RuleExprType; + /*! \brief Initialize the functor. Should be called at the beginning of Apply(). */ + virtual void Init(const BNFGrammar& grammar) { + grammar_ = grammar; + builder_ = BNFGrammarBuilder(); + } + + /*! \brief Visit a lookahead assertion expr referred by id. */ + virtual T VisitLookaheadAssertion(int32_t lookahead_assertion_id) { + if (lookahead_assertion_id == -1) { + return -1; + } + return VisitExpr(lookahead_assertion_id); + } + + /*! \brief Visit a RuleExpr by id. */ + virtual T VisitExpr(int32_t old_rule_expr_id) { + return VisitExpr(grammar_->GetRuleExpr(old_rule_expr_id)); + } + /*! \brief Visit a RuleExpr. Dispatch to the corresponding Visit function. */ virtual T VisitExpr(const RuleExpr& rule_expr) { switch (rule_expr.type) { @@ -68,47 +102,48 @@ class BNFGrammarMutator { return VisitChoices(rule_expr); case RuleExprType::kEmptyStr: return VisitEmptyStr(rule_expr); + case RuleExprType::kByteString: + return VisitByteString(rule_expr); case RuleExprType::kCharacterClass: - case RuleExprType::kNegCharacterClass: return VisitCharacterClass(rule_expr); - case RuleExprType::kRuleRef: - return VisitRuleRef(rule_expr); case RuleExprType::kCharacterClassStar: return VisitCharacterClassStar(rule_expr); + case RuleExprType::kRuleRef: + return VisitRuleRef(rule_expr); default: LOG(FATAL) << "Unexpected sequence type: " << static_cast(rule_expr.type); } } - /*! \brief Visit a sequence RuleExpr. */ - virtual T VisitSequence(const RuleExpr& rule_expr) { + /*! \brief Visit a choices RuleExpr. */ + virtual T VisitChoices(const RuleExpr& rule_expr) { if constexpr (std::is_same::value) { for (auto i : rule_expr) { - VisitExpr(grammar_->GetRuleExpr(i)); + VisitExpr(i); } } else if constexpr (std::is_same::value) { - std::vector sequence_ids; + std::vector choice_ids; for (int32_t i : rule_expr) { - sequence_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + choice_ids.push_back(VisitExpr(i)); } - return builder_.AddSequence(sequence_ids); + return builder_.AddChoices(choice_ids); } else { return T(); } } - /*! \brief Visit a choices RuleExpr. */ - virtual T VisitChoices(const RuleExpr& rule_expr) { + /*! \brief Visit a sequence RuleExpr. */ + virtual T VisitSequence(const RuleExpr& rule_expr) { if constexpr (std::is_same::value) { for (auto i : rule_expr) { - VisitExpr(grammar_->GetRuleExpr(i)); + VisitExpr(i); } } else if constexpr (std::is_same::value) { - std::vector choice_ids; + std::vector sequence_ids; for (int32_t i : rule_expr) { - choice_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + sequence_ids.push_back(VisitExpr(i)); } - return builder_.AddChoices(choice_ids); + return builder_.AddSequence(sequence_ids); } else { return T(); } @@ -128,23 +163,18 @@ class BNFGrammarMutator { /*! \brief Visit an empty string RuleExpr. */ virtual T VisitEmptyStr(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + /*! \brief Visit a character class RuleExpr. */ + virtual T VisitByteString(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + /*! \brief Visit a character class RuleExpr. */ virtual T VisitCharacterClass(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + /*! \brief Visit a star quantifier RuleExpr. */ + virtual T VisitCharacterClassStar(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + /*! \brief Visit a rule reference RuleExpr. */ virtual T VisitRuleRef(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } - /*! \brief Visit a star quantifier RuleExpr. */ - virtual T VisitCharacterClassStar(const RuleExpr& rule_expr) { - if constexpr (std::is_same::value) { - VisitExpr(grammar_->GetRuleExpr(rule_expr[0])); - } else if constexpr (std::is_same::value) { - return builder_.AddCharacterClassStar(VisitExpr(grammar_->GetRuleExpr(rule_expr[0]))); - } else { - return T(); - } - } - /*! \brief The grammar to visit or mutate. */ BNFGrammar grammar_; /*! @@ -152,33 +182,38 @@ class BNFGrammarMutator { * can be used to build a new grammar in subclasses. */ BNFGrammarBuilder builder_; + /*! \brief The name of the current rule being visited. */ + std::string cur_rule_name_; }; /*! - * \brief Unwrap the rules containing nested expressions. After unwrapping, each rule will be in - * the form: `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. - * - * I.e. a list of choices, each choice is a sequence of elements. Elements can be a character class - * or a rule reference. And if the rule can be empty, the first choice will be an empty string. - * - * \example The rule `A ::= ((a) (((b)) (c)) "")` will be replaced by `A ::= ((a b c))`. One choice - * containing a sequence of three elements. The empty string is removed. - * \example The rule `A ::= (a | (b | (c | "")))` will be replaced by - * `A ::= ("" | (a) | (b) | (c))`. The first choice is an empty string, and each of the other three - * choices is a sequence containing a single element. - * \example The rule `A ::= (a | (b (c | d)))` will be replaced by - * `A ::= ((a) | (b B)), B ::= ((c) | (d))`. A new rule B is created to represent the nested - * choices. + * \brief Visitor of BNFGrammar. + * \tparam ReturnType The return type of the Apply() function. Denotes the collected information. */ -class NestedRuleUnwrapper : public BNFGrammarMutator { +template +using BNFGrammarVisitor = BNFGrammarFunctor; + +/*! + * \brief Mutator of BNFGrammar. The Apply() function returns the updated grammar. + */ +using BNFGrammarMutator = BNFGrammarFunctor; + +/*! + * \brief Normalize a BNFGrammar: expand the nested rules, combine consequent sequences and strings, + * etc. + */ +class BNFGrammarNormalizer : public BNFGrammarMutator { public: using BNFGrammarMutator::BNFGrammarMutator; - BNFGrammar Apply() final; + BNFGrammar Apply(const BNFGrammar& grammar) final; + + private: + std::vector> GetNormalizerList(); }; } // namespace serve } // namespace llm } // namespace mlc -#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SIMPLIFIER_H_ +#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_FUNCTOR_H_ diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index a4eda4e395..2799ee4ba9 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -29,6 +29,7 @@ class EBNFParserImpl { int32_t ParseRuleRef(); int32_t ParseElement(); int32_t ParseQuantifier(); + int32_t ParseLookaheadAssertion(); int32_t ParseSequence(); int32_t ParseChoices(); Rule ParseRule(); @@ -157,10 +158,10 @@ int32_t EBNFParserImpl::ParseCharacterClass() { } auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_, kCustomEscapeMap); - if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { + if (codepoint == CharHandlingError::kInvalidUTF8) { ThrowParseError("Invalid UTF8 sequence"); } - if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { + if (codepoint == CharHandlingError::kInvalidEscape) { ThrowParseError("Invalid escape sequence"); } Consume(new_cur - cur_); @@ -189,26 +190,37 @@ int32_t EBNFParserImpl::ParseCharacterClass() { // parse a c style string with utf8 support int32_t EBNFParserImpl::ParseString() { - std::vector character_classes; + std::vector codepoints; while (Peek() && Peek() != '\"') { if (Peek() == '\r' || Peek() == '\n') { ThrowParseError("There should be no newline character in a string literal"); } auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_); - if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { + if (codepoint == CharHandlingError::kInvalidUTF8) { ThrowParseError("Invalid utf8 sequence"); } - if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { + if (codepoint == CharHandlingError::kInvalidEscape) { ThrowParseError("Invalid escape sequence"); } Consume(new_cur - cur_); - character_classes.push_back(builder_.AddCharacterClass({{codepoint, codepoint}})); + codepoints.push_back(codepoint); } - if (character_classes.empty()) { + if (codepoints.empty()) { return builder_.AddEmptyStr(); } - return builder_.AddSequence(character_classes); + + // convert codepoints to string + std::string str; + for (auto codepoint : codepoints) { + str += PrintAsUTF8(codepoint); + } + // convert str to int32_t vector + std::vector bytes; + for (auto c : str) { + bytes.push_back(static_cast(c)); + } + return builder_.AddByteString(bytes); } int32_t EBNFParserImpl::ParseRuleRef() { @@ -264,9 +276,11 @@ int32_t EBNFParserImpl::ParseElement() { } int32_t EBNFParserImpl::HandleStarQuantifier(int32_t rule_expr_id) { - if (builder_.GetRuleExpr(rule_expr_id).type == BNFGrammarBuilder::RuleExprType::kCharacterClass) { + BNFGrammarNode::RuleExpr rule_expr = builder_.GetRuleExpr(rule_expr_id); + if (rule_expr.type == BNFGrammarBuilder::RuleExprType::kCharacterClass) { // We have special handling for character class star, e.g. [a-z]* - return builder_.AddCharacterClassStar(rule_expr_id); + rule_expr.type = BNFGrammarBuilder::RuleExprType::kCharacterClassStar; + return builder_.AddRuleExpr(rule_expr); } else { // For other star quantifiers, we transform it into a rule: // a* --> rule ::= a rule | "" @@ -327,12 +341,11 @@ int32_t EBNFParserImpl::ParseQuantifier() { int32_t EBNFParserImpl::ParseSequence() { std::vector elements; - elements.push_back(ParseQuantifier()); - ConsumeSpace(in_parentheses_); - while (Peek() && Peek() != '|' && Peek() != ')' && Peek() != '\n' && Peek() != '\r') { + do { elements.push_back(ParseQuantifier()); ConsumeSpace(in_parentheses_); - } + } while (Peek() && Peek() != '|' && Peek() != ')' && Peek() != '\n' && Peek() != '\r' && + (Peek() != '(' || Peek(1) != '=')); return builder_.AddSequence(elements); } @@ -350,6 +363,24 @@ int32_t EBNFParserImpl::ParseChoices() { return builder_.AddChoices(choices); } +int32_t EBNFParserImpl::ParseLookaheadAssertion() { + if (Peek() != '(' || Peek(1) != '=') { + return -1; + } + Consume(2); + auto prev_in_parentheses = in_parentheses_; + in_parentheses_ = true; + ConsumeSpace(in_parentheses_); + auto result = ParseSequence(); + ConsumeSpace(in_parentheses_); + if (Peek() != ')') { + ThrowParseError("Expect )"); + } + Consume(); + in_parentheses_ = prev_in_parentheses; + return result; +} + EBNFParserImpl::Rule EBNFParserImpl::ParseRule() { std::string name = ParseName(); cur_rule_name_ = name; @@ -359,7 +390,10 @@ EBNFParserImpl::Rule EBNFParserImpl::ParseRule() { } Consume(3); ConsumeSpace(); - return {name, ParseChoices()}; + auto body_id = ParseChoices(); + ConsumeSpace(); + auto lookahead_id = ParseLookaheadAssertion(); + return {name, body_id, lookahead_id}; } void EBNFParserImpl::BuildRuleNameToId() { @@ -399,8 +433,14 @@ BNFGrammar EBNFParserImpl::DoParse(std::string ebnf_string, std::string main_rul ResetStringIterator(ebnf_string.c_str()); ConsumeSpace(); while (Peek()) { + // Throw error when there are multiple lookahead assertions + if (Peek() == '(' && Peek(1) == '=') { + ThrowParseError("Unexpected lookahead assertion"); + } auto new_rule = ParseRule(); builder_.UpdateRuleBody(new_rule.name, new_rule.body_expr_id); + // Update the lookahead assertion + builder_.AddLookaheadAssertion(new_rule.name, new_rule.lookahead_assertion_id); ConsumeSpace(); } diff --git a/cpp/serve/grammar/grammar_parser.h b/cpp/serve/grammar/grammar_parser.h index 4d10e8eb0d..94ac3d4ce1 100644 --- a/cpp/serve/grammar/grammar_parser.h +++ b/cpp/serve/grammar/grammar_parser.h @@ -23,7 +23,7 @@ using namespace tvm::runtime; * \details This function accepts the EBNF notation defined in the W3C XML Specification * (https://www.w3.org/TR/xml/#sec-notation), which is a popular standard, with the following * changes: - * - Using # as comment mark instead of /**\/ + * - Using # as comment mark instead of C-style comments * - Accept C-style unicode escape sequence \u01AB, \U000001AB, \xAB instead of #x0123 * - Rule A-B (match A and not match B) is not supported yet * diff --git a/cpp/serve/grammar/grammar_serializer.cc b/cpp/serve/grammar/grammar_serializer.cc index c3c2c88baa..5176b9f102 100644 --- a/cpp/serve/grammar/grammar_serializer.cc +++ b/cpp/serve/grammar/grammar_serializer.cc @@ -18,7 +18,11 @@ namespace serve { using namespace tvm::runtime; std::string BNFGrammarPrinter::PrintRule(const Rule& rule) { - return rule.name + " ::= " + PrintRuleExpr(rule.body_expr_id); + std::string res = rule.name + " ::= " + PrintRuleExpr(rule.body_expr_id); + if (rule.lookahead_assertion_id != -1) { + res += " (=" + PrintRuleExpr(rule.lookahead_assertion_id) + ")"; + } + return res; } std::string BNFGrammarPrinter::PrintRule(int32_t rule_id) { @@ -28,10 +32,12 @@ std::string BNFGrammarPrinter::PrintRule(int32_t rule_id) { std::string BNFGrammarPrinter::PrintRuleExpr(const RuleExpr& rule_expr) { std::string result; switch (rule_expr.type) { + case RuleExprType::kByteString: + return PrintByteString(rule_expr); case RuleExprType::kCharacterClass: return PrintCharacterClass(rule_expr); - case RuleExprType::kNegCharacterClass: - return PrintCharacterClass(rule_expr); + case RuleExprType::kCharacterClassStar: + return PrintCharacterClassStar(rule_expr); case RuleExprType::kEmptyStr: return PrintEmptyStr(rule_expr); case RuleExprType::kRuleRef: @@ -40,8 +46,6 @@ std::string BNFGrammarPrinter::PrintRuleExpr(const RuleExpr& rule_expr) { return PrintSequence(rule_expr); case RuleExprType::kChoices: return PrintChoices(rule_expr); - case RuleExprType::kCharacterClassStar: - return PrintCharacterClassStar(rule_expr); default: LOG(FATAL) << "Unexpected RuleExpr type: " << static_cast(rule_expr.type); } @@ -51,14 +55,29 @@ std::string BNFGrammarPrinter::PrintRuleExpr(int32_t rule_expr_id) { return PrintRuleExpr(grammar_->GetRuleExpr(rule_expr_id)); } +std::string BNFGrammarPrinter::PrintByteString(const RuleExpr& rule_expr) { + std::string internal_str; + internal_str.reserve(rule_expr.data_len); + for (int i = 0; i < rule_expr.data_len; ++i) { + internal_str += static_cast(rule_expr[i]); + } + auto codepoints = ParseUTF8(internal_str.c_str(), UTF8ErrorPolicy::kReturnByte); + std::string result; + for (auto codepoint : codepoints) { + result += PrintAsEscaped(codepoint); + } + return "\"" + result + "\""; +} + std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) { static const std::unordered_map kCustomEscapeMap = {{'-', "\\-"}, {']', "\\]"}}; std::string result = "["; - if (rule_expr.type == RuleExprType::kNegCharacterClass) { + bool is_negative = static_cast(rule_expr[0]); + if (is_negative) { result += "^"; } - for (auto i = 0; i < rule_expr.data_len; i += 2) { + for (auto i = 1; i < rule_expr.data_len; i += 2) { result += PrintAsEscaped(rule_expr[i], kCustomEscapeMap); if (rule_expr[i] == rule_expr[i + 1]) { continue; @@ -70,6 +89,10 @@ std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) { return result; } +std::string BNFGrammarPrinter::PrintCharacterClassStar(const RuleExpr& rule_expr) { + return PrintCharacterClass(rule_expr) + "*"; +} + std::string BNFGrammarPrinter::PrintEmptyStr(const RuleExpr& rule_expr) { return "\"\""; } std::string BNFGrammarPrinter::PrintRuleRef(const RuleExpr& rule_expr) { @@ -103,10 +126,6 @@ std::string BNFGrammarPrinter::PrintChoices(const RuleExpr& rule_expr) { return result; } -std::string BNFGrammarPrinter::PrintCharacterClassStar(const RuleExpr& rule_expr) { - return PrintRuleExpr(rule_expr[0]) + "*"; -} - std::string BNFGrammarPrinter::ToString() { std::string result; auto num_rules = grammar_->NumRules(); @@ -121,7 +140,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarToString").set_body_typed([](const BNFG }); std::string BNFGrammarJSONSerializer::ToString() { - picojson::object grammar_json; + picojson::object grammar_json_obj; picojson::array rules_json; for (const auto& rule : grammar_->rules_) { @@ -130,20 +149,21 @@ std::string BNFGrammarJSONSerializer::ToString() { rule_json["body_expr_id"] = picojson::value(static_cast(rule.body_expr_id)); rules_json.push_back(picojson::value(rule_json)); } - grammar_json["rules"] = picojson::value(rules_json); + grammar_json_obj["rules"] = picojson::value(rules_json); picojson::array rule_expr_data_json; for (const auto& data : grammar_->rule_expr_data_) { rule_expr_data_json.push_back(picojson::value(static_cast(data))); } - grammar_json["rule_expr_data"] = picojson::value(rule_expr_data_json); + grammar_json_obj["rule_expr_data"] = picojson::value(rule_expr_data_json); picojson::array rule_expr_indptr_json; for (const auto& index_ptr : grammar_->rule_expr_indptr_) { rule_expr_indptr_json.push_back(picojson::value(static_cast(index_ptr))); } - grammar_json["rule_expr_indptr"] = picojson::value(rule_expr_indptr_json); + grammar_json_obj["rule_expr_indptr"] = picojson::value(rule_expr_indptr_json); - return picojson::value(grammar_json).serialize(prettify_); + auto grammar_json = picojson::value(grammar_json_obj); + return grammar_json.serialize(prettify_); } TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarToJSON") diff --git a/cpp/serve/grammar/grammar_serializer.h b/cpp/serve/grammar/grammar_serializer.h index 4ad5c2103b..f0837d9638 100644 --- a/cpp/serve/grammar/grammar_serializer.h +++ b/cpp/serve/grammar/grammar_serializer.h @@ -62,8 +62,12 @@ class BNFGrammarPrinter : public BNFGrammarSerializer { std::string PrintRuleExpr(int32_t rule_expr_id); private: + /*! \brief Print a RuleExpr for byte string. */ + std::string PrintByteString(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for character class. */ std::string PrintCharacterClass(const RuleExpr& rule_expr); + /*! \brief Print a RuleExpr for a star quantifier of a character class. */ + std::string PrintCharacterClassStar(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for empty string. */ std::string PrintEmptyStr(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for rule reference. */ @@ -72,8 +76,6 @@ class BNFGrammarPrinter : public BNFGrammarSerializer { std::string PrintSequence(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for rule_expr choices. */ std::string PrintChoices(const RuleExpr& rule_expr); - /*! \brief Print a RuleExpr for star quantifier. */ - std::string PrintCharacterClassStar(const RuleExpr& rule_expr); }; /*! diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index 451127e746..e6e68f376f 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -2,6 +2,7 @@ * Copyright (c) 2023 by Contributors * \file serve/grammar/grammar_state_matcher.cc */ +// #define TVM_LOG_DEBUG 1 #include "grammar_state_matcher.h" #include @@ -123,13 +124,15 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm private: using RuleExpr = BNFGrammarNode::RuleExpr; using RuleExprType = BNFGrammarNode::RuleExprType; + using SaveType = CatagorizedTokens::SaveType; public: GrammarStateMatcherNodeImpl(std::shared_ptr init_ctx, int max_rollback_steps = 0) : GrammarStateMatcherBase(init_ctx->grammar), init_ctx_(init_ctx), - max_rollback_steps_(max_rollback_steps) {} + max_rollback_steps_(max_rollback_steps), + tmp_accepted_bitset_(init_ctx_->vocab_size) {} bool AcceptToken(int32_t token_id) final; @@ -143,8 +146,8 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm void ResetState() final { stack_tops_history_.Reset(); - token_size_history_.clear(); - InitStackState(); + token_length_history.clear(); + PushInitialState(kInvalidRulePosition, true); } private: @@ -160,14 +163,8 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm const std::vector& uncertain_tokens_bitset); /*! \brief Set the acceptable next token in next_token_bitmask. */ - void SetTokenBitmask(DLTensor* next_token_bitmask, std::vector& accepted_indices, - std::vector& rejected_indices, bool can_reach_end); - - /*! \brief Check if a token is a stop token. */ - bool IsStopToken(int32_t token_id) const { - return std::find(init_ctx_->stop_token_ids.begin(), init_ctx_->stop_token_ids.end(), - token_id) != init_ctx_->stop_token_ids.end(); - } + void SetTokenBitmask(DLTensor* next_token_bitmask, const DynamicBitset& accepted_bitset, + const std::vector& rejected_indices, bool can_reach_end); /*! * \brief Accept the stop token and terminates the matcher. @@ -180,14 +177,12 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm std::shared_ptr init_ctx_; int max_rollback_steps_; - std::deque token_size_history_; + std::deque token_length_history; // Temporary data for FindNextTokenBitmask. They are stored here to avoid repeated allocation. - std::vector tmp_accepted_indices_; + DynamicBitset tmp_accepted_bitset_; std::vector tmp_rejected_indices_; - std::vector tmp_accepted_indices_delta_; std::vector tmp_rejected_indices_delta_; - std::vector tmp_uncertain_tokens_bitset_; }; bool GrammarStateMatcherNodeImpl::AcceptStopToken() { @@ -204,23 +199,31 @@ bool GrammarStateMatcherNodeImpl::AcceptToken(int32_t token_id) { "accept another token id " << token_id; + CHECK(token_id >= 0 && token_id < init_ctx_->vocab_size) + << "Invalid token id " << token_id << " for GrammarStateMatcher"; + // Handle the stop token - if (IsStopToken(token_id)) { + if (std::find(init_ctx_->stop_token_ids.begin(), init_ctx_->stop_token_ids.end(), token_id) != + init_ctx_->stop_token_ids.end()) { return AcceptStopToken(); } - CHECK(init_ctx_->id_to_token_codepoints.count(token_id) > 0) - << "Token id " << token_id << " is not supported in generation"; - const auto& token = init_ctx_->id_to_token_codepoints[token_id].token; - for (auto codepoint : token) { - if (!AcceptCodepoint(codepoint, false)) { + if (init_ctx_->special_token_ids.count(token_id) > 0) { + LOG(FATAL) + << "Token id " << token_id << ": " << init_ctx_->token_table[token_id] + << " is regarded as a special token, and cannot be accepted by the GrammarStateMatcher"; + } + + const auto& token = init_ctx_->token_table[token_id]; + for (auto char_value : token) { + if (!AcceptChar(char_value, false)) { return false; } } - token_size_history_.push_back(token.size()); - if (token_size_history_.size() > max_rollback_steps_) { - DiscardEarliestCodepoints(token_size_history_.front()); - token_size_history_.pop_front(); + token_length_history.push_back(token.size()); + if (token_length_history.size() > max_rollback_steps_) { + DiscardEarliestChars(token_length_history.front()); + token_length_history.pop_front(); } return true; } @@ -229,7 +232,7 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm CHECK(!IsTerminated()) << "GrammarStateMatcher has terminated after accepting the stop token, but is trying to " "find the next token mask"; - const auto& sorted_token_codepoints = init_ctx_->sorted_token_codepoints; + const auto& sorted_token_table = init_ctx_->sorted_token_table; const auto& catagorized_tokens_for_grammar = init_ctx_->catagorized_tokens_for_grammar; const auto& latest_stack_tops = stack_tops_history_.GetLatest(); @@ -238,113 +241,132 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm // The final accepted token set is the union of the accepted token sets of all stacks. // The final rejected token set is the intersection of the rejected token sets of all stacks. - // Note these indices store the indices in sorted_token_codepoints, instead of the token ids. - tmp_accepted_indices_.clear(); + // Note these indices store the indices in sorted_token_table, instead of the token ids. + tmp_accepted_bitset_.Reset(); // {-1} means the universal set, i.e. all tokens initially tmp_rejected_indices_.assign({-1}); + // std::chrono::microseconds time_unc(0); + // std::chrono::microseconds time_idx(0); + int check_cnt = 0; + for (auto top : latest_stack_tops) { - // Step 1. Find the current catagorized_tokens auto cur_rule_position = tree_[top]; - auto current_sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); - if (cur_rule_position.parent_id == RulePosition::kNoParent && - cur_rule_position.element_id == current_sequence.size()) { + if (tree_.IsEndPosition(cur_rule_position)) { continue; } - const auto& catagorized_tokens = catagorized_tokens_for_grammar.at( - {cur_rule_position.sequence_id, cur_rule_position.element_id}); + const auto& catagorized_tokens = catagorized_tokens_for_grammar.at(cur_rule_position); + + // auto start = std::chrono::high_resolution_clock::now(); // For each stack, we will check every uncertain token and put them into the accepted or // rejected list. - // If the accepted tokens are saved, it means it is likely to be smaller than the rejected - // tokens, so we will just find the accepted tokens, and vice versa. - bool is_find_accept_mode = - catagorized_tokens.not_saved_index != CatagorizedTokens::NotSavedIndex::kAccepted; - - // If uncertain tokens are saved, we will iterate over the uncertain tokens. - // Otherwise, we will iterate over all_tokens - accepted_tokens - rejected_tokens. - bool is_uncertain_saved = - catagorized_tokens.not_saved_index != CatagorizedTokens::NotSavedIndex::kUncertain; // Step 2. Update the accepted tokens in accepted_indices_delta, or the rejected tokens in // rejected_indices_delta. - // Examine only the current one stack - stack_tops_history_.PushHistory({tree_.NewNode(cur_rule_position)}); - - const std::vector* prev_token = nullptr; - int prev_matched_size = 0; + // If the accepted tokens are saved, it means it is likely to be smaller than the rejected + // tokens, so we will just find the accepted tokens, and vice versa. - tmp_accepted_indices_delta_.clear(); tmp_rejected_indices_delta_.clear(); - if (!is_uncertain_saved) { - // unc_tokens = all_tokens - accepted_tokens - rejected_tokens - tmp_uncertain_tokens_bitset_.assign(sorted_token_codepoints.size(), true); - for (auto idx : catagorized_tokens.accepted_indices) { - tmp_uncertain_tokens_bitset_[idx] = false; - } - for (auto idx : catagorized_tokens.rejected_indices) { - tmp_uncertain_tokens_bitset_[idx] = false; - } - } + // Examine only the current one stack + stack_tops_history_.PushHistory({tree_.NewNode(cur_rule_position)}); - int iterator_uncertain = -1; + const std::string* prev_token = nullptr; + int prev_matched_size = 0; - while (true) { - // Step 2.1. Find the current token. - auto idx = - GetNextUncertainToken(is_uncertain_saved, &iterator_uncertain, - catagorized_tokens.uncertain_indices, tmp_uncertain_tokens_bitset_); - if (idx == -1) { - break; - } - const auto& cur_token = sorted_token_codepoints[idx].token; + // std::cout << tree_.PrintNode(top) << std::endl; + + // std::cout << "Accepted count: " << catagorized_tokens.accepted_indices.size() + // << ", rejected count: " << catagorized_tokens.rejected_indices.size() + // << ", uncertain count: " << catagorized_tokens.uncertain_indices.size() + // << ", save type: " << static_cast(catagorized_tokens.save_type) << std::endl; + + // if (catagorized_tokens.accepted_indices.size() < 200) { + // std::cout << "Accpeted: "; + // for (int i = 0; i < catagorized_tokens.accepted_indices.size(); ++i) { + // std::cout << "<" + // << PrintAsEscaped( + // sorted_token_table[catagorized_tokens.accepted_indices[i]].second) + // << "> "; + // } + // std::cout << "\n"; + // } + + // if (catagorized_tokens.uncertain_indices.size() > 100) { + // std::cout << "Uncertain: "; + // for (int i = 0; i < catagorized_tokens.uncertain_indices.size(); ++i) { + // std::cout << "<" + // << PrintAsEscaped( + // sorted_token_table[catagorized_tokens.uncertain_indices[i]].second) + // << "> "; + // } + // std::cout << "\n"; + // } + + for (auto cur_token_idx : catagorized_tokens.uncertain_indices) { + const auto& cur_token = sorted_token_table[cur_token_idx].second; + bool accepted = true; - // Step 2.2. Find the longest common prefix with the accepted part of the previous token. + // Step 2.1. Find the longest common prefix with the accepted part of the previous token. // We can reuse the previous matched size to avoid unnecessary matching. - int prev_useful_size = 0; if (prev_token) { - prev_useful_size = std::min(prev_matched_size, static_cast(cur_token.size())); - for (int j = 0; j < prev_useful_size; ++j) { - if (cur_token[j] != (*prev_token)[j]) { - prev_useful_size = j; - break; - } + int lcp_len = std::mismatch(cur_token.begin(), cur_token.end(), prev_token->begin(), + prev_token->end()) + .first - + cur_token.begin(); + if (lcp_len > prev_matched_size) { + accepted = false; + } else if (lcp_len < prev_matched_size) { + RollbackChars(prev_matched_size - lcp_len); } - RollbackCodepoints(prev_matched_size - prev_useful_size); + prev_matched_size = std::min(prev_matched_size, lcp_len); } - // Step 2.3. Find if the current token is accepted or rejected. - bool accepted = true; - prev_matched_size = prev_useful_size; - - for (int j = prev_useful_size; j < cur_token.size(); ++j) { - if (!AcceptCodepoint(cur_token[j], false)) { - accepted = false; - break; + // Step 2.2. Find if the current token is accepted or rejected. + if (accepted) { + for (int j = prev_matched_size; j < cur_token.size(); ++j) { + ++check_cnt; + if (!AcceptChar(cur_token[j], false)) { + accepted = false; + break; + } + prev_matched_size = j + 1; } - prev_matched_size = j + 1; } - // Step 2.4. Push the result to the delta list. - if (accepted && is_find_accept_mode) { - tmp_accepted_indices_delta_.push_back(idx); - } else if (!accepted && !is_find_accept_mode) { - tmp_rejected_indices_delta_.push_back(idx); + // Step 2.3. Push the result to the delta list. + if (catagorized_tokens.save_type == SaveType::kAcceptedBitset || + catagorized_tokens.save_type == SaveType::kAccepted) { + if (accepted) { + tmp_accepted_bitset_.Set(sorted_token_table[cur_token_idx].first, true); + } + } else { + if (!accepted) { + tmp_rejected_indices_delta_.push_back(cur_token_idx); + } } prev_token = &cur_token; } - RollbackCodepoints(prev_matched_size + 1); + RollbackChars(prev_matched_size + 1); + + // auto end = std::chrono::high_resolution_clock::now(); + + // time_unc += std::chrono::duration_cast(end - start); + + // start = std::chrono::high_resolution_clock::now(); // Step 3. Update the accepted_indices and rejected_indices - if (is_find_accept_mode) { - // accepted_indices += catagorized_tokens.accepted_indices + accepted_indices_delta - IntsetUnion(&tmp_accepted_indices_delta_, catagorized_tokens.accepted_indices); - IntsetUnion(&tmp_accepted_indices_, tmp_accepted_indices_delta_); + if (catagorized_tokens.save_type == SaveType::kAcceptedBitset) { + tmp_accepted_bitset_ |= catagorized_tokens.accepted_bitset; + } else if (catagorized_tokens.save_type == SaveType::kAccepted) { + for (auto idx : catagorized_tokens.accepted_indices) { + tmp_accepted_bitset_.Set(sorted_token_table[idx].first, true); + } } else { // rejected_indices = Intersect( // rejected_indices, @@ -352,72 +374,81 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm IntsetUnion(&tmp_rejected_indices_delta_, catagorized_tokens.rejected_indices); IntsetIntersection(&tmp_rejected_indices_, tmp_rejected_indices_delta_); } + // end = std::chrono::high_resolution_clock::now(); + // time_idx += std::chrono::duration_cast(end - start); } // Finally update the rejected_ids bitset + // auto start = std::chrono::high_resolution_clock::now(); bool can_reach_end = CanReachEnd(); - SetTokenBitmask(next_token_bitmask, tmp_accepted_indices_, tmp_rejected_indices_, can_reach_end); + SetTokenBitmask(next_token_bitmask, tmp_accepted_bitset_, tmp_rejected_indices_, can_reach_end); + // auto end = std::chrono::high_resolution_clock::now(); + // time_idx += std::chrono::duration_cast(end - start); + // std::cout << "Time for uncertain: " << time_unc.count() + // << "us, time for index: " << time_idx.count() << "us" << std::endl; + // std::cout << "Check cnt " << check_cnt << std::endl; } void GrammarStateMatcherNodeImpl::Rollback(int num_tokens) { - CHECK(num_tokens <= token_size_history_.size()) + CHECK(num_tokens <= token_length_history.size()) << "Intended to rollback " << num_tokens << " tokens, but only the last " - << token_size_history_.size() << " steps of history are saved"; + << token_length_history.size() << " steps of history are saved"; while (num_tokens > 0) { - int steps = token_size_history_.back(); - RollbackCodepoints(steps); - token_size_history_.pop_back(); + int steps = token_length_history.back(); + RollbackChars(steps); + token_length_history.pop_back(); --num_tokens; } } void GrammarStateMatcherNodeImpl::SetTokenBitmask(DLTensor* next_token_bitmask, - std::vector& accepted_indices, - std::vector& rejected_indices, + const DynamicBitset& accepted_bitset, + const std::vector& rejected_indices, bool can_reach_end) { - // accepted_ids = Union(accepted_indices, all_tokens - rejected_indices) - // rejected_ids = Intersect(all_tokens - accepted_indices, rejected_indices) + // next_token_bitmask = set(all accepted tokens) = + // 1. all_tokens - (rejected_ids / accepted_ids) + // (when rejected_ids != {-1}, i.e. rejected_ids is not the universal set) + // 2. accepted_ids + // (otherwise, when rejected_ids is the universal set) CHECK(next_token_bitmask->dtype.code == kDLUInt && next_token_bitmask->dtype.bits == 32 && next_token_bitmask->data && next_token_bitmask->ndim == 1 && next_token_bitmask->shape) << "The provied bitmask's shape or dtype is not valid."; + CHECK(next_token_bitmask->shape[0] >= DynamicBitset::CalculateBufferSize(init_ctx_->vocab_size)) + << "The provided bitmask is not large enough to store the token set. The length should be " + << DynamicBitset::CalculateBufferSize(init_ctx_->vocab_size) << " at least"; - BitsetManager next_token_bitset(reinterpret_cast(next_token_bitmask->data), - next_token_bitmask->shape[0], init_ctx_->vocab_size); + DynamicBitset next_token_bitset(init_ctx_->vocab_size, + reinterpret_cast(next_token_bitmask->data)); + const auto& sorted_token_table = init_ctx_->sorted_token_table; if (rejected_indices.size() == 1 && rejected_indices[0] == -1) { // If rejected_indices is the universal set, the final accepted token set is just // accepted_indices - next_token_bitset.Reset(false); - for (int idx : accepted_indices) { - next_token_bitset.Set(init_ctx_->sorted_token_codepoints[idx].id, true); - } + next_token_bitset = accepted_bitset; if (can_reach_end) { // add end tokens - for (int idx : init_ctx_->stop_token_ids) { - next_token_bitset.Set(idx, true); + for (int id : init_ctx_->stop_token_ids) { + next_token_bitset.Set(id, true); } } } else { // Otherwise, the final rejected token set is (rejected_indices \ accepted_indices) - next_token_bitset.Reset(true); + next_token_bitset.Set(); - auto it_acc = accepted_indices.begin(); for (auto i : rejected_indices) { - while (it_acc != accepted_indices.end() && *it_acc < i) { - ++it_acc; - } - if (it_acc == accepted_indices.end() || *it_acc != i) { - next_token_bitset.Set(init_ctx_->sorted_token_codepoints[i].id, false); + auto id = sorted_token_table[i].first; + if (!accepted_bitset[id]) { + next_token_bitset.Set(id, false); } } - for (int idx : init_ctx_->special_token_ids) { - next_token_bitset.Set(idx, false); + for (int id : init_ctx_->special_token_ids) { + next_token_bitset.Set(id, false); } if (!can_reach_end) { - for (int idx : init_ctx_->stop_token_ids) { - next_token_bitset.Set(idx, false); + for (int id : init_ctx_->stop_token_ids) { + next_token_bitset.Set(id, false); } } } @@ -452,16 +483,24 @@ GrammarStateMatcher::GrammarStateMatcher(std::shared_ptr tokenizer, int max_rollback_steps) { + .set_body_typed([](BNFGrammar grammar, Optional tokenizer, int max_rollback_steps, + String token_table_postproc_method) { auto preproc_start = std::chrono::high_resolution_clock::now(); - auto init_ctx = GrammarStateMatcher::CreateInitContext( - grammar, tokenizer ? tokenizer.value()->TokenTable() : std::vector()); + std::shared_ptr init_ctx; + if (tokenizer) { + auto token_table = Tokenizer::PostProcessTokenTable(tokenizer.value()->TokenTable(), + token_table_postproc_method); + init_ctx = GrammarStateMatcher::CreateInitContext(grammar, token_table); + } else { + init_ctx = GrammarStateMatcher::CreateInitContext(grammar, {}); + } + auto preproc_end = std::chrono::high_resolution_clock::now(); - std::cerr << "Preprocess takes " + LOG(INFO) << "GrammarStateMatcher preprocess takes " << std::chrono::duration_cast(preproc_end - preproc_start) .count() - << "us" << std::endl; + << "us"; return GrammarStateMatcher(init_ctx, max_rollback_steps); }); #endif @@ -479,11 +518,11 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenTable") *rv = GrammarStateMatcher(init_ctx, max_rollback_steps); }); -TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugAcceptCodepoint") - .set_body_typed([](GrammarStateMatcher matcher, int32_t codepoint) { +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugAcceptChar") + .set_body_typed([](GrammarStateMatcher matcher, int32_t codepoint, bool verbose) { auto mutable_node = const_cast(matcher.as()); - return mutable_node->AcceptCodepoint(codepoint); + return mutable_node->AcceptChar(codepoint, verbose); }); TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherAcceptToken") @@ -507,32 +546,43 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherResetState") /*! \brief Check if a matcher can accept the complete string, and then reach the end of the * grammar. Does not change the state of the GrammarStateMatcher. For test purpose. */ -bool MatchCompleteString(GrammarStateMatcher matcher, String str) { +bool MatchCompleteString(GrammarStateMatcher matcher, String str, bool verbose) { auto mutable_node = const_cast(matcher.as()); - auto codepoints = ParseUTF8(str.c_str()); int accepted_cnt = 0; - for (auto codepoint : codepoints) { - if (!mutable_node->AcceptCodepoint(codepoint, false)) { - mutable_node->RollbackCodepoints(accepted_cnt); + for (auto char_value : str.operator std::string()) { + if (!mutable_node->AcceptChar(char_value, verbose)) { + if (verbose) { + LOG(INFO) << "Matching failed after accepting " << accepted_cnt << " characters"; + } + mutable_node->RollbackChars(accepted_cnt); return false; } ++accepted_cnt; } auto accepted = mutable_node->CanReachEnd(); - mutable_node->RollbackCodepoints(accepted_cnt); + if (verbose) { + if (accepted) { + LOG(INFO) << "Matching succeed after accepting " << accepted_cnt << " characters"; + } else { + LOG(INFO) << "Matching failed due to the end state not reached after all " << accepted_cnt + << " characters are accepted"; + } + } + mutable_node->RollbackChars(accepted_cnt); return accepted; } TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugMatchCompleteString") - .set_body_typed([](GrammarStateMatcher matcher, String str) { - return MatchCompleteString(matcher, str); + .set_body_typed([](GrammarStateMatcher matcher, String str, bool verbose) { + return MatchCompleteString(matcher, str, verbose); }); /*! \brief Print the accepted and rejected tokens stored in the bitset. For debug purposes. */ -void PrintAcceptedRejectedTokens( +std::string PrintAcceptedRejectedTokens( const std::shared_ptr& init_ctx, - const BitsetManager& bitset, int threshold = 500) { + const DynamicBitset& bitset, int threshold = 300) { + std::stringstream ss; auto vocab_size = init_ctx->vocab_size; std::vector accepted_ids; std::vector rejected_ids; @@ -544,42 +594,27 @@ void PrintAcceptedRejectedTokens( } } - if (accepted_ids.size() < threshold) { - std::cerr << "Accepted: "; - for (auto id : accepted_ids) { - std::cerr << "<"; - auto token = init_ctx->token_table[id]; - if (token.size() == 1 && (static_cast(token[0]) >= 128 || token[0] == 0)) { - // First cast to unsigned, then cast to int - std::cerr << static_cast(static_cast(token[0])); - } else { - auto codepoints = ParseUTF8(token.c_str()); - for (auto c : codepoints) { - std::cerr << PrintAsEscaped(c); - } - } - std::cerr << "> "; - } - std::cerr << "\n"; + ss << "Accepted: "; + auto end_it = + accepted_ids.size() > threshold ? accepted_ids.begin() + threshold : accepted_ids.end(); + for (auto it = accepted_ids.begin(); it != end_it; ++it) { + ss << "<" << PrintAsEscaped(init_ctx->token_table[*it]) << "> "; + } + if (accepted_ids.size() > threshold) { + ss << "..."; } + ss << "\n"; - if (rejected_ids.size() < threshold) { - std::cerr << "Rejected: "; - for (auto id : rejected_ids) { - std::cerr << "<"; - auto token = init_ctx->token_table[id]; - if (token.size() == 1 && ((unsigned char)token[0] >= 128 || token[0] == 0)) { - std::cerr << (int)(unsigned char)token[0]; - } else { - auto codepoints = ParseUTF8(token.c_str()); - for (auto c : codepoints) { - std::cerr << PrintAsEscaped(c); - } - } - std::cerr << "> "; - } - std::cerr << "\n"; + ss << "Rejected: "; + end_it = rejected_ids.size() > threshold ? rejected_ids.begin() + threshold : rejected_ids.end(); + for (auto it = rejected_ids.begin(); it != end_it; ++it) { + ss << "<" << PrintAsEscaped(init_ctx->token_table[*it]) << "> "; + } + if (rejected_ids.size() > threshold) { + ss << "..."; } + ss << "\n"; + return ss.str(); } /*! @@ -591,7 +626,7 @@ void PrintAcceptedRejectedTokens( IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = false) { auto init_ctx = matcher.as()->init_ctx_; auto vocab_size = init_ctx->vocab_size; - auto bitset_size = BitsetManager::CalculateBufferSize(vocab_size); + auto bitset_size = DynamicBitset::CalculateBufferSize(vocab_size); auto ndarray = NDArray::Empty(ShapeTuple{static_cast(bitset_size)}, DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0}); auto dltensor = const_cast(ndarray.operator->()); @@ -605,7 +640,7 @@ IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = fals end = std::chrono::high_resolution_clock::now(); } - auto bitset = BitsetManager(reinterpret_cast(dltensor->data), bitset_size, vocab_size); + auto bitset = DynamicBitset(vocab_size, reinterpret_cast(dltensor->data)); std::vector rejected_ids; for (int i = 0; i < vocab_size; i++) { if (bitset[i] == 0) { @@ -614,10 +649,10 @@ IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = fals } if (verbose) { - std::cerr << "FindNextTokenBitmask takes " + LOG(INFO) << "FindNextTokenBitmask takes " << std::chrono::duration_cast(end - start).count() << "us" << ", found accepted: " << vocab_size - rejected_ids.size() - << ", rejected: " << rejected_ids.size() << std::endl; + << ", rejected: " << rejected_ids.size(); } auto ret = IntTuple(rejected_ids); @@ -634,7 +669,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFindNextRejectedTokens") NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher) { auto init_ctx = matcher.as()->init_ctx_; auto vocab_size = init_ctx->vocab_size; - auto bitset_size = BitsetManager::CalculateBufferSize(vocab_size); + auto bitset_size = DynamicBitset::CalculateBufferSize(vocab_size); auto bitmask = NDArray::Empty(ShapeTuple{static_cast(bitset_size)}, DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0}); auto dltensor = const_cast(bitmask.operator->()); diff --git a/cpp/serve/grammar/grammar_state_matcher.h b/cpp/serve/grammar/grammar_state_matcher.h index eceaa75d07..eedf7a1989 100644 --- a/cpp/serve/grammar/grammar_state_matcher.h +++ b/cpp/serve/grammar/grammar_state_matcher.h @@ -130,14 +130,13 @@ class GrammarStateMatcher : public ObjectRef { }; /*! - * \brief Helper class to get the grammar state init context for grammars or schemas. This class - * maintains cache internally, so the same grammar or schema will not be preprocessed multiple - * times. + * \brief A cache to get the grammar state init context for grammar or schema. This class avoids + * redundant preprocessing of the grammar or schema when constructing a GrammarStateInitContext. * \note This class is associated with a token table when constructed. The token table is used to * create every grammar state init context. If multiple toke tables are used to create init * contexts, an instance of this class for each token table should be created. */ -class GrammarInitContextStorageNode : public Object { +class GrammarInitContextCacheNode : public Object { public: /*! \brief Get the init context for pure JSON. */ virtual std::shared_ptr GetInitContextForJSON() = 0; @@ -147,25 +146,25 @@ class GrammarInitContextStorageNode : public Object { const std::string& schema) = 0; /*! \brief Clear the interal cache of init contexts. */ - virtual void ClearCache() = 0; + virtual void Clear() = 0; - static constexpr const char* _type_key = "mlc.serve.GrammarInitContextStorageNode"; + static constexpr const char* _type_key = "mlc.serve.GrammarInitContextCacheNode"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(GrammarInitContextStorageNode, Object); + TVM_DECLARE_BASE_OBJECT_INFO(GrammarInitContextCacheNode, Object); }; -class GrammarInitContextStorage : public ObjectRef { +class GrammarInitContextCache : public ObjectRef { public: /*! - * \brief Construct a GrammarInitContextStorage with a token table. This class will always create + * \brief Construct a GrammarInitContextCache with a token table. This class will always create * grammar state init contexts with this token table. * \param token_table The token table that the grammar will use. */ - GrammarInitContextStorage(const std::vector& token_table); + GrammarInitContextCache(const std::vector& token_table); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GrammarInitContextStorage, ObjectRef, - GrammarInitContextStorageNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GrammarInitContextCache, ObjectRef, + GrammarInitContextCacheNode); }; } // namespace serve diff --git a/cpp/serve/grammar/grammar_state_matcher_base.h b/cpp/serve/grammar/grammar_state_matcher_base.h index 5b774d33a4..1241e7307a 100644 --- a/cpp/serve/grammar/grammar_state_matcher_base.h +++ b/cpp/serve/grammar/grammar_state_matcher_base.h @@ -32,95 +32,172 @@ class GrammarStateMatcherBase { * \param grammar The grammar to match. * \param init_rule_position The initial rule position. If not specified, the main rule will be * used. + * \param expand_init_rule_position Whether to expand the initial rule position to all possible + * locations. See ExpandRulePosition. */ - GrammarStateMatcherBase(const BNFGrammar& grammar, RulePosition init_rule_position = {}) + GrammarStateMatcherBase(const BNFGrammar& grammar, + RulePosition init_rule_position = kInvalidRulePosition, + bool expand_init_rule_position = true) : grammar_(grammar), tree_(grammar), stack_tops_history_(&tree_) { - InitStackState(init_rule_position); + PushInitialState(init_rule_position, expand_init_rule_position); } - /*! \brief Accept one codepoint. */ - bool AcceptCodepoint(TCodepoint codepoint, bool verbose = false); + /*! \brief Accept one character. */ + bool AcceptChar(uint8_t char_value, bool verbose = false); /*! \brief Check if the end of the main rule is reached. If so, the stop token can be accepted. */ bool CanReachEnd() const; - /*! \brief Rollback the matcher to a previous state. */ - void RollbackCodepoints(int rollback_codepoint_cnt); + /*! \brief Rollback the matcher to a previous state by the number of characters. */ + void RollbackChars(int rollback_cnt); - /*! \brief Discard the earliest history. */ - void DiscardEarliestCodepoints(int discard_codepoint_cnt); + /*! \brief Discard the earliest history by the number of characters. */ + void DiscardEarliestChars(int discard_cnt); /*! \brief Print the stack state. */ std::string PrintStackState(int steps_behind_latest = 0) const; protected: - // Init the stack state according to the given rule position. - // If init_rule_position is {}, init the stack with the main rule. - void InitStackState(RulePosition init_rule_position = {}); + // Push an initial stack state according to the given rule position. + // If init_rule_position is kInvalidRulePosition, init the stack with the main rule. + void PushInitialState(RulePosition init_rule_position, bool expand_init_rule_position); - // Update the char_class_star_id field of the given rule_position, if it refers to a character - // class star rule. - void UpdateCharClassStarId(RulePosition* rule_position) const; + // Check if the character is accepted by the current rule position. + bool CheckIfAccepted(const RulePosition& rule_position, uint8_t char_value) const; /*! * \brief Find the next position in the rule. If the next position is at the end of the rule, - * the result depends on the consider_parent parameter: - * - false: kInvalidRulePosition will be returned. - * - true: the next position of the parent rule will be returned. If the current rule is the root - * rule, the RulePosition will be returned as is to indicate the end of the grammar. + * and consider_parent is true, will iteratively find the next position in the parent rule. * \param rule_position The current position. - * \param consider_parent Whether to consider the parent position if the current position is at - * the end of the rule. + * \param consider_parent Whether to consider the parent position if the current position is + * at the end of the rule. + * \returns (success, next_rule_position), indicating if the iteration is successful and the + * next rule position. */ - RulePosition IterateToNextPosition(const RulePosition& rule_position, bool consider_parent) const; + std::pair GetNextPositionInSequence(const RulePosition& rule_position, + bool consider_parent) const; + + // Return the updated rule position after accepting the char + RulePosition UpdatePositionWithChar(const RulePosition& rule_position, uint8_t char_value) const; /*! - * \brief Expand the given rule position (may be a RuleRef element) s.t. every new position is a - * CharacterClass or refers to a CharacterClassStar rule. Push all new positions into - * new_stack_tops. - * \details This method will start from cur_rule_position and continuously iterate to the next - * position as long as the current position can be empty (e.g. the current position is a - * reference to an rule that can be empty, or to a character class star rule). If the current - * position can not be empty, stop expanding. All positions collected will be pushed into - * new_stack_tops. + * \brief Expand the given rule position to all possible positions approachable in the grammar. + * The expanded positions must refers to an element (CharacterClass or CharacterClassStar or + * ByteString) in a rule. Push all new positions into new_stack_tops. + * \example + * A ::= "a" B [a-z]* "c" + * B ::= "b" | "" * - * If the end of the current rule is reached: - * - If is_outmost_level is true, we can go to the next position in the parent rule. - * - Otherwise, stop iteration. + * Input position: (rule=A, position=B) + * Approachable positions: (rule=B, position="b"), (rule=A, position=[a-z]*), + * (rule=A, position="c"), since B and [a-z]* can be empty. * \param cur_rule_position The current rule position. * \param new_stack_tops The vector to store the new stack tops. - * \param is_outmost_level Whether the current position is the outmost level of the rule. - * \param first_id_if_inserted Being not -1 means the first node is already inserted. This is the - * id of the first node. This is used to avoid inserting the same node twice. - * \return Whether the end of the rule can be reached. Used as the condition of recursion. + * \param consider_parent Whether consider expanding the elements in the parent rule. Useful for + * inner recursion. + * \param first_id_if_inserted An optimization. When cur_rule_position is already inserted to + * the state tree, pass its id to avoid inserting it again. -1 (ignore it) by default. + * \return Whether the end of the rule can be reached. Useful for inner recursion. */ bool ExpandRulePosition(RulePosition cur_rule_position, std::vector* new_stack_tops, - bool is_outmost_level, int32_t first_id_if_inserted = -1); + bool consider_parent = true, int32_t first_id_if_inserted = -1); + // The matched grammar. BNFGrammar grammar_; + // The tree storing all states RulePositionTree tree_; + // The tracked history of stack tops (each stack top refers to a node in the tree). + // We store the stack tops in different steps in the history to support rollback. StackTopsHistory stack_tops_history_; - // Temporary data for AcceptCodepoint. + // Temporary data for AcceptChar. std::vector tmp_new_stack_tops_; }; /*! \brief Check the codepoint is contained in the character class. */ -inline bool CharacterClassContains(const BNFGrammarNode::RuleExpr& rule_expr, - TCodepoint codepoint) { - DCHECK(rule_expr.type == BNFGrammarNode::RuleExprType::kCharacterClass || - rule_expr.type == BNFGrammarNode::RuleExprType::kNegCharacterClass); - for (int i = 0; i < rule_expr.size(); i += 2) { - if (rule_expr.data[i] <= codepoint && codepoint <= rule_expr.data[i + 1]) { - return rule_expr.type == BNFGrammarNode::RuleExprType::kCharacterClass; +inline bool GrammarStateMatcherBase::CheckIfAccepted(const RulePosition& rule_position, + uint8_t char_value) const { + auto current_sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + auto current_element = grammar_->GetRuleExpr(current_sequence[rule_position.element_id]); + if (current_element.type == RuleExprType::kCharacterClass || + current_element.type == RuleExprType::kCharacterClassStar) { + if (rule_position.left_utf8_bytes > 0) { + return (char_value & 0xC0) == 0x80; + } + auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(char_value); + if (!accepted) { + return false; + } + bool is_negative = static_cast(current_element[0]); + if (num_bytes > 1) { + return is_negative; + } + for (int i = 1; i < current_element.size(); i += 2) { + if (current_element[i] <= char_value && char_value <= current_element[i + 1]) { + return !is_negative; + } + } + return is_negative; + } else if (current_element.type == RuleExprType::kByteString) { + return current_element[rule_position.element_in_string] == char_value; + } else { + LOG(FATAL) << "Unexpected RuleExprType in CheckIfAccepted: " + << static_cast(current_element.type); + } +} + +inline RulePosition GrammarStateMatcherBase::UpdatePositionWithChar( + const RulePosition& rule_position, uint8_t char_value) const { + auto current_sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + auto current_element = grammar_->GetRuleExpr(current_sequence[rule_position.element_id]); + RulePosition new_rule_position = rule_position; + switch (current_element.type) { + case RuleExprType::kCharacterClass: { + if (rule_position.left_utf8_bytes > 1) { + new_rule_position.left_utf8_bytes -= 1; + return new_rule_position; + } else if (rule_position.left_utf8_bytes == 1) { + return GetNextPositionInSequence(rule_position, true).second; + } + // If no left utf8 bytes, check the first byte to find the left bytes needed. + DCHECK(rule_position.left_utf8_bytes == 0); + auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(char_value); + DCHECK(accepted); + if (num_bytes > 1) { + new_rule_position.left_utf8_bytes = num_bytes - 1; + return new_rule_position; + } + return GetNextPositionInSequence(rule_position, true).second; + } + case RuleExprType::kCharacterClassStar: { + if (rule_position.left_utf8_bytes >= 1) { + new_rule_position.left_utf8_bytes -= 1; + } else { + DCHECK(rule_position.left_utf8_bytes == 0); + auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(char_value); + DCHECK(accepted); + new_rule_position.left_utf8_bytes = num_bytes - 1; + } + return new_rule_position; + } + case RuleExprType::kByteString: { + if (rule_position.element_in_string + 1 < current_element.size()) { + new_rule_position.element_in_string += 1; + return new_rule_position; + } + return GetNextPositionInSequence(rule_position, true).second; } + default: + LOG(FATAL) << "Unexpected RuleExprType in UpdatePositionWithChar: " + << static_cast(current_element.type); } - return rule_expr.type == BNFGrammarNode::RuleExprType::kNegCharacterClass; } -inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool verbose) { +inline bool GrammarStateMatcherBase::AcceptChar(uint8_t char_value, bool verbose) { if (verbose) { - std::cout << "Stack before accepting: " << PrintStackState() << std::endl; + LOG(INFO) << "Matching char: " << static_cast(char_value) << " \"" + << PrintAsEscaped(char_value) << "\""; + LOG(INFO) << "Previous stack: " << PrintStackState(); } const auto& prev_stack_tops = stack_tops_history_.GetLatest(); @@ -135,37 +212,31 @@ inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool continue; } - auto current_char_class = - cur_rule_position.char_class_star_id != -1 - ? grammar_->GetRuleExpr(cur_rule_position.char_class_star_id) - : grammar_->GetRuleExpr(current_sequence[cur_rule_position.element_id]); - DCHECK(current_char_class.type == RuleExprType::kCharacterClass || - current_char_class.type == RuleExprType::kNegCharacterClass); - auto ok = CharacterClassContains(current_char_class, codepoint); - if (!ok) { + auto accepted = CheckIfAccepted(cur_rule_position, char_value); + if (!accepted) { continue; } - if (cur_rule_position.char_class_star_id == -1) { - auto next_rule_position = IterateToNextPosition(cur_rule_position, true); - DCHECK(next_rule_position != kInvalidRulePosition); - ExpandRulePosition(next_rule_position, &tmp_new_stack_tops_, true); + auto new_rule_position = UpdatePositionWithChar(cur_rule_position, char_value); + + if (new_rule_position == cur_rule_position) { + ExpandRulePosition(new_rule_position, &tmp_new_stack_tops_, true, prev_top); } else { - ExpandRulePosition(cur_rule_position, &tmp_new_stack_tops_, true, prev_top); + ExpandRulePosition(new_rule_position, &tmp_new_stack_tops_, true); } } if (tmp_new_stack_tops_.empty()) { if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Rejected" - << std::endl; + LOG(INFO) << "Character " << static_cast(char_value) << " \"" + << PrintAsEscaped(char_value) << "\" Rejected"; } return false; } stack_tops_history_.PushHistory(tmp_new_stack_tops_); if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Accepted" - << std::endl; - std::cout << "Stack after accepting: " << PrintStackState() << std::endl; + LOG(INFO) << "Character: " << static_cast(char_value) << " \"" + << PrintAsEscaped(char_value) << "\" Accepted"; + LOG(INFO) << "New stack after acceptance: " << PrintStackState(); } #if TVM_LOG_DEBUG stack_tops_history_.CheckWellFormed(); @@ -179,80 +250,92 @@ inline bool GrammarStateMatcherBase::CanReachEnd() const { [&](int32_t id) { return tree_.IsEndPosition(tree_[id]); }); } -inline void GrammarStateMatcherBase::RollbackCodepoints(int rollback_codepoint_cnt) { - stack_tops_history_.Rollback(rollback_codepoint_cnt); +inline void GrammarStateMatcherBase::RollbackChars(int rollback_cnt) { + stack_tops_history_.Rollback(rollback_cnt); } -inline void GrammarStateMatcherBase::DiscardEarliestCodepoints(int discard_codepoint_cnt) { - stack_tops_history_.DiscardEarliest(discard_codepoint_cnt); +inline void GrammarStateMatcherBase::DiscardEarliestChars(int discard_cnt) { + stack_tops_history_.DiscardEarliest(discard_cnt); } inline std::string GrammarStateMatcherBase::PrintStackState(int steps_behind_latest) const { return stack_tops_history_.PrintHistory(steps_behind_latest); } -inline void GrammarStateMatcherBase::InitStackState(RulePosition init_rule_position) { +inline void GrammarStateMatcherBase::PushInitialState(RulePosition init_rule_position, + bool expand_init_rule_position) { if (init_rule_position == kInvalidRulePosition) { // Initialize the stack with the main rule. auto main_rule = grammar_->GetMainRule(); auto main_rule_body = grammar_->GetRuleExpr(main_rule.body_expr_id); - std::vector new_stack_tops; + std::vector stack_tops; for (auto i : main_rule_body) { auto init_rule_position = RulePosition(0, i, 0, RulePosition::kNoParent); - UpdateCharClassStarId(&init_rule_position); - ExpandRulePosition(init_rule_position, &new_stack_tops, true); + if (expand_init_rule_position) { + ExpandRulePosition(init_rule_position, &stack_tops, true); + } else { + stack_tops.push_back(tree_.NewNode(init_rule_position)); + } } - stack_tops_history_.PushHistory(new_stack_tops); + stack_tops_history_.PushHistory(stack_tops); } else { - stack_tops_history_.PushHistory({tree_.NewNode(init_rule_position)}); - } -} - -inline void GrammarStateMatcherBase::UpdateCharClassStarId(RulePosition* rule_position) const { - auto rule_expr = grammar_->GetRuleExpr(rule_position->sequence_id); - auto element = grammar_->GetRuleExpr(rule_expr[rule_position->element_id]); - if (element.type == RuleExprType::kRuleRef) { - auto sub_rule_body = grammar_->GetRuleExpr(grammar_->GetRule(element[0]).body_expr_id); - if (sub_rule_body.type == RuleExprType::kCharacterClassStar) { - rule_position->char_class_star_id = sub_rule_body[0]; + if (expand_init_rule_position) { + std::vector stack_tops; + ExpandRulePosition(init_rule_position, &stack_tops, true); + stack_tops_history_.PushHistory(stack_tops); + } else { + stack_tops_history_.PushHistory({tree_.NewNode(init_rule_position)}); } } } -inline RulePosition GrammarStateMatcherBase::IterateToNextPosition( +inline std::pair GrammarStateMatcherBase::GetNextPositionInSequence( const RulePosition& rule_position, bool consider_parent) const { - auto next_position = RulePosition(rule_position.rule_id, rule_position.sequence_id, - rule_position.element_id + 1, rule_position.parent_id); - auto rule_expr = grammar_->GetRuleExpr(rule_position.sequence_id); - auto current_sequence_length = rule_expr.size(); - DCHECK(next_position.element_id <= current_sequence_length); - - if (next_position.element_id < current_sequence_length) { - // Update char_class_star_id if the position refers to a character class star rule. - UpdateCharClassStarId(&next_position); - return next_position; + auto sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + + auto next_position = rule_position; + next_position.element_id += 1; + next_position.element_in_string = 0; + next_position.left_utf8_bytes = 0; + + DCHECK(next_position.element_id <= sequence.size()); + + if (next_position.element_id < sequence.size()) { + return {true, next_position}; } if (!consider_parent) { - return kInvalidRulePosition; + return {false, kInvalidRulePosition}; } - if (next_position.parent_id == RulePosition::kNoParent) { - return next_position; - } else { - auto parent_rule_position = tree_[next_position.parent_id]; - return IterateToNextPosition(parent_rule_position, true); + // Find the next position in the parent rule + while (next_position.parent_id != RulePosition::kNoParent) { + next_position = tree_[next_position.parent_id]; + next_position.element_id += 1; + DCHECK(next_position.element_in_string == 0); + DCHECK(next_position.left_utf8_bytes == 0); + + sequence = grammar_->GetRuleExpr(next_position.sequence_id); + DCHECK(next_position.element_id <= sequence.size()); + + if (next_position.element_id < sequence.size()) { + break; + } } + + return {true, next_position}; } inline bool GrammarStateMatcherBase::ExpandRulePosition(RulePosition cur_rule_position, std::vector* new_stack_tops, - bool is_outmost_level, + bool consider_parent, int32_t first_id_if_inserted) { bool is_first = false; + bool is_iteration_successful = true; - for (; cur_rule_position != kInvalidRulePosition; - cur_rule_position = IterateToNextPosition(cur_rule_position, is_outmost_level)) { + for (; is_iteration_successful; + std::tie(is_iteration_successful, cur_rule_position) = + GetNextPositionInSequence(cur_rule_position, consider_parent)) { // Insert the node to the tree, if not inserted before. int32_t new_node_id; if (is_first && first_id_if_inserted != -1) { @@ -263,7 +346,7 @@ inline bool GrammarStateMatcherBase::ExpandRulePosition(RulePosition cur_rule_po is_first = false; // Case 1. The current position points to the end of the grammar. - if (is_outmost_level) { + if (consider_parent) { if (tree_.IsEndPosition(cur_rule_position)) { new_stack_tops->push_back(new_node_id); return true; @@ -272,42 +355,39 @@ inline bool GrammarStateMatcherBase::ExpandRulePosition(RulePosition cur_rule_po DCHECK(!tree_.IsEndPosition(cur_rule_position)); } - // Case 2. The current position refers to a character class star rule. It can be empty. - if (cur_rule_position.char_class_star_id != -1) { - new_stack_tops->push_back(new_node_id); - continue; - } - - // Case 3. Character class: cannot be empty. auto sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); auto element = grammar_->GetRuleExpr(sequence[cur_rule_position.element_id]); - if (element.type == RuleExprType::kCharacterClass || - element.type == RuleExprType::kNegCharacterClass) { - new_stack_tops->push_back(new_node_id); - return false; - } - - // Case 4. The current position refers to a normal rule, i.e. a rule of choices of sequences. - DCHECK(element.type == RuleExprType::kRuleRef); - auto sub_rule_id = element[0]; - auto sub_rule = grammar_->GetRule(sub_rule_id); - auto sub_rule_body = grammar_->GetRuleExpr(sub_rule.body_expr_id); - DCHECK(sub_rule_body.type == RuleExprType::kChoices); - - bool contain_empty = false; - - for (auto sequence_id : sub_rule_body) { - auto sequence = grammar_->GetRuleExpr(sequence_id); - if (sequence.type == RuleExprType::kEmptyStr) { - contain_empty = true; - continue; + bool can_be_empty = false; + + if (element.type == RuleExprType::kRuleRef) { + // Case 2. The current position refers to another rule. + auto ref_rule = grammar_->GetRule(element[0]); + auto ref_rule_body = grammar_->GetRuleExpr(ref_rule.body_expr_id); + DCHECK(ref_rule_body.type == RuleExprType::kChoices); + + for (auto sequence_id : ref_rule_body) { + auto ref_rule_sequence = grammar_->GetRuleExpr(sequence_id); + if (ref_rule_sequence.type == RuleExprType::kEmptyStr) { + can_be_empty = true; + continue; + } + auto ref_rule_position = RulePosition(element[0], sequence_id, 0, new_node_id); + // Find the positions in every choice of the referred rule + can_be_empty |= ExpandRulePosition(ref_rule_position, new_stack_tops, false); } - auto sub_rule_position = RulePosition(sub_rule_id, sequence_id, 0, new_node_id); - UpdateCharClassStarId(&sub_rule_position); - contain_empty |= ExpandRulePosition(sub_rule_position, new_stack_tops, false); + } else if (element.type == RuleExprType::kCharacterClass || + element.type == RuleExprType::kByteString) { + // Case 3. Character class or byte string. cannot be empty. + new_stack_tops->push_back(new_node_id); + can_be_empty = false; + } else { + DCHECK(element.type == RuleExprType::kCharacterClassStar); + // Case 4. Character class star. Might be empty. + new_stack_tops->push_back(new_node_id); + can_be_empty = cur_rule_position.left_utf8_bytes == 0; } - if (!contain_empty) { + if (!can_be_empty) { return false; } } diff --git a/cpp/serve/grammar/grammar_state_matcher_preproc.h b/cpp/serve/grammar/grammar_state_matcher_preproc.h index f63eee2c5c..dc9fb9646e 100644 --- a/cpp/serve/grammar/grammar_state_matcher_preproc.h +++ b/cpp/serve/grammar/grammar_state_matcher_preproc.h @@ -9,6 +9,7 @@ #include #include "../../support/encoding.h" +#include "../../support/utils.h" #include "grammar.h" #include "grammar_state_matcher_base.h" @@ -18,34 +19,47 @@ namespace serve { using namespace tvm::runtime; -/*! \brief A token and its id. */ -struct TokenAndId { - std::vector token; - int32_t id; - /*! \brief Compare tokens by their unicode codepoint sequence. */ - bool operator<(const TokenAndId& other) const; -}; - /*! - * \brief Preprocessed information, for a given specific rule and position, divides the token set + * \brief Preprocessed information, for a given specific RulePosition, divides the token set * into three categories: accepted, rejected, and uncertain. - * \note Since the union of these three sets is the whole token set, we only need to store the - * smaller two sets. The unsaved set is specified by not_saved_index. - * \note These indices are the indices of sorted_token_codepoints in the GrammarStateInitContext + * Accepted: tokens that can be determined by the current RulePosition to be acceptable + * Rejected: tokens that can be determined by the current RulePosition to be unacceptable + * Uncertain: tokens that need the state of the parent RulePositions to determine if acceptable + * + * \note uncertain indices are stored directly. Accepted / rejected indices have three ways to + * store to reduce memory and computation usage. See SaveType. + * \note These indices are the indices of sorted_token_table in the GrammarStateInitContext * object, instead of the token ids. That helps the matching process. */ struct CatagorizedTokens { + enum class SaveType { + // Only store all accepted token indices. Then rejected indices = all_indices - accepted_indices + // - uncertain_indices. This is useful when |accepted_indices| < |rejected_indices|. + kAccepted = 0, + // Only store all accepted token indices. Then accepted indices = all_indices - rejected_indices + // - uncertain_indices. This is useful when |accepted_indices| > |rejected_indices|. + kRejected = 1, + // Store all accepted token indices in a bitset. This is useful when both |accepted_indices| and + // |rejected_indices| are large. + kAcceptedBitset = 2 + }; + SaveType save_type; + + static constexpr int USE_BITSET_THRESHOLD = 200; + std::vector accepted_indices; std::vector rejected_indices; + DynamicBitset accepted_bitset; + std::vector uncertain_indices; - enum class NotSavedIndex { kAccepted = 0, kRejected = 1, kUncertain = 2 }; - NotSavedIndex not_saved_index; CatagorizedTokens() = default; - CatagorizedTokens(std::vector&& accepted_indices, - std::vector&& rejected_indices, - std::vector&& uncertain_indices); + CatagorizedTokens(int vocab_size, + const std::vector>& sorted_token_table, + const std::vector& accepted_indices, + const std::vector& rejected_indices, + const std::vector& uncertain_indices); }; /*! @@ -57,189 +71,227 @@ class GrammarStateInitContext { public: /******************* Information about the tokenizer *******************/ - /*! \brief The token table. Now only used for debug purpose. */ - std::vector token_table; - /*! \brief The vocabulary size of the tokenizer. */ + /*! \brief The vocabulary size of the tokenizer. Special tokens are included. */ size_t vocab_size; - /*! \brief All tokens represented by the id and codepoints of each. The tokens are sorted by - * codepoint values to reuse the common prefix during matching. */ - std::vector sorted_token_codepoints; - /*! \brief The mapping from token id to token represented by codepoints. Only contains - * non-special and non-stop tokens. */ - std::unordered_map id_to_token_codepoints; - /*! \brief The stop tokens. They can be accepted iff GramamrMatcher can reach the end of the - * grammar. */ + /*! \brief The token table. Special tokens are included. */ + std::vector token_table; + /*! \brief All (id, token) pairs sorted in lexicographic order. This sorting is done to + * maximize prefix reuse during matching. Special tokens and stop tokens are not included. */ + std::vector> sorted_token_table; + /*! \brief The stop tokens. When the GrammarStateMatcher can reach the end of the= grammar, + * stop tokens can be accepted. */ std::vector stop_token_ids; - /*! \brief The special tokens. Currently we will ignore these tokens during grammar-guided - * matching. */ - std::vector special_token_ids; + /*! \brief The special tokens. These tokens are ignored (masked out) during the grammar-guided + * generation. */ + std::unordered_set special_token_ids; /******************* Information about the grammar *******************/ + /*! \brief The grammar for the GrammarStateMatcher. */ BNFGrammar grammar; /******************* Grammar-specific tokenizer information *******************/ - /*! \brief A sequence id and its position. */ - struct SequenceIdAndPosition { - int32_t sequence_id; - int32_t element_id; - bool operator==(const SequenceIdAndPosition& other) const { - return sequence_id == other.sequence_id && element_id == other.element_id; + struct RulePositionEqual { + std::size_t operator()(const RulePosition& lhs, const RulePosition& rhs) const noexcept { + return lhs.sequence_id == rhs.sequence_id && lhs.element_id == rhs.element_id && + lhs.left_utf8_bytes == rhs.left_utf8_bytes && + lhs.element_in_string == rhs.element_in_string; } }; - /*! \brief Hash function for SequenceIdAndPosition. */ - struct SequenceIdAndPositionHash { - std::size_t operator()(const SequenceIdAndPosition& k) const { - return std::hash()(k.sequence_id) ^ (std::hash()(k.element_id) << 1); + struct RulePositionHash { + std::size_t operator()(const RulePosition& rule_position) const noexcept { + return HashCombine(rule_position.sequence_id, rule_position.element_id, + rule_position.left_utf8_bytes, rule_position.element_in_string); } }; - /*! \brief Mapping from sequence id and its position to the catagorized tokens. */ - std::unordered_map + /*! \brief Mapping from RulePositions to the catagorized tokens. */ + std::unordered_map catagorized_tokens_for_grammar; }; -/* \brief The concrete implementation of GrammarStateMatcherNode. */ +/*! \brief The concrete implementation of GrammarStateMatcherNode. */ class GrammarStateMatcherForInitContext : public GrammarStateMatcherBase { public: + // Do not expand the initial rule position: we want to find the accepted/rejected tokens + // that exactly start from the initial rule position. GrammarStateMatcherForInitContext(const BNFGrammar& grammar, RulePosition init_rule_position) - : GrammarStateMatcherBase(grammar, init_rule_position) {} - - CatagorizedTokens GetCatagorizedTokens(const std::vector& sorted_token_codepoints, - bool is_main_rule); + : GrammarStateMatcherBase(grammar, init_rule_position, false), + init_rule_id(init_rule_position.rule_id) {} + + /*! + * \brief Get the catagorized tokens for the given RulePosition. + * \param consider_parent_rule Whether to consider the parent rule. If false, there will be + * no uncertain tokens. Useful for the main rule. + */ + CatagorizedTokens GetCatagorizedTokens( + int vocab_size, const std::vector>& sorted_token_table, + bool consider_parent_rule); private: using RuleExpr = BNFGrammarNode::RuleExpr; using RuleExprType = BNFGrammarNode::RuleExprType; + /*! \brief Check if a token can pass the lookahead assertion. */ + bool IsTokenPassLookaheadAssertion(const std::string& token, + const std::vector& can_reach_end_stack); + + // The id of the initial rule. + int32_t init_rule_id; + // Temporary data for GetCatagorizedTokens. std::vector tmp_accepted_indices_; std::vector tmp_rejected_indices_; std::vector tmp_uncertain_indices_; - std::vector tmp_can_see_end_stack_; + std::vector tmp_can_reach_end_stack_; + std::vector tmp_can_reach_end_prefix_or_stack_; }; -inline bool TokenAndId::operator<(const TokenAndId& other) const { - for (size_t i = 0; i < token.size(); ++i) { - if (i >= other.token.size()) { - return false; - } - if (token[i] < other.token[i]) { - return true; - } else if (token[i] > other.token[i]) { - return false; +inline CatagorizedTokens::CatagorizedTokens( + int vocab_size, const std::vector>& sorted_token_table, + const std::vector& accepted_indices, const std::vector& rejected_indices, + const std::vector& uncertain_indices) { + auto size_acc = accepted_indices.size(); + auto size_rej = rejected_indices.size(); + + save_type = size_acc >= USE_BITSET_THRESHOLD && size_rej >= USE_BITSET_THRESHOLD + ? SaveType::kAcceptedBitset + : size_acc < size_rej ? SaveType::kAccepted + : SaveType::kRejected; + + if (save_type == SaveType::kAcceptedBitset) { + accepted_bitset = DynamicBitset(vocab_size); + for (auto idx : accepted_indices) { + accepted_bitset.Set(sorted_token_table[idx].first, true); } + } else if (save_type == SaveType::kAccepted) { + this->accepted_indices = accepted_indices; + } else { + this->rejected_indices = rejected_indices; } - return token.size() < other.token.size(); + + this->uncertain_indices = uncertain_indices; } -inline CatagorizedTokens::CatagorizedTokens(std::vector&& accepted_indices, - std::vector&& rejected_indices, - std::vector&& uncertain_indices) { - auto size_acc = accepted_indices.size(); - auto size_rej = rejected_indices.size(); - auto size_unc = uncertain_indices.size(); - not_saved_index = - (size_acc >= size_rej && size_acc >= size_unc) - ? NotSavedIndex::kAccepted - : (size_rej >= size_unc ? NotSavedIndex::kRejected : NotSavedIndex::kUncertain); - - if (not_saved_index != NotSavedIndex::kAccepted) { - this->accepted_indices = std::move(accepted_indices); +bool GrammarStateMatcherForInitContext::IsTokenPassLookaheadAssertion( + const std::string& token, const std::vector& can_reach_end_stack) { + auto lookahead_assertion_id = grammar_->GetRule(init_rule_id).lookahead_assertion_id; + if (lookahead_assertion_id == -1) { + return true; } - if (not_saved_index != NotSavedIndex::kRejected) { - this->rejected_indices = std::move(rejected_indices); - } - if (not_saved_index != NotSavedIndex::kUncertain) { - this->uncertain_indices = std::move(uncertain_indices); + auto lookahead_rule_position = RulePosition(-1, lookahead_assertion_id, 0); + PushInitialState(lookahead_rule_position, true); + int token_len = token.size(); + + // Find all positions that can come to and end. Then check if the suffix from that position + // can be accepted by the lookahead assertion. + for (int i = static_cast(can_reach_end_stack.size()); i >= 0; --i) { + if (!can_reach_end_stack[i]) { + continue; + } + int last_accept_pos = i - 1; + for (int pos = i; pos < token_len; ++pos) { + if (!AcceptChar(token[pos])) { + break; + } + last_accept_pos = pos; + // Case 1. The whole rule is finished. + if (CanReachEnd()) { + // accepted chars: pos - i + 1 + // we need to rollback the pushed initial state as well + RollbackChars(pos - i + 2); + return true; + } + } + // Case 2. The whole token is accepted + if (last_accept_pos == token_len - 1) { + RollbackChars(last_accept_pos - i + 2); + return true; + } + // Case 3. The token is not accepted. Check the next position. + RollbackChars(last_accept_pos - i + 1); } + + RollbackChars(1); + return false; } inline CatagorizedTokens GrammarStateMatcherForInitContext::GetCatagorizedTokens( - const std::vector& sorted_token_codepoints, bool is_main_rule) { - // Support the current stack contains only one stack with one RulePosition. - // Iterate over all tokens. Split them into three categories: - // - accepted_indices: If a token is accepted by current rule - // - rejected_indices: If a token is rejected by current rule - // - uncertain_indices: If a prefix of a token is accepted by current rule and comes to the end - // of the rule. - - // Note many tokens may contain the same prefix, so we will avoid unnecessary matching - + int vocab_size, const std::vector>& sorted_token_table, + bool consider_parent_rule) { tmp_accepted_indices_.clear(); tmp_rejected_indices_.clear(); tmp_uncertain_indices_.clear(); + // For every character in the current token, stores whether it is possible to reach the end of - // the rule when matching until this character. Useful for rollback. - tmp_can_see_end_stack_.assign({CanReachEnd()}); + // the rule when matching until this character. Store it in a stack for later rollback. + tmp_can_reach_end_stack_.assign({CanReachEnd()}); + tmp_can_reach_end_prefix_or_stack_.assign({tmp_can_reach_end_stack_.back()}); int prev_matched_size = 0; - for (int i = 0; i < static_cast(sorted_token_codepoints.size()); ++i) { - const auto& token = sorted_token_codepoints[i].token; - const auto* prev_token = i > 0 ? &sorted_token_codepoints[i - 1].token : nullptr; - - // Find the longest common prefix with the accepted part of the previous token. - auto prev_useful_size = 0; - if (prev_token) { - prev_useful_size = std::min(prev_matched_size, static_cast(token.size())); - for (int j = 0; j < prev_useful_size; ++j) { - if (token[j] != (*prev_token)[j]) { - prev_useful_size = j; - break; - } - } - RollbackCodepoints(prev_matched_size - prev_useful_size); - tmp_can_see_end_stack_.erase( - tmp_can_see_end_stack_.end() - (prev_matched_size - prev_useful_size), - tmp_can_see_end_stack_.end()); - } + for (int i = 0; i < static_cast(sorted_token_table.size()); ++i) { + const auto& token = sorted_token_table[i].second; - // Find if the current token is accepted or rejected or uncertain. bool accepted = true; - bool can_see_end = tmp_can_see_end_stack_.back(); - prev_matched_size = prev_useful_size; - for (int j = prev_useful_size; j < token.size(); ++j) { - if (!AcceptCodepoint(token[j], false)) { + + // Many tokens may contain the same prefix, so we will avoid unnecessary matching + // by finding the longest common prefix with the previous token. + if (i > 0) { + const auto& prev_token = sorted_token_table[i - 1].second; + int lcp_len = + std::mismatch(token.begin(), token.end(), prev_token.begin(), prev_token.end()).first - + token.begin(); + if (lcp_len > prev_matched_size) { + // Case 1. The common prefix is rejected by the matcher in the last token. Reject directly. accepted = false; - break; + } else if (lcp_len < prev_matched_size) { + // Case 2. The common prefix is shorter than the previous matched size. Rollback + // the non-common part. + RollbackChars(prev_matched_size - lcp_len); + tmp_can_reach_end_stack_.erase( + tmp_can_reach_end_stack_.end() - (prev_matched_size - lcp_len), + tmp_can_reach_end_stack_.end()); + tmp_can_reach_end_prefix_or_stack_.erase( + tmp_can_reach_end_prefix_or_stack_.end() - (prev_matched_size - lcp_len), + tmp_can_reach_end_prefix_or_stack_.end()); } - if (CanReachEnd()) { - can_see_end = true; + prev_matched_size = std::min(prev_matched_size, lcp_len); + } + + if (accepted) { + // Accept the rest chars one by one + for (int j = prev_matched_size; j < token.size(); ++j) { + if (!AcceptChar(token[j], false)) { + accepted = false; + break; + } + tmp_can_reach_end_stack_.push_back(CanReachEnd()); + tmp_can_reach_end_prefix_or_stack_.push_back(tmp_can_reach_end_stack_.back() || + tmp_can_reach_end_prefix_or_stack_.back()); + prev_matched_size = j + 1; } - tmp_can_see_end_stack_.push_back(can_see_end); - prev_matched_size = j + 1; } + + bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back(); + if (accepted) { tmp_accepted_indices_.push_back(i); - } else if (can_see_end && !is_main_rule) { - // If the current rule is the main rule, there will be no uncertain indices since we will - // never consider its parent rule. Unaccepted tokens are just rejected. + } else if (can_reach_end && consider_parent_rule && + IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_)) { + // 1. If the current rule is the main rule (consider_parent_rule=false), there are no + // uncertain tokens. Not accepted tokens are just rejected. + // 2. If a token cannot pass the lookahead assertion, it is rejected. tmp_uncertain_indices_.push_back(i); } else { tmp_rejected_indices_.push_back(i); } } - RollbackCodepoints(prev_matched_size); - return CatagorizedTokens(std::move(tmp_accepted_indices_), std::move(tmp_rejected_indices_), - std::move(tmp_uncertain_indices_)); -} - -inline std::string ReplaceUnderscoreWithSpace(const std::string& str, - const std::string& kSpecialUnderscore) { - std::string res; - size_t pos = 0; - while (pos < str.size()) { - size_t found = str.find(kSpecialUnderscore, pos); - if (found == std::string::npos) { - res += str.substr(pos); - break; - } - res += str.substr(pos, found - pos) + " "; - pos = found + kSpecialUnderscore.size(); - } - return res; + // Rollback the last matched part + RollbackChars(prev_matched_size); + return CatagorizedTokens(vocab_size, sorted_token_table, tmp_accepted_indices_, + tmp_rejected_indices_, tmp_uncertain_indices_); } inline std::shared_ptr GrammarStateMatcher::CreateInitContext( @@ -248,87 +300,94 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC auto ptr = std::make_shared(); ptr->grammar = grammar; - ptr->token_table = token_table; ptr->vocab_size = token_table.size(); + ptr->token_table = token_table; if (ptr->vocab_size == 0) { return ptr; } for (int i = 0; i < token_table.size(); ++i) { - auto token = token_table[i]; - if (token == "" || token == "" || token == "") { - ptr->special_token_ids.push_back(i); - } else if (token == "") { + const auto& token = token_table[i]; + // LLaMA2: + // LLaMA3: <|end_of_text|>, <|eot_id|> + // Phi-2: <|endoftext|> + // Gemma: , + if (token == "" || token == "<|end_of_text|>" || token == "<|eot_id|>" || + token == "<|endoftext|>" || token == "" || token == "") { ptr->stop_token_ids.push_back(i); - } else if (token.size() == 1 && - (static_cast(token[0]) >= 128 || token[0] == 0)) { - // Currently we consider all tokens with one character that >= 128 as special tokens, - // and will ignore generating them during grammar-guided generation. - ptr->special_token_ids.push_back(i); + } else if ((token[0] == '<' && token[token.size() - 1] == '>' && token.size() >= 3) || + token == "[@BOS@]") { + // gemma treats [@BOS@] as a special token + ptr->special_token_ids.insert(i); } else { - // First replace the special underscore with space. - auto codepoints = ParseUTF8(token.c_str()); - DCHECK(!codepoints.empty() && - codepoints[0] != static_cast(CharHandlingError::kInvalidUtf8)) - << "Invalid token: " << token; - ptr->sorted_token_codepoints.push_back({codepoints, i}); - ptr->id_to_token_codepoints[i] = {codepoints, i}; + ptr->sorted_token_table.push_back({i, token}); } } - std::sort(ptr->sorted_token_codepoints.begin(), ptr->sorted_token_codepoints.end()); + + auto f_compare_token = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(ptr->sorted_token_table.begin(), ptr->sorted_token_table.end(), f_compare_token); // Find the corresponding catagorized tokens for: - // 1. All character elements in the grammar - // 2. All RuleRef elements that refers to a rule containing a CharacterClassStar RuleExpr. - for (int i = 0; i < static_cast(grammar->NumRules()); ++i) { - auto rule = grammar->GetRule(i); - auto rule_expr = grammar->GetRuleExpr(rule.body_expr_id); - // Skip CharacterClassStar since we just handle it at the reference element during matching. - if (rule_expr.type == RuleExprType::kCharacterClassStar) { - continue; - } - DCHECK(rule_expr.type == RuleExprType::kChoices); - for (auto sequence_id : rule_expr) { - auto sequence_expr = grammar->GetRuleExpr(sequence_id); - if (sequence_expr.type == RuleExprType::kEmptyStr) { + // 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3) + // 2. All byte strings (with element_in_string=0, 1, 2, ...) + auto main_rule_id = grammar->GetMainRuleId(); + for (int rule_id = 0; rule_id < static_cast(grammar->NumRules()); ++rule_id) { + auto rule = grammar->GetRule(rule_id); + auto rule_body = grammar->GetRuleExpr(rule.body_expr_id); + DCHECK(rule_body.type == RuleExprType::kChoices); + for (auto sequence_id : rule_body) { + auto sequence = grammar->GetRuleExpr(sequence_id); + if (sequence.type == RuleExprType::kEmptyStr) { continue; } - DCHECK(sequence_expr.type == RuleExprType::kSequence); - for (int element_id = 0; element_id < sequence_expr.size(); ++element_id) { - auto element_expr = grammar->GetRuleExpr(sequence_expr[element_id]); - auto cur_rule_position = RulePosition{i, sequence_id, element_id}; - if (element_expr.type == RuleExprType::kRuleRef) { - auto ref_rule = grammar->GetRule(element_expr[0]); - auto ref_rule_expr = grammar->GetRuleExpr(ref_rule.body_expr_id); - if (ref_rule_expr.type == RuleExprType::kChoices) { - continue; - } else { - // Reference to a CharacterClassStar of a character class. - cur_rule_position.char_class_star_id = ref_rule_expr[0]; - } + DCHECK(sequence.type == RuleExprType::kSequence); + for (int element_id = 0; element_id < sequence.size(); ++element_id) { + auto element = grammar->GetRuleExpr(sequence[element_id]); + if (element.type == RuleExprType::kRuleRef) { + continue; } - auto grammar_state_matcher = GrammarStateMatcherForInitContext(grammar, cur_rule_position); - auto cur_catagorized_tokens_for_grammar = - grammar_state_matcher.GetCatagorizedTokens(ptr->sorted_token_codepoints, i == 0); - ptr->catagorized_tokens_for_grammar[{sequence_id, element_id}] = - cur_catagorized_tokens_for_grammar; + auto add_catagorized_tokens = [&](const RulePosition& rule_position) { + auto grammar_state_matcher = GrammarStateMatcherForInitContext(grammar, rule_position); + auto cur_catagorized_tokens_for_grammar = grammar_state_matcher.GetCatagorizedTokens( + ptr->vocab_size, ptr->sorted_token_table, rule_id != main_rule_id); + ptr->catagorized_tokens_for_grammar[rule_position] = cur_catagorized_tokens_for_grammar; + }; + + auto cur_rule_position = RulePosition(rule_id, sequence_id, element_id); + if (element.type == RuleExprType::kByteString) { + for (int idx = 0; idx < element.size(); ++idx) { + cur_rule_position.element_in_string = idx; + add_catagorized_tokens(cur_rule_position); + } + } else { + DCHECK(element.type == RuleExprType::kCharacterClassStar || + element.type == RuleExprType::kCharacterClass); + for (int left_utf8_bytes = 0; left_utf8_bytes <= 3; ++left_utf8_bytes) { + cur_rule_position.left_utf8_bytes = left_utf8_bytes; + add_catagorized_tokens(cur_rule_position); + } + } } } } return ptr; } -class GrammarInitContextStorageImpl : public GrammarInitContextStorageNode { +class GrammarInitContextCacheImpl : public GrammarInitContextCacheNode { public: - GrammarInitContextStorageImpl(const std::vector& token_table); + GrammarInitContextCacheImpl(const std::vector& token_table); - std::shared_ptr GetInitContextForJSONSchema(const std::string& schema); + std::shared_ptr GetInitContextForJSONSchema( + const std::string& schema) final; - std::shared_ptr GetInitContextForJSON(); + std::shared_ptr GetInitContextForJSON() final; - void ClearCache(); + void Clear() final; private: /*! \brief The token table associated with this storage class. */ @@ -340,7 +399,7 @@ class GrammarInitContextStorageImpl : public GrammarInitContextStorageNode { std::shared_ptr init_ctx_for_json_; }; -inline GrammarInitContextStorageImpl::GrammarInitContextStorageImpl( +inline GrammarInitContextCacheImpl::GrammarInitContextCacheImpl( const std::vector& token_table) : token_table_(token_table) { init_ctx_for_json_ = @@ -348,7 +407,7 @@ inline GrammarInitContextStorageImpl::GrammarInitContextStorageImpl( } inline std::shared_ptr -GrammarInitContextStorageImpl::GetInitContextForJSONSchema(const std::string& schema) { +GrammarInitContextCacheImpl::GetInitContextForJSONSchema(const std::string& schema) { auto it = init_ctx_for_schema_cache_.find(schema); if (it != init_ctx_for_schema_cache_.end()) { return it->second; @@ -360,14 +419,14 @@ GrammarInitContextStorageImpl::GetInitContextForJSONSchema(const std::string& sc } inline std::shared_ptr -GrammarInitContextStorageImpl::GetInitContextForJSON() { +GrammarInitContextCacheImpl::GetInitContextForJSON() { return init_ctx_for_json_; } -inline void GrammarInitContextStorageImpl::ClearCache() { init_ctx_for_schema_cache_.clear(); } +inline void GrammarInitContextCacheImpl::Clear() { init_ctx_for_schema_cache_.clear(); } -GrammarInitContextStorage::GrammarInitContextStorage(const std::vector& token_table) - : ObjectRef(make_object(token_table)) {} +GrammarInitContextCache::GrammarInitContextCache(const std::vector& token_table) + : ObjectRef(make_object(token_table)) {} } // namespace serve } // namespace llm diff --git a/cpp/serve/grammar/grammar_state_matcher_state.h b/cpp/serve/grammar/grammar_state_matcher_state.h index 47f3e11c7b..1b8a34074f 100644 --- a/cpp/serve/grammar/grammar_state_matcher_state.h +++ b/cpp/serve/grammar/grammar_state_matcher_state.h @@ -20,18 +20,20 @@ using namespace tvm::runtime; /*! \brief Specifies a position in a rule. */ struct RulePosition { - /*! \brief The rule's id. */ + /*! \brief The rule's id. Used for debug purposes. */ int32_t rule_id = -1; /*! \brief Which choice in this rule is selected. */ int32_t sequence_id = -1; - /*! \brief Which element of the choice sequence is being visited. */ + /*! \brief Which element of the choice sequence is to be visited. */ int32_t element_id = -1; - /*! - * \brief If the element refers to another rule, and the body of another rule is a - * CharacterClassStar RuleExpr, this field will be set to the id of the character class. - * This is for the special support of CharacterClassStar. - */ - int32_t char_class_star_id = -1; + + /*! \brief The number of left utf8 bytes in the current element. Used when the element is + * a character class or a character class star. */ + int32_t left_utf8_bytes = 0; + /*! \brief The next position to match in the current byte string. Used when the element is + * a byte string. */ + int32_t element_in_string = 0; + /*! \brief The id of the parent node in the RulePositionTree. */ int32_t parent_id = -1; /*! \brief The reference count of this RulePosition. If reduces to zero, the node will be @@ -43,24 +45,21 @@ struct RulePosition { constexpr RulePosition() = default; constexpr RulePosition(int32_t rule_id, int32_t sequence_id, int32_t element_id, - int32_t parent_id = kNoParent, int32_t char_class_star_id = -1) - : rule_id(rule_id), - sequence_id(sequence_id), - element_id(element_id), - char_class_star_id(char_class_star_id), - parent_id(parent_id) {} + int32_t parent_id = kNoParent) + : rule_id(rule_id), sequence_id(sequence_id), element_id(element_id), parent_id(parent_id) {} + + // The position is invalid when sequence_id is -1. + bool IsInvalid() const { return sequence_id == -1; } bool operator==(const RulePosition& other) const { return rule_id == other.rule_id && sequence_id == other.sequence_id && - element_id == other.element_id && char_class_star_id == other.char_class_star_id && - parent_id == other.parent_id; + element_id == other.element_id && parent_id == other.parent_id && + left_utf8_bytes == other.left_utf8_bytes && element_in_string == other.element_in_string; } - - bool operator!=(const RulePosition& other) const { return !(*this == other); } }; /*! \brief A special value for invalid RulePosition. */ -inline constexpr RulePosition kInvalidRulePosition(-1, -1, -1, -1, -1); +inline constexpr RulePosition kInvalidRulePosition(-1, -1, -1, -1); /*! \brief A buffer to manage all RulePositions. */ class RulePositionBuffer { @@ -76,7 +75,7 @@ class RulePositionBuffer { id = buffer_.size() - 1; } else { id = free_nodes_.back(); - DCHECK(buffer_[id] == kInvalidRulePosition); + DCHECK(buffer_[id].IsInvalid()); free_nodes_.pop_back(); } rule_position.reference_count = 0; @@ -86,7 +85,7 @@ class RulePositionBuffer { /*! \brief Free the RulePosition with the given id. */ void Free(int32_t id) { - DCHECK(buffer_[id] != kInvalidRulePosition); + DCHECK(!buffer_[id].IsInvalid()); buffer_[id] = kInvalidRulePosition; free_nodes_.push_back(id); } @@ -102,11 +101,13 @@ class RulePositionBuffer { /*! \brief Get the RulePosition with the given id. */ RulePosition& operator[](int32_t id) { - DCHECK(id < static_cast(buffer_.size()) && buffer_[id] != kInvalidRulePosition); + DCHECK(id >= 0 && id < static_cast(buffer_.size())); + DCHECK(!buffer_[id].IsInvalid()); return buffer_[id]; } const RulePosition& operator[](int32_t id) const { - DCHECK(id < static_cast(buffer_.size()) && buffer_[id] != kInvalidRulePosition); + DCHECK(id >= 0 && id < static_cast(buffer_.size())); + DCHECK(!buffer_[id].IsInvalid()); return buffer_[id]; } @@ -145,7 +146,7 @@ class RulePositionTree { auto id = node_buffer_.Allocate(rule_position); if (rule_position.parent_id != RulePosition::kNoParent) { DCHECK(rule_position.parent_id < static_cast(node_buffer_.Capacity()) && - node_buffer_[rule_position.parent_id] != kInvalidRulePosition); + !node_buffer_[rule_position.parent_id].IsInvalid()); node_buffer_[rule_position.parent_id].reference_count++; } return id; @@ -183,7 +184,7 @@ class RulePositionTree { /*! \brief Get the RulePosition with the given id. */ const RulePosition& operator[](int32_t id) const { DCHECK(id != RulePosition::kNoParent); - DCHECK(node_buffer_[id] != kInvalidRulePosition); + DCHECK(!node_buffer_[id].IsInvalid()); return node_buffer_[id]; } @@ -331,15 +332,26 @@ inline std::string RulePositionTree::PrintNode(int32_t id) const { inline std::string RulePositionTree::PrintNode(const RulePosition& rule_position) const { std::stringstream ss; - ss << "RulePosition: rule " << rule_position.rule_id << ": " - << grammar_->GetRule(rule_position.rule_id).name; + ss << "RulePosition: rule " << rule_position.rule_id; + if (rule_position.rule_id != -1) { + ss << ": " << grammar_->GetRule(rule_position.rule_id).name; + } ss << ", sequence " << rule_position.sequence_id << ": " << BNFGrammarPrinter(grammar_).PrintRuleExpr(rule_position.sequence_id); ss << ", element id: " << rule_position.element_id; - if (rule_position.char_class_star_id != -1) { - ss << ", char class " << rule_position.char_class_star_id << ": " - << BNFGrammarPrinter(grammar_).PrintRuleExpr(rule_position.char_class_star_id) << "*"; + + auto sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + if (rule_position.element_id < static_cast(sequence.size())) { + auto element = grammar_->GetRuleExpr(sequence[rule_position.element_id]); + if (element.type == BNFGrammarNode::RuleExprType::kByteString) { + ss << ", element in string: " << rule_position.element_in_string; + } else { + DCHECK(element.type == BNFGrammarNode::RuleExprType::kCharacterClass || + element.type == BNFGrammarNode::RuleExprType::kCharacterClassStar); + ss << ", left utf8 bytes: " << rule_position.left_utf8_bytes; + } } + ss << ", parent id: " << rule_position.parent_id << ", ref count: " << rule_position.reference_count; return ss.str(); @@ -370,7 +382,7 @@ inline void RulePositionTree::CheckWellFormed(const std::vector& outsid std::queue visit_queue; for (auto id : outside_pointers) { CHECK(id >= 0 && id < buffer_size); - CHECK(buffer[id] != kInvalidRulePosition); + CHECK(!buffer[id].IsInvalid()); new_reference_counter[id]++; if (visited[id] == false) { visited[id] = true; @@ -383,7 +395,7 @@ inline void RulePositionTree::CheckWellFormed(const std::vector& outsid const auto& rule_position = buffer[cur_id]; if (rule_position.parent_id != RulePosition::kNoParent) { CHECK(rule_position.parent_id >= 0 && rule_position.parent_id < buffer_size); - CHECK(buffer[rule_position.parent_id] != kInvalidRulePosition); + CHECK(!buffer[rule_position.parent_id].IsInvalid()); new_reference_counter[rule_position.parent_id]++; if (visited[rule_position.parent_id] == false) { visited[rule_position.parent_id] = true; @@ -394,11 +406,11 @@ inline void RulePositionTree::CheckWellFormed(const std::vector& outsid for (int i = 0; i < static_cast(buffer.size()); ++i) { if (free_nodes_set.count(i)) { - CHECK(buffer[i] == kInvalidRulePosition); + CHECK(buffer[i].IsInvalid()); CHECK(visited[i] == false); } else { CHECK(visited[i] == true); - CHECK(buffer[i] != kInvalidRulePosition); + CHECK(!buffer[i].IsInvalid()); CHECK(new_reference_counter[i] == buffer[i].reference_count) << "Reference counters unmatch for node #" << i << ": Updated " << new_reference_counter[i] << ", Original " << buffer[i].reference_count; diff --git a/cpp/serve/grammar/json_schema_converter.cc b/cpp/serve/grammar/json_schema_converter.cc index 83be710cf5..e0c465ba9e 100644 --- a/cpp/serve/grammar/json_schema_converter.cc +++ b/cpp/serve/grammar/json_schema_converter.cc @@ -385,9 +385,9 @@ void JSONSchemaToEBNFConverter::AddBasicRules() { void JSONSchemaToEBNFConverter::AddHelperRules() { rules_.push_back(std::make_pair( kBasicEscape, "[\"\\\\/bfnrt] | \"u\" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]")); - rules_.push_back(std::make_pair(kBasicStringSub, "\"\" | [^\"\\\\\\r\\n] " + kBasicStringSub + - " | \"\\\\\" " + kBasicEscape + " " + - kBasicStringSub)); + rules_.push_back(std::make_pair( + kBasicStringSub, "(\"\\\"\" | [^\"\\\\\\r\\n] " + kBasicStringSub + " | \"\\\\\" " + + kBasicEscape + " " + kBasicStringSub + ") (= [ \\n\\t]* [,}\\]:])")); } void JSONSchemaToEBNFConverter::CreateBasicRule(const picojson::value& schema, @@ -648,7 +648,7 @@ std::string JSONSchemaToEBNFConverter::VisitString(const picojson::object& schem "pattern", "format", }); - return "[\"] " + kBasicStringSub + " [\"]"; + return "[\"] " + kBasicStringSub; } std::string JSONSchemaToEBNFConverter::VisitBoolean(const picojson::object& schema, diff --git a/cpp/serve/grammar/support.h b/cpp/serve/grammar/support.h index fb9002dbac..c8b3f34344 100644 --- a/cpp/serve/grammar/support.h +++ b/cpp/serve/grammar/support.h @@ -8,30 +8,72 @@ #include +#include #include #include +#include namespace mlc { namespace llm { namespace serve { -/*! \brief Manages a segment of externally provided memory and use it as a bitset. */ -class BitsetManager { +/*! \brief A bitset with runtime specified length. It manages memory internally or the memory + * provided externally with enough size. */ +class DynamicBitset { public: - BitsetManager(uint32_t* data, int buffer_size, int element_cnt) - : data_(data), buffer_size_(buffer_size), element_cnt_(element_cnt) { - DCHECK(buffer_size >= CalculateBufferSize(element_cnt)); + static int CalculateBufferSize(int element_size) { return (element_size + 31) / 32; } + + DynamicBitset() : size_(0), buffer_size_(0), data_(nullptr), is_internal_(true) {} + + DynamicBitset(int size, uint32_t* data = nullptr) + : size_(size), buffer_size_(CalculateBufferSize(size)) { + if (data == nullptr) { + internal_buffer_.resize(buffer_size_, 0); + data_ = internal_buffer_.data(); + is_internal_ = true; + } else { + data_ = data; + is_internal_ = false; + } } - static int CalculateBufferSize(int element_cnt) { return (element_cnt + 31) / 32; } + DynamicBitset& operator=(const DynamicBitset& other) { + DCHECK(is_internal_ || size_ >= other.size_) << "Expanding bitset size is not allowed when the " + "memory of the bitset is externally managed"; + size_ = other.size_; + buffer_size_ = other.buffer_size_; + if (is_internal_) { + internal_buffer_.reserve(buffer_size_); + data_ = internal_buffer_.data(); + } + if (data_ != other.data_) { + std::memcpy(data_, other.data_, buffer_size_ * sizeof(uint32_t)); + } + return *this; + } + + DynamicBitset& operator=(DynamicBitset&& other) { + size_ = other.size_; + buffer_size_ = other.buffer_size_; + is_internal_ = other.is_internal_; + if (is_internal_) { + internal_buffer_ = std::move(other.internal_buffer_); + data_ = internal_buffer_.data(); + } else { + data_ = other.data_; + } + return *this; + } bool operator[](int index) const { - DCHECK(index >= 0 && index < element_cnt_); + DCHECK(data_ && index >= 0 && index < size_); return (data_[index / 32] >> (index % 32)) & 1; } + int Size() const { return size_; } + void Set(int index, bool value) { - DCHECK(index >= 0 && index < element_cnt_); + DCHECK(data_ && index >= 0 && index < size_); if (value) { data_[index / 32] |= 1 << (index % 32); } else { @@ -39,14 +81,30 @@ class BitsetManager { } } - void Reset(bool value) { std::memset(data_, value ? 0xFF : 0, buffer_size_ * sizeof(uint32_t)); } + void Set() { + DCHECK(data_); + std::memset(data_, 0xFF, buffer_size_ * sizeof(uint32_t)); + } + + void Reset() { + DCHECK(data_); + std::memset(data_, 0, buffer_size_ * sizeof(uint32_t)); + } - int GetElementCnt() const { return element_cnt_; } + DynamicBitset& operator|=(const DynamicBitset& other) { + DCHECK(buffer_size_ <= other.buffer_size_); + for (int i = 0; i < buffer_size_; ++i) { + data_[i] |= other.data_[i]; + } + return *this; + } private: - uint32_t* const data_; - const int buffer_size_; - const int element_cnt_; + int size_; + int buffer_size_; + uint32_t* data_; + std::vector internal_buffer_; + bool is_internal_; }; /*! diff --git a/cpp/support/encoding.cc b/cpp/support/encoding.cc index d9420bbbd5..9f33f98a7e 100644 --- a/cpp/support/encoding.cc +++ b/cpp/support/encoding.cc @@ -36,14 +36,15 @@ std::string PrintAsUTF8(TCodepoint codepoint) { return utf8; } -std::string PrintAsEscaped(TCodepoint codepoint, - const std::unordered_map& custom_escape_map) { +std::string PrintAsEscaped( + TCodepoint codepoint, + const std::unordered_map& additional_escape_map) { static const std::unordered_map kCodepointToEscape = { {'\'', "\\\'"}, {'\"', "\\\""}, {'\?', "\\\?"}, {'\\', "\\\\"}, {'\a', "\\a"}, {'\b', "\\b"}, {'\f', "\\f"}, {'\n', "\\n"}, {'\r', "\\r"}, {'\t', "\\t"}, {'\v', "\\v"}, {'\0', "\\0"}, {'\x1B', "\\e"}}; - if (auto it = custom_escape_map.find(codepoint); it != custom_escape_map.end()) { + if (auto it = additional_escape_map.find(codepoint); it != additional_escape_map.end()) { return it->second; } @@ -56,14 +57,24 @@ std::string PrintAsEscaped(TCodepoint codepoint, } // convert codepoint to hex - int width = codepoint <= 0xFFFF ? 4 : 8; + char prefix = codepoint <= 0xFF ? 'x' : codepoint <= 0xFFFF ? 'u' : 'U'; + int width = codepoint <= 0xFF ? 2 : codepoint <= 0xFFFF ? 4 : 8; std::stringstream ss; ss << std::setfill('0') << std::setw(width) << std::hex << codepoint; auto hex = ss.str(); - return codepoint <= 0xFFFF ? "\\u" + hex : "\\U" + hex; + return std::string("\\") + prefix + hex; } -std::pair ParseNextUTF8(const char* utf8) { +std::string PrintAsEscaped(std::string raw_str) { + std::string res; + auto codepoints = ParseUTF8(raw_str.c_str(), UTF8ErrorPolicy::kReturnByte); + for (auto c : codepoints) { + res += PrintAsEscaped(c); + } + return res; +} + +std::tuple HandleUTF8FirstByte(uint8_t byte) { static const std::array kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07}; // clang-format off static const std::array kUtf8Bytes = { @@ -85,30 +96,44 @@ std::pair ParseNextUTF8(const char* utf8) { 4, 4, 4, 4, 4, 4, 4, 4, -1, -1, -1, -1, -1, -1, -1, -1, }; // clang-format on + auto num_bytes = kUtf8Bytes[static_cast(byte)]; + if (num_bytes == -1) { + return {false, 0, 0}; + } + return {true, num_bytes, byte & kFirstByteMask[num_bytes]}; +} - auto bytes = kUtf8Bytes[static_cast(utf8[0])]; - if (bytes == -1) { - // invalid utf8 - return {static_cast(CharHandlingError::kInvalidUtf8), utf8}; +std::pair ParseNextUTF8(const char* utf8, UTF8ErrorPolicy error_policy) { + auto [accepted, num_bytes, res] = HandleUTF8FirstByte(utf8[0]); + if (accepted) { + for (int i = 1; i < num_bytes; ++i) { + if (utf8[i] == 0 || (static_cast(utf8[i]) & 0xC0) != 0x80) { + // invalid utf8 + accepted = false; + break; + } + res = (res << 6) | (static_cast(utf8[i]) & 0x3F); + } } - TCodepoint res = static_cast(utf8[0]) & kFirstByteMask[bytes]; - for (int i = 1; i < bytes; ++i) { - if (utf8[i] == 0 || (static_cast(utf8[i]) & 0xC0) != 0x80) { - // invalid utf8 - return {static_cast(CharHandlingError::kInvalidUtf8), 0}; + if (!accepted) { + // invalid utf8 + if (error_policy == UTF8ErrorPolicy::kReturnInvalid) { + return {CharHandlingError::kInvalidUTF8, utf8}; + } else { + return {static_cast(utf8[0]), utf8 + 1}; } - res = (res << 6) | (static_cast(utf8[i]) & 0x3F); } - return {res, utf8 + bytes}; + + return {res, utf8 + num_bytes}; } -std::vector ParseUTF8(const char* utf8) { +std::vector ParseUTF8(const char* utf8, UTF8ErrorPolicy error_policy) { std::vector codepoints; while (*utf8 != 0) { TCodepoint codepoint; - std::tie(codepoint, utf8) = ParseNextUTF8(utf8); - if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { + std::tie(codepoint, utf8) = ParseNextUTF8(utf8, error_policy); + if (codepoint == CharHandlingError::kInvalidUTF8) { return {codepoint}; } codepoints.push_back(codepoint); @@ -129,17 +154,17 @@ inline int HexCharToInt(char c) { } std::pair ParseNextUTF8OrEscaped( - const char* utf8, const std::unordered_map& custom_escape_map) { + const char* utf8, const std::unordered_map& additional_escape_map) { static const std::unordered_map kEscapeToCodepoint = { {"\\\'", '\''}, {"\\\"", '\"'}, {"\\\?", '\?'}, {"\\\\", '\\'}, {"\\a", '\a'}, {"\\b", '\b'}, {"\\f", '\f'}, {"\\n", '\n'}, {"\\r", '\r'}, {"\\t", '\t'}, {"\\v", '\v'}, {"\\0", '\0'}, {"\\e", '\x1B'}}; if (utf8[0] != '\\') { - return ParseNextUTF8(utf8); + return ParseNextUTF8(utf8, UTF8ErrorPolicy::kReturnInvalid); } auto escape_sequence = std::string(utf8, 2); - if (auto it = custom_escape_map.find(escape_sequence); it != custom_escape_map.end()) { + if (auto it = additional_escape_map.find(escape_sequence); it != additional_escape_map.end()) { return {it->second, utf8 + 2}; } if (auto it = kEscapeToCodepoint.find(escape_sequence); it != kEscapeToCodepoint.end()) { @@ -159,7 +184,7 @@ std::pair ParseNextUTF8OrEscaped( ++len; } if (len == 0) { - return {static_cast(CharHandlingError::kInvalidEscape), utf8}; + return {CharHandlingError::kInvalidEscape, utf8}; } return {codepoint, utf8 + len + 2}; } else if (utf8[1] == 'u' || utf8[1] == 'U') { @@ -170,13 +195,13 @@ std::pair ParseNextUTF8OrEscaped( for (int i = 0; i < len; ++i) { auto digit = HexCharToInt(utf8[i + 2]); if (digit == -1) { - return {static_cast(CharHandlingError::kInvalidEscape), utf8}; + return {CharHandlingError::kInvalidEscape, utf8}; } codepoint = codepoint * 16 + digit; } return {codepoint, utf8 + len + 2}; } else { - return {static_cast(CharHandlingError::kInvalidEscape), utf8}; + return {CharHandlingError::kInvalidEscape, utf8}; } } diff --git a/cpp/support/encoding.h b/cpp/support/encoding.h index 790040e97e..0b18c43b0d 100644 --- a/cpp/support/encoding.h +++ b/cpp/support/encoding.h @@ -17,59 +17,89 @@ namespace llm { using TCodepoint = int32_t; /*! - * \brief Convert a codepoint to a UTF-8 string. + * \brief Handle the utf-8 first byte. + * \returns (is_valid, total_number_of_bytes, initial_codepoint). + */ +std::tuple HandleUTF8FirstByte(uint8_t byte); + +/*! + * \brief Print a codepoint to a UTF-8 string. * \param codepoint The codepoint. * \return The UTF-8 string. */ std::string PrintAsUTF8(TCodepoint codepoint); /*! - * \brief Convert a codepoint to a printable string. If the codepoint is not printable, it will be + * \brief Print a codepoint to a escaped string. If the codepoint is not printable, it will be * escaped. By default the function support escape sequences in C ("\n", "\t", "\u0123"). User can - * specify more escape sequences using custom_escape_map. + * specify more escape sequences using additional_escape_map. * \param codepoint The codepoint. - * \param custom_escape_map A map from codepoint to escape sequence. If the codepoint is in the map, - * it will be escaped using the corresponding escape sequence. e.g. {{'-', "\\-"}}. - * \return The printable string. + * \param additional_escape_map A map from codepoint to escape sequence. If the codepoint is in the + * map, it will be escaped using the corresponding escape sequence. e.g. {{'-', "\\-"}}. \return The + * printable string. */ std::string PrintAsEscaped( TCodepoint codepoint, - const std::unordered_map& custom_escape_map = {}); + const std::unordered_map& additional_escape_map = {}); + +/*! + * \brief Print the given string to a escaped string that can be printed. + * \return The escaped string. + */ +std::string PrintAsEscaped(std::string raw_str); /*! * \brief Represents an error when handling characters. Will be returned as a special TCodepoint * value. */ -enum class CharHandlingError : TCodepoint { +enum CharHandlingError : TCodepoint { /*! \brief The UTF-8 string is invalid. */ - kInvalidUtf8 = -10, + kInvalidUTF8 = -10, /*! \brief The escape sequence is invalid. */ kInvalidEscape = -11, }; /*! - * \brief Convert a UTF-8 string to a codepoint. + * \brief The method to handle invalid UTF-8 sequence. + */ +enum class UTF8ErrorPolicy { + /*! \brief Return an error codepoint when an error is encountered. */ + kReturnInvalid, + /*! \brief Skip the error and continue parsing. */ + kReturnByte, +}; + +/*! + * \brief Parse the first codepoint in a UTF-8 string. * \param utf8 The UTF-8 string. - * \return The codepoint and the number of bytes consumed. If the UTF-8 string is invalid, the - * function returns (CharHandlingError::kInvalidUtf8, 0). + * \return The codepoint and new pointer. If the UTF-8 string is invalid, and the error policy is + * kReturnInvalid, the function returns (CharHandlingError::kInvalidUTF8, input char pointer). */ -std::pair ParseNextUTF8(const char* utf8); +std::pair ParseNextUTF8( + const char* utf8, UTF8ErrorPolicy error_policy = UTF8ErrorPolicy::kReturnInvalid); -std::vector ParseUTF8(const char* utf8); +/*! + * \brief Parse all codepoints in a UTF-8 string. + * \param utf8 The UTF-8 string. + * \return All codepoints. If the UTF-8 string is invalid, and the error policy is + * kReturnInvalid, the function returns {CharHandlingError::kInvalidUTF8}. + */ +std::vector ParseUTF8(const char* utf8, + UTF8ErrorPolicy error_policy = UTF8ErrorPolicy::kReturnInvalid); /*! - * \brief Convert a UTF-8 string or an escape sequence to a codepoint. By default the function - * supports escape sequences in C ("\n", "\t", "\u0123"). User can specify more escape sequences - * using custom_escape_map. + * \brief Parse the first codepoint from a UTF-8 string. Also checks escape sequences and converts + * the escaped char to its original value. * \param utf8 The UTF-8 string or the escape sequence. - * \param custom_escape_map A map from escape sequence to codepoint. If the escape sequence is in - * the map, it will be converted to the corresponding codepoint. e.g. {{"\\-", '-'}}. - * \return The codepoint and the number of bytes consumed. If the UTF-8 string or the escape - * sequence is invalid, the function returns - * (CharHandlingError::kInvalidUtf8 or CharHandlingError::kInvalidEscape, 0). + * \param additional_escape_map A map from escape sequence to codepoint. If the escape sequence is + * in the map, it will be converted to the corresponding codepoint. e.g. {{"\\-", '-'}}. + * \return The codepoint and the new pointer. If the UTF-8 string or the escape sequence is + * invalid, and the error policy is kReturnInvalid, the function returns + * (CharHandlingError::kInvalidUTF8, input char pointer). */ std::pair ParseNextUTF8OrEscaped( - const char* utf8, const std::unordered_map& custom_escape_map = {}); + const char* utf8, + const std::unordered_map& additional_escape_map = {}); } // namespace llm } // namespace mlc diff --git a/cpp/support/utils.h b/cpp/support/utils.h index 6c53e35715..2789654a88 100644 --- a/cpp/support/utils.h +++ b/cpp/support/utils.h @@ -37,5 +37,23 @@ inline bool StartsWith(const std::string& str, const char* prefix) { return prefix[n] == '\0'; } +/*! + * \brief Hash and combine value into seed. + * \ref https://www.boost.org/doc/libs/1_84_0/boost/intrusive/detail/hash_combine.hpp + */ +inline void HashCombineBinary(uint32_t& seed, uint32_t value) { + seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +/*! + * \brief Find the hash sum of several uint32_t args. + */ +template +uint32_t HashCombine(Args... args) { + uint32_t seed = 0; + (..., HashCombineBinary(seed, args)); + return seed; +} + } // namespace llm } // namespace mlc diff --git a/cpp/tokenizers.cc b/cpp/tokenizers.cc index 6fe9217520..cc1c172697 100644 --- a/cpp/tokenizers.cc +++ b/cpp/tokenizers.cc @@ -152,7 +152,8 @@ inline std::string ByteLevelDecoder(const std::string& token) { }; // clang-format on - auto unicode_codepoints = ParseUTF8(token.c_str()); + auto unicode_codepoints = ParseUTF8(token.c_str(), UTF8ErrorPolicy::kReturnInvalid); + ICHECK(unicode_codepoints.size() != 1 || unicode_codepoints[0] != kInvalidUTF8); std::string decoded; for (auto unicode_codepoint : unicode_codepoints) { diff --git a/python/mlc_llm/serve/grammar.py b/python/mlc_llm/serve/grammar.py index cf491884c2..8b5b7d9649 100644 --- a/python/mlc_llm/serve/grammar.py +++ b/python/mlc_llm/serve/grammar.py @@ -1,6 +1,6 @@ """Classes handling the grammar guided generation of MLC LLM serving""" -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import tvm import tvm._ffi @@ -22,19 +22,20 @@ class BNFGrammar(Object): def from_ebnf_string( ebnf_string: str, main_rule: str = "main", - normalize: bool = True, - simplify: bool = True, ) -> "BNFGrammar": - r"""Parse a BNF grammar from a string in BNF/EBNF format. - - This method accepts the EBNF notation from the W3C XML Specification - (https://www.w3.org/TR/xml/#sec-notation), which is a popular standard, with the following - changes: - - Using # as comment mark instead of /**/ - - Using C-style unicode escape sequence \u01AB, \U000001AB, \xAB instead of #x0123 - - Do not support A-B (match A and not match B) yet - - See tests/python/serve/json.ebnf for an example. + r"""Construct a BNF grammar with a EBNF-formatted string. The grammar will be normalized + (simplified) by default. + + EBNF grammar: see https://www.w3.org/TR/xml/#sec-notation. Note: + 1. Use # as the comment mark + 2. Use C-style unicode escape sequence \u01AB, \U000001AB, \xAB + 3. A-B (match A and not match B) is not supported yet + 4. Lookahead assertion can be added at the end of a rule to speed up matching. E.g. + ``` + main ::= "ab" a [a-z] + a ::= "cd" (=[a-z]) + ``` + The assertion (=[a-z]) means a must be followed by [a-z]. Parameters ---------- @@ -44,28 +45,13 @@ def from_ebnf_string( main_rule : str The name of the main rule. Default: "main". - normalize : bool - Whether to normalize the grammar. Default: true. Only set to false for the purpose of - testing. - - In The normalized form of a BNF grammar, every rule is in the form: - `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. - - I.e. a list of choices, each choice is a sequence of elements. Elements can be a - character class or a rule reference. And if the rule can be empty, the first choice - will be an empty string. - - simplify : bool - Whether to simplify the grammar to make matching more efficient. Default: true. Not - implemented yet. - Returns ------- grammar : BNFGrammar The parsed BNF grammar. """ return _ffi_api.BNFGrammarFromEBNFString( # type: ignore # pylint: disable=no-member - ebnf_string, main_rule, normalize, simplify + ebnf_string, main_rule ) def to_string(self) -> str: @@ -167,6 +153,31 @@ def get_grammar_of_json() -> "BNFGrammar": """ return _ffi_api.BNFGrammarGetGrammarOfJSON() # type: ignore # pylint: disable=no-member + @staticmethod + def debug_from_ebnf_string_no_normalize( + ebnf_string: str, + main_rule: str = "main", + ) -> "BNFGrammar": + r"""Construct a BNF grammar with a EBNF-formatted string, but not normalize it. + For test purposes. + + Parameters + ---------- + ebnf_string : str + The grammar string. + + main_rule : str + The name of the main rule. Default: "main". + + Returns + ------- + grammar : BNFGrammar + The parsed BNF grammar. + """ + return _ffi_api.BNFGrammarDebugFromEBNFStringNoNormalize( # type: ignore # pylint: disable=no-member + ebnf_string, main_rule + ) + @staticmethod def debug_json_schema_to_ebnf( schema: str, @@ -235,6 +246,11 @@ class GrammarStateMatcher(Object): max_rollback_steps : int The maximum number of steps to rollback when backtracking. Default: 0. + + token_table_postproc_method : Literal["byte_fallback", "byte_level"] + A helper parameter for the tokenizer. Only useful when the tokenizer is specified. + The method to postprocess the token table. For LLaMA and LLaMA-2 tokenizer, use + "byte_fallback"; for LLaMA-3 tokenizer, use "byte_level". Default: "byte_fallback". """ def __init__( @@ -242,6 +258,7 @@ def __init__( grammar: BNFGrammar, tokenizer: Union[None, Tokenizer, List[str]] = None, max_rollback_steps: int = 0, + token_table_postproc_method: Literal["byte_fallback", "byte_level"] = "byte_fallback", ): if isinstance(tokenizer, list): self.__init_handle_by_constructor__( @@ -256,6 +273,7 @@ def __init__( grammar, tokenizer, max_rollback_steps, + token_table_postproc_method, ) def accept_token(self, token_id: int) -> bool: @@ -346,7 +364,7 @@ def is_terminated(self) -> bool: """ return _ffi_api.GrammarStateMatcherIsTerminated(self) # type: ignore # pylint: disable=no-member - def debug_accept_char(self, codepoint: int) -> bool: + def debug_accept_char(self, codepoint: int, verbose: bool = False) -> bool: """Accept one unicode codepoint to the current state. For test purposes. Parameters @@ -354,11 +372,11 @@ def debug_accept_char(self, codepoint: int) -> bool: codepoint : int The unicode codepoint of the character to be accepted. """ - return _ffi_api.GrammarStateMatcherDebugAcceptCodepoint( # type: ignore # pylint: disable=no-member - self, codepoint + return _ffi_api.GrammarStateMatcherDebugAcceptChar( # type: ignore # pylint: disable=no-member + self, codepoint, verbose ) - def debug_match_complete_string(self, string: str) -> bool: + def debug_match_complete_string(self, string: str, verbose: bool = False) -> bool: """Check if the matcher can accept the complete string, and then reach the end of the grammar. Does not change the state of the GrammarStateMatcher. For test purposes. @@ -367,4 +385,4 @@ def debug_match_complete_string(self, string: str) -> bool: string : str The string to be matched. """ - return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string) # type: ignore # pylint: disable=no-member + return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string, verbose) # type: ignore # pylint: disable=no-member diff --git a/tests/python/serve/test_grammar_parser.py b/tests/python/serve/test_grammar_parser.py index 10eacdf9b9..5e335e15c7 100644 --- a/tests/python/serve/test_grammar_parser.py +++ b/tests/python/serve/test_grammar_parser.py @@ -1,4 +1,5 @@ # pylint: disable=missing-module-docstring,missing-function-docstring +import json import os import pytest @@ -14,11 +15,13 @@ def test_bnf_simple(): c ::= "c" """ expected = """main ::= ((b c)) -b ::= (([b])) -c ::= (([c])) +b ::= (("b")) +c ::= (("c")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() + print(after) + print(expected) assert after == expected @@ -32,11 +35,11 @@ def test_ebnf(): b ::= ((b_1)) c ::= ((c_1)) d ::= ((d_1)) -b_1 ::= ("" | ([a] [b] b_1)) +b_1 ::= ("" | ("ab" b_1)) c_1 ::= (([acep-z] c_1) | ([acep-z])) -d_1 ::= ("" | ([d])) +d_1 ::= ("" | ("d")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -49,18 +52,33 @@ def test_star_quantifier(): e ::= [e]* [f]* | [g]* """ expected = """main ::= ((b c d)) -b ::= [b]* +b ::= (([b]*)) c ::= ((c_1)) d ::= ((d_1)) -e ::= ((e_star e_star_1) | (e_star_2)) -c_1 ::= ("" | ([b] c_1)) +e ::= (([e]* [f]*) | ([g]*)) +c_1 ::= ("" | ("b" c_1)) d_1 ::= ("" | (d_1_choice d_1)) -e_star ::= [e]* -e_star_1 ::= [f]* -e_star_2 ::= [g]* -d_1_choice ::= (([b] [c] [d]) | ([p] [q])) +d_1_choice ::= (("bcd") | ("pq")) +""" + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") + after = bnf_grammar.to_string() + assert after == expected + + +def test_lookahead_assertion(): + before = """main ::= ((b c d)) +b ::= (("abc" [a-z])) (=("abc")) +c ::= (("a") | ("b")) (=([a-z] "b")) +d ::= (("ac") | ("b" d_choice)) (=("abc")) +d_choice ::= (("e") | ("d")) +""" + expected = """main ::= ((b c d)) +b ::= (("abc" [a-z])) (=("abc")) +c ::= (("a") | ("b")) (=([a-z] "b")) +d ::= (("ac") | ("b" d_choice)) (=("abc")) +d_choice ::= (("e") | ("d")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -68,14 +86,14 @@ def test_star_quantifier(): def test_char(): before = r"""main ::= [a-z] [A-z] "\u0234" "\U00000345\xff" [-A-Z] [--] [^a] rest rest ::= [a-zA-Z0-9-] [\u0234-\U00000345] [测-试] [\--\]] rest1 -rest1 ::= "\?\"\'测试あc" "👀" "" +rest1 ::= "\?\"\'测试あc" "👀" "" [a-a] [b-b] """ - expected = r"""main ::= (([a-z] [A-z] ([\u0234]) ([\u0345] [\u00ff]) [\-A-Z] [\-\-] [^a] rest)) + expected = r"""main ::= (([a-z] [A-z] "\u0234\u0345\u00ff" [\-A-Z] [\-\-] [^a] rest)) rest ::= (([a-zA-Z0-9\-] [\u0234-\u0345] [\u6d4b-\u8bd5] [\--\]] rest1)) -rest1 ::= ((([\?] [\"] [\'] [\u6d4b] [\u8bd5] [\u3042] [c]) ([\U0001f440]) "")) +rest1 ::= (("\?\"\'\u6d4b\u8bd5\u3042c\U0001f440ab")) """ # Disable unwrap_nesting_rules to expose the result before unwrapping. - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", False, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -88,9 +106,9 @@ def test_space(): "f" | "g" """ - expected = """main ::= (([a] [b] [c] [d] [e]) | ([f]) | ([g])) + expected = """main ::= (("abcde") | ("f") | ("g")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -98,10 +116,10 @@ def test_space(): def test_nest(): before = """main::= "a" ("b" | "c" "d") | (("e" "f")) """ - expected = """main ::= (([a] main_choice) | ([e] [f])) -main_choice ::= (([b]) | ([c] [d])) + expected = """main ::= (("a" main_choice) | ("ef")) +main_choice ::= (("b") | ("cd")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -115,15 +133,16 @@ def test_flatten(): empty_test ::= "d" | (("" | "" "") "" | "a" "") | ("" ("" | "")) "" "" """ expected = """main ::= ((or_test sequence_test nested_test empty_test)) -or_test ::= ("" | ([a]) | ([b]) | ([d] [e]) | (or_test) | ([^a-z])) -sequence_test ::= (([a] [a] [b] sequence_test_choice [d] [e] sequence_test)) -nested_test ::= (([a] [b] [c] [d]) | ([a]) | ([b]) | ([c]) | (nested_rest)) -nested_rest ::= (([a]) | ([b] [c]) | ([d]) | ([e] [f]) | ([g])) -empty_test ::= ("" | ([d]) | ([a])) -sequence_test_choice ::= (([c]) | ([d])) +or_test ::= ("" | ("a") | ("b") | ("de") | (or_test) | ([^a-z])) +sequence_test ::= (("aab" sequence_test_choice "de" sequence_test)) +nested_test ::= (("abcd") | ("a") | ("b") | ("c") | (nested_rest)) +nested_rest ::= (("a") | ("bc") | ("d") | ("ef") | ("g")) +empty_test ::= ("" | ("d") | ("a")) +sequence_test_choice ::= (("c") | ("d")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() + print(after) assert after == expected @@ -135,51 +154,53 @@ def test_json(): before = file.read() expected = r"""main ::= ((element)) -value ::= ((object) | (array) | (string) | (number) | ([t] [r] [u] [e]) | ([f] [a] [l] [s] [e]) | ([n] [u] [l] [l])) -object ::= (([{] ws [}]) | ([{] members [}])) -members ::= ((member) | (member [,] members)) -member ::= ((ws string ws [:] element)) -array ::= (([[] ws [\]]) | ([[] elements [\]])) -elements ::= ((element) | (element [,] elements)) +value ::= ((object) | (array) | (string) | (number) | ("true") | ("false") | ("null")) +object ::= (("{" ws "}") | ("{" members "}")) +members ::= ((member) | (member "," members)) +member ::= ((ws string ws ":" element)) +array ::= (("[" ws "]") | ("[" elements "]")) +elements ::= ((element) | (element "," elements)) element ::= ((ws value ws)) -string ::= (([\"] characters [\"])) +string ::= (("\"" characters "\"")) characters ::= ("" | (character characters)) -character ::= (([^\"\\]) | ([\\] escape)) -escape ::= (([\"]) | ([\\]) | ([/]) | ([b]) | ([f]) | ([n]) | ([r]) | ([t]) | ([u] hex hex hex hex)) +character ::= (([^\"\\]) | ("\\" escape)) +escape ::= (("\"") | ("\\") | ("/") | ("b") | ("f") | ("n") | ("r") | ("t") | ("u" hex hex hex hex)) hex ::= (([A-Fa-f0-9])) number ::= ((integer fraction exponent)) -integer ::= ((digit) | (onenine digits) | ([\-] digit) | ([\-] onenine digits)) +integer ::= ((digit) | (onenine digits) | ("-" digit) | ("-" onenine digits)) digits ::= ((digit) | (digit digits)) digit ::= (([0-9])) onenine ::= (([1-9])) -fraction ::= ("" | ([.] digits)) +fraction ::= ("" | ("." digits)) exponent ::= ("" | (exponent_choice exponent_choice_1 digits)) -ws ::= ("" | ([ ] ws) | ([\n] ws) | ([\r] ws) | ([\t] ws)) -exponent_choice ::= (([e]) | ([E])) -exponent_choice_1 ::= ("" | ([+]) | ([\-])) +ws ::= ("" | (" " ws) | ("\n" ws) | ("\r" ws) | ("\t" ws)) +exponent_choice ::= (("e") | ("E")) +exponent_choice_1 ::= ("" | ("+") | ("-")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() + print(after) assert after == expected def test_to_string_roundtrip(): """Checks the printed result can be parsed, and the parsing-printing process is idempotent.""" - before = r"""main ::= (b c) | (b main) -b ::= b_1 d -c ::= c_1 -d ::= d_1 -b_1 ::= ([b] b_1) | "" -c_1 ::= (c_2 c_1) | c_2 -c_2 ::= [acep-z] -d_1 ::= [d] | "" + before = r"""main ::= ((b c) | (b main)) +b ::= ((b_1 d)) +c ::= ((c_1)) +d ::= ((d_1)) +b_1 ::= ("" | ("b" b_1)) +c_1 ::= ((c_2 c_1) | (c_2)) (=("abc" [a-z])) +c_2 ::= (([acep-z])) +d_1 ::= ("" | ("d")) """ - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main") output_string_1 = bnf_grammar_1.to_string() - bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, "main", True, False) + bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, "main") output_string_2 = bnf_grammar_2.to_string() + assert before == output_string_1 assert output_string_1 == output_string_2 @@ -245,34 +266,50 @@ def test_error(): ): BNFGrammar.from_ebnf_string('a ::= "a"') + with pytest.raises( + TVMError, + match="TVMError: EBNF parse error at line 1, column 21: Unexpected lookahead assertion", + ): + BNFGrammar.from_ebnf_string('main ::= "a" (="a") (="b")') + def test_to_json(): before = """main ::= b c | b main b ::= "bcd" c ::= [a-z] """ - expected = ( - '{"rule_expr_indptr":[0,3,6,10,13,16,20,24,28,32,36,41,44,48,51],"rule_expr_data"' - ":[3,1,1,3,1,2,4,2,0,1,3,1,1,3,1,0,4,2,3,4,5,2,2,5,0,2,98,98,0,2,99,99,0,2,100,100," - '4,3,7,8,9,5,1,10,0,2,97,122,4,1,12,5,1,13],"rules":[{"body_expr_id":6,"name":"main"},' - '{"body_expr_id":11,"name":"b"},{"body_expr_id":14,"name":"c"}]}' - ) - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) - after = bnf_grammar.to_json(False) - assert after == expected + expected_obj = { + "rules": [ + {"body_expr_id": 6, "name": "main"}, + {"body_expr_id": 9, "name": "b"}, + {"body_expr_id": 12, "name": "c"}, + ], + "rule_expr_indptr": [0, 3, 6, 10, 13, 16, 20, 24, 29, 32, 35, 40, 43], + "rule_expr_data": [ + # fmt: off + 4,1,1,4,1,2,5,2,0,1,4,1,1,4,1,0,5,2,3,4,6,2,2,5,0,3,98,99, + 100,5,1,7,6,1,8,1,3,0,97,122,5,1,10,6,1,11 + # fmt: on + ], + } + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") + print(bnf_grammar) + after_str = bnf_grammar.to_json(False) + after_obj = json.loads(after_str) + assert after_obj == expected_obj def test_to_json_roundtrip(): before = r"""main ::= ((b c) | (b main)) -b ::= ((b_1 d)) +b ::= ((b_1 d [a]*)) c ::= ((c_1)) d ::= ((d_1)) -b_1 ::= ("" | ([b] b_1)) +b_1 ::= ("" | ("b" b_1)) c_1 ::= ((c_2 c_1) | (c_2)) c_2 ::= (([acep-z])) -d_1 ::= ("" | ([d])) +d_1 ::= ("" | ("d")) """ - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main") output_json_1 = bnf_grammar_1.to_json(False) bnf_grammar_2 = BNFGrammar.from_json(output_json_1) output_json_2 = bnf_grammar_2.to_json(False) diff --git a/tests/python/serve/test_grammar_state_matcher_custom.py b/tests/python/serve/test_grammar_state_matcher_custom.py index 6fc48705d1..6ad6294d77 100644 --- a/tests/python/serve/test_grammar_state_matcher_custom.py +++ b/tests/python/serve/test_grammar_state_matcher_custom.py @@ -40,6 +40,20 @@ def json_grammar(): return get_json_grammar() +def test_simple(): + grammar_str = """main ::= rule1 rule2 +rule1 ::= (rule2 | rule3) "a" +rule2 ::= "b" +rule3 ::= "c" +""" + + grammar = BNFGrammar.from_ebnf_string(grammar_str) + matcher = GrammarStateMatcher(grammar) + assert matcher.debug_match_complete_string("bab") + assert not matcher.debug_match_complete_string("abb") + assert matcher.debug_match_complete_string("cab") + + (json_input_accepted,) = tvm.testing.parameters( ('{"name": "John"}',), ('{ "name" : "John" }',), @@ -241,8 +255,8 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): '{"id": 1,"name": "Example"}', [ # fmt: off - 31989, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 299, 299, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 292, 292, 292, 292, 292, 31974, 31999 + 31989, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 272, 272, 272, 272, + 272, 31973, 31846, 31846, 265, 265, 265, 265, 265, 265, 265, 265, 31974, 31999 # fmt: on ], ), @@ -258,15 +272,15 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): }""", [ # fmt: off - 31989, 31912, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 31915, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 31974, 31915, 31915, 299, 299, 299, 31973, - 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 299, 299, 31973, 31846, 31846, - 31840, 291, 291, 291, 31969, 31846, 31846, 291, 291, 291, 31969, 31974, 31915, 31915, - 299, 299, 299, 31973, 31846, 31846, 31908, 299, 299, 299, 299, 31973, 31846, 31846, - 31906, 299, 299, 299, 299, 31973, 31846, 31846, 291, 291, 291, 31968, 31970, 31915, - 31915, 299, 299, 299, 299, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, - 31846, 31846, 31943, 31970, 31974, 31915, 31915, 299, 299, 299, 299, 31973, 31846, - 31846, 292, 292, 292, 292, 31974, 31974, 31999 + 31989, 31912, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 31915, 272, 272, + 272, 31973, 31846, 31846, 265, 265, 265, 31974, 31915, 31915, 272, 272, 272, 31973, + 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 272, 272, 31973, 31846, 31846, + 31840, 264, 264, 264, 31969, 31846, 31846, 264, 264, 264, 31969, 31974, 31915, 31915, + 272, 272, 272, 31973, 31846, 31846, 31908, 272, 272, 272, 272, 31973, 31846, 31846, + 31906, 272, 272, 272, 272, 31973, 31846, 31846, 264, 264, 264, 31968, 31970, 31915, + 31915, 272, 272, 272, 272, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, + 31846, 31846, 31943, 31970, 31974, 31915, 31915, 272, 272, 272, 272, 31973, 31846, + 31846, 265, 265, 265, 265, 31974, 31974, 31999 # fmt: on ], ), @@ -395,5 +409,6 @@ class MainModel(BaseModel): if __name__ == "__main__": # Run a benchmark to show the performance before running tests test_find_next_rejected_tokens(get_json_grammar(), '{"id": 1,"name": "Example"}') + test_find_next_rejected_tokens_schema() tvm.testing.main() diff --git a/tests/python/serve/test_grammar_state_matcher_json.py b/tests/python/serve/test_grammar_state_matcher_json.py index fc0f79a041..51737e1435 100644 --- a/tests/python/serve/test_grammar_state_matcher_json.py +++ b/tests/python/serve/test_grammar_state_matcher_json.py @@ -2,7 +2,7 @@ # pylint: disable=redefined-outer-name,unbalanced-tuple-unpacking """This test uses the optimized JSON grammar provided by the grammar library.""" import sys -from typing import List, Optional +from typing import List, Literal, Optional import pytest import tvm @@ -213,19 +213,40 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_pressure) -(input_find_rejected_tokens, expected_rejected_sizes) = tvm.testing.parameters( +( + tokenizer_path, + input_find_rejected_tokens, + expected_rejected_sizes, + token_table_postproc_method, +) = tvm.testing.parameters( ( # short test + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", '{"id": 1,"name": "Example"}', [ # fmt: off - 31989, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 299, 299, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 292, 292, 292, 292, 292, 31974, 31999 + 31989, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 272, 272, 272, 272, + 272, 31973, 31846, 31846, 265, 265, 265, 265, 265, 265, 265, 265, 31974, 31999 # fmt: on ], + "byte_fallback", + ), + ( + # short test + "dist/Meta-Llama-3-8B-Instruct-q4f16_1-MLC", + '{"id": 1,"name": "Example哈哈"}', + [ + # fmt: off + 128235, 127497, 5002, 5002, 5002, 127849, 126399, 126399, 126760, 127499, 5002, 5002, + 5002, 5002, 5002, 127849, 126399, 126399, 4952, 4952, 4952, 4952, 4952, 4952, 4952, + 4952, 128066, 128111, 4952, 128066, 128111, 4952, 127873, 128254 + # fmt: on + ], + "byte_level", ), ( # long test + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", """{ "id": 1, "na": "ex", @@ -236,40 +257,51 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): }""", [ # fmt: off - 31989, 31912, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 31915, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 31974, 31915, 31915, 299, 299, 299, 31973, - 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 299, 299, 31973, 31846, 31846, - 31840, 291, 291, 291, 31969, 31846, 31846, 291, 291, 291, 31969, 31974, 31915, 31915, - 299, 299, 299, 31973, 31846, 31846, 31908, 299, 299, 299, 299, 31973, 31846, 31846, - 31906, 299, 299, 299, 299, 31973, 31846, 31846, 291, 291, 291, 31968, 31970, 31915, - 31915, 299, 299, 299, 299, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, - 31846, 31846, 31943, 31970, 31974, 31915, 31915, 299, 299, 299, 299, 31973, 31846, - 31846, 292, 292, 292, 292, 31974, 31974, 31999 + 31989, 31912, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 31915, 272, 272, + 272, 31973, 31846, 31846, 265, 265, 265, 31974, 31915, 31915, 272, 272, 272, 31973, + 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 272, 272, 31973, 31846, 31846, + 31840, 264, 264, 264, 31969, 31846, 31846, 264, 264, 264, 31969, 31974, 31915, 31915, + 272, 272, 272, 31973, 31846, 31846, 31908, 272, 272, 272, 272, 31973, 31846, 31846, + 31906, 272, 272, 272, 272, 31973, 31846, 31846, 264, 264, 264, 31968, 31970, 31915, + 31915, 272, 272, 272, 272, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, + 31846, 31846, 31943, 31970, 31974, 31915, 31915, 272, 272, 272, 272, 31973, 31846, + 31846, 265, 265, 265, 265, 31974, 31974, 31999 # fmt: on ], + "byte_fallback", ), ) def test_find_next_rejected_tokens( json_grammar: BNFGrammar, + tokenizer_path: str, input_find_rejected_tokens: str, - expected_rejected_sizes: Optional[List[int]] = None, + expected_rejected_sizes: Optional[List[int]], + token_table_postproc_method: Literal["byte_fallback", "byte_level"], ): - tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" tokenizer = Tokenizer(tokenizer_path) - grammar_state_matcher = GrammarStateMatcher(json_grammar, tokenizer) + grammar_state_matcher = GrammarStateMatcher( + json_grammar, tokenizer, token_table_postproc_method=token_table_postproc_method + ) + input_bytes = input_find_rejected_tokens.encode("utf-8") + rejected_sizes = [] - real_sizes = [] - for c in input_find_rejected_tokens: + for i, c in enumerate(input_bytes): rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - real_sizes.append(len(rejected_token_ids)) - print("Accepting char:", c, file=sys.stderr) - assert grammar_state_matcher.debug_accept_char(ord(c)) + rejected_sizes.append(len(rejected_token_ids)) + if expected_rejected_sizes is not None: + assert rejected_sizes[-1] == expected_rejected_sizes[i], ( + len(rejected_token_ids), + expected_rejected_sizes[i], + ) + print("Accepting char:", c, bytes([c]), file=sys.stderr) + assert grammar_state_matcher.debug_accept_char(c) + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - real_sizes.append(len(rejected_token_ids)) + rejected_sizes.append(len(rejected_token_ids)) if expected_rejected_sizes is not None: - assert real_sizes == expected_rejected_sizes + assert rejected_sizes[-1] == expected_rejected_sizes[-1] def test_token_based_operations(json_grammar: BNFGrammar): @@ -305,7 +337,7 @@ def test_token_based_operations(json_grammar: BNFGrammar): accepted = list(set(range(len(token_table))) - set(rejected)) accepted_tokens = [token_table[i] for i in accepted] result.append(accepted_tokens) - assert id in accepted + assert id in accepted, token_table[id] assert grammar_state_matcher.accept_token(id) rejected = grammar_state_matcher.find_next_rejected_tokens() @@ -407,6 +439,20 @@ def test_termination(json_grammar: BNFGrammar): if __name__ == "__main__": # Run a benchmark to show the performance before running tests - test_find_next_rejected_tokens(BNFGrammar.get_grammar_of_json(), '{"id": 1,"name": "Example"}') + test_find_next_rejected_tokens( + BNFGrammar.get_grammar_of_json(), + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + '{"id": 1,"name": "Example"}', + None, + "byte_fallback", + ) + + test_find_next_rejected_tokens( + BNFGrammar.get_grammar_of_json(), + "dist/Meta-Llama-3-8B-Instruct-q4f16_1-MLC", + '{"id": 1,"name": "Example哈哈"}', + None, + "byte_level", + ) tvm.testing.main() diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index 2b3ce29c7f..8bd86a25a1 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -13,7 +13,7 @@ prompts_list = [ "Generate a JSON string containing 20 objects:", - "Generate a JSON containing a list:", + "Generate a JSON containing a non-empty list:", "Generate a JSON with 5 elements:", ] model_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" diff --git a/web/emcc/mlc_wasm_runtime.cc b/web/emcc/mlc_wasm_runtime.cc index b9a7f55bfa..6ba914ee9f 100644 --- a/web/emcc/mlc_wasm_runtime.cc +++ b/web/emcc/mlc_wasm_runtime.cc @@ -36,9 +36,9 @@ // Grammar related #include "serve/grammar/grammar.cc" +#include "serve/grammar/grammar_functor.cc" #include "serve/grammar/grammar_parser.cc" #include "serve/grammar/grammar_serializer.cc" -#include "serve/grammar/grammar_simplifier.cc" #include "serve/grammar/grammar_state_matcher.cc" #include "serve/grammar/json_schema_converter.cc" #include "support/encoding.cc" From 0c03537e284e92bc7b27832ba86cc1dea224b9a5 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Tue, 14 May 2024 15:47:41 -0700 Subject: [PATCH 308/531] [DebugChat] Fix DebugChat softmax function and save logits to debug folder (#2342) * [DebugChat] Fix DebugChat softmax function and save logits to debug folder * Fix lint --- python/mlc_llm/testing/debug_chat.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index 8ff370e9d9..fee8cb8867 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -351,7 +351,9 @@ def _sample_token_from_logits( if presence_penalty != 0.0 or frequency_penalty != 0.0: self._apply_presence_and_freq_penalty(logits_np, presence_penalty, frequency_penalty) - self._softmax_with_temperature(logits_np, temperature) + logits_np = self._softmax_with_temperature(logits_np, temperature) + np.savez(self.instrument.debug_out / "logits.npz", logits_np) + logits = logits.copyfrom(logits_np) next_token = self.sample_topp_from_prob_func(logits, top_p, random.random()) return next_token From b247f8d2c733c71924c4afc2abc427f6c8d0ab91 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 14 May 2024 15:48:56 -0700 Subject: [PATCH 309/531] [Serving] Add Medusa speculative decoding (#2337) * [Serving] Add Medusa speculative decoding --- cpp/metadata/model.cc | 15 +- cpp/metadata/model.h | 31 +++ cpp/serve/config.cc | 13 +- cpp/serve/config.h | 56 +++-- cpp/serve/engine.cc | 39 ++-- cpp/serve/engine_actions/action_commons.cc | 20 ++ cpp/serve/engine_actions/action_commons.h | 18 ++ .../engine_actions/batch_prefill_base.cc | 12 +- cpp/serve/engine_actions/batch_prefill_base.h | 4 +- .../engine_actions/eagle_batch_verify.cc | 111 +++++----- .../eagle_new_request_prefill.cc | 196 ++++++++++-------- cpp/serve/function_table.cc | 1 - cpp/serve/model.cc | 70 ++++++- cpp/serve/model.h | 6 +- python/mlc_llm/cli/serve.py | 2 +- python/mlc_llm/interface/compile.py | 9 + python/mlc_llm/interface/serve.py | 2 +- python/mlc_llm/model/medusa/__init__.py | 0 python/mlc_llm/model/medusa/medusa_loader.py | 51 +++++ python/mlc_llm/model/medusa/medusa_model.py | 83 ++++++++ .../model/medusa/medusa_quantization.py | 20 ++ python/mlc_llm/model/model.py | 13 ++ python/mlc_llm/serve/config.py | 5 +- python/mlc_llm/serve/engine.py | 5 +- python/mlc_llm/serve/engine_base.py | 2 +- 25 files changed, 558 insertions(+), 226 deletions(-) create mode 100644 python/mlc_llm/model/medusa/__init__.py create mode 100644 python/mlc_llm/model/medusa/medusa_loader.py create mode 100644 python/mlc_llm/model/medusa/medusa_model.py create mode 100644 python/mlc_llm/model/medusa/medusa_quantization.py diff --git a/cpp/metadata/model.cc b/cpp/metadata/model.cc index 62ba2787b9..e3e9a79b3c 100644 --- a/cpp/metadata/model.cc +++ b/cpp/metadata/model.cc @@ -63,8 +63,17 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata, if (metadata.count("attention_sink_size")) // remove after sink is decoupled from model lib result.attention_sink_size = json::Lookup(metadata, "attention_sink_size"); result.tensor_parallel_shards = json::Lookup(metadata, "tensor_parallel_shards"); - result.kv_cache_metadata = - KVCacheMetadata::FromJSON(json::Lookup(metadata, "kv_cache")); + result.kv_state_kind = KVStateKindFromString( + json::LookupOrDefault(metadata, "kv_state_kind", "kv_cache")); + if (result.kv_state_kind != KVStateKind::kNone) { + result.kv_cache_metadata = + KVCacheMetadata::FromJSON(json::Lookup(metadata, "kv_cache")); + } else { + result.kv_cache_metadata = {/*num_hidden_layers=*/0, + /*head_dim=*/0, + /*num_attention_heads=*/0, + /*num_key_value_heads=*/0}; + } { std::vector& params = result.params; picojson::array json_params = json::Lookup(metadata, "params"); @@ -94,7 +103,7 @@ ModelMetadata ModelMetadata::FromModule(tvm::runtime::Module module, try { return ModelMetadata::FromJSON(json, model_config); } catch (const std::exception& e) { - LOG(WARNING) << "Failed to parse metadata:\n" << json_str; + LOG(WARNING) << "Failed to parse metadata:\n" << json_str << "\nerror: " << e.what(); throw e; } } diff --git a/cpp/metadata/model.h b/cpp/metadata/model.h index ede06b6b3f..4b204f6902 100644 --- a/cpp/metadata/model.h +++ b/cpp/metadata/model.h @@ -16,6 +16,36 @@ namespace mlc { namespace llm { +/*! \brief The kind of cache. */ +enum class KVStateKind : int { + kKVCache = 0, + kRNNState = 1, + kNone = 2, +}; + +inline std::string KVStateKindToString(KVStateKind kv_state_kind) { + if (kv_state_kind == KVStateKind::kKVCache) { + return "kv_cache"; + } else if (kv_state_kind == KVStateKind::kRNNState) { + return "rnn_state"; + } else if (kv_state_kind == KVStateKind::kNone) { + return "none"; + } else { + LOG(FATAL) << "Invalid kv state kind: " << static_cast(kv_state_kind); + } +} + +inline KVStateKind KVStateKindFromString(const std::string& kv_state_kind) { + if (kv_state_kind == "kv_cache") { + return KVStateKind::kKVCache; + } else if (kv_state_kind == "rnn_state") { + return KVStateKind::kRNNState; + } else if (kv_state_kind == "none") { + return KVStateKind::kNone; + } else { + LOG(FATAL) << "Invalid kv state kind string: " << kv_state_kind; + } +} struct ModelMetadata { struct Param { struct Preproc { @@ -49,6 +79,7 @@ struct ModelMetadata { int64_t attention_sink_size; std::vector params; std::unordered_map memory_usage; + KVStateKind kv_state_kind; KVCacheMetadata kv_cache_metadata; static ModelMetadata FromJSON(const picojson::object& json_str, diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 9b9d5ba65a..cbc4c6c613 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -248,7 +248,6 @@ EngineConfig EngineConfig::FromJSONAndInferredConfig( CHECK(inferred_config.max_single_sequence_length.has_value()); CHECK(inferred_config.prefill_chunk_size.has_value()); CHECK(inferred_config.max_history_size.has_value()); - CHECK(inferred_config.kv_state_kind.has_value()); ObjectPtr n = make_object(); // - Get models and model libs. @@ -290,7 +289,6 @@ EngineConfig EngineConfig::FromJSONAndInferredConfig( n->max_single_sequence_length = inferred_config.max_single_sequence_length.value(); n->prefill_chunk_size = inferred_config.prefill_chunk_size.value(); n->max_history_size = inferred_config.max_history_size.value(); - n->kv_state_kind = inferred_config.kv_state_kind.value(); return EngineConfig(n); } @@ -356,7 +354,6 @@ String EngineConfigNode::AsJSONString() const { picojson::value(static_cast(this->max_single_sequence_length)); config["prefill_chunk_size"] = picojson::value(static_cast(this->prefill_chunk_size)); config["max_history_size"] = picojson::value(static_cast(this->max_history_size)); - config["kv_state_kind"] = picojson::value(KVStateKindToString(this->kv_state_kind)); config["speculative_mode"] = picojson::value(SpeculativeModeToString(this->speculative_mode)); config["spec_draft_length"] = picojson::value(static_cast(this->spec_draft_length)); config["verbose"] = picojson::value(static_cast(this->verbose)); @@ -428,14 +425,18 @@ Result GetModelConfigLimits(const std::vector(compile_time_model_config, "max_batch_size")); } ICHECK_NE(model_max_prefill_chunk_size, std::numeric_limits::max()); ICHECK_NE(model_max_batch_size, std::numeric_limits::max()); + ICHECK_GT(model_max_prefill_chunk_size, 0); + ICHECK_GT(model_max_batch_size, 0); return Result::Ok( {model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size}); } @@ -689,7 +690,6 @@ Result InferrableEngineConfig::InferForKVCache( << " MB). The actual usage might be slightly larger than the estimated number."; } - inferred_config.kv_state_kind = KVStateKind::kKVCache; inferred_config.max_history_size = 0; return Result::Ok(inferred_config); } @@ -853,7 +853,6 @@ Result InferrableEngineConfig::InferForRNNState( << " MB). The actual usage might be slightly larger than the estimated number."; } - inferred_config.kv_state_kind = KVStateKind::kRNNState; return Result::Ok(inferred_config); } diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 8437232d37..2680eb755c 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -114,12 +114,8 @@ enum class SpeculativeMode : int { kSmallDraft = 1, /*! \brief The eagle-style speculative decoding. */ kEagle = 2, -}; - -/*! \brief The kind of cache. */ -enum class KVStateKind : int { - kKVCache = 0, - kRNNState = 1, + /*! \brief The Medusa-style speculative decoding. */ + kMedusa = 3, }; class InferrableEngineConfig; @@ -172,8 +168,6 @@ class EngineConfigNode : public Object { int prefill_chunk_size = 1024; /*! \brief The maximum history size for RNN state. KV cache does not need this. */ int max_history_size = 0; - /*! \brief The kind of cache. Whether it's KV cache or RNN state. */ - KVStateKind kv_state_kind = KVStateKind::kKVCache; /*************** Speculative decoding ***************/ @@ -216,7 +210,6 @@ struct InferrableEngineConfig { std::optional max_single_sequence_length; std::optional prefill_chunk_size; std::optional max_history_size; - std::optional kv_state_kind; /*! \brief Infer the config for KV cache from a given initial config. */ TVM_DLL static Result InferForKVCache( @@ -238,9 +231,16 @@ struct InferrableEngineConfig { Result ModelsUseKVCache(const std::vector& model_configs); inline std::string EngineModeToString(EngineMode mode) { - return mode == EngineMode::kLocal ? "local" - : mode == EngineMode::kInteractive ? "interactive" - : "server"; + if (mode == EngineMode::kLocal) { + return "local"; + } else if (mode == EngineMode::kInteractive) { + return "interactive"; + } else if (mode == EngineMode::kServer) { + return "server"; + } else { + LOG(FATAL) << "Invalid engine mode: " << static_cast(mode); + throw; + } } inline EngineMode EngineModeFromString(const std::string& mode) { @@ -252,13 +252,22 @@ inline EngineMode EngineModeFromString(const std::string& mode) { return EngineMode::kServer; } else { LOG(FATAL) << "Invalid engine mode string: " << mode; + throw; } } inline std::string SpeculativeModeToString(SpeculativeMode speculative_mode) { - return speculative_mode == SpeculativeMode::kDisable ? "disable" - : speculative_mode == SpeculativeMode::kSmallDraft ? "small_draft" - : "eagle"; + if (speculative_mode == SpeculativeMode::kDisable) { + return "disable"; + } else if (speculative_mode == SpeculativeMode::kSmallDraft) { + return "small_draft"; + } else if (speculative_mode == SpeculativeMode::kEagle) { + return "eagle"; + } else if (speculative_mode == SpeculativeMode::kMedusa) { + return "medusa"; + } else { + LOG(FATAL) << "Invalid speculative mode: " << static_cast(speculative_mode); + } } inline SpeculativeMode SpeculativeModeFromString(const std::string& speculative_mode) { @@ -268,22 +277,11 @@ inline SpeculativeMode SpeculativeModeFromString(const std::string& speculative_ return SpeculativeMode::kSmallDraft; } else if (speculative_mode == "eagle") { return SpeculativeMode::kEagle; + } else if (speculative_mode == "medusa") { + return SpeculativeMode::kMedusa; } else { LOG(FATAL) << "Invalid speculative mode string: " << speculative_mode; - } -} - -inline std::string KVStateKindToString(KVStateKind kv_state_kind) { - return kv_state_kind == KVStateKind::kKVCache ? "kv_cache" : "rnn_State"; -} - -inline KVStateKind KVStateKindFromString(const std::string& kv_state_kind) { - if (kv_state_kind == "kv_cache") { - return KVStateKind::kKVCache; - } else if (kv_state_kind == "rnn_state") { - return KVStateKind::kRNNState; - } else { - LOG(FATAL) << "Invalid kv state kind string: " << kv_state_kind; + throw; } } diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 9b9cf81fe7..418cabfc91 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -105,8 +105,7 @@ class EngineImpl : public Engine { model->SetPrefillChunkSize(engine_config->prefill_chunk_size); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, engine_config->max_total_sequence_length, - engine_config->prefill_chunk_size, engine_config->max_history_size, - engine_config->kv_state_kind); + engine_config->prefill_chunk_size, engine_config->max_history_size); n->model_workspaces_.push_back( ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); } @@ -161,6 +160,18 @@ class EngineImpl : public Engine { n->model_workspaces_, draft_token_workspace_manager, engine_config, n->trace_recorder_)}; break; + case SpeculativeMode::kMedusa: + n->actions_ = {EngineAction::EagleNewRequestPrefill(n->models_, // + logit_processor, // + sampler, // + n->model_workspaces_, // + draft_token_workspace_manager, // + engine_config, // + n->trace_recorder_), + EngineAction::EagleBatchVerify( + n->models_, logit_processor, sampler, n->model_workspaces_, + draft_token_workspace_manager, engine_config, n->trace_recorder_)}; + break; default: n->actions_ = { EngineAction::NewRequestPrefill(n->models_, // @@ -422,13 +433,9 @@ class EngineImpl : public Engine { json::LookupOptional(config, "max_history_size"); std::optional kv_state_kind_str = json::LookupOptional(config, "kv_state_kind"); - std::optional kv_state_kind; - if (kv_state_kind_str.has_value()) { - kv_state_kind = KVStateKindFromString(kv_state_kind_str.value()); - } - InferrableEngineConfig inferrable_cfg{max_num_sequence, max_total_sequence_length, + InferrableEngineConfig inferrable_cfg{max_num_sequence, max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, - max_history_size, kv_state_kind}; + max_history_size}; // - Get the model metadata. std::vector model_metadata; @@ -440,28 +447,13 @@ class EngineImpl : public Engine { if (use_kv_cache.IsErr()) { return TResult::Error(use_kv_cache.UnwrapErr()); } - KVStateKind inferred_kv_state_kind; Result inferrable_cfg_res; if (use_kv_cache.Unwrap()) { - inferred_kv_state_kind = KVStateKind::kKVCache; - // - Check if the kv state kind from config is valid. - if (kv_state_kind.has_value() && kv_state_kind.value() != inferred_kv_state_kind) { - return TResult::Error( - "Invalid kv state kind in EngineConfig. The models use KV cache, but RNN state is " - "specified in EngineConfig."); - } // - Infer configuration. inferrable_cfg_res = InferrableEngineConfig::InferForKVCache( mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg, verbose); } else { - inferred_kv_state_kind = KVStateKind::kRNNState; - // - Check if the kv state kind from config is valid. - if (kv_state_kind.has_value() && kv_state_kind.value() != inferred_kv_state_kind) { - return TResult::Error( - "Invalid kv state kind in EngineConfig. The models use RNN state, but KV cache is " - "specified in EngineConfig."); - } // - Infer configuration. inferrable_cfg_res = InferrableEngineConfig::InferForRNNState( mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg, @@ -477,7 +469,6 @@ class EngineImpl : public Engine { ICHECK(inferrable_cfg.max_single_sequence_length.has_value()); ICHECK(inferrable_cfg.prefill_chunk_size.has_value()); ICHECK(inferrable_cfg.max_history_size.has_value()); - ICHECK(inferrable_cfg.kv_state_kind.has_value()); return TResult::Ok(EngineConfig::FromJSONAndInferredConfig(config, inferrable_cfg)); } diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index af0dfe978d..3289ef57c6 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -211,6 +211,26 @@ RequestStateEntry PreemptLastRunningRequestStateEntry( return rsentry; } +std::pair> ApplyLogitProcessorAndSample( + const LogitProcessor& logit_processor, const Sampler& sampler, const NDArray& logits, + const Array& generation_cfg, const Array& request_ids, + const Array& mstates, const std::vector& rngs, + const std::vector& sample_indices) { + // - Update logits. + logit_processor->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); + + // - Compute probability distributions. + NDArray probs_on_device = + logit_processor->ComputeProbsFromLogits(logits, generation_cfg, request_ids); + + // - Sample tokens. + NDArray renormalized_probs = sampler->BatchRenormalizeProbsByTopP(probs_on_device, sample_indices, + request_ids, generation_cfg); + std::vector sample_results = sampler->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); + return {std::move(probs_on_device), std::move(sample_results)}; +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index 07bef2d2d9..de98e11e67 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -75,6 +75,24 @@ inline std::vector GetRunningRequestStateEntries(const Engine return rsentries; } +/*! + * \brief Apply the logit processor to the logits and sample one token for each request. + * \param logit_processor The logit processor to apply. + * \param sampler The sampler to sample tokens. + * \param logits The logits to process. + * \param generation_cfg The generation configurations of the requests. + * \param request_ids The request ids. + * \param mstates The model states of the requests. + * \param rngs The random generators of the requests. + * \param sample_indices The indices of the requests to sample. + * \return The processed logits and the sampled results. + */ +std::pair> ApplyLogitProcessorAndSample( + const LogitProcessor& logit_processor, const Sampler& sampler, const NDArray& logits, + const Array& generation_cfg, const Array& request_ids, + const Array& mstates, const std::vector& rngs, + const std::vector& sample_indices); + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/engine_actions/batch_prefill_base.cc b/cpp/serve/engine_actions/batch_prefill_base.cc index df6df2b3d9..f570551417 100644 --- a/cpp/serve/engine_actions/batch_prefill_base.cc +++ b/cpp/serve/engine_actions/batch_prefill_base.cc @@ -34,6 +34,7 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) { int num_available_pages = models_[0]->GetNumAvailablePages(); int num_running_rsentries = GetRunningRequestStateEntries(estate).size(); int current_total_seq_len = models_[0]->GetCurrentTotalSequenceLength(); + KVStateKind kv_state_kind = models_[0]->GetMetadata().kv_state_kind; int num_prefill_rsentries = 0; for (const Request& request : estate->waiting_queue) { @@ -61,7 +62,7 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) { --num_child_to_activate) { if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate, total_input_length, total_required_pages, num_available_pages, - current_total_seq_len, num_running_rsentries)) { + current_total_seq_len, num_running_rsentries, kv_state_kind)) { prefill_inputs.push_back({rsentry, input_length, num_child_to_activate}); num_prefill_rsentries += 1 + num_child_to_activate; can_prefill = true; @@ -93,7 +94,8 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) { total_input_length += input_length; total_required_pages += num_require_pages; if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length, total_required_pages, - num_available_pages, current_total_seq_len, num_running_rsentries)) { + num_available_pages, current_total_seq_len, num_running_rsentries, + kv_state_kind)) { prefill_inputs.push_back({rsentry, input_length, 0}); num_prefill_rsentries += 1; } @@ -114,11 +116,11 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) { bool BatchPrefillBaseActionObj::CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, int num_required_pages, int num_available_pages, int current_total_seq_len, - int num_running_rsentries) { + int num_running_rsentries, KVStateKind kv_state_kind) { ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence); // For RNN State, it can prefill as long as it can be instantiated. - if (engine_config_->kv_state_kind == KVStateKind::kRNNState) { + if (kv_state_kind == KVStateKind::kRNNState || kv_state_kind == KVStateKind::kNone) { return true; } @@ -310,4 +312,4 @@ void BatchPrefillBaseActionObj::UpdateRequestStateEntriesWithSampleResults( } // namespace serve } // namespace llm -} // namespace mlc \ No newline at end of file +} // namespace mlc diff --git a/cpp/serve/engine_actions/batch_prefill_base.h b/cpp/serve/engine_actions/batch_prefill_base.h index 54b257dc21..122a214496 100644 --- a/cpp/serve/engine_actions/batch_prefill_base.h +++ b/cpp/serve/engine_actions/batch_prefill_base.h @@ -40,7 +40,7 @@ class BatchPrefillBaseActionObj : public EngineActionObj { /*! \brief Check if the input requests can be prefilled under conditions. */ bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length, int num_required_pages, int num_available_pages, int current_total_seq_len, - int num_running_rsentries); + int num_running_rsentries, KVStateKind kv_state_kind); /*! * \brief Chunk the input of the given RequestModelState for prefill @@ -104,4 +104,4 @@ class BatchPrefillBaseActionObj : public EngineActionObj { } // namespace serve } // namespace llm -} // namespace mlc \ No newline at end of file +} // namespace mlc diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index 9f31ed22d6..1a8bec2eea 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -179,7 +179,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // - Slice and save hidden_states_for_sample last_accepted_hidden_positions.push_back(cum_verify_lengths[i] + accept_length - 1); } - if (!fully_accepted_rsentries.empty()) { + if (!fully_accepted_rsentries.empty() && + engine_config_->speculative_mode == SpeculativeMode::kEagle) { // - Run a step of batch decode for requests whose drafts are fully accepted. // When a request's draft is fully accepted, there is an extra token proposed // by the draft model but not added into the draft model's KV cache. @@ -239,9 +240,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj { // One step draft for the following steps // Gather hidden states for the last accepted tokens. - hidden_states = models_[draft_model_id_]->GatherHiddenStates( - hidden_states, last_accepted_hidden_positions, - &model_workspaces_[draft_model_id_].hidden_states); + // Use the function and the workspace of the verify model because the information about the + // hidden states is not available in the draft model for medusa. + hidden_states = models_[0]->GatherHiddenStates(hidden_states, last_accepted_hidden_positions, + &model_workspaces_[0].hidden_states); std::vector input_tokens; Array mstates; @@ -255,61 +257,50 @@ class EagleBatchVerifyActionObj : public EngineActionObj { input_tokens.push_back(mstates[i]->committed_tokens.back().sampled_token_id.first); } - // - Compute embeddings. - RECORD_EVENT(trace_recorder_, request_ids, "start proposal embedding"); - embeddings = models_[draft_model_id_]->TokenEmbed( - {IntTuple{input_tokens.begin(), input_tokens.end()}}); - RECORD_EVENT(trace_recorder_, request_ids, "finish proposal embedding"); - - // - Invoke model decode. - RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); - ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( - embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); - hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden( - fused_embedding_hidden_states, request_internal_ids); - - if (models_[draft_model_id_]->CanGetLogits()) { - logits = models_[draft_model_id_]->GetLogits(hidden_states); - } else { - // - Use base model's head. - logits = models_[0]->GetLogits(hidden_states); + Array multi_step_logits{nullptr}; // for medusa output + if (engine_config_->speculative_mode == SpeculativeMode::kEagle) { + // - Compute embeddings. + RECORD_EVENT(trace_recorder_, request_ids, "start proposal embedding"); + embeddings = models_[draft_model_id_]->TokenEmbed( + {IntTuple{input_tokens.begin(), input_tokens.end()}}); + RECORD_EVENT(trace_recorder_, request_ids, "finish proposal embedding"); + + // - Invoke model decode. + RECORD_EVENT(trace_recorder_, request_ids, "start proposal decode"); + ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden( + embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1); + hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden( + fused_embedding_hidden_states, request_internal_ids); + + int lm_head_model_id = models_[draft_model_id_]->CanGetLogits() ? draft_model_id_ : 0; + logits = models_[lm_head_model_id]->GetLogits(hidden_states); + RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); + ICHECK_EQ(logits->ndim, 2); + ICHECK_EQ(logits->shape[0], num_rsentries); + } else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) { + multi_step_logits = models_[draft_model_id_]->GetMultiStepLogits(hidden_states); } - RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode"); - ICHECK_EQ(logits->ndim, 2); - ICHECK_EQ(logits->shape[0], num_rsentries); - // - Update logits. - logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids); - - // - Compute probability distributions. - probs_on_device = - logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids); - - // - Sample tokens. // Fill range [0, num_rsentries) into `sample_indices`. std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( - probs_on_device, sample_indices, request_ids, generation_cfg); - std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); - ICHECK_EQ(sample_results.size(), num_rsentries); - // - Slice and save hidden_states_for_sample - draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_); - models_[draft_model_id_]->ScatterDraftProbs( - renormalized_probs, draft_token_slots_, - &model_workspaces_[verify_model_id_].draft_probs_storage); - models_[draft_model_id_]->ScatterHiddenStates( - hidden_states, draft_token_slots_, - &model_workspaces_[verify_model_id_].draft_hidden_states_storage); - // - Add draft token to the state. - for (int i = 0; i < num_rsentries; ++i) { - mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); - estate->stats.total_draft_length += 1; + if (engine_config_->speculative_mode == SpeculativeMode::kEagle) { + const auto& [renormalized_probs, sample_results] = + ApplyLogitProcessorAndSample(logit_processor_, sampler_, logits, generation_cfg, + request_ids, mstates, rngs, sample_indices); + UpdateRequestStatesWithDraftProposals(mstates, sample_results, draft_model_id_, + renormalized_probs, hidden_states, estate); + } else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) { + for (int draft_id = 0; draft_id < engine_config_->spec_draft_length; draft_id++) { + const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample( + logit_processor_, sampler_, multi_step_logits[draft_id], generation_cfg, request_ids, + mstates, rngs, sample_indices); + UpdateRequestStatesWithDraftProposals(mstates, sample_results, draft_model_id_, + renormalized_probs, hidden_states, estate); + } } } - auto tend = std::chrono::high_resolution_clock::now(); estate->stats.engine_total_decode_time += static_cast((tend - tstart).count()) / 1e9; @@ -371,6 +362,24 @@ class EagleBatchVerifyActionObj : public EngineActionObj { return num_required_pages <= num_available_pages; } + void UpdateRequestStatesWithDraftProposals(const Array& mstates, + const std::vector& sample_results, + int model_id, const NDArray& renormalized_probs, + const ObjectRef& hidden_states_for_sample, + EngineState estate) { + draft_token_workspace_manager_->AllocSlots(mstates.size(), &draft_token_slots_); + models_[0]->ScatterDraftProbs(renormalized_probs, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); + if (engine_config_->speculative_mode == SpeculativeMode::kEagle && + engine_config_->spec_draft_length > 1) { + models_[0]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_, + &model_workspaces_[0].draft_hidden_states_storage); + } + for (int i = 0; i < static_cast(mstates.size()); ++i) { + mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]); + estate->stats.total_draft_length += 1; + } + } /*! * \brief The model to run decode in. When there are multiple * models, the `Step` function of the created action will not take effect. diff --git a/cpp/serve/engine_actions/eagle_new_request_prefill.cc b/cpp/serve/engine_actions/eagle_new_request_prefill.cc index 2190cf61ed..a2da53e171 100644 --- a/cpp/serve/engine_actions/eagle_new_request_prefill.cc +++ b/cpp/serve/engine_actions/eagle_new_request_prefill.cc @@ -111,6 +111,11 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj { } } request_internal_ids.push_back(mstate->internal_id); + + if (engine_config_->speculative_mode == SpeculativeMode::kMedusa && model_id > 0) { + // Embedding is only needed for the base model in Medusa. + continue; + } RECORD_EVENT(trace_recorder_, prefill_inputs[i].rsentry->request->id, "start embedding"); // Speculative models shift left the input tokens by 1 when base model has committed tokens. // Note: for n > 1 cases Eagle doesn't work because parent entry doesn't shift input tokens. @@ -125,59 +130,56 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj { } RECORD_EVENT(trace_recorder_, request_ids, "start prefill"); - ObjectRef embedding_or_hidden_states{nullptr}; - if (model_id == 0) { - embedding_or_hidden_states = embeddings; - } else { - embedding_or_hidden_states = models_[model_id]->FuseEmbedHidden( - embeddings, hidden_states_for_input, /*batch_size*/ 1, /*seq_len*/ cum_prefill_length); - } - // hidden_states: (b * s, h) - ObjectRef hidden_states = models_[model_id]->BatchPrefillToLastHidden( - embedding_or_hidden_states, request_internal_ids, prefill_lengths); - RECORD_EVENT(trace_recorder_, request_ids, "finish prefill"); - if (model_id == 0) { - // We only need to sample for model 0 in prefill. - hidden_states_for_input = hidden_states; - } + Array multi_step_logits{nullptr}; - // Whether to use base model to get logits. - int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id; + if (model_id == 0 || engine_config_->speculative_mode == SpeculativeMode::kEagle) { + ObjectRef embedding_or_hidden_states{nullptr}; + if (model_id == 0) { + embedding_or_hidden_states = embeddings; + } else { + embedding_or_hidden_states = + models_[model_id]->FuseEmbedHidden(embeddings, hidden_states_for_input, + /*batch_size*/ 1, /*seq_len*/ cum_prefill_length); + } + // hidden_states: (b * s, h) + ObjectRef hidden_states = models_[model_id]->BatchPrefillToLastHidden( + embedding_or_hidden_states, request_internal_ids, prefill_lengths); + RECORD_EVENT(trace_recorder_, request_ids, "finish prefill"); - std::vector logit_positions; - { - // Prepare the logit positions - logit_positions.reserve(prefill_lengths.size()); - int total_len = 0; - for (int i = 0; i < prefill_lengths.size(); ++i) { - total_len += prefill_lengths[i]; - logit_positions.push_back(total_len - 1); + if (model_id == 0) { + // We only need to sample for model 0 in prefill. + hidden_states_for_input = hidden_states; } + + // Whether to use base model to get logits. + int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id; + + std::vector logit_positions; + { + // Prepare the logit positions + logit_positions.reserve(prefill_lengths.size()); + int total_len = 0; + for (int i = 0; i < prefill_lengths.size(); ++i) { + total_len += prefill_lengths[i]; + logit_positions.push_back(total_len - 1); + } + } + // hidden_states_for_sample: (b * s, h) + hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates( + hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states); + // logits_for_sample: (b * s, v) + logits_for_sample = models_[sample_model_id]->GetLogits(hidden_states_for_sample); + } else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) { + // Note: spec_draft_length in engine config has to be match the model config in Medusa. + multi_step_logits = models_[model_id]->GetMultiStepLogits(hidden_states_for_sample); + } else { + LOG(FATAL) << "unreachable"; } - // hidden_states_for_sample: (b * s, h) - hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates( - hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states); - // logits_for_sample: (b * s, v) - logits_for_sample = models_[sample_model_id]->GetLogits(hidden_states_for_sample); - // - Update logits. - ICHECK(logits_for_sample.defined()); - Array generation_cfg; - Array mstates_for_logitproc; - generation_cfg.reserve(num_rsentries); - mstates_for_logitproc.reserve(num_rsentries); - for (int i = 0; i < num_rsentries; ++i) { - generation_cfg.push_back(prefill_inputs[i].rsentry->request->generation_cfg); - mstates_for_logitproc.push_back(prefill_inputs[i].rsentry->mstates[sample_model_id]); - } - logit_processor_->InplaceUpdateLogits(logits_for_sample, generation_cfg, - mstates_for_logitproc, request_ids); - // - Compute probability distributions. - NDArray probs_on_device = - logit_processor_->ComputeProbsFromLogits(logits_for_sample, generation_cfg, request_ids); + Array request_ids_for_logitproc = request_ids; - // - Sample tokens. + // - Prepare the configurations for the sampler. // For prefill_inputs which have children, sample // one token for each rstate that is depending. // Otherwise, sample a token for the current rstate. @@ -185,12 +187,12 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj { std::vector rsentries_for_sample; std::vector rngs; std::vector rsentry_activated; + Array generation_cfg; sample_indices.reserve(num_rsentries); rsentries_for_sample.reserve(num_rsentries); rngs.reserve(num_rsentries); rsentry_activated.reserve(num_rsentries); request_ids.clear(); - generation_cfg.clear(); for (int i = 0; i < num_rsentries; ++i) { const RequestStateEntry& rsentry = prefill_inputs[i].rsentry; // No sample for rsentries with remaining inputs. @@ -251,45 +253,51 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj { } } - NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( - probs_on_device, sample_indices, request_ids, generation_cfg); - std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( - renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); - ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); - - // - Update the committed tokens of states. - // - If a request is first-time prefilled, set the prefill finish time. - auto tnow = std::chrono::high_resolution_clock::now(); - if (model_id == 0) { - UpdateRequestStateEntriesWithSampleResults(rsentries_for_sample, rsentry_activated, - sample_results); - // Add the sampled token as an input of the eagle models. - for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { - for (int mid = 1; mid < static_cast(models_.size()); ++mid) { - TokenData token_data = - Downcast(rsentries_for_sample[i]->mstates[mid]->inputs.back()); - std::vector token_ids = {token_data->token_ids.begin(), - token_data->token_ids.end()}; - token_ids.push_back(sample_results[i].sampled_token_id.first); - int ninputs = static_cast(rsentries_for_sample[i]->mstates[mid]->inputs.size()); - rsentries_for_sample[i]->mstates[mid]->inputs.Set( - ninputs - 1, TokenData(IntTuple(token_ids.begin(), token_ids.end()))); + // - Prepare input for logit processor. + ICHECK(logits_for_sample.defined()); + Array generation_cfg_for_logitproc; + Array mstates_for_logitproc; + generation_cfg_for_logitproc.reserve(num_rsentries); + mstates_for_logitproc.reserve(num_rsentries); + for (int i = 0; i < num_rsentries; ++i) { + generation_cfg_for_logitproc.push_back(prefill_inputs[i].rsentry->request->generation_cfg); + mstates_for_logitproc.push_back(prefill_inputs[i].rsentry->mstates[model_id]); + } + if (model_id == 0 || engine_config_->speculative_mode == SpeculativeMode::kEagle) { + const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample( + logit_processor_, sampler_, logits_for_sample, generation_cfg_for_logitproc, + request_ids_for_logitproc, mstates_for_logitproc, rngs, sample_indices); + if (model_id == 0) { + UpdateRequestStateEntriesWithSampleResults(rsentries_for_sample, rsentry_activated, + sample_results); + // Add the sampled token as an input of the eagle models. + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + for (int mid = 1; mid < static_cast(models_.size()); ++mid) { + TokenData token_data = + Downcast(rsentries_for_sample[i]->mstates[mid]->inputs.back()); + std::vector token_ids = {token_data->token_ids.begin(), + token_data->token_ids.end()}; + token_ids.push_back(sample_results[i].sampled_token_id.first); + int ninputs = static_cast(rsentries_for_sample[i]->mstates[mid]->inputs.size()); + rsentries_for_sample[i]->mstates[mid]->inputs.Set( + ninputs - 1, TokenData(IntTuple(token_ids.begin(), token_ids.end()))); + } } + } else { + // - Slice and save hidden_states_for_sample + UpdateRequestStatesWithDraftProposals(rsentries_for_sample, sample_results, model_id, + renormalized_probs, hidden_states_for_sample, + estate); } - } else { - // - Slice and save hidden_states_for_sample - draft_token_workspace_manager_->AllocSlots(rsentries_for_sample.size(), - &draft_token_slots_); - models_[model_id]->ScatterDraftProbs(renormalized_probs, draft_token_slots_, - &model_workspaces_[0].draft_probs_storage); - if (engine_config_->spec_draft_length > 1) { - models_[model_id]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_, - &model_workspaces_[0].draft_hidden_states_storage); - } - for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { - rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i], - draft_token_slots_[i]); - estate->stats.total_draft_length += 1; + } else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) { + for (int draft_id = 0; draft_id < engine_config_->spec_draft_length; ++draft_id) { + const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample( + logit_processor_, sampler_, multi_step_logits[draft_id], generation_cfg_for_logitproc, + request_ids_for_logitproc, mstates_for_logitproc, rngs, sample_indices); + + UpdateRequestStatesWithDraftProposals(rsentries_for_sample, sample_results, model_id, + renormalized_probs, + /*hidden_states=*/ObjectRef{nullptr}, estate); } } } @@ -302,6 +310,26 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj { return processed_requests; } + void UpdateRequestStatesWithDraftProposals( + const std::vector& rsentries_for_sample, + const std::vector& sample_results, int model_id, + const NDArray& renormalized_probs, const ObjectRef& hidden_states_for_sample, + EngineState estate) { + draft_token_workspace_manager_->AllocSlots(rsentries_for_sample.size(), &draft_token_slots_); + models_[0]->ScatterDraftProbs(renormalized_probs, draft_token_slots_, + &model_workspaces_[0].draft_probs_storage); + if (engine_config_->speculative_mode == SpeculativeMode::kEagle && + engine_config_->spec_draft_length > 1) { + models_[0]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_, + &model_workspaces_[0].draft_hidden_states_storage); + } + for (int i = 0; i < static_cast(rsentries_for_sample.size()); ++i) { + rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i], + draft_token_slots_[i]); + estate->stats.total_draft_length += 1; + } + } + private: /*! \brief The logit processor. */ LogitProcessor logit_processor_; diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index bdf28dfdb5..d63857d539 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -232,7 +232,6 @@ void FunctionTable::_InitFunctions() { } else { this->create_kv_cache_func_ = mod_get_func("create_tir_paged_kv_cache"); } - ICHECK(this->create_kv_cache_func_.defined()); } this->reset_kv_cache_func_ = get_global_func("vm.builtin.kv_state_clear"); this->kv_cache_add_sequence_func_ = get_global_func("vm.builtin.kv_state_add_sequence"); diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index e16432c222..89b25827b8 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -67,11 +67,7 @@ class ModelImpl : public ModelObj { // Step 3. Reset this->Reset(); // Step 4. Set model type - if (json::Lookup(model_config, "model_type").find("rwkv") != std::string::npos) { - this->kind = KVStateKind::kRNNState; - } else { - this->kind = KVStateKind::kKVCache; - } + this->kind = GetMetadata().kv_state_kind; } /*********************** Model Computation ***********************/ @@ -149,6 +145,21 @@ class ModelImpl : public ModelObj { return logits; } + Array GetMultiStepLogits(const ObjectRef& hidden_states) final { + NVTXScopedRange nvtx_scope("GetMultiStepLogits"); + CHECK(ft_.get_logits_func_.defined()) << "`get_logits` function is not found in the model."; + + ObjectRef hidden_states_dref_or_nd{nullptr}; + ObjectRef ret = ft_.get_logits_func_(hidden_states, params_); + Array logits{nullptr}; + if (ft_.use_disco) { + logits = Downcast(ret)->DebugGetFromRemote(0); + } else { + logits = Downcast>(ret); + } + return logits; + } + ObjectRef FuseEmbedHidden(const ObjectRef& embeddings, const ObjectRef& previous_hidden_states, int batch_size, int seq_len) final { NVTXScopedRange nvtx_scope("FuseEmbedHidden"); @@ -563,8 +574,9 @@ class ModelImpl : public ModelObj { /*********************** KV Cache Management ***********************/ void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size, int max_history_size, - KVStateKind kv_state_kind) final { + int prefill_chunk_size, int max_history_size) final { + // KVStateKind kv_state_kind) final { + KVStateKind kv_state_kind = GetMetadata().kv_state_kind; if (kv_state_kind == KVStateKind::kKVCache) { IntTuple max_num_sequence_tuple{max_num_sequence}; IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; @@ -576,30 +588,51 @@ class ModelImpl : public ModelObj { support_sliding_window); local_kv_cache_ = ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; - } else { + } else if (kv_state_kind == KVStateKind::kRNNState) { IntTuple max_num_sequence_tuple{max_num_sequence}; IntTuple max_history_size_tuple = {std::max(max_history_size, 1)}; kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_history_size_tuple); local_kv_cache_ = ft_.use_disco ? Downcast(kv_cache_)->DebugGetFromRemote(0) : kv_cache_; + } else if (kv_state_kind == KVStateKind::kNone) { + // Do nothing + } else { + LOG(FATAL) << "Unknown kv_state_kind: " << static_cast(kv_state_kind); } } - void AddNewSequence(int64_t seq_id) final { ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); } + void AddNewSequence(int64_t seq_id) final { + if (ft_.model_metadata_.kv_state_kind == KVStateKind::kNone) { + return; + } + ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id); + } void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos) final { + if (ft_.model_metadata_.kv_state_kind == KVStateKind::kNone) { + return; + } ft_.kv_cache_fork_sequence_func_(kv_cache_, parent_seq_id, child_seq_id, fork_pos); } void RemoveSequence(int64_t seq_id) final { + if (this->kind == KVStateKind::kNone) { + return; + } ft_.kv_cache_remove_sequence_func_(kv_cache_, seq_id); } void PopNFromKVCache(int64_t seq_id, int num_tokens) final { + if (this->kind == KVStateKind::kNone) { + return; + } ft_.kv_cache_popn_func_(kv_cache_, seq_id, num_tokens); } void EnableSlidingWindowForSeq(int64_t seq_id) final { + if (this->kind == KVStateKind::kNone) { + return; + } if (sliding_window_size_ != -1) { ft_.kv_cache_enable_sliding_window_for_seq_(kv_cache_, seq_id, sliding_window_size_, attention_sink_size_); @@ -620,7 +653,7 @@ class ModelImpl : public ModelObj { } int GetCurrentTotalSequenceLength() const final { - if (this->kind == KVStateKind::kRNNState) { + if (this->kind == KVStateKind::kRNNState || this->kind == KVStateKind::kNone) { // RNNState does not have a total sequence length limit return 0; } else { @@ -670,6 +703,9 @@ class ModelImpl : public ModelObj { } ObjectRef AllocEmbeddingTensor() final { + if (!ft_.alloc_embedding_tensor_func_.defined()) { + return ObjectRef{nullptr}; + } // Allocate the embedding tensor. ObjectRef embedding = ft_.alloc_embedding_tensor_func_(); // Get the shape of the embedding tensor for hidden size. @@ -690,6 +726,9 @@ class ModelImpl : public ModelObj { } ObjectRef AllocHiddenStatesTensor() final { + if (!ft_.alloc_embedding_tensor_func_.defined()) { + return ObjectRef{nullptr}; + } // Allocate the hidden_states tensor. // Use the same function as embeddings. ObjectRef hidden_states = ft_.alloc_embedding_tensor_func_(); @@ -778,6 +817,17 @@ class ModelImpl : public ModelObj { ft_.scatter_probs_func_(input, indices_device, *dst); } + Array GetMedusaLogits(const ObjectRef& hidden_states) { + ObjectRef result = ft_.get_logits_func_(hidden_states); + Array logits{nullptr}; + if (ft_.use_disco) { + logits = Downcast(result)->DebugGetFromRemote(0); + } else { + logits = Downcast>(result); + } + return logits; + } + /************** Debug/Profile **************/ void DebugCallFuncOnAllAllWorker(const String& func_name) final { diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 96d2ecb401..41fccf8d0b 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -139,6 +139,8 @@ class ModelObj : public Object { */ virtual NDArray GetLogits(const ObjectRef& last_hidden_states) = 0; + virtual Array GetMultiStepLogits(const ObjectRef& last_hidden_states) = 0; + /*! * \brief Batch prefill function. Embedding in, logits out. * The embedding order of sequences in `embedding_arr` follows @@ -224,11 +226,9 @@ class ModelObj : public Object { * are allowed to exist in the KV cache at any time. * \param max_history_size The maximum history size for RNN state to roll back. * The KV cache does not need this. - * \param kv_state_kind The kind of cache. It can be KV cache or RNN state. */ virtual void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size, int max_history_size, - KVStateKind kv_state_kind) = 0; + int prefill_chunk_size, int max_history_size) = 0; /*! \brief Add a new sequence with the given sequence id to the KV cache. */ virtual void AddNewSequence(int64_t seq_id) = 0; diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index d776ed146b..c6314f2c04 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -50,7 +50,7 @@ def main(argv): parser.add_argument( "--speculative-mode", type=str, - choices=["disable", "small_draft", "eagle"], + choices=["disable", "small_draft", "eagle", "medusa"], default="disable", help=HELP["speculative_mode_serve"] + ' (default: "%(default)s")', ) diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index 7aafc64738..a8a170c3ad 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -85,6 +85,14 @@ def _apply_preproc_to_params( return extra_tirs +def _infer_kv_state_kind(model_type) -> str: + if "rwkv" in model_type: + return "rnn_state" + if "medusa" in model_type: + return "none" + return "kv_cache" + + def _compile(args: CompileArgs, model_config: ConfigBase): def _get_variable_bounds(model_config) -> Dict[str, int]: if hasattr(model_config, "sliding_window_size"): @@ -178,6 +186,7 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: "prefill_chunk_size": model_config.prefill_chunk_size, # type: ignore "tensor_parallel_shards": model_config.tensor_parallel_shards, # type: ignore "kv_cache_bytes": kv_cache_bytes, + "kv_state_kind": _infer_kv_state_kind(args.model.name), } logger.info("Registering metadata: %s", metadata) metadata["params"] = [_get_param_metadata(name, param) for name, param in named_params] diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index d1cde12678..acf6ead514 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -23,7 +23,7 @@ def serve( prefill_chunk_size: Optional[int], max_history_size: Optional[int], gpu_memory_utilization: Optional[float], - speculative_mode: Literal["disable", "small_draft", "eagle"], + speculative_mode: Literal["disable", "small_draft", "eagle", "medusa"], spec_draft_length: int, enable_tracing: bool, host: str, diff --git a/python/mlc_llm/model/medusa/__init__.py b/python/mlc_llm/model/medusa/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/medusa/medusa_loader.py b/python/mlc_llm/model/medusa/medusa_loader.py new file mode 100644 index 0000000000..41bef4d98d --- /dev/null +++ b/python/mlc_llm/model/medusa/medusa_loader.py @@ -0,0 +1,51 @@ +""" +This file specifies how MLC's Medusa parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .medusa_model import MedusaConfig, MedusaModel + + +def huggingface(model_config: MedusaConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : MedusaConfig + The configuration of the Medusa model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = MedusaModel(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_llm/model/medusa/medusa_model.py b/python/mlc_llm/model/medusa/medusa_model.py new file mode 100644 index 0000000000..af21164421 --- /dev/null +++ b/python/mlc_llm/model/medusa/medusa_model.py @@ -0,0 +1,83 @@ +"""Medusa model definition.""" +import dataclasses +from typing import Any, Dict, Optional + +from tvm.relax.frontend import nn + +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class MedusaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Llama model.""" + + medusa_num_heads: int + medusa_num_layers: int + hidden_size: int + vocab_size: int + max_batch_size: int = 1 + tensor_parallel_shards: int = 1 + + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + # Unused parameters. Kept for compatibility with the compilation flow. + prefill_chunk_size: int = -1 + context_window_size: int = -1 + + +# pylint: disable=missing-docstring + + +class ResBlock(nn.Module): + """Residual block with SiLU activation.""" + + def __init__(self, hidden_size): + super().__init__() + self.linear = nn.Linear(hidden_size, hidden_size) + self.act = nn.SiLU() + + def forward(self, x): + return x + self.act(self.linear(x)) + + +class MedusaModel(nn.Module): + """Medusa model definition.""" + + def __init__(self, config: MedusaConfig): + self.hidden_size = config.hidden_size + self.dtype = "float32" + self.medusa_head = nn.ModuleList( + [ + nn.ModuleList( + [ResBlock(config.hidden_size) for _ in range(config.medusa_num_layers)] + + [nn.Linear(config.hidden_size, config.vocab_size, bias=False)] + ) + for _ in range(config.medusa_num_heads) + ] + ) + + def get_default_spec(self): + mod_spec = { + "get_logits": { + "hidden_states": nn.spec.Tensor(["batch_size", self.hidden_size], self.dtype), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) + + def get_logits(self, hidden_states: nn.Tensor): + logits = [] + for head in self.medusa_head: + logits.append(head(hidden_states).astype("float32")) + return logits + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype diff --git a/python/mlc_llm/model/medusa/medusa_quantization.py b/python/mlc_llm/model/medusa/medusa_quantization.py new file mode 100644 index 0000000000..9fb2b6c255 --- /dev/null +++ b/python/mlc_llm/model/medusa/medusa_quantization.py @@ -0,0 +1,20 @@ +"""This file specifies how MLC's Medusa parameters are quantized.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import NoQuantize + +from .medusa_model import MedusaConfig, MedusaModel + + +def no_quant( + model_config: MedusaConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama2 model without quantization.""" + model: nn.Module = MedusaModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 08d272f409..042bd7ceaa 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -19,6 +19,7 @@ from .internlm import internlm_loader, internlm_model, internlm_quantization from .llama import llama_loader, llama_model, llama_quantization from .llava import llava_loader, llava_model, llava_quantization +from .medusa import medusa_loader, medusa_model, medusa_quantization from .mistral import mistral_loader, mistral_model, mistral_quantization from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization from .orion import orion_loader, orion_model, orion_quantization @@ -385,4 +386,16 @@ class Model: "ft-quant": bert_quantization.ft_quant, }, ), + "medusa": Model( + name="medusa", + model=medusa_model.MedusaModel, + config=medusa_model.MedusaConfig, + source={ + "huggingface-torch": medusa_loader.huggingface, + "huggingface-safetensor": medusa_loader.huggingface, + }, + quantize={ + "no-quant": medusa_quantization.no_quant, + }, + ), } diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 916403839a..2dbaaf36a6 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -194,11 +194,12 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes kv_state_kind: Optional[Literal["kv_cache", "rnn_state"]] The kind of cache. - speculative_mode : Literal["disable", "small_draft", "eagle"] + speculative_mode : Literal["disable", "small_draft", "eagle", "medusa"] The speculative mode. "disable" means speculative decoding is disabled. "small_draft" means the normal speculative decoding (small draft) mode. "eagle" means the eagle-style speculative decoding. + "medusa" means the medusa-style speculative decoding. spec_draft_length : int The number of tokens to generate in speculative proposal (draft). @@ -220,7 +221,7 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes prefill_chunk_size: Optional[int] = None max_history_size: Optional[int] = None kv_state_kind: Optional[Literal["kv_cache", "rnn_state"]] = None - speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable" + speculative_mode: Literal["disable", "small_draft", "eagle", "medusa"] = "disable" spec_draft_length: int = 4 verbose: bool = True diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index c99dbd4794..896930e684 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -827,11 +827,12 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. - speculative_mode : Literal["disable", "small_draft", "eagle"] + speculative_mode : Literal["disable", "small_draft", "eagle", "medusa"] The speculative mode. "disable" means speculative decoding is disabled. "small_draft" means the normal speculative decoding (small draft) mode. "eagle" means the eagle-style speculative decoding. + "medusa" means the medusa-style speculative decoding. spec_draft_length : int The number of tokens to generate in speculative proposal (draft). @@ -856,7 +857,7 @@ def __init__( # pylint: disable=too-many-arguments prefill_chunk_size: Optional[int] = None, max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, - speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable", + speculative_mode: Literal["disable", "small_draft", "eagle", "medusa"] = "disable", spec_draft_length: int = 4, enable_tracing: bool = False, verbose: bool = True, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 641c8f6ed5..12b495dfca 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -425,7 +425,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals prefill_chunk_size: Optional[int], max_history_size: Optional[int], gpu_memory_utilization: Optional[float], - speculative_mode: Literal["disable", "small_draft", "eagle"], + speculative_mode: Literal["disable", "small_draft", "eagle", "medusa"], spec_draft_length: int, enable_tracing: bool, verbose: bool, From 2bbbd52cde62aeed2d0a6f7975c5af81ba84da4a Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 14 May 2024 20:30:45 -0700 Subject: [PATCH 310/531] Fix cublas offloading (#2343) --- python/mlc_llm/compiler_pass/cublas_dispatch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mlc_llm/compiler_pass/cublas_dispatch.py b/python/mlc_llm/compiler_pass/cublas_dispatch.py index f5af94cc4b..b8e461e945 100644 --- a/python/mlc_llm/compiler_pass/cublas_dispatch.py +++ b/python/mlc_llm/compiler_pass/cublas_dispatch.py @@ -20,7 +20,8 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR model_names = [ gv.name_hint for gv, func in mod.functions.items() if isinstance(func, relax.Function) ] - model_names = [name for name in model_names if "batch" not in name] + # exclude single batch decode + model_names = [name for name in model_names if "batch" in name or "decode" not in name] mod = tvm.transform.Sequential( [ relax.transform.FuseOpsByPattern( From 227dbb87260b2e14d030a6d880c4f69d475c7022 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Tue, 14 May 2024 22:28:32 -0700 Subject: [PATCH 311/531] Add false for arg worker0_only in disco.empty (#2344) --- cpp/llm_chat.cc | 2 +- cpp/serve/function_table.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 93de185eb2..a8d2edc11a 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -302,7 +302,7 @@ struct FunctionTable { Device null_device{DLDeviceType(0), 0}; if (this->use_disco) { DRef empty_func = sess->GetGlobalFunc("runtime.disco.empty"); - return sess->CallPacked(empty_func, shape, dtype, null_device); + return sess->CallPacked(empty_func, shape, dtype, null_device, false); } else { return NDArray::Empty(shape, dtype, device); } diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index d63857d539..2ed864f298 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -270,7 +270,7 @@ ObjectRef FunctionTable::Empty(ShapeTuple shape, DataType dtype, Device device) Device null_device{DLDeviceType(0), 0}; if (this->use_disco) { DRef empty_func = sess->GetGlobalFunc("runtime.disco.empty"); - return sess->CallPacked(empty_func, shape, dtype, null_device); + return sess->CallPacked(empty_func, shape, dtype, null_device, false); } else { return NDArray::Empty(shape, dtype, device); } From 9b89e048a5bcd84b68f9df3675d1599e502884df Mon Sep 17 00:00:00 2001 From: Git bot Date: Wed, 15 May 2024 05:50:15 +0000 Subject: [PATCH 312/531] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index c8f7ec8dc0..ce58d63453 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c8f7ec8dc0377ad362e1c81b194c6e2322f27a75 +Subproject commit ce58d63453ff83b930fa2be665647621b2eec4d2 From 56ea1560a02a3672f6b7802853447236e777cd60 Mon Sep 17 00:00:00 2001 From: Animesh Bohara Date: Wed, 15 May 2024 21:48:39 +0530 Subject: [PATCH 313/531] [JSONFFIEngine] Refactor device argument and request_stream_callback argument (#2334) * 1. Refactor init_background_engine in JSONFFIEngine to use device_type and device_id arguments. 2. request_stream_callback is called on each string of the array of strings. * Calling callback on string of list of JSON dicts instead of each string of JSON dict multiple times --------- Co-authored-by: Animesh Bohara --- cpp/json_ffi/json_ffi_engine.cc | 22 ++++++++++++---------- python/mlc_llm/json_ffi/engine.py | 22 ++++++++++++---------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 343266135c..98d00061a8 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -44,7 +44,9 @@ void JSONFFIEngine::StreamBackError(std::string request_id) { response.model = "json_ffi"; // TODO: Return model name from engine (or from args) response.system_fingerprint = ""; - this->request_stream_callback_(Array{picojson::value(response.AsJSON()).serialize()}); + picojson::array response_arr; + response_arr.push_back(picojson::value(response.AsJSON())); + this->request_stream_callback_(picojson::value(response_arr).serialize()); } bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request_id) { @@ -117,8 +119,9 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_END(); - void InitBackgroundEngine(Device device, Optional request_stream_callback, - Optional trace_recorder) { + void InitBackgroundEngine(int device_type, int device_id, + Optional request_stream_callback) { + DLDevice device{static_cast(device_type), device_id}; this->device_ = device; CHECK(request_stream_callback.defined()) << "JSONFFIEngine requires request stream callback function, but it is not given."; @@ -127,13 +130,12 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { ICHECK_EQ(args.size(), 1); Array delta_outputs = args[0]; - Array responses = this->GetResponseFromStreamOutput(delta_outputs); + String responses = this->GetResponseFromStreamOutput(delta_outputs); this->request_stream_callback_(responses); }; request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - this->engine_->InitThreadedEngine(device, std::move(request_stream_callback), - std::move(trace_recorder)); + this->engine_->InitThreadedEngine(device, std::move(request_stream_callback), NullOpt); } void Reload(String engine_config_json_str) { @@ -169,7 +171,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { void RunBackgroundStreamBackLoop() { this->engine_->RunBackgroundStreamBackLoop(); } - Array GetResponseFromStreamOutput(Array delta_outputs) { + String GetResponseFromStreamOutput(Array delta_outputs) { std::unordered_map> response_map; for (const auto& delta_output : delta_outputs) { std::string request_id = delta_output->request_id; @@ -211,16 +213,16 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { response_map[request_id].push_back(choice); } - Array response_arr; + picojson::array response_arr; for (const auto& [request_id, choices] : response_map) { ChatCompletionStreamResponse response; response.id = request_id; response.choices = choices; response.model = "json_ffi"; // TODO: Return model name from engine (or from args) response.system_fingerprint = ""; - response_arr.push_back(picojson::value(response.AsJSON()).serialize()); + response_arr.push_back(picojson::value(response.AsJSON())); } - return response_arr; + return picojson::value(response_arr).serialize(); } }; diff --git a/python/mlc_llm/json_ffi/engine.py b/python/mlc_llm/json_ffi/engine.py index 237319a926..9a95d4b0a4 100644 --- a/python/mlc_llm/json_ffi/engine.py +++ b/python/mlc_llm/json_ffi/engine.py @@ -1,5 +1,6 @@ # pylint: disable=chained-comparison,missing-docstring,too-few-public-methods,too-many-instance-attributes # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable +import json import queue import threading from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union @@ -20,17 +21,15 @@ class EngineState: sync_queue: queue.Queue - def get_request_stream_callback(self) -> Callable[[List[str]], None]: + def get_request_stream_callback(self) -> Callable[[str], None]: # ChatCompletionStreamResponse - def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: + def _callback(chat_completion_stream_responses_json_str: str) -> None: self._sync_request_stream_callback(chat_completion_stream_responses_json_str) return _callback - def _sync_request_stream_callback( - self, chat_completion_stream_responses_json_str: List[str] - ) -> None: + def _sync_request_stream_callback(self, chat_completion_stream_responses_json_str: str) -> None: # Put the delta outputs to the queue in the unblocking way. self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) @@ -125,7 +124,9 @@ def _background_stream_back_loop(): verbose=False, ) - self._ffi["init_background_engine"](device, self.state.get_request_stream_callback(), None) + self._ffi["init_background_engine"]( + device.device_type, device.device_id, self.state.get_request_stream_callback() + ) self._ffi["reload"](self.engine_config.asjson()) def terminate(self): @@ -210,11 +211,12 @@ def _handle_chat_completion( try: while num_unfinished_requests > 0: - chat_completion_stream_responses_json_str = self.state.sync_queue.get() - for chat_completion_response_json_str in chat_completion_stream_responses_json_str: + chat_completion_responses_json_str = self.state.sync_queue.get() + chat_completion_responses_list = json.loads(chat_completion_responses_json_str) + for chat_completion_response_json_dict in chat_completion_responses_list: chat_completion_response = ( - openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( - chat_completion_response_json_str + openai_api_protocol.ChatCompletionStreamResponse.model_validate( + chat_completion_response_json_dict ) ) for choice in chat_completion_response.choices: From 152ecc43cf20158ff9cd89a9d2398142f6a61067 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Thu, 16 May 2024 05:55:11 -0700 Subject: [PATCH 314/531] [Serving] Add reset_engine in debug_entrypoints (#2347) --- cpp/serve/engine_state.cc | 2 ++ cpp/serve/threaded_engine.cc | 1 + python/mlc_llm/serve/engine_base.py | 5 ++++ .../serve/entrypoints/debug_entrypoints.py | 30 +++++++++++++++++-- 4 files changed, 36 insertions(+), 2 deletions(-) diff --git a/cpp/serve/engine_state.cc b/cpp/serve/engine_state.cc index 7847f53fd5..1882ad59ad 100644 --- a/cpp/serve/engine_state.cc +++ b/cpp/serve/engine_state.cc @@ -43,6 +43,8 @@ void EngineStats::Reset() { total_decode_length = 0; total_accepted_length = 0; total_draft_length = 0; + accept_count.clear(); + draft_count.clear(); } TVM_REGISTER_OBJECT_TYPE(EngineStateObj); diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index 8c3cadd358..33fc39e93f 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -374,6 +374,7 @@ class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("get_complete_engine_config", &ThreadedEngineImpl::GetCompleteEngineConfigJSONString); TVM_MODULE_VTABLE_ENTRY("stats", &ThreadedEngineImpl::Stats); + TVM_MODULE_VTABLE_ENTRY("reset", &ThreadedEngineImpl::Reset); TVM_MODULE_VTABLE_ENTRY("debug_call_func_on_all_worker", &ThreadedEngineImpl::DebugCallFuncOnAllAllWorker); TVM_MODULE_VTABLE_END(); diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 12b495dfca..22acedc271 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -468,6 +468,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "get_default_generation_config", "get_complete_engine_config", "stats", + "reset", "debug_call_func_on_all_worker", ] } @@ -533,6 +534,10 @@ def stats(self): """Get the engine stats.""" return self._ffi["stats"]() + def reset(self): + """Reset the engine, clear the running data and statistics.""" + return self._ffi["reset"]() + def process_chat_completion_request( # pylint: disable=too-many-arguments request: openai_api_protocol.ChatCompletionRequest, diff --git a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py index 9f6508ea42..d62bd78d77 100644 --- a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py @@ -34,7 +34,7 @@ async def debug_dump_event_trace(request: fastapi.Request): HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" ) - # - Check the requested model. + # Check the requested model. model = request_dict["model"] server_context: ServerContext = ServerContext.current() @@ -99,7 +99,7 @@ async def debug_dump_engine_stats(request: fastapi.Request): HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" ) - # - Check the requested model. + # Check the requested model. model = request_dict["model"] server_context: ServerContext = ServerContext.current() @@ -107,3 +107,29 @@ async def debug_dump_engine_stats(request: fastapi.Request): res = async_engine.stats() print(res) return json.loads(res) + + +@app.post("/debug/reset_engine") +async def debug_reset_engine_stats(request: fastapi.Request): + """Reset the engine, clean up all running data and statistics.""" + # Get the raw request body as bytes + request_raw_data = await request.body() + request_json_str = request_raw_data.decode("utf-8") + try: + # Parse the JSON string + request_dict = json.loads(request_json_str) + except json.JSONDecodeError: + return error_protocol.create_error_response( + HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" + ) + if "model" not in request_dict: + return error_protocol.create_error_response( + HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" + ) + + # Check the requested model. + model = request_dict["model"] + + server_context: ServerContext = ServerContext.current() + async_engine = server_context.get_engine(model) + async_engine.reset() From ac1cd51b14501cf046203f370a31c4b27ea63c00 Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Fri, 17 May 2024 20:38:55 -0700 Subject: [PATCH 315/531] [Bugfix] Make sequence_length dtype int64 in EngineConfig. Fix Mistral engine serving issue (#2358) * [Bugfix] Make sequence_length dtype int64 in EngineConfig. Fix Mistral engine serving issue --- cpp/serve/config.cc | 7 ++++--- cpp/serve/config.h | 6 +++--- cpp/serve/engine_actions/action_commons.cc | 2 +- cpp/serve/engine_actions/action_commons.h | 2 +- cpp/serve/engine_actions/batch_prefill_base.cc | 3 ++- cpp/serve/model.cc | 4 ++-- cpp/serve/model.h | 4 ++-- cpp/serve/request_state.cc | 2 +- cpp/serve/request_state.h | 3 ++- python/mlc_llm/serve/config.py | 6 +++--- 10 files changed, 21 insertions(+), 18 deletions(-) diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index cbc4c6c613..367bda701a 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -392,9 +392,10 @@ Result GetModelConfigLimits(const std::vector(model_configs[i], "model_config"); // - The maximum single sequence length is the minimum context window size among all models. int64_t runtime_context_window_size = - json::Lookup(model_configs[i], "context_window_size"); + json::LookupOptional(model_configs[i], "context_window_size").value_or(-1); int64_t compile_time_context_window_size = - json::Lookup(compile_time_model_config, "context_window_size"); + json::LookupOptional(compile_time_model_config, "context_window_size") + .value_or(-1); if (runtime_context_window_size > compile_time_context_window_size) { return Result::Error( "Model " + std::to_string(i) + "'s runtime context window size (" + @@ -458,7 +459,7 @@ Result EstimateMemoryUsageOnMode( InferrableEngineConfig init_config, bool verbose) { std::ostringstream os; InferrableEngineConfig inferred_config = init_config; - // - 1. max_mum_sequence + // - 1. max_num_sequence if (!init_config.max_num_sequence.has_value()) { if (mode == EngineMode::kLocal) { inferred_config.max_num_sequence = diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 2680eb755c..04b6b637f9 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -158,14 +158,14 @@ class EngineConfigNode : public Object { */ int max_num_sequence = 4; /*! \brief The maximum length allowed for a single sequence in the engine. */ - int max_total_sequence_length = 4096; + int64_t max_total_sequence_length = 4096; /*! * \brief The maximum total number of tokens whose KV data are allowed * to exist in the KV cache at any time. */ - int max_single_sequence_length = 4096; + int64_t max_single_sequence_length = 4096; /*! \brief The maximum total sequence length in a prefill. */ - int prefill_chunk_size = 1024; + int64_t prefill_chunk_size = 1024; /*! \brief The maximum history size for RNN state. KV cache does not need this. */ int max_history_size = 0; diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index 3289ef57c6..7354054187 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -85,7 +85,7 @@ void ProcessFinishedRequestStateEntries(std::vector finished_ void ActionStepPostProcess(Array requests, EngineState estate, Array models, const Tokenizer& tokenizer, FRequestStreamCallback request_stream_callback, - int max_single_sequence_length) { + int64_t max_single_sequence_length) { NVTXScopedRange nvtx_scope("EngineAction postproc"); std::vector finished_rsentries; finished_rsentries.reserve(requests.size()); diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index de98e11e67..1844ba97e9 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -44,7 +44,7 @@ void RemoveRequestFromModel(EngineState estate, int64_t req_internal_id, Array requests, EngineState estate, Array models, const Tokenizer& tokenizer, FRequestStreamCallback request_stream_callback, - int max_single_sequence_length); + int64_t max_single_sequence_length); /*! * \brief Preempt the last running request state entry from `running_queue`. diff --git a/cpp/serve/engine_actions/batch_prefill_base.cc b/cpp/serve/engine_actions/batch_prefill_base.cc index f570551417..b96727b985 100644 --- a/cpp/serve/engine_actions/batch_prefill_base.cc +++ b/cpp/serve/engine_actions/batch_prefill_base.cc @@ -130,7 +130,8 @@ bool BatchPrefillBaseActionObj::CanPrefill(EngineState estate, int num_prefill_r ? (engine_config_->spec_draft_length + 1) : 1; if ((num_running_rsentries + num_prefill_rsentries) * spec_factor > - std::min(engine_config_->max_num_sequence, engine_config_->prefill_chunk_size)) { + std::min(static_cast(engine_config_->max_num_sequence), + engine_config_->prefill_chunk_size)) { return false; } diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 89b25827b8..3fb5f8a4ea 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -573,8 +573,8 @@ class ModelImpl : public ModelObj { /*********************** KV Cache Management ***********************/ - void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size, int max_history_size) final { + void CreateKVCache(int page_size, int max_num_sequence, int64_t max_total_sequence_length, + int64_t prefill_chunk_size, int max_history_size) final { // KVStateKind kv_state_kind) final { KVStateKind kv_state_kind = GetMetadata().kv_state_kind; if (kv_state_kind == KVStateKind::kKVCache) { diff --git a/cpp/serve/model.h b/cpp/serve/model.h index 41fccf8d0b..f27795f66f 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -227,8 +227,8 @@ class ModelObj : public Object { * \param max_history_size The maximum history size for RNN state to roll back. * The KV cache does not need this. */ - virtual void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, - int prefill_chunk_size, int max_history_size) = 0; + virtual void CreateKVCache(int page_size, int max_num_sequence, int64_t max_total_sequence_length, + int64_t prefill_chunk_size, int max_history_size) = 0; /*! \brief Add a new sequence with the given sequence id to the KV cache. */ virtual void AddNewSequence(int64_t seq_id) = 0; diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 4c59ae52a2..a542c1c9b5 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -118,7 +118,7 @@ RequestStateEntry::RequestStateEntry( } DeltaRequestReturn RequestStateEntryNode::GetReturnTokenIds(const Tokenizer& tokenizer, - int max_single_sequence_length) { + int64_t max_single_sequence_length) { std::vector return_token_ids; std::vector logprob_json_strs; Optional finish_reason; diff --git a/cpp/serve/request_state.h b/cpp/serve/request_state.h index 79abcb1a24..5eec3fe82a 100644 --- a/cpp/serve/request_state.h +++ b/cpp/serve/request_state.h @@ -202,7 +202,8 @@ class RequestStateEntryNode : public Object { * \return The delta token ids to return, the logprob JSON strings of each delta token id, and * the optional finish reason. */ - DeltaRequestReturn GetReturnTokenIds(const Tokenizer& tokenizer, int max_single_sequence_length); + DeltaRequestReturn GetReturnTokenIds(const Tokenizer& tokenizer, + int64_t max_single_sequence_length); static constexpr const char* _type_key = "mlc.serve.RequestStateEntry"; static constexpr const bool _type_has_method_sequal_reduce = false; diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 2dbaaf36a6..722a5bd6af 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -179,12 +179,12 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes processed by the KV cache at any time. max_total_sequence_length : Optional[int] - The maximum length allowed for a single sequence in the engine. - - max_single_sequence_length : Optional[int] The maximum total number of tokens whose KV data are allowed to exist in the KV cache at any time. + max_single_sequence_length : Optional[int] + The maximum length allowed for a single sequence in the engine. + prefill_chunk_size : Optional[int] The maximum total sequence length in a prefill. From 96fc28994a30c35939c86f28f86a0d7a552a435f Mon Sep 17 00:00:00 2001 From: Kartik Khandelwal Date: Sat, 18 May 2024 09:43:50 -0400 Subject: [PATCH 316/531] [JSON FFI] Example Android Application using JSON FFI Engine (#2322) * pass str to callback and not List[str] add json ffif android example fix lint Refactor MLCEngineExample and MLCEngine.kt Use ChatCompletionMessageContent class ChatCompletionMessageContent: text and parts * JSONFFIEngine: Cast request_stream_callback argument to std::string. Decode in Android as List --------- Co-authored-by: Animesh Bohara --- android/MLCChat/settings.gradle | 1 + android/MLCEngineExample/README.md | 6 + android/MLCEngineExample/app/.gitignore | 2 + android/MLCEngineExample/app/build.gradle | 73 +++++++ .../MLCEngineExample/app/proguard-rules.pro | 21 ++ .../app/src/main/AndroidManifest.xml | 41 ++++ .../app/src/main/ic_launcher-playstore.png | Bin 0 -> 47710 bytes .../ai/mlc/mlcengineexample/MainActivity.kt | 73 +++++++ .../ai/mlc/mlcengineexample/ui/theme/Color.kt | 44 ++++ .../ai/mlc/mlcengineexample/ui/theme/Theme.kt | 107 ++++++++++ .../ai/mlc/mlcengineexample/ui/theme/Type.kt | 34 ++++ .../res/drawable/ic_android_black_24dp.xml | 5 + .../src/main/res/drawable/mlc_logo_108.xml | 11 + .../app/src/main/res/values/colors.xml | 10 + .../app/src/main/res/values/strings.xml | 3 + .../app/src/main/res/values/themes.xml | 6 + .../app/src/main/res/xml/backup_rules.xml | 13 ++ .../main/res/xml/data_extraction_rules.xml | 19 ++ android/MLCEngineExample/build.gradle | 5 + android/MLCEngineExample/bundle_weight.py | 65 ++++++ android/MLCEngineExample/gradle.properties | 23 +++ .../gradle/wrapper/gradle-wrapper.jar | Bin 0 -> 59203 bytes .../gradle/wrapper/gradle-wrapper.properties | 6 + android/MLCEngineExample/gradlew | 185 +++++++++++++++++ android/MLCEngineExample/gradlew.bat | 89 ++++++++ .../MLCEngineExample/mlc-package-config.json | 10 + android/MLCEngineExample/settings.gradle | 18 ++ android/mlc4j/build.gradle | 3 +- .../java/ai/mlc/mlcllm/JSONFFIEngine.java | 83 ++++++++ .../src/main/java/ai/mlc/mlcllm/MLCEngine.kt | 133 ++++++++++++ .../main/java/ai/mlc/mlcllm/OpenAIProtocol.kt | 191 ++++++++++++++++++ cpp/json_ffi/json_ffi_engine.cc | 2 +- 32 files changed, 1280 insertions(+), 2 deletions(-) create mode 100644 android/MLCEngineExample/README.md create mode 100644 android/MLCEngineExample/app/.gitignore create mode 100644 android/MLCEngineExample/app/build.gradle create mode 100644 android/MLCEngineExample/app/proguard-rules.pro create mode 100644 android/MLCEngineExample/app/src/main/AndroidManifest.xml create mode 100644 android/MLCEngineExample/app/src/main/ic_launcher-playstore.png create mode 100644 android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/MainActivity.kt create mode 100644 android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Color.kt create mode 100644 android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Theme.kt create mode 100644 android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Type.kt create mode 100644 android/MLCEngineExample/app/src/main/res/drawable/ic_android_black_24dp.xml create mode 100644 android/MLCEngineExample/app/src/main/res/drawable/mlc_logo_108.xml create mode 100644 android/MLCEngineExample/app/src/main/res/values/colors.xml create mode 100644 android/MLCEngineExample/app/src/main/res/values/strings.xml create mode 100644 android/MLCEngineExample/app/src/main/res/values/themes.xml create mode 100644 android/MLCEngineExample/app/src/main/res/xml/backup_rules.xml create mode 100644 android/MLCEngineExample/app/src/main/res/xml/data_extraction_rules.xml create mode 100644 android/MLCEngineExample/build.gradle create mode 100644 android/MLCEngineExample/bundle_weight.py create mode 100644 android/MLCEngineExample/gradle.properties create mode 100644 android/MLCEngineExample/gradle/wrapper/gradle-wrapper.jar create mode 100644 android/MLCEngineExample/gradle/wrapper/gradle-wrapper.properties create mode 100755 android/MLCEngineExample/gradlew create mode 100644 android/MLCEngineExample/gradlew.bat create mode 100644 android/MLCEngineExample/mlc-package-config.json create mode 100644 android/MLCEngineExample/settings.gradle create mode 100644 android/mlc4j/src/main/java/ai/mlc/mlcllm/JSONFFIEngine.java create mode 100644 android/mlc4j/src/main/java/ai/mlc/mlcllm/MLCEngine.kt create mode 100644 android/mlc4j/src/main/java/ai/mlc/mlcllm/OpenAIProtocol.kt diff --git a/android/MLCChat/settings.gradle b/android/MLCChat/settings.gradle index 6866480997..b19a744002 100644 --- a/android/MLCChat/settings.gradle +++ b/android/MLCChat/settings.gradle @@ -16,3 +16,4 @@ rootProject.name = "MLCChat" include ':app' include ':mlc4j' project(':mlc4j').projectDir = file('dist/lib/mlc4j') +include ':mlcengineexample' diff --git a/android/MLCEngineExample/README.md b/android/MLCEngineExample/README.md new file mode 100644 index 0000000000..977c84d295 --- /dev/null +++ b/android/MLCEngineExample/README.md @@ -0,0 +1,6 @@ +# MLC-LLM Android + +Checkout [Documentation page](https://llm.mlc.ai/docs/deploy/android.html) for more information. + +- run `mlc_llm package` +- open this `MLCEngineExample/` folder as a project in Android Studio diff --git a/android/MLCEngineExample/app/.gitignore b/android/MLCEngineExample/app/.gitignore new file mode 100644 index 0000000000..558f311c28 --- /dev/null +++ b/android/MLCEngineExample/app/.gitignore @@ -0,0 +1,2 @@ +/build +/src/main/libs \ No newline at end of file diff --git a/android/MLCEngineExample/app/build.gradle b/android/MLCEngineExample/app/build.gradle new file mode 100644 index 0000000000..c6b902bff5 --- /dev/null +++ b/android/MLCEngineExample/app/build.gradle @@ -0,0 +1,73 @@ +plugins { + id 'com.android.application' + id 'org.jetbrains.kotlin.android' +} + +android { + namespace 'ai.mlc.mlcengineexample' + compileSdk 34 + + defaultConfig { + applicationId "ai.mlc.mlcengineexample" + minSdk 26 + targetSdk 33 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + vectorDrawables { + useSupportLibrary true + } + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } + kotlinOptions { + jvmTarget = '1.8' + } + buildFeatures { + compose true + } + composeOptions { + kotlinCompilerExtensionVersion '1.4.3' + } + packagingOptions { + resources { + excludes += '/META-INF/{AL2.0,LGPL2.1}' + } + } +} + +dependencies { + implementation project(":mlc4j") + implementation 'androidx.core:core-ktx:1.10.1' + implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.6.1' + implementation 'androidx.activity:activity-compose:1.7.1' + implementation platform('androidx.compose:compose-bom:2022.10.00') + implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.6.1' + implementation 'androidx.compose.ui:ui' + implementation 'androidx.compose.ui:ui-graphics' + implementation 'androidx.compose.ui:ui-tooling-preview' + implementation 'androidx.compose.material3:material3:1.1.0' + implementation 'androidx.compose.material:material-icons-extended' + implementation 'androidx.appcompat:appcompat:1.6.1' + implementation 'androidx.navigation:navigation-compose:2.5.3' + implementation 'com.google.code.gson:gson:2.10.1' + implementation fileTree(dir: 'src/main/libs', include: ['*.aar', '*.jar'], exclude: []) + testImplementation 'junit:junit:4.13.2' + androidTestImplementation 'androidx.test.ext:junit:1.1.5' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1' + androidTestImplementation platform('androidx.compose:compose-bom:2022.10.00') + androidTestImplementation 'androidx.compose.ui:ui-test-junit4' + debugImplementation 'androidx.compose.ui:ui-tooling' + debugImplementation 'androidx.compose.ui:ui-test-manifest' + +} \ No newline at end of file diff --git a/android/MLCEngineExample/app/proguard-rules.pro b/android/MLCEngineExample/app/proguard-rules.pro new file mode 100644 index 0000000000..481bb43481 --- /dev/null +++ b/android/MLCEngineExample/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/android/MLCEngineExample/app/src/main/AndroidManifest.xml b/android/MLCEngineExample/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000..244ca12c34 --- /dev/null +++ b/android/MLCEngineExample/app/src/main/AndroidManifest.xml @@ -0,0 +1,41 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/android/MLCEngineExample/app/src/main/ic_launcher-playstore.png b/android/MLCEngineExample/app/src/main/ic_launcher-playstore.png new file mode 100644 index 0000000000000000000000000000000000000000..3c16fd65fd66c5db7ab493bcb80b8f1ead7c8ef1 GIT binary patch literal 47710 zcmdpe^LJcrxc0dY@EinZ6{~Xd(K+l|M2}Zvu5r6 z==ODA&qOH7OQ9eVAc8<3lrPfa-#{Qp;435u9v1kxai4htf&4*V#6?ux^-kK|bX52M zy}dpk`^bCp9PaMqFUriM^x;Jh=zU`!!!uSelFg+asHBYyhgTD}HefMe3CFfB*1*74 zj0_}em*Kjuv})Wr_)ub1wYvBGI2wOV)jG0dwR*pzT>0on1s_NTA2@dTKJzY$4lRNX zeJn);0zO2lMfxEA`+-pW|BoMn=`Zfr3+}j9J6^forwxvT=->kl*21C4;IZws7gwVq zWe(KzFicH$mA_e{u1b&QFFYVe$_!uOnZqKBJ|UkPUXZa?R3zG`=YkFdgQ z5^1dHpVQdM;iKE>;VcfSO4Tc)<@KrP(eog$qy4hN@||vO85^B5LzQAKL?Btp$1MA&#cJI{lx9fS<&^uE*jCnUrE;+=~aZAs-^9NcQ3CKp*z2ZfvH}X zf$41&>RM&RqE*Ht!J(*={nNZ4k+_7+--8xllJSCeU@Nkyz8@Ep<+-><%zKMf=A{tY###n)X;%bJM)k4)EDMP z!Q1UEdY~bcZ#}Qq*A!FPi1Wlen8lesoEFjBIVF_+LdWw96?n%Q{0>%Rlnm{LIF*l( z5!35ErPSJ=_o^S!?WK2$r-l4UQSF?6ZfMh>_)iiVL*#aVuL4z`Ra!;^>?OY z!bmWqjLX^>PYY4#!3e@;iI~0t1V|F=zlU1hqkVCU<@wB*XO!*c23YH{b0_Uvmm!8n zlVX7N%Z33Hjs^1b&phcp84!f+&!#o7AzOqN5iUgFe0q@fZ}i>+^we3E#8b(l>v$pf3}<-ZGn#8vyn~N zJ*5){xC;YB@Hii=|I{Gh-}Ebt97crl@_}4vCc^|)-4%ZVYavTP1g{%9=bkV^bLM&< zo7w@S<#Q?$Fqj9HOScJNcMDXs9Wq}g3f6k(|0zeGuur{Km5bOHE3bca-db@f`G1cs zxU#+4EhJRj=uFeK;y_SrU94VkVB!3x(gjb144$YUyE=H!gScQ`i&xF-#7-6EG<2fV zO2LokEMJ@it^7>|@N!RNsQw0Z!Vzd8?qjVJ1!2boX7E3kSgnioo;B;uPZ5S$XutVU z0b0HaC|tGXPd^3e;oP;}xk=^d)3fFB2F~=iOTkq*k16tc(D0=`Iy5E)FyJ&6PU7Uh zTbjcPPI*ZMD6-*;y+L{Qzq47k)Mh0@k-%wHg7?Z!+Y{BADXITj`|Ir$6$>JI8+yYf zhkL=M>(iAGmVAHz$D>#Z_`qxi@cE;KLRKt#Cat}DLkO%H{MwTT@gkGhvBsskrin2} zLDA8DlzIge2;CX%s}-AnC-ese5V&Mkv|xfZvzD?(^M!~Tr$eft=`k)z|4+FC$WXk~ zL@7U5oX84Uu9%|UA6aceZ*XK~ita++{0hB2a5TJcnPtONlH@Zyf zi8V1M&*rH*IqLalR&>92!27TRHt~_lv~2y`|A`Xb4T6S@g8nzZ8c80D?@b|>*4V~n z`>ZNr-@2Ei1C^qQ`;%xo?K?uC0Vmku(gR*}j|T!SsPfn9iNH+3`lr%aS89RPg-R`q z0I8$s{iVRik=Sn{0r?s7z zMRBNn#Wjw#Qi+|O)9TB4w=~r|7quweDk@Y2V8tQ<6f5T!8NEMFNAA{#;o-kUuToXt z)`jM$qt4|j?9(qgEu-^r(-aEGF*&Dh2@j8waR8$x2ERp+knkcmRb#PJ`+&}I%9$MY zQq1MkAflgrJ5N%-D3lN>Lxl)vh6Ol=ltSb}ghQV2;tdGAO>ky7u7_4Nh8_LbZVhl4 z_jUfPsRap`B@=iX93u2&jGrQ~p^shcORB{{OaTgqveO*i?UHet@tsqz z$XX|)=1d?cir9fgMc~L#6&^UQlc#^$IO!*t?hEXQQ;f-<7zs%?LjF4qIM4Xs2O)E^ zfgr2Z%p0zCqaz$UQ>KA;px-c9m%7pDgG;Bs69DWS0@!&V?9L93V%McK0ryf^NWcuC zuVyH1*8x5s9VfY1E+DIe7(ivSA^JZyY2ukz&SNCK_%z%0_XyMCS-&mvyAmnN(gL55NhEsDu4jG%GyR?$lg6Zpkwli)XH3KT|+;@>-T;rd5pc zSIX^~t4f)TZfK9S$-19iiykD8jwiE)XCh$oy351~w52-UK;! z_cm3h@4koSv=zC)w+JKQg=L%X|57W==qMSoC;=(M4@d%LKZ$p9E4JsB@;dhw7fMjg zj{@a)bwp)fzM@0T5EN3Nf%kvw6H)Lp`id&^%6ZD!^vwcP58W~Cf*Om9<6$6zej^MF z!wL!HFV0&hE(EN@Qar2b<;d`2G8mTG2ylE+0Z&Tl-~O!Ti$N@iFDKxuH;xIK1>|*SbX~Gj z%l&if!>5vbaeg3X{Q$=-1FElK(BPOma<0&3SI06!qDB-l;in@nrOiYvTPLd~PW8!{n7 zOHcKUQnL{hL*Z*#Yg$Y;jh#yqw(<1%_jI~X$HSD|9|E}UdT>@f8Sv6$#Op&I^AxS4 z5Ia+$fd}93Z&#lxg8BV%y8j|gMT1r>5Kti;unNKb9F!7V_fbV}UmL#cKMzLJAJ%4`IbUi|*5#;`tJc-0GkY--5`&bwkJPcDMT$xO zd*0x>BT-%43r#;q!P85v^dq_wpjYG9f?wGj4$Bz;Ib_3tpNqlU3HWSLUYF(AnguvD z{Um|+N&yt2h5*!zAGs^qAxFtrz>qlqJ(_@uu++Lw!SQaL>DW`%j05TsCcSgc`*6D|0er~WpW(Nq^b`t62W|6<2o(z7%Us?Is~Xj{xhE$h(&#S4 zsetID>}ZMlv2bT{%%oQ+e%g_i=^Dp$56eCQ6cmhz7*$gIz&TGMECoUZ3`&#^yozyq zCjVLymSyx&)i-V+S@doY;oDeM*q^e0wgEfx7G@st&jL`*+34@ix}GUG<=m1o^I7>E z+7q=m7nx!h;K-y4-}!}&>))bC*K-d98#nZk0yTUf6FFEY$nsm}he520_ab^X8m|%e zHV>xwZ$w8BdXHTJr>2b-FB9ro++cqx|Nc4@XX&>kE7v+!5ZkFMo%LRy4jX-$b6Mva z|2#~#RR0VKlhGj=H~;NFmjIp1LI&Rmw$g0Og$E?0C?0RWJ0t$AndN&;*;tO^bSOnl zF}H3Yxv%T>UxWf;Oiv6XQ4o-~PtA*||DzowuD6#%KNWIK%r&0o`9$W}8)%^YpLM_z z`pMgp>$d(HqPaMW z&3NJ=qG5b=FRdYNpZgalcu?aH`{RndDRJ=Nr2v|l3QI~Dy_^ti3&uay6=cV9{I}y6 zx1VQO7#{0ijhv=!pU^%bll0t(%;2&g@`gRd>P*BVasqUD^9;9?%uQ;O$H-oyAe&bKy{!vGor`CQ)9?@B_uZ4W1ih7c5S zzy@&GRe^%;tHC!Co@hPYFMc0FQBmh%+uqUnO)ivVRIKh^`CGu{uqgqjimw|nbPK3X zT>g6r1FL@*7fV^gnf&Lc$u0Onm>kjIcf_odnPUIQzVt)6o;H-4lfB_zSqw^zQotLfR~Z$?6>&2GozC! zZrt(VcMb& z30*g#+eR|vqN8HW`5M{@)oT?@UllUnh8Hl;@#oDn<*LKoZ^o|c!sBAm;31vW`@8p> zmBZoiy?l{IKLk*{@3e*o=QFmJXc&Fv{8+Ge!O_(q{bgeJxM!aBbL>Ctk)h$d(GzbV zw=>g)pNS?*?di>!;rdN>;15`nBFyAg*qCAWmGR7GSicUNovn?}z^54w_LH5jow5t! zBxQjb{O0|I(>uFPkxAuzUoy*L>tedR&nel>4hCU*tWX3_(`Ie?_(d)rU z{&!n!;rOxw>adxcvrtzZZhGQ*@HwZ~U-Hzp6IMU{CTyw4R5FTsjTRgkF6j8=bGi<} z@KoVjhu4_voi>Kx#*-FlUx-aCcmBvgu`R4Ynz)Qem3~eCH&F+N2ytv^U2sWGk09E` zy~w=op4y6blZ%_WJidn}=D7_(3vN^)t0iM@nL0NQQ|f0kxr30N=hu%jC(|u6`O8E= z`N7GDgu>gZ*pAAq()7xBMqga;vf8vfn%^Yon#-jDnV>z@)4DfKtoXugR;w}kVE^X2+y9g?{jofcSh)u#&?uBQVlH_O)w zxQ4^sFQ=X?--rf2&a|w5>u~FMG5z>@G0026YZ0{t{dj3#G#i;$eSUzA7%i=MRwt`w zD()w|Grvs_i0ByA1r($irj4K-+$Tvy#9!4?bd3P6wfFJF; z)PD;V4o~(Wgc(Mp7#vZ7y6ap&JH1t2H+(y=wrPpmfm@^9J}Sb(#tqHdQ@iDywFRR; z{5e){R}1U+G}#Oo>IgvgIUcH88)n#1W2$#Fj0?FKvoO8R+V=@w^4%{-(IRAZSl zG8$mdN3cIPES~975+sTiWfv~wMC1UVwC>XGB@KE><=phWxRJrXbW)n!Q)$s_HaAZC z(O+_u$sOwN?V+2jm~@E6r&sTsw3qyq2oJXH1e$Yo8L_An*UObpg` zfX-}0is^yP;7kamy%0BWT3>~Pd}OTu6*BGJ>X@o9Y($!d++&p1`hA98jDg;%ok?H!S!gCVt9c$eupA z|6uXqF@b4#_Qklj>0AYB==IS|?UpxHZB7ri5ms@iNn7~nkeqnV>9a2TukvFUtDu~( zKp0H<_-=hS8Lf`f0IqCVXmd}}AuESE<5FDA@?whsn5W+O-BhdN603!t#1u-sdg&naNti)-w z*_zd9ehPeZ8nGitCR6krkK8FLSO1?>fk*R>8 zCfyc3eSLy~{V~DMv2bPL{e_nsNt8uK=ErMM^(LkmtRf%y)Iq1T|XeDmpZ0RQGJSaSc*q|a`C7v30+(Z03ajVEYa%To>*LbXGyb` zy!E>IDvb-sE&yDlH&FFct!6PYpdFG0&b=&jk7CEJMUSQuf@e6uE3z-H$9d?g9JvI8?h3UPlvF#%-tlii6vSjbN@1fN~l?;2B z?&ZAl>c-z%m}*O{1h4uxmHvXGwdG`YA|=Gp{f9cD-t(}yK0U$iax7YSG$5(l3`iK> zruc!(mH{9bzf?7)8ph|vRPdM zglRqq$D3kEq!n)aEd6z;9s8-59G`q z1jux~Sr6!oOzI)!yHYFmo1YyUe|AFrOL&4(VMu?T#XUq2&3Qlk1t`^%9& z8e)YT_N-1#hsM8h??OjJL#sYiH96+*uV)NjpVcf9iay-4FdG$eyingRP2}Q}z_q06 zAH%a~v3HVoBxTxIP7p>sz4kNqW}0(HEZtr0Y^ohVV?c9kfi&0(B)oW<5Sum54HT)y zfT9#i>REB5YNbW%q2>GCA4R;-)s@ch@ZVV?Zwf60pYR# zjeD^?YqKw_40J#G60C107~z4D3AvGRk`olayOIETtMbL8hF04vQP?ec`L(1K4UTVI zrrijKEbR%Oqrm->RopX6lr9{L^06I1_lJC6Kd>{l)AO{WM|rS<%0Ot(d5eGnqLmCB zq3_?*WuR^IbG8!*lg1e6_lbkexrFWa07Sj(ot?9LfU@H|TW0|(lBvLiWT zz_D;PbJ2gJVmaihLEGeNzlcTeF(M`bR)7P)w}^Fash2OS6((HyCFX0AXt( z^T}3#ni~g_)2H~<&|utRH=)zQjdtWv)UfKA%+9hj{VHyHFfcky569-XK?{JYc6hv+ z<)$@ILDZ~w_%-rPqQcHZ*~`%K48eJ4HTT2`e&@n0-D_!*EFD~ldghf2ZOR^!&_^2C zq2XSnFy$%Gk0xS#p1BZ(r1Cra0d4Be4Zk<(GoTL{9_j+PCwq?)LBvJ z$8ewY^x9A_l^9JrRhlWLHd&0%XOZgYr{Knyz2Mt1FN5-F@b3CZ$9jdoYA?JMYTbW9MmBcxylO|h*7eI zGU^{%_KA%on>o!A_u)9$fDGy1W!=2S4E=>T78pQyf$Fswd>g4Zosa;d6aWNGM%szK zRQ)-4j$vXY)yB-?HtF*U{qe3BP%sI7iJHsI4k?R?Uyd|hk^5}6Xc*JO!-TAuIpuQ_ zj7)GUqx;PAY&M)*lHtfrDGlX%*qPu;jJ(}K9HC8iND?CAj z-?mKz@(5d@V6oo2dtKN4^_+rR(zeg5=pi=N{Y{#GKU6_H9s$jexpsdMT7;xz{l5O% z$-%*A!%Ub$)gIQ|)|u?YI7FGw{up;hEmWlM9}9LAjT|{YcKxb7kTk;S|7cR%3(3aB zP2LS8pVz{8d81X=dm{%fkH-P!KbjUO|K#zIZ{`L+Kdvcv`njOovs_)< zTkd<1Zcg*2>sjv%L>Tkq{UPnyOswlXuY1nG48r{8FQPf@)&qaygu87qHdjYfVX@*8 z^H+;r3GTu*?Kiv4o1m)+bPkY6#8!W zg;l;O{e1>i4e8_eD50M%@aOsm!lUs$9CzLuTHzdoIv>{9k1nSX)e2>#hYNElj#Y|+ zra?uZsj!549YoZKkMjj#CB0}e%^FH}QY7d1A4VYUw%8;7JBGOKz{q|OxillJMm5)2 zM_0c*?N`e=9}-;j%R+taw&|o47ee{?tQ$S?vc~irlUn|UNaUxYek6iva#22 zNL8%i+ec$!03D6^RCnbOUk`K9wu6mT^TJ8N31C&<$!Xe=V3cToct!!4g-L!bMmHn| zT$l7l_aFkzXPPXdt$zoT5wy${1Z*yRE(dkNmOh{26`(?s1Z?CrjRrVgfh{a_tiOi)sk{(F(a)rj8(TsZ^6$-^>n{rtus@c%9^ah%yQ?`a z>av+38X!k}6sBm_J&rO_y9A;tLs9rdo7F_o0qlVw2we!#=7X|hGl9p*S`-+X1#DcP z{k5)g}Y z9Sj&rNb)EfkdOQr1O;*AqmxXOyOhO!?Af`|enflDGjUAI7ZnKVgO~)BTU?aZJ=AA` z5&NG_TZgf(OA=Y{2V-fl!7|lkmJX#KrpK#UMXB~KVKE))opf9ZBtu?K8=g0Zp__?IB_R&r?eFG?`m{aRq6b(Gy9QpNevBXQO zGHQ@J5!l5hOKAgz!GP}!C~1f3(~kuSma8+d&B!pD;vK1i&sfdXB?duLZ5L7}dcC)U z-(!EWb6U`@h?t_%%t&UKO-t#wImOhk58d7j)}QL+Z368O&J%8h=NP?RL;3@OFj~x| z9>s}(cN78Nn}j?Tq9U{Kzg+?TVmV=TVOY>lW9!e}f+$Vvl7elb&67y?R7diYpWMy` zWK_Nx6@Tf@YQc+>E^Ern3CQETK$v1>*?PM#iR=y#ULfW1)9*xY{hh8`F%--ZcpMu9 zc_W}`mSctmqLcdJA?Qb^_gVu@C_&_Jsi-xRoRBI*ZpZ~>H;QTu#})iK=L|@fWlYEt zKAwUrM)w4j0-O}-wUc4dGe|v^=f1qL%Z!~swR*j;I3pGbE6S4gD%K)>^s)<$I%^b6 z|I`ylI9Q_99+4>Nn_YuiliEv@r%AgvSVPFO3YC+V8v2jH*q;5TTF>F@h`!oYJ3Lba zP-tVh2t+fSDbc2suT)(QzjqMCf2En||5dfcC<}zbNcUzz`Smh8pXz+NO1Qrfs3|6sRDSfwaa9Ck;Ny=SOa8Z ze1>zjN0--#f)W&??e6aM66 zWz)Q0{kTQqT{L3-^fI{1ZavKD=*-s@hs=^?K!Ug1ND*%8;+Ih$67|Z;LL&bQ#{!9) zuZnh@MGOAYS32xVx8~x8o){v+?Bmg~Ib*NgN^u>76zOCxEs$BN1b*u#q4hI_ZA+e@ zdoo@^1#?--ZbYE|sN^WTPDWUmHFp#|KAENBUEL0~&0C#|0{wWUd595>YoFj(&eT$C z$20iNb8+~0#00!Snp$sJ$jTf9Z* zx8H;xNQ%DZ!)(CM+6hIa`f*pD%SToxTUdV2=>bC=svxK5P zSwMFKk30XbX$eo4UmI)g?X$?tqOv2^?Y799zI_n1VSb!%X_`64A1fLA+$S2D9 zrHj!AA+@H{OBP-S`aN9Un-<01N#==Q5-U1yH#XVeH)Eadwx>BGud$YHBpE|B?rF*OAe zO8{kw$e2NQVpBhE@{Q|xrf|GE^sv7rHtJM2W*6q0D!92?&d$$hz!u?iFi$yHBr;=R zm*YBshO^mc&>A{=0o22wpKWeeDhGRJRCemhbSxB8@euF;v^XVPS`!=fi0MoRYiK~5 zm3nNg*J%u+XT^&u5b647nUn~L-A!a&Xf)i?#RJOi!pR7f>N$Eq_awoM-~E#fhZJKq zd0ACJ0f8F@ZRettzSqxqGtfk>a6wxHOZ#HbNRS5<-c>;nHRzv7v0ryJm9h|N;2{Rw z2C5FRPf-|YkYM#CrT)unc*9Kbh&5`-k5<6Q%IuLdd=Q9R z5^kfCd=R~dOzTwSVNEeF++xS~+PPW{^~A~#LWR{$+>!-q75_>5$}84fBYg55o{<>) z5a!Z>+ne!@H7LqiFF|y*$Hoh3^#Coo zUHn}IEfMp%nb3aZ3=9z6LZYcZCqu}_0*uu)xpJDDudeDFA#+J+#{2&@ z|MvAx9T+GzOBPl|DM<(OR9x3tzRCX#NWUI*LobEP4Es^;O|aV>pXn_Oz!+3KJee^0v&8;4yno~lzghhmy>UsM2tbE?KH}ENcFB<`3_iKF_Vq+Ar zmEpbMkdyf@X_|fFui4?MYd;4oE3Dr-Wmo!&e8V^46KOaEl3n3~rzS8t9g{)q)T24V z=rP2IhDzv2gQ!d?#fG#AmVR==a?%@@w7XB^GfXflTJX&d_b{!A`o(Y3fhdO4sGy#I zLxx5=bXP`4a1!N?p)fG7r$K`c-vA;>Wf^)YD&ir!T3CwwAZZdFA~X+y{X55e?}mT& z>+LJcUfsqAsu-o}h(<{fLiE7caQVGCe(tHG^)9^7XM`ytph~9|U{HWgv&Ktv1x)F9td8>hQEVC25!H%7_vfw&z` z`5Y4kiBBqyt!=G71&5{i31&55vYP2cMp~lcnC#4BSPtg@)j#E|3l$=QFm6CLD*hn< zvhbk)Xh7gXx+8H`7UK}5AL{tEuSA~{@wA(xzwhs%?Q`O~+pte(R1NaDWR;;6)*!QL zIxpcsAu))t*1X^H{gnczR7r@+XqS*PH#u?4sH+&o`9M5DCD>~8ln)b@-0hFexx|yx zk}J?jabwo#mqf|%vtBt`FQ^5qnkkKTEtr7#hJXyS&3p_2u_e8W5Yus{Rs;_h3)09O zo0^WGw}oLu5q@EBIm8(tipe$YEb(%lf(&|lg6=nyZHFm}72d{Dk&C9Nb@zYPnSs#O z4 zh*A03(`sOQ$m9yG|C48JMB6BuiH&4&ND)yFdV376wclFLL(^hGTm676X3&bF0p|H8 z#AQ@<{0ST|XhxUc;c0itd_;co)gQd3Lwg~Yv z}o zcHz%BBs*hh@*)+DpWlDuXvC+Nf5cBGk-hA-zPwJ}dNiR=!QqDC?(VIwNq)0K*iGYl zaM?%)CfZ0PXh+2gx25>xOvxtm8V60-4XAK=l!Qc63n3_1Sc8_dUS(@ytJ64dWSl9x zFM;T6UBa&ZInwxM8&7{uO3LrlnEGUOZbXp{)gc9e`RxSn3wa!w3aPjwn9pO^;P6il zglI8V36*r!E#)n~+*e5KpW4;0IebXZQWf4?#toB~ZGSSV%eZy*Bo(7$QRaJ4-j9A- zJl`DxP^q_BKR?X_y)hE${SyHqkujmPs^!)Iv*s`NJdB~#Fn_=t4g&u$I2e_mPVLUS z)HX{#0Ry88{gn1>{1R2edIBg>QoZt(M2+H|t&C)Wiav@P_}+p^y!YC}4HBN>=qPys zMbiPigzF+RzYpXF}6 zs-tZVnCzB(n8$QogSX`4?Ez794?94Ne@)EnqwN#pI59 zcwI1kp2Q!r_ZCaeKiGQVSfe?B4*alR6=7xN$2N00X`Mu5Vcvhl$mCJ&CZsu)Rcn1H zd6;&81)|~{2Hx?G6wsl*qSJa6){5NRnLM&W(`qD7u_6HvuV}2({KP%m10g~K2a1sd zh*emaX?nXZa@F3}a}XOi9WNAuzw=PO?+381jsC166|&+&T0yQ`AATp=|C5%EHR1iY zHJ+dUKt-A;3Q9z@j$ZlS#r}qgK@Y*w{!g(%Ny~^Wk)JZQ&lxlo3(i^zeb6FM)g%>b zj!6WruBMufx1Y_B#hM`Yz8IkNX;ClWaPEUPs3%6pi{N)n&T-KPB&@;ogp6nAoWj+; z>e6*9CSpw*Lgo1jg_n0bhWCz$tM#dEw7>G)h}w6-{{WuwSAe8p_%`f02@}$fMG!aC z&xJP-IHnTYNXmp#@%8kXu@GFFX%&K7+{)FV){9PYYwt!70)^Z5wSJw(Qhay%g?}LN z5lyTXen9z;;Xm!8S9H^@`JU+k_>ZGyzWXu@1EByXdwruo@dK=C%ftDD6M+6k_kITu zz5+052mmq11jK1HbL10q*}`&W>xXmxu7SZOdA-!(fr8;~KrJfPhN`#c2@!lnujPEPez_0e@Oc3BcBVPX)B=Wo z)%}wTarF;x*N_W;d=Xacg&haa5X#6IqI??`DS{(6#fkWY)ZVLAiT?0^CGU-op#9yU z61>dn$NJjFgI$pD2Rt}@y1fR0av*Q)MLm6c=#vFEioqDbSiXZd;5&&kzZo<-rIEGRsbI~ZJ)$zEKlI= zk@Eox$+2|`V(}P(YFrN9W0=sK6Jf4_wnhb%e^`ky54QMVp+n}Am#LHQ@%C@YsYOW) z(qIQYi%sFv4*kBvkS)3-Y12q!&X($d>@%5_*#I)@>&$X zo*McM^rmC|SPt3z$UPSBu`lmXRKv}2V^)9SjSE)TQ%0pzu^fxS(epkpiBH?H0Um&u zs`088ap6``{R6uDjO3NcQ2sO2mTBlr7}iK-&Y5mbk}?S~~O#j|#}`uxW+@3(7%MgdV0}3;>Ofs;J;4^O~E3 zRXN?@LY)TX!vv!a0XFH)`9S<4h7>Vc&whampL7Mlhg6ZT1J@b=%}NB|XEAs)%ZcnX zs1OL661GFKk!7Gla@+rTKGbRT=BVi@E-=v5$bbUq(%Hqqq!%BD_ttH)aoZuN55d)a z-`$%|?GkcOW6Ky?j^NGZ zar@)C4ef!Lx#NX~KCe6DFWXi6a{HRUV+bgE#l^SiU4j_G&nvM zDt{QOV=Etk=sE<^Oyw0BAKdfNn750X@8Xkeaber}0RxARo8u&vp_ziVwcH} zTnQ%RxWscT8msd-hVpP)_yJgn2E5YQhWA#6ao`+9YJck>iolikKy3W+Gh_>qpq_GWf|@D5wq|})T0bcKgta(hc?@(F z&+`rzZV;nl61Fu6xX-)8m@)`-J1SgY7t1b>EpkHPh+SR&YYp__H=hQ# z)b`UB)fsN8IJxJfrwn1?IrW@C#HiZ$diiSWjO@h_NtZ3rh8o6`^!Uj?frd7U;Uy~Q zOx0Q#{uND4ew3x1>$D?nS?2^{mt3rq$u3n9NyQHzTWNm!pz4s%9VdV%|MMAV5LERz zH4%tArFg7sXy!Wk!;lbb`kXaS?>Nx_dDVh@YU6pQIM9&#HX7~B%=lS0TBY-sqVA{K zYmJG#*ZZi)`qB>4GGllj%3m-a<+<(mE}Y1pn!sEj&MmJaRj6Qd_-_mV1Av}C9eysT zpASU)5*0nM(MXbBJ=Ws5R=puxuufxF5T3Gw*c$`nlbRM)z^Kg)im#r9?A#!smzo$< zSj{FVe+Z^<=I-vk;yHa9o6T+|xmE&Lshnx3>E-NoL4z4s$2(-Qcd^Zte-=N3)jD$( ziOHUVj^BYICsm@)zBM%cKGxWqmWQLMM@_e#+nFL&ka%Y-mBNz;BwbRkQhC}U!HFZ&7#4%y~n^79s@?kAG?fFw#xURmiz$?%y&w zQ9&n{p_n`s{&R*-aqxl;Z>{9BImSqzsm2`q1Fz$<>_`sLyB{>_fslV4dbKgGg*9aJ zc~65;Jt+sKWdA2NaIJuYo{{n0kb#ylXRHkaMx#94dz1#z;1m@k`T!r6`w?kVgK;3y zp-Th4bc)JWnGnd1T!?qXRh<;Mf)Yl6cB$195!ubE&rjTgzu_B%-$JXRfxHs#5$#agwvKtDum*e7-&Xb2P?nni98w_!Jn2dy2gBz9h0C~NR0wHv@1nK zI}CFC9s$v08x}3@wCH zy$3>ggl&&1|NCA9lmVqiIh4P*uCxXuu3nCYq>oryo1tEnjVP~|S-L6L7-ya$(N+xqQ{*~ z)}RIu33S>Ei@p40I>Ywd-m_jczw=UpGdX7@?q=c*IoV>8gXux))#G~(MzM;H@mUUT5bc>bp_#`- zb|zKfPIWf^8KJ;DXEB86cl<*Z2aq*XlnbbGum~XEwlnVJbi^_Bqpx|jMPu@K5asvn zTwym)p8Av1>c3l(tfJnJLspZ#+J`@7C!z&?uh;N)sP&Z8)QDpiF##&J6exestT(@k zBHYCI(kEW^`3=y_W8ZHM*wK%PAzt=(|9VwAgOFmm)cn31<6G*o4!n4Jhl-b3Lq|D- zA|t;ZJS-l26S%k59CG!hE?}aZHixhs$jY`@niMO>gntFoGeW-@Hr#6C4cB@VLTJT7 zfj-t}4chH&(rT$1e|5nW$qlQhj)nNh+$c zpg=w$=Lurit>~i#AHV&#B%QXnah_VOQ4+2)-= zx!G!jDlW)`I{sK<V4~el-jA}qw@gs_X3-YOh2IC|%+^Lrp zs7T11`G%y>1{y?}!`6+8J-5r_-V+qygoBwZ2{1^Mf?}5T29%V8h9If3#=0T8q!#M= z5(J+<5W)vPV;w8nOIwE+kH8#v0n1Q1I^%N;5hOgRC?N}{Ls3G;%3#v){b8H* zIR^7xken@zV@mtRdHU8lo0BmWxQTi#n)us>7!@Q8fdw2^#J^z(;Z!(}JvyhmTK7an zf$XP4AhlG*anMrL-#;SwXCV}`JXB?eiV9k&A6MWe{>}$7QoiE75HT! zppg#c&sJX6l)ib_DRpGo1fQ~GpG#ufc^X5p7^D!fVBK`y)^`+EI3omTGqKBICr}xM z+RKGD_~R5CDmPA*F%?=RnS}=3U%{dBFS02(&y}IbA0Vj_6`*-CZnPjB9qZ_d5i0>$ zgU4BD%TC9XZqT&5iwbR-RkH${6@t~VkH{&c=aAD3p9rD zQ8OXWb)`8kd!bKzQ^qGJESOx=3;?xY{!rKkFEVIqz^W<3xM);M`s?2uggRF%9dpi@ z>(Pg?PjwY}dOcX#(yoV3@#WzqA(CPxoA{!<4E*$q){g57!9wbDL1w8(`yjH}$^2!V zEA@1=-*5(m8VOMT{08nOpQ0U>0G#NbZ2%`+uUO3FUZh+R9n*Fq%0*|oeGjcmpwuEe78Q@+m<0&%kXiiw4u9LVyoc)ZbIUt?2r1=D(*Y6t@=`pBI6Qg$g@ zY`XZ$VIQv>l^;(`AzuTZNm?%d#n&X=vZT}mSOo3m!n%{6P`~o~V$MGv9D@El)=Ue@ zw{wH$Vu7iP{iS9-()2P@na20NGkG$St?QESHqPb9fa1}^6cz_4`EQRfJC=d2I-TY9 zpBVeEJzB!>?f{iy=7|tZPKqVZ9K}ZWEnxGg35)&<@NW;qd`{6o52;-lRJJ+*zxujM zsBrCMhWbo0^WfXGJX|~I$qP#HkHEqUjg7D@_)iTITG2CKtBcZ98$fF|#e!i*>D^H3 z`VdeL_mjt&{AMgjkcTb82?Rac=gz-@o8oAzC7SMCVGXQ5zw_cbC-S~jYfOK2;>K*Y z<2q~oPYMa+P{OGv0CgpG?cP0RMIF{5Cx>Tc1W1H6(xs2jHRCp4b%1y-?~}A>P60pA zn+rH8#WvvIECAAgLfuPw3w~r8nG=y)-HRM{?nkF2;=OSQWp1&9i~&5=|DozD1L9iN zWd{px!QI^h1b25QxVyW%69@!%cMrjWOGt2ccXxMqi@onT_rCuOtEaoyC-qg;NZbW4 z3)nJ5K^1x_WVXt$Kzlj*hN>HyiBR;Tq*<0(#0C=Vwjx!wtM-a31< z-eVODtinDRQ$x15a>*{aIJisi3(#at?{0#6|C*17Xue*X9pR+A7O4K!^I8ieivk=E zHbL_sOaR#_cHQj(SST&G!K*G-tZBwo>#RX1zPXU`8 z^?~WZ5FjyWn_k{W?Y;)2s($QWA%WHa7gYdw{(7-!KMoX+7eBfeu6(9GA|T=qE}*Gd z@6H01s^7xHY*mA))PoxthhU@SV{o#Nx!W;_>d63N_4=wVai+Tp27HVB3W21={vIt| zfNA7WVE6LP%__iSclF5#47st#!kTsdtr^CC@!7Wp`6X!TDP>DuUI@ToU*5x_4(p`~ zOwr+k;=XNGqKa=B{$B99de@dxrD1ggwX`Wh)b6{I@qN^wg0)$i+s_NBS`Uh$yd1_} zFGzv_g&+L6EB^UgRWGVUksknLmW#j`P((s@U#$T`h_5(c&ap3RfiGj50YMG!<{}{2 zz=rRugX0Zo-T)a?R@VAQM^GgpybB0Ps~zGUn_c|5M>^!l4^s%OSozr@5qW61TwerP z1b$P&m_vCy&DyLWw-}1!G(3 zXVTXTA8it7zW#0ugl`Ku3_QW<96Mw%FIxgJc~x89G+gH8fC<$qcrlT_@z4hBf){$| zCgV@G&DWB87Y~>aYK3-7?#NMKR6;~Kfqpmu{g)(w@z37mjW7MU5R9p7V`)7f6(O^` z+}G+4wL*ofDkssIKf0UAtnR0w(lxZq{fN)ELCrQ*zEtQ!*ya9v)*&uW(HG?f4zW`T z;68698%%?%?!8J}=K>BHKBWPvNnAT^s4p4q$3HF~>Yat@jDJb^y{H^wy@o-mf$wj z%j20Zze1h#)ubl~(2EO|FfA)~7AcA;LPj+c`^zqkdZA;>t6@-qY}4!;-10YWm*(R> zr?7Mtb>(Fr`6j|=-{>4=`}M_q_;#Bo^JOqcf>Z+PTOYtTtby} zG~?R!DGfRotaN&8c(4Ru3hFe()Z@6|{m|s!*X#mHBot?yT*JVTHyfyv#$r;AnuuK{ zfpSaS@=swgq|sGjDylYOS$zL)4nBzT?-LOTtzWdI!qCK|+-cr!d6>7k zVq>Xc&opdTBx5#mWPyN=*K)3#+3G2)#Eewt7{rz-s{6=PAksYnGGL9HMjFx4JzTx4X-SVe+WJAh9%X6)9k|@D|_@>cU__ zcVWJ#bwhz+2>UtjI5R3%=lVG<`I>Zy@j*5ukOJhh5q~osJcdd7h>)NdydtdMz9^fk z+Gp#`#hJBM?hnwQsG~grFBJ#LN+65k2CTt53xmf>Ob&-am#yg^X%Zh5)bQ;C0HKGC zD==JF%t1Y}09XBB^-R!1dMFNtT#jm>`?CE7-o9qtYCncHH)p{~& zomshHzutb-0=jThlByAJVJ^Ogi8Jt6RB>xOB-CVwoLxF3khm?!{&v15iIpjWQu&!g zWdpY(Q=ZNHla=5uOV(Z+(<{^}l&_7Y!kBmBrP0Ci(|xX^v+!BQjS|Ys4GBTziAvwT ziOuPFl{SE;F#9-9CS@ZLA#~4qpxV6Z1$M@%k zlKtCf$4<|e$ZH{F4y~F@NvJ&1br8LdZ|3U3=Wvc42|ywLCs%DEMFj+;C2-*06aR_+ zaffaCyZdiJl59n@=G3XPr7e`OU?l{KUqX?GK44qGfpc@TZ}4=|$Br}b_37I5kJOXrvUsuvM#xt|lEdPaJe?*`UVdz+W@%KOy3wVL?7eIQ zZMYsa9uRy$;^3|Hu-mI^;fJdJN}tW2_A7XlyLLjp#vyT`c?J6^f2Q6}o~g-80RVM2 zTxKf>0xk@>2ebTwY7EKl%0;IztUu=EX->4ZZ@)=9OchGixTJzE1i2C8r&H@3$vMLS z%YfNG-;F{$B<1b$G!cLbAnFensEyj5Uj)$8Of_(;D~Lc{Wy6Jm8P@T6affNUWm_6? zL$5EG-Cp0+R1C{YtHJ+(k?1Rm8`=~+bt@gcZk9NbD+I=44l-Y?!iGU8w8)>=mj>U2 zS`-xOlI;ldjDQYA(4U?gtOZd=y$Wh!8PW5`-e0#Wv(%WnCKHTD5pD?}f0g758OpzK zdo4^24q)rz;JJq@yH&e#fk0HZ6?4kJ#@c2$2z{#PGVFv-K={yrke!-@Mg~(I@Z^Du zat29D+f5ye@4e{)j=;Kk(ifT@R?{S@FQqPkk`nt$5Q8P z?TYU{3z$Zm2J^o+CTOxu#`5+0bqc9&{(O5k^X$(2n~;hQFl}+}0eVRkfPQ&wc9z$) zILsz>)5?Mf7Y_8WDlRr&%^BiOR*Y5i9$z{1RZ}bUC76P-6|ctj5?Y-G<>ZGN7F8JG z@+UE%p=Tym2VLJRm&N5p+VeOe#^C~9zC5fkB>Hmzy2CN2u6l7!v7MI~z(Aq6hm-)+ zKbkKF&k9z*Kk=DLM6B~DUSzNgbg+(BAlfAdwS3y5x!~!=f;@U)eT_Q@v~ME8uT>)I#z6QFE~Y8 zmq4&!v>IQYNdXc@u+Il5@aR>I2i+tK4FLSAKfV{}$ttB0_Js^o`tEu^ATvB>E`xtY z7usCg=v)~58%xSADgg&C1K7D~$PQaY=a}xS{^hf)(kO%o_)e>(CO&V!^=frfY%(wl zOpX@`ut&xGkip>@g(p(YZZDsT-(G8c)&)hVY|R)z+IS>ii@=UV$`-&K9PD;5E?h<< zN=>-^Zn@p7YyjqY<0%GOozL&~Xz#}mCwmd^4|m*f<5edC#&Jj>N$C-iF~FDY*N{y` z8AO*9A>YmZ8I!#zja~2iaP7O9G-80V$ZP8uk@dm%%QT0N=#^-q>`qs3nzsu~#L;G6 zAI5VhI;H`buHj@ys;cugSKY`+-QZV>qr%oIA?INDnJfMkJq08h zOJEJS3s%x1)Guk236}9kgkK}+{AFpJ#sS_+x_m`Nq#@&P^Wb(%pl^9@g_(s`Se6+w zrLqru0(&2#x^<8YeDfY>_HqwOO`X9$J)pu{oRB$_D*(9BDhfN|2-> zp<3LSM<9dX=m_b&xG8rHF{w8)FJteW-%b8_ zzZ42!OAAgaSV5*i=7?3%@1tb^c;DI7Mz6?+=uNz#GBWa-?@VmKS<#Ge`TSuuNbsvBCVU=#fd-xio+5MwfHB>X3fum^v;qy4-@^SLn-K>|kqIu!+*e*hGfH*IRG_ZZe!0+df#h#{Y7xvm|zY@%&Z|UkS31hBattUo_$_F(%DV zVBn{x#O|_%2_FTFVj5n}NanoycbUwNK2U)^t)YNx|9mCKD%7UMJOP11hVF>krS*|C z2U@NViH1HR5K>%UH^0@Pu@MBQHaq5APu+L4{Gv<;b(cq5| z-M9<+X(%T_8WpLT&<)#_9xh4W()g9W{$*j40Ce3e;yE(01qwH%XMhX9gB6SB5NmKmg`efxo~IKxKA<%Je7!t zaiNBlJ4sx~(!&ZKyZ=S!$mUX-oaNb8l(HUSKpK)i*!}wZgwX zf+MC*-sWIVjnMMG5(gK6hnrrsBqV00t5mn3m)DXsbDQ)~dKHBRFJKsNTp7qR7PC>3 zOHl7oZ>d&r<WMC@E6dmAJ!0SU>^elJftoQM^JAsp^3&7 zbEgps_de4BtRIog(%`y`9<^j70MSi`la_q>k5I^}m{j?j2Lt=lKkA*? z&Df$wD0%b8L|WA%WT^t1pomB{8$b<+Q&bk9mZ6}M@;Zbtb7@xZk#~-K0E@t~2+R_& zhLf7U7?E^frpQn2o|)I?hL2odBa8Gx#diTdT`3upD8oeWtD^nv2pUGI&wDpWT984m z(?XLkJ|r4d4a@roIPq*)d7=-g6{e@?D~bXj&h*yajv13nh{aGwRn|KIC@nXp=8w@G z1gPF+D~-^gyAZk~anm2hEwg#Ok*!mWx76M|yzrdvP|gh}awJ1uht5*^?jpX85kGl) zm)%m0n)yD4=P}(D?M_+jhm0Htz`koajf^ld0Hqt;=dtBB(&SGB2S0*W?Y2xN5z$MG z=?mA8^?(2NxuioSmB99WdL5c>c>ClvD4aNtQWQXcRXa5|tuOI|`283w=6k`IjZt(u zlfFn&jp~-70GVK^kQ;`x0mkFJX&vcv0>Gvpkl;;tcfTJn51f;&oAw?sE<%8f&Cy zHJ^N^{GxD;C7Y}qJ_sJ18*FZq6(A9yTFosbP{=Fcm%3&U+tHWa7%KJxNtTxcS@};A z7(7!kUW^)k5@1vEQxegwkmp@3KdRT3_v#fCn%e+z4w#XuR#dSzMZg7>^#g{1H$0;% zTa0x{t}VL<296+lgzedh{6>H)tGA3&ExxfR;`(=nlR?qzurG5zw!?{vXO0gY|Km`V zPZC%fSSh$soX=YnR3&DqAF>QHr@%X-D+`+G%ls~N>bsZSYy>Ch&d3=FWU#RQY(%)u zJoW^0bMFb|4r{Q>CBLl3fD*U`f(DNf{XJTz8H(aOJ<$V=vo`o{&*z5!OLpahSDNpb zzh_vfbe)~EXTOpjMHW$a+q?OtcrVy5tPB^BbT?^AlCB`2+xclR?%r9m5dqRVrNu)O z{6%CH3@DV<>>!Cur+9>bDUfD(E-mp{tfl{pxZS!=*{mfjAL88hH+5ExP7v!3)DYk! z(v1sf8uS1HJvCGmf&5ykU!&)k*t>h}%d3{cyOs}N+&L*N)r}l$h7)ZOO~}IxXsh17 z!`|}n$UiIXU2M3K$;)Mv-c}5*1PHu>qf`LEaA!;}{3?j8f4#N?{i9i$@h97N2S#gA z`~Yv@X^e_Q0d|3azO%)PA(ZZ->9k$6;&V=3(UHRCfnVB~lor69u)pW5?W4}%hkIrY z&gF)1c(>{U;G1_kb3zn`rKx0|!K%=npNyq>z%-Z!>3=q!6iE=rpp~}SH>1d_RtSr_ z#4Gby2C)U4vjGchZ8!91%MzabE7q1~843Z!F7cPjq`a|Pfa_#2vAC6_%!fYwZfo~H z%j^>ED&uqo`md%X>4v4zLW!dzD*3{2l5b!WMN*8fiQf~8ULS+f0QJ1Vw+ILaCLrK} zrt$C*z81cc%}u@A0~aV#(r@klG@+}4*awX6<~U%I-wkQlNs+qRBOr3gG*neSIa$sm zv%Z&$?av*-YTM64pF~-+ZA|(?6UTK*<$Xg8xPeen3Ai7g0+E@3nt{I8q(`dyb^5u^ zb9t{O6upRR>9=d23V>X(%^fidcm(YL?d^XSgpC>j+6Zg)VUKJBM2+sh)vrh+z`3jk z?*#hS*V`@wArP~Er7`MBB~k}K0Tht`e9Qu27cp0cn1UfiI5ETMWY(E&*bmS5IthBO zw~BpRKe>CLdz_*W3}D}lSJ8eYgDw8=HR=T`CE=XTH${nWqB&@GxzVD9%8tkv8x+24 zUBR>}vGKF7cghDmxymm{ts=wNKw-E*!6jv=G>{K#Kmu5xXpn7!mVpq90Se5~Zi1ts zg2(I02k@IF_S%RXUPKQ*JkE!%!> zyz%;x-QfQW(gHkjU6)4k-{lzqIc)pe8m7|=1`ekGH25`PN0=LQm`dX%%2%|1jH>^$ z>4|Jj6{fD;-(T4~0_BP*9QLYTq(b@4mk!pOB@tRkY6)6jHf zY3T5kutFg2u+hCZFQM@}P$jR(4pvY8G7%E^-&+#^*CmDVI{;-hT&L#o14>03P#UsV z0@kj+0zL!`ODM8{sf3ESS@zG32r>f%6cZYDPVI*Wu}Z19Y8__{L&M0&8s`7=dO;c7 zj{$V%>zi>ko#EmRu-9rVlZ$@AShv3f;vPmV@mp^!NO77z0)^MV~7 z0^Sr~@v)60#j^FG{zMzIP5+(W?wfB3&zjo9=>6-D{LRU3)jvMAPD+X8^d`U5smn!$ ztbB|xbz{eJ_vhJ8MPXGwpN&U|X0v(HlE+Ye`9K((zXG;V57-_6w!ASIxc1A#^jBVK z84v+S-?Nh>9AU%GuwD7ohQkobH0n|3>Z2t8i9WgDKTz~r;w68dGyy%5Lc(*0c-Yas zXt2~W&BcIwHXG1IkDPP%e;V)$G;{%!O78<9`OMZ)xUciQ^5K6S`=CBIT&1Ah=fMMj zELPID)5#=qOdYr=>3pEaZ26egLzYFBG{XEbD&nW5!BzyoWKR9%;`Uh@5zm~SoU*u^< zQ97;wMgRW(0d@x_G=S-KkfFW{CAE{j;kCd4V_^zF@2FL3Q-0MJ`yG(&rhoPqkaM}% zYaCjmwBI>eGzUnahbd~;3kLSp#+G+4q*hl|RPUY_oyeKg?7O=llO1&GIg<3{FN9lX z76R@ysfY`Jt?a`?H~MTk@jt;z5ljQmu|@cge{c+=kC)mSq|pqAixh9Z&J@5hQ)$0t zyaw1((?DnbxLLAueFEhCrw!`^fWE5K8Hbk$Jre#Q_NFN3uWcb<%j?L{4dC_TfRmbK z@AS!a$#1tuK+?4O{Fn12zjoAbxrjw*0flPhma3J@+X4Z8m7KCs`)~GVZNCd?6FOdA?(|yC+-Fu?b?0<-F;E~rD2fJwK}3{Btka73 zR{o;p`u!cpX*7*wFVlI{u%BxWEX=@nuevyCDxt-@#hiQESNEoak$Hu$WX$vMZdxTz zvRPis>QOb#O*vOJpZC(qx?O;XdR6vSmt5%7uecU&b}19;YZ}7bUe&>JEBo~;g?;FA zQpr$b%E48_)QXx|JZK$<*W>K30@gmTdRX{bBO~Qf!y%+AAfP8nA`aTZ<`SFM*k;j9zr;X31}$Wyyw!JvD!0)~i+XA12>S5JB@ zs@3DryOx1OF74}kI-XRJgE4p)az8|ByG6#S_}e!9cn3#dnr3|(EU3=n6X3z!4Qv7P zNNk-}&(JY_T`P8a+D>B{h&hUq6%2(?7IY8m35sqgdO*&zT*K=grcWP+4A5jtgAF)0 z0Pj-Rhh{cg^eQ_)5#&?Y_4z!#_tVUb%ajr{oIfB73V}ug`3XP32K##<=#mnt63{Co z%1`&Q{;a5Li_&T zBuo_Q$~+2q9j2WciD3KVZwhToOWY_GH!K;}I){w^{c3{We^8^M<{}>DW3-aj?(t%P z$F@ysk%3R!o4Zf2vBIbPpMPqC7W08o_wt%%k8+>Iq*m3*E}~C{YA{QI{W~7L<~Tk0 z$p3zG;*cK}PP~H1_oxYT{Y^LGQgKZ@g97jzvu>6C9=R<$u&H7pVg(?{eKrw7-=YTf z^>Z(?3uC7CHQnM-b}$qxU~^!LTa>tenamJcPPjI^ z&RBc|4TJA5`{%XYDk%)SDbo$8LYZjG@n5-&iK)wH;ZBQc0M-1%EWe+v0m~j$% zcY&$>L;2!#F{oju<$Ng+D2~0DNAyhaN&P#7wlEa(>t*moJ74ks}r z0zNV{bKCfqsuIbCv8R2>IUd3qsyRN5-ys2-8o*AW#{P5lh_Uj5)-QMHoOX4}ZC3p> z#O3uW>bBNif9CMIHFTpGJNnvlv(_D_!-B{NS>E4A!4QXs%Ni$=n|OSFH+A*OH7b$5 zs^mx866sL;4;D*?7!X=w-)|{_vzwL$VP8bRf;cEaqzIUL@gK$t3kP3CN=zOW`99h7 z;94#lwx~WjM(|##LqB+Thh-ZF*)R^8U073sqCI*c1Mtv5sBnD{i61{=tQQv{dF{0L zelk1sk~(Jp(oxMLt>@L5vHYdJY&o+tISC^j%H)4PrUY^wCr2dxAfMAt42_hNd@}%@ z^~XEv8n;E&sv(npr<0I4Hf6IkfLxe|mU(FTf7h9hpkUJFEGzh>vY>BlurW)*bmOb; zC!a}N%NgHBA0_J$acDvyNZ9!n%1he0)Xf9$r#P<>c-J@B0BZ3+0ayMOF^;c2VqC6| zZO1$w(zABl#8Z1zGM#ZlMM_Txmjh$J3q3+do$Mz<*0erakAX#f?5zTB_$DeE$W9j2 zi?X`k^(QK&!Fll_f##?68mwl6VUC;PQC7VwRCmct{Pv$Wz_d4tBVh?{Iw^y&yl4u( zfzFVdqgnGf4K9G={dG0*5ffjabDMAu$0RK;1tNFot+=T+D36rmO0XNw$f-Pu^YERi_0(4--CxZSh` z_V*VMjDTZt41o(8$~Fr(cY)UOF1ulE?5c??$3DnSo4rI;WKCRAWBPR@vj2zN3nWU+sdru3G+n z_0?Nlhlld$=3%w06apI@&iDlRzpXht@vgk0S1S8>V^CXHtvx4kueRzwt>s*Fte%~1 zv=QZOlzR$bQKb5;YBT$#NbvI5_(}ME@v~p#aVB_uH1XYgWE($nNpW5|b?cqsig;M(N)WjKzh;B$O=}#Xn~< z>y>}TrCZ~`CR>{&V9=k#q0pbjrOBDYu8uiK&=2E6?o)JCJ$qx8r1j%bB>2)?^EkAm zf1SV~a7JfJcNh05=Qepx;Gk&qn`2Fl&%B6U^|Y>Z$C_LPj-F3V2Csf?MPli=*w@Y_ z&98h6#c_^`RuJbjg#Wj+3Gj?Mc^EP^f?F`xy zf(oBzZaXTSY#VQn`(<6HE+;)Y<#2>j`b0gUNN=fZ@wlvct4;eZmF(nc*gbQH^nk@A z@|vJl7N|{%R3^j~beYi1>=Mha8v857mhjCxRFbO~$SOXcPEC$8@Wk}ruJoD4Uahe3 zeWM8m7WF&@p@66S6Y}xAv(9IG z0n(#J{KT12xlW@M{t+6D7Oz-(y$0Iaoapw;e4`d;Qs;9SD6S_pf33qH}7=7Vbsk1 zx8rw;6j9&&!7Xm1YSo@%AAVot)$yqYhr}-vvlVFXR|auCNP)4V^M5z_#(mX)eIFx+ zRj({3(rpB$0CjY4r!F?FQSUEz=a@&IkV1O9R552KT_;DUz?N(?CH$A;M2b;Yt<&SK zFZT0H)*M9N5eLFerW$|8w{4YfwZ&Ry|IL1(YRKPOkDFQ~e4HOZcimR?%H($z8ZCT2 z9Y(d`baCP1Z>22ySHcrqY1%U`+YW8_*>IbTye?b+LXeO&Q*6EpcI zJcEAQYlJv5Vi8aU*Gb8>S4Lez^?_}KP>c1F7}D(hP<2&?fr)N|!0cEXb~NGr&dGk( zPxw9i8ez)DGJM%9i=A>ouV+SjSM`!uF4=?w=Tv{9Mds>A$?Vmfjjw|FdvN-3a zkJ@JNp^sfR7aX^{KU#tZ9EC}ybGt`Hr?<(AeW}M4=|t&X)^EO8d5pS?7+*@ye9GDj z8O*P0jj^haBr((YDV`lu=CHGNetdcCtS>QHWAUDiKxFwB$80_)i&+x8VHOvCKK_ip z&i*kPl<|+>)A)G<(M9~b{nXswzGlR}a;O%VAE+6Kf-j#@5tt9~35v#+ zVk&WVbnfb!OSSiwDHr@9FPpUqQnf$j`kS90LDo|Zfi*z9*Gjr5QOI1to3Pvk8IbX| z_XHhgh>c1C!Z+`Al#H^gYjE!-fI1wgLQW9G%@U8N9;hb=Ow3|L?_USrEEheCxqor+ zACWP&&8c0PNGgsz3B#X|_0Y$^phsV2 z+Z-77qKHyzv~lSdlb}l=hKwzQdX9~i-L6}NerN1MJwOFEG-R5Rz??Itu1MYXzQd0V zXD2(+T0SeKvX4IwrpM$mlR^0~K~HN#By>pb!OD^yp{OjeRJnZcu3UWKFnH#sHP0#x zF0GSaNuSn-M=`l-Ve9F#IqaV-i$FG9=pVgdoz40tlhR-H>#%;T$SOOe&oU>0tWSlF zD3Dd%t+(|qXFNn`RIhArk3Km>7?TWB_v0ri{=y}qZ#k0n%4yGLIf7=D-AQ1F{T$oA9haG|Fk{Q_juArL;6HO+le* z{AJ(d=q9h@>|d9$JULLvSiS_PGG%YjbvPMB{o=?}4d%2nh=G__klTm#z26)Rm&u%Zw$*(~9L3#|D`>YM!~^L(P-5 zKjghP&H{O1+<4TCaBh*dRT zgc#SqTf8x*t@dd3D*pZzBaBp5)(p4_IkE%SS;p11G+t$AttvFbqS`T)N#tf+l{@I5*z^-^Fs@$)&Yw znqxEf_q!`ad|(Cp?u|JB9$-CUp5IDwQghH6Kz8>+@XX?*m|7rIk1*l$Absi%#r97R zZ3qd&J~_;_ew-8O-w`FumI@oNT;~7>vgfk&N+L+>9@A~}5R1u?jQggIZ0QOEQDJc_ zLRxlqwjMh{7jRagw~1b5mlw*+5bEg%jtc&)?HBfOLC~K`|37ElLt5{#W9J$ZeSf2i zp5L9X1~3Tg>SPnBam4;Mbj~&3=3#=4e<)-(uj^E(aGM8c5**xkoQIy__DLc~8ltHm zB}y0T#Q*Tx5@{fxf5!CtLTUo^fzqo_cDc#bqZ;2Ug#&nT{WH1V!YI{SZ?nfeAyn5) z_R`b-X`Y$EL&Pk1!hE@3Q*Hd##7J=I7a_YkL`n+PV!5=CJoi)(cW%LN@@HQ!J$yZn z8$JUCO`{E0o%P03{vG5L$I^9XMM3Kq+xeQj=jcVm0>pje+O_7!EprdIXN447CU;C5 z>bX5JlcE>GV3t#MZrD}*wzA}=&qus6y9_I3q@i|62bSos(VfQBSQL#T++)+I+uN(> z<~T3%n*Y6NnA&bK0#vf&@%;ZRumv)ppQW#v#j08SO6u}}>361>yT$aP#3=hq?FFZ1 z(?h_x3Tf;4No_$f+s)IXG5J~)1I3$6RXqv%3-vy_l=*;l0oXG~7Nw=|O#YFL=kCaD z40Blt9~n^PDB_?7;$6+ezat#&j4jAom1u-xN9L`&caE>n32bw={H1q8z-1rsufA<7 zJwE}ZF*_H);rZv|b^P}R9|T5=RIRy{he`Qjg250=oRXbgmhee!FX@@VBMW%Zoz5>(tUV zWzui@@;V1tIKLiU6~EX#+mpgHXvt4F)^xZW$CDuYoCP#?C zx+;U1fDRw{1(evixzSK^KBG+IwO0+k=A_%-qev$o7Lj;bCGnAZy6#>8bb*j`UG$^4 zycV4){}F?~SG|1tkzduuF%j*8vfjgw@aGE21eeZ|W#-dAtvQzD_P!Gem%MAFL*)?s zcg8gh$<@J_2hQ02N&)`~CmIehgaUMY;S6!Yrb|cIs_%>XhQfm6>F@zOU@dN&*$n4- z@;!O;3*pvh(86Y!NecJjC<{22KgTocg(I_c281ywB-&NvgXalG+$|zSQi6oX9^kP| z2M$**bzfL6&q!hYV3j(M?b-V7&YY%rXHaKeUUi1ssGmHE(twI(_tSmUe^xW^kjd-Y zy}(y9Da1Fo2V1QqZHFBtumXPuz=Cd{GTFRU!ge&Cxr$6O8$Mm27^~Lab@HL1Fu|!H zr`h#SbENl8TktEHlc9SsD6!s_EpdRMXff))stp}1-lf&#y&1ILh*glFpVFlhvu7ww ziVFKD(8#@G9|qV$0b6+NrtNL_e(ozLkBkU$CP3Cbc!woR;QiEXQ|piUC5ky)sWRO` zsM9VRd+dQnlNcCKJqr?DS&#`{HHUEt)yQh)&CRyO zeRfh`FpKRE(LdE*(SMX81pb?mCHhXAuS?q5;s)Fe`)!T;e2;#+)+a2<#7I80j7svc znx=GEuQ`r@&)rro+`JxiU4RW}((Ea-7kEz$u1t0>l&RY?i-##_9X*YG18po$qL(;^ z;{XU-RRc@>*&PwGINMRSdB#3!|JR$RdNFtP|2#MVH1D#$FK`rCTPmFXX4+|G*7lsE{FeDOD?E z`ZG8bZywz4i(m=f4*s=f%7z1f`{q@jp9om}r@4lfplid6q z)@qZZRU;sEKV0y?cYou&$+((JhR(6CrU*%@K%R?<{>CGlQk-QjU2r><`D7pDM!5Ye z(wNF^m+NNX_1S{J^JPkp)=MFq|IB+p=f7JX(!@TB^i&dZ!L^8#dda8N_^CK4`B!B7 zuGOdc$L(OJ{dpWdcXO=7j8dZQy4@y_ za1}&GtaUr(tL7sE%1g4XV*A9#_T7O#0pxb>Uo_7-as9swYU*b_!Rq;nRm{czyNK`H zQF<$%Rb2YeMpuwp@iDz|FTX44*X@wwaGDj->$Yo`Y-e#h3yB$4;y@;`hHX3r#R27V&K1+QK*xpI-=3CXD@YwkHEDh9NRvz{&D@DY`)b9>ryj7|9O zn;QVod%rsKVvOzuA2|T9`eK9AY5YlABh2T2$_rH2AFY2hQ2gsE{C8Q_a$xwpFkAOh z*^WNR5PFrF_G58!A5OsC0Eo^&oxOYFm^H$jte01dinGGP6qMZlP}&-nM+UwZ9T~xW zHE_&y`ZI(@;KVvFPXlyaLh9C^h0rALhRj1U>fwIA^WL(13-w$O@n4abnWRva?R34bVzA z;ZH-yDn5Ien~T~zq!)a@|683JO02WtAlj#AE+MMHrg{><2dmERVsZ?C*HX<)kCq{U zk#<8_p>JJzQ|aSa)oZXGnsW2bZZ5~y8vSX*e;dqU#x{)2T~iUcKA4XdbNl-rSY2;? zjSOb~?H;2dm;OV%d>Wr5Hbs8gI}~&HEoJltJ$mC=im~u$>$1T+liVVCCd ztPSt_u~*aJI!i^N@7lxUj}^8DF2lA{+IbBBN16ES(5tnzSt5p+428rHdDfz(^*`^x zfNN%^iZXYo3VhPmwyTDX7lI}!)KLA^6ecAduF0uEz(@@ z_#6+lIiC09oa}v{cdk~wHwzHY?e8hM`FYlk5kX>fFe^zV5q6m=CpB_= z)iZW=%%D>%y+2R6xA{oVW@8AKBX$1P6dhlpT;A*EyV%~!n0iybcOv(1Q6NK>#(}B4*cGH zy3wo8ZL!l9EK>Thucw(OTi=<=EG~TSj<*n}2>IejwQCwL2mpkI+sngszG&& z4x?~9(|1x?L-hsl2IC7n{9DmPbo&FS8BU!hFqCYx8;;4;NVenRqP-dSKy2_hZZjc3*3BK)iah5$9!AMF0X+{etRvU z+`74=v`#C9Hm^E%5IN!W;3FmMQ^-m%=?ffnb8~MGV@Sefpc?RALyjli;kLqnK+3U> z=^vCdFg4-C>-ovCQ&F_GqGGx@NY#np+&lXbgFTVb?Cy?KnZ;*QOQQE%oO>q@@lMAv zmgC>kWJ#FT$x7jgyq+g|7EJf#s}421Z%J>QYv#mbh&|Vs=3idLs!(YW2F{V8BYqWC z+*wGG6cj=<_c5O;h0w-@1WU;4`|A0uQ8@ow32b9$>keV87}9IkeUTXP(=|^-_im~K zYSy4^I;*o5+8p4q3b>@$~z0Lox_$L#hekdeG=tai#2-_GkK!0~GvQR_>O zz3X@6a-`WUn8%Hp#UJ`(PL45o{H%NKBrV;j;E`YHW>3_*r%GR3bSuB|%sst+9C6tj zt^RuSu^h=Rf}p4Vd9=oFUZjHDA~i$=4kk(d1*>!%d7IS;|LJ8N zMo_d!!1rzH-y%djMo==&VWqe#{08hfuUli~TLQXhD%ZU8kf2;>7sSjCFd(pm>99MZ zS?nUOo6L(|H=pnJ-=+uEi_FZqk9k&0m3hHd@!VbjTpF2%f+lt9g+l&ee|U?20mUQQD?*h3tnN2y16WRS@I|=NDSX*VG(j`9FbyfbMYr zCexj%f7XY715#4DMSM!py~hQ1QIG^_5&r>&%B~%(q)FiW-s*`s_T z-H^E8uAZSYp(Y7`V2IAXOeYGGOFXUUWG;h?G0@rWdi z*E=ER%S+o|n7umJS;!i^%bDo%qeJ$w-1$6jQ3pWgDOwBn^VZHVt`si1RXhqcYPKzTio0dcktI=tVHGt<=MUpON5*8!4|*MCHFV(@y+B8$Y}`)kyX(jVG#c1 zvOKnFN60WBt7$au5r=98IzISa?)8&Ob?L-^23dz|7~!*dq9AxY4sG3#jh0&Pwc>LFkSrWXODcyygSLW$ z@H@cLkmBF-L+WjIx^~qLY*B8VLC4U_t2Vmm@e)c5j9Ur31$FHkN%7cv*&Oq&zd?3c z|HJtRKt-qjn0OVP(=5gR>{v>9E&989h9~YU2K>f_V<3%`%pq+vECKe07=4xslRxam zX2nyscvMC`?_$+Gc>6-xRy9CWh0?$D_;0wCr7%|e#!H|TOs^(EpA<{^{mU7&L?(-L zlX}I#Vksq+8JL58Em|>E@3Ghtdu=r^!2L2dJEeCu99@dgCodZ@Jh~hfy}l#)37^&W z9p;(KZBf?CX(|Rl#WbW^qpuWwE>sFx@cUXB|A50VaG0=ewISR4p0O%7zMP~ZlKEut zFM06{Me%hB#D;NPphOC&pAq*E@Pt|5X+2GjM}`D{uVi{4vk<$ONo|Ai*O91`qX}ri zE&e{6bY>#1A#WJRm{GfHw>}8=;^bL=ytp5IdC2-1?B4j{Y{7Y7xvH}NN*JEt6utDS z?;eNLqyEoo+lgYAZEYQRUJpCz!uMvJZXa>kB^ltSIwL0|FXh@NK3NU6RG{*F@G|us z`JXF-upuYx>2mM=(?xrP^wyV|bimNjoDVA>1$n2@=s22Y% z5(IyM2wdHU$uHY<%33KaLpKAF!fT0FA0=)3E}9bGCyDWtm9sIvY^boVjjX~!z5Gt3 zsY9#3UR;12$MFvZ&*&p!9~qbJv3lB>{8%u`(T~(n@91$3l{c6>GFlNF-&TGFIRfGt z0M{9T?zdwM-i75d%J2JLq+Dv_?Yi7?hEr)tsF00IOHRXA%(I((BJl|$@!w1Z9;CPg zkjb{AfZC+G1j2ed4ow^aJ(FQ{=haAVCz$BN5Od$flpND_q;K}4@5d+VEwYTaRqVDC*j7WXj7P{nl&fEa`r4@&un|Y23*n;*O6+}3A*Em z0AK`a9U8mW413(Yv4&PQU%Qv;a?GKhO+WS7W+$Uk$VODpQ4}=FR~M-A@A96*bgzKO zR_J=i#XZB6Dw_TOn)|ANsM@w&QW}X@8UZO$N^)R;5kxu_5EMZ|Qo6fA1nH1QS|x-b zqy`j}9vUP^O1iuEGatPF{{GW_v^{bV*Q|KrzOP#Lv{6ZnFv)Gje>^WwrIAs7?>Q7K zb@`idxdkhIVfz&v(gKT!r;%I%?+r=k~9U zBbLH?n^pFIW6t_Lsyn!jK6SX0lnFL-?&p|)e(f1e;n%vjdr_g3)dT*3{B=Yl7 z3--)u@ua*nW4_EQX9QWJv_lEHdy0d(f^+~Vx8C0>#V)p#K4tirY8fB9V?klUC_ zuOO=bd18TNf;>5CPxwn!30mkkAe~`}c(zVzA!{%3=vEhioU69`GH~gGB?)6Pw=)rc zPwjr#G%WJm938QuPe19O<68d9sej!+bAK*w!N3Oo^O>~F5uWH zZc`^?A*EGyw)XIl0m3!-m(w~$8>n(V{dv&FWN`mXM?OW!%c-E_#jj1Yt}z}Ht-!(U zlJ#YZvZt+uF%p9ZkO9!;IB zj?M2v0{P4}Uz9pMqyeX@lhWKa;f76N2YuFltT$hkM%yN%#N260w~YDCrpf1^tmHAz z$gsgjAiw!bP>&z@o(*cefBVy9fuR2K3Df=%ZyFoN37qPx|7fR>u|~tmK(yB09};?K zY=O@jn&$k0MUdKd_%GYu9sI{XXcj{nrsu`d@(pISNFK8a*}C1fV(|M0PwtU$ErUm_ zN|}b6eLID}rO)(o8%0nKcr?o}#yC4&Z0X?0rfTje6Ihy@R%8Xjol_1pUbCMNUO`O%=Fs@>{E-EG&N#PoZt>#@Che!Ja2HYow*=5 z#`XK?!YA7(SlD}00SYn6`fqQ1YaL`ek0UBf_ifk9%61`VMcpixQL+eoVRFI|vQq}W z9pbn*OQ(mUbXjR0ShBF9(?Xp7YZlm)kVs(cvT~f*ZR)`VRtGh`IcmnbLKA0q;tiB z^c8a@%YFyzq~lblRX4q?_|6r`>fe&by=k)Me~$Pp_#~5eVtNq99Aq&ElpCFZvFP;I zygyI2kc(MNbhX+Fggh?Te7msYzT6sfz@Xj>SuGqnsz1qny*IR9;nVv-k{W3n#}sw> z6|T{3P^p^glEP)@J3Z=f*5eY}(GkZE-+Mog6fY^XSimK)4ZXCt^BQ00?fV5q+{y4t zgF7<3S+G(ZXCu;-949Bywunj6%>%7=kd0*8v5ME(6esjPsLu7bta^8n|Dm< zR964K)lsp~xir5O>QXttCFW|yjpa(OLE{#1o0f&UnkPbZT<4a2%YAl;V%-|`mhYwu zc+4Av&9lv%f45|w|(q!D3*4Bv>Yumdb+gSvl02V zVY4UeWeX-MrmTQeGnqe$qWE?t9zaHT5DrK6RmZ zYEMIz(OUUQ8i9ZFOeSXl0PQkTZ*BpELC7fC1RCd7$fuCxgQsp~{btkY=OuLUtyd1- z{BNz&2l|8wR=K3_rI=>O+SoQ&lrHw)P%=&qwjaL{|267xH0rzQEhlxKpl%OpG2&5> zP%_0izHI%0ZcFi8t~n19Y-R%<5;fsSmhyzkfPx-fI1w}pI$ z`U*TY*ZQMtU7gO>!dV(SpCLc5(iuYRbl-j2eS@rPJ%jNO4@MeWm~I}r;VZgOmbhMo zArojGjgQTEiq<~nN_tlVt)K;#p`7`Vw^>GI?QT(IMzu6rdkI?jSOCP~HL6{&HvjyZ z^nGDhdA8rthmVoYuj+E4_bt3ZFTV_MI=a0(ADnJS<}AuX?!A)#>f6OaE8_e9Y|=C^ zq3E~vhHqW$-&-H%qTkFxXKz2kz^VUAf7@IQy7qp8wCiO4Q4?eCYusKk&6-AJlowF` z4gf3;XggUzMK~fb`6|~;!p*zpklL<*rxkRFY>yZg>&*r+n_%=?n`+{{qEiaX`&o>! z*Gjafag6)ts#D+`U9p3`XgT&g_GiYMO?oflefe+@F%0dbfhh|Kd*q6*L@>-FY9#bI z?+B##REQwm{1iWAAO2YA3W#_(=k{AQ{*=PPk%7569v~!{|;39QlL}KWq*Va^<$Q6&x%BG8mK22rmDuI z5Y^<1o$gYT92OcR1%>st2QENe+MX=8v7k*YN#YASI&eMvUyX@6$qtyG} z7mb^E9-X*?cRRYe8%jJ=DB9iVBkvLj|GVl)Lm_Wu0uZ82O~1ru4Ks^4e@?IGEygR4 zuJsp4(Po^#zCjw8H5KEnz%}1Cj1QcE?E_7An-bn6Ry)l!r!TBMK7Zgk_ZDw_t$I~ZeZ)qJ zX2_#Sw8IWy1>@KmRM|pAIhmIuVb-h0bEng0v3B^?0C??VWpK@qn{U=TUlb0tE|jCh zSO(4v)`E|DQND9~Q=~&3vu9jP+<@L+7IO~g_-D2i`!PJAw_v5TX>;rKP6uC0UbA)h zc&hH|@t}tte2HJ66?JarKlkQ9UiR(pql~I3jH8^9x7lGNDHw0PfjB)Sa5sP0V%x1A zN67V`w(cL>2h``4uNQ&Zv(j2G+r926W0=>MWI{bTd3v*;Oc_o|bSPiLGTA;>G-Xv` z=00)kXDVH}@nxKK-@OOt;fWbm^IMOkvZ?z_XZ*U*rDG|KWm@x#U=oK5i8tRX-Pz~; zfJx!ago~D85L|&Soou59K#~0EovcV16;`^`0h-vKfU6acB*G-K*t-ebG{1ZeBcebJe<`v8i0a9MRUrUoZ+Ktg z@eU|NlY8`G=Xvgi#{nq}orj4PlZT|lB1HWT#9;1pzL$L2+0`^%R7@POd4pm&u{(@c z{qPo%D>h7+b}<|)Z;)06??l4D5G}JiVdEL3yTtHrBEu=qd|ljyxH2OGN(oS(&G>9H}~y#>5|Ex7h(c#r(u>! z@d&3f&`V`_$2WI#UUt4i$UnQEd#&hec)V!cD z3a=h*Oqk=q^9<#OhcI=vzQbWBV21)O3>shQiI<)SqE@h%z-aScZF>&$m8ksNi|_%P z{Y?=YC_|qIZv@q>F>r>|WpBOmG@Rf*h}75y=AxuaVY)(j$!*doa(U&<9`QE4C+jlU zw`~$`W#ey&uN`-}TFml}QX4T6#x9WGOweu9~+K9}` zjbz{1x>0qzAy&!)`KBLTb~k6A#Uf6p!nk7r`4#JiSmShgu9}p%dSGDlVd3=Ya?s%G zbjRqA%JZykEZU`+Z0nxA@hM1nmGvfpOZT>Kud3&bbb{BKG z&n-af$kktBN6GqlFxN$>KMJ2--g2$eS586wHyuV)nJW^N3(XL&Kk;umLZQG zi`4JE!rN{o{I|LP8S%1fE8Kbk4Ha_dxeW-6xh*B2Sd`OYCd%X#TEvqlZ=r?F+|`jL zcPf5l@lF`@#;Rx4{_#?7C2@wfk_$3MML)d*3v5Vi z^zk=4LWa8Tl`cn1o2$s;wBgM!Y%R)1Wh4Sx$|_mK^ZpSc$(-1Yr(;*f>A2W4PUSW>zIM_{QQ=ufx}Qp#=vptbNwut&ZXim& zuw_V@p|{~RbUrN9q^dM?k`}ezTNa3PR+B@btcsY>;QjZxN9$0iMiMfa)qD-!yM=&Fm z^c~zUB5U!z)}*6s>=VarrR84dU6lJSI6}~#^Dwoa`MY=v8g>7)Ok^t#QQmp z=D*p(-hQtUko(In%NDMY&{pz#9RaaIb50qkMeVnpx3x zKlEIDn4)mT>)f$(X=;0`t0cd7YhnBi2v=vkKCPCC1-n!@rS}&lSdZ+1kAHX0L__@r zP|kDQYNEe;F2|d^ULSq@@)iYOc92@SRM`R?7iB@M(JrO+D;DcqM5^q8z6cxBz`a2C zYmW1Gwmz4~BKfDqHT({ZYNyNJ0-Jh%<%Ah7qcE?2S26E3mC2CA^AHOfcK9#_N4pK~ zN`7tTOf~(^P@iLWK<+au8QUv690n`5MImg_h#u+_L`v5Ad98!nQ$_GBBP2RNDt-Uv zZTJ)AVPv}M`_tWvZf3%VoISkQRA}E-er@TD2%T*|nZ@(~qeJcK^^$B9FpV*ER}^5~ z8(foYJwYZ^KB04kFKsT8Us!!>0&15Ei(y_H*+)5rjECl=iTuRNEI0irV5n7k+qt69 zxvD0t|`{&&@oI8LkpI} zTQ>K=3=b#n?NPYOYeiw>jf}Wv&YcGw_bm`mhcq%mW%4Z4I&dx174hy+m>?TDafzQmy#XYI!K z;&`L*hAqj4BUI^=Hir{d&XVED0Gs-#PEh*2-VBIOHH6#lN=v@`IP93u6lRh4x4>IZ zp4#0f3d=$@vdBb7y5YKxpSP$T+?ab(vMfu&a*93$6YfB7c@H4eEb|?D>ib4W=OkMH zBeuE4jF$-ezEu&-OASk2fnqWx;rFzIV@he&&`{S(k)M@-0gy%Uw=DZL`7sU8g9=(1 zZQ0%Wpj`p*e=&!hbkgaZhPc(G3}S^c+xp3odeCV>k*aP*n2Z`AaZfC$W7#QN_VLP!l@TIM-$f$ ztap#8OXSuIyWT2VzpjG2Gul&~ChUIfTH-A3F5L$*IJ7m;r|i&=rno;u7ke>+a>4sz z0zm0R@*4e2Gx1qrmbLQ$mT5;339IZ*9nl%5m?{TTCJ3oQ$-24^!C|~5#M!_&nKM&i5(g8&12#(vseeNc__!<kR9>{cJg7q$hQ{n_66e_e{N zRFOJ<{ej~jsaZ%q5}WGeLSg9959^xm5ySn0Yb(e=lKncTsOTl*Gh3Y!zso-m>r(Qv zTs6_Rw(q>d7{*x znPv$Kc}t6d_69z?x8Lr_X`Cd<>by3!0B4>F{yo+RlvFn2EgN3VMTr&74kv7J<@Kg+ znR*st8tz=kqT^{UzFch65}kPApp4!Y^ZF)T_ywwrW zborL-jrwKxmX1LthWL4&CI;$odz6ejADArr>Bnz`ZRvDqUvHeaNK(CS#Aiu?XLXrB zD_(jf^=`y5P6n~3T4Vj5Fg%Dy7mPYniwXn5!W9+^ykjenbJjLVH5daeDw*%{GgO@p zyT_hE?#l*zVG7v3b;MDUS&0;w+E#&YEvCjLD3^i`o4w{4@P41HF}R3NC2FcD}!ybWP@VhkJn zbcuD&U75yy=3~?GqL>gP=eD&dV+^nhk9_sxKNN$-ABOxDaG6euab5&XarNp4w4>FQ z^ecefvN|55y7zalE^RmOkl&%b-faU&T4>FmDd3m?sKH)H3WM{i3Y<>;zC{t*_`r}y zva9iQ`B$BVtrwP4FqqDB1|b~*6ZrAAZ+2yiUDe5t1ok~#ZORkBzDy0m+{;@i(@rMHzVY+_3PhB zB09Jj3vpeeYq9r)L-BYpzL`ZXvnMlaMJ2IKy88U>Xv!qcU;JwoLaa*O@Dx&E|<}qh?&6 zW#}FwLf`raaD>TM1zAi1+hyfxZwY9fpfh7kOpMwi!yt*H0!+c<@{bd#a$x zL8O<@k2Z=N3cr&q4(?@NX0N>RSlO@M0sb>=I$88Xi;?XA00vg44k-qyB~bH^ea6`O z@lxi3WSFT2mgK?pN=}ZndC;zUG)^;z0KAx86m)j6xNBBRVL*KF++%U|BDE@$eon)w zQ@4*?%+LF49jO7jN2(;@emaStoP6o&?``N8DYPk|O#Qo7xw=cFt)rrzGcA+wUDq?m z>6ufRgF~UVZBYxTcBk*O{3*fLKe3b4MY$2$m|I@An!kMs^2Ms>rNP=p^|Aw}Z zN+L5gq@2c=9iKJGP;&u2)bVHJUCSY(Iu4BoB{sX&E%$t75@#}Al9{lT#VS7>MeO}S zy&3n^mAbGjI_n`9x(GcHtk9*xz(ZH@=hR9dHpD8NL+!h@tl!|cN!~>g-WtlV89V!_VMa3k)9h_07C{p(|OWRodo?H-ptMXWLgMN?qgFOAry2NGkfad$dIKL0Fwxf`j=NMivhUc!s(R9yJ2EF9D3fQxI_C5f_{+14LQB%J0ta$2v3uj3zFd$n}!KO}r_56;p4NNA$ zxMVS$8(QAISD-Mh54>`8B!O>T_h~9*L2$r}q{VHm5HdyMA~F+y$ve#JZt=HeliSa- zB)s4olVYJeK0byfZHAv#UacR8+t#fYr!E~kE7|x=Rf%;`V|E#}2`D__a^*b3f zd9P=pZv5rP<;e{ZrFa?sAue@)Ua3}E159>9zjh=P2FFA?|LPKO%NP(U=po7a zro&`LHHu(OTGhmQ@AN?o=H7h|op*r_bci1K(GO(_ToHi^yBn9r-)9OYg zO=bP#D0Xe{I30o^b@-jN5a1^)DKq%V(=ACWv_dn6IK4s0jg9KI6*1DbrxlB7rj}UD z<{Fx--iDL)3{K>@TbzQIH*abUn%qS$hTraUiN*Qupw$R>%@X?7C zJwBp7G)o!vAjTBv4aJD$h4)VEK-5x{?E$CCrm%X2m5Zv*g1i4_8iLY$F`$7|?t<9F zS62SD<4VJ5_ka*Gv&7WWct;~f?@ktcyo}>WeNQ4iaV@ddBMVqTQq>JJN;OIsseCKy3G z$H49b<2j7TAHR5bomCJ_OXcfGVKGo))gr|+YFufzu+mbM6F_(a>39je$&yS5;$D7n zKLdrq?YQZ|Db#!d0D#((hmAtuMCHneLe(8EAK2Zw%3<<13Z|855_)EYK{n^RD}mBa zW9;`G&Z_uo56s=^xa98R4pJtCCCOl)InAh~RvN^4xL9whK!zQ;hR=M83LwB};TH%Y zFY6jzPrT{NkVBx4h)}Ra|KE^z|0szxl|Y|ww1-SJjgqA&5Fap9YsIVZ5b&=H6Sjhi zGc=Y9GT4ZxZh7N~QGI@T)(@l@wTpdA+N)R+wF(g_yb~~Zcru@$$u7f-VNy$7Hl8uV zUvxbYL&?n)E}%{afhy@$itmy(=iNRTKmebElum35oe$bE$t@or7(X00H4R>)O(|I~ z+Z`P^i?N>2ADm4!VH?*Uq?kaRzMI3Nd+b8%UPT_Wi5F#b)z#35VL*A1FF|`ht=DH= zsHfmK4lMc%_!yH%dAu$79F$AyQ?7q(Jf&m)GwClBeqdjpfj#)da_e7X>NvbY?sRVNou zm#}AdTOR?zr$9q#DR+0(n~;%fI7&K`@!$9*2W>tNg+2=Us;xe$vNIwdva0CJzw zHe5v6&wh-4%u>Ifc6drzGhHrE^2{f{ZgRf2`S(8JxcpIBrx&nI{Kq@jM}aE)75e$K z2HFLK(^ z4IUjS#BN7AiUE2b5T~&RK8sy`P4sN)2}(n!n2WuKfr6;7``O}$#m%u<5Ep}Zm9ZJY zwZ&iF!V65jLUBOepMc(Jv-0vQ%m#fKS<M(o2n!QrSfa|79{P2b~cj-JCT8WYhY~r) z*!YkY4B&u_TNE2Q-~Vq&TPQhZOo^)UGtZ{@f83QnJO)xemjBUyaNOV3Kt4gQyS zHy->I*1rRyW&YH9)e*DG7&Ui|yp+#J+ZG)2Z#$ZQ z58Y|QgEv_i{rA~_*I=~&vt>yuVVl@+IK|u+E|Lgh7dj7O|JQ|Jb~|K=@af4nNbP&R zs9i7j+zG=AG(>mp|MPW9*(S`aCTLk3+Np6UF(Dg-d{Pv)zQpU*B z`Qd2KcAsA1ef$98J6QjHOR)73-jCI?YMEcS-A)V_6kV1Gs7Mr}vnZ9vMu`2d-(>ys ziyX`R>66!h91<|^aCP~X_=UX#u*eJQ(Ic^|w~Id_QUCSGjbrXVD^;*sVHb0~^Y{f} zbOyp8)j5>~Tw`I>Q9+nh2lmf2Xv9Y_f_$&{+E&Oj-*fFJdE-~;B-9V&=M3P&HRxcSckCwTr$+N*r{FtEjoCDaQ$)_yUKXreCHLF z)Ovcc3A5ZBdN;QcdhiIsCxuX_0=j}=#gXeK=f!+wMx}gpP;f_5 z^hW;Q{{9HJSwtS=><;t*d31^4giA~wc?{6RGb?_Z0~2z4j?EE&9hET z(JLN9hem1#$5!QUsBZ6!rM7;Pdb5b%bV$A z5m*h3XqFwXGO8(do91p zl0dn54+xN;J+tO!RJ{jnc%QqR0B;)`Khn7jr}k#b`uEt&JGb;29h$uLtb-?a-W--7@KMdB3j5+VF_uDr;5A zy=LwW;u)`aR{vd)sP)Cj z&YY$sf!)Sa_GdsV`qBIX-Nt?vO}h&mx5<=+z1B^U7C@?CM5FgZ#%-k=7ZAu<&Wx1` zFK}Sm1tpW(MIR?m>}{rKr#_Swr08|i9PR;uaBLH7SR<1=<8I=Gs4Q^a&P!mQ?(1d0 zFslSnxctOT%oO%Oxw8gyM&qPhK4Ema5)*;<^KvTHlI z%ARiHA7r(AN2tclU`-jX(+sfplo%iBljr)ESW#{gCW@R@&>4E;E0LdnCxVgWu!LO{45 j_-xe5^8e={+As0etSRBcmM = engine.chatCompletion( + messages=listOf( + ChatCompletionMessage( + role=ChatCompletionRole.user, + content="What is the meaning of life?" + ) + ), + model=modelPath, + ) + coroutineScope.launch { + for (it in response) { + responseText.value += it.choices[0].delta.content?.asText() + } + } + Surface( + modifier = Modifier + .fillMaxSize() + ) { + MLCEngineExampleTheme { + Text(text = responseText.value) + } + } + } + } +} diff --git a/android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Color.kt b/android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Color.kt new file mode 100644 index 0000000000..8a9a8b5f25 --- /dev/null +++ b/android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Color.kt @@ -0,0 +1,44 @@ +package ai.mlc.mlcengineexample.ui.theme + +import androidx.compose.ui.graphics.Color + +val Blue10 = Color(0xFF000F5E) +val Blue20 = Color(0xFF001E92) +val Blue30 = Color(0xFF002ECC) +val Blue40 = Color(0xFF1546F6) +val Blue80 = Color(0xFFB8C3FF) +val Blue90 = Color(0xFFDDE1FF) + +val DarkBlue10 = Color(0xFF00036B) +val DarkBlue20 = Color(0xFF000BA6) +val DarkBlue30 = Color(0xFF1026D3) +val DarkBlue40 = Color(0xFF3648EA) +val DarkBlue80 = Color(0xFFBBC2FF) +val DarkBlue90 = Color(0xFFDEE0FF) + +val Yellow10 = Color(0xFF261900) +val Yellow20 = Color(0xFF402D00) +val Yellow30 = Color(0xFF5C4200) +val Yellow40 = Color(0xFF7A5900) +val Yellow80 = Color(0xFFFABD1B) +val Yellow90 = Color(0xFFFFDE9C) + +val Red10 = Color(0xFF410001) +val Red20 = Color(0xFF680003) +val Red30 = Color(0xFF930006) +val Red40 = Color(0xFFBA1B1B) +val Red80 = Color(0xFFFFB4A9) +val Red90 = Color(0xFFFFDAD4) + +val Grey10 = Color(0xFF191C1D) +val Grey20 = Color(0xFF2D3132) +val Grey80 = Color(0xFFC4C7C7) +val Grey90 = Color(0xFFE0E3E3) +val Grey95 = Color(0xFFEFF1F1) +val Grey99 = Color(0xFFFBFDFD) + +val BlueGrey30 = Color(0xFF45464F) +val BlueGrey50 = Color(0xFF767680) +val BlueGrey60 = Color(0xFF90909A) +val BlueGrey80 = Color(0xFFC6C5D0) +val BlueGrey90 = Color(0xFFE2E1EC) \ No newline at end of file diff --git a/android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Theme.kt b/android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Theme.kt new file mode 100644 index 0000000000..aa56c8fca9 --- /dev/null +++ b/android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Theme.kt @@ -0,0 +1,107 @@ +package ai.mlc.mlcengineexample.ui.theme + +import android.app.Activity +import android.os.Build +import androidx.compose.foundation.isSystemInDarkTheme +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.darkColorScheme +import androidx.compose.material3.dynamicDarkColorScheme +import androidx.compose.material3.dynamicLightColorScheme +import androidx.compose.material3.lightColorScheme +import androidx.compose.runtime.Composable +import androidx.compose.runtime.SideEffect +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.toArgb +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.platform.LocalView +import androidx.core.view.WindowCompat + +private val DarkColorScheme = darkColorScheme( + primary = Blue80, + onPrimary = Blue20, + primaryContainer = Blue30, + onPrimaryContainer = Blue90, + inversePrimary = Blue40, + secondary = DarkBlue80, + onSecondary = DarkBlue20, + secondaryContainer = DarkBlue30, + onSecondaryContainer = DarkBlue90, + tertiary = Yellow80, + onTertiary = Yellow20, + tertiaryContainer = Yellow30, + onTertiaryContainer = Yellow90, + error = Red80, + onError = Red20, + errorContainer = Red30, + onErrorContainer = Red90, + background = Grey10, + onBackground = Grey90, + surface = Grey10, + onSurface = Grey80, + inverseSurface = Grey90, + inverseOnSurface = Grey20, + surfaceVariant = BlueGrey30, + onSurfaceVariant = BlueGrey80, + outline = BlueGrey60 +) + +private val LightColorScheme = lightColorScheme( + primary = Blue40, + onPrimary = Color.White, + primaryContainer = Blue90, + onPrimaryContainer = Blue10, + inversePrimary = Blue80, + secondary = DarkBlue40, + onSecondary = Color.White, + secondaryContainer = DarkBlue90, + onSecondaryContainer = DarkBlue10, + tertiary = Yellow40, + onTertiary = Color.White, + tertiaryContainer = Yellow90, + onTertiaryContainer = Yellow10, + error = Red40, + onError = Color.White, + errorContainer = Red90, + onErrorContainer = Red10, + background = Grey99, + onBackground = Grey10, + surface = Grey99, + onSurface = Grey10, + inverseSurface = Grey20, + inverseOnSurface = Grey95, + surfaceVariant = BlueGrey90, + onSurfaceVariant = BlueGrey30, + outline = BlueGrey50 +) + +@Composable +fun MLCEngineExampleTheme( + darkTheme: Boolean = isSystemInDarkTheme(), + // Dynamic color is available on Android 12+ + dynamicColor: Boolean = true, + content: @Composable () -> Unit +) { + val colorScheme = when { + dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> { + val context = LocalContext.current + if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context) + } + + darkTheme -> DarkColorScheme + else -> LightColorScheme + } + val view = LocalView.current + if (!view.isInEditMode) { + SideEffect { + val window = (view.context as Activity).window + window.statusBarColor = colorScheme.primary.toArgb() + WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme + } + } + + MaterialTheme( + colorScheme = colorScheme, + typography = Typography, + content = content + ) +} \ No newline at end of file diff --git a/android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Type.kt b/android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Type.kt new file mode 100644 index 0000000000..345efc4749 --- /dev/null +++ b/android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Type.kt @@ -0,0 +1,34 @@ +package ai.mlc.mlcengineexample.ui.theme + +import androidx.compose.material3.Typography +import androidx.compose.ui.text.TextStyle +import androidx.compose.ui.text.font.FontFamily +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.unit.sp + +// Set of Material typography styles to start with +val Typography = Typography( + bodyLarge = TextStyle( + fontFamily = FontFamily.Default, + fontWeight = FontWeight.Normal, + fontSize = 16.sp, + lineHeight = 24.sp, + letterSpacing = 0.5.sp + ) + /* Other default text styles to override + titleLarge = TextStyle( + fontFamily = FontFamily.Default, + fontWeight = FontWeight.Normal, + fontSize = 22.sp, + lineHeight = 28.sp, + letterSpacing = 0.sp + ), + labelSmall = TextStyle( + fontFamily = FontFamily.Default, + fontWeight = FontWeight.Medium, + fontSize = 11.sp, + lineHeight = 16.sp, + letterSpacing = 0.5.sp + ) + */ +) \ No newline at end of file diff --git a/android/MLCEngineExample/app/src/main/res/drawable/ic_android_black_24dp.xml b/android/MLCEngineExample/app/src/main/res/drawable/ic_android_black_24dp.xml new file mode 100644 index 0000000000..fe51230740 --- /dev/null +++ b/android/MLCEngineExample/app/src/main/res/drawable/ic_android_black_24dp.xml @@ -0,0 +1,5 @@ + + + diff --git a/android/MLCEngineExample/app/src/main/res/drawable/mlc_logo_108.xml b/android/MLCEngineExample/app/src/main/res/drawable/mlc_logo_108.xml new file mode 100644 index 0000000000..d5307e0979 --- /dev/null +++ b/android/MLCEngineExample/app/src/main/res/drawable/mlc_logo_108.xml @@ -0,0 +1,11 @@ + + + diff --git a/android/MLCEngineExample/app/src/main/res/values/colors.xml b/android/MLCEngineExample/app/src/main/res/values/colors.xml new file mode 100644 index 0000000000..f8c6127d32 --- /dev/null +++ b/android/MLCEngineExample/app/src/main/res/values/colors.xml @@ -0,0 +1,10 @@ + + + #FFBB86FC + #FF6200EE + #FF3700B3 + #FF03DAC5 + #FF018786 + #FF000000 + #FFFFFFFF + \ No newline at end of file diff --git a/android/MLCEngineExample/app/src/main/res/values/strings.xml b/android/MLCEngineExample/app/src/main/res/values/strings.xml new file mode 100644 index 0000000000..e6fa718075 --- /dev/null +++ b/android/MLCEngineExample/app/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + MLCEngineExample + \ No newline at end of file diff --git a/android/MLCEngineExample/app/src/main/res/values/themes.xml b/android/MLCEngineExample/app/src/main/res/values/themes.xml new file mode 100644 index 0000000000..54af29ec8d --- /dev/null +++ b/android/MLCEngineExample/app/src/main/res/values/themes.xml @@ -0,0 +1,6 @@ + + + +