diff --git a/android/library/prepare_model_lib.py b/android/library/prepare_model_lib.py index dc14397a16..9f143d7357 100644 --- a/android/library/prepare_model_lib.py +++ b/android/library/prepare_model_lib.py @@ -1,5 +1,6 @@ import json import os + from tvm.contrib import ndk @@ -23,8 +24,8 @@ def main(): tar_list = [] model_set = set() - for model, model_lib_path in app_config["model_lib_path_for_prepare_libs"].items(): - path = os.path.join(artifact_path, model_lib_path) + for model, model_lib in app_config["model_lib_path_for_prepare_libs"].items(): + path = os.path.join(artifact_path, model_lib) if not os.path.isfile(path): raise RuntimeError(f"Cannot find android library {path}") tar_list.append(path) @@ -58,11 +59,11 @@ def main(): model_prefix_pattern not in global_symbol_map and "_" + model_prefix_pattern not in global_symbol_map ): - model_lib_path = app_config["model_lib_path_for_prepare_libs"][model_lib] + model_lib = app_config["model_lib_path_for_prepare_libs"][model_lib] print( "ValidationError:\n" f"\tmodel_lib {model_lib} requested in {app_config_path} is not found in {lib_path}\n" - f"\tspecifically the model_lib for {model_lib_path} in model_lib_path_for_prepare_libs.\n" + f"\tspecifically the model_lib for {model_lib} in model_lib_path_for_prepare_libs.\n" f"\tcurrent available model_libs in {lib_path}: {available_model_libs}" ) error_happened = True diff --git a/cpp/json_ffi/config.cc b/cpp/json_ffi/conv_template.cc similarity index 86% rename from cpp/json_ffi/config.cc rename to cpp/json_ffi/conv_template.cc index 8f5c0e1062..9511bb5b64 100644 --- a/cpp/json_ffi/config.cc +++ b/cpp/json_ffi/conv_template.cc @@ -1,8 +1,8 @@ -#include "config.h" +#include "conv_template.h" #include -#include "../metadata/json_parser.h" +#include "../support/json_parser.h" namespace mlc { namespace llm { @@ -10,27 +10,6 @@ namespace json_ffi { using namespace mlc::llm; -/****************** Model-defined generation config ******************/ - -TVM_REGISTER_OBJECT_TYPE(ModelDefinedGenerationConfigNode); - -ModelDefinedGenerationConfig::ModelDefinedGenerationConfig(double temperature, double top_p, - double frequency_penalty, - double presence_penalty) { - ObjectPtr n = make_object(); - n->temperature = temperature; - n->top_p = top_p; - n->frequency_penalty = frequency_penalty; - n->presence_penalty = presence_penalty; - data_ = std::move(n); -} - -TVM_REGISTER_GLOBAL("mlc.json_ffi.ModelDefinedGenerationConfig") - .set_body_typed([](double temperature, double top_p, double frequency_penalty, - double presence_penalty) { - return ModelDefinedGenerationConfig(temperature, top_p, frequency_penalty, presence_penalty); - }); - /****************** Conversation template ******************/ std::map PLACEHOLDERS = { @@ -334,24 +313,6 @@ std::optional Conversation::FromJSON(const std::string& json_str, return Conversation::FromJSON(json_obj.value(), err); } -/****************** JSON FFI engine config ******************/ - -TVM_REGISTER_OBJECT_TYPE(JSONFFIEngineConfigNode); - -JSONFFIEngineConfig::JSONFFIEngineConfig( - String conv_template, Map model_generation_cfgs) { - ObjectPtr n = make_object(); - n->conv_template = conv_template; - n->model_generation_cfgs = model_generation_cfgs; - data_ = std::move(n); -} - -TVM_REGISTER_GLOBAL("mlc.json_ffi.JSONFFIEngineConfig") - .set_body_typed([](String conv_template, - Map model_generation_cfgs) { - return JSONFFIEngineConfig(std::move(conv_template), std::move(model_generation_cfgs)); - }); - } // namespace json_ffi } // namespace llm } // namespace mlc diff --git a/cpp/json_ffi/config.h b/cpp/json_ffi/conv_template.h similarity index 66% rename from cpp/json_ffi/config.h rename to cpp/json_ffi/conv_template.h index fe5e4e42e2..eeb348831c 100644 --- a/cpp/json_ffi/config.h +++ b/cpp/json_ffi/conv_template.h @@ -1,9 +1,5 @@ -#ifndef MLC_LLM_JSON_FFI_CONFIG_H -#define MLC_LLM_JSON_FFI_CONFIG_H - -#include -#include -#include +#ifndef MLC_LLM_JSON_FFI_CONV_TEMPLATE_H +#define MLC_LLM_JSON_FFI_CONV_TEMPLATE_H #include #include @@ -22,35 +18,11 @@ namespace mlc { namespace llm { namespace json_ffi { -/****************** Model-defined generation config ******************/ - -class ModelDefinedGenerationConfigNode : public Object { - public: - double temperature; - double top_p; - double frequency_penalty; - double presence_penalty; - - static constexpr const char* _type_key = "mlc.json_ffi.ModelDefinedGenerationConfig"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(ModelDefinedGenerationConfigNode, Object); -}; - -class ModelDefinedGenerationConfig : public ObjectRef { - public: - explicit ModelDefinedGenerationConfig(double temperature, double top_p, double frequency_penalty, - double presence_penalty); - - TVM_DEFINE_OBJECT_REF_METHODS(ModelDefinedGenerationConfig, ObjectRef, - ModelDefinedGenerationConfigNode); -}; - /****************** Conversation template ******************/ enum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION }; -MessagePlaceholders messagePlaceholderFromString(const std::string& role); +MessagePlaceholders MessagePlaceholderFromString(const std::string& role); class Message { public: @@ -144,29 +116,8 @@ struct Conversation { static std::optional FromJSON(const std::string& json_str, std::string* err); }; -/****************** JSON FFI engine config ******************/ - -class JSONFFIEngineConfigNode : public Object { - public: - String conv_template; - Map model_generation_cfgs; - - static constexpr const char* _type_key = "mlc.json_ffi.JSONFFIEngineConfig"; - static constexpr const bool _type_has_method_sequal_reduce = false; - static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(JSONFFIEngineConfigNode, Object); -}; - -class JSONFFIEngineConfig : public ObjectRef { - public: - explicit JSONFFIEngineConfig(String conv_template, - Map model_generation_cfgs); - - TVM_DEFINE_OBJECT_REF_METHODS(JSONFFIEngineConfig, ObjectRef, JSONFFIEngineConfigNode); -}; - } // namespace json_ffi } // namespace llm } // namespace mlc -#endif /* MLC_LLM_JSON_FFI_CONV_TEMPLATE_H */ +#endif // MLC_LLM_JSON_FFI_CONV_TEMPLATE_H diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index d5fc53b8fa..6b2676ee3f 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -4,6 +4,10 @@ #include #include +#include "../serve/model.h" +#include "../support/json_parser.h" +#include "../support/result.h" + namespace mlc { namespace llm { namespace json_ffi { @@ -83,13 +87,27 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request Array inputs = inputs_obj.value(); // generation_cfg - Optional generation_cfg = GenerationConfig::Create( - request_json_str, &err_, conv_template, this->model_generation_cfgs[request.model]); - if (!generation_cfg.defined()) { - return false; + Array stop_strs; + stop_strs.reserve(conv_template.stop_str.size()); + for (const std::string& stop_str : conv_template.stop_str) { + stop_strs.push_back(stop_str); + } + if (request.stop.has_value()) { + stop_strs.reserve(stop_strs.size() + request.stop.value().size()); + for (const std::string& stop_str : request.stop.value()) { + stop_strs.push_back(stop_str); + } } - Request engine_request(request_id, inputs, generation_cfg.value()); + GenerationConfig generation_cfg(request.n, request.temperature, request.top_p, + request.frequency_penalty, request.presence_penalty, + /*repetition_penalty=*/std::nullopt, request.logprobs, + request.top_logprobs, request.logit_bias, request.seed, + request.ignore_eos, request.max_tokens, std::move(stop_strs), + conv_template.stop_token_ids, /*response_format=*/std::nullopt, + this->default_generation_cfg_json_str_); + + Request engine_request(request_id, inputs, generation_cfg); this->engine_->AddRequest(engine_request); return true; @@ -122,22 +140,8 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &JSONFFIEngineImpl::ExitBackgroundLoop); TVM_MODULE_VTABLE_END(); - void InitBackgroundEngine(JSONFFIEngineConfig json_ffi_engine_config, EngineConfig engine_config, - Device device, Optional request_stream_callback, + void InitBackgroundEngine(Device device, Optional request_stream_callback, Optional trace_recorder) { - std::optional conv_template = - Conversation::FromJSON(json_ffi_engine_config->conv_template, &err_); - if (!conv_template.has_value()) { - LOG(FATAL) << "Invalid conversation template JSON: " << err_; - } - this->conv_template_ = conv_template.value(); - this->model_generation_cfgs = json_ffi_engine_config->model_generation_cfgs; - - // Todo(mlc-team): decouple InitBackgroundEngine into two functions - // by removing `engine_config` from arguments, after properly handling - // streamers. - this->streamer_ = TextStreamer(Tokenizer::FromPath(engine_config->model)); - CHECK(request_stream_callback.defined()) << "JSONFFIEngine requires request stream callback function, but it is not given."; this->request_stream_callback_ = request_stream_callback.value(); @@ -150,12 +154,31 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { }; request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - this->engine_->InitBackgroundEngine(device, std::move(request_stream_callback), - std::move(trace_recorder)); - this->engine_->Reload(std::move(engine_config)); + this->engine_->InitThreadedEngine(device, std::move(request_stream_callback), + std::move(trace_recorder)); } - void Reload(EngineConfig engine_config) { this->engine_->Reload(std::move(engine_config)); } + void Reload(String engine_config_json_str) { + this->engine_->Reload(engine_config_json_str); + this->default_generation_cfg_json_str_ = this->engine_->GetDefaultGenerationConfigJSONString(); + picojson::object engine_config_json = + json::ParseToJsonObject(this->engine_->GetCompleteEngineConfigJSONString()); + + // Load conversation template. + Result model_config_json = + serve::Model::LoadModelConfig(json::Lookup(engine_config_json, "model")); + CHECK(model_config_json.IsOk()) << model_config_json.UnwrapErr(); + std::optional conv_template = Conversation::FromJSON( + json::Lookup(model_config_json.Unwrap(), "conv_template"), &err_); + if (!conv_template.has_value()) { + LOG(FATAL) << "Invalid conversation template JSON: " << err_; + } + this->conv_template_ = conv_template.value(); + // Create streamer. + // Todo(mlc-team): Create one streamer for each request, instead of a global one. + this->streamer_ = + TextStreamer(Tokenizer::FromPath(json::Lookup(engine_config_json, "model"))); + } void Unload() { this->engine_->Unload(); } diff --git a/cpp/json_ffi/json_ffi_engine.h b/cpp/json_ffi/json_ffi_engine.h index d57384abb5..e805cb6e8a 100644 --- a/cpp/json_ffi/json_ffi_engine.h +++ b/cpp/json_ffi/json_ffi_engine.h @@ -12,7 +12,7 @@ #include "../serve/threaded_engine.h" #include "../streamer.h" -#include "config.h" +#include "conv_template.h" #include "openai_api_protocol.h" namespace mlc { @@ -49,7 +49,7 @@ class JSONFFIEngine { PackedFunc request_stream_callback_; TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request Conversation conv_template_; - Map model_generation_cfgs; + String default_generation_cfg_json_str_; }; } // namespace json_ffi diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc index 13f4b140ce..4547108eb5 100644 --- a/cpp/json_ffi/openai_api_protocol.cc +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -5,7 +5,7 @@ */ #include "openai_api_protocol.h" -#include "../metadata/json_parser.h" +#include "../support/json_parser.h" namespace mlc { namespace llm { diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index 429050da3c..70ef2fb22f 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -13,7 +13,7 @@ #include #include -#include "config.h" +#include "conv_template.h" #include "picojson.h" namespace mlc { @@ -94,7 +94,7 @@ class ChatCompletionRequest { std::optional presence_penalty = std::nullopt; bool logprobs = false; int top_logprobs = 0; - std::optional> logit_bias = std::nullopt; + std::optional>> logit_bias = std::nullopt; std::optional max_tokens = std::nullopt; int n = 1; std::optional seed = std::nullopt; diff --git a/cpp/metadata/model.cc b/cpp/metadata/model.cc index 8c2cf66a80..2daf1d0338 100644 --- a/cpp/metadata/model.cc +++ b/cpp/metadata/model.cc @@ -4,7 +4,7 @@ #include -#include "./json_parser.h" +#include "../support/json_parser.h" namespace mlc { namespace llm { @@ -39,6 +39,16 @@ ModelMetadata::Param ModelMetadata::Param::FromJSON(const picojson::object& para return result; } +ModelMetadata::KVCacheMetadata ModelMetadata::KVCacheMetadata::FromJSON( + const picojson::object& json) { + KVCacheMetadata kv_cache_metadata; + kv_cache_metadata.num_hidden_layers = json::Lookup(json, "num_hidden_layers"); + kv_cache_metadata.head_dim = json::Lookup(json, "head_dim"); + kv_cache_metadata.num_attention_heads = json::Lookup(json, "num_attention_heads"); + kv_cache_metadata.num_key_value_heads = json::Lookup(json, "num_key_value_heads"); + return kv_cache_metadata; +} + ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata, const picojson::object& model_config) { ModelMetadata result; @@ -53,6 +63,8 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata, if (metadata.count("attention_sink_size")) // remove after sink is decoupled from model lib result.attention_sink_size = json::Lookup(metadata, "attention_sink_size"); result.tensor_parallel_shards = json::Lookup(metadata, "tensor_parallel_shards"); + result.kv_cache_metadata = + KVCacheMetadata::FromJSON(json::Lookup(metadata, "kv_cache")); { std::vector& params = result.params; picojson::array json_params = json::Lookup(metadata, "params"); @@ -76,17 +88,8 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata, ModelMetadata ModelMetadata::FromModule(tvm::runtime::Module module, const picojson::object& model_config) { std::string json_str = ""; - try { - TypedPackedFunc pf = module.GetFunction("_metadata"); - if (pf == nullptr) { - // legacy path - // TODO: remove this after full SLMify - return ModelMetadata(); - } - json_str = pf(); - } catch (...) { - return ModelMetadata(); // TODO: add a warning message about legacy usecases - } + TypedPackedFunc pf = module.GetFunction("_metadata"); + json_str = pf(); picojson::object json = json::ParseToJsonObject(json_str); try { return ModelMetadata::FromJSON(json, model_config); diff --git a/cpp/metadata/model.h b/cpp/metadata/model.h index 2472cb7d36..ede06b6b3f 100644 --- a/cpp/metadata/model.h +++ b/cpp/metadata/model.h @@ -32,6 +32,14 @@ struct ModelMetadata { static Param FromJSON(const picojson::object& param_obj, const picojson::object& model_config); }; + struct KVCacheMetadata { + int64_t num_hidden_layers; + int64_t num_attention_heads; + int64_t num_key_value_heads; + int64_t head_dim; + static KVCacheMetadata FromJSON(const picojson::object& json); + }; + std::string model_type; std::string quantization; int64_t context_window_size; @@ -41,6 +49,7 @@ struct ModelMetadata { int64_t attention_sink_size; std::vector params; std::unordered_map memory_usage; + KVCacheMetadata kv_cache_metadata; static ModelMetadata FromJSON(const picojson::object& json_str, const picojson::object& model_config); diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 3bb809ad67..30a3617a8d 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -5,12 +5,14 @@ #include "config.h" #include +#include #include +#include #include #include "../json_ffi/openai_api_protocol.h" -#include "../metadata/json_parser.h" +#include "../support/json_parser.h" #include "data.h" namespace mlc { @@ -21,178 +23,174 @@ namespace serve { TVM_REGISTER_OBJECT_TYPE(GenerationConfigNode); -GenerationConfig::GenerationConfig(String config_json_str) { - picojson::value config_json; - std::string err = picojson::parse(config_json, config_json_str); - if (!err.empty()) { - LOG(FATAL) << err; - return; +GenerationConfig::GenerationConfig( + std::optional n, std::optional temperature, std::optional top_p, + std::optional frequency_penalty, std::optional presense_penalty, + std::optional repetition_penalty, std::optional logprobs, + std::optional top_logprobs, std::optional>> logit_bias, + std::optional seed, std::optional ignore_eos, std::optional max_tokens, + std::optional> stop_strs, std::optional> stop_token_ids, + std::optional response_format, Optional default_config_json_str) { + ObjectPtr obj = make_object(); + GenerationConfig default_config; + if (default_config_json_str.defined()) { + default_config = GenerationConfig(default_config_json_str.value(), NullOpt); + } else { + default_config = GenerationConfig(obj); } - ObjectPtr n = make_object(); - - picojson::object config = config_json.get(); - if (config.count("n")) { - CHECK(config["n"].is()); - n->n = config["n"].get(); - CHECK_GT(n->n, 0) << "\"n\" should be at least 1"; - } - if (config.count("temperature")) { - CHECK(config["temperature"].is()); - n->temperature = config["temperature"].get(); - } - if (config.count("top_p")) { - CHECK(config["top_p"].is()); - n->top_p = config["top_p"].get(); - } - if (config.count("frequency_penalty")) { - CHECK(config["frequency_penalty"].is()); - n->frequency_penalty = config["frequency_penalty"].get(); - CHECK(std::fabs(n->frequency_penalty) <= 2.0) << "Frequency penalty must be in [-2, 2]!"; - } - if (config.count("presence_penalty")) { - CHECK(config["presence_penalty"].is()); - n->presence_penalty = config["presence_penalty"].get(); - CHECK(std::fabs(n->presence_penalty) <= 2.0) << "Presence penalty must be in [-2, 2]!"; - } - if (config.count("repetition_penalty")) { - CHECK(config["repetition_penalty"].is()); - n->repetition_penalty = config["repetition_penalty"].get(); - CHECK(n->repetition_penalty > 0) << "Repetition penalty must be a positive number!"; - } - if (config.count("logprobs")) { - CHECK(config["logprobs"].is()); - n->logprobs = config["logprobs"].get(); - } - if (config.count("top_logprobs")) { - CHECK(config["top_logprobs"].is()); - n->top_logprobs = config["top_logprobs"].get(); - CHECK(n->top_logprobs >= 0 && n->top_logprobs <= 5) - << "At most 5 top logprob tokens are supported"; - CHECK(n->top_logprobs == 0 || n->logprobs) - << "\"logprobs\" must be true to support \"top_logprobs\""; - } - if (config.count("logit_bias")) { - CHECK(config["logit_bias"].is() || config["logit_bias"].is()); - if (config["logit_bias"].is()) { - picojson::object logit_bias_json = config["logit_bias"].get(); - std::vector> logit_bias; - logit_bias.reserve(logit_bias_json.size()); - for (auto [token_id_str, bias] : logit_bias_json) { - CHECK(bias.is()); - double bias_value = bias.get(); - CHECK_LE(std::fabs(bias_value), 100.0) - << "Logit bias value should be in range [-100, 100]."; - logit_bias.emplace_back(std::stoi(token_id_str), bias_value); - } - n->logit_bias = std::move(logit_bias); - } + obj->n = n.value_or(default_config->n); + CHECK_GT(obj->n, 0) << "\"n\" should be at least 1"; + obj->temperature = temperature.value_or(default_config->temperature); + CHECK_GE(obj->temperature, 0) << "\"temperature\" should be non-negative"; + obj->top_p = top_p.value_or(default_config->top_p); + CHECK(obj->top_p >= 0 && obj->top_p <= 1) << "\"top_p\" should be in range [0, 1]"; + obj->frequency_penalty = frequency_penalty.value_or(default_config->frequency_penalty); + CHECK(std::fabs(obj->frequency_penalty) <= 2.0) << "Frequency penalty must be in [-2, 2]!"; + obj->presence_penalty = presense_penalty.value_or(default_config->presence_penalty); + CHECK(std::fabs(obj->presence_penalty) <= 2.0) << "Presence penalty must be in [-2, 2]!"; + obj->repetition_penalty = repetition_penalty.value_or(default_config->repetition_penalty); + CHECK(obj->repetition_penalty > 0) << "Repetition penalty must be a positive number!"; + obj->logprobs = logprobs.value_or(default_config->logprobs); + obj->top_logprobs = top_logprobs.value_or(default_config->top_logprobs); + CHECK(obj->top_logprobs >= 0 && obj->top_logprobs <= 5) + << "At most 5 top logprob tokens are supported"; + CHECK(obj->top_logprobs == 0 || obj->logprobs) + << "\"logprobs\" must be true to support \"top_logprobs\""; + + obj->logit_bias = logit_bias.value_or(default_config->logit_bias); + for (auto [token_id_str, bias] : obj->logit_bias) { + CHECK_LE(std::fabs(bias), 100.0) << "Logit bias value should be in range [-100, 100]."; } - if (config.count("max_tokens")) { - if (config["max_tokens"].is()) { - n->max_tokens = config["max_tokens"].get(); - } else { - CHECK(config["max_tokens"].is()) << "Unrecognized max_tokens"; - // "-1" means the generation will not stop until exceeding - // model capability or hit any stop criteria. - n->max_tokens = -1; - } + + obj->seed = seed.value_or(std::random_device{}()); + // "ignore_eos" is for benchmarking. Not the part of OpenAI API spec. + obj->ignore_eos = ignore_eos.value_or(default_config->ignore_eos); + // "-1" means the generation will not stop until exceeding + // model capability or hit any stop criteria. + obj->max_tokens = max_tokens.value_or(-1); + + obj->stop_strs = stop_strs.value_or(default_config->stop_strs); + obj->stop_token_ids = stop_token_ids.value_or(default_config->stop_token_ids); + obj->response_format = response_format.value_or(default_config->response_format); + + data_ = std::move(obj); +} + +GenerationConfig::GenerationConfig(String config_json_str, + Optional default_config_json_str) { + picojson::object config = json::ParseToJsonObject(config_json_str); + ObjectPtr n = make_object(); + GenerationConfig default_config; + if (default_config_json_str.defined()) { + default_config = GenerationConfig(default_config_json_str.value(), NullOpt); + } else { + default_config = GenerationConfig(n); } - if (config.count("seed")) { - if (config["seed"].is()) { - n->seed = config["seed"].get(); - } else { - CHECK(config["seed"].is()) << "Unrecognized seed"; - n->seed = std::random_device{}(); + + n->n = json::LookupOrDefault(config, "n", default_config->n); + CHECK_GT(n->n, 0) << "\"n\" should be at least 1"; + n->temperature = + json::LookupOrDefault(config, "temperature", default_config->temperature); + CHECK_GE(n->temperature, 0) << "\"temperature\" should be non-negative"; + n->top_p = json::LookupOrDefault(config, "top_p", default_config->top_p); + CHECK(n->top_p >= 0 && n->top_p <= 1) << "\"top_p\" should be in range [0, 1]"; + n->frequency_penalty = + json::LookupOrDefault(config, "frequency_penalty", default_config->frequency_penalty); + CHECK(std::fabs(n->frequency_penalty) <= 2.0) << "Frequency penalty must be in [-2, 2]!"; + n->presence_penalty = + json::LookupOrDefault(config, "presence_penalty", default_config->presence_penalty); + CHECK(std::fabs(n->presence_penalty) <= 2.0) << "Presence penalty must be in [-2, 2]!"; + n->repetition_penalty = json::LookupOrDefault(config, "repetition_penalty", + default_config->repetition_penalty); + CHECK(n->repetition_penalty > 0) << "Repetition penalty must be a positive number!"; + n->logprobs = json::LookupOrDefault(config, "logprobs", default_config->logprobs); + n->top_logprobs = + json::LookupOrDefault(config, "top_logprobs", default_config->top_logprobs); + CHECK(n->top_logprobs >= 0 && n->top_logprobs <= 5) + << "At most 5 top logprob tokens are supported"; + CHECK(n->top_logprobs == 0 || n->logprobs) + << "\"logprobs\" must be true to support \"top_logprobs\""; + + std::optional logit_bias_obj = + json::LookupOptional(config, "logit_bias"); + if (logit_bias_obj.has_value()) { + std::vector> logit_bias; + logit_bias.reserve(logit_bias_obj.value().size()); + for (auto [token_id_str, bias] : logit_bias_obj.value()) { + CHECK(bias.is()); + double bias_value = bias.get(); + CHECK_LE(std::fabs(bias_value), 100.0) << "Logit bias value should be in range [-100, 100]."; + logit_bias.emplace_back(std::stoi(token_id_str), bias_value); } + n->logit_bias = std::move(logit_bias); } else { - n->seed = std::random_device{}(); + n->logit_bias = default_config->logit_bias; } - if (config.count("stop_strs")) { - CHECK(config["stop_strs"].is()) - << "Invalid stop_strs. Stop strs should be an array of strings"; - picojson::array stop_strs_arr = config["stop_strs"].get(); + + n->seed = json::LookupOrDefault(config, "seed", std::random_device{}()); + // "ignore_eos" is for benchmarking. Not the part of OpenAI API spec. + n->ignore_eos = json::LookupOrDefault(config, "ignore_eos", default_config->ignore_eos); + // "-1" means the generation will not stop until exceeding + // model capability or hit any stop criteria. + n->max_tokens = json::LookupOrDefault(config, "max_tokens", -1); + + std::optional stop_strs_arr = + json::LookupOptional(config, "stop_strs"); + if (stop_strs_arr.has_value()) { Array stop_strs; - stop_strs.reserve(stop_strs_arr.size()); - for (const picojson::value& v : stop_strs_arr) { + stop_strs.reserve(stop_strs_arr.value().size()); + for (const picojson::value& v : stop_strs_arr.value()) { CHECK(v.is()) << "Invalid stop string in stop_strs"; stop_strs.push_back(v.get()); } n->stop_strs = std::move(stop_strs); + } else { + n->stop_strs = default_config->stop_strs; } - if (config.count("stop_token_ids")) { - CHECK(config["stop_token_ids"].is()) - << "Invalid stop_token_ids. Stop tokens should be an array of integers"; - picojson::array stop_token_ids_arr = config["stop_token_ids"].get(); + std::optional stop_token_ids_arr = + json::LookupOptional(config, "stop_token_ids"); + if (stop_token_ids_arr.has_value()) { std::vector stop_token_ids; - stop_token_ids.reserve(stop_token_ids_arr.size()); - for (const picojson::value& v : stop_token_ids_arr) { + stop_token_ids.reserve(stop_token_ids_arr.value().size()); + for (const picojson::value& v : stop_token_ids_arr.value()) { CHECK(v.is()) << "Invalid stop token in stop_token_ids"; stop_token_ids.push_back(v.get()); } n->stop_token_ids = std::move(stop_token_ids); + } else { + n->stop_token_ids = default_config->stop_token_ids; } - // Params for benchmarking. Not the part of openai spec. - if (config.count("ignore_eos")) { - CHECK(config["ignore_eos"].is()); - n->ignore_eos = config["ignore_eos"].get(); - } - - if (config.count("response_format")) { - CHECK(config["response_format"].is()); - picojson::object response_format_json = config["response_format"].get(); + std::optional response_format_obj = + json::LookupOptional(config, "response_format"); + if (response_format_obj.has_value()) { ResponseFormat response_format; - if (response_format_json.count("type")) { - CHECK(response_format_json["type"].is()); - response_format.type = response_format_json["type"].get(); - } - if (response_format_json.count("schema")) { - if (response_format_json["schema"].is()) { - response_format.schema = NullOpt; - } else { - CHECK(response_format_json["schema"].is()); - response_format.schema = response_format_json["schema"].get(); - } + response_format.type = json::LookupOrDefault(response_format_obj.value(), "type", + response_format.type); + std::optional schema = + json::LookupOptional(response_format_obj.value(), "schema"); + if (schema.has_value()) { + response_format.schema = schema.value(); } n->response_format = response_format; + } else { + n->response_format = default_config->response_format; } data_ = std::move(n); } -Optional GenerationConfig::Create( - const std::string& json_str, std::string* err, const Conversation& conv_template, - const ModelDefinedGenerationConfig& model_defined_gen_config) { - std::optional optional_json_obj = json::LoadJSONFromString(json_str, err); - if (!err->empty() || !optional_json_obj.has_value()) { - return NullOpt; - } - picojson::object& json_obj = optional_json_obj.value(); +GenerationConfig GenerationConfig::GetDefaultFromModelConfig( + const picojson::object& model_config_json) { ObjectPtr n = make_object(); - - n->temperature = - json::LookupOrDefault(json_obj, "temperature", model_defined_gen_config->temperature); - n->top_p = json::LookupOrDefault(json_obj, "top_p", model_defined_gen_config->top_p); - n->frequency_penalty = json::LookupOrDefault(json_obj, "frequency_penalty", - model_defined_gen_config->frequency_penalty); - n->presence_penalty = json::LookupOrDefault(json_obj, "presence_penalty", - model_defined_gen_config->presence_penalty); - n->logprobs = json::LookupOrDefault(json_obj, "logprobs", false); - n->top_logprobs = static_cast(json::LookupOrDefault(json_obj, "top_logprobs", 0)); - n->ignore_eos = json::LookupOrDefault(json_obj, "ignore_eos", false); - - // Copy stop str from conversation template to generation config - for (auto& stop_str : conv_template.stop_str) { - n->stop_strs.push_back(stop_str); - } - for (auto& stop_token_id : conv_template.stop_token_ids) { - n->stop_token_ids.push_back(stop_token_id); - } - - GenerationConfig gen_config; - gen_config.data_ = std::move(n); - return gen_config; + n->temperature = json::LookupOrDefault(model_config_json, "temperature", n->temperature); + n->top_p = json::LookupOrDefault(model_config_json, "top_p", n->top_p); + n->frequency_penalty = + json::LookupOrDefault(model_config_json, "frequency_penalty", n->frequency_penalty); + n->presence_penalty = + json::LookupOrDefault(model_config_json, "presence_penalty", n->presence_penalty); + return GenerationConfig(n); } String GenerationConfigNode::AsJSONString() const { @@ -243,87 +241,638 @@ String GenerationConfigNode::AsJSONString() const { TVM_REGISTER_OBJECT_TYPE(EngineConfigNode); -EngineConfig::EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, int kv_cache_page_size, - int max_num_sequence, int max_total_sequence_length, - int max_single_sequence_length, int prefill_chunk_size, - int max_history_size, KVStateKind kv_state_kind, - SpeculativeMode speculative_mode, int spec_draft_length) { +EngineConfig EngineConfig::FromJSONAndInferredConfig( + const picojson::object& json, const InferrableEngineConfig& inferred_config) { + CHECK(inferred_config.max_num_sequence.has_value()); + CHECK(inferred_config.max_total_sequence_length.has_value()); + CHECK(inferred_config.max_single_sequence_length.has_value()); + CHECK(inferred_config.prefill_chunk_size.has_value()); + CHECK(inferred_config.max_history_size.has_value()); + CHECK(inferred_config.kv_state_kind.has_value()); ObjectPtr n = make_object(); - n->model = std::move(model); - n->model_lib_path = std::move(model_lib_path); - n->additional_models = std::move(additional_models); - n->additional_model_lib_paths = std::move(additional_model_lib_paths); - n->kv_cache_page_size = kv_cache_page_size; - n->max_num_sequence = max_num_sequence; - n->max_total_sequence_length = max_total_sequence_length; - n->max_single_sequence_length = max_single_sequence_length; - n->prefill_chunk_size = prefill_chunk_size; - n->max_history_size = max_history_size; - n->kv_state_kind = kv_state_kind; - n->spec_draft_length = spec_draft_length; - n->speculative_mode = speculative_mode; - data_ = std::move(n); + + // - Get models and model libs. + n->model = json::Lookup(json, "model"); + n->model_lib = json::Lookup(json, "model_lib"); + std::vector additional_models; + std::vector additional_model_libs; + picojson::array additional_models_arr = + json::LookupOrDefault(json, "additional_models", picojson::array()); + picojson::array additional_model_libs_arr = + json::LookupOrDefault(json, "additional_model_libs", picojson::array()); + CHECK_EQ(additional_models_arr.size(), additional_model_libs_arr.size()) + << "The number of additional model libs does not match the number of additional models"; + int num_additional_models = additional_models_arr.size(); + additional_models.reserve(num_additional_models); + additional_model_libs.reserve(num_additional_models); + for (int i = 0; i < num_additional_models; ++i) { + additional_models.push_back(json::Lookup(additional_models_arr, i)); + additional_model_libs.push_back(json::Lookup(additional_model_libs_arr, i)); + } + n->additional_models = additional_models; + n->additional_model_libs = additional_model_libs; + n->mode = EngineModeFromString(json::Lookup(json, "mode")); + + // - Other fields with default value. + n->gpu_memory_utilization = + json::LookupOrDefault(json, "gpu_memory_utilization", n->gpu_memory_utilization); + n->kv_cache_page_size = + json::LookupOrDefault(json, "kv_cache_page_size", n->kv_cache_page_size); + n->speculative_mode = SpeculativeModeFromString(json::LookupOrDefault( + json, "speculative_mode", SpeculativeModeToString(n->speculative_mode))); + n->spec_draft_length = + json::LookupOrDefault(json, "spec_draft_length", n->spec_draft_length); + n->verbose = json::LookupOrDefault(json, "verbose", n->verbose); + + // - Fields from the inferred engine config. + n->max_num_sequence = inferred_config.max_num_sequence.value(); + n->max_total_sequence_length = inferred_config.max_total_sequence_length.value(); + n->max_single_sequence_length = inferred_config.max_single_sequence_length.value(); + n->prefill_chunk_size = inferred_config.prefill_chunk_size.value(); + n->max_history_size = inferred_config.max_history_size.value(); + n->kv_state_kind = inferred_config.kv_state_kind.value(); + + return EngineConfig(n); } -EngineConfig EngineConfig::FromJSONString(const std::string& json_str) { +Result>> +EngineConfig::GetModelsAndModelLibsFromJSONString(const std::string& json_str) { + using TResult = Result>>; picojson::value config_json; std::string err = picojson::parse(config_json, json_str); if (!err.empty()) { - LOG(FATAL) << err; + return TResult::Error(err); } - // Get json fields. + // Get the models and model libs from JSON. picojson::object config = config_json.get(); String model = json::Lookup(config, "model"); - String model_lib_path = json::Lookup(config, "model_lib_path"); - std::vector additional_models; - std::vector additional_model_lib_paths; - int kv_cache_page_size = json::Lookup(config, "kv_cache_page_size"); - int max_num_sequence = json::Lookup(config, "max_num_sequence"); - int max_total_sequence_length = json::Lookup(config, "max_total_sequence_length"); - int max_single_sequence_length = json::Lookup(config, "max_single_sequence_length"); - int prefill_chunk_size = json::Lookup(config, "prefill_chunk_size"); - int max_history_size = json::Lookup(config, "max_history_size"); - KVStateKind kv_state_kind = - static_cast(json::Lookup(config, "kv_state_kind")); - SpeculativeMode speculative_mode = - static_cast(json::Lookup(config, "speculative_mode")); - int spec_draft_length = json::Lookup(config, "spec_draft_length"); - + String model_lib = json::Lookup(config, "model_lib"); picojson::array additional_models_arr = - json::Lookup(config, "additional_models"); - picojson::array additional_model_lib_paths_arr = - json::Lookup(config, "additional_model_lib_paths"); - CHECK_EQ(additional_models_arr.size(), additional_model_lib_paths_arr.size()) - << "The number of additional model lib paths does not match the number of additional models"; + json::LookupOrDefault(config, "additional_models", picojson::array()); + picojson::array additional_model_libs_arr = + json::LookupOrDefault(config, "additional_model_libs", picojson::array()); + if (additional_models_arr.size() != additional_model_libs_arr.size()) { + return TResult::Error( + "The number of additional model libs does not match the number of additional models"); + } + int num_additional_models = additional_models_arr.size(); - additional_models.reserve(num_additional_models); - additional_model_lib_paths.reserve(num_additional_models); + std::vector> models_and_model_libs; + models_and_model_libs.reserve(num_additional_models + 1); + models_and_model_libs.emplace_back(model, model_lib); for (int i = 0; i < num_additional_models; ++i) { - additional_models.push_back(json::Lookup(additional_models_arr, i)); - additional_model_lib_paths.push_back( - json::Lookup(additional_model_lib_paths_arr, i)); + models_and_model_libs.emplace_back(json::Lookup(additional_models_arr, i), + json::Lookup(additional_model_libs_arr, i)); + } + return TResult::Ok(models_and_model_libs); +} + +String EngineConfigNode::AsJSONString() const { + picojson::object config; + + // - Models and model libs + config["model"] = picojson::value(this->model); + config["model_lib"] = picojson::value(this->model_lib); + picojson::array additional_models_arr; + picojson::array additional_model_libs_arr; + additional_models_arr.reserve(this->additional_models.size()); + additional_model_libs_arr.reserve(this->additional_models.size()); + for (int i = 0; i < static_cast(this->additional_models.size()); ++i) { + additional_models_arr.push_back(picojson::value(this->additional_models[i])); + additional_model_libs_arr.push_back(picojson::value(this->additional_model_libs[i])); } + config["additional_models"] = picojson::value(additional_models_arr); + config["additional_model_libs"] = picojson::value(additional_model_libs_arr); + + // - Other fields + config["mode"] = picojson::value(EngineModeToString(this->mode)); + config["gpu_memory_utilization"] = picojson::value(this->gpu_memory_utilization); + config["kv_cache_page_size"] = picojson::value(static_cast(this->kv_cache_page_size)); + config["max_num_sequence"] = picojson::value(static_cast(this->max_num_sequence)); + config["max_total_sequence_length"] = + picojson::value(static_cast(this->max_total_sequence_length)); + config["max_single_sequence_length"] = + picojson::value(static_cast(this->max_single_sequence_length)); + config["prefill_chunk_size"] = picojson::value(static_cast(this->prefill_chunk_size)); + config["max_history_size"] = picojson::value(static_cast(this->max_history_size)); + config["kv_state_kind"] = picojson::value(KVStateKindToString(this->kv_state_kind)); + config["speculative_mode"] = picojson::value(SpeculativeModeToString(this->speculative_mode)); + config["spec_draft_length"] = picojson::value(static_cast(this->spec_draft_length)); + config["verbose"] = picojson::value(static_cast(this->verbose)); + + return picojson::value(config).serialize(true); +} + +/****************** InferrableEngineConfig ******************/ + +/*! \brief The class for config limitation from models. */ +struct ModelConfigLimits { + int64_t model_max_single_sequence_length; + int64_t model_max_prefill_chunk_size; + int64_t model_max_batch_size; +}; + +/*! \brief Convert the bytes to megabytes, keeping 3 decimals. */ +inline std::string BytesToMegabytesString(double bytes) { + std::string str; + str.resize(20); + std::sprintf(&str[0], "%.3f", bytes / 1024 / 1024); + str.resize(std::strlen(str.c_str())); + return str; +} - return EngineConfig(std::move(model), std::move(model_lib_path), additional_models, - additional_model_lib_paths, kv_cache_page_size, max_num_sequence, - max_total_sequence_length, max_single_sequence_length, prefill_chunk_size, - max_history_size, kv_state_kind, speculative_mode, spec_draft_length); +/*! + * \brief Get the upper bound of single sequence length, prefill size and batch size + * from model config. + */ +Result GetModelConfigLimits(const std::vector& model_configs) { + int64_t model_max_single_sequence_length = std::numeric_limits::max(); + int64_t model_max_prefill_chunk_size = std::numeric_limits::max(); + int64_t model_max_batch_size = std::numeric_limits::max(); + for (int i = 0; i < static_cast(model_configs.size()); ++i) { + picojson::object compile_time_model_config = + json::Lookup(model_configs[i], "model_config"); + // - The maximum single sequence length is the minimum context window size among all models. + int64_t runtime_context_window_size = + json::Lookup(model_configs[i], "context_window_size"); + int64_t compile_time_context_window_size = + json::Lookup(compile_time_model_config, "context_window_size"); + if (runtime_context_window_size > compile_time_context_window_size) { + return Result::Error( + "Model " + std::to_string(i) + "'s runtime context window size (" + + std::to_string(runtime_context_window_size) + + ") is larger than the context window size used at compile time (" + + std::to_string(compile_time_context_window_size) + ")."); + } + if (runtime_context_window_size == -1 && compile_time_context_window_size != -1) { + return Result::Error( + "Model " + std::to_string(i) + + "'s runtime context window size (infinite) is larger than the context " + "window size used at compile time (" + + std::to_string(compile_time_context_window_size) + ")."); + } + if (runtime_context_window_size != -1) { + model_max_single_sequence_length = + std::min(model_max_single_sequence_length, runtime_context_window_size); + } + // - The maximum prefill chunk size is the minimum prefill chunk size among all models. + int64_t runtime_prefill_chunk_size = + json::Lookup(model_configs[i], "prefill_chunk_size"); + int64_t compile_time_prefill_chunk_size = + json::Lookup(compile_time_model_config, "prefill_chunk_size"); + if (runtime_prefill_chunk_size > compile_time_prefill_chunk_size) { + return Result::Error( + "Model " + std::to_string(i) + "'s runtime prefill chunk size (" + + std::to_string(runtime_prefill_chunk_size) + + ") is larger than the prefill chunk size used at compile time (" + + std::to_string(compile_time_prefill_chunk_size) + ")."); + } + model_max_prefill_chunk_size = + std::min(model_max_prefill_chunk_size, runtime_prefill_chunk_size); + // - The maximum batch size is the minimum max batch size among all models. + model_max_batch_size = std::min( + model_max_batch_size, json::Lookup(compile_time_model_config, "max_batch_size")); + } + ICHECK_NE(model_max_prefill_chunk_size, std::numeric_limits::max()); + ICHECK_NE(model_max_batch_size, std::numeric_limits::max()); + return Result::Ok( + {model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size}); } -TVM_REGISTER_GLOBAL("mlc.serve.EngineConfig") - .set_body_typed([](String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, int kv_cache_page_size, - int max_num_sequence, int max_total_sequence_length, - int max_single_sequence_length, int prefill_chunk_size, int max_history_size, - int kv_state_kind, int speculative_mode, int spec_draft_length) { - return EngineConfig(std::move(model), std::move(model_lib_path), std::move(additional_models), - std::move(additional_model_lib_paths), kv_cache_page_size, - max_num_sequence, max_total_sequence_length, max_single_sequence_length, - prefill_chunk_size, max_history_size, KVStateKind(kv_state_kind), - SpeculativeMode(speculative_mode), spec_draft_length); - }); +/*! \brief The class for memory usage estimation result. */ +struct MemUsageEstimationResult { + double total_memory_bytes; + double kv_cache_memory_bytes; + double temp_memory_bytes; + InferrableEngineConfig inferred_config; +}; + +Result EstimateMemoryUsageOnMode( + EngineMode mode, Device device, double gpu_memory_utilization, int64_t params_bytes, + int64_t temp_buffer_bytes, + const std::vector& model_configs, // + const std::vector& model_metadata, // + ModelConfigLimits model_config_limits, // + InferrableEngineConfig init_config, bool verbose) { + std::ostringstream os; + InferrableEngineConfig inferred_config = init_config; + // - 1. max_mum_sequence + if (!init_config.max_num_sequence.has_value()) { + if (mode == EngineMode::kLocal) { + inferred_config.max_num_sequence = + std::min(static_cast(4), model_config_limits.model_max_batch_size); + } else if (mode == EngineMode::kInteractive) { + inferred_config.max_num_sequence = 1; + } else { + inferred_config.max_num_sequence = model_config_limits.model_max_batch_size; + } + os << "max batch size will be set to " << inferred_config.max_num_sequence.value() << ", "; + } else { + os << "max batch size " << inferred_config.max_num_sequence.value() + << " is specified by user, "; + } + int64_t max_num_sequence = inferred_config.max_num_sequence.value(); + // - 2. max_single_sequence_length + if (!init_config.max_single_sequence_length.has_value()) { + inferred_config.max_single_sequence_length = + model_config_limits.model_max_single_sequence_length; + } else { + inferred_config.max_single_sequence_length = + std::min(inferred_config.max_single_sequence_length.value(), + model_config_limits.model_max_single_sequence_length); + } + // - 3. infer the maximum total sequence length that can fit GPU memory. + double kv_bytes_per_token = 0; + double kv_aux_workspace_bytes = 0; + double model_workspace_bytes = 0; + double logit_processor_workspace_bytes = 0; + ICHECK_EQ(model_configs.size(), model_metadata.size()); + int num_models = model_configs.size(); + for (int i = 0; i < num_models; ++i) { + // - Read the vocab size and compile-time prefill chunk size (which affects memory allocation). + picojson::object compile_time_model_config = + json::Lookup(model_configs[i], "model_config"); + int64_t vocab_size = json::Lookup(compile_time_model_config, "vocab_size"); + int64_t prefill_chunk_size = + json::Lookup(compile_time_model_config, "prefill_chunk_size"); + // - Calculate KV cache memory usage. + int64_t num_layers = model_metadata[i].kv_cache_metadata.num_hidden_layers; + int64_t head_dim = model_metadata[i].kv_cache_metadata.head_dim; + int64_t num_qo_heads = model_metadata[i].kv_cache_metadata.num_attention_heads; + int64_t num_kv_heads = model_metadata[i].kv_cache_metadata.num_key_value_heads; + int64_t hidden_size = head_dim * num_qo_heads; + kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25; + kv_aux_workspace_bytes += + (max_num_sequence + 1) * 88 + prefill_chunk_size * (num_qo_heads + 1) * 8 + + prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 + 48 * 1024 * 1024; + model_workspace_bytes += prefill_chunk_size * 4 + max_num_sequence * 4 + + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2; + logit_processor_workspace_bytes += + max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125; + } + // Get single-card GPU size. + TVMRetValue rv; + DeviceAPI::Get(device)->GetAttr(device, DeviceAttrKind::kTotalGlobalMemory, &rv); + int64_t gpu_size_bytes = rv; + // Compute the maximum total sequence length under the GPU memory budget. + int64_t model_max_total_sequence_length = + static_cast((gpu_size_bytes * gpu_memory_utilization // + - params_bytes // + - temp_buffer_bytes // + - kv_aux_workspace_bytes // + - model_workspace_bytes // + - logit_processor_workspace_bytes) / + kv_bytes_per_token); + if (model_max_total_sequence_length <= 0) { + if (verbose) { + LOG(INFO) << "temp_buffer = " << BytesToMegabytesString(temp_buffer_bytes); + LOG(INFO) << "kv_aux workspace = " << BytesToMegabytesString(kv_aux_workspace_bytes); + LOG(INFO) << "model workspace = " << BytesToMegabytesString(model_workspace_bytes); + LOG(INFO) << "logit processor workspace = " + << BytesToMegabytesString(logit_processor_workspace_bytes); + } + return Result::Error( + "Insufficient GPU memory error: " + "The available single GPU memory is " + + BytesToMegabytesString(gpu_size_bytes * gpu_memory_utilization) + + " MB, " + "which is less than the sum of model weight size (" + + BytesToMegabytesString(params_bytes) + " MB) and temporary buffer size (" + + BytesToMegabytesString(temp_buffer_bytes + kv_aux_workspace_bytes + model_workspace_bytes + + logit_processor_workspace_bytes) + + " MB).\n" + "1. You can set a larger \"gpu_memory_utilization\" value.\n" + "2. If the model weight size is too large, please enable tensor parallelism by passing " + "`--tensor-parallel-shards $NGPU` to `mlc_llm gen_config` or use quantization.\n" + "3. If the temporary buffer size is too large, please use a smaller `--prefill-chunk-size` " + "in `mlc_llm gen_config`."); + } + if (device.device_type == DLDeviceType::kDLMetal) { + // NOTE: Metal runtime has severe performance issues with large buffers. + // To work around the issue, we limit the KV cache capacity to 32768. + model_max_total_sequence_length = + std::min(model_max_total_sequence_length, static_cast(32768)); + } + // Compute the total memory usage except the KV cache part. + double total_mem_usage_except_kv_cache = + (params_bytes + temp_buffer_bytes + kv_aux_workspace_bytes + model_workspace_bytes + + logit_processor_workspace_bytes); + + // - 4. max_total_sequence_length + if (!init_config.max_total_sequence_length.has_value()) { + if (mode == EngineMode::kLocal) { + inferred_config.max_total_sequence_length = std::min( + {model_max_total_sequence_length, model_config_limits.model_max_single_sequence_length, + static_cast(8192)}); + } else if (mode == EngineMode::kInteractive) { + inferred_config.max_total_sequence_length = std::min( + model_max_total_sequence_length, model_config_limits.model_max_single_sequence_length); + } else { + inferred_config.max_total_sequence_length = + std::min(model_max_total_sequence_length, + max_num_sequence * model_config_limits.model_max_single_sequence_length); + } + os << "max KV cache token capacity will be set to " + << inferred_config.max_total_sequence_length.value() << ", "; + } else { + os << "max KV cache token capacity " << inferred_config.max_total_sequence_length.value() + << " is specified by user, "; + } + // - 5. prefill_chunk_size + if (!init_config.prefill_chunk_size.has_value()) { + if (mode == EngineMode::kLocal || mode == EngineMode::kInteractive) { + inferred_config.prefill_chunk_size = + std::min({model_config_limits.model_max_prefill_chunk_size, + inferred_config.max_total_sequence_length.value(), + model_config_limits.model_max_single_sequence_length}); + } else { + inferred_config.prefill_chunk_size = model_config_limits.model_max_prefill_chunk_size; + } + os << "prefill chunk size will be set to " << inferred_config.prefill_chunk_size.value() + << ". "; + } else { + os << "prefill chunk size " << inferred_config.prefill_chunk_size.value() + << " is specified by user. "; + } + + // - Print logging message + if (verbose) { + LOG(INFO) << "Under mode \"" << EngineModeToString(mode) << "\", " << os.str(); + } + + return Result::Ok( + {total_mem_usage_except_kv_cache + + inferred_config.max_total_sequence_length.value() * kv_bytes_per_token, + kv_bytes_per_token * inferred_config.max_total_sequence_length.value() + + kv_aux_workspace_bytes, + model_workspace_bytes + logit_processor_workspace_bytes + temp_buffer_bytes, + inferred_config}); +} + +Result InferrableEngineConfig::InferForKVCache( + EngineMode mode, Device device, double gpu_memory_utilization, + const std::vector& model_configs, + const std::vector& model_metadata, InferrableEngineConfig init_config, + bool verbose) { + // - Check if max_history_size is not set. + if (init_config.max_history_size.has_value() && init_config.max_history_size.value() != 0) { + return Result::Error( + "KV cache does not support max_history_size, while it is set to " + + std::to_string(init_config.max_history_size.value()) + " in the input EngineConfig"); + } + // - Get the upper bound of single sequence length, prefill size and batch size + // from model config. + Result model_config_limits_res = GetModelConfigLimits(model_configs); + if (model_config_limits_res.IsErr()) { + return Result::Error(model_config_limits_res.UnwrapErr()); + } + ModelConfigLimits model_config_limits = model_config_limits_res.Unwrap(); + // - Get total model parameter size and temporary in-function buffer + // size in bytes on single GPU. + int64_t params_bytes = 0; + int64_t temp_buffer_bytes = 0; + for (const ModelMetadata& metadata : model_metadata) { + for (const ModelMetadata::Param& param : metadata.params) { + int64_t param_size = param.dtype.bytes(); + for (int64_t v : param.shape) { + ICHECK_GE(v, 0); + param_size *= v; + } + params_bytes += param_size; + } + for (const auto& [func_name, temp_buffer_size] : metadata.memory_usage) { + temp_buffer_bytes = std::max(temp_buffer_bytes, temp_buffer_size); + } + } + // Magnify the temp buffer by a factor of 2 for safety. + temp_buffer_bytes *= 2; + + // - Infer the engine config and estimate memory usage for each mode. + Result local_mode_estimation_result = EstimateMemoryUsageOnMode( + EngineMode::kLocal, device, gpu_memory_utilization, params_bytes, temp_buffer_bytes, + model_configs, model_metadata, model_config_limits, init_config, verbose); + Result interactive_mode_estimation_result = EstimateMemoryUsageOnMode( + EngineMode::kInteractive, device, gpu_memory_utilization, params_bytes, temp_buffer_bytes, + model_configs, model_metadata, model_config_limits, init_config, verbose); + Result server_mode_estimation_result = EstimateMemoryUsageOnMode( + EngineMode::kServer, device, gpu_memory_utilization, params_bytes, temp_buffer_bytes, + model_configs, model_metadata, model_config_limits, init_config, verbose); + // - Pick the estimation result according to the mode. + std::string mode_name; + Result final_estimation_result; + if (mode == EngineMode::kLocal) { + final_estimation_result = std::move(local_mode_estimation_result); + } else if (mode == EngineMode::kInteractive) { + final_estimation_result = std::move(interactive_mode_estimation_result); + } else { + final_estimation_result = std::move(server_mode_estimation_result); + } + if (final_estimation_result.IsErr()) { + return Result::Error(final_estimation_result.UnwrapErr()); + } + // - Print log message. + MemUsageEstimationResult final_estimation = final_estimation_result.Unwrap(); + InferrableEngineConfig inferred_config = std::move(final_estimation.inferred_config); + if (verbose) { + LOG(INFO) << "The actual engine mode is \"" << EngineModeToString(mode) + << "\". So max batch size is " << inferred_config.max_num_sequence.value() + << ", max KV cache token capacity is " + << inferred_config.max_total_sequence_length.value() << ", prefill chunk size is " + << inferred_config.prefill_chunk_size.value() << "."; + LOG(INFO) << "Estimated total single GPU memory usage: " + << BytesToMegabytesString(final_estimation.total_memory_bytes) + << " MB (Parameters: " << BytesToMegabytesString(params_bytes) + << " MB. KVCache: " << BytesToMegabytesString(final_estimation.kv_cache_memory_bytes) + << " MB. Temporary buffer: " + << BytesToMegabytesString(final_estimation.temp_memory_bytes) + << " MB). The actual usage might be slightly larger than the estimated number."; + } + + inferred_config.kv_state_kind = KVStateKind::kKVCache; + inferred_config.max_history_size = 0; + return Result::Ok(inferred_config); +} + +Result InferrableEngineConfig::InferForRNNState( + EngineMode mode, Device device, double gpu_memory_utilization, + const std::vector& model_configs, + const std::vector& model_metadata, InferrableEngineConfig init_config, + bool verbose) { + // - Check max_single_sequence_length is not set. + if (init_config.max_single_sequence_length.has_value()) { + return Result::Error( + "RNN state does not support max_single_sequence_length, while it is set to " + + std::to_string(init_config.max_single_sequence_length.value()) + + " in the input EngineConfig"); + } + // - Get the upper bound of single sequence length, prefill size and batch size + // from model config. + Result model_config_limits_res = GetModelConfigLimits(model_configs); + if (model_config_limits_res.IsErr()) { + return Result::Error(model_config_limits_res.UnwrapErr()); + } + ModelConfigLimits model_config_limits = model_config_limits_res.Unwrap(); + + std::ostringstream os; + InferrableEngineConfig inferred_config = init_config; + // - 1. prefill_chunk_size + if (!init_config.prefill_chunk_size.has_value()) { + inferred_config.prefill_chunk_size = + std::min(model_config_limits.model_max_prefill_chunk_size, static_cast(4096)); + os << "prefill chunk size will be set to " << inferred_config.prefill_chunk_size.value() + << ", "; + } else { + os << "prefill chunk size " << inferred_config.prefill_chunk_size.value() + << " is specified by user, "; + } + // - 2. max_batch_size + if (!init_config.max_num_sequence.has_value()) { + inferred_config.max_num_sequence = + mode == EngineMode::kInteractive + ? 1 + : std::min(static_cast(4), model_config_limits.model_max_batch_size); + os << "max batch size will be set to " << inferred_config.max_num_sequence.value() << ", "; + } else { + os << "max batch size " << inferred_config.max_num_sequence.value() + << " is specified by user, "; + } + int64_t max_num_sequence = inferred_config.max_num_sequence.value(); + // - 3. max_total_sequence_length + if (!init_config.max_total_sequence_length.has_value()) { + inferred_config.max_total_sequence_length = 32768; + os << "max RNN state token capacity will be set to " + << inferred_config.max_total_sequence_length.value() << ". "; + } else { + os << "max RNN state token capacity " << inferred_config.max_total_sequence_length.value() + << " is specified by user. "; + } + + // - Extra logging message + if (mode == EngineMode::kLocal) { + os << "We choose small max batch size and RNN state capacity to use less GPU memory."; + } else if (mode == EngineMode::kInteractive) { + os << "We fix max batch size to 1 for interactive single sequence use."; + } else { + os << "We use as much GPU memory as possible (within the limit of gpu_memory_utilization)."; + } + if (verbose) { + LOG(INFO) << "Under mode \"" << EngineModeToString(mode) << "\", " << os.str(); + } + + // - Get total model parameter size and temporary in-function buffer + // size in bytes on single GPU. + int64_t params_bytes = 0; + int64_t temp_buffer_bytes = 0; + for (const ModelMetadata& metadata : model_metadata) { + for (const ModelMetadata::Param& param : metadata.params) { + int64_t param_size = param.dtype.bytes(); + for (int64_t v : param.shape) { + ICHECK_GE(v, 0); + param_size *= v; + } + params_bytes += param_size; + } + for (const auto& [func_name, temp_buffer_size] : metadata.memory_usage) { + temp_buffer_bytes += temp_buffer_size; + } + } + // - 4. max_history_size + double rnn_state_base_bytes = 0; // The memory usage for rnn state when history = 1. + double model_workspace_bytes = 0; + double logit_processor_workspace_bytes = 0; + ICHECK_EQ(model_configs.size(), model_metadata.size()); + int num_models = model_configs.size(); + for (int i = 0; i < num_models; ++i) { + // - Read the vocab size and compile-time prefill chunk size (which affects memory allocation). + picojson::object compile_time_model_config = + json::Lookup(model_configs[i], "model_config"); + int64_t vocab_size = json::Lookup(compile_time_model_config, "vocab_size"); + int64_t prefill_chunk_size = + json::Lookup(compile_time_model_config, "prefill_chunk_size"); + int64_t head_size = json::Lookup(compile_time_model_config, "head_size"); + int64_t num_heads = json::Lookup(compile_time_model_config, "num_heads"); + int64_t num_layers = json::Lookup(compile_time_model_config, "num_hidden_layers"); + int64_t hidden_size = json::Lookup(compile_time_model_config, "hidden_size"); + // - Calculate RNN state memory usage. + rnn_state_base_bytes += (max_num_sequence * hidden_size * num_layers * 2 * 2 + + max_num_sequence * num_heads * head_size * head_size * num_layers * 2); + model_workspace_bytes += prefill_chunk_size * 4 + max_num_sequence * 4 + + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2; + logit_processor_workspace_bytes += + max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125; + } + // Get single-card GPU size. + TVMRetValue rv; + DeviceAPI::Get(device)->GetAttr(device, DeviceAttrKind::kTotalGlobalMemory, &rv); + int64_t gpu_size_bytes = rv; + // Compute the maximum history size length under the GPU memory budget. + int64_t model_max_history_size = static_cast((gpu_size_bytes * gpu_memory_utilization // + - params_bytes // + - temp_buffer_bytes // + - model_workspace_bytes // + - logit_processor_workspace_bytes) / + rnn_state_base_bytes); + if (model_max_history_size <= 0) { + return Result::Error( + "Insufficient GPU memory error: " + "The available single GPU memory is " + + BytesToMegabytesString(gpu_size_bytes * gpu_memory_utilization) + + " MB, " + "which is less than the sum of model weight size (" + + BytesToMegabytesString(params_bytes) + " MB) and temporary buffer size (" + + BytesToMegabytesString( + (temp_buffer_bytes + model_workspace_bytes + logit_processor_workspace_bytes)) + + " MB). " + "If the model weight size is too large, please use quantization. " + "If the temporary buffer size is too large, please use a smaller `--prefill-chunk-size` in " + "`mlc_llm gen_config`."); + } + if (!init_config.max_history_size.has_value()) { + inferred_config.max_history_size = model_max_history_size; + } else { + inferred_config.max_history_size = + std::min(inferred_config.max_history_size.value(), model_max_history_size); + } + if (verbose) { + LOG(INFO) << "The actual engine mode is \"" << EngineModeToString(mode) + << "\". So max batch size is " << inferred_config.max_num_sequence.value() + << ", max RNN state token capacity is " + << inferred_config.max_total_sequence_length.value() << ", prefill chunk size is " + << inferred_config.prefill_chunk_size.value() << "."; + LOG(INFO) << "Estimated total single GPU memory usage: " + << BytesToMegabytesString(params_bytes + temp_buffer_bytes + + inferred_config.max_history_size.value() * + rnn_state_base_bytes) + << " MB (Parameters: " << BytesToMegabytesString(params_bytes) << " MB. RNN state: " + << BytesToMegabytesString(inferred_config.max_history_size.value() * + rnn_state_base_bytes) + << " MB. Temporary buffer: " + << BytesToMegabytesString(model_workspace_bytes + logit_processor_workspace_bytes + + temp_buffer_bytes) + << " MB). The actual usage might be slightly larger than the estimated number."; + } + + inferred_config.kv_state_kind = KVStateKind::kRNNState; + return Result::Ok(inferred_config); +} + +/****************** Config utils ******************/ + +Result ModelsUseKVCache(const std::vector& model_configs) { + ICHECK_GE(model_configs.size(), 1); + std::string model_type = json::Lookup(model_configs[0], "model_type"); + bool use_kv_cache = model_type.find("rwkv") == std::string::npos; + for (int i = 1; i < static_cast(model_configs.size()); ++i) { + if ((json::Lookup(model_configs[i], "model_type").find("rwkv") == + std::string::npos) != use_kv_cache) { + return Result::Error( + "Invalid models in EngineConfig. Models must be all RNN model or none model is RNN " + "model."); + } + } + return Result::Ok(use_kv_cache); +} } // namespace serve } // namespace llm diff --git a/cpp/serve/config.h b/cpp/serve/config.h index fd76dd49f0..8437232d37 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -5,13 +5,15 @@ #ifndef MLC_LLM_SERVE_CONFIG_H_ #define MLC_LLM_SERVE_CONFIG_H_ +#include #include #include #include #include -#include "../json_ffi/config.h" +#include "../metadata/model.h" +#include "../support/result.h" namespace mlc { namespace llm { @@ -19,7 +21,6 @@ namespace serve { using namespace tvm; using namespace tvm::runtime; -using namespace mlc::llm::json_ffi; /****************** GenerationConfig ******************/ @@ -60,22 +61,51 @@ class GenerationConfigNode : public Object { class GenerationConfig : public ObjectRef { public: - explicit GenerationConfig(String config_json_str); + TVM_DLL explicit GenerationConfig( + std::optional n, std::optional temperature, std::optional top_p, + std::optional frequency_penalty, std::optional presense_penalty, + std::optional repetition_penalty, std::optional logprobs, + std::optional top_logprobs, std::optional>> logit_bias, + std::optional seed, std::optional ignore_eos, std::optional max_tokens, + std::optional> stop_strs, std::optional> stop_token_ids, + std::optional response_format, Optional default_config_json_str); - /*! - * \brief Create a generation config from a ChatCompletionRequest. - * If the request does not contain a generation config, the model-defined - * generation config will be used. - */ - static Optional Create( - const std::string& json_str, std::string* err, const Conversation& conv_template, - const ModelDefinedGenerationConfig& model_defined_gen_config); + TVM_DLL explicit GenerationConfig(String config_json_str, + Optional default_config_json_str); + + /*! \brief Get the default generation config from the model config. */ + TVM_DLL static GenerationConfig GetDefaultFromModelConfig(const picojson::object& json); TVM_DEFINE_OBJECT_REF_METHODS(GenerationConfig, ObjectRef, GenerationConfigNode); }; /****************** Engine config ******************/ +/*! + * \brief The engine mode in MLC LLM. + * We provide three preset modes: "local", "interactive" and "server". + * The default mode is "local". + * The choice of mode decides the values of "max_batch_size", "max_total_sequence_length" + * and "prefill_chunk_size" when they are not explicitly specified. + * 1. Mode "local" refers to the local server deployment which has low + * request concurrency. So the max batch size will be set to 4, and max + * total sequence length and prefill chunk size are set to the context + * window size (or sliding window size) of the model. + * 2. Mode "interactive" refers to the interactive use of server, which + * has at most 1 concurrent request. So the max batch size will be set to 1, + * and max total sequence length and prefill chunk size are set to the context + * window size (or sliding window size) of the model. + * 3. Mode "server" refers to the large server use case which may handle + * many concurrent request and want to use GPU memory as much as possible. + * In this mode, we will automatically infer the largest possible max batch + * size and max total sequence length. + */ +enum class EngineMode : int { + kLocal = 0, + kInteractive = 1, + kServer = 2, +}; + /*! \brief The speculative mode. */ enum class SpeculativeMode : int { /*! \brief Disable speculative decoding. */ @@ -87,11 +117,13 @@ enum class SpeculativeMode : int { }; /*! \brief The kind of cache. */ -enum KVStateKind { - kAttention = 0, +enum class KVStateKind : int { + kKVCache = 0, kRNNState = 1, }; +class InferrableEngineConfig; + /*! \brief The configuration of engine execution config. */ class EngineConfigNode : public Object { public: @@ -99,44 +131,61 @@ class EngineConfigNode : public Object { /*! \brief The path to the model directory. */ String model; - /*! \brief The path to the model library. */ - String model_lib_path; + /*! \brief The path or identifier to the model library. */ + String model_lib; /*! \brief The path to the additional models' directories. */ Array additional_models; /*! \brief The path to the additional models' libraries. */ - Array additional_model_lib_paths; + Array additional_model_libs; /*************** KV cache config and engine capacities ***************/ + /*! + * \brief The engine mode in MLC LLM. + * \sa EngineMode + */ + EngineMode mode = EngineMode::kLocal; + /*! + * \brief A number in (0, 1) denoting the fraction of GPU memory used by the server in total. + * It is used to infer to maximum possible KV cache capacity. + * When it is unspecified, it defaults to 0.85. + * Under mode "local" or "interactive", the actual memory usage may be + * significantly smaller than this number. Under mode "server", the actual + * memory usage may be slightly larger than this number. + */ + float gpu_memory_utilization = 0.85; /*! \brief The number of consecutive tokens handled in each page in paged KV cache. */ - int kv_cache_page_size; + int kv_cache_page_size = 16; /*! * \brief The maximum number of sequences that are allowed to be * processed by the KV cache at any time. */ - int max_num_sequence; + int max_num_sequence = 4; /*! \brief The maximum length allowed for a single sequence in the engine. */ - int max_total_sequence_length; + int max_total_sequence_length = 4096; /*! * \brief The maximum total number of tokens whose KV data are allowed * to exist in the KV cache at any time. */ - int max_single_sequence_length; + int max_single_sequence_length = 4096; /*! \brief The maximum total sequence length in a prefill. */ - int prefill_chunk_size; + int prefill_chunk_size = 1024; /*! \brief The maximum history size for RNN state. KV cache does not need this. */ - int max_history_size; + int max_history_size = 0; /*! \brief The kind of cache. Whether it's KV cache or RNN state. */ - KVStateKind kv_state_kind; + KVStateKind kv_state_kind = KVStateKind::kKVCache; /*************** Speculative decoding ***************/ /*! \brief The speculative mode. */ - SpeculativeMode speculative_mode; + SpeculativeMode speculative_mode = SpeculativeMode::kDisable; /*! \brief The number of tokens to generate in speculative proposal (draft). */ int spec_draft_length = 4; - String AsJSONString() const; + /*************** Debug ***************/ + bool verbose = false; + + TVM_DLL String AsJSONString() const; static constexpr const char* _type_key = "mlc.serve.EngineConfig"; static constexpr const bool _type_has_method_sequal_reduce = false; @@ -146,19 +195,98 @@ class EngineConfigNode : public Object { class EngineConfig : public ObjectRef { public: - explicit EngineConfig(String model, String model_lib_path, Array additional_models, - Array additional_model_lib_paths, int kv_cache_page_size, - int max_num_sequence, int max_total_sequence_length, - int max_single_sequence_length, int prefill_chunk_size, - int max_history_size, KVStateKind kv_state_kind, - SpeculativeMode speculative_mode, int spec_draft_length); + /*! \brief Create EngineConfig from JSON object and inferred config. */ + TVM_DLL static EngineConfig FromJSONAndInferredConfig( + const picojson::object& json, const InferrableEngineConfig& inferred_config); - /*! \brief Create EngineConfig from JSON string. */ - static EngineConfig FromJSONString(const std::string& json_str); + /*! + * \brief Get all the models and model libs from the JSON string for engine initialization. + * \return The parsed models/model libs from config or error message. + */ + TVM_DLL static Result>> + GetModelsAndModelLibsFromJSONString(const std::string& json_str); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode); }; +/*! \brief A subset of engine config that is inferrable. */ +struct InferrableEngineConfig { + std::optional max_num_sequence; + std::optional max_total_sequence_length; + std::optional max_single_sequence_length; + std::optional prefill_chunk_size; + std::optional max_history_size; + std::optional kv_state_kind; + + /*! \brief Infer the config for KV cache from a given initial config. */ + TVM_DLL static Result InferForKVCache( + EngineMode mode, Device device, double gpu_memory_utilization, + const std::vector& model_configs, + const std::vector& model_metadata, InferrableEngineConfig init_config, + bool verbose); + /*! \brief Infer the config for RNN state from a given initial config. */ + TVM_DLL static Result InferForRNNState( + EngineMode mode, Device device, double gpu_memory_utilization, + const std::vector& model_configs, + const std::vector& model_metadata, InferrableEngineConfig init_config, + bool verbose); +}; + +/****************** Config utils ******************/ + +/*! \brief Check if the models use KV cache or RNN state. */ +Result ModelsUseKVCache(const std::vector& model_configs); + +inline std::string EngineModeToString(EngineMode mode) { + return mode == EngineMode::kLocal ? "local" + : mode == EngineMode::kInteractive ? "interactive" + : "server"; +} + +inline EngineMode EngineModeFromString(const std::string& mode) { + if (mode == "local") { + return EngineMode::kLocal; + } else if (mode == "interactive") { + return EngineMode::kInteractive; + } else if (mode == "server") { + return EngineMode::kServer; + } else { + LOG(FATAL) << "Invalid engine mode string: " << mode; + } +} + +inline std::string SpeculativeModeToString(SpeculativeMode speculative_mode) { + return speculative_mode == SpeculativeMode::kDisable ? "disable" + : speculative_mode == SpeculativeMode::kSmallDraft ? "small_draft" + : "eagle"; +} + +inline SpeculativeMode SpeculativeModeFromString(const std::string& speculative_mode) { + if (speculative_mode == "disable") { + return SpeculativeMode::kDisable; + } else if (speculative_mode == "small_draft") { + return SpeculativeMode::kSmallDraft; + } else if (speculative_mode == "eagle") { + return SpeculativeMode::kEagle; + } else { + LOG(FATAL) << "Invalid speculative mode string: " << speculative_mode; + } +} + +inline std::string KVStateKindToString(KVStateKind kv_state_kind) { + return kv_state_kind == KVStateKind::kKVCache ? "kv_cache" : "rnn_State"; +} + +inline KVStateKind KVStateKindFromString(const std::string& kv_state_kind) { + if (kv_state_kind == "kv_cache") { + return KVStateKind::kKVCache; + } else if (kv_state_kind == "rnn_state") { + return KVStateKind::kRNNState; + } else { + LOG(FATAL) << "Invalid kv state kind string: " << kv_state_kind; + } +} + } // namespace serve } // namespace llm } // namespace mlc diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 297eba8b10..6fd6188562 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -17,6 +17,8 @@ #include #include +#include "../support/json_parser.h" +#include "../support/result.h" #include "../tokenizers.h" #include "engine_actions/action.h" #include "engine_actions/action_commons.h" @@ -45,61 +47,71 @@ class EngineImpl : public Engine { public: /********************** Engine Management **********************/ - explicit EngineImpl(EngineConfig engine_config, DLDevice device, - Optional request_stream_callback, - Optional trace_recorder) { - // Step 1. Initialize metadata and singleton states inside the engine - this->estate_->Reset(); - // Being "-1" means there is no limit on single sequence length. - if (engine_config->max_single_sequence_length == -1) { - engine_config->max_single_sequence_length = std::numeric_limits::max(); + static Result Create(const std::string& engine_config_json_str, + DLDevice device, + Optional request_stream_callback, + Optional trace_recorder) { + using TResult = Result; + std::unique_ptr n = std::make_unique(); + + // - Read the models and model libs from the EngineConfig JSON string. + Result>> models_and_model_libs_res = + EngineConfig::GetModelsAndModelLibsFromJSONString(engine_config_json_str); + if (models_and_model_libs_res.IsErr()) { + return TResult::Error(models_and_model_libs_res.UnwrapErr()); } - this->request_stream_callback_ = std::move(request_stream_callback); - this->trace_recorder_ = trace_recorder; - - // Step 2. Initialize each model independently. - // Create the logit processor and sampler. - this->models_.clear(); - this->model_workspaces_.clear(); - + std::vector> models_and_model_libs = + models_and_model_libs_res.Unwrap(); + ICHECK_GE(models_and_model_libs.size(), 1); + // - Initialize singleton states inside the engine. + n->estate_->Reset(); + n->request_stream_callback_ = std::move(request_stream_callback); + n->trace_recorder_ = trace_recorder; + n->device_ = device; + // - Load model config, create a shared disco session when tensor + // parallelism is enabled. std::vector model_configs; - model_configs.push_back(Model::LoadModelConfig(engine_config->model)); - for (const auto& model_path : engine_config->additional_models) { - model_configs.push_back(Model::LoadModelConfig(model_path)); + for (int i = 0; i < static_cast(models_and_model_libs.size()); ++i) { + const auto& [model_str, model_lib] = models_and_model_libs[i]; + Result model_config_res = Model::LoadModelConfig(model_str); + if (model_config_res.IsErr()) { + return TResult::Error("Model " + std::to_string(i) + + " has invalid mlc-chat-config.json: " + model_config_res.UnwrapErr()); + } + model_configs.push_back(model_config_res.Unwrap()); } - - Optional session = CreateDiscoSession(model_configs, device); - - auto f_create_model = [this, &engine_config, &device, &trace_recorder, &model_configs, - &session](const String& model_path, const String& model_lib_path, - int model_index) { - Model model = Model::Create(model_lib_path, std::move(model_path), model_configs[model_index], - device, engine_config->max_num_sequence, session, + Optional session = n->CreateDiscoSession(model_configs, device); + // - Initialize each model independently. + n->models_.clear(); + for (int i = 0; i < static_cast(models_and_model_libs.size()); ++i) { + const auto& [model_str, model_lib] = models_and_model_libs[i]; + Model model = Model::Create(model_lib, model_str, model_configs[i], device, session, /*trace_enabled=*/trace_recorder.defined()); + n->models_.push_back(model); + } + // - Automatically infer the missing fields in EngineConfig JSON strings + // and get the final EngineConfig. + Result engine_config_res = + n->AutoDecideEngineConfig(engine_config_json_str, model_configs); + if (engine_config_res.IsErr()) { + return TResult::Error(engine_config_res.UnwrapErr()); + } + EngineConfig engine_config = engine_config_res.Unwrap(); + // - Load model weights, create KV cache and workspace. + n->model_workspaces_.clear(); + for (const Model& model : n->models_) { + model->LoadParams(); + model->SetMaxNumSequence(engine_config->max_num_sequence); + model->SetPrefillChunkSize(engine_config->prefill_chunk_size); model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence, engine_config->max_total_sequence_length, engine_config->prefill_chunk_size, engine_config->max_history_size, engine_config->kv_state_kind); - CHECK_GE(model->GetMaxWindowSize(), engine_config->max_single_sequence_length) - << "The window size of the model, " << model->GetMaxWindowSize() - << ", is smaller than the pre-defined max single sequence length, " - << engine_config->max_single_sequence_length; - this->models_.push_back(model); - this->model_workspaces_.push_back( + n->model_workspaces_.push_back( ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); - }; - - f_create_model(engine_config->model, engine_config->model_lib_path, /*model_index=*/0); - CHECK_EQ(engine_config->additional_models.size(), - engine_config->additional_model_lib_paths.size()) - << "The additional model and lib path list has mismatched size."; - for (int i = 0; i < static_cast(engine_config->additional_models.size()); ++i) { - f_create_model(engine_config->additional_models[i], - engine_config->additional_model_lib_paths[i], /*model_index=*/i + 1); } - - // Step 3. Initialize tokenizer and grammar - this->tokenizer_ = Tokenizer::FromPath(engine_config->model); + // - Initialize tokenizer and grammar + n->tokenizer_ = Tokenizer::FromPath(engine_config->model); std::string token_table_postproc_method; if (model_configs[0].count("token_table_postproc_method") == 0) { // Backward compatibility: use "byte_fallback" by default @@ -108,73 +120,77 @@ class EngineImpl : public Engine { token_table_postproc_method = model_configs[0].at("token_table_postproc_method").get(); } - this->token_table_ = - Tokenizer::PostProcessTokenTable(tokenizer_->TokenTable(), token_table_postproc_method); - this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_); - - // Step 4. Initialize engine actions that represent state transitions. + n->token_table_ = + Tokenizer::PostProcessTokenTable(n->tokenizer_->TokenTable(), token_table_postproc_method); + n->grammar_init_context_storage_ = GrammarInitContextStorage(n->token_table_); + // - Create the logit processor and sampler, and + // the DraftTokenWorkspaceManager for speculative decoding. int max_num_tokens = engine_config->max_num_sequence; DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr}; if (engine_config->speculative_mode != SpeculativeMode::kDisable) { max_num_tokens *= engine_config->spec_draft_length + 1; - draft_token_workspace_manager = models_[0]->CreateDraftTokenWorkspaceManager(max_num_tokens); + draft_token_workspace_manager = + n->models_[0]->CreateDraftTokenWorkspaceManager(max_num_tokens); draft_token_workspace_manager->AllocWorkspace( - &model_workspaces_[0], + &n->model_workspaces_[0], /*require_hidden_states=*/engine_config->speculative_mode == SpeculativeMode::kEagle); } LogitProcessor logit_processor = - this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); - Sampler sampler = this->models_[0]->CreateSampler( - max_num_tokens, static_cast(this->models_.size()), trace_recorder); + n->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder); + Sampler sampler = n->models_[0]->CreateSampler( + max_num_tokens, static_cast(n->models_.size()), trace_recorder); + // - Initialize engine actions that represent state transitions. if (engine_config->speculative_mode != SpeculativeMode::kDisable) { // Speculative decoding is only possible for more than one model. - ICHECK_GT(this->models_.size(), 1U); + ICHECK_GT(n->models_.size(), 1U); switch (engine_config->speculative_mode) { case SpeculativeMode::kEagle: - this->actions_ = { - EngineAction::EagleNewRequestPrefill(this->models_, // + n->actions_ = { + EngineAction::EagleNewRequestPrefill(n->models_, // logit_processor, // sampler, // - this->model_workspaces_, // + n->model_workspaces_, // draft_token_workspace_manager, // engine_config, // - this->trace_recorder_), - EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler, - this->model_workspaces_, draft_token_workspace_manager, - this->trace_recorder_, - engine_config->spec_draft_length), - EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler, - this->model_workspaces_, draft_token_workspace_manager, - engine_config, this->trace_recorder_)}; + n->trace_recorder_), + EngineAction::EagleBatchDraft(n->models_, logit_processor, sampler, + n->model_workspaces_, draft_token_workspace_manager, + n->trace_recorder_, engine_config->spec_draft_length), + EngineAction::EagleBatchVerify(n->models_, logit_processor, sampler, + n->model_workspaces_, draft_token_workspace_manager, + engine_config, n->trace_recorder_)}; break; default: - this->actions_ = { - EngineAction::NewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::BatchDraft(this->models_, logit_processor, sampler, - this->model_workspaces_, draft_token_workspace_manager, - this->trace_recorder_), - EngineAction::BatchVerify(this->models_, logit_processor, sampler, - this->model_workspaces_, draft_token_workspace_manager, - engine_config, this->trace_recorder_)}; + n->actions_ = { + EngineAction::NewRequestPrefill(n->models_, // + logit_processor, // + sampler, // + n->model_workspaces_, // + engine_config, // + n->trace_recorder_), + EngineAction::BatchDraft(n->models_, logit_processor, sampler, n->model_workspaces_, + draft_token_workspace_manager, n->trace_recorder_), + EngineAction::BatchVerify(n->models_, logit_processor, sampler, n->model_workspaces_, + draft_token_workspace_manager, engine_config, + n->trace_recorder_)}; } } else { - this->actions_ = {EngineAction::NewRequestPrefill(this->models_, // - logit_processor, // - sampler, // - this->model_workspaces_, // - engine_config, // - this->trace_recorder_), - EngineAction::BatchDecode(this->models_, logit_processor, sampler, - this->trace_recorder_)}; + n->actions_ = { + EngineAction::NewRequestPrefill(n->models_, // + logit_processor, // + sampler, // + n->model_workspaces_, // + engine_config, // + n->trace_recorder_), + EngineAction::BatchDecode(n->models_, logit_processor, sampler, n->trace_recorder_)}; } - // Step 4. Automatically set the threading backend max concurrency. - this->engine_config_ = engine_config; - SetThreadMaxConcurrency(); + // - Automatically set the threading backend max concurrency. + n->engine_config_ = engine_config; + n->SetThreadMaxConcurrency(); + // - Get the default generation config from the first model. + GenerationConfig default_generation_cfg = + GenerationConfig::GetDefaultFromModelConfig(model_configs[0]); + return TResult::Ok({std::move(n), std::move(engine_config), std::move(default_generation_cfg)}); } void Reset() final { @@ -321,7 +337,8 @@ class EngineImpl : public Engine { } /************** Utility Functions **************/ - Optional CreateDiscoSession(std::vector model_configs, Device device) { + Optional CreateDiscoSession(const std::vector& model_configs, + Device device) { const auto& base_model_config = model_configs[0]; auto f_get_num_shards = [](const picojson::object& model_config) -> int { @@ -373,6 +390,95 @@ class EngineImpl : public Engine { } private: + Result AutoDecideEngineConfig(const std::string& engine_config_json_str, + const std::vector& model_configs) { + using TResult = Result; + picojson::value config_json; + std::string err = picojson::parse(config_json, engine_config_json_str); + if (!err.empty()) { + return TResult::Error(err); + } + picojson::object config = config_json.get(); + ObjectPtr n = make_object(); + + // - Get the engine mode and maximum GPU utilization for inference. + EngineMode mode = EngineModeFromString(json::Lookup(config, "mode")); + double gpu_memory_utilization = + json::LookupOrDefault(config, "gpu_memory_utilization", n->gpu_memory_utilization); + bool verbose = json::LookupOrDefault(config, "verbose", n->verbose); + + // - Get the config fields that can be automatically inferred. + std::optional max_num_sequence = + json::LookupOptional(config, "max_num_sequence"); + std::optional max_total_sequence_length = + json::LookupOptional(config, "max_total_sequence_length"); + std::optional max_single_sequence_length = + json::LookupOptional(config, "max_single_sequence_length"); + std::optional prefill_chunk_size = + json::LookupOptional(config, "prefill_chunk_size"); + std::optional max_history_size = + json::LookupOptional(config, "max_history_size"); + std::optional kv_state_kind_str = + json::LookupOptional(config, "kv_state_kind"); + std::optional kv_state_kind; + if (kv_state_kind_str.has_value()) { + kv_state_kind = KVStateKindFromString(kv_state_kind_str.value()); + } + InferrableEngineConfig inferrable_cfg{max_num_sequence, max_total_sequence_length, + max_single_sequence_length, prefill_chunk_size, + max_history_size, kv_state_kind}; + + // - Get the model metadata. + std::vector model_metadata; + for (const Model& model : models_) { + model_metadata.push_back(model->GetMetadata()); + } + // - Select from kv cache or RNN state. + Result use_kv_cache = ModelsUseKVCache(model_configs); + if (use_kv_cache.IsErr()) { + return TResult::Error(use_kv_cache.UnwrapErr()); + } + KVStateKind inferred_kv_state_kind; + Result inferrable_cfg_res; + if (use_kv_cache.Unwrap()) { + inferred_kv_state_kind = KVStateKind::kKVCache; + // - Check if the kv state kind from config is valid. + if (kv_state_kind.has_value() && kv_state_kind.value() != inferred_kv_state_kind) { + return TResult::Error( + "Invalid kv state kind in EngineConfig. The models use KV cache, but RNN state is " + "specified in EngineConfig."); + } + // - Infer configuration. + inferrable_cfg_res = InferrableEngineConfig::InferForKVCache( + mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg, + verbose); + } else { + inferred_kv_state_kind = KVStateKind::kRNNState; + // - Check if the kv state kind from config is valid. + if (kv_state_kind.has_value() && kv_state_kind.value() != inferred_kv_state_kind) { + return TResult::Error( + "Invalid kv state kind in EngineConfig. The models use RNN state, but KV cache is " + "specified in EngineConfig."); + } + // - Infer configuration. + inferrable_cfg_res = InferrableEngineConfig::InferForRNNState( + mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg, + verbose); + } + + if (inferrable_cfg_res.IsErr()) { + return TResult::Error(inferrable_cfg_res.UnwrapErr()); + } + inferrable_cfg = inferrable_cfg_res.Unwrap(); + ICHECK(inferrable_cfg.max_num_sequence.has_value()); + ICHECK(inferrable_cfg.max_total_sequence_length.has_value()); + ICHECK(inferrable_cfg.max_single_sequence_length.has_value()); + ICHECK(inferrable_cfg.prefill_chunk_size.has_value()); + ICHECK(inferrable_cfg.max_history_size.has_value()); + ICHECK(inferrable_cfg.kv_state_kind.has_value()); + return TResult::Ok(EngineConfig::FromJSONAndInferredConfig(config, inferrable_cfg)); + } + /*! \brief Set the maximum threading backend concurrency. */ void SetThreadMaxConcurrency() { int host_cpu_usage = 1; @@ -408,6 +514,8 @@ class EngineImpl : public Engine { GrammarInitContextStorage grammar_init_context_storage_; // Models Array models_; + // Device that the models run on. + Device device_; // Workspace of each model. std::vector model_workspaces_; // Request stream callback function @@ -418,12 +526,12 @@ class EngineImpl : public Engine { Optional trace_recorder_; }; -std::unique_ptr Engine::Create(EngineConfig engine_config, Device device, - Optional request_stream_callback, - Optional trace_recorder) { - return std::make_unique(std::move(engine_config), device, - std::move(request_stream_callback), - std::move(trace_recorder)); +Result Engine::Create(const std::string& engine_config_json_str, + Device device, + Optional request_stream_callback, + Optional trace_recorder) { + return EngineImpl::Create(engine_config_json_str, device, std::move(request_stream_callback), + std::move(trace_recorder)); } /*! \brief Clear global memory manager */ @@ -445,13 +553,21 @@ class EngineModule : public ModuleNode { TVM_MODULE_VTABLE_ENTRY("reset", &EngineModule::Reset); TVM_MODULE_VTABLE_ENTRY("get_request_stream_callback", &EngineModule::GetRequestStreamCallback); TVM_MODULE_VTABLE_ENTRY("set_request_stream_callback", &EngineModule::SetRequestStreamCallback); + TVM_MODULE_VTABLE_ENTRY("get_default_generation_config", + &EngineModule::GetDefaultGenerationConfigJSONString); TVM_MODULE_VTABLE_END(); /*! \brief Initialize the engine with config and other fields. */ - void Init(EngineConfig engine_config, Device device, Optional request_stream_callback, + void Init(const std::string& engine_config_json_str, Device device, + Optional request_stream_callback, Optional trace_recorder) { - this->engine_ = Engine::Create(std::move(engine_config), device, - std::move(request_stream_callback), std::move(trace_recorder)); + Result output_res = + Engine::Create(engine_config_json_str, device, std::move(request_stream_callback), + std::move(trace_recorder)); + CHECK(output_res.IsOk()) << output_res.UnwrapErr(); + EngineCreationOutput output = output_res.Unwrap(); + this->engine_ = std::move(output.reloaded_engine); + this->default_generation_cfg_json_str_ = output.default_generation_cfg->AsJSONString(); } /*! \brief Construct an EngineModule. */ static tvm::runtime::Module Create() { return Module(make_object()); } @@ -473,6 +589,12 @@ class EngineModule : public ModuleNode { void Reset() { return GetEngine()->Reset(); } /*! \brief Redirection to `Engine::Stats` */ String Stats() { return GetEngine()->Stats(); } + /*! \brief Return the default generation config string. */ + String GetDefaultGenerationConfigJSONString() { + CHECK(!default_generation_cfg_json_str_.empty()) + << "The default generation config has not been set."; + return default_generation_cfg_json_str_; + } private: Engine* GetEngine() { @@ -481,6 +603,7 @@ class EngineModule : public ModuleNode { } std::unique_ptr engine_ = nullptr; + String default_generation_cfg_json_str_; }; TVM_REGISTER_GLOBAL("mlc.serve.create_engine").set_body_typed(EngineModule::Create); diff --git a/cpp/serve/engine.h b/cpp/serve/engine.h index 2fc0a4d730..7bbe942227 100644 --- a/cpp/serve/engine.h +++ b/cpp/serve/engine.h @@ -21,6 +21,18 @@ using namespace tvm::runtime; typedef TypedPackedFunc)> FRequestStreamCallback; +class Engine; + +/*! + * \brief The output of engine creation, including the created engine and + * the default generation config for requests. + */ +struct EngineCreationOutput { + std::unique_ptr reloaded_engine; + EngineConfig completed_engine_config; + GenerationConfig default_generation_cfg; +}; + /*! * \brief The engine interface for request serving in MLC LLM. * The engine can run one or multiple LLM models internally for @@ -50,15 +62,16 @@ class Engine { /*! * \brief Create an engine in unique pointer. - * \param engine_config The engine config. + * \param engine_config_json_str The serialized JSON string of the engine config. * \param device The device where the run models. * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. - * \return The created Engine in pointer. + * \return The created Engine in pointer, and the default generation config. */ - static std::unique_ptr Create(EngineConfig engine_config, Device device, - Optional request_stream_callback, - Optional trace_recorder); + static Result Create(const std::string& engine_config_json_str, + Device device, + Optional request_stream_callback, + Optional trace_recorder); /*! \brief Reset the engine, clean up all running data and statistics. */ virtual void Reset() = 0; diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index 55ab0a1dff..a0ae4d98f3 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -5,8 +5,8 @@ #include "grammar_parser.h" -#include "../../metadata/json_parser.h" #include "../../support/encoding.h" +#include "../../support/json_parser.h" #include "grammar_builder.h" namespace mlc { diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index be76b40e2e..0bd4126b40 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -12,6 +12,7 @@ #include +#include "../support/json_parser.h" #include "config.h" #include "logit_processor.h" @@ -26,13 +27,13 @@ class ModelImpl; TVM_REGISTER_OBJECT_TYPE(ModelObj); Model Model::Create(String reload_lib_path, String model_path, const picojson::object& model_config, - DLDevice device, int max_num_sequence, const Optional& session, - bool trace_enabled) { - return Model(make_object(reload_lib_path, model_path, model_config, device, - max_num_sequence, session, trace_enabled)); + DLDevice device, const Optional& session, bool trace_enabled) { + return Model(make_object(reload_lib_path, model_path, model_config, device, session, + trace_enabled)); } -picojson::object Model::LoadModelConfig(const String& model_path) { +Result Model::LoadModelConfig(const String& model_path) { + using TResult = Result; picojson::object model_config; std::ifstream config_istream((model_path + "/mlc-chat-config.json").c_str()); std::ostringstream config_ostream; @@ -42,10 +43,10 @@ picojson::object Model::LoadModelConfig(const String& model_path) { picojson::value config_json; std::string err = picojson::parse(config_json, config_str); if (!err.empty()) { - LOG(FATAL) << err; + return TResult::Error(err); } picojson::object config = config_json.get(); - return config; + return TResult::Ok(config); } class ModelImpl : public ModelObj { @@ -55,34 +56,21 @@ class ModelImpl : public ModelObj { * \sa Model::Create */ explicit ModelImpl(String reload_lib_path, String model_path, picojson::object model_config, - DLDevice device, int max_num_sequence, const Optional& session, - bool trace_enabled) - : device_(device) { + DLDevice device, const Optional& session, bool trace_enabled) + : model_(model_path), device_(device) { // Step 1. Process model config json string. LoadModelConfigJSON(model_config); // Step 2. Initialize vm, we use the packed function mechanism // so there is no explicit abi dependency on these extra // classes other than basic tvm runtime. this->ft_.Init(reload_lib_path, device_, model_config, session); - // Step 3. Load params in nd-array cache. - this->params_ = ft_.LoadParams(model_path, device_); - // Step 4. Set max_num_sequence - this->max_num_sequence_ = max_num_sequence; - // Step 5. Reset + // Step 3. Reset this->Reset(); - // Step 6. Initialize the shared NDArray. - Device device_host{DLDeviceType::kDLCPU, 0}; - memory::Allocator* allocator = - memory::MemoryManager::GetOrCreateAllocator(device_host, memory::AllocatorType::kNaive); - ICHECK_NOTNULL(allocator); - token_ids_storage_ = memory::Storage( - allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)), allocator); - this->logit_pos_arr_ = NDArray::Empty({max_num_sequence}, DataType::Int(32), device_host); - // Step 7. Set model type - if (model_config["model_type"].get().find("rwkv") != std::string::npos) { + // Step 4. Set model type + if (json::Lookup(model_config, "model_type").find("rwkv") != std::string::npos) { this->kind = KVStateKind::kRNNState; } else { - this->kind = KVStateKind::kAttention; + this->kind = KVStateKind::kKVCache; } } @@ -104,6 +92,7 @@ class ModelImpl : public ModelObj { } ICHECK_EQ(token_ids_nd->ndim, 1); ICHECK_EQ(token_ids_nd->shape[0], num_tokens); + ICHECK_NE(prefill_chunk_size_, -1); auto token_ids_dref_or_nd = ft_.CopyToWorker0(token_ids_nd, "token_ids", {prefill_chunk_size_}); ObjectRef embeddings = ft_.embed_func_(token_ids_dref_or_nd, params_); @@ -249,6 +238,7 @@ class ModelImpl : public ModelObj { ShapeTuple embedding_shape{1, total_length, hidden_size_}; embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape); } + ICHECK_NE(max_num_sequence_, -1); ObjectRef logit_pos_dref_or_nd = ft_.CopyToWorker0(logit_pos_nd, "logit_pos", {max_num_sequence_}); // args: embeddings, logit_pos, kv_cache, params @@ -576,7 +566,7 @@ class ModelImpl : public ModelObj { void CreateKVCache(int page_size, int max_num_sequence, int max_total_sequence_length, int prefill_chunk_size, int max_history_size, KVStateKind kv_state_kind) final { - if (kv_state_kind == KVStateKind::kAttention) { + if (kv_state_kind == KVStateKind::kKVCache) { IntTuple max_num_sequence_tuple{max_num_sequence}; IntTuple max_total_sequence_length_tuple{max_total_sequence_length}; IntTuple prefill_chunk_size_tuple{prefill_chunk_size}; @@ -619,6 +609,8 @@ class ModelImpl : public ModelObj { /************** Raw Info Query **************/ + ModelMetadata GetMetadata() const final { return ft_.model_metadata_; } + int GetNumAvailablePages() const final { if (this->kind == KVStateKind::kRNNState) { // RNNState does not introduce new page at runtime @@ -639,14 +631,32 @@ class ModelImpl : public ModelObj { /*********************** Utilities ***********************/ + void LoadParams() final { this->params_ = ft_.LoadParams(model_, device_); } + + void SetMaxNumSequence(int max_num_sequence) final { + this->max_num_sequence_ = max_num_sequence; + this->logit_pos_arr_ = + NDArray::Empty({max_num_sequence}, DataType::Int(32), Device{DLDeviceType::kDLCPU, 0}); + } + + void SetPrefillChunkSize(int prefill_chunk_size) final { + this->prefill_chunk_size_ = prefill_chunk_size; + Device device_host{DLDeviceType::kDLCPU, 0}; + memory::Allocator* allocator = + memory::MemoryManager::GetOrCreateAllocator(device_host, memory::AllocatorType::kNaive); + ICHECK_NOTNULL(allocator); + token_ids_storage_ = memory::Storage( + allocator->Alloc(device_host, {prefill_chunk_size_}, DataType::Int(32)), allocator); + } + LogitProcessor CreateLogitProcessor(int max_num_token, - Optional trace_recorder) { + Optional trace_recorder) final { return LogitProcessor(max_num_token, vocab_size_, &this->ft_, device_, std::move(trace_recorder)); } Sampler CreateSampler(int max_num_sample, int num_models, - Optional trace_recorder) { + Optional trace_recorder) final { if (Sampler::SupportGPUSampler(device_)) { return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_, std::move(trace_recorder)); @@ -660,11 +670,6 @@ class ModelImpl : public ModelObj { return num_shards_ > 1 ? num_shards_ : 0; } - int GetMaxWindowSize() const final { - // Being "-1" means there is no limit on the window size. - return max_window_size_ != -1 ? max_window_size_ : std::numeric_limits::max(); - } - ObjectRef AllocEmbeddingTensor() final { // Allocate the embedding tensor. ObjectRef embedding = ft_.alloc_embedding_tensor_func_(); @@ -678,6 +683,7 @@ class ModelImpl : public ModelObj { NDArray embedding_nd = Downcast(embedding); embedding_shape = embedding_nd.Shape(); } + ICHECK_NE(prefill_chunk_size_, -1); ICHECK_EQ(embedding_shape.size(), 2); ICHECK_GE(embedding_shape[0], prefill_chunk_size_); this->hidden_size_ = embedding_shape[1]; @@ -697,8 +703,9 @@ class ModelImpl : public ModelObj { hidden_states_nd = Downcast(hidden_states); } ShapeTuple hidden_states_shape = hidden_states_nd.Shape(); + ICHECK_NE(prefill_chunk_size_, -1); ICHECK_EQ(hidden_states_shape.size(), 2); - ICHECK_EQ(hidden_states_shape[0], prefill_chunk_size_); + ICHECK_GE(hidden_states_shape[0], prefill_chunk_size_); this->hidden_size_ = hidden_states_shape[1]; this->hidden_states_dtype_ = hidden_states_nd->dtype; return hidden_states; @@ -731,6 +738,7 @@ class ModelImpl : public ModelObj { NDArray indices_nd = logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ICHECK_NE(max_num_sequence_, -1); ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos", {max_num_sequence_}); ft_.gather_hidden_states_func_(input, indices_device, dst_view); return dst_view; @@ -741,6 +749,7 @@ class ModelImpl : public ModelObj { NDArray indices_nd = logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ICHECK_NE(max_num_sequence_, -1); ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos", {max_num_sequence_}); ft_.scatter_hidden_states_func_(input, indices_device, *dst); } @@ -752,6 +761,7 @@ class ModelImpl : public ModelObj { NDArray indices_nd = logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ICHECK_NE(max_num_sequence_, -1); ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); ft_.gather_probs_func_(input, indices_device, dst_view); @@ -763,6 +773,7 @@ class ModelImpl : public ModelObj { NDArray indices_nd = logit_pos_arr_.CreateView({static_cast(indices.size())}, DataType::Int(32)); indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int)); + ICHECK_NE(max_num_sequence_, -1); ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, "logit_pos_local", {max_num_sequence_}, /*local_only=*/true); ft_.scatter_probs_func_(input, indices_device, *dst); @@ -776,50 +787,22 @@ class ModelImpl : public ModelObj { private: /*! \brief Load model configuration from JSON. */ - picojson::object LoadModelConfigJSON(picojson::object config) { - if (config.count("context_window_size")) { - CHECK(config["context_window_size"].is()); - this->max_window_size_ = config["context_window_size"].get(); - } else { - LOG(FATAL) << "Key \"context_window_size\" not found."; - } - if (config.count("sliding_window_size")) { - CHECK(config["sliding_window_size"].is()); - this->sliding_window_size_ = config["sliding_window_size"].get(); - CHECK(sliding_window_size_ == -1 || sliding_window_size_ > 0) - << "Sliding window should be either -1 (which means disabled) of positive"; - } - if (config.count("attention_sink_size")) { - CHECK(config["attention_sink_size"].is()); - this->attention_sink_size_ = config["attention_sink_size"].get(); - this->attention_sink_size_ = std::max(this->attention_sink_size_, 0); - } - if (config.count("tensor_parallel_shards")) { - CHECK(config["tensor_parallel_shards"].is()); - this->num_shards_ = config["tensor_parallel_shards"].get(); - } else { - LOG(FATAL) << "Key \"tensor_parallel_shards\" not found."; - } - if (config.count("prefill_chunk_size")) { - CHECK(config["prefill_chunk_size"].is()); - this->prefill_chunk_size_ = config["prefill_chunk_size"].get(); - } else { - LOG(FATAL) << "Key \"prefill_chunk_size\" not found."; - } - if (config.count("vocab_size")) { - CHECK(config["vocab_size"].is()); - this->vocab_size_ = config["vocab_size"].get(); - } else { - LOG(FATAL) << "Key \"vocab_size\" not found."; - } - - return config; + void LoadModelConfigJSON(const picojson::object& config) { + this->sliding_window_size_ = + json::LookupOrDefault(config, "sliding_window_size", this->sliding_window_size_); + CHECK(sliding_window_size_ == -1 || sliding_window_size_ > 0) + << "Sliding window should be either -1 (which means disabled) of positive"; + this->attention_sink_size_ = + json::LookupOrDefault(config, "attention_sink_size", this->attention_sink_size_); + this->attention_sink_size_ = std::max(this->attention_sink_size_, 0); + this->num_shards_ = json::Lookup(config, "tensor_parallel_shards"); + this->vocab_size_ = json::Lookup(config, "vocab_size"); } //---------------------------- // Model configurations //---------------------------- - int max_window_size_ = -1; + std::string model_; int sliding_window_size_ = -1; int attention_sink_size_ = 0; int num_shards_ = -1; diff --git a/cpp/serve/model.h b/cpp/serve/model.h index f587969bfb..1ac4e4001c 100644 --- a/cpp/serve/model.h +++ b/cpp/serve/model.h @@ -12,6 +12,7 @@ #include #include "../base.h" +#include "../support/result.h" #include "config.h" #include "draft_token_workspace_manager.h" #include "event_trace_recorder.h" @@ -254,6 +255,9 @@ class ModelObj : public Object { /************** Raw Info Query **************/ + /*! \brief Return the metadata JSON object of the model. */ + virtual ModelMetadata GetMetadata() const = 0; + /*! \brief Get the number of available pages in KV cache. */ virtual int GetNumAvailablePages() const = 0; @@ -262,6 +266,21 @@ class ModelObj : public Object { /*********************** Utilities ***********************/ + /*! \brief Load the model's weight parameters, which is not loaded at construction time. */ + virtual void LoadParams() = 0; + + /*! + * \brief Set the maximum number of sequences to be processed for the model, + * which is not initialized at construction time. + */ + virtual void SetMaxNumSequence(int max_num_sequence) = 0; + + /*! + * \brief Set the prefill chunk size for the model, + * which is not initialized at construction time. + */ + virtual void SetPrefillChunkSize(int prefill_chunk_size) = 0; + /*! \brief Create a logit processor from this model. */ virtual LogitProcessor CreateLogitProcessor(int max_num_token, Optional trace_recorder) = 0; @@ -279,9 +298,6 @@ class ModelObj : public Object { */ virtual int EstimateHostCPURequirement() const = 0; - /*! \brief Get the max window size of the model. "-1" means infinite length. */ - virtual int GetMaxWindowSize() const = 0; - /*! \brief Allocate an embedding tensor with the prefill chunk size. */ virtual ObjectRef AllocEmbeddingTensor() = 0; @@ -331,22 +347,20 @@ class Model : public ObjectRef { * \param model_path The path to the model weight parameters. * \param model_config The model config json object. * \param device The device to run the model on. - * \param max_num_sequence The maximum number of sequences to be processed * \param session The session to run the model on. * \param trace_enabled A boolean indicating whether tracing is enabled. * \return The created runtime module. */ TVM_DLL static Model Create(String reload_lib_path, String model_path, const picojson::object& model_config, DLDevice device, - int max_num_sequence, const Optional& session, - bool trace_enabled); + const Optional& session, bool trace_enabled); /*! * Load the model config from the given model path. * \param model_path The path to the model weight parameters. * \return The model config json object. */ - static picojson::object LoadModelConfig(const String& model_path); + TVM_DLL static Result LoadModelConfig(const String& model_path); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj); }; diff --git a/cpp/serve/request.cc b/cpp/serve/request.cc index 8ecd20b18e..bd955ec846 100644 --- a/cpp/serve/request.cc +++ b/cpp/serve/request.cc @@ -67,9 +67,11 @@ Request Request::FromUntokenized(const Request& request, const Tokenizer& tokeni } TVM_REGISTER_GLOBAL("mlc.serve.Request") - .set_body_typed([](String id, Array inputs, String generation_cfg_json) { + .set_body_typed([](String id, Array inputs, String generation_cfg_json_str, + Optional default_generation_cfg_json_str) { return Request(std::move(id), std::move(inputs), - GenerationConfig(std::move(generation_cfg_json))); + GenerationConfig(std::move(generation_cfg_json_str), + std::move(default_generation_cfg_json_str))); }); TVM_REGISTER_GLOBAL("mlc.serve.RequestGetInputs").set_body_typed([](Request request) { diff --git a/cpp/serve/threaded_engine.cc b/cpp/serve/threaded_engine.cc index 080853d465..8c3cadd358 100644 --- a/cpp/serve/threaded_engine.cc +++ b/cpp/serve/threaded_engine.cc @@ -13,6 +13,7 @@ #include #include +#include "../support/result.h" #include "engine.h" #include "request.h" @@ -36,8 +37,8 @@ enum class InstructionKind : int { /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineImpl : public ThreadedEngine { public: - void InitBackgroundEngine(Device device, Optional request_stream_callback, - Optional trace_recorder) final { + void InitThreadedEngine(Device device, Optional request_stream_callback, + Optional trace_recorder) final { device_ = device; CHECK(request_stream_callback.defined()) << "ThreadedEngine requires request stream callback function, but it is not given."; @@ -45,17 +46,23 @@ class ThreadedEngineImpl : public ThreadedEngine { trace_recorder_ = trace_recorder; } - void Reload(EngineConfig engine_config) final { + void Reload(String engine_config_json_str) final { bool need_notify = false; { std::lock_guard lock(background_loop_mutex_); - instruction_queue_.emplace_back(InstructionKind::kReloadEngine, std::move(engine_config)); + instruction_queue_.emplace_back(InstructionKind::kReloadEngine, + std::move(engine_config_json_str)); ++pending_request_operation_cnt_; need_notify = engine_waiting_; } if (need_notify) { background_loop_cv_.notify_one(); } + { + std::unique_lock lock(reload_unload_mutex_); + reload_finished_ = false; + reload_unload_cv_.wait(lock, [this] { return reload_finished_; }); + } } void Unload() final { @@ -69,6 +76,11 @@ class ThreadedEngineImpl : public ThreadedEngine { if (need_notify) { background_loop_cv_.notify_one(); } + { + std::unique_lock lock(reload_unload_mutex_); + unload_finished_ = false; + reload_unload_cv_.wait(lock, [this] { return unload_finished_; }); + } } void Reset() final { @@ -140,7 +152,7 @@ class ThreadedEngineImpl : public ThreadedEngine { EngineUnloadImpl(); } else if (kind == InstructionKind::kReloadEngine) { EngineUnloadImpl(); - EngineReloadImpl(Downcast(arg)); + EngineReloadImpl(Downcast(arg)); } else if (kind == InstructionKind::kResetEngine) { if (background_engine_ != nullptr) { background_engine_->Reset(); @@ -199,7 +211,23 @@ class ThreadedEngineImpl : public ThreadedEngine { request_stream_callback_cv_.notify_one(); } - /************** Debug/Profile **************/ + /************** Query/Profile/Debug **************/ + + String GetDefaultGenerationConfigJSONString() const final { + CHECK(!default_generation_cfg_json_str_.empty()) + << "The default generation config has not been set."; + return default_generation_cfg_json_str_; + }; + + String GetCompleteEngineConfigJSONString() const final { + CHECK(!complete_engine_config_json_str_.empty()) << "The engine config has not been set."; + return complete_engine_config_json_str_; + }; + + String Stats() final { + std::lock_guard lock(background_loop_mutex_); + return background_engine_->Stats(); + } void DebugCallFuncOnAllAllWorker(const String& func_name) final { bool need_notify = false; @@ -214,13 +242,8 @@ class ThreadedEngineImpl : public ThreadedEngine { } } - String Stats() final { - std::lock_guard lock(background_loop_mutex_); - return background_engine_->Stats(); - } - private: - void EngineReloadImpl(EngineConfig engine_config) { + void EngineReloadImpl(const std::string& engine_config_json_str) { auto frequest_stream_callback_wrapper = [this](TVMArgs args, TVMRetValue* ret) { ICHECK_EQ(args.size(), 1); Array delta_outputs = args[0]; @@ -237,8 +260,19 @@ class ThreadedEngineImpl : public ThreadedEngine { }; Optional request_stream_callback = PackedFunc(frequest_stream_callback_wrapper); - background_engine_ = Engine::Create(std::move(engine_config), device_, - std::move(request_stream_callback), trace_recorder_); + Result output_res = Engine::Create( + engine_config_json_str, device_, std::move(request_stream_callback), trace_recorder_); + CHECK(output_res.IsOk()) << output_res.UnwrapErr(); + EngineCreationOutput output = output_res.Unwrap(); + background_engine_ = std::move(output.reloaded_engine); + default_generation_cfg_json_str_ = output.default_generation_cfg->AsJSONString(); + complete_engine_config_json_str_ = output.completed_engine_config->AsJSONString(); + { + // Wake up the thread waiting for reload finish. + std::lock_guard lock(reload_unload_mutex_); + reload_finished_ = true; + reload_unload_cv_.notify_one(); + } } void EngineUnloadImpl() { @@ -250,6 +284,14 @@ class ThreadedEngineImpl : public ThreadedEngine { tvm::runtime::Registry::Get("vm.builtin.memory_manager.clear"); ICHECK(fclear_memory_manager) << "Cannot find env function vm.builtin.memory_manager.clear"; (*fclear_memory_manager)(); + default_generation_cfg_json_str_ = ""; + complete_engine_config_json_str_ = ""; + } + { + // Wake up the thread waiting for unload finish. + std::lock_guard lock(reload_unload_mutex_); + unload_finished_ = true; + reload_unload_cv_.notify_one(); } } @@ -261,13 +303,19 @@ class ThreadedEngineImpl : public ThreadedEngine { PackedFunc request_stream_callback_; /*! \brief Event trace recorder. */ Optional trace_recorder_; + /*! \brief The complete engine config JSON string. */ + String complete_engine_config_json_str_; + /*! \brief The default generation config JSON string. */ + String default_generation_cfg_json_str_; /*! \brief The mutex ensuring only one thread can access critical regions. */ std::mutex background_loop_mutex_; std::mutex request_stream_callback_mutex_; + std::mutex reload_unload_mutex_; /*! \brief The condition variable preventing threaded engine from spinning. */ std::condition_variable background_loop_cv_; std::condition_variable request_stream_callback_cv_; + std::condition_variable reload_unload_cv_; /*! \brief A boolean flag denoting if the engine needs to exit background loop. */ std::atomic exit_now_ = false; @@ -303,13 +351,17 @@ class ThreadedEngineImpl : public ThreadedEngine { bool engine_waiting_ = false; /*! \brief A boolean flag indicating if the stream callback loop is waiting. */ bool stream_callback_waiting_ = false; + /*! \brief A boolean indicating if the engine reload has finished. */ + bool reload_finished_ = false; + /*! \brief A boolean indicating if the engine unload has finished. */ + bool unload_finished_ = false; }; /*! \brief The implementation of ThreadedEngine. */ class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { public: TVM_MODULE_VTABLE_BEGIN("mlc.serve.async_threaded_engine"); - TVM_MODULE_VTABLE_ENTRY("init_background_engine", &ThreadedEngineImpl::InitBackgroundEngine); + TVM_MODULE_VTABLE_ENTRY("init_threaded_engine", &ThreadedEngineImpl::InitThreadedEngine); TVM_MODULE_VTABLE_ENTRY("reload", &ThreadedEngineImpl::Reload); TVM_MODULE_VTABLE_ENTRY("add_request", &ThreadedEngineImpl::AddRequest); TVM_MODULE_VTABLE_ENTRY("abort_request", &ThreadedEngineImpl::AbortRequest); @@ -317,9 +369,13 @@ class ThreadedEngineModule : public ThreadedEngineImpl, public ModuleNode { TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop", &ThreadedEngineImpl::RunBackgroundStreamBackLoop); TVM_MODULE_VTABLE_ENTRY("exit_background_loop", &ThreadedEngineImpl::ExitBackgroundLoop); + TVM_MODULE_VTABLE_ENTRY("get_default_generation_config", + &ThreadedEngineImpl::GetDefaultGenerationConfigJSONString); + TVM_MODULE_VTABLE_ENTRY("get_complete_engine_config", + &ThreadedEngineImpl::GetCompleteEngineConfigJSONString); + TVM_MODULE_VTABLE_ENTRY("stats", &ThreadedEngineImpl::Stats); TVM_MODULE_VTABLE_ENTRY("debug_call_func_on_all_worker", &ThreadedEngineImpl::DebugCallFuncOnAllAllWorker); - TVM_MODULE_VTABLE_ENTRY("stats", &ThreadedEngineImpl::Stats); TVM_MODULE_VTABLE_END(); }; diff --git a/cpp/serve/threaded_engine.h b/cpp/serve/threaded_engine.h index d0f2ebe2d7..b6afdcbb7c 100644 --- a/cpp/serve/threaded_engine.h +++ b/cpp/serve/threaded_engine.h @@ -39,14 +39,14 @@ class ThreadedEngine { * \param request_stream_callback The request stream callback function to. * \param trace_recorder Event trace recorder for requests. */ - virtual void InitBackgroundEngine(Device device, Optional request_stream_callback, - Optional trace_recorder) = 0; + virtual void InitThreadedEngine(Device device, Optional request_stream_callback, + Optional trace_recorder) = 0; /*! * \brief Reload the engine with the new engine config. - * \param engine_config The engine config. + * \param engine_config_json_str The engine config JSON string. */ - virtual void Reload(EngineConfig engine_config) = 0; + virtual void Reload(String engine_config_json_str) = 0; /*! \brief Unload the background engine. */ virtual void Unload() = 0; @@ -73,13 +73,19 @@ class ThreadedEngine { /*! \brief Abort the input request (specified by id string) from engine. */ virtual void AbortRequest(const String& request_id) = 0; - /************** Debug/Profile **************/ + /************** Query/Profile/Debug **************/ - /*! \brief Call the given global function on all workers. Only for debug purpose. */ - virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; + /*! \brief Return the default generation config JSON string. */ + virtual String GetDefaultGenerationConfigJSONString() const = 0; + + /*! \brief Return the complete engine config JSON string. */ + virtual String GetCompleteEngineConfigJSONString() const = 0; /*! \brief Print the statistics of the engine. */ virtual String Stats() = 0; + + /*! \brief Call the given global function on all workers. Only for debug purpose. */ + virtual void DebugCallFuncOnAllAllWorker(const String& func_name) = 0; }; } // namespace serve diff --git a/cpp/metadata/json_parser.h b/cpp/support/json_parser.h similarity index 92% rename from cpp/metadata/json_parser.h rename to cpp/support/json_parser.h index 99a284fc42..f71757435a 100644 --- a/cpp/metadata/json_parser.h +++ b/cpp/support/json_parser.h @@ -2,8 +2,8 @@ * \file json_parser.h * \brief Helps to parse JSON strings and objects. */ -#ifndef MLC_LLM_CPP_JSON_PARSER_H_ -#define MLC_LLM_CPP_JSON_PARSER_H_ +#ifndef MLC_LLM_SUPPORT_JSON_PARSER_H_ +#define MLC_LLM_SUPPORT_JSON_PARSER_H_ #include #include @@ -165,6 +165,17 @@ inline ValueType LookupOrDefault(const picojson::object& json, const std::string return it->second.get(); } +template +inline std::optional LookupOptional(const picojson::object& json, + const std::string& key) { + auto it = json.find(key); + if (it == json.end() || it->second.is()) { + return std::nullopt; + } + CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; + return it->second.get(); +} + template inline ValueType Lookup(const picojson::array& json, int index) { CHECK(index < json.size()) << "IndexError: json::array index out of range"; @@ -209,4 +220,4 @@ inline picojson::object ParseToJsonObject(const std::string& json_str) { } // namespace llm } // namespace mlc -#endif // MLC_LLM_CPP_JSON_PARSER_H_ +#endif // MLC_LLM_SUPPORT_JSON_PARSER_H_ diff --git a/cpp/support/result.h b/cpp/support/result.h new file mode 100644 index 0000000000..c6def39525 --- /dev/null +++ b/cpp/support/result.h @@ -0,0 +1,77 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file result.h + * \brief The header for the Result class in MLC LLM. + */ +#ifndef MLC_LLM_SUPPORT_RESULT_H_ +#define MLC_LLM_SUPPORT_RESULT_H_ + +#include + +#include +#include + +namespace mlc { +namespace llm { + +/*! + * \brief The result class in MLC LLM. + * Each instance is either an okay value or an error. + * \tparam T The okay value type of the result. + * \tparam E The error type of the result. + */ +template +class Result { + public: + /*! \brief Create a result with an okay value. */ + static Result Ok(T value) { + Result result; + result.ok_value_ = std::move(value); + return result; + } + /*! \brief Create a result with an error value. */ + static Result Error(E error) { + Result result; + result.err_value_ = std::move(error); + return result; + } + /*! \brief Check if the result is okay or not. */ + bool IsOk() const { return ok_value_.has_value(); } + /*! \brief Check if the result is an error or not. */ + bool IsErr() const { return err_value_.has_value(); } + /*! + * \brief Unwrap the result and return the okay value. + * Throwing exception if it is an error. + * \note This function returns the ok value by moving, so a Result can be unwrapped only once. + */ + T Unwrap() { + ICHECK(ok_value_.has_value()) << "Cannot unwrap result on an error value."; + ICHECK(!unwrapped_) << "Cannot unwrap a Result instance twice."; + unwrapped_ = true; + return std::move(ok_value_.value()); + } + /*! + * \brief Unwrap the result and return the error value. + * Throwing exception if it is an okay value. + * \note This function returns the error value by moving, so a Result can be unwrapped only once. + */ + E UnwrapErr() { + ICHECK(err_value_.has_value()) << "Cannot unwrap result on an okay value."; + ICHECK(!unwrapped_) << "Cannot unwrap a Result instance twice."; + unwrapped_ = true; + return std::move(err_value_.value()); + } + + private: + /*! \brief A boolean flag indicating if the result is okay or error. */ + bool unwrapped_ = false; + /*! \brief The internal optional okay value. */ + std::optional ok_value_; + /*! \brief The internal optional error value. */ + std::optional err_value_; +}; + +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_SUPPORT_RESULT_H_ diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index 4706e09811..560ca17255 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -285,7 +285,7 @@ We can check the output with the commands below: python >>> from mlc_llm import ChatModule >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ - model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") + model_lib="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") >>> cm.generate("hi") 'Hi! How can I assist you today?' @@ -312,7 +312,7 @@ We can check the output with the commands below: python >>> from mlc_llm import ChatModule >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ - model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so") + model_lib="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so") >>> cm.generate("hi") 'Hi! How can I assist you today?' @@ -340,7 +340,7 @@ We can check the output with the commands below: python >>> from mlc_llm import ChatModule >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ - model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so", device="vulkan") + model_lib="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so", device="vulkan") >>> cm.generate("hi") 'Hi! How can I assist you today?' diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst index aa65256fd6..1518f5145a 100644 --- a/docs/compilation/convert_weights.rst +++ b/docs/compilation/convert_weights.rst @@ -177,6 +177,6 @@ Running the distributed models are similar to running prebuilt model weights and python >>> from mlc_llm import ChatModule >>> cm = ChatModule(model="dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC", \ - model_lib_path="dist/prebuilt_libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") # Adjust based on backend + model_lib="dist/prebuilt_libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") # Adjust based on backend >>> cm.generate("hi") 'Hi! How can I assist you today?' diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst index a7ebe28d6d..f978581707 100644 --- a/docs/deploy/cli.rst +++ b/docs/deploy/cli.rst @@ -92,13 +92,13 @@ For models other than the prebuilt ones we provided: Once you have the model locally compiled with a model library and model weights, to run ``mlc_llm``, simply - Specify the path to ``mlc-chat-config.json`` and the converted model weights to ``--model`` -- Specify the path to the compiled model library (e.g. a .so file) to ``--model-lib-path`` +- Specify the path to the compiled model library (e.g. a .so file) to ``--model-lib`` .. code:: shell mlc_llm chat dist/Llama-2-7b-chat-hf-q4f16_1-MLC \ --device "cuda:0" --overrides context_window_size=1024 \ - --model-lib-path dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-vulkan.so + --model-lib dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-vulkan.so # CUDA on Linux: dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so # Metal on macOS: dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-metal.so # Same rule applies for other platforms diff --git a/docs/deploy/ide_integration.rst b/docs/deploy/ide_integration.rst index 866dfa3cbe..7e0735d8e0 100644 --- a/docs/deploy/ide_integration.rst +++ b/docs/deploy/ide_integration.rst @@ -112,7 +112,7 @@ You can now locally deploy your compiled model with the MLC serve module. To fin python -m mlc_llm.serve.server \ --model dist/CodeLlama-7b-hf-q4f16_1-MLC \ - --model-lib-path ./dist/libs/CodeLlama-7b-hf-q4f16_1-cuda.so + --model-lib ./dist/libs/CodeLlama-7b-hf-q4f16_1-cuda.so Configure the IDE Extension --------------------------- diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index 75a5cdbdc7..2bcf7997d3 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -273,7 +273,7 @@ We simply specify the Huggingface link as ``model_url``, while reusing the ``mod "model_url": "https://huggingface.co/mlc-ai/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC", "model_id": "Mistral-7B-Instruct-v0.2-q3f16_1", "model_lib": "mistral_q3f16_1", - "model_lib_path": "lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q3f16_1-iphone.tar", + "model_lib": "lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q3f16_1-iphone.tar", "estimated_vram_bytes": 3316000000 } ] @@ -421,7 +421,6 @@ rounded up to MB. "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC", "model_id": "phi-2-q4f16_1", "model_lib": "phi_msft_q4f16_1", - "model_lib_path": "lib/phi-2/phi-2-q4f16_1-iphone.tar", "estimated_vram_bytes": 3043000000 } ] diff --git a/docs/deploy/python_chat_module.rst b/docs/deploy/python_chat_module.rst index 5776e29138..14e9f3ed03 100644 --- a/docs/deploy/python_chat_module.rst +++ b/docs/deploy/python_chat_module.rst @@ -95,7 +95,7 @@ file ``sample_mlc_llm.py`` and paste the following lines: # Create a ChatModule instance cm = ChatModule( model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} @@ -106,7 +106,7 @@ file ``sample_mlc_llm.py`` and paste the following lines: # Here WizardMath reuses Mistral's model library # cm = ChatModule( # model="dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", # or "dist/WizardMath-7B-V1.1-q4f16_1-MLC" - # model_lib_path="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" + # model_lib="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" # ) # Generate a response for a given prompt @@ -200,7 +200,7 @@ We provide an example below. cm = ChatModule( chat_config=chat_config, model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} @@ -275,7 +275,7 @@ We provide an example below. cm = ChatModule( chat_config=chat_config, model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} @@ -320,7 +320,7 @@ We provide an example below. # Create a ChatModule instance cm = ChatModule( model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} diff --git a/docs/deploy/python_engine.rst b/docs/deploy/python_engine.rst index 89c60ac422..2ef4d5bd23 100644 --- a/docs/deploy/python_engine.rst +++ b/docs/deploy/python_engine.rst @@ -219,15 +219,15 @@ you can construct a :class:`mlc_llm.MLCEngine` as follows: **Specify Model Library Path.** Further, if you build the model library on your own, -you can use it in :class:`mlc_llm.MLCEngine` by passing the library path through argument ``model_lib_path``. +you can use it in :class:`mlc_llm.MLCEngine` by passing the library path through argument ``model_lib``. .. code:: python from mlc_llm import MLCEngine model = "models/phi-2" - model_lib_path = "models/phi-2/lib.so" # Assuming the phi-2 model library is built at "models/phi-2/lib.so" - engine = MLCEngine(model, model_lib_path=model_lib_path) + model_lib = "models/phi-2/lib.so" # Assuming the phi-2 model library is built at "models/phi-2/lib.so" + engine = MLCEngine(model, model_lib=model_lib) The same applies to :class:`mlc_llm.AsyncMLCEngine`. diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index 07d39dbfad..a82c914004 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -28,7 +28,7 @@ This section provides a quick start guide to work with MLC-LLM REST API. To laun .. code:: bash - mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] + mlc_llm serve MODEL [--model-lib PATH-TO-MODEL-LIB] where ``MODEL`` is the model folder after compiling with :ref:`MLC-LLM build process `. Information about other arguments can be found under :ref:`Launch the server ` section. @@ -66,14 +66,14 @@ To launch the MLC Server for MLC-LLM, run the following command in your terminal .. code:: bash - mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] + mlc_llm serve MODEL [--model-lib PATH-TO-MODEL-LIB] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] MODEL The model folder after compiling with MLC-LLM build process. The parameter can either be the model name with its quantization scheme (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model folder. In the former case, we will use the provided name to search for the model folder over possible paths. ---model-lib-path A field to specify the full path to the model library file to use (e.g. a ``.so`` file). +--model-lib A field to specify the full path to the model library file to use (e.g. a ``.so`` file). --device The description of the device to run on. User should provide a string in the form of 'device_name:device_id' or 'device_name', where 'device_name' is one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the diff --git a/docs/get_started/introduction.rst b/docs/get_started/introduction.rst index 29060d5a60..bcba8f631e 100644 --- a/docs/get_started/introduction.rst +++ b/docs/get_started/introduction.rst @@ -240,20 +240,20 @@ Below is an example command of compiling model libraries in MLC LLM: .. code:: bash - export $MODEL_LIB_PATH=$MLC_MODEL_PATH/lib.so # ".dylib" for Intel Macs. - # ".dll" for Windows. - # ".wasm" for web. - # ".tar" for iPhone/Android. - mlc_llm compile $MLC_MODEL_PATH -o $MODEL_LIB_PATH + export $MODEL_LIB=$MLC_MODEL_PATH/lib.so # ".dylib" for Intel Macs. + # ".dll" for Windows. + # ".wasm" for web. + # ".tar" for iPhone/Android. + mlc_llm compile $MLC_MODEL_PATH -o $MODEL_LIB At runtime, we need to specify this model library path to use it. For example, .. code:: bash # For chat CLI - mlc_llm chat $MLC_MODEL_PATH --model-lib-path $MODEL_LIB_PATH + mlc_llm chat $MLC_MODEL_PATH --model-lib $MODEL_LIB # For REST server - mlc_llm serve $MLC_MODEL_PATH --model-lib-path $MODEL_LIB_PATH + mlc_llm serve $MLC_MODEL_PATH --model-lib $MODEL_LIB .. code:: python @@ -261,8 +261,8 @@ At runtime, we need to specify this model library path to use it. For example, # For Python API model = "models/phi-2" - model_lib_path = "models/phi-2/lib.so" - engine = MLCEngine(model, model_lib_path=model_lib_path) + model_lib = "models/phi-2/lib.so" + engine = MLCEngine(model, model_lib=model_lib) :ref:`compile-model-libraries` introduces the model compilation command in detail, where you can find instructions and example commands to compile model to different diff --git a/examples/python/sample_mlc_chat.py b/examples/python/sample_mlc_chat.py index de00e84ff6..f4e49bb2bd 100644 --- a/examples/python/sample_mlc_chat.py +++ b/examples/python/sample_mlc_chat.py @@ -7,7 +7,7 @@ # Create a ChatModule instance cm = ChatModule( model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" + model_lib="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so", # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} @@ -18,7 +18,7 @@ # Here WizardMath reuses Mistral's model library # cm = ChatModule( # model="dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", # or "dist/WizardMath-7B-V1.1-q4f16_1-MLC" -# model_lib_path="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" +# model_lib="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" # ) # Generate a response for a given prompt diff --git a/python/mlc_llm/chat_module.py b/python/mlc_llm/chat_module.py index 24ad8faecf..2efc3ec9b9 100644 --- a/python/mlc_llm/chat_module.py +++ b/python/mlc_llm/chat_module.py @@ -442,7 +442,7 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi if field_name == "model_lib": warn_msg = ( 'WARNING: Do not override "model_lib" in ChatConfig. ' - "This override will be ignored. Please use ChatModule.model_lib_path to " + "This override will be ignored. Please use ChatModule.model_lib to " "override the full model library path instead." ) warnings.warn(warn_msg) @@ -493,7 +493,7 @@ def _get_lib_module_path( # pylint: disable=too-many-arguments model: str, model_path: str, chat_config: ChatConfig, - model_lib_path: Optional[str], + model_lib: Optional[str], device_name: str, config_file_path: str, ) -> str: @@ -507,7 +507,7 @@ def _get_lib_module_path( # pylint: disable=too-many-arguments Model path found by `_get_model_path`. chat_config : ChatConfig Chat config after potential overrides. Returned by ``_get_chat_config``. - model_lib_path : Optional[str] + model_lib : Optional[str] User's input. Supposedly a full path to model library. Prioritized to use. device_name : str User's input. Used to construct the library model file name. @@ -516,20 +516,20 @@ def _get_lib_module_path( # pylint: disable=too-many-arguments Returns ------- - model_lib_path : str + model_lib : str The path pointing to the model library we find. Raises ------ FileNotFoundError: if we cannot find a valid model library file. """ - # 1. Use user's model_lib_path if provided - if model_lib_path is not None: - if os.path.isfile(model_lib_path): - logger.info("Using library model: %s", model_lib_path) - return model_lib_path + # 1. Use user's model_lib if provided + if model_lib is not None: + if os.path.isfile(model_lib): + logger.info("Using library model: %s", model_lib) + return model_lib raise FileNotFoundError( - f"The `model_lib_path` you passed in is not a file: {model_lib_path}.\n" + f"The `model_lib` you passed in is not a file: {model_lib}.\n" f"Please refer to {_PYTHON_GET_STARTED_TUTORIAL_URL} as tutorial on model loading." ) @@ -584,7 +584,7 @@ def _get_lib_module_path( # pylint: disable=too-many-arguments err_msg += f"- {candidate}\n" err_msg += ( "If you would like to directly specify the model library path, you may " - "consider passing in the `ChatModule.model_lib_path` parameter.\n" + "consider passing in the `ChatModule.model_lib` parameter.\n" f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an example " "on how to load a model." ) @@ -654,12 +654,12 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio return json.dumps(asdict(generation_config)) -def _inspect_model_lib_metadata_memory_usage(model_lib_path, config_file_path): +def _inspect_model_lib_metadata_memory_usage(model_lib, config_file_path): cmd = [ sys.executable, "-m", "mlc_llm.cli.model_metadata", - model_lib_path, + model_lib, "--memory-only", "--mlc-chat-config", config_file_path, @@ -716,7 +716,7 @@ class ChatModule: # pylint: disable=too-many-instance-attributes A ``ChatConfig`` instance partially filled. Will be used to override the ``mlc-chat-config.json``. - model_lib_path : Optional[str] + model_lib : Optional[str] The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use the provided ``model`` to search over possible paths. @@ -727,7 +727,7 @@ def __init__( # pylint: disable=too-many-arguments model: str, device: str = "auto", chat_config: Optional[ChatConfig] = None, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, ): # 0. Get device: # Retrieve device_name and device_id (if any, default 0) from device arg @@ -768,12 +768,12 @@ def __init__( # pylint: disable=too-many-arguments self.chat_config = _get_chat_config(self.config_file_path, chat_config) # 4. Look up model library - if model_lib_path is not None: - self.model_lib_path = _get_lib_module_path( + if model_lib is not None: + self.model_lib = _get_lib_module_path( model, self.model_path, self.chat_config, - model_lib_path, + model_lib, self.device.MASK2STR[self.device.device_type], self.config_file_path, ) @@ -781,20 +781,20 @@ def __init__( # pylint: disable=too-many-arguments logger.info("Now compiling model lib on device...") from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel - self.model_lib_path = str( + self.model_lib = str( jit.jit( model_path=Path(self.model_path), chat_config=asdict(self.chat_config), device=self.device, ) ) - _inspect_model_lib_metadata_memory_usage(self.model_lib_path, self.config_file_path) + _inspect_model_lib_metadata_memory_usage(self.model_lib, self.config_file_path) # 5. Call reload user_chat_config_json_str = _convert_chat_config_to_json_str( self.chat_config, self.chat_config.conv_template ) - self._reload(self.model_lib_path, self.model_path, user_chat_config_json_str) + self._reload(self.model_lib, self.model_path, user_chat_config_json_str) def generate( self, diff --git a/python/mlc_llm/cli/bench.py b/python/mlc_llm/cli/bench.py index 26b74b1f10..0e42048ff2 100644 --- a/python/mlc_llm/cli/bench.py +++ b/python/mlc_llm/cli/bench.py @@ -1,4 +1,5 @@ """Command line entrypoint of benchmark.""" + from mlc_llm.help import HELP from mlc_llm.interface.bench import bench from mlc_llm.interface.chat import ChatConfigOverride @@ -45,10 +46,10 @@ def main(argv): help=HELP["generate_length"] + ' (default: "%(default)s")', ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, default=None, - help=HELP["model_lib_path"] + ' (default: "%(default)s")', + help=HELP["model_lib"] + ' (default: "%(default)s")', ) parsed = parser.parse_args(argv) bench( @@ -58,5 +59,5 @@ def main(argv): opt=parsed.opt, overrides=parsed.overrides, generate_length=parsed.generate_length, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, ) diff --git a/python/mlc_llm/cli/benchmark.py b/python/mlc_llm/cli/benchmark.py index 72c86fab03..aa22bae68c 100644 --- a/python/mlc_llm/cli/benchmark.py +++ b/python/mlc_llm/cli/benchmark.py @@ -1,4 +1,5 @@ """A command line tool for benchmarking a chat model.""" + import argparse from pathlib import Path @@ -74,7 +75,7 @@ def main(): model=args.model, device=args.device, chat_config=ChatConfig(tensor_parallel_shards=args.tensor_parallel_shards), - model_lib_path=args.model_lib, + model_lib=args.model_lib, ) prompt = _load_prompt(args.prompt) output = chat_module.benchmark_generate(prompt, generate_length=args.generate_length) diff --git a/python/mlc_llm/cli/chat.py b/python/mlc_llm/cli/chat.py index 13c83a64ec..34fb5daa09 100644 --- a/python/mlc_llm/cli/chat.py +++ b/python/mlc_llm/cli/chat.py @@ -1,4 +1,5 @@ """Command line entrypoint of chat.""" + from mlc_llm.help import HELP from mlc_llm.interface.chat import ChatConfigOverride, chat from mlc_llm.support.argparse import ArgumentParser @@ -32,10 +33,10 @@ def main(argv): help=HELP["chatconfig_overrides"] + ' (default: "%(default)s")', ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, default=None, - help=HELP["model_lib_path"] + ' (default: "%(default)s")', + help=HELP["model_lib"] + ' (default: "%(default)s")', ) parsed = parser.parse_args(argv) chat( @@ -43,5 +44,5 @@ def main(argv): device=parsed.device, opt=parsed.opt, overrides=parsed.overrides, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, ) diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index 6663a0c230..9ba0e01e3d 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -4,7 +4,6 @@ from mlc_llm.help import HELP from mlc_llm.interface.serve import serve -from mlc_llm.serve.config import SpeculativeMode from mlc_llm.support.argparse import ArgumentParser @@ -24,10 +23,10 @@ def main(argv): help=HELP["device_deploy"] + ' (default: "%(default)s")', ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, default=None, - help=HELP["model_lib_path"] + ' (default: "%(default)s")', + help=HELP["model_lib"] + ' (default: "%(default)s")', ) parser.add_argument( "--mode", @@ -44,18 +43,16 @@ def main(argv): "--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"] ) parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"]) - parser.add_argument( - "--max-history-size", type=int, default=1, help=HELP["max_history_size_serve"] - ) + parser.add_argument("--max-history-size", type=int, help=HELP["max_history_size_serve"]) parser.add_argument( "--gpu-memory-utilization", type=float, help=HELP["gpu_memory_utilization_serve"] ) parser.add_argument( "--speculative-mode", type=str, - choices=["DISABLE", "SMALL_DRAFT", "EAGLE"], - default="DISABLE", - help=HELP["speculative_mode_serve"], + choices=["disable", "small_draft", "eable"], + default="disable", + help=HELP["speculative_mode_serve"] + ' (default: "%(default)s")', ) parser.add_argument( "--spec-draft-length", type=int, default=4, help=HELP["spec_draft_length_serve"] @@ -97,7 +94,7 @@ def main(argv): serve( model=parsed.model, device=parsed.device, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, mode=parsed.mode, additional_models=parsed.additional_models, max_batch_size=parsed.max_batch_size, @@ -105,7 +102,7 @@ def main(argv): prefill_chunk_size=parsed.prefill_chunk_size, max_history_size=parsed.max_history_size, gpu_memory_utilization=parsed.gpu_memory_utilization, - speculative_mode=SpeculativeMode[parsed.speculative_mode], + speculative_mode=parsed.speculative_mode, spec_draft_length=parsed.spec_draft_length, enable_tracing=parsed.enable_tracing, host=parsed.host, diff --git a/python/mlc_llm/help.py b/python/mlc_llm/help.py index 86930fa5ea..f6ef6c38af 100644 --- a/python/mlc_llm/help.py +++ b/python/mlc_llm/help.py @@ -25,9 +25,9 @@ A path to ``mlc-chat-config.json``, or an MLC model directory that contains `mlc-chat-config.json`. It can also be a link to a HF repository pointing to an MLC compiled model. """.strip(), - "model_lib_path": """ + "model_lib": """ The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use -the provided ``model`` to search over possible paths. It the model lib path is not found, it will be +the provided ``model`` to search over possible paths. It the model lib is not found, it will be compiled in a JIT manner. """.strip(), "model_type": """ @@ -186,8 +186,8 @@ When engine is enabled with speculative decoding, additional models are needed. The way of specifying additional models is: "--additional-models model_path_1 model_path_2 ..." or -"--additional-models model_path_1:model_lib_path_1 model_path_2 ...". -When the model lib path of a model is not given, JIT model compilation will be activated +"--additional-models model_path_1:model_lib_1 model_path_2 ...". +When the model lib of a model is not given, JIT model compilation will be activated to compile the model automatically. """, "gpu_memory_utilization_serve": """ @@ -199,10 +199,10 @@ """, "speculative_mode_serve": """ The speculative decoding mode. Right now three options are supported: - - DISABLE, where speculative decoding is not enabled, - - SMALL_DRAFT, denoting the normal speculative decoding (small draft) style, - - EAGLE, denoting the eagle-style speculative decoding. -The default mode is "DISABLE". + - "disable", where speculative decoding is not enabled, + - "small_draft", denoting the normal speculative decoding (small draft) style, + - "eagle", denoting the eagle-style speculative decoding. +The default mode is "disable". """, "spec_draft_length_serve": """ The number of draft tokens to generate in speculative proposal. The default values is 4. diff --git a/python/mlc_llm/interface/bench.py b/python/mlc_llm/interface/bench.py index 6a7d833447..baa350df05 100644 --- a/python/mlc_llm/interface/bench.py +++ b/python/mlc_llm/interface/bench.py @@ -1,4 +1,5 @@ """Python entrypoint of benchmark.""" + from typing import Optional from mlc_llm.chat_module import ChatConfig, ChatModule @@ -13,7 +14,7 @@ def bench( # pylint: disable=too-many-arguments opt: str, overrides: ChatConfigOverride, generate_length: int, - model_lib_path: Optional[str], + model_lib: Optional[str], ): """run the benchmarking""" # Set up chat config @@ -21,7 +22,7 @@ def bench( # pylint: disable=too-many-arguments # Apply overrides config = overrides.apply(config) # Set up ChatModule - cm = ChatModule(model, device, chat_config=config, model_lib_path=model_lib_path) + cm = ChatModule(model, device, chat_config=config, model_lib=model_lib) output = cm.benchmark_generate(prompt, generate_length=generate_length) print(f"Generated text:\n{output}\n") diff --git a/python/mlc_llm/interface/chat.py b/python/mlc_llm/interface/chat.py index 9c0763a6ef..75985ec27a 100644 --- a/python/mlc_llm/interface/chat.py +++ b/python/mlc_llm/interface/chat.py @@ -1,4 +1,5 @@ """Python entrypoint of chat.""" + import dataclasses from typing import List, Optional, Union @@ -121,7 +122,7 @@ def chat( device: str, opt: str, overrides: ChatConfigOverride, - model_lib_path: Optional[str], + model_lib: Optional[str], ): """chat with a model.""" # Set up chat config and generate config @@ -130,7 +131,7 @@ def chat( # Apply overrides config = overrides.apply(config) # Set up ChatModule - cm = ChatModule(model, device, chat_config=config, model_lib_path=model_lib_path) + cm = ChatModule(model, device, chat_config=config, model_lib=model_lib) _print_help_str() cm._process_system_prompts() # pylint: disable=protected-access diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index 40fa9fdda8..d1cde12678 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -8,7 +8,6 @@ from mlc_llm.protocol import error_protocol from mlc_llm.serve import engine -from mlc_llm.serve.config import SpeculativeMode from mlc_llm.serve.entrypoints import debug_entrypoints, openai_entrypoints from mlc_llm.serve.server import ServerContext @@ -16,7 +15,7 @@ def serve( model: str, device: str, - model_lib_path: Optional[str], + model_lib: Optional[str], mode: Literal["local", "interactive", "server"], additional_models: List[str], max_batch_size: Optional[int], @@ -24,7 +23,7 @@ def serve( prefill_chunk_size: Optional[int], max_history_size: Optional[int], gpu_memory_utilization: Optional[float], - speculative_mode: SpeculativeMode, + speculative_mode: Literal["disable", "small_draft", "eagle"], spec_draft_length: int, enable_tracing: bool, host: str, @@ -39,7 +38,7 @@ def serve( async_engine = engine.AsyncMLCEngine( model=model, device=device, - model_lib_path=model_lib_path, + model_lib=model_lib, mode=mode, additional_models=additional_models, max_batch_size=max_batch_size, diff --git a/python/mlc_llm/json_ffi/engine.py b/python/mlc_llm/json_ffi/engine.py index 0c604a2ef3..237319a926 100644 --- a/python/mlc_llm/json_ffi/engine.py +++ b/python/mlc_llm/json_ffi/engine.py @@ -1,6 +1,5 @@ # pylint: disable=chained-comparison,missing-docstring,too-few-public-methods,too-many-instance-attributes # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -import json import queue import threading from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union @@ -11,8 +10,6 @@ from mlc_llm.serve import engine_utils from mlc_llm.serve.engine_base import ( EngineConfig, - SpeculativeMode, - _infer_kv_cache_config, _parse_models, _process_model_args, detect_device, @@ -20,32 +17,6 @@ from mlc_llm.tokenizer import Tokenizer -# TODO(mlc-team): further minimize the JSONFFIEngine -# construction to not depend on any config and directly pass in JSON -# model defined generation config should be read from the JSONFFIEngine via Reload -def create_model_defined_generation_config( - temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float -) -> tvm.runtime.Object: - return tvm.get_global_func("mlc.json_ffi.ModelDefinedGenerationConfig")( - temperature, - top_p, - frequency_penalty, - presence_penalty, - ) - - -# TODO(mlc-team): further minimize the JSONFFIEngine -# Engine config should be passed as json str -# and backend should have good default -# only model and model_lib should be mandatory -def create_json_ffi_engine_config( - conv_template: str, model_generation_cfgs: Dict[str, tvm.runtime.Object] -) -> tvm.runtime.Object: - return tvm.get_global_func("mlc.json_ffi.JSONFFIEngineConfig")( - conv_template, model_generation_cfgs - ) - - class EngineState: sync_queue: queue.Queue @@ -70,27 +41,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model: str, device: Union[str, tvm.runtime.Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, max_history_size: Optional[int] = None, prefill_chunk_size: Optional[int] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable", spec_draft_length: int = 4, gpu_memory_utilization: Optional[float] = None, ) -> None: # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) + models = _parse_models(model, model_lib, additional_models) if isinstance(device, str): device = detect_device(device) assert isinstance(device, tvm.runtime.Device) - ( - model_args, - model_config_paths, - self.conv_template, - ) = _process_model_args(models, device) + model_args = _process_model_args(models, device)[0] # TODO(mlc-team) Remove the model config parsing, estimation below # in favor of a simple direct passing of parameters into backend. @@ -103,33 +70,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals # since we won't have similar logics in android/iOS # # - Load the raw model config into dict - self.model_config_dicts = [] for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] - with open(model_config_paths[i], "r", encoding="utf-8") as file: - self.model_config_dicts.append(json.load(file)) - - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + model_info.model_lib = model_args[i][1] # - Initialize engine state and engine. self.state = EngineState() @@ -151,43 +93,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals } self.tokenizer = Tokenizer(model_args[0][0]) - self.engine_config = EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - max_history_size=max_history_size, - kv_state_kind=kv_state_kind, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ) - - self.json_ffi_engine_config = create_json_ffi_engine_config( - conv_template=self.conv_template.model_dump_json(), - model_generation_cfgs={ - model.model: create_model_defined_generation_config( - temperature=model_config["temperature"], - top_p=model_config["top_p"], - frequency_penalty=model_config["frequency_penalty"], - presence_penalty=model_config["presence_penalty"], - ) - for model, model_config in zip(models, self.model_config_dicts) - }, - ) - - self._ffi["init_background_engine"]( - self.json_ffi_engine_config, - self.engine_config, - device, - self.state.get_request_stream_callback(), - None, - ) - def _background_loop(): self._ffi["run_background_loop"]() @@ -203,6 +108,26 @@ def _background_stream_back_loop(): self._background_stream_back_loop_thread.start() self._terminated = False + self.engine_config = EngineConfig( + model=model_args[0][0], + model_lib=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_libs=[model_arg[1] for model_arg in model_args[1:]], + mode=mode, + gpu_memory_utilization=gpu_memory_utilization, + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + verbose=False, + ) + + self._ffi["init_background_engine"](device, self.state.get_request_stream_callback(), None) + self._ffi["reload"](self.engine_config.asjson()) + def terminate(self): self._terminated = True self._ffi["exit_background_loop"]() @@ -301,7 +226,7 @@ def _handle_chat_completion( raise exception def _test_reload(self): - self._ffi["reload"](self.engine_config) + self._ffi["reload"](self.engine_config.asjson()) def _test_reset(self): self._ffi["reset"]() diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 4a5168f971..9a0a724ea1 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -107,9 +107,9 @@ class CompletionRequest(BaseModel): @field_validator("frequency_penalty", "presence_penalty") @classmethod - def check_penalty_range(cls, penalty_value: float) -> float: + def check_penalty_range(cls, penalty_value: Optional[float]) -> Optional[float]: """Check if the penalty value is in range [-2, 2].""" - if penalty_value < -2 or penalty_value > 2: + if penalty_value and (penalty_value < -2 or penalty_value > 2): raise ValueError("Penalty value should be in range [-2, 2].") return penalty_value @@ -221,7 +221,7 @@ class ChatCompletionRequest(BaseModel): @field_validator("frequency_penalty", "presence_penalty") @classmethod - def check_penalty_range(cls, penalty_value: float) -> float: + def check_penalty_range(cls, penalty_value: Optional[float]) -> Optional[float]: """Check if the penalty value is in range [-2, 2].""" if penalty_value and (penalty_value < -2 or penalty_value > 2): raise ValueError("Penalty value should be in range [-2, 2].") @@ -386,7 +386,7 @@ def openai_api_get_unsupported_fields( def openai_api_get_generation_config( - request: Union[CompletionRequest, ChatCompletionRequest], model_config: Dict[str, Any] + request: Union[CompletionRequest, ChatCompletionRequest] ) -> Dict[str, Any]: """Create the generation config from the given request.""" from ..serve.config import ResponseFormat # pylint: disable=import-outside-toplevel @@ -407,17 +407,6 @@ def openai_api_get_generation_config( ] for arg_name in arg_names: kwargs[arg_name] = getattr(request, arg_name) - - # If per-request generation config values are missing, try loading from model config. - # If still not found, then use the default OpenAI API value - if kwargs["temperature"] is None: - kwargs["temperature"] = model_config.get("temperature", 1.0) - if kwargs["top_p"] is None: - kwargs["top_p"] = model_config.get("top_p", 1.0) - if kwargs["frequency_penalty"] is None: - kwargs["frequency_penalty"] = model_config.get("frequency_penalty", 0.0) - if kwargs["presence_penalty"] is None: - kwargs["presence_penalty"] = model_config.get("presence_penalty", 0.0) if kwargs["max_tokens"] is None: # Setting to -1 means the generation will not stop until # exceeding model capability or hit any stop criteria. diff --git a/python/mlc_llm/protocol/protocol_utils.py b/python/mlc_llm/protocol/protocol_utils.py index 3005909bbd..f4273d0302 100644 --- a/python/mlc_llm/protocol/protocol_utils.py +++ b/python/mlc_llm/protocol/protocol_utils.py @@ -23,14 +23,13 @@ def get_unsupported_fields(request: RequestProtocol) -> List[str]: def get_generation_config( request: RequestProtocol, - model_config: Dict[str, Any], extra_stop_token_ids: Optional[List[int]] = None, extra_stop_str: Optional[List[str]] = None, ) -> GenerationConfig: """Create the generation config in MLC LLM out from the input request protocol.""" kwargs: Dict[str, Any] if isinstance(request, (OpenAICompletionRequest, OpenAIChatCompletionRequest)): - kwargs = openai_api_get_generation_config(request, model_config) + kwargs = openai_api_get_generation_config(request) else: raise RuntimeError("Cannot reach here") diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 59358c1646..ec6899ea26 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -2,7 +2,7 @@ # Load MLC LLM library by importing base from .. import base -from .config import EngineConfig, GenerationConfig, SpeculativeMode +from .config import EngineConfig, GenerationConfig from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData from .engine import AsyncMLCEngine, MLCEngine from .grammar import BNFGrammar, GrammarStateMatcher diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 6b808ac37b..916403839a 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -1,14 +1,9 @@ """Configuration dataclasses used in MLC LLM serving""" -import enum import json from dataclasses import asdict, dataclass, field from typing import Dict, List, Literal, Optional -import tvm - -from . import _ffi_api - @dataclass class ResponseFormat: @@ -43,19 +38,19 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes n : int How many chat completion choices to generate for each input message. - temperature : float + temperature : Optional[float] The value that applies to logits and modulates the next token probabilities. - top_p : float + top_p : Optional[float] In sampling, only the most probable tokens with probabilities summed up to `top_p` are kept for sampling. - frequency_penalty : float + frequency_penalty : Optional[float] Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. - presence_penalty : float + presence_penalty : Optional[float] Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. @@ -101,10 +96,10 @@ class GenerationConfig: # pylint: disable=too-many-instance-attributes """ n: int = 1 - temperature: float = 0.8 - top_p: float = 0.95 - frequency_penalty: float = 0.0 - presence_penalty: float = 0.0 + temperature: Optional[float] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None repetition_penalty: float = 1.0 logprobs: bool = False top_logprobs: int = 0 @@ -128,26 +123,8 @@ def from_json(json_str: str) -> "GenerationConfig": return GenerationConfig(**json.loads(json_str)) -class KVStateKind(enum.IntEnum): # pylint: disable=too-few-public-methods - """Possible kinds of KV state.""" - - ATTENTION = 0 - RNNSTATE = 1 - - -class SpeculativeMode(enum.IntEnum): - """The speculative mode.""" - - # Disable speculative decoding. - DISABLE = 0 - # The normal speculative decoding (small draft) mode. - SMALL_DRAFT = 1 - # The eagle-style speculative decoding. - EAGLE = 2 - - -@tvm._ffi.register_object("mlc.serve.EngineConfig") # pylint: disable=protected-access -class EngineConfig(tvm.runtime.Object): +@dataclass +class EngineConfig: # pylint: disable=too-many-instance-attributes """The class of MLCEngine execution configuration. Parameters @@ -155,74 +132,103 @@ class EngineConfig(tvm.runtime.Object): model : str The path to the model directory. - model_lib_path : str + model_lib : str The path to the model library. additional_models : List[str] The path to the additional models' directories. - additional_model_lib_paths : List[str] + additional_model_libs : List[str] The path to the additional models' libraries. + mode : Literal["local", "interactive", "server"] + The engine mode in MLC LLM. + We provide three preset modes: "local", "interactive" and "server". + The default mode is "local". + The choice of mode decides the values of "max_batch_size", "max_total_sequence_length" + and "prefill_chunk_size" when they are not explicitly specified. + 1. Mode "local" refers to the local server deployment which has low + request concurrency. So the max batch size will be set to 4, and max + total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 2. Mode "interactive" refers to the interactive use of server, which + has at most 1 concurrent request. So the max batch size will be set to 1, + and max total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 3. Mode "server" refers to the large server use case which may handle + many concurrent request and want to use GPU memory as much as possible. + In this mode, we will automatically infer the largest possible max batch + size and max total sequence length. + + You can manually specify arguments "max_batch_size", "max_total_sequence_length" and + "prefill_chunk_size" to override the automatic inferred values. + + gpu_memory_utilization : float + A number in (0, 1) denoting the fraction of GPU memory used by the server in total. + It is used to infer to maximum possible KV cache capacity. + When it is unspecified, it defaults to 0.85. + Under mode "local" or "interactive", the actual memory usage may be + significantly smaller than this number. Under mode "server", the actual + memory usage may be slightly larger than this number. + kv_cache_page_size : int The number of consecutive tokens handled in each page in paged KV cache. - max_num_sequence : int + max_num_sequence : Optional[int] The maximum number of sequences that are allowed to be processed by the KV cache at any time. - max_total_sequence_length : int + max_total_sequence_length : Optional[int] The maximum length allowed for a single sequence in the engine. - max_single_sequence_length : int + max_single_sequence_length : Optional[int] The maximum total number of tokens whose KV data are allowed to exist in the KV cache at any time. - prefill_chunk_size : int + prefill_chunk_size : Optional[int] The maximum total sequence length in a prefill. - max_history_size: int + max_history_size: Optional[int] The maximum history size for RNN state to rool back. - kv_state_kind: KVStateKind + kv_state_kind: Optional[Literal["kv_cache", "rnn_state"]] The kind of cache. - speculative_mode : SpeculativeMode + speculative_mode : Literal["disable", "small_draft", "eagle"] The speculative mode. + "disable" means speculative decoding is disabled. + "small_draft" means the normal speculative decoding (small draft) mode. + "eagle" means the eagle-style speculative decoding. spec_draft_length : int The number of tokens to generate in speculative proposal (draft). + + verbose : bool + A boolean indicating whether to print logging info in engine. """ - def __init__( # pylint: disable=too-many-arguments - self, - model: str, - model_lib_path: str, - additional_models: List[str], - additional_model_lib_paths: List[str], - kv_cache_page_size: int, - max_num_sequence: int, - max_total_sequence_length: int, - max_single_sequence_length: int, - prefill_chunk_size: int, - max_history_size: int, - kv_state_kind: KVStateKind, - speculative_mode: SpeculativeMode, - spec_draft_length: int, - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.EngineConfig, # type: ignore # pylint: disable=no-member - model, - model_lib_path, - additional_models, - additional_model_lib_paths, - kv_cache_page_size, - max_num_sequence, - max_total_sequence_length, - max_single_sequence_length, - prefill_chunk_size, - max_history_size, - kv_state_kind, - speculative_mode, - spec_draft_length, - ) + model: str + model_lib: str + additional_models: List[str] = field(default_factory=list) + additional_model_libs: List[str] = field(default_factory=list) + mode: Literal["local", "interactive", "server"] = "local" + gpu_memory_utilization: Optional[float] = None + kv_cache_page_size: int = 16 + max_num_sequence: Optional[int] = None + max_total_sequence_length: Optional[int] = None + max_single_sequence_length: Optional[int] = None + prefill_chunk_size: Optional[int] = None + max_history_size: Optional[int] = None + kv_state_kind: Optional[Literal["kv_cache", "rnn_state"]] = None + speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable" + spec_draft_length: int = 4 + verbose: bool = True + + def asjson(self) -> str: + """Return the config in string of JSON format.""" + return json.dumps(asdict(self)) + + @staticmethod + def from_json(json_str: str) -> "EngineConfig": + """Construct a config from JSON string.""" + return EngineConfig(**json.loads(json_str)) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 413c856db1..8b63a65130 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -22,7 +22,7 @@ from mlc_llm.protocol import openai_api_protocol from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import GenerationConfig, SpeculativeMode +from mlc_llm.serve.config import GenerationConfig from mlc_llm.serve.request import Request from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -63,8 +63,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals messages: List[Dict[str, Any]], stream: Literal[True], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -72,8 +72,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -112,8 +112,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -122,8 +122,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Literal[False] = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -161,8 +161,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -171,8 +171,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -240,8 +240,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals messages: List[Dict[str, Any]], stream: Literal[True], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -249,8 +249,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -289,8 +289,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -299,8 +299,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Literal[False] = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -336,8 +336,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -346,8 +346,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -417,8 +417,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -427,8 +427,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -467,8 +467,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -478,8 +478,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals stop: Optional[Union[str, List[str]]] = None, stream: Literal[False] = False, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -515,8 +515,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -526,8 +526,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals stop: Optional[Union[str, List[str]]] = None, stream: bool = False, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -596,8 +596,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -606,8 +606,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -646,8 +646,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -657,8 +657,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals stop: Optional[Union[str, List[str]]] = None, stream: Literal[False] = False, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -694,8 +694,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -705,8 +705,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals stop: Optional[Union[str, List[str]]] = None, stream: bool = False, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -758,7 +758,7 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): Parameters ---------- - models : str + model : str A path to ``mlc-chat-config.json``, or an MLC model directory that contains `mlc-chat-config.json`. It can also be a link to a HF repository pointing to an MLC compiled model. @@ -767,10 +767,10 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): The device used to deploy the model such as "cuda" or "cuda:0". Will default to "auto" and detect from local available GPUs if not specified. - model_lib_path : Optional[str] + model_lib : Optional[str] The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use the provided ``model`` to search over possible paths. - It the model lib path is not found, it will be compiled in a JIT manner. + It the model lib is not found, it will be compiled in a JIT manner. mode : Literal["local", "interactive", "server"] The engine mode in MLC LLM. @@ -798,8 +798,8 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): The model paths and (optional) model library paths of additional models (other than the main model). When engine is enabled with speculative decoding, additional models are needed. - Each string in the list is either in form "model_path" or "model_path:model_lib_path". - When the model lib path of a model is not given, JIT model compilation will + Each string in the list is either in form "model_path" or "model_path:model_lib". + When the model lib of a model is not given, JIT model compilation will be activated to compile the model automatically. max_batch_size : Optional[int] @@ -827,15 +827,20 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. - engine_config : Optional[EngineConfig] - The MLCEngine execution configuration. - Currently speculative decoding mode is specified via engine config. - For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" - to specify the eagle-style speculative decoding. - Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. + speculative_mode : Literal["disable", "small_draft", "eagle"] + The speculative mode. + "disable" means speculative decoding is disabled. + "small_draft" means the normal speculative decoding (small draft) mode. + "eagle" means the eagle-style speculative decoding. + + spec_draft_length : int + The number of tokens to generate in speculative proposal (draft). enable_tracing : bool A boolean indicating if to enable event logging for requests. + + verbose : bool + A boolean indicating whether to print logging info in engine. """ def __init__( # pylint: disable=too-many-arguments @@ -843,7 +848,7 @@ def __init__( # pylint: disable=too-many-arguments model: str, device: Union[str, Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, @@ -851,15 +856,16 @@ def __init__( # pylint: disable=too-many-arguments prefill_chunk_size: Optional[int] = None, max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable", spec_draft_length: int = 4, enable_tracing: bool = False, + verbose: bool = True, ) -> None: super().__init__( "async", model=model, device=device, - model_lib_path=model_lib_path, + model_lib=model_lib, mode=mode, additional_models=additional_models, max_batch_size=max_batch_size, @@ -870,6 +876,7 @@ def __init__( # pylint: disable=too-many-arguments speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, enable_tracing=enable_tracing, + verbose=verbose, ) self.chat = Chat(weakref.ref(self)) self.completions = AsyncCompletion(weakref.ref(self)) @@ -889,8 +896,8 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -899,8 +906,8 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -1012,8 +1019,8 @@ async def _completion( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -1023,8 +1030,8 @@ async def _completion( # pylint: disable=too-many-arguments,too-many-locals stop: Optional[Union[str, List[str]]] = None, stream: bool = False, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -1194,7 +1201,6 @@ async def _handle_completion( request, request_id, self.state, - self.model_config_dicts[0], self.tokenizer, self.max_input_sequence_length, ) @@ -1264,7 +1270,9 @@ async def _generate( # Create the request with the given id, input data, generation # config and the created callback. input_data = engine_utils.convert_prompts_to_data(prompt) - request = Request(request_id, input_data, generation_config) + request = Request( + request_id, input_data, generation_config, self.default_generation_cfg_json_str + ) # Create the unique async request stream of the request. stream = engine_base.AsyncRequestStream() @@ -1309,7 +1317,7 @@ class MLCEngine(engine_base.MLCEngineBase): Parameters ---------- - models : str + model : str A path to ``mlc-chat-config.json``, or an MLC model directory that contains `mlc-chat-config.json`. It can also be a link to a HF repository pointing to an MLC compiled model. @@ -1318,10 +1326,10 @@ class MLCEngine(engine_base.MLCEngineBase): The device used to deploy the model such as "cuda" or "cuda:0". Will default to "auto" and detect from local available GPUs if not specified. - model_lib_path : Optional[str] + model_lib : Optional[str] The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use the provided ``model`` to search over possible paths. - It the model lib path is not found, it will be compiled in a JIT manner. + It the model lib is not found, it will be compiled in a JIT manner. mode : Literal["local", "interactive", "server"] The engine mode in MLC LLM. @@ -1349,8 +1357,8 @@ class MLCEngine(engine_base.MLCEngineBase): The model paths and (optional) model library paths of additional models (other than the main model). When engine is enabled with speculative decoding, additional models are needed. - Each string in the list is either in form "model_path" or "model_path:model_lib_path". - When the model lib path of a model is not given, JIT model compilation will + Each string in the list is either in form "model_path" or "model_path:model_lib". + When the model lib of a model is not given, JIT model compilation will be activated to compile the model automatically. max_batch_size : Optional[int] @@ -1375,15 +1383,20 @@ class MLCEngine(engine_base.MLCEngineBase): significantly smaller than this number. Under mode "server", the actual memory usage may be slightly larger than this number. - engine_config : Optional[EngineConfig] - The MLCEngine execution configuration. - Currently speculative decoding mode is specified via engine config. - For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" - to specify the eagle-style speculative decoding. - Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. + speculative_mode : Literal["disable", "small_draft", "eagle"] + The speculative mode. + "disable" means speculative decoding is disabled. + "small_draft" means the normal speculative decoding (small draft) mode. + "eagle" means the eagle-style speculative decoding. + + spec_draft_length : int + The number of tokens to generate in speculative proposal (draft). enable_tracing : bool A boolean indicating if to enable event logging for requests. + + verbose : bool + A boolean indicating whether to print logging info in engine. """ def __init__( # pylint: disable=too-many-arguments @@ -1391,7 +1404,7 @@ def __init__( # pylint: disable=too-many-arguments model: str, device: Union[str, Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, @@ -1399,15 +1412,16 @@ def __init__( # pylint: disable=too-many-arguments prefill_chunk_size: Optional[int] = None, max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable", spec_draft_length: int = 4, enable_tracing: bool = False, + verbose: bool = True, ) -> None: super().__init__( "sync", model=model, device=device, - model_lib_path=model_lib_path, + model_lib=model_lib, mode=mode, additional_models=additional_models, max_batch_size=max_batch_size, @@ -1418,6 +1432,7 @@ def __init__( # pylint: disable=too-many-arguments speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, enable_tracing=enable_tracing, + verbose=verbose, ) self.chat = Chat(weakref.ref(self)) self.completions = Completion(weakref.ref(self)) @@ -1437,8 +1452,8 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -1447,8 +1462,8 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, @@ -1560,8 +1575,8 @@ def _completion( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -1571,8 +1586,8 @@ def _completion( # pylint: disable=too-many-arguments,too-many-locals stop: Optional[Union[str, List[str]]] = None, stream: bool = False, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, @@ -1737,7 +1752,6 @@ def _handle_completion( request, request_id, self.state, - self.model_config_dicts[0], self.tokenizer, self.max_input_sequence_length, ) @@ -1804,7 +1818,9 @@ def _generate( # pylint: disable=too-many-locals # Create the request with the given id, input data, generation # config and the created callback. input_data = engine_utils.convert_prompts_to_data(prompt) - request = Request(request_id, input_data, generation_config) + request = Request( + request_id, input_data, generation_config, self.default_generation_cfg_json_str + ) # Record the stream in the tracker self.state.sync_output_queue = queue.Queue() diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 7f3f7e1331..e0d7160ece 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -6,7 +6,6 @@ import asyncio import json import queue -import subprocess import sys import threading from dataclasses import asdict, dataclass @@ -20,12 +19,7 @@ from mlc_llm.protocol import openai_api_protocol, protocol_utils from mlc_llm.protocol.conversation_protocol import Conversation from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import ( - EngineConfig, - GenerationConfig, - KVStateKind, - SpeculativeMode, -) +from mlc_llm.serve.config import EngineConfig, GenerationConfig from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -49,25 +43,25 @@ class ModelInfo: or a full path to a model directory (e.g., "dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1") - model_lib_path : Optional[str] + model_lib : Optional[str] The path to the compiled library of the model. E.g., "dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so" """ model: str - model_lib_path: Optional[str] = None + model_lib: Optional[str] = None def _parse_models( - model: str, model_lib_path: Optional[str], additional_models: Optional[List[str]] + model: str, model_lib: Optional[str], additional_models: Optional[List[str]] ) -> List[ModelInfo]: - """Parse the specified model paths and model lib paths. + """Parse the specified model paths and model libs. Return a list of ModelInfo, which is a wrapper class of the model path + lib path. Each additional model is expected to follow the format of either - "{MODEL_PATH}" or "{MODEL_PATH}:{MODEL_LIB_PATH}". + "{MODEL_PATH}" or "{MODEL_PATH}:{MODEL_LIB}". """ - models = [ModelInfo(model, model_lib_path)] + models = [ModelInfo(model, model_lib)] if additional_models is not None: for additional_model in additional_models: splits = additional_model.split(":", maxsplit=1) @@ -95,30 +89,30 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: assert isinstance(chat_config.conv_template, Conversation) conversation = chat_config.conv_template - if model.model_lib_path is not None: - # do model lib search if the model lib path is provided + if model.model_lib is not None: + # do model lib search if the model lib is provided # error out if file not found - model_lib_path = _get_lib_module_path( + model_lib = _get_lib_module_path( model=model.model, model_path=model_path, chat_config=chat_config, - model_lib_path=model.model_lib_path, + model_lib=model.model_lib, device_name=device.MASK2STR[device.device_type], config_file_path=config_file_path, ) else: # TODO(mlc-team) add logging information - # Run jit if model_lib_path is not provided + # Run jit if model_lib is not provided from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel - model_lib_path = str( + model_lib = str( jit.jit( model_path=Path(model_path), chat_config=asdict(chat_config), device=device, ) ) - return model_path, model_lib_path + return model_path, model_lib model_args: List[Tuple[str, str]] = [_convert_model_info(model) for model in models] @@ -126,618 +120,43 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: return model_args, config_file_paths, conversation -def _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( # pylint: disable=too-many-locals,too-many-arguments - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_paths: List[str], - model_config_dicts: List[Dict[str, Any]], - max_num_sequence: int, - gpu_memory_utilization: Optional[float], -) -> Tuple[float, float, float, float, float, int]: - """Estimate the memory usage and the max total sequence length (capacity) - that the KV cache can support. - """ - assert len(models) != 0 - - kv_bytes_per_token = 0 - kv_aux_workspace_bytes = 0 - model_workspace_bytes = 0 - logit_processor_workspace_bytes = 0 - params_bytes = 0 - temp_func_bytes = 0 - - for model, model_config_path, model_config_dict in zip( - models, model_config_paths, model_config_dicts - ): - # Read metadata for the parameter size and the temporary memory size. - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-memory-usage-in-json", - "--mlc-chat-config", - model_config_path, - ] - usage_str = subprocess.check_output(cmd, universal_newlines=True) - usage_json = json.loads(usage_str) - params_bytes += usage_json["params_bytes"] - temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) - - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-kv-cache-metadata-in-json", - ] - kv_cache_metadata_str = subprocess.check_output(cmd, universal_newlines=True) - kv_cache_metadata = json.loads(kv_cache_metadata_str) - - # Read model config and compute the kv size per token. - model_config = model_config_dict["model_config"] - vocab_size = model_config["vocab_size"] - prefill_chunk_size = model_config["prefill_chunk_size"] - num_layers = kv_cache_metadata["num_hidden_layers"] - head_dim = kv_cache_metadata["head_dim"] - num_qo_heads = kv_cache_metadata["num_attention_heads"] - num_kv_heads = kv_cache_metadata["num_key_value_heads"] - hidden_size = head_dim * num_qo_heads - kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25 - kv_aux_workspace_bytes += ( - (max_num_sequence + 1) * 88 - + prefill_chunk_size * (num_qo_heads + 1) * 8 - + prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 - + 48 * 1024 * 1024 - ) - model_workspace_bytes += ( - prefill_chunk_size * 4 - + max_num_sequence * 4 - + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 - ) - logit_processor_workspace_bytes += ( - max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 - ) - - # Get single-card GPU size. - gpu_size_bytes = device.total_global_memory - if gpu_size_bytes is None: - raise ValueError("Cannot read total GPU global memory from device.") - if gpu_memory_utilization is None: - gpu_memory_utilization = 0.85 - - model_max_total_sequence_length = int( - ( - int(gpu_size_bytes) * gpu_memory_utilization - - params_bytes - - temp_func_bytes - - kv_aux_workspace_bytes - - model_workspace_bytes - - logit_processor_workspace_bytes - ) - / kv_bytes_per_token - ) - if model_max_total_sequence_length <= 0: - raise ValueError( - f"The model weight size {params_bytes} may be larger than available GPU memory " - f"size {gpu_size_bytes * gpu_memory_utilization} bytes." - ) - - if device.device_type == Device.kDLMetal: - # NOTE: Metal runtime has severe performance issues with large buffers. - # To work around the issue, we limit the KV cache capacity to 32768. - model_max_total_sequence_length = min(model_max_total_sequence_length, 32768) - - total_mem_usage_except_kv_cache = ( - params_bytes - + temp_func_bytes - + kv_aux_workspace_bytes - + model_workspace_bytes - + logit_processor_workspace_bytes - ) - return ( - total_mem_usage_except_kv_cache, - params_bytes, - kv_bytes_per_token, - kv_aux_workspace_bytes, - model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes, - int(model_max_total_sequence_length), - ) - - -def _estimate_mem_usage_and_max_history_size_for_rnn_state( # pylint: disable=too-many-arguments, too-many-locals, unused-argument - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_paths: List[str], - model_config_dicts: List[Dict[str, Any]], - max_num_sequence: int, - gpu_memory_utilization: Optional[float], -) -> Tuple[float, float, float, int]: - # Get single-card GPU size. - gpu_size_bytes = device.total_global_memory - if gpu_size_bytes is None: - raise ValueError("Cannot read total GPU global memory from device.") - if gpu_memory_utilization is None: - gpu_memory_utilization = 0.90 - - rnn_state_base_bytes = 0.0 # the memory usage for rnn state when history = 1 - param_bytes = 0.0 - temp_func_bytes = 0.0 - model_workspace_bytes = 0.0 - logit_processor_workspace_bytes = 0.0 - for model, model_config_path, model_config_dict in zip( - models, model_config_paths, model_config_dicts - ): - # Read metadata for the parameter size and the temporary memory size. - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-memory-usage-in-json", - "--mlc-chat-config", - model_config_path, - ] - usage_str = subprocess.check_output(cmd, universal_newlines=True) - usage_json = json.loads(usage_str) - param_bytes += usage_json["params_bytes"] - temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) - - model_config = model_config_dict["model_config"] - vocab_size = model_config_dict["vocab_size"] - head_size = model_config["head_size"] - num_heads = model_config["num_heads"] - num_layers = model_config["num_hidden_layers"] - hidden_size = model_config["hidden_size"] - prefill_chunk_size = model_config["prefill_chunk_size"] - logit_processor_workspace_bytes += ( - max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 - ) - - model_workspace_bytes += ( - prefill_chunk_size * 4 - + max_num_sequence * 4 - + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 - ) - - rnn_state_base_bytes += ( - max_num_sequence * hidden_size * num_layers * 2 * 2 - + max_num_sequence * num_heads * head_size * head_size * num_layers * 2 - ) - - max_history_size = int( - ( - gpu_size_bytes * gpu_memory_utilization - - logit_processor_workspace_bytes - - model_workspace_bytes - - param_bytes - - temp_func_bytes - ) - / rnn_state_base_bytes - ) - if max_history_size < 1: - raise ValueError( - f"Memory required by models may be larger than available GPU memory " - f"size {gpu_size_bytes * gpu_memory_utilization} bytes." - ) - - return ( - param_bytes, - model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes, - rnn_state_base_bytes, - max_history_size, - ) - - -def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[int, int, int]: - """Read the model config dictionaries, and return the maximum single - sequence length the models can support, the maximum prefill chunk - size the models can support, and the max batch size the models can support. - - Returns - ------- - model_max_single_sequence_length : int - The maximum single sequence length the models can support. - model_max_prefill_chunk_size : int - The maximum prefill chunk size the models can support. - model_max_batch_size : int - The max batch size the models can support. - """ - model_max_single_sequence_length = int(1e9) - model_max_prefill_chunk_size = int(1e9) - model_max_batch_size = int(1e9) - for i, config in enumerate(model_config_dicts): - runtime_context_window_size = config["context_window_size"] - compile_time_context_window_size = config["model_config"]["context_window_size"] - if runtime_context_window_size > compile_time_context_window_size: - raise ValueError( - f"Model {i}'s runtime context window size ({runtime_context_window_size}) is " - "larger than the context window size used at compile time " - f"({compile_time_context_window_size})" - ) - if runtime_context_window_size == -1 and compile_time_context_window_size != -1: - raise ValueError( - f"Model {i}'s runtime context window size (infinite) is " - "larger than the context window size used at compile time " - f"({compile_time_context_window_size})" - ) - if runtime_context_window_size != -1: - model_max_single_sequence_length = min( - model_max_single_sequence_length, runtime_context_window_size - ) - - runtime_prefill_chunk_size = config["prefill_chunk_size"] - compile_time_prefill_chunk_size = config["model_config"]["prefill_chunk_size"] - if runtime_prefill_chunk_size > compile_time_prefill_chunk_size: - raise ValueError( - f"Model {i}'s runtime prefill chunk size ({runtime_prefill_chunk_size}) is " - "larger than the prefill chunk size used at compile time " - f"({compile_time_prefill_chunk_size})" - ) - model_max_prefill_chunk_size = min(model_max_prefill_chunk_size, runtime_prefill_chunk_size) - - model_max_batch_size = min(model_max_batch_size, config["model_config"]["max_batch_size"]) - - assert model_max_prefill_chunk_size != int(1e9) - assert model_max_batch_size != int(1e9) - return model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size - - -def _infer_kv_cache_config_for_kv_cache( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - mode: Literal["local", "interactive", "server"], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - gpu_memory_utilization: Optional[float], - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_dicts: List[Dict[str, Any]], - model_config_paths: List[str], -) -> Tuple[int, int, int, KVStateKind, int]: - """Initialize the KV cache config with user input and GPU memory usage estimation. - The returned four integers are: - - max_batch_size - - max_total_sequence_length - - prefill_chunk_size - - kv_state_kind - - model_max_single_sequence_length - """ - ( - model_max_single_sequence_length, - model_max_prefill_chunk_size, - model_max_batch_size, - ) = _get_model_config_limit(model_config_dicts) - - def infer_args_under_mode( - mode: Literal["local", "interactive", "server"], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - ) -> Tuple[Tuple[int, int, int, KVStateKind], List[float]]: - logging_msg = "" - # - max_batch_size - if max_batch_size is None: - max_batch_size = ( - min(4, model_max_batch_size) - if mode == "local" - else (1 if mode == "interactive" else model_max_batch_size) - ) - logging_msg += f"max batch size is set to {max_batch_size}, " - else: - logging_msg += f"max batch size {max_batch_size} is specified by user, " - # - infer the maximum total sequence length that can fit GPU memory. - ( - total_mem_usage_except_kv_cache, - model_params_bytes, - kv_bytes_per_token, - kv_aux_workspace_bytes, - temp_workspace_bytes, - model_max_total_sequence_length, - ) = _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( - models, - device, - model_config_paths, - model_config_dicts, - max_batch_size, - gpu_memory_utilization, - ) - # - max_total_sequence_length - if max_total_sequence_length is None: - if mode == "local": - max_total_sequence_length = min( - model_max_total_sequence_length, model_max_single_sequence_length, 8192 - ) - elif mode == "interactive": - max_total_sequence_length = min( - model_max_total_sequence_length, model_max_single_sequence_length - ) - else: - max_total_sequence_length = min( - model_max_total_sequence_length, - max_batch_size * model_max_single_sequence_length, - ) - logging_msg += f"max KV cache token capacity is set to {max_total_sequence_length}, " - else: - logging_msg += ( - f"max KV cache token capacity {max_total_sequence_length} is specified by user. " - ) - # - prefill_chunk_size - if prefill_chunk_size is None: - if mode in ["local", "interactive"]: - prefill_chunk_size = min( - model_max_prefill_chunk_size, - model_max_total_sequence_length, - model_max_single_sequence_length, - ) - else: - prefill_chunk_size = model_max_prefill_chunk_size - logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " - else: - logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " - - if mode == "local": - logging_msg += ( - "We choose small max batch size and KV cache capacity to use less GPU memory." - ) - elif mode == "interactive": - logging_msg += "We fix max batch size to 1 for interactive single sequence use." - else: - logging_msg += ( - "We use as much GPU memory as possible (within the" - " limit of gpu_memory_utilization)." - ) - logger.info('Under mode "%s", %s', mode, logging_msg) - - # - Construct the KV cache config - # - Estimate total GPU memory usage on single GPU. - return ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - KVStateKind.ATTENTION, - ), [ - total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token, - model_params_bytes, - kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes, - temp_workspace_bytes, - ] - - # - Infer KV cache config and estimate memory usage for each mode. - local_kv_cache_config, local_mem_usage_list = infer_args_under_mode( - "local", max_batch_size, max_total_sequence_length, prefill_chunk_size - ) - interactive_kv_cache_config, interactive_mem_usage_list = infer_args_under_mode( - "interactive", max_batch_size, max_total_sequence_length, prefill_chunk_size - ) - server_kv_cache_config, server_mem_usage_list = infer_args_under_mode( - "server", max_batch_size, max_total_sequence_length, prefill_chunk_size - ) - - # - Select the config based on the actual mode. +def _print_engine_mode_logging_msg(mode: Literal["local", "interactive", "server"]) -> None: + """Print the logging info for engine mode selection.""" if mode == "local": - kv_cache_config = local_kv_cache_config - mem_usage_list = local_mem_usage_list + logger.info( + "The selected engine mode is %s. " + "We choose small max batch size and KV cache capacity to use less GPU memory.", + green(mode), + ) elif mode == "interactive": - kv_cache_config = interactive_kv_cache_config - mem_usage_list = interactive_mem_usage_list - else: - kv_cache_config = server_kv_cache_config - mem_usage_list = server_mem_usage_list - - logger.info( - 'The actual engine mode is "%s". So max batch size is %s, ' - "max KV cache token capacity is %s, prefill chunk size is %s.", - green(mode), - green(str(kv_cache_config[0])), - green(str(kv_cache_config[1])), - green(str(kv_cache_config[2])), - ) - - logger.info( - "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB). " - "The actual usage might be slightly larger than the estimated number.", - green("Estimated total single GPU memory usage"), - *list(mem_usage / 1024 / 1024 for mem_usage in mem_usage_list), - ) - # - Final messages - override_msg = "Please override the arguments if you have particular values to set." - if mode in ["local", "interactive"]: logger.info( - 'Please switch to mode "server" if you want to use more GPU memory ' - "and support more concurrent requests. %s", - override_msg, + "The selected engine mode is %s. " + "We fix max batch size to 1 for interactive single sequence use.", + green(mode), ) else: logger.info( - 'Please switch to mode "local" or "interactive" if you want to use less GPU memory ' - "or do not have many concurrent requests to process. %s", - override_msg, + "The selected engine mode is %s. " + "We use as much GPU memory as possible (within the limit " + "of gpu_memory_utilization).", + green(mode), ) - return *kv_cache_config, model_max_single_sequence_length - - -def _infer_kv_cache_config_for_rnn_state( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - mode: Literal["local", "interactive", "server"], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - max_history_size: Optional[int], - gpu_memory_utilization: Optional[float], - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_dicts: List[Dict[str, Any]], - model_config_paths: List[str], -) -> Tuple[int, int, int, KVStateKind, int]: - """Initialize the RNN state config with user input and GPU memory usage estimation. - The returned four integers are: - - max_batch_size - - max_total_sequence_length - - prefill_chunk_size - - kv_state_kind - - max_history_size - """ - logging_msg = "" - prefill_chunk_size = 0 - - if prefill_chunk_size is None: - prefill_chunk_size = min( - config["prefill_chunk_size"] if "prefill_chunk_size" in config else 4096 - for config in model_config_dicts - ) - logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " - else: - logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " - if max_batch_size is None: - max_batch_size = 1 if mode == "interactive" else 4 - logging_msg += f"max batch size is set to {max_batch_size}, " - else: - logging_msg += f"max batch size {max_batch_size} is specified by user, " - - if mode == "local": - logging_msg += ( - "We choose small max batch size and RNN state capacity to use less GPU memory." - ) - elif mode == "interactive": - logging_msg += "We fix max batch size to 1 for interactive single sequence use." - else: - logging_msg += ( - "We use as much GPU memory as possible (within the" " limit of gpu_memory_utilization)." + if mode != "local": + logger.info( + "If you have low concurrent requests and want to use less GPU memory, " + 'please select mode "local".' ) - logger.info('Under mode "%s", %s', mode, logging_msg) - - ( - model_param_bytes, - model_temp_bytes, - model_rnn_state_base_bytes, - model_max_history_size, - ) = _estimate_mem_usage_and_max_history_size_for_rnn_state( - models, - device, - model_config_paths, - model_config_dicts, - max_batch_size, - gpu_memory_utilization, - ) - if max_history_size is None: - max_history_size = model_max_history_size - else: - max_history_size = min(max_history_size, model_max_history_size) - max_total_sequence_length = 32768 - prefill_chunk_size = 0 - kind = KVStateKind.RNNSTATE - - logger.info( - "%s: %.2f MB (Parameters: %.2f MB. RNNState: %.2f MB. Temporary buffer: %.2f MB). " - "The actual usage might be slightly larger than the estimated number.", - green("Estimated total single GPU memory usage"), - (model_param_bytes + model_temp_bytes + model_rnn_state_base_bytes) / 1024 / 1024, - model_param_bytes / 1024 / 1024, - max_history_size * model_rnn_state_base_bytes / 1024 / 1024, - model_temp_bytes / 1024 / 1024, - ) - - return ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - kind, - max_history_size, - ) - - -def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - mode: Literal["local", "interactive", "server"], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - max_history_size: Optional[int], - gpu_memory_utilization: Optional[float], - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_dicts: List[Dict[str, Any]], - model_config_paths: List[str], -) -> Tuple[int, int, int, int, int, KVStateKind]: - """Initialize the cache config with user input and GPU memory usage estimation. - The returned four integers are: - - max_batch_size - - max_total_sequence_length - - prefill_chunk_size - - max_single_sequence_length - - max_history_size - - kv_state_kind - """ - if all("rwkv" not in model.model for model in models): - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - kv_state_kind, - max_single_sequence_length, - ) = _infer_kv_cache_config_for_kv_cache( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - gpu_memory_utilization, - models, - device, - model_config_dicts, - model_config_paths, + if mode != "interactive": + logger.info( + "If you don't have concurrent requests and only use the engine interactively, " + 'please select mode "interactive".' ) - max_history_size = 0 # KV cache doesn't need this - elif all("rwkv" in model.model for model in models): - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - kv_state_kind, - max_history_size, - ) = _infer_kv_cache_config_for_rnn_state( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - model_config_dicts, - model_config_paths, + if mode != "server": + logger.info( + "If you have high concurrent requests and want to maximize the GPU memory utilization, " + 'please select mode "server".' ) - max_single_sequence_length = max_total_sequence_length # RNN state doesn't need this - else: - raise ValueError("The models should be either all KV cache models or all RNN state models.") - return ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) - - -def _infer_generation_config( - model_config_dicts: List[Dict[str, Any]] -) -> List[Tuple[float, float, float, float]]: - """Infer the generation config from the model config dictionaries. - The returned four floats are: - - temperature - - top_p - - frequency_penalty - - presence_penalty - """ - generation_configs = [] - - for model_config in model_config_dicts: - temperature = model_config.get("temperature", 1.0) - top_p = model_config.get("top_p", 1.0) - frequency_penalty = model_config.get("frequency_penalty", 0.0) - presence_penalty = model_config.get("presence_penalty", 0.0) - generation_configs.append((temperature, top_p, frequency_penalty, presence_penalty)) - - return generation_configs @dataclass @@ -1000,7 +419,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals kind: Literal["async", "sync"], model: str, device: Union[str, tvm.runtime.Device], - model_lib_path: Optional[str], + model_lib: Optional[str], mode: Literal["local", "interactive", "server"], additional_models: Optional[List[str]], max_batch_size: Optional[int], @@ -1008,12 +427,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals prefill_chunk_size: Optional[int], max_history_size: Optional[int], gpu_memory_utilization: Optional[float], - speculative_mode: SpeculativeMode, + speculative_mode: Literal["disable", "small_draft", "eagle"], spec_draft_length: int, enable_tracing: bool, + verbose: bool, ) -> None: # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) + models = _parse_models(model, model_lib, additional_models) if isinstance(device, str): device = detect_device(device) assert isinstance(device, Device) @@ -1026,31 +446,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals # - Load the raw model config into dict self.model_config_dicts = [] for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] + model_info.model_lib = model_args[i][1] with open(model_config_paths[i], "r", encoding="utf-8") as file: self.model_config_dicts.append(json.load(file)) - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + # - Print logging info for regarding the mode selection. + if verbose: + _print_engine_mode_logging_msg(mode) # - Initialize engine state and engine. self.state = EngineState(enable_tracing) @@ -1063,35 +465,20 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "run_background_loop", "run_background_stream_back_loop", "reload", - "init_background_engine", + "init_threaded_engine", "exit_background_loop", - "debug_call_func_on_all_worker", + "get_default_generation_config", + "get_complete_engine_config", "stats", + "debug_call_func_on_all_worker", ] } self.tokenizer = Tokenizer(model_args[0][0]) - self._ffi["init_background_engine"]( + self._ffi["init_threaded_engine"]( device, self.state.get_request_stream_callback(kind), self.state.trace_recorder, ) - self._ffi["reload"]( - EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - max_history_size=max_history_size, - kv_state_kind=kv_state_kind, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ) - ) def _background_loop(): self._ffi["run_background_loop"]() @@ -1108,6 +495,31 @@ def _background_stream_back_loop(): self._background_stream_back_loop_thread.start() self._terminated = False + self._ffi["reload"]( + EngineConfig( + model=model_args[0][0], + model_lib=model_args[0][1], + additional_models=[model_arg[0] for model_arg in model_args[1:]], + additional_model_libs=[model_arg[1] for model_arg in model_args[1:]], + mode=mode, + gpu_memory_utilization=gpu_memory_utilization, + kv_cache_page_size=16, + max_num_sequence=max_batch_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + verbose=verbose, + ).asjson() + ) + self.default_generation_cfg_json_str: str = self._ffi["get_default_generation_config"]() + self.engine_config = EngineConfig.from_json(self._ffi["get_complete_engine_config"]()) + self.max_input_sequence_length = min( + self.engine_config.max_single_sequence_length, + self.engine_config.max_total_sequence_length, + ) + def terminate(self): """Terminate the engine.""" self._terminated = True @@ -1215,7 +627,6 @@ def process_chat_completion_request( # pylint: disable=too-many-arguments # Process generation config. Create request id. generation_cfg = protocol_utils.get_generation_config( request, - model_config, extra_stop_token_ids=conv_template.stop_token_ids, extra_stop_str=conv_template.stop_str, ) @@ -1336,11 +747,10 @@ def process_chat_completion_stream_output( # pylint: disable=too-many-arguments return response, num_completion_tokens -def process_completion_request( # pylint: disable=too-many-arguments +def process_completion_request( request: openai_api_protocol.CompletionRequest, request_id: str, engine_state: EngineState, - model_config: Dict[str, Any], tokenizer: Tokenizer, max_input_sequence_length: int, ) -> Tuple[List[int], GenerationConfig, int, Optional[openai_api_protocol.CompletionResponse]]: @@ -1392,7 +802,7 @@ def process_completion_request( # pylint: disable=too-many-arguments assert isinstance(prompt, list) # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config(request, model_config) + generation_cfg = protocol_utils.get_generation_config(request) # - Echo back the prompt. echo_response = None diff --git a/python/mlc_llm/serve/request.py b/python/mlc_llm/serve/request.py index 5c2d8ad196..44cdcd292c 100644 --- a/python/mlc_llm/serve/request.py +++ b/python/mlc_llm/serve/request.py @@ -1,6 +1,6 @@ """The request class in MLC LLM serving""" -from typing import List, Union +from typing import List, Optional, Union import tvm._ffi from tvm.runtime import Object @@ -28,6 +28,11 @@ class Request(Object): generation_config : GenerationConfig The sampling configuration which may contain temperature, top_p, repetition_penalty, max_gen_len, etc. + + default_generation_config_json_str : Optional[str] + The JSON string of the default generation config. + When a field in the input generation_config is not defined, + we use the value in the default generation config. """ def __init__( @@ -35,6 +40,7 @@ def __init__( request_id: str, inputs: Union[Data, List[Data]], generation_config: GenerationConfig, + default_generation_config_json_str: Optional[str] = None, ): if not isinstance(inputs, list): inputs = [inputs] @@ -43,6 +49,7 @@ def __init__( request_id, inputs, generation_config.asjson(), + default_generation_config_json_str, ) @property diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index 1d17f8e66a..dcecd25795 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -11,8 +11,6 @@ import requests from tvm.runtime import Device -from mlc_llm.serve.config import SpeculativeMode - class PopenServer: # pylint: disable=too-many-instance-attributes """The wrapper of MLC LLM server, which runs the server in @@ -23,14 +21,14 @@ def __init__( # pylint: disable=too-many-arguments model: str, device: Union[str, Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, max_total_sequence_length: Optional[int] = None, prefill_chunk_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable", spec_draft_length: int = 4, enable_tracing: bool = False, host: str = "127.0.0.1", @@ -38,7 +36,7 @@ def __init__( # pylint: disable=too-many-arguments ) -> None: """Please check out `python/mlc_llm/cli/serve.py` for the server arguments.""" self.model = model - self.model_lib_path = model_lib_path + self.model_lib = model_lib self.device = device self.mode = mode self.additional_models = additional_models @@ -59,8 +57,8 @@ def start(self) -> None: # pylint: disable=too-many-branches """ cmd = [sys.executable] cmd += ["-m", "mlc_llm", "serve", self.model] - if self.model_lib_path is not None: - cmd += ["--model-lib-path", self.model_lib_path] + if self.model_lib is not None: + cmd += ["--model-lib", self.model_lib] cmd += ["--device", self.device] if self.mode is not None: cmd += ["--mode", self.mode] @@ -72,10 +70,10 @@ def start(self) -> None: # pylint: disable=too-many-branches cmd += ["--max-total-seq-length", str(self.max_total_sequence_length)] if self.prefill_chunk_size is not None: cmd += ["--prefill-chunk-size", str(self.prefill_chunk_size)] - if self.speculative_mode != SpeculativeMode.DISABLE: + if self.speculative_mode != "disable": cmd += [ "--speculative-mode", - self.speculative_mode.name, + self.speculative_mode, "--spec-draft-length", str(self.spec_draft_length), ] diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 1be841cb08..39b09b36ce 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -14,10 +14,10 @@ import tvm from mlc_llm.serve import data -from mlc_llm.serve.config import EngineConfig, GenerationConfig, SpeculativeMode +from mlc_llm.serve.config import EngineConfig, GenerationConfig from mlc_llm.serve.engine_base import ( - _infer_kv_cache_config, _parse_models, + _print_engine_mode_logging_msg, _process_model_args, detect_device, ) @@ -58,13 +58,6 @@ class SyncMLCEngine: Parameters ---------- - models : Union[ModelInfo, List[ModelInfo]] - One or a list of model info (specifying which models to load and - which device to load to) to launch the engine. - - kv_cache_config : KVCacheConfig - The configuration of the paged KV cache. - request_stream_callback : Optional[Callable[[str, data.TokenData, Optional[str]], None]] The provided callback function to handle the generation output. It has the signature of `(str, data.TokenData, bool) -> None`, @@ -80,11 +73,11 @@ class SyncMLCEngine: the `set_request_stream_callback` method. Otherwise, the engine will raise exception. - engine_config : Optional[EngineConfig] - The Engine execution configuration. - enable_tracing : bool A boolean indicating if to enable event logging for requests. + + verbose : bool + A boolean indicating whether to print logging info in engine. """ def __init__( # pylint: disable=too-many-arguments,too-many-locals @@ -92,7 +85,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model: str, device: Union[str, tvm.runtime.Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", additional_models: Optional[List[str]] = None, max_batch_size: Optional[int] = None, @@ -101,12 +94,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals max_history_size: Optional[int] = None, gpu_memory_utilization: Optional[float] = None, enable_tracing: bool = False, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, + speculative_mode: Literal["disable", "small_draft", "eagle"] = "disable", spec_draft_length: int = 4, + verbose: bool = True, request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, ): # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) + models = _parse_models(model, model_lib, additional_models) if isinstance(device, str): device = detect_device(device) assert isinstance(device, tvm.runtime.Device) @@ -119,31 +113,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals # - Load the raw model config into dict self.model_config_dicts = [] for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] + model_info.model_lib = model_args[i][1] with open(model_config_paths[i], "r", encoding="utf-8") as file: self.model_config_dicts.append(json.load(file)) - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + # - Print logging info for regarding the mode selection. + if verbose: + _print_engine_mode_logging_msg(mode) self._ffi = _create_tvm_module( "mlc.serve.create_engine", @@ -156,6 +132,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "reset", "get_request_stream_callback", "set_request_stream_callback", + "get_default_generation_config", ], ) self.trace_recorder = EventTraceRecorder() if enable_tracing else None @@ -163,23 +140,25 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self._ffi["init"]( EngineConfig( model=model_args[0][0], - model_lib_path=model_args[0][1], + model_lib=model_args[0][1], additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], + additional_model_libs=[model_arg[1] for model_arg in model_args[1:]], + mode=mode, + gpu_memory_utilization=gpu_memory_utilization, kv_cache_page_size=16, max_num_sequence=max_batch_size, max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, prefill_chunk_size=prefill_chunk_size, max_history_size=max_history_size, - kv_state_kind=kv_state_kind, speculative_mode=speculative_mode, spec_draft_length=spec_draft_length, - ), + verbose=verbose, + ).asjson(), device, request_stream_callback, self.trace_recorder, ) + self.default_generation_cfg_json_str: str = self._ffi["get_default_generation_config"]() self.tokenizer = Tokenizer(model_args[0][0]) def generate( # pylint: disable=too-many-locals @@ -304,6 +283,7 @@ def convert_to_data(prompt: Union[str, List[int], List[data.Data]]) -> List[data request_id=str(req_id), inputs=input_data, generation_config=generation_cfg, + default_generation_config_json_str=self.default_generation_cfg_json_str, ) ) diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index 4f1cfe103d..8ff370e9d9 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -144,7 +144,7 @@ class DebugChat: # pylint: disable=too-many-instance-attributes, too-few-public dc = DebugChat( model="./dist/Llama-2-7b-chat-hf-q4f16_1-MLC", debug_dir=Path("./debug-llama-2"), - model_lib_path="./dist/llama-2-7b-chat-q4f16_1-metal.so", + model_lib="./dist/llama-2-7b-chat-q4f16_1-metal.so", ) dc.generate("hello world", 3) """ @@ -152,7 +152,7 @@ class DebugChat: # pylint: disable=too-many-instance-attributes, too-few-public def __init__( # pylint: disable=too-many-arguments self, model: str, - model_lib_path: str, + model_lib: str, debug_dir: Path, device: Optional[str] = "auto", chat_config: Optional[ChatConfig] = None, @@ -169,7 +169,7 @@ def __init__( # pylint: disable=too-many-arguments folder. In the former case, we will use the provided name to search for the model folder over possible paths. - model_lib_path : str + model_lib : str The full path to the model library file to use (e.g. a ``.so`` file). debug_dir: Path @@ -213,7 +213,7 @@ def instrument( debug_instrument if debug_instrument else DefaultDebugInstrument(debug_dir / "prefill") ) self.mod, self.params, self.metadata = _get_tvm_module( - model, model_lib_path, self.device, self.instrument + model, model_lib, self.device, self.instrument ) self.model_path, self.config_file_path = _get_model_path(model) self.chat_config = _get_chat_config(self.config_file_path, chat_config) @@ -427,7 +427,7 @@ def main(): required=True, ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, help="The full path to the model library file to use (e.g. a ``.so`` file).", required=True, @@ -447,7 +447,7 @@ def main(): parsed = parser.parse_args() dc = DebugChat( model=parsed.model, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, debug_dir=Path(parsed.debug_dir), device=parsed.device, ) diff --git a/python/mlc_llm/testing/debug_compare.py b/python/mlc_llm/testing/debug_compare.py index b3487e3e48..d257d0f3b0 100644 --- a/python/mlc_llm/testing/debug_compare.py +++ b/python/mlc_llm/testing/debug_compare.py @@ -139,7 +139,7 @@ def get_instrument(args): if args.cmp_device is None: assert args.cmp_lib_path is None, "cmp_lib_path must be None if cmp_device is None" args.cmp_device = args.device - args.cmp_lib_path = args.model_lib_path + args.cmp_lib_path = args.model_lib if args.cmp_device == "iphone": assert args.cmp_lib_path.endswith(".dylib"), "Require a dylib file for iPhone" @@ -194,7 +194,7 @@ def main(): required=True, ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, help="The full path to the model library file to use (e.g. a ``.so`` file).", required=True, @@ -230,7 +230,7 @@ def main(): instrument = get_instrument(parsed) debug_chat = DebugChat( model=parsed.model, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, debug_dir=Path(parsed.debug_dir), device=parsed.device, debug_instrument=instrument, diff --git a/rust/src/chat_module.rs b/rust/src/chat_module.rs index b90549d06c..e8c1893a98 100644 --- a/rust/src/chat_module.rs +++ b/rust/src/chat_module.rs @@ -213,24 +213,24 @@ fn get_chat_config(config_file_path: &Path) -> result::Result, device_name: &str, + model: &str, model_path: &Path, chat_config: &ChatConfig, model_lib: Option<&str>, device_name: &str, config_file_path: &Path, ) -> PathBuf { - // 1. Use user's model_lib_path if provided - if let Some(lib_path) = model_lib_path { + // 1. Use user's model_lib if provided + if let Some(lib_path) = model_lib { let path = Path::new(lib_path); if path.is_file() { info!("Using library model: {:?}", path); return path.to_path_buf(); } else { - panic!("The `model_lib_path` you passed in is not a file: {:?}.", lib_path); + panic!("The `model_lib` you passed in is not a file: {:?}.", lib_path); } } @@ -290,7 +290,7 @@ fn get_lib_module_path( } err_msg += &format!( "If you would like to directly specify the model library path, you may \ - consider passing in the `ChatModule.model_lib_path` parameter." + consider passing in the `ChatModule.model_lib` parameter." ); panic!("{}", err_msg); @@ -323,7 +323,7 @@ pub struct ChatModule { } impl ChatModule { - pub fn new(model: &str, device: &str, model_lib_path: Option<&str>) -> Result { + pub fn new(model: &str, device: &str, model_lib: Option<&str>) -> Result { let device_err_msg = format!( "Invalid device name: {}. Please enter the device in the form \ 'device_name:device_id' or 'device_name', where 'device_name' needs to be \ @@ -362,11 +362,11 @@ impl ChatModule { let chat_config = get_chat_config(&config_file_path).unwrap(); // 4. Look up the model library - let model_lib_path = get_lib_module_path( + let model_lib = get_lib_module_path( model, &model_path, &chat_config, - model_lib_path, + model_lib, device_name, &config_file_path, ); @@ -375,7 +375,7 @@ impl ChatModule { chat_module: m, chat_config, }; - let model_lib_str = model_lib_path.as_path().display().to_string(); + let model_lib_str = model_lib.as_path().display().to_string(); let model_path_str = model_path.as_path().display().to_string(); chat_mod.reload(&model_lib_str, &model_path_str, "").unwrap(); Ok(chat_mod) diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index c52571b522..b438c2a352 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union +from typing import Dict, List, Optional from mlc_llm.json_ffi import JSONFFIEngine @@ -120,12 +120,10 @@ def test_reload_reset_unload(): def test_function_calling(): model = "dist/gorilla-openfunctions-v1-q4f16_1-MLC" - model_lib_path = ( - "dist/gorilla-openfunctions-v1-q4f16_1-MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so" - ) + model_lib = "dist/gorilla-openfunctions-v1-q4f16_1-MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so" engine = JSONFFIEngine( model, - model_lib_path=model_lib_path, + model_lib=model_lib, max_total_sequence_length=1024, ) diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index c89a9e2c38..da9b486476 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -10,14 +10,14 @@ def _parse_args(): args = argparse.ArgumentParser() - args.add_argument("--model-lib-path", type=str) + args.add_argument("--model-lib", type=str) args.add_argument("--device", type=str, default="auto") args.add_argument("--batch-size", type=int, default=80) args.add_argument("--max-total-seq-length", type=int) args.add_argument("--seed", type=int, default=0) parsed = args.parse_args() - parsed.model = os.path.dirname(parsed.model_lib_path) + parsed.model = os.path.dirname(parsed.model_lib) assert parsed.batch_size % 16 == 0 return parsed @@ -44,7 +44,7 @@ def benchmark(args: argparse.Namespace): engine = SyncMLCEngine( model=args.model, device=args.device, - model_lib_path=args.model_lib_path, + model_lib=args.model_lib, mode="server", max_batch_size=args.batch_size, max_total_sequence_length=args.max_total_seq_length, diff --git a/tests/python/serve/server/conftest.py b/tests/python/serve/server/conftest.py index e425494231..1ba0d096e8 100644 --- a/tests/python/serve/server/conftest.py +++ b/tests/python/serve/server/conftest.py @@ -9,15 +9,15 @@ @pytest.fixture(scope="session") def served_model() -> Tuple[str, str]: - model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") - if model_lib_path is None: + model_lib = os.environ.get("MLC_SERVE_MODEL_LIB") + if model_lib is None: raise ValueError( 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' "Please set it to model lib compiled by MLC LLM " "(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`)." ) - model = os.path.dirname(model_lib_path) - return model, model_lib_path + model = os.path.dirname(model_lib) + return model, model_lib @pytest.fixture(scope="session") @@ -25,7 +25,7 @@ def launch_server(served_model): # pylint: disable=redefined-outer-name """A pytest session-level fixture which launches the server in a subprocess.""" server = PopenServer( model=served_model[0], - model_lib_path=served_model[1], + model_lib=served_model[1], enable_tracing=True, ) diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py index e4f64d2ce4..db2d601f11 100644 --- a/tests/python/serve/server/test_server.py +++ b/tests/python/serve/server/test_server.py @@ -1287,14 +1287,14 @@ def test_debug_dump_event_trace( if __name__ == "__main__": - model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") - if model_lib_path is None: + model_lib = os.environ.get("MLC_SERVE_MODEL_LIB") + if model_lib is None: raise ValueError( 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' "Please set it to model lib compiled by MLC LLM " "(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`)." ) - MODEL = (os.path.dirname(model_lib_path), model_lib_path) + MODEL = (os.path.dirname(model_lib), model_lib) test_openai_v1_models(MODEL, None) diff --git a/tests/python/serve/server/test_server_function_call.py b/tests/python/serve/server/test_server_function_call.py index 3fff27b938..b55fe10455 100644 --- a/tests/python/serve/server/test_server_function_call.py +++ b/tests/python/serve/server/test_server_function_call.py @@ -195,15 +195,15 @@ def test_openai_v1_chat_completion_function_call( if __name__ == "__main__": - model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") - if model_lib_path is None: + model_lib = os.environ.get("MLC_SERVE_MODEL_LIB") + if model_lib is None: raise ValueError( 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' "Please set it to model lib compiled by MLC LLM " "(e.g., `./dist/gorilla-openfunctions-v1-q4f16_1_MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so`) " "which supports function calls." ) - MODEL = (os.path.dirname(model_lib_path), model_lib_path) + MODEL = (os.path.dirname(model_lib), model_lib) for msg in CHAT_COMPLETION_MESSAGES: test_openai_v1_chat_completion_function_call(MODEL, None, stream=False, messages=msg) diff --git a/tests/python/serve/server/test_server_image.py b/tests/python/serve/server/test_server_image.py index 9b016224e4..d1a79c5445 100644 --- a/tests/python/serve/server/test_server_image.py +++ b/tests/python/serve/server/test_server_image.py @@ -239,8 +239,8 @@ def test_openai_v1_chat_completions( if __name__ == "__main__": - model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") - if model_lib_path is None: + model_lib = os.environ.get("MLC_SERVE_MODEL_LIB") + if model_lib is None: raise ValueError( 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' "Please set it to model lib compiled by MLC LLM " @@ -249,9 +249,9 @@ def test_openai_v1_chat_completions( model = os.environ.get("MLC_SERVE_MODEL") if model is None: - MODEL = (os.path.dirname(model_lib_path), model_lib_path) + MODEL = (os.path.dirname(model_lib), model_lib) else: - MODEL = (model, model_lib_path) + MODEL = (model, model_lib) for msg in CHAT_COMPLETION_MESSAGES: test_openai_v1_chat_completions(MODEL, None, stream=False, messages=msg) diff --git a/tests/python/serve/test_radix_tree.py b/tests/python/serve/test_radix_tree.py index cea421cd95..06d2196d67 100644 --- a/tests/python/serve/test_radix_tree.py +++ b/tests/python/serve/test_radix_tree.py @@ -1,6 +1,3 @@ -from tvm import TVMError -from tvm.runtime import ShapeTuple - from mlc_llm.serve import PagedRadixTree diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index 6e3835238a..2c431ebcf5 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -22,10 +22,10 @@ async def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" async_engine = AsyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) @@ -79,10 +79,10 @@ async def generate_task( async def test_chat_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" async_engine = AsyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) @@ -131,10 +131,10 @@ async def generate_task(prompt: str, request_id: str): async def test_chat_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" async_engine = AsyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) @@ -182,10 +182,10 @@ async def generate_task(prompt: str, request_id: str): async def test_completion(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" async_engine = AsyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) @@ -234,10 +234,10 @@ async def generate_task(prompt: str, request_id: str): async def test_completion_non_stream(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" async_engine = AsyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index c3963af613..926aa87f60 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -3,7 +3,7 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncMLCEngine, GenerationConfig, SpeculativeMode +from mlc_llm.serve import AsyncMLCEngine, GenerationConfig prompts = [ "What is the meaning of life?", @@ -22,17 +22,15 @@ async def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) + small_model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" async_engine = AsyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.SMALL_DRAFT, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="small_draft", ) num_requests = 10 diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 37d1833b14..dc67f3c91e 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -31,11 +31,11 @@ ] -def create_engine(model: str, model_lib_path: str): +def create_engine(model: str, model_lib: str): if "rwkv" in model: return MLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_batch_size=8, max_history_size=1, @@ -43,15 +43,15 @@ def create_engine(model: str, model_lib_path: str): else: return MLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_engine_generate(model: str, model_lib_path: str): - engine = create_engine(model, model_lib_path) +@pytest.mark.parametrize("model,model_lib", test_models) +def test_engine_generate(model: str, model_lib: str): + engine = create_engine(model, model_lib) num_requests = 10 max_tokens = 256 @@ -81,10 +81,10 @@ def test_engine_generate(model: str, model_lib_path: str): del engine -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_chat_completion(model: str, model_lib_path: str): +@pytest.mark.parametrize("model,model_lib", test_models) +def test_chat_completion(model: str, model_lib: str): # Create engine - engine = create_engine(model, model_lib_path) + engine = create_engine(model, model_lib) num_requests = 2 max_tokens = 64 @@ -119,9 +119,9 @@ def test_chat_completion(model: str, model_lib_path: str): del engine -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_chat_completion_non_stream(model: str, model_lib_path: str): - engine = create_engine(model, model_lib_path) +@pytest.mark.parametrize("model,model_lib", test_models) +def test_chat_completion_non_stream(model: str, model_lib: str): + engine = create_engine(model, model_lib) num_requests = 2 max_tokens = 64 @@ -155,9 +155,9 @@ def test_chat_completion_non_stream(model: str, model_lib_path: str): del engine -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_completion(model: str, model_lib_path: str): - engine = create_engine(model, model_lib_path) +@pytest.mark.parametrize("model,model_lib", test_models) +def test_completion(model: str, model_lib: str): + engine = create_engine(model, model_lib) num_requests = 2 max_tokens = 128 @@ -192,9 +192,9 @@ def test_completion(model: str, model_lib_path: str): del engine -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_completion_non_stream(model: str, model_lib_path: str): - engine = create_engine(model, model_lib_path) +@pytest.mark.parametrize("model,model_lib", test_models) +def test_completion_non_stream(model: str, model_lib: str): + engine = create_engine(model, model_lib) num_requests = 2 max_tokens = 128 @@ -229,9 +229,9 @@ def test_completion_non_stream(model: str, model_lib_path: str): if __name__ == "__main__": - for model, model_lib_path in test_models: - test_engine_generate(model, model_lib_path) - test_chat_completion(model, model_lib_path) - test_chat_completion_non_stream(model, model_lib_path) - test_completion(model, model_lib_path) - test_completion_non_stream(model, model_lib_path) + for model, model_lib in test_models: + test_engine_generate(model, model_lib) + test_chat_completion(model, model_lib) + test_chat_completion_non_stream(model, model_lib) + test_completion(model, model_lib) + test_completion_non_stream(model, model_lib) diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index b764c62cd2..2b3ce29c7f 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -17,12 +17,12 @@ "Generate a JSON with 5 elements:", ] model_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" -model_lib_path = "dist/libs/Llama-2-7b-chat-hf-q4f16_1-cuda.so" +model_lib = "dist/libs/Llama-2-7b-chat-hf-q4f16_1-cuda.so" def test_batch_generation_with_grammar(): # Create engine - engine = SyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + engine = SyncMLCEngine(model=model_path, model_lib=model_lib, mode="server") prompt_len = len(prompts_list) prompts = prompts_list * 3 @@ -69,7 +69,7 @@ def test_batch_generation_with_grammar(): def test_batch_generation_with_schema(): # Create engine - engine = SyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + engine = SyncMLCEngine(model=model_path, model_lib=model_lib, mode="server") prompt = ( "Generate a json containing three fields: an integer field named size, a " @@ -121,7 +121,7 @@ class Schema(BaseModel): async def run_async_engine(): # Create engine - async_engine = AsyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") + async_engine = AsyncMLCEngine(model=model_path, model_lib=model_lib, mode="server") prompts = prompts_list * 20 diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py index 59e8c97196..01bb1967e0 100644 --- a/tests/python/serve/test_serve_engine_image.py +++ b/tests/python/serve/test_serve_engine_image.py @@ -12,10 +12,10 @@ def get_test_image(config) -> data.ImageData: def test_engine_generate(): # Create engine model = "dist/llava-1.5-7b-hf-q4f16_1-MLC/params" - model_lib_path = "dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so" + model_lib = "dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index 33c06b1c5e..3f1fa5107c 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -4,13 +4,7 @@ import numpy as np -from mlc_llm.serve import ( - GenerationConfig, - Request, - RequestStreamOutput, - SpeculativeMode, - data, -) +from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data from mlc_llm.serve.sync_engine import SyncMLCEngine prompts = [ @@ -85,18 +79,16 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) + small_model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.SMALL_DRAFT, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="small_draft", request_stream_callback=fcallback, ) @@ -153,18 +145,16 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" small_model = "dist/Eagle-llama2-7b-chat-q0f16-MLC" - small_model_lib_path = ( - "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" - ) + small_model_lib = "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.EAGLE, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="eagle", spec_draft_length=2, request_stream_callback=fcallback, ) @@ -236,19 +226,17 @@ def step(self) -> None: # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) + small_model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" timer = CallbackTimer() engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.SMALL_DRAFT, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="small_draft", request_stream_callback=timer.callback_getter(), ) @@ -322,19 +310,19 @@ def step(self) -> None: # Create engine model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" small_model = "dist/Eagle-llama2-7b-chat-q4f16_1-MLC" - small_model_lib_path = ( + small_model_lib = ( "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) timer = CallbackTimer() engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.EAGLE, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="eagle", request_stream_callback=timer.callback_getter(), ) @@ -379,19 +367,17 @@ def compare_output_text(output_text1, output_text2): def test_engine_generate(compare_precision=False): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) + small_model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.SMALL_DRAFT, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="small_draft", ) num_requests = 10 @@ -405,7 +391,7 @@ def test_engine_generate(compare_precision=False): ) engine_single_model = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, ) @@ -441,18 +427,18 @@ def test_engine_generate(compare_precision=False): def test_engine_eagle_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" small_model = "dist/Eagle-llama2-7b-chat-q4f16_1-MLC" - small_model_lib_path = ( + small_model_lib = ( "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" ) engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.EAGLE, + additional_models=[small_model + ":" + small_model_lib], + speculative_mode="eagle", ) num_requests = 10 @@ -493,10 +479,10 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" + model_lib = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, request_stream_callback=fcallback, @@ -556,24 +542,22 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" + model_lib = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) + small_model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" # If Flashinfer allows head_dim < 128, we can test this model # small_model = "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC" - # small_model_lib_path = ( + # small_model_lib = ( # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so" # ) spec_engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], + additional_models=[small_model + ":" + small_model_lib], spec_draft_length=6, - speculative_mode=SpeculativeMode.SMALL_DRAFT, + speculative_mode="small_draft", request_stream_callback=fcallback, ) @@ -631,19 +615,17 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" small_model = "dist/Eagle-llama2-7b-chat-q0f16-MLC" - small_model_lib_path = ( - "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" - ) + small_model_lib = "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" spec_engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], + additional_models=[small_model + ":" + small_model_lib], spec_draft_length=6, - speculative_mode=SpeculativeMode.EAGLE, + speculative_mode="eagle", request_stream_callback=fcallback, ) diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py index f68f48b7c5..8c574f875f 100644 --- a/tests/python/serve/test_serve_sync_engine.py +++ b/tests/python/serve/test_serve_sync_engine.py @@ -79,10 +79,10 @@ def fcallback(delta_outputs: List[RequestStreamOutput]): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", request_stream_callback=fcallback, ) @@ -155,10 +155,10 @@ def step(self) -> None: # Create engine timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", request_stream_callback=timer.callback_getter(), ) @@ -236,10 +236,10 @@ def step(self) -> None: # Create engine timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", request_stream_callback=timer.callback_getter(), ) @@ -322,10 +322,10 @@ def all_finished(self) -> bool: # Create engine timer = CallbackTimer() model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", request_stream_callback=timer.callback_getter(), ) @@ -364,10 +364,10 @@ def all_finished(self) -> bool: def test_engine_generate(): # Create engine model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" + model_lib = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" engine = SyncMLCEngine( model=model, - model_lib_path=model_lib_path, + model_lib=model_lib, mode="server", max_total_sequence_length=4096, )