Skip to content

Commit

Permalink
Merge branch 'mlc-ai:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
JackWeiw authored Jul 2, 2024
2 parents e97204a + 0575b92 commit b64f91c
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 28 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tokenizers-cpp
10 changes: 8 additions & 2 deletions android/mlc4j/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,27 @@ add_library(model_android STATIC IMPORTED)
set_target_properties(model_android PROPERTIES IMPORTED_LOCATION ${ANDROID_BIN_DIR}/lib/libmodel_android.a)

add_library(tvm4j_runtime_packed SHARED ${TVM_SOURCE_DIR}/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc)
set(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} TVM_SOURCE_DIR=${TVM_SOURCE_DIR})

target_include_directories(tvm4j_runtime_packed PUBLIC
${JNI_INCLUDE_DIRS}
${JNI_HEADER}
${ANDROID_DIR}/src/cpp
${TVM_SOURCE_DIR}/3rdparty/dlpack/include
${TVM_SOURCE_DIR}/3rdparty/dmlc-core/include
${TVM_SOURCE_DIR}/3rdparty/OpenCL-Headers
${TVM_SOURCE_DIR}/3rdparty/picojson
${TVM_SOURCE_DIR}/include
)
target_compile_definitions(tvm4j_runtime_packed PUBLIC ${MLC_LLM_COMPILE_DEFS})
target_compile_definitions(tvm4j_runtime_packed PUBLIC TVM_RELAX_VM_ENABLE_PROFILER=0)

set(MLC_ENABLE_SENTENCEPIECE_TOKENIZER OFF)
target_link_libraries(tvm4j_runtime_packed
sentencepiece-static
tokenizers_c
tokenizers_cpp
log
-Wl,--whole-archive
tvm_runtime
mlc_llm_static
model_android
-Wl,--no-whole-archive
Expand All @@ -70,5 +74,7 @@ target_link_libraries(tvm4j_runtime_packed
target_compile_definitions(tvm4j_runtime_packed PUBLIC TVM4J_ANDROID)
add_dependencies(tvm4j_runtime_packed tvm4j_core)

target_compile_definitions(mlc_llm_objs PUBLIC MLC_SINGLE_GPU_ONLY)

install_jar(tvm4j_core output)
install(TARGETS tvm4j_runtime_packed LIBRARY DESTINATION output/${ANDROID_ABI})
39 changes: 38 additions & 1 deletion android/mlc4j/src/cpp/tvm_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,44 @@
#include <dlfcn.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/c_runtime_api.h>

#define STRINGIFY_MACRO(x) STR(x)
#define STR(x) #x
#define EXPAND(x) x
#define CONCAT(n1, n2) STRINGIFY_MACRO(EXPAND(n1) EXPAND(n2))

// clang-format off
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/c_runtime_api.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/container.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/cpu_device_api.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/file_utils.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/library_module.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/logging.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/module.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/ndarray.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/object.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/opencl/opencl_device_api.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/opencl/opencl_module.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/opencl/texture_pool.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/profiling.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/registry.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/source_utils.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/system_library.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/thread_pool.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/threading_backend.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/workspace_pool.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/memory/memory_manager.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/nvtx.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/relax_vm/builtin.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/relax_vm/bytecode.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/relax_vm/executable.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/relax_vm/kv_state.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/relax_vm/ndarray_cache_support.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/relax_vm/paged_kv_cache.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/relax_vm/rnn_state.cc)
#include CONCAT(TVM_SOURCE_DIR,/src/runtime/relax_vm/vm.cc)
// clang-format on

static_assert(TVM_LOG_CUSTOMIZE == 1, "TVM_LOG_CUSTOMIZE must be 1");

Expand Down
3 changes: 3 additions & 0 deletions cpp/loader/multi_gpu_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
* \file multi_gpu_loader.cc
* \brief Implementation of a multi-GPU loader with loading-time sharding.
*/
#ifndef MLC_SINGLE_GPU_ONLY
#include <picojson.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/disco/builtin.h>
Expand Down Expand Up @@ -265,3 +266,5 @@ TVM_REGISTER_GLOBAL("mlc.loader.LoadMultiGPUPresharded").set_body_typed(LoadMult
} // namespace loader
} // namespace llm
} // namespace mlc

#endif // MLC_SINGLE_GPU_ONLY
12 changes: 9 additions & 3 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ class EngineImpl : public Engine {

Optional<Session> session = NullOpt;
if (num_shards > 1) {
#ifndef MLC_SINGLE_GPU_ONLY
constexpr const char* f_create_process_pool = "runtime.disco.create_process_pool";
if (Registry::Get(f_create_process_pool) == nullptr) {
LOG(FATAL) << "Cannot find process launcher `" << f_create_process_pool << "`. "
Expand All @@ -695,6 +696,9 @@ class EngineImpl : public Engine {
}
session = Session::ProcessSession(num_shards, f_create_process_pool, "mlc_llm.cli.worker");
session.value()->InitCCL(ccl, ShapeTuple(device_ids));
#else
LOG(FATAL) << "MLC_SINGLE_GPU_ONLY is specified. Multi-GPU is not enabled.";
#endif // MLC_SINGLE_GPU_ONLY
}
return {session, num_shards};
}
Expand Down Expand Up @@ -782,9 +786,11 @@ class EngineImpl : public Engine {
for (Model model : models_) {
host_cpu_usage += model->EstimateHostCPURequirement();
}
int max_concurrency = tvm::runtime::threading::MaxConcurrency();
tvm::runtime::threading::SetMaxConcurrency(
std::min(std::max(max_concurrency - host_cpu_usage, 1), engine_config_->max_num_sequence));
if (host_cpu_usage > 1) {
int max_concurrency = tvm::runtime::threading::MaxConcurrency();
tvm::runtime::threading::SetMaxConcurrency(std::min(
std::max(max_concurrency - host_cpu_usage, 1), engine_config_->max_num_sequence));
}
}

/*! \brief Create a grammar init context according to the response format. If the response format
Expand Down
63 changes: 57 additions & 6 deletions cpp/serve/engine_actions/batch_prefill_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ BatchPrefillBaseActionObj::BatchPrefillBaseActionObj(Array<Model> models,
*/
std::vector<BatchPrefillBaseActionObj::PrefillInput>
BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) {
// Preempt request state entries when decode cannot apply.
std::vector<RequestStateEntry> running_rsentries;
{
NVTXScopedRange nvtx_scope("BatchDecode getting requests");
running_rsentries = GetRunningRequestStateEntries(estate);
while (!(running_rsentries.size() <= models_[0]->GetNumAvailablePages())) {
if (estate->prefix_cache->TryFreeMemory()) continue;
RequestStateEntry preempted =
PreemptLastRunningRequestStateEntry(estate, models_, NullOpt, trace_recorder_);
if (preempted.same_as(running_rsentries.back())) {
running_rsentries.pop_back();
}
}
}

if (estate->waiting_queue.empty()) {
// No request to prefill.
return {};
Expand All @@ -44,13 +59,20 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) {
std::vector<std::vector<PrefillInput>> prefill_inputs_for_all_models;
prefill_inputs_for_all_models.reserve(models_.size());

int num_decode_inputs = static_cast<int>(running_rsentries.size());

// We first collect the inputs that can be prefilled for each model.
// Then we make a reduction to return the maximum common inputs.
for (int i = 0; i < static_cast<int>(models_.size()); ++i) {
std::vector<PrefillInput> prefill_inputs;
// - Try to prefill pending requests.
// - Try to prefill pending requests, in addition to reserved decode requests.
int total_input_length = 0;
int total_required_pages = 0;
int total_required_pages = num_decode_inputs;
// Reserve decode requests first.
for (const RequestStateEntry& rsentry : running_rsentries) {
prefill_inputs.push_back({rsentry, rsentry->mstates[i]->num_tokens_for_next_decode, 0});
total_input_length += rsentry->mstates[i]->num_tokens_for_next_decode;
}
int num_available_pages = models_[i]->GetNumAvailablePages();
int num_running_rsentries = GetRunningRequestStateEntries(estate).size();
int current_total_seq_len = models_[i]->GetCurrentTotalSequenceLength();
Expand Down Expand Up @@ -177,7 +199,8 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) {
std::min(num_prefill_inputs, static_cast<int>(prefill_inputs_for_all_models[i].size()));
}

if (num_prefill_inputs == 0) {
// If all inputs are decode inputs, since no prefill inputs can be added, skip prefill action
if (num_prefill_inputs == num_decode_inputs) {
return {};
}

Expand Down Expand Up @@ -259,6 +282,17 @@ bool BatchPrefillBaseActionObj::CanPrefill(EngineState estate, int num_prefill_r
std::pair<Array<Data>, int> BatchPrefillBaseActionObj::ChunkPrefillInputData(
const RequestModelState& mstate, int max_prefill_length) {
if (mstate->inputs.empty()) {
// If the request is a hybrid decode request
ICHECK(mstate->num_tokens_for_next_decode > 0);
int num_tokens = mstate->num_tokens_for_next_decode;
mstate->num_tokens_for_next_decode = 0;
std::vector<int32_t> decode_tokens;
decode_tokens.reserve(num_tokens);
for (auto begin = mstate->committed_tokens.end() - num_tokens;
begin != mstate->committed_tokens.end(); ++begin) {
decode_tokens.push_back(begin->GetTokenId());
}
return {{TokenData(decode_tokens)}, num_tokens};
}
ICHECK(!mstate->inputs.empty());
std::vector<Data> inputs;
Expand Down Expand Up @@ -378,11 +412,14 @@ std::vector<Request> BatchPrefillBaseActionObj::RemoveProcessedRequests(
break;
}
}
if (!pending_state_exists) {
if (!pending_state_exists &&
std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), rsentry->request) !=
estate->waiting_queue.end()) {
auto it =
std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), rsentry->request);
ICHECK(it != estate->waiting_queue.end());
estate->waiting_queue.erase(it);
if (it != estate->waiting_queue.end()) {
estate->waiting_queue.erase(it);
}
}
}
return processed_requests;
Expand All @@ -393,6 +430,20 @@ void BatchPrefillBaseActionObj::UpdateRequestStateEntriesWithSampleResults(
const std::vector<bool>& rsentry_activated, const std::vector<SampleResult>& sample_results) {
auto tnow = std::chrono::high_resolution_clock::now();
for (int i = 0; i < static_cast<int>(rsentries_for_sample.size()); ++i) {
// If the request is a hybrid decode request
if (rsentries_for_sample[i]->status == RequestStateStatus::kAlive &&
rsentries_for_sample[i]->child_indices.empty() &&
rsentries_for_sample[i]->mstates[0]->inputs.empty()) {
for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) {
CHECK(!mstate->require_retokenization_in_next_decode);
mstate->CommitToken(sample_results[i]);
// live update the output metrics
rsentries_for_sample[i]->rstate->metrics.completion_tokens += 1;
rsentries_for_sample[i]->rstate->metrics.prefill_end_time_point = tnow;
}
continue;
}

// Update all model states of the request state entry.
for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) {
mstate->CommitToken(sample_results[i]);
Expand Down
4 changes: 2 additions & 2 deletions docs/get_started/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,6 @@ What to Do Next
- :ref:`deploy-android`
- :ref:`deploy-ide-integration`

- `Convert model weight to MLC format <convert-weights-via-MLC>`_, if you want to run your own models.
- `Compile model libraries <compile-model-libraries>`_, if you want to deploy to web/iOS/Android or control the model optimizations.
- :ref:`convert-weights-via-MLC`, if you want to run your own models.
- :ref:`compile-model-libraries`, if you want to deploy to web/iOS/Android or control the model optimizations.
- Report any problem or ask any question: open new issues in our `GitHub repo <https://github.com/mlc-ai/mlc-llm/issues>`_.
6 changes: 4 additions & 2 deletions python/mlc_llm/model/gemma/gemma_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
"""Configuration of the Gemma model."""

hidden_size: int
hidden_act: str
intermediate_size: int
attention_bias: bool
num_attention_heads: int
Expand All @@ -31,6 +30,7 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
num_hidden_layers: int
rms_norm_eps: float
vocab_size: int
hidden_activation: Optional[str] = None
position_embedding_base: int = 0
context_window_size: int = 0
prefill_chunk_size: int = 0
Expand All @@ -39,7 +39,9 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
if self.hidden_act not in ("gelu", "gelu_pytorch_tanh"):
if self.hidden_activation is None:
self.hidden_activation = self.kwargs.get("hidden_act", None)
if self.hidden_activation not in ("gelu", "gelu_pytorch_tanh"):
raise ValueError("Only GeLU is supported as the activation for gemma.")
if self.attention_bias:
raise ValueError('Only "False" attention_bias is supported for gemma')
Expand Down
Loading

0 comments on commit b64f91c

Please sign in to comment.