diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 0d1aaf8ce8..0000000000 --- a/.gitmodules +++ /dev/null @@ -1,9 +0,0 @@ -[submodule "3rdparty/argparse"] - path = 3rdparty/argparse - url = https://github.com/p-ranav/argparse -[submodule "3rdparty/tokenizers-cpp"] - path = 3rdparty/tokenizers-cpp - url = https://github.com/mlc-ai/tokenizers-cpp -[submodule "3rdparty/googletest"] - path = 3rdparty/googletest - url = https://github.com/google/googletest.git diff --git a/3rdparty/argparse b/3rdparty/argparse deleted file mode 160000 index 557948f123..0000000000 --- a/3rdparty/argparse +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 557948f1236db9e27089959de837cc23de6c6bbd diff --git a/3rdparty/googletest b/3rdparty/googletest deleted file mode 160000 index 4580469122..0000000000 --- a/3rdparty/googletest +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 45804691223635953f311cf31a10c632553bbfc3 diff --git a/3rdparty/tokenizers-cpp b/3rdparty/tokenizers-cpp deleted file mode 160000 index 27dbe17d72..0000000000 --- a/3rdparty/tokenizers-cpp +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 27dbe17d7268801ec720569167af905c88d3db50 diff --git a/CMakeLists.txt b/CMakeLists.txt deleted file mode 100644 index 7f0dd7ef24..0000000000 --- a/CMakeLists.txt +++ /dev/null @@ -1,176 +0,0 @@ -cmake_minimum_required(VERSION 3.18) -project(mlc_llm C CXX) - -include(CheckCXXCompilerFlag) -if(MSVC) - set(CMAKE_CXX_FLAGS "/fp:fast ${CMAKE_CXX_FLAGS}") -else() - set(CMAKE_CXX_FLAGS "-ffast-math ${CMAKE_CXX_FLAGS}") -endif() - -if(EXISTS ${CMAKE_BINARY_DIR}/config.cmake) - include(${CMAKE_BINARY_DIR}/config.cmake) -else() - if(EXISTS ${CMAKE_SOURCE_DIR}/config.cmake) - include(${CMAKE_SOURCE_DIR}/config.cmake) - endif() -endif() - -if(NOT CMAKE_BUILD_TYPE) - set( - CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "Build type" FORCE - ) - message(STATUS "Setting default build type to " ${CMAKE_BUILD_TYPE}) -endif(NOT CMAKE_BUILD_TYPE) - -option(MLC_HIDE_PRIVATE_SYMBOLS "Hide private symbols" ON) - -if (MLC_LLM_INSTALL_STATIC_LIB) - set(BUILD_STATIC_RUNTIME ON) -endif() - -set(MLC_VISIBILITY_FLAG "") -if (MLC_HIDE_PRIVATE_SYMBOLS) - set(HIDE_PRIVATE_SYMBOLS ON) - if (NOT MSVC) - set(MLC_VISIBILITY_FLAG "-fvisibility=hidden") - endif() - message(STATUS "Hide private symbols") -endif() - -option(BUILD_CPP_TEST "Build cpp unittests" OFF) - -set(CMAKE_CUDA_STANDARD 17) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_POSITION_INDEPENDENT_CODE ON) - -# tvm runtime config: minimize runtime components -set(USE_RPC OFF) -set(USE_MICRO OFF) -set(USE_GRAPH_EXECUTOR OFF) -set(USE_GRAPH_EXECUTOR_DEBUG OFF) -set(USE_AOT_EXECUTOR OFF) -set(USE_PROFILER OFF) -set(USE_GTEST OFF) -set(USE_LIBBACKTRACE OFF) -set(BUILD_DUMMY_LIBTVM ON) -if (NOT DEFINED TVM_HOME) - if(DEFINED ENV{TVM_HOME}) - set(TVM_HOME "$ENV{TVM_HOME}") - else() - set(TVM_HOME 3rdparty/tvm) - endif(DEFINED ENV{TVM_HOME}) -endif (NOT DEFINED TVM_HOME) -message(STATUS "TVM_HOME: ${TVM_HOME}") -add_subdirectory(${TVM_HOME} tvm EXCLUDE_FROM_ALL) - -set(MLC_LLM_RUNTIME_LINKER_LIB "") -set(TOKENZIER_CPP_PATH 3rdparty/tokenizers-cpp) -add_subdirectory(${TOKENZIER_CPP_PATH} tokenizers EXCLUDE_FROM_ALL) - - -tvm_file_glob(GLOB_RECURSE MLC_LLM_SRCS cpp/*.cc) -add_library(mlc_llm_objs OBJECT ${MLC_LLM_SRCS}) - -set( - MLC_LLM_INCLUDES - ${TVM_HOME}/include - ${TVM_HOME}/3rdparty/dlpack/include - ${TVM_HOME}/3rdparty/dmlc-core/include - ${TVM_HOME}/3rdparty/picojson -) - -set(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} DMLC_USE_LOGGING_LIBRARY=) -set(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} __STDC_FORMAT_MACROS=1) -set(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} PICOJSON_USE_INT64) - -target_include_directories(mlc_llm_objs PRIVATE ${MLC_LLM_INCLUDES}) -target_compile_definitions(mlc_llm_objs PRIVATE ${MLC_LLM_COMPILE_DEFS}) -target_include_directories(mlc_llm_objs PRIVATE ${TOKENZIER_CPP_PATH}/include) -target_compile_definitions(mlc_llm_objs PRIVATE -DMLC_LLM_EXPORTS) - -add_library(mlc_llm SHARED $) -add_library(mlc_llm_static STATIC $) -add_dependencies(mlc_llm_static tokenizers_cpp sentencepiece-static tokenizers_c tvm_runtime) -set_target_properties(mlc_llm_static PROPERTIES OUTPUT_NAME mlc_llm) - -target_link_libraries(mlc_llm PUBLIC tvm_runtime) -target_link_libraries(mlc_llm PRIVATE tokenizers_cpp) - -find_library( - FLASH_ATTN_LIBRARY flash_attn - HINTS ${TVM_HOME}/*/3rdparty/libflash_attn/src -) - -if (FLASH_ATTN_LIBRARY STREQUAL "FLASH_ATTN_LIBRARY-NOTFOUND") - message(WARNING "Cannot find libflash_attn. The model must not have been built with --use-flash-attn-mqa option.") -else () - target_link_libraries(mlc_llm PUBLIC -Wl,--no-as-needed ${FLASH_ATTN_LIBRARY}) -endif() - -if(CMAKE_BUILD_TYPE STREQUAL "Debug") - target_compile_definitions(mlc_llm PRIVATE "TVM_LOG_DEBUG") - target_compile_definitions(mlc_llm_objs PRIVATE "TVM_LOG_DEBUG") - target_compile_definitions(mlc_llm_static PRIVATE "TVM_LOG_DEBUG") -endif() - -if (BUILD_CPP_TEST) - message(STATUS "Building cpp unittests") - add_subdirectory(3rdparty/googletest) - file(GLOB_RECURSE MLC_LLM_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/cpp/*unittest.cc) - add_executable(mlc_llm_cpp_tests ${MLC_LLM_TEST_SRCS}) - target_include_directories(mlc_llm_cpp_tests PRIVATE ${MLC_LLM_INCLUDES}) - target_include_directories(mlc_llm_cpp_tests PRIVATE ${PROJECT_SOURCE_DIR}/cpp) - target_include_directories(mlc_llm_cpp_tests PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) - target_link_libraries(mlc_llm_cpp_tests PUBLIC mlc_llm gtest gtest_main) -endif(BUILD_CPP_TEST) - -if (CMAKE_SYSTEM_NAME STREQUAL "Android") - target_link_libraries(mlc_llm PRIVATE log) - target_link_libraries(tokenizers_cpp PRIVATE log) -endif() - -add_library(mlc_llm_module SHARED $) -target_link_libraries(mlc_llm_module PUBLIC tvm) -target_link_libraries(mlc_llm_module PRIVATE tokenizers_cpp) - - -set_property(TARGET mlc_llm_module APPEND PROPERTY LINK_OPTIONS "${MLC_VISIBILITY_FLAG}") -set_property(TARGET mlc_llm APPEND PROPERTY LINK_OPTIONS "${MLC_VISIBILITY_FLAG}") - -find_program(CARGO_EXECUTABLE cargo) - -if(NOT CARGO_EXECUTABLE) - message(FATAL_ERROR "Cargo is not found! Please install cargo.") -endif() - -# when this option is on, -# we install all static lib deps into lib -if (MLC_LLM_INSTALL_STATIC_LIB) - install(TARGETS - mlc_llm_static - tokenizers_cpp - sentencepiece-static - tvm_runtime - LIBRARY DESTINATION lib${LIB_SUFFIX} - ) - # tokenizers need special handling as it builds from rust - if(MSVC) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/tokenizers/libtokenizers_c.lib - DESTINATION lib${LIB_SUFFIX} - ) - else() - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/tokenizers/libtokenizers_c.a - DESTINATION lib${LIB_SUFFIX} - ) - endif() -else() - install(TARGETS tvm_runtime mlc_llm mlc_llm_module - mlc_llm_static - tokenizers_cpp - sentencepiece-static - RUNTIME_DEPENDENCY_SET tokenizers_c - RUNTIME DESTINATION bin - LIBRARY DESTINATION lib${LIB_SUFFIX} - ) -endif() diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md deleted file mode 100644 index 3f70fac24e..0000000000 --- a/CONTRIBUTORS.md +++ /dev/null @@ -1,6 +0,0 @@ -MLC LLM Contributors -==================== - - -## List of Contributors -- [Full List of Contributors](https://github.com/mlc-ai/mlc-llm/graphs/contributors) diff --git a/README.md b/README.md deleted file mode 100644 index 88e3abd07d..0000000000 --- a/README.md +++ /dev/null @@ -1,240 +0,0 @@ -[discord-url]: https://discord.gg/9Xpy2HGBuD - -# MLC LLM - -[Documentation](https://llm.mlc.ai/docs) | [Blog](https://blog.mlc.ai/) | [Discord][discord-url] - -**M**achine **L**earning **C**ompilation for **L**arge **L**anguage **M**odels (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. - -**Universal deployment.** MLC LLM supports the following platforms and hardware: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
AMD GPUNVIDIA GPUApple GPUIntel GPU
Linux / Win✅ Vulkan, ROCm✅ Vulkan, CUDAN/A✅ Vulkan
macOS✅ Metal (dGPU)N/A✅ Metal✅ Metal (iGPU)
Web Browser✅ WebGPU and WASM
iOS / iPadOS✅ Metal on Apple A-series GPU
Android✅ OpenCL on Adreno GPU✅ OpenCL on Mali GPU
- - -## Quick Start - -We introduce the quick start examples of chat CLI, Python API and REST server here to use MLC LLM. -We use 4-bit quantized 8B Llama-3 model for demonstration purpose. -The pre-quantized Llama-3 weights is available at https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC. -You can also try out unquantized Llama-3 model by replacing `q4f16_1` to `q0f16` in the examples below. -Please visit our [documentation](https://llm.mlc.ai/docs/index.html) for detailed quick start and introduction. - -### Installation - -MLC LLM is available via [pip](https://llm.mlc.ai/docs/install/mlc_llm.html#install-mlc-packages). -It is always recommended to install it in an isolated conda virtual environment. - -To verify the installation, activate your virtual environment, run - -```bash -python -c "import mlc_llm; print(mlc_llm.__path__)" -``` - -You are expected to see the installation path of MLC LLM Python package. - -### Chat CLI - -We can try out the chat CLI in MLC LLM with 4-bit quantized 8B Llama-3 model. - -```bash -mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC -``` - -It may take 1-2 minutes for the first time running this command. -After waiting, this command launch a chat interface where you can enter your prompt and chat with the model. - -``` -You can use the following special commands: -/help print the special commands -/exit quit the cli -/stats print out the latest stats (token/sec) -/reset restart a fresh chat -/set [overrides] override settings in the generation config. For example, - `/set temperature=0.5;max_gen_len=100;stop=end,stop` - Note: Separate stop words in the `stop` option with commas (,). -Multi-line input: Use escape+enter to start a new line. - -user: What's the meaning of life -assistant: -What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. - -The concept of the meaning of life has been debated and... -``` - -### Python API - -We can run the Llama-3 model with the chat completion Python API of MLC LLM. -You can save the code below into a Python file and run it. - -```python -from mlc_llm import MLCEngine - -# Create engine -model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" -engine = MLCEngine(model) - -# Run chat completion in OpenAI API. -for response in engine.chat.completions.create( - messages=[{"role": "user", "content": "What is the meaning of life?"}], - model=model, - stream=True, -): - for choice in response.choices: - print(choice.delta.content, end="", flush=True) -print("\n") - -engine.terminate() -``` - -**The Python API of `mlc_llm.MLCEngine` fully aligns with OpenAI API**. -You can use MLCEngine in the same way of using -[OpenAI's Python package](https://github.com/openai/openai-python?tab=readme-ov-file#usage) -for both synchronous and asynchronous generation. - -If you would like to do concurrent asynchronous generation, you can use `mlc_llm.AsyncMLCEngine` instead. - -### REST Server - -We can launch a REST server to serve the 4-bit quantized Llama-3 model for OpenAI chat completion requests. -The server has fully OpenAI API completeness. - -```bash -mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC -``` - -The server is hooked at `http://127.0.0.1:8000` by default, and you can use `--host` and `--port` -to set a different host and port. -When the server is ready (showing `INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)`), -we can open a new shell and send a cURL request via the following command: - -```bash -curl -X POST \ - -H "Content-Type: application/json" \ - -d '{ - "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", - "messages": [ - {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} - ] - }' \ - http://127.0.0.1:8000/v1/chat/completions -``` - -## Universal Deployment APIs - -MLC LLM provides multiple sets of APIs across platforms and environments. These include -* [Python API](https://llm.mlc.ai/docs/deploy/python_engine.html) -* [OpenAI-compatible Rest-API](https://llm.mlc.ai/docs/deploy/rest.html) -* [C++ API](https://llm.mlc.ai/docs/deploy/cli.html) -* [JavaScript API](https://llm.mlc.ai/docs/deploy/javascript.html) and [Web LLM](https://github.com/mlc-ai/web-llm) -* [Swift API for iOS App](https://llm.mlc.ai/docs/deploy/ios.html) -* [Java API and Android App](https://llm.mlc.ai/docs/deploy/android.html) - -## Citation - -Please consider citing our project if you find it useful: - -```bibtex -@software{mlc-llm, - author = {MLC team}, - title = {{MLC-LLM}}, - url = {https://github.com/mlc-ai/mlc-llm}, - year = {2023} -} -``` - -The underlying techniques of MLC LLM include: - -
- References (Click to expand) - - ```bibtex - @inproceedings{tensorir, - author = {Feng, Siyuan and Hou, Bohan and Jin, Hongyi and Lin, Wuwei and Shao, Junru and Lai, Ruihang and Ye, Zihao and Zheng, Lianmin and Yu, Cody Hao and Yu, Yong and Chen, Tianqi}, - title = {TensorIR: An Abstraction for Automatic Tensorized Program Optimization}, - year = {2023}, - isbn = {9781450399166}, - publisher = {Association for Computing Machinery}, - address = {New York, NY, USA}, - url = {https://doi.org/10.1145/3575693.3576933}, - doi = {10.1145/3575693.3576933}, - booktitle = {Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2}, - pages = {804–817}, - numpages = {14}, - keywords = {Tensor Computation, Machine Learning Compiler, Deep Neural Network}, - location = {Vancouver, BC, Canada}, - series = {ASPLOS 2023} - } - - @inproceedings{metaschedule, - author = {Shao, Junru and Zhou, Xiyou and Feng, Siyuan and Hou, Bohan and Lai, Ruihang and Jin, Hongyi and Lin, Wuwei and Masuda, Masahiro and Yu, Cody Hao and Chen, Tianqi}, - booktitle = {Advances in Neural Information Processing Systems}, - editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh}, - pages = {35783--35796}, - publisher = {Curran Associates, Inc.}, - title = {Tensor Program Optimization with Probabilistic Programs}, - url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/e894eafae43e68b4c8dfdacf742bcbf3-Paper-Conference.pdf}, - volume = {35}, - year = {2022} - } - - @inproceedings{tvm, - author = {Tianqi Chen and Thierry Moreau and Ziheng Jiang and Lianmin Zheng and Eddie Yan and Haichen Shen and Meghan Cowan and Leyuan Wang and Yuwei Hu and Luis Ceze and Carlos Guestrin and Arvind Krishnamurthy}, - title = {{TVM}: An Automated {End-to-End} Optimizing Compiler for Deep Learning}, - booktitle = {13th USENIX Symposium on Operating Systems Design and Implementation (OSDI 18)}, - year = {2018}, - isbn = {978-1-939133-08-3}, - address = {Carlsbad, CA}, - pages = {578--594}, - url = {https://www.usenix.org/conference/osdi18/presentation/chen}, - publisher = {USENIX Association}, - month = oct, - } - ``` -
- -## Links - -- You might want to check out our online public [Machine Learning Compilation course](https://mlc.ai) for a systematic -walkthrough of our approaches. -- [WebLLM](https://webllm.mlc.ai/) is a companion project using MLC LLM's WebGPU and WebAssembly backend. -- [WebStableDiffusion](https://websd.mlc.ai/) is a companion project for diffusion models with the WebGPU backend. - diff --git a/android/.gitignore b/android/.gitignore deleted file mode 100644 index 002b05d2be..0000000000 --- a/android/.gitignore +++ /dev/null @@ -1,19 +0,0 @@ -app/src/main/jni/*.h -app/src/main/jni/*.cc -app/src/main/obj - -*.iml -.gradle -/local.properties -/.idea/caches -/.idea/libraries -/.idea/modules.xml -/.idea/workspace.xml -/.idea/navEditor.xml -/.idea/assetWizardSettings.xml -.DS_Store -/build -/captures -.externalNativeBuild -.cxx -local.properties diff --git a/android/README.md b/android/README.md deleted file mode 100644 index 502eb53c7b..0000000000 --- a/android/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# MLC-LLM Android - -[Documentation page](https://llm.mlc.ai/docs/deploy/android.html) diff --git a/android/app/.gitignore b/android/app/.gitignore deleted file mode 100644 index 558f311c28..0000000000 --- a/android/app/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -/build -/src/main/libs \ No newline at end of file diff --git a/android/app/build.gradle b/android/app/build.gradle deleted file mode 100644 index 1fd30e3985..0000000000 --- a/android/app/build.gradle +++ /dev/null @@ -1,73 +0,0 @@ -plugins { - id 'com.android.application' - id 'org.jetbrains.kotlin.android' -} - -android { - namespace 'ai.mlc.mlcchat' - compileSdk 34 - - defaultConfig { - applicationId "ai.mlc.mlcchat" - minSdk 26 - targetSdk 33 - versionCode 1 - versionName "1.0" - - testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" - vectorDrawables { - useSupportLibrary true - } - } - - buildTypes { - release { - minifyEnabled false - proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' - } - } - compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 - } - kotlinOptions { - jvmTarget = '1.8' - } - buildFeatures { - compose true - } - composeOptions { - kotlinCompilerExtensionVersion '1.4.3' - } - packagingOptions { - resources { - excludes += '/META-INF/{AL2.0,LGPL2.1}' - } - } -} - -dependencies { - implementation project(":library") - implementation 'androidx.core:core-ktx:1.10.1' - implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.6.1' - implementation 'androidx.activity:activity-compose:1.7.1' - implementation platform('androidx.compose:compose-bom:2022.10.00') - implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.6.1' - implementation 'androidx.compose.ui:ui' - implementation 'androidx.compose.ui:ui-graphics' - implementation 'androidx.compose.ui:ui-tooling-preview' - implementation 'androidx.compose.material3:material3:1.1.0' - implementation 'androidx.compose.material:material-icons-extended' - implementation 'androidx.appcompat:appcompat:1.6.1' - implementation 'androidx.navigation:navigation-compose:2.5.3' - implementation 'com.google.code.gson:gson:2.10.1' - implementation fileTree(dir: 'src/main/libs', include: ['*.aar', '*.jar'], exclude: []) - testImplementation 'junit:junit:4.13.2' - androidTestImplementation 'androidx.test.ext:junit:1.1.5' - androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1' - androidTestImplementation platform('androidx.compose:compose-bom:2022.10.00') - androidTestImplementation 'androidx.compose.ui:ui-test-junit4' - debugImplementation 'androidx.compose.ui:ui-tooling' - debugImplementation 'androidx.compose.ui:ui-test-manifest' - -} \ No newline at end of file diff --git a/android/app/proguard-rules.pro b/android/app/proguard-rules.pro deleted file mode 100644 index 481bb43481..0000000000 --- a/android/app/proguard-rules.pro +++ /dev/null @@ -1,21 +0,0 @@ -# Add project specific ProGuard rules here. -# You can control the set of applied configuration files using the -# proguardFiles setting in build.gradle. -# -# For more details, see -# http://developer.android.com/guide/developing/tools/proguard.html - -# If your project uses WebView with JS, uncomment the following -# and specify the fully qualified class name to the JavaScript interface -# class: -#-keepclassmembers class fqcn.of.javascript.interface.for.webview { -# public *; -#} - -# Uncomment this to preserve the line number information for -# debugging stack traces. -#-keepattributes SourceFile,LineNumberTable - -# If you keep the line number information, uncomment this to -# hide the original source file name. -#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/android/app/src/main/AndroidManifest.xml b/android/app/src/main/AndroidManifest.xml deleted file mode 100644 index caad13bf69..0000000000 --- a/android/app/src/main/AndroidManifest.xml +++ /dev/null @@ -1,41 +0,0 @@ - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/android/app/src/main/ic_launcher-playstore.png b/android/app/src/main/ic_launcher-playstore.png deleted file mode 100644 index 3c16fd65fd..0000000000 Binary files a/android/app/src/main/ic_launcher-playstore.png and /dev/null differ diff --git a/android/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt b/android/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt deleted file mode 100644 index 6a3bf4a211..0000000000 --- a/android/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt +++ /dev/null @@ -1,760 +0,0 @@ -package ai.mlc.mlcchat - -import ai.mlc.mlcllm.ChatModule -import android.app.Application -import android.content.ClipData -import android.content.ClipboardManager -import android.content.Context -import android.os.Environment -import android.widget.Toast -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.toMutableStateList -import androidx.lifecycle.AndroidViewModel -import androidx.lifecycle.viewModelScope -import com.google.gson.Gson -import com.google.gson.annotations.SerializedName -import kotlinx.coroutines.launch -import java.io.File -import java.io.FileOutputStream -import java.net.URL -import java.nio.channels.Channels -import java.util.UUID -import java.util.concurrent.Executors -import kotlin.concurrent.thread - -class AppViewModel(application: Application) : AndroidViewModel(application) { - val modelList = emptyList().toMutableStateList() - val chatState = ChatState() - val modelSampleList = emptyList().toMutableStateList() - private var showAlert = mutableStateOf(false) - private var alertMessage = mutableStateOf("") - private var appConfig = AppConfig( - emptyList().toMutableList(), - emptyList().toMutableList() - ) - private val application = getApplication() - private val appDirFile = application.getExternalFilesDir("") - private val gson = Gson() - private val modelIdSet = emptySet().toMutableSet() - - companion object { - const val AppConfigFilename = "app-config.json" - const val ModelConfigFilename = "mlc-chat-config.json" - const val ParamsConfigFilename = "ndarray-cache.json" - const val ModelUrlSuffix = "resolve/main/" - } - - init { - loadAppConfig() - } - - fun isShowingAlert(): Boolean { - return showAlert.value - } - - fun errorMessage(): String { - return alertMessage.value - } - - fun dismissAlert() { - require(showAlert.value) - showAlert.value = false - } - - fun copyError() { - require(showAlert.value) - val clipboard = - application.getSystemService(Context.CLIPBOARD_SERVICE) as ClipboardManager - clipboard.setPrimaryClip(ClipData.newPlainText("MLCChat", errorMessage())) - } - - private fun issueAlert(error: String) { - showAlert.value = true - alertMessage.value = error - } - - fun requestDeleteModel(modelId: String) { - deleteModel(modelId) - issueAlert("Model: $modelId has been deleted") - } - - - private fun loadAppConfig() { - val appConfigFile = File(appDirFile, AppConfigFilename) - val jsonString: String = if (!appConfigFile.exists()) { - application.assets.open(AppConfigFilename).bufferedReader().use { it.readText() } - } else { - appConfigFile.readText() - } - appConfig = gson.fromJson(jsonString, AppConfig::class.java) - appConfig.modelLibs = emptyList().toMutableList() - modelList.clear() - modelIdSet.clear() - modelSampleList.clear() - for (modelRecord in appConfig.modelList) { - appConfig.modelLibs.add(modelRecord.modelLib) - val modelDirFile = File(appDirFile, modelRecord.modelId) - val modelConfigFile = File(modelDirFile, ModelConfigFilename) - if (modelConfigFile.exists()) { - val modelConfigString = modelConfigFile.readText() - val modelConfig = gson.fromJson(modelConfigString, ModelConfig::class.java) - modelConfig.modelId = modelRecord.modelId - modelConfig.modelLib = modelRecord.modelLib - modelConfig.estimatedVramBytes = modelRecord.estimatedVramBytes - addModelConfig(modelConfig, modelRecord.modelUrl, true) - } else { - downloadModelConfig( - if (modelRecord.modelUrl.endsWith("/")) modelRecord.modelUrl else "${modelRecord.modelUrl}/", - modelRecord, - true - ) - } - } - } - - private fun updateAppConfig(action: () -> Unit) { - action() - val jsonString = gson.toJson(appConfig) - val appConfigFile = File(appDirFile, AppConfigFilename) - appConfigFile.writeText(jsonString) - } - - private fun addModelConfig(modelConfig: ModelConfig, modelUrl: String, isBuiltin: Boolean) { - require(!modelIdSet.contains(modelConfig.modelId)) - modelIdSet.add(modelConfig.modelId) - modelList.add( - ModelState( - modelConfig, - modelUrl + if (modelUrl.endsWith("/")) "" else "/", - File(appDirFile, modelConfig.modelId) - ) - ) - if (!isBuiltin) { - updateAppConfig { - appConfig.modelList.add( - ModelRecord( - modelUrl, - modelConfig.modelId, - modelConfig.estimatedVramBytes, - modelConfig.modelLib - ) - ) - } - } - } - - private fun deleteModel(modelId: String) { - val modelDirFile = File(appDirFile, modelId) - modelDirFile.deleteRecursively() - require(!modelDirFile.exists()) - modelIdSet.remove(modelId) - modelList.removeIf { modelState -> modelState.modelConfig.modelId == modelId } - updateAppConfig { - appConfig.modelList.removeIf { modelRecord -> modelRecord.modelId == modelId } - } - } - - private fun isModelConfigAllowed(modelConfig: ModelConfig): Boolean { - if (appConfig.modelLibs.contains(modelConfig.modelLib)) return true - viewModelScope.launch { - issueAlert("Model lib ${modelConfig.modelLib} is not supported.") - } - return false - } - - - private fun downloadModelConfig( - modelUrl: String, - modelRecord: ModelRecord, - isBuiltin: Boolean - ) { - thread(start = true) { - try { - val url = URL("${modelUrl}${ModelUrlSuffix}${ModelConfigFilename}") - val tempId = UUID.randomUUID().toString() - val tempFile = File( - application.getExternalFilesDir(Environment.DIRECTORY_DOWNLOADS), - tempId - ) - url.openStream().use { - Channels.newChannel(it).use { src -> - FileOutputStream(tempFile).use { fileOutputStream -> - fileOutputStream.channel.transferFrom(src, 0, Long.MAX_VALUE) - } - } - } - require(tempFile.exists()) - viewModelScope.launch { - try { - val modelConfigString = tempFile.readText() - val modelConfig = gson.fromJson(modelConfigString, ModelConfig::class.java) - modelConfig.modelId = modelRecord.modelId - modelConfig.modelLib = modelRecord.modelLib - modelConfig.estimatedVramBytes = modelRecord.estimatedVramBytes - if (modelIdSet.contains(modelConfig.modelId)) { - tempFile.delete() - issueAlert("${modelConfig.modelId} has been used, please consider another local ID") - return@launch - } - if (!isModelConfigAllowed(modelConfig)) { - tempFile.delete() - return@launch - } - val modelDirFile = File(appDirFile, modelConfig.modelId) - val modelConfigFile = File(modelDirFile, ModelConfigFilename) - tempFile.copyTo(modelConfigFile, overwrite = true) - tempFile.delete() - require(modelConfigFile.exists()) - addModelConfig(modelConfig, modelUrl, isBuiltin) - } catch (e: Exception) { - viewModelScope.launch { - issueAlert("Add model failed: ${e.localizedMessage}") - } - } - } - } catch (e: Exception) { - viewModelScope.launch { - issueAlert("Download model config failed: ${e.localizedMessage}") - } - } - - } - } - - inner class ModelState( - val modelConfig: ModelConfig, - private val modelUrl: String, - private val modelDirFile: File - ) { - var modelInitState = mutableStateOf(ModelInitState.Initializing) - private var paramsConfig = ParamsConfig(emptyList()) - val progress = mutableStateOf(0) - val total = mutableStateOf(1) - val id: UUID = UUID.randomUUID() - private val remainingTasks = emptySet().toMutableSet() - private val downloadingTasks = emptySet().toMutableSet() - private val maxDownloadTasks = 3 - private val gson = Gson() - - - init { - switchToInitializing() - } - - private fun switchToInitializing() { - val paramsConfigFile = File(modelDirFile, ParamsConfigFilename) - if (paramsConfigFile.exists()) { - loadParamsConfig() - switchToIndexing() - } else { - downloadParamsConfig() - } - } - - private fun loadParamsConfig() { - val paramsConfigFile = File(modelDirFile, ParamsConfigFilename) - require(paramsConfigFile.exists()) - val jsonString = paramsConfigFile.readText() - paramsConfig = gson.fromJson(jsonString, ParamsConfig::class.java) - } - - private fun downloadParamsConfig() { - thread(start = true) { - val url = URL("${modelUrl}${ModelUrlSuffix}${ParamsConfigFilename}") - val tempId = UUID.randomUUID().toString() - val tempFile = File(modelDirFile, tempId) - url.openStream().use { - Channels.newChannel(it).use { src -> - FileOutputStream(tempFile).use { fileOutputStream -> - fileOutputStream.channel.transferFrom(src, 0, Long.MAX_VALUE) - } - } - } - require(tempFile.exists()) - val paramsConfigFile = File(modelDirFile, ParamsConfigFilename) - tempFile.renameTo(paramsConfigFile) - require(paramsConfigFile.exists()) - viewModelScope.launch { - loadParamsConfig() - switchToIndexing() - } - } - } - - fun handleStart() { - switchToDownloading() - } - - fun handlePause() { - switchToPausing() - } - - fun handleClear() { - require( - modelInitState.value == ModelInitState.Downloading || - modelInitState.value == ModelInitState.Paused || - modelInitState.value == ModelInitState.Finished - ) - switchToClearing() - } - - private fun switchToClearing() { - if (modelInitState.value == ModelInitState.Paused) { - modelInitState.value = ModelInitState.Clearing - clear() - } else if (modelInitState.value == ModelInitState.Finished) { - modelInitState.value = ModelInitState.Clearing - if (chatState.modelName.value == modelConfig.modelId) { - chatState.requestTerminateChat { clear() } - } else { - clear() - } - } else { - modelInitState.value = ModelInitState.Clearing - } - } - - fun handleDelete() { - require( - modelInitState.value == ModelInitState.Downloading || - modelInitState.value == ModelInitState.Paused || - modelInitState.value == ModelInitState.Finished - ) - switchToDeleting() - } - - private fun switchToDeleting() { - if (modelInitState.value == ModelInitState.Paused) { - modelInitState.value = ModelInitState.Deleting - delete() - } else if (modelInitState.value == ModelInitState.Finished) { - modelInitState.value = ModelInitState.Deleting - if (chatState.modelName.value == modelConfig.modelId) { - chatState.requestTerminateChat { delete() } - } else { - delete() - } - } else { - modelInitState.value = ModelInitState.Deleting - } - } - - private fun switchToIndexing() { - modelInitState.value = ModelInitState.Indexing - progress.value = 0 - total.value = modelConfig.tokenizerFiles.size + paramsConfig.paramsRecords.size - for (tokenizerFilename in modelConfig.tokenizerFiles) { - val file = File(modelDirFile, tokenizerFilename) - if (file.exists()) { - ++progress.value - } else { - remainingTasks.add( - DownloadTask( - URL("${modelUrl}${ModelUrlSuffix}${tokenizerFilename}"), - file - ) - ) - } - } - for (paramsRecord in paramsConfig.paramsRecords) { - val file = File(modelDirFile, paramsRecord.dataPath) - if (file.exists()) { - ++progress.value - } else { - remainingTasks.add( - DownloadTask( - URL("${modelUrl}${ModelUrlSuffix}${paramsRecord.dataPath}"), - file - ) - ) - } - } - if (progress.value < total.value) { - switchToPaused() - } else { - switchToFinished() - } - } - - private fun switchToDownloading() { - modelInitState.value = ModelInitState.Downloading - for (downloadTask in remainingTasks) { - if (downloadingTasks.size < maxDownloadTasks) { - handleNewDownload(downloadTask) - } else { - return - } - } - } - - private fun handleNewDownload(downloadTask: DownloadTask) { - require(modelInitState.value == ModelInitState.Downloading) - require(!downloadingTasks.contains(downloadTask)) - downloadingTasks.add(downloadTask) - thread(start = true) { - val tempId = UUID.randomUUID().toString() - val tempFile = File(modelDirFile, tempId) - downloadTask.url.openStream().use { - Channels.newChannel(it).use { src -> - FileOutputStream(tempFile).use { fileOutputStream -> - fileOutputStream.channel.transferFrom(src, 0, Long.MAX_VALUE) - } - } - } - require(tempFile.exists()) - tempFile.renameTo(downloadTask.file) - require(downloadTask.file.exists()) - viewModelScope.launch { - handleFinishDownload(downloadTask) - } - } - } - - private fun handleNextDownload() { - require(modelInitState.value == ModelInitState.Downloading) - for (downloadTask in remainingTasks) { - if (!downloadingTasks.contains(downloadTask)) { - handleNewDownload(downloadTask) - break - } - } - } - - private fun handleFinishDownload(downloadTask: DownloadTask) { - remainingTasks.remove(downloadTask) - downloadingTasks.remove(downloadTask) - ++progress.value - require( - modelInitState.value == ModelInitState.Downloading || - modelInitState.value == ModelInitState.Pausing || - modelInitState.value == ModelInitState.Clearing || - modelInitState.value == ModelInitState.Deleting - ) - if (modelInitState.value == ModelInitState.Downloading) { - if (remainingTasks.isEmpty()) { - if (downloadingTasks.isEmpty()) { - switchToFinished() - } - } else { - handleNextDownload() - } - } else if (modelInitState.value == ModelInitState.Pausing) { - if (downloadingTasks.isEmpty()) { - switchToPaused() - } - } else if (modelInitState.value == ModelInitState.Clearing) { - if (downloadingTasks.isEmpty()) { - clear() - } - } else if (modelInitState.value == ModelInitState.Deleting) { - if (downloadingTasks.isEmpty()) { - delete() - } - } - } - - private fun clear() { - val files = modelDirFile.listFiles { dir, name -> - !(dir == modelDirFile && name == ModelConfigFilename) - } - require(files != null) - for (file in files) { - file.deleteRecursively() - require(!file.exists()) - } - val modelConfigFile = File(modelDirFile, ModelConfigFilename) - require(modelConfigFile.exists()) - switchToIndexing() - } - - private fun delete() { - modelDirFile.deleteRecursively() - require(!modelDirFile.exists()) - requestDeleteModel(modelConfig.modelId) - } - - private fun switchToPausing() { - modelInitState.value = ModelInitState.Pausing - } - - private fun switchToPaused() { - modelInitState.value = ModelInitState.Paused - } - - - private fun switchToFinished() { - modelInitState.value = ModelInitState.Finished - } - - fun startChat() { - chatState.requestReloadChat( - modelConfig, - modelDirFile.absolutePath, - ) - } - - } - - inner class ChatState { - val messages = emptyList().toMutableStateList() - val report = mutableStateOf("") - val modelName = mutableStateOf("") - private var modelChatState = mutableStateOf(ModelChatState.Ready) - @Synchronized get - @Synchronized set - private val backend = ChatModule() - private var modelLib = "" - private var modelPath = "" - private val executorService = Executors.newSingleThreadExecutor() - - private fun mainResetChat() { - executorService.submit { - callBackend { backend.resetChat() } - viewModelScope.launch { - clearHistory() - switchToReady() - } - } - } - - private fun clearHistory() { - messages.clear() - report.value = "" - } - - - private fun switchToResetting() { - modelChatState.value = ModelChatState.Resetting - } - - private fun switchToGenerating() { - modelChatState.value = ModelChatState.Generating - } - - private fun switchToReloading() { - modelChatState.value = ModelChatState.Reloading - } - - private fun switchToReady() { - modelChatState.value = ModelChatState.Ready - } - - private fun switchToFailed() { - modelChatState.value = ModelChatState.Falied - } - - private fun callBackend(callback: () -> Unit): Boolean { - try { - callback() - } catch (e: Exception) { - viewModelScope.launch { - val stackTrace = e.stackTraceToString() - val errorMessage = e.localizedMessage - appendMessage( - MessageRole.Bot, - "MLCChat failed\n\nStack trace:\n$stackTrace\n\nError message:\n$errorMessage" - ) - switchToFailed() - } - return false - } - return true - } - - fun requestResetChat() { - require(interruptable()) - interruptChat( - prologue = { - switchToResetting() - }, - epilogue = { - mainResetChat() - } - ) - } - - private fun interruptChat(prologue: () -> Unit, epilogue: () -> Unit) { - // prologue runs before interruption - // epilogue runs after interruption - require(interruptable()) - if (modelChatState.value == ModelChatState.Ready) { - prologue() - epilogue() - } else if (modelChatState.value == ModelChatState.Generating) { - prologue() - executorService.submit { - viewModelScope.launch { epilogue() } - } - } else { - require(false) - } - } - - fun requestTerminateChat(callback: () -> Unit) { - require(interruptable()) - interruptChat( - prologue = { - switchToTerminating() - }, - epilogue = { - mainTerminateChat(callback) - } - ) - } - - private fun mainTerminateChat(callback: () -> Unit) { - executorService.submit { - callBackend { backend.unload() } - viewModelScope.launch { - clearHistory() - switchToReady() - callback() - } - } - } - - private fun switchToTerminating() { - modelChatState.value = ModelChatState.Terminating - } - - - fun requestReloadChat(modelConfig: ModelConfig, modelPath: String) { - - if (this.modelName.value == modelConfig.modelId && this.modelLib == modelConfig.modelLib && this.modelPath == modelPath) { - return - } - require(interruptable()) - interruptChat( - prologue = { - switchToReloading() - }, - epilogue = { - mainReloadChat(modelConfig, modelPath) - } - ) - } - - private fun mainReloadChat(modelConfig: ModelConfig, modelPath: String) { - clearHistory() - this.modelName.value = modelConfig.modelId - this.modelLib = modelConfig.modelLib - this.modelPath = modelPath - executorService.submit { - viewModelScope.launch { - Toast.makeText(application, "Initialize...", Toast.LENGTH_SHORT).show() - } - if (!callBackend { - backend.unload() - backend.reload( - modelConfig.modelLib, - modelPath - ) - }) return@submit - viewModelScope.launch { - Toast.makeText(application, "Ready to chat", Toast.LENGTH_SHORT).show() - switchToReady() - } - } - } - - fun requestGenerate(prompt: String) { - require(chatable()) - switchToGenerating() - executorService.submit { - appendMessage(MessageRole.User, prompt) - appendMessage(MessageRole.Bot, "") - if (!callBackend { backend.prefill(prompt) }) return@submit - while (!backend.stopped()) { - if (!callBackend { - backend.decode() - val newText = backend.message - viewModelScope.launch { updateMessage(MessageRole.Bot, newText) } - }) return@submit - if (modelChatState.value != ModelChatState.Generating) return@submit - } - val runtimeStats = backend.runtimeStatsText() - viewModelScope.launch { - report.value = runtimeStats - if (modelChatState.value == ModelChatState.Generating) switchToReady() - } - } - } - - private fun appendMessage(role: MessageRole, text: String) { - messages.add(MessageData(role, text)) - } - - - private fun updateMessage(role: MessageRole, text: String) { - messages[messages.size - 1] = MessageData(role, text) - } - - fun chatable(): Boolean { - return modelChatState.value == ModelChatState.Ready - } - - fun interruptable(): Boolean { - return modelChatState.value == ModelChatState.Ready - || modelChatState.value == ModelChatState.Generating - || modelChatState.value == ModelChatState.Falied - } - } -} - -enum class ModelInitState { - Initializing, - Indexing, - Paused, - Downloading, - Pausing, - Clearing, - Deleting, - Finished -} - -enum class ModelChatState { - Generating, - Resetting, - Reloading, - Terminating, - Ready, - Falied -} - -enum class MessageRole { - Bot, - User -} - -data class DownloadTask(val url: URL, val file: File) - -data class MessageData(val role: MessageRole, val text: String, val id: UUID = UUID.randomUUID()) - -data class AppConfig( - @SerializedName("model_libs") var modelLibs: MutableList, - @SerializedName("model_list") val modelList: MutableList, -) - -data class ModelRecord( - @SerializedName("model_url") val modelUrl: String, - @SerializedName("model_id") val modelId: String, - @SerializedName("estimated_vram_bytes") val estimatedVramBytes: Long?, - @SerializedName("model_lib") val modelLib: String -) - -data class ModelConfig( - @SerializedName("model_lib") var modelLib: String, - @SerializedName("model_id") var modelId: String, - @SerializedName("estimated_vram_bytes") var estimatedVramBytes: Long?, - @SerializedName("tokenizer_files") val tokenizerFiles: List, - @SerializedName("context_window_size") val contextWindowSize: Int, - @SerializedName("prefill_chunk_size") val prefillChunkSize: Int, -) - -data class ParamsRecord( - @SerializedName("dataPath") val dataPath: String -) - -data class ParamsConfig( - @SerializedName("records") val paramsRecords: List -) \ No newline at end of file diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ChatView.kt b/android/app/src/main/java/ai/mlc/mlcchat/ChatView.kt deleted file mode 100644 index 9f581ab313..0000000000 --- a/android/app/src/main/java/ai/mlc/mlcchat/ChatView.kt +++ /dev/null @@ -1,220 +0,0 @@ -package ai.mlc.mlcchat - -import androidx.compose.foundation.background -import androidx.compose.foundation.gestures.detectTapGestures -import androidx.compose.foundation.layout.Arrangement -import androidx.compose.foundation.layout.Column -import androidx.compose.foundation.layout.IntrinsicSize -import androidx.compose.foundation.layout.Row -import androidx.compose.foundation.layout.aspectRatio -import androidx.compose.foundation.layout.fillMaxSize -import androidx.compose.foundation.layout.fillMaxWidth -import androidx.compose.foundation.layout.height -import androidx.compose.foundation.layout.padding -import androidx.compose.foundation.layout.widthIn -import androidx.compose.foundation.layout.wrapContentHeight -import androidx.compose.foundation.layout.wrapContentWidth -import androidx.compose.foundation.lazy.LazyColumn -import androidx.compose.foundation.lazy.items -import androidx.compose.foundation.lazy.rememberLazyListState -import androidx.compose.foundation.shape.RoundedCornerShape -import androidx.compose.foundation.text.selection.SelectionContainer -import androidx.compose.material.icons.Icons -import androidx.compose.material.icons.filled.ArrowBack -import androidx.compose.material.icons.filled.Replay -import androidx.compose.material.icons.filled.Send -import androidx.compose.material3.Divider -import androidx.compose.material3.ExperimentalMaterial3Api -import androidx.compose.material3.Icon -import androidx.compose.material3.IconButton -import androidx.compose.material3.MaterialTheme -import androidx.compose.material3.OutlinedTextField -import androidx.compose.material3.Scaffold -import androidx.compose.material3.Text -import androidx.compose.material3.TopAppBar -import androidx.compose.material3.TopAppBarDefaults -import androidx.compose.runtime.Composable -import androidx.compose.runtime.getValue -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.rememberCoroutineScope -import androidx.compose.runtime.saveable.rememberSaveable -import androidx.compose.runtime.setValue -import androidx.compose.ui.Alignment -import androidx.compose.ui.Modifier -import androidx.compose.ui.input.pointer.pointerInput -import androidx.compose.ui.platform.LocalFocusManager -import androidx.compose.ui.text.style.TextAlign -import androidx.compose.ui.unit.dp -import androidx.navigation.NavController -import kotlinx.coroutines.launch - -@ExperimentalMaterial3Api -@Composable -fun ChatView( - navController: NavController, chatState: AppViewModel.ChatState -) { - val localFocusManager = LocalFocusManager.current - Scaffold(topBar = { - TopAppBar( - title = { - Text( - text = "MLCChat: " + chatState.modelName.value.split("-")[0], - color = MaterialTheme.colorScheme.onPrimary - ) - }, - colors = TopAppBarDefaults.topAppBarColors(containerColor = MaterialTheme.colorScheme.primary), - navigationIcon = { - IconButton( - onClick = { navController.popBackStack() }, - enabled = chatState.interruptable() - ) { - Icon( - imageVector = Icons.Filled.ArrowBack, - contentDescription = "back home page", - tint = MaterialTheme.colorScheme.onPrimary - ) - } - }, - actions = { - IconButton( - onClick = { chatState.requestResetChat() }, - enabled = chatState.interruptable() - ) { - Icon( - imageVector = Icons.Filled.Replay, - contentDescription = "reset the chat", - tint = MaterialTheme.colorScheme.onPrimary - ) - } - }) - }, modifier = Modifier.pointerInput(Unit) { - detectTapGestures(onTap = { - localFocusManager.clearFocus() - }) - }) { paddingValues -> - Column( - modifier = Modifier - .fillMaxSize() - .padding(paddingValues) - .padding(horizontal = 10.dp) - ) { - val lazyColumnListState = rememberLazyListState() - val coroutineScope = rememberCoroutineScope() - Text( - text = chatState.report.value, - textAlign = TextAlign.Center, - modifier = Modifier - .fillMaxWidth() - .wrapContentHeight() - .padding(top = 5.dp) - ) - Divider(thickness = 1.dp, modifier = Modifier.padding(vertical = 5.dp)) - LazyColumn( - modifier = Modifier.weight(9f), - verticalArrangement = Arrangement.spacedBy(5.dp, alignment = Alignment.Bottom), - state = lazyColumnListState - ) { - coroutineScope.launch { - lazyColumnListState.animateScrollToItem(chatState.messages.size) - } - items( - items = chatState.messages, - key = { message -> message.id }, - ) { message -> - MessageView(messageData = message) - } - item { - // place holder item for scrolling to the bottom - } - } - Divider(thickness = 1.dp, modifier = Modifier.padding(top = 5.dp)) - SendMessageView(chatState = chatState) - } - } -} - -@Composable -fun MessageView(messageData: MessageData) { - SelectionContainer { - if (messageData.role == MessageRole.Bot) { - Row( - horizontalArrangement = Arrangement.Start, - modifier = Modifier.fillMaxWidth() - ) { - Text( - text = messageData.text, - textAlign = TextAlign.Left, - color = MaterialTheme.colorScheme.onSecondaryContainer, - modifier = Modifier - .wrapContentWidth() - .background( - color = MaterialTheme.colorScheme.secondaryContainer, - shape = RoundedCornerShape(5.dp) - ) - .padding(5.dp) - .widthIn(max = 300.dp) - ) - - } - } else { - Row( - horizontalArrangement = Arrangement.End, - modifier = Modifier.fillMaxWidth() - ) { - Text( - text = messageData.text, - textAlign = TextAlign.Right, - color = MaterialTheme.colorScheme.onPrimaryContainer, - modifier = Modifier - .wrapContentWidth() - .background( - color = MaterialTheme.colorScheme.primaryContainer, - shape = RoundedCornerShape(5.dp) - ) - .padding(5.dp) - .widthIn(max = 300.dp) - ) - - } - } - } -} - -@ExperimentalMaterial3Api -@Composable -fun SendMessageView(chatState: AppViewModel.ChatState) { - val localFocusManager = LocalFocusManager.current - Row( - horizontalArrangement = Arrangement.spacedBy(5.dp), - verticalAlignment = Alignment.CenterVertically, - modifier = Modifier - .height(IntrinsicSize.Max) - .fillMaxWidth() - .padding(bottom = 5.dp) - ) { - var text by rememberSaveable { mutableStateOf("") } - OutlinedTextField( - value = text, - onValueChange = { text = it }, - label = { Text(text = "Input") }, - modifier = Modifier - .weight(9f), - ) - IconButton( - onClick = { - localFocusManager.clearFocus() - chatState.requestGenerate(text) - text = "" - }, - modifier = Modifier - .aspectRatio(1f) - .weight(1f), - enabled = (text != "" && chatState.chatable()) - ) { - Icon( - imageVector = Icons.Filled.Send, - contentDescription = "send message", - ) - } - } -} diff --git a/android/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt b/android/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt deleted file mode 100644 index c586869324..0000000000 --- a/android/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt +++ /dev/null @@ -1,29 +0,0 @@ -package ai.mlc.mlcchat - -import ai.mlc.mlcchat.ui.theme.MLCChatTheme -import android.os.Bundle -import androidx.activity.ComponentActivity -import androidx.activity.compose.setContent -import androidx.compose.foundation.layout.fillMaxSize -import androidx.compose.material3.ExperimentalMaterial3Api -import androidx.compose.material3.Surface -import androidx.compose.ui.Modifier - - -class MainActivity : ComponentActivity() { - - @ExperimentalMaterial3Api - override fun onCreate(savedInstanceState: Bundle?) { - super.onCreate(savedInstanceState) - setContent { - Surface( - modifier = Modifier - .fillMaxSize() - ) { - MLCChatTheme { - NavView() - } - } - } - } -} \ No newline at end of file diff --git a/android/app/src/main/java/ai/mlc/mlcchat/NavView.kt b/android/app/src/main/java/ai/mlc/mlcchat/NavView.kt deleted file mode 100644 index fe897ce0cf..0000000000 --- a/android/app/src/main/java/ai/mlc/mlcchat/NavView.kt +++ /dev/null @@ -1,18 +0,0 @@ -package ai.mlc.mlcchat - -import androidx.compose.material3.ExperimentalMaterial3Api -import androidx.compose.runtime.Composable -import androidx.lifecycle.viewmodel.compose.viewModel -import androidx.navigation.compose.NavHost -import androidx.navigation.compose.composable -import androidx.navigation.compose.rememberNavController - -@ExperimentalMaterial3Api -@Composable -fun NavView(appViewModel: AppViewModel = viewModel()) { - val navController = rememberNavController() - NavHost(navController = navController, startDestination = "home") { - composable("home") { StartView(navController, appViewModel) } - composable("chat") { ChatView(navController, appViewModel.chatState) } - } -} \ No newline at end of file diff --git a/android/app/src/main/java/ai/mlc/mlcchat/StartView.kt b/android/app/src/main/java/ai/mlc/mlcchat/StartView.kt deleted file mode 100644 index a58129efba..0000000000 --- a/android/app/src/main/java/ai/mlc/mlcchat/StartView.kt +++ /dev/null @@ -1,251 +0,0 @@ -package ai.mlc.mlcchat - -import androidx.compose.foundation.gestures.detectTapGestures -import androidx.compose.foundation.layout.Arrangement -import androidx.compose.foundation.layout.Column -import androidx.compose.foundation.layout.Row -import androidx.compose.foundation.layout.aspectRatio -import androidx.compose.foundation.layout.fillMaxSize -import androidx.compose.foundation.layout.fillMaxWidth -import androidx.compose.foundation.layout.height -import androidx.compose.foundation.layout.padding -import androidx.compose.foundation.layout.width -import androidx.compose.foundation.layout.wrapContentHeight -import androidx.compose.foundation.lazy.LazyColumn -import androidx.compose.foundation.lazy.items -import androidx.compose.foundation.text.selection.SelectionContainer -import androidx.compose.material.icons.Icons -import androidx.compose.material.icons.outlined.Chat -import androidx.compose.material.icons.outlined.Delete -import androidx.compose.material.icons.outlined.Download -import androidx.compose.material.icons.outlined.Pause -import androidx.compose.material.icons.outlined.Schedule -import androidx.compose.material3.AlertDialog -import androidx.compose.material3.Divider -import androidx.compose.material3.ExperimentalMaterial3Api -import androidx.compose.material3.Icon -import androidx.compose.material3.IconButton -import androidx.compose.material3.LinearProgressIndicator -import androidx.compose.material3.MaterialTheme -import androidx.compose.material3.OutlinedTextField -import androidx.compose.material3.Scaffold -import androidx.compose.material3.Text -import androidx.compose.material3.TextButton -import androidx.compose.material3.TopAppBar -import androidx.compose.material3.TopAppBarDefaults -import androidx.compose.runtime.Composable -import androidx.compose.runtime.getValue -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.saveable.rememberSaveable -import androidx.compose.runtime.setValue -import androidx.compose.ui.Alignment -import androidx.compose.ui.Modifier -import androidx.compose.ui.input.pointer.pointerInput -import androidx.compose.ui.platform.LocalFocusManager -import androidx.compose.ui.text.style.TextAlign -import androidx.compose.ui.unit.dp -import androidx.navigation.NavController - - -@ExperimentalMaterial3Api -@Composable -fun StartView( - navController: NavController, - appViewModel: AppViewModel -) { - val localFocusManager = LocalFocusManager.current - Scaffold( - topBar = { - TopAppBar( - title = { Text(text = "MLCChat", color = MaterialTheme.colorScheme.onPrimary) }, - colors = TopAppBarDefaults.topAppBarColors(containerColor = MaterialTheme.colorScheme.primary) - ) - }, - modifier = Modifier.pointerInput(Unit) { - detectTapGestures(onTap = { - localFocusManager.clearFocus() - }) - } - ) - { paddingValues -> - Column( - modifier = Modifier - .fillMaxSize() - .padding(paddingValues) - .padding(horizontal = 10.dp) - ) { - Text(text = "Model List", modifier = Modifier.padding(top = 10.dp)) - LazyColumn() { - items(items = appViewModel.modelList, - key = { modelState -> modelState.id } - ) { modelState -> - ModelView( - navController = navController, - modelState = modelState, - appViewModel = appViewModel - ) - } - } - } - if (appViewModel.isShowingAlert()) { - AlertDialog( - onDismissRequest = { appViewModel.dismissAlert() }, - onConfirmation = { appViewModel.copyError() }, - error = appViewModel.errorMessage() - ) - } - } -} - -@ExperimentalMaterial3Api -@Composable -fun AlertDialog( - onDismissRequest: () -> Unit, - onConfirmation: () -> Unit, - error: String, -) { - AlertDialog( - title = { Text(text = "Error") }, - text = { Text(text = error) }, - onDismissRequest = { onDismissRequest() }, - confirmButton = { - TextButton(onClick = { onConfirmation() }) { Text("Copy") } - }, - dismissButton = { - TextButton(onClick = { onDismissRequest() }) { Text("Dismiss") } - } - ) -} - -@Composable -fun ModelView( - navController: NavController, - modelState: AppViewModel.ModelState, - appViewModel: AppViewModel -) { - var isDeletingModel by rememberSaveable { mutableStateOf(false) } - Column( - verticalArrangement = Arrangement.SpaceBetween, - modifier = Modifier - .wrapContentHeight() - ) { - Row( - horizontalArrangement = Arrangement.spacedBy(5.dp), - verticalAlignment = Alignment.CenterVertically, - modifier = Modifier - .fillMaxWidth() - .wrapContentHeight() - ) { - Text( - text = modelState.modelConfig.modelId, - textAlign = TextAlign.Left, - modifier = Modifier - .wrapContentHeight() - .weight(8f) - ) - Divider( - modifier = Modifier - .height(20.dp) - .width(1.dp) - ) - if (modelState.modelInitState.value == ModelInitState.Paused) { - IconButton( - onClick = { modelState.handleStart() }, modifier = Modifier - .aspectRatio(1f) - .weight(1f) - ) { - Icon( - imageVector = Icons.Outlined.Download, - contentDescription = "start downloading", - ) - } - - } else if (modelState.modelInitState.value == ModelInitState.Downloading) { - IconButton( - onClick = { modelState.handlePause() }, modifier = Modifier - .aspectRatio(1f) - .weight(1f) - ) { - Icon( - imageVector = Icons.Outlined.Pause, - contentDescription = "pause downloading", - ) - } - } else if (modelState.modelInitState.value == ModelInitState.Finished) { - IconButton( - onClick = { - modelState.startChat() - navController.navigate("chat") - }, - enabled = appViewModel.chatState.interruptable(), - modifier = Modifier - .aspectRatio(1f) - .weight(1f) - ) { - Icon( - imageVector = Icons.Outlined.Chat, - contentDescription = "start chatting", - ) - } - } else { - IconButton( - enabled = false, onClick = {}, modifier = Modifier - .aspectRatio(1f) - .weight(1f) - ) { - Icon( - imageVector = Icons.Outlined.Schedule, - contentDescription = "pending", - ) - } - } - if (modelState.modelInitState.value == ModelInitState.Downloading || - modelState.modelInitState.value == ModelInitState.Paused || - modelState.modelInitState.value == ModelInitState.Finished - ) { - IconButton( - onClick = { isDeletingModel = true }, - modifier = Modifier - .aspectRatio(1f) - .weight(1f) - ) { - Icon( - imageVector = Icons.Outlined.Delete, - contentDescription = "start downloading", - tint = MaterialTheme.colorScheme.error - ) - } - } - } - LinearProgressIndicator( - progress = modelState.progress.value.toFloat() / modelState.total.value, - modifier = Modifier.fillMaxWidth() - ) - if (isDeletingModel) { - Row( - horizontalArrangement = Arrangement.End, - verticalAlignment = Alignment.CenterVertically, - modifier = Modifier - .fillMaxWidth() - .wrapContentHeight() - ) { - TextButton(onClick = { isDeletingModel = false }) { - Text(text = "cancel") - } - TextButton(onClick = { - isDeletingModel = false - modelState.handleClear() - }) { - Text(text = "clear data", color = MaterialTheme.colorScheme.error) - } - TextButton(onClick = { - isDeletingModel = false - modelState.handleDelete() - }) { - Text(text = "delete model", color = MaterialTheme.colorScheme.error) - } - } - } - } -} - diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Color.kt b/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Color.kt deleted file mode 100644 index 75a3557baa..0000000000 --- a/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Color.kt +++ /dev/null @@ -1,44 +0,0 @@ -package ai.mlc.mlcchat.ui.theme - -import androidx.compose.ui.graphics.Color - -val Blue10 = Color(0xFF000F5E) -val Blue20 = Color(0xFF001E92) -val Blue30 = Color(0xFF002ECC) -val Blue40 = Color(0xFF1546F6) -val Blue80 = Color(0xFFB8C3FF) -val Blue90 = Color(0xFFDDE1FF) - -val DarkBlue10 = Color(0xFF00036B) -val DarkBlue20 = Color(0xFF000BA6) -val DarkBlue30 = Color(0xFF1026D3) -val DarkBlue40 = Color(0xFF3648EA) -val DarkBlue80 = Color(0xFFBBC2FF) -val DarkBlue90 = Color(0xFFDEE0FF) - -val Yellow10 = Color(0xFF261900) -val Yellow20 = Color(0xFF402D00) -val Yellow30 = Color(0xFF5C4200) -val Yellow40 = Color(0xFF7A5900) -val Yellow80 = Color(0xFFFABD1B) -val Yellow90 = Color(0xFFFFDE9C) - -val Red10 = Color(0xFF410001) -val Red20 = Color(0xFF680003) -val Red30 = Color(0xFF930006) -val Red40 = Color(0xFFBA1B1B) -val Red80 = Color(0xFFFFB4A9) -val Red90 = Color(0xFFFFDAD4) - -val Grey10 = Color(0xFF191C1D) -val Grey20 = Color(0xFF2D3132) -val Grey80 = Color(0xFFC4C7C7) -val Grey90 = Color(0xFFE0E3E3) -val Grey95 = Color(0xFFEFF1F1) -val Grey99 = Color(0xFFFBFDFD) - -val BlueGrey30 = Color(0xFF45464F) -val BlueGrey50 = Color(0xFF767680) -val BlueGrey60 = Color(0xFF90909A) -val BlueGrey80 = Color(0xFFC6C5D0) -val BlueGrey90 = Color(0xFFE2E1EC) \ No newline at end of file diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Theme.kt b/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Theme.kt deleted file mode 100644 index cbc61567b7..0000000000 --- a/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Theme.kt +++ /dev/null @@ -1,107 +0,0 @@ -package ai.mlc.mlcchat.ui.theme - -import android.app.Activity -import android.os.Build -import androidx.compose.foundation.isSystemInDarkTheme -import androidx.compose.material3.MaterialTheme -import androidx.compose.material3.darkColorScheme -import androidx.compose.material3.dynamicDarkColorScheme -import androidx.compose.material3.dynamicLightColorScheme -import androidx.compose.material3.lightColorScheme -import androidx.compose.runtime.Composable -import androidx.compose.runtime.SideEffect -import androidx.compose.ui.graphics.Color -import androidx.compose.ui.graphics.toArgb -import androidx.compose.ui.platform.LocalContext -import androidx.compose.ui.platform.LocalView -import androidx.core.view.WindowCompat - -private val DarkColorScheme = darkColorScheme( - primary = Blue80, - onPrimary = Blue20, - primaryContainer = Blue30, - onPrimaryContainer = Blue90, - inversePrimary = Blue40, - secondary = DarkBlue80, - onSecondary = DarkBlue20, - secondaryContainer = DarkBlue30, - onSecondaryContainer = DarkBlue90, - tertiary = Yellow80, - onTertiary = Yellow20, - tertiaryContainer = Yellow30, - onTertiaryContainer = Yellow90, - error = Red80, - onError = Red20, - errorContainer = Red30, - onErrorContainer = Red90, - background = Grey10, - onBackground = Grey90, - surface = Grey10, - onSurface = Grey80, - inverseSurface = Grey90, - inverseOnSurface = Grey20, - surfaceVariant = BlueGrey30, - onSurfaceVariant = BlueGrey80, - outline = BlueGrey60 -) - -private val LightColorScheme = lightColorScheme( - primary = Blue40, - onPrimary = Color.White, - primaryContainer = Blue90, - onPrimaryContainer = Blue10, - inversePrimary = Blue80, - secondary = DarkBlue40, - onSecondary = Color.White, - secondaryContainer = DarkBlue90, - onSecondaryContainer = DarkBlue10, - tertiary = Yellow40, - onTertiary = Color.White, - tertiaryContainer = Yellow90, - onTertiaryContainer = Yellow10, - error = Red40, - onError = Color.White, - errorContainer = Red90, - onErrorContainer = Red10, - background = Grey99, - onBackground = Grey10, - surface = Grey99, - onSurface = Grey10, - inverseSurface = Grey20, - inverseOnSurface = Grey95, - surfaceVariant = BlueGrey90, - onSurfaceVariant = BlueGrey30, - outline = BlueGrey50 -) - -@Composable -fun MLCChatTheme( - darkTheme: Boolean = isSystemInDarkTheme(), - // Dynamic color is available on Android 12+ - dynamicColor: Boolean = true, - content: @Composable () -> Unit -) { - val colorScheme = when { - dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> { - val context = LocalContext.current - if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context) - } - - darkTheme -> DarkColorScheme - else -> LightColorScheme - } - val view = LocalView.current - if (!view.isInEditMode) { - SideEffect { - val window = (view.context as Activity).window - window.statusBarColor = colorScheme.primary.toArgb() - WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme - } - } - - MaterialTheme( - colorScheme = colorScheme, - typography = Typography, - content = content - ) -} \ No newline at end of file diff --git a/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Type.kt b/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Type.kt deleted file mode 100644 index 30e70c2b7e..0000000000 --- a/android/app/src/main/java/ai/mlc/mlcchat/ui/theme/Type.kt +++ /dev/null @@ -1,34 +0,0 @@ -package ai.mlc.mlcchat.ui.theme - -import androidx.compose.material3.Typography -import androidx.compose.ui.text.TextStyle -import androidx.compose.ui.text.font.FontFamily -import androidx.compose.ui.text.font.FontWeight -import androidx.compose.ui.unit.sp - -// Set of Material typography styles to start with -val Typography = Typography( - bodyLarge = TextStyle( - fontFamily = FontFamily.Default, - fontWeight = FontWeight.Normal, - fontSize = 16.sp, - lineHeight = 24.sp, - letterSpacing = 0.5.sp - ) - /* Other default text styles to override - titleLarge = TextStyle( - fontFamily = FontFamily.Default, - fontWeight = FontWeight.Normal, - fontSize = 22.sp, - lineHeight = 28.sp, - letterSpacing = 0.sp - ), - labelSmall = TextStyle( - fontFamily = FontFamily.Default, - fontWeight = FontWeight.Medium, - fontSize = 11.sp, - lineHeight = 16.sp, - letterSpacing = 0.5.sp - ) - */ -) \ No newline at end of file diff --git a/android/app/src/main/res/drawable/ic_android_black_24dp.xml b/android/app/src/main/res/drawable/ic_android_black_24dp.xml deleted file mode 100644 index fe51230740..0000000000 --- a/android/app/src/main/res/drawable/ic_android_black_24dp.xml +++ /dev/null @@ -1,5 +0,0 @@ - - - diff --git a/android/app/src/main/res/drawable/mlc_logo_108.xml b/android/app/src/main/res/drawable/mlc_logo_108.xml deleted file mode 100644 index d5307e0979..0000000000 --- a/android/app/src/main/res/drawable/mlc_logo_108.xml +++ /dev/null @@ -1,11 +0,0 @@ - - - diff --git a/android/app/src/main/res/values/colors.xml b/android/app/src/main/res/values/colors.xml deleted file mode 100644 index f8c6127d32..0000000000 --- a/android/app/src/main/res/values/colors.xml +++ /dev/null @@ -1,10 +0,0 @@ - - - #FFBB86FC - #FF6200EE - #FF3700B3 - #FF03DAC5 - #FF018786 - #FF000000 - #FFFFFFFF - \ No newline at end of file diff --git a/android/app/src/main/res/values/strings.xml b/android/app/src/main/res/values/strings.xml deleted file mode 100644 index a6b10f5b60..0000000000 --- a/android/app/src/main/res/values/strings.xml +++ /dev/null @@ -1,3 +0,0 @@ - - MLCChat - \ No newline at end of file diff --git a/android/app/src/main/res/values/themes.xml b/android/app/src/main/res/values/themes.xml deleted file mode 100644 index a16e9d4b0e..0000000000 --- a/android/app/src/main/res/values/themes.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - - - - - - - - - - - - - - diff --git a/docs/_static/img/project-structure.svg b/docs/_static/img/project-structure.svg deleted file mode 100644 index e4ad7db6b1..0000000000 --- a/docs/_static/img/project-structure.svg +++ /dev/null @@ -1,1189 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/_static/img/project-workflow.svg b/docs/_static/img/project-workflow.svg deleted file mode 100644 index eac1313a44..0000000000 --- a/docs/_static/img/project-workflow.svg +++ /dev/null @@ -1,1173 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/community/faq.rst b/docs/community/faq.rst deleted file mode 100644 index 4bc6f9deb8..0000000000 --- a/docs/community/faq.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. _FAQ: - -Frequently Asked Questions -========================== - -This is a list of Frequently Asked Questions (FAQ) about the MLC-LLM. Feel free to suggest new entries! - -... How can I customize the temperature, and repetition penalty of models? - Please check our :ref:`configure-mlc-chat-json` tutorial. - -... What's the quantization algorithm MLC-LLM using? - Please check our :doc:`/compilation/configure_quantization` tutorial. - -... Why do I encounter an error ``free(): invalid pointer, Aborted (core dumped)`` at the end of model compilation? - This happens if you compiled TVM-Unity from source and didn't hide LLVM symbols in cmake configurations. - Please follow our instructions in :ref:`Building TVM Unity from Source ` tutorial to compile TVM-Unity which hides LLVM symbols, or use our pre-built MLC-LLM :doc:`pip wheels <../install/mlc_llm>`. diff --git a/docs/community/guideline.rst b/docs/community/guideline.rst deleted file mode 100644 index 33e8982543..0000000000 --- a/docs/community/guideline.rst +++ /dev/null @@ -1,125 +0,0 @@ -.. _community_guide: - -Community Guideline -=================== - -.. contents:: - :depth: 2 - :local: - -Welcome to the MLC-LLM community! Just like you, all of us are in awe of the immense power of large language models. -Our goal for MLC-LLM is to foster a project that is driven by an open-source community, working together to democratize -this technology and make it accessible across various devices. We are thrilled to have you as part of our -community and eagerly anticipate your valuable contributions. - - -.. _community_discussion: - -Participate in Community Discussions ------------------------------------- - -We encourage open discussions. If you encounter a bug or have a feature request, please file an issue in MLC-LLM's -GitHub `issue tracker `__. You are encouraged to tag the issue with labels -such as "bug," "feature request," or "iOS" so that the relevant developers can quickly notice your concern. - -Additionally, we have set up a `discord server `__ for online discussions. -While we encourage participation in the Discord server, we also recommend creating a GitHub issue even if the -topic has been discussed there. This ensures that the discussion is archived and searchable for future reference. - -Before submitting an issue, we kindly ask you to check our :doc:`/community/faq` to see if your question has already been answered. - -.. _contribute-to-mlc-llm: - -Contribute to MLC-LLM ---------------------- - -.. _fork-and-create-pull-requests: - -Fork and Create Pull Requests -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Ready to contribute to MLC-LLM? Awesome! We are excited to see you are ready to contribute your code. -The standard way to make changes to MLC-LLM code base is through creating a `pull-request `__, -and we will review your code and merge it to the code base when it is ready. - -The first step to becoming a developer is to `fork `__ the repository to your own -github account, you will notice a repository under ``https://github.com/username/mlc-llm`` where ``username`` is your github user name. - -You can clone your fork to your local machine and commit changes, or edit the contents of your fork (in the case you are just fixing typos) -on GitHub directly. Once your update is complete, you can click the ``contribute`` button and open a pull request to the main repository. - -.. _contribute-new-models: - -Contribute New Models to MLC-LLM -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -* If you have compiled a model using our :doc:`/compilation/compile_models` tutorial for an existing model architecture, please upload your models to the internet (e.g., Hugging Face) by following :ref:`distribute-compiled-models` tutorial. Once you have done that, you can create a pull request to add an entry in the :doc:`/prebuilt_models` page. Additionally, you have the option to `create a speed report issue `__ to track the speed and memory consumption of your model. You don't need to test it on all devices; let the community collaborate on building it together! - -* If you add a new model variant to MLC-LLM by following our :doc:`/compilation/define_new_models` tutorial. - Please create a pull request to add your model architecture (currently model architectures are placed under - `relax_models `__ folder). - -.. _coding-styles: - -Coding Styles -^^^^^^^^^^^^^ - -For python codes, we generally follow the `PEP8 style guide `__. -The python comments follow `NumPy style `__ python docstrings. -To make things easy, you can use `black `__ to automatically format your python code. - -.. code:: bash - - pip install black - black your_python_file.py - -For C++ codes, we generally follow the `Google C++ style guide `__. -The C++ comments should be `Doxygen compatible `__. -Fo your convenience, you can use `clang-format `__ to automatically format your C++ code. - -.. code:: bash - - clang-format -i your_cpp_file.cpp - -.. _general-development-process: - -General Development Process ---------------------------- - -Everyone in the community is welcome to send patches, documents, and propose new directions to the project. -The key guideline here is to enable everyone in the community to get involved and participate in the decision and development. -We encourage public discussion in different channels, so that everyone in the community can participate -and get informed in developments. - -Code reviews are one of the key ways to ensure the quality of the code. High-quality code reviews prevent technical debt -for long-term and are crucial to the success of the project. A pull request needs to be reviewed before it gets merged. -A committer who has the expertise of the corresponding area would moderate the pull request and merge the code when -it is ready. The corresponding committer could request multiple reviewers who are familiar with the area of the code. -We encourage contributors to request code reviews themselves and help review each other's code -- remember everyone -is volunteering their time to the community, high-quality code review itself costs as much as the actual code -contribution, you could get your code quickly reviewed if you do others the same favor. - -The community should strive to reach a consensus on technical decisions through discussion. We expect committers to -moderate technical discussions in a diplomatic way, and provide suggestions with clear technical reasoning when necessary. - - -.. _roles-committers: - -Committers -^^^^^^^^^^ - -Committers are individuals who are granted with write access to the project. A committer is usually responsible for -a certain area or several areas of the code where they oversee the code review process. -The area of contribution can take all forms, including code contributions and code reviews, documents, education, and outreach. -The review of pull requests will be assigned to the committers who recently contribute to the area this PR belongs to. -Committers are essential for a high quality and healthy project. The community actively looks for new committers -from contributors. Each existing committer can nominate new committers to MLC projects. - -.. _roles-contributors: - -Contributors -^^^^^^^^^^^^ -We also welcome contributors if you are not ready to be a committer yet. Everyone who contributes to -the project (in the form of code, bugfix, documentation, tutorials, etc) is a contributor. -We maintain a `page `__ to acknowledge contributors, -please let us know if you contribute to the project and if your name is not included in the list. diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst deleted file mode 100644 index 4706e09811..0000000000 --- a/docs/compilation/compile_models.rst +++ /dev/null @@ -1,1055 +0,0 @@ -.. _compile-model-libraries: - -Compile Model Libraries -======================= - -To run a model with MLC LLM in any platform, you need: - -1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC `__.) -2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). - -If you are simply adding a model variant, follow :ref:`convert-weights-via-MLC` suffices. - -This page describes how to compile a model library with MLC LLM. Model compilation optimizes -the model inference for a given platform, allowing users bring their own new model -architecture, use different quantization modes, and customize the overall model -optimization flow. - -We compile ``RedPajama-INCITE-Chat-3B-v1`` with ``q4f16_1`` as an example for all platforms. - -.. note:: - Before you proceed, make sure you followed :ref:`install-tvm-unity`, a required - backend to compile models with MLC LLM. - - Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-chat-module` to obtain - the CLI app / Python API that can be used to chat with the compiled model. - Finally, we strongly recommend you to read :ref:`project-overview` first to get - familiarized with the high-level terminologies. - -.. contents:: Table of Contents - :depth: 1 - :local: - -0. Verify Installation ----------------------- - -**Step 1. Verify mlc_llm** - -We use the python package ``mlc_llm`` to compile models. This can be installed by -following :ref:`install-mlc-packages`, either by building from source, or by -installing the prebuilt package. Verify ``mlc_llm`` installation in command line via: - -.. code:: bash - - $ mlc_llm --help - # You should see help information with this line - usage: MLC LLM Command Line Interface. [-h] {compile,convert_weight,gen_config} - -.. note:: - If it runs into error ``command not found: mlc_llm``, try ``python -m mlc_llm --help``. - -**Step 2. Verify TVM** - -To compile models, you also need to follow :ref:`install-tvm-unity`. -Here we verify ``tvm`` quickly with command line (for full verification, see :ref:`tvm-unity-validate`): - -.. code:: bash - - $ python -c "import tvm; print(tvm.__file__)" - /some-path/lib/python3.11/site-packages/tvm/__init__.py - -1. Clone from HF and convert_weight ------------------------------------ - -This replicates :ref:`convert-weights-via-MLC`, see that page for more details. - -You can be under the mlc-llm repo, or your own working directory. Note that all platforms -can share the same compiled/quantized weights. - -.. code:: shell - - # Create directory - mkdir -p dist/models && cd dist/models - # Clone HF weights - git lfs install - git clone https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1 - cd ../.. - # Convert weight - mlc_llm convert_weight ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ - --quantization q4f16_1 \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - -2. Generate mlc-chat-config and compile ---------------------------------------- - -A model library is specified by: - - - The model architecture (e.g. ``llama-2``, ``gpt-neox``) - - Quantization (e.g. ``q4f16_1``, ``q0f32``) - - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill-chunk-size``), which affects memory planning - - Platform (e.g. ``cuda``, ``webgpu``, ``iOS``) - -All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_config``. - -.. code:: shell - - # Create output directory for the model library compiled - mkdir dist/libs - -.. tabs:: - - .. group-tab:: Linux - CUDA - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ - --quantization q4f16_1 --conv-template redpajama_chat \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ - --device cuda -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so - - - .. group-tab:: Metal - - For M-chip Mac: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ - --quantization q4f16_1 --conv-template redpajama_chat \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ - --device metal -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so - - Cross-Compiling for Intel Mac on M-chip Mac: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ - --quantization q4f16_1 --conv-template redpajama_chat \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ - --device metal:x86-64 -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal_x86_64.dylib - - For Intel Mac: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ - --quantization q4f16_1 --conv-template redpajama_chat \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ - --device metal -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal_x86_64.dylib - - - .. group-tab:: Vulkan - - For Linux: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ - --quantization q4f16_1 --conv-template redpajama_chat \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ - --device vulkan -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so - - For Windows: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ - --quantization q4f16_1 --conv-template redpajama_chat \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ - --device vulkan -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.dll - - .. group-tab:: iOS/iPadOS - - You need a Mac to compile models for it. - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ --quantization q4f16_1 \ - --conv-template redpajama_chat --context-window-size 768 \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ - --device iphone -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar - - .. note:: - If it runs into error - - .. code:: text - - Compilation error: - xcrun: error: unable to find utility "metal", not a developer tool or in PATH - xcrun: error: unable to find utility "metallib", not a developer tool or in PATH - - , please check and make sure you have Command Line Tools for Xcode installed correctly. - You can use ``xcrun metal`` to validate: when it prints ``metal: error: no input files``, it means the Command Line Tools for Xcode is installed and can be found, and you can proceed with the model compiling. - - .. group-tab:: Android - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ --quantization q4f16_1 \ - --conv-template redpajama_chat --context-window-size 768 \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ - --device android -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar - - .. group-tab:: WebGPU - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ - --quantization q4f16_1 --conv-template redpajama_chat \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ - --device webgpu -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm - - .. note:: - To compile for webgpu, you need to build from source when installing ``mlc_llm``. Besides, you also need to follow :ref:`install-web-build`. - Otherwise, it would run into error - - .. code:: text - - RuntimeError: Cannot find libraries: wasm_runtime.bc - - .. note:: - For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage. - Otherwise, you may run into issues like: - - .. code:: text - - TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from - 'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range. - -.. note:: - - For the ``conv-template``, `conv_template.cc `__ - contains a full list of conversation templates that MLC provides. If the model you are adding - requires a new conversation template, you would need to add your own. - Follow `this PR `__ as an example. - However, adding your own template would require you :ref:`build mlc_llm from source ` - in order for it to be recognized by the runtime. - - For more details, please see :ref:`configure-mlc-chat-json`. - -3. Verify output and chat -------------------------- - -By executing the compile command above, we generate the model weights, model lib, and a chat config. -We can check the output with the commands below: - -.. tabs:: - - .. group-tab:: Linux - CUDA - - .. code:: shell - - ~/mlc-llm > ls dist/libs - RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so # ===> the model library - - ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - mlc-chat-config.json # ===> the chat config - ndarray-cache.json # ===> the model weight info - params_shard_0.bin # ===> the model weights - params_shard_1.bin - ... - tokenizer.json # ===> the tokenizer files - tokenizer_config.json - - We can now chat with the model using the command line interface (CLI) app or the Python API. - - .. code:: shell - - python - >>> from mlc_llm import ChatModule - >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ - model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") - >>> cm.generate("hi") - 'Hi! How can I assist you today?' - - .. group-tab:: Metal - - .. code:: shell - - ~/mlc-llm > ls dist/libs - RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so # ===> the model library (will be -metal_x86_64.dylib for Intel Mac) - - ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - mlc-chat-config.json # ===> the chat config - ndarray-cache.json # ===> the model weight info - params_shard_0.bin # ===> the model weights - params_shard_1.bin - ... - tokenizer.json # ===> the tokenizer files - tokenizer_config.json - - We can now chat with the model using the command line interface (CLI) app or the Python API. - - .. code:: shell - - python - >>> from mlc_llm import ChatModule - >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ - model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so") - >>> cm.generate("hi") - 'Hi! How can I assist you today?' - - - .. group-tab:: Vulkan - - .. code:: shell - - ~/mlc-llm > ls dist/libs - RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so # ===> the model library (will be .dll for Windows) - - ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - mlc-chat-config.json # ===> the chat config - ndarray-cache.json # ===> the model weight info - params_shard_0.bin # ===> the model weights - params_shard_1.bin - ... - tokenizer.json # ===> the tokenizer files - tokenizer_config.json - - We can now chat with the model using the command line interface (CLI) app or the Python API. - - .. code:: shell - - python - >>> from mlc_llm import ChatModule - >>> cm = ChatModule(model="./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", \ - model_lib_path="./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so", device="vulkan") - >>> cm.generate("hi") - 'Hi! How can I assist you today?' - - .. group-tab:: iOS/iPadOS - - .. code:: shell - - ~/mlc-llm > ls dist/libs - RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar # ===> the model library - - ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - mlc-chat-config.json # ===> the chat config - ndarray-cache.json # ===> the model weight info - params_shard_0.bin # ===> the model weights - params_shard_1.bin - ... - tokenizer.json # ===> the tokenizer files - tokenizer_config.json - - The model lib ``dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar`` - will be packaged as a static library into the iOS app. Checkout :ref:`deploy-ios` for more details. - - .. group-tab:: Android - - .. code:: shell - - ~/mlc-llm > ls dist/libs - RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar # ===> the model library - - ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - mlc-chat-config.json # ===> the chat config - ndarray-cache.json # ===> the model weight info - params_shard_0.bin # ===> the model weights - params_shard_1.bin - ... - tokenizer.json # ===> the tokenizer files - tokenizer_config.json - - The model lib ``dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar`` - will be packaged as a static library into the android app. Checkout :ref:`deploy-android` for more details. - - .. group-tab:: WebGPU - - .. code:: shell - - ~/mlc-llm > ls dist/libs - RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm # ===> the model library - - ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - mlc-chat-config.json # ===> the chat config - ndarray-cache.json # ===> the model weight info - params_shard_0.bin # ===> the model weights - params_shard_1.bin - ... - tokenizer.json # ===> the tokenizer files - tokenizer_config.json - - To use this in WebGPU runtime, checkout :ref:`webllm-runtime`. - -Compile Commands for More Models --------------------------------- - -This section lists compile commands for more models that you can try out. Note that this can be easily -generalized to any model variant, as long as mlc-llm supports the architecture. - -.. tabs:: - - .. tab:: Model: Llama-2-7B - - Please `request for access `_ to the Llama-2 weights from Meta first. - After granted access, first create directory ``dist/models`` and download the model to the directory. - For example, you can run the following code: - - .. code:: shell - - mkdir -p dist/models && cd dist/models - git lfs install - git clone https://huggingface.co/meta-llama/Llama-2-7b-chat-hf - cd ../.. - - Then convert the HF weights into MLC-compatible weights. Note that all platforms - can share the same compiled/quantized weights. - - .. code:: shell - - mlc_llm convert_weight ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC - - Afterwards, run the following command to generate mlc config and compile the model. - - .. code:: shell - - # Create output directory for the model library compiled - mkdir dist/libs - - .. tabs:: - - .. tab:: Target: CUDA - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ - --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device cuda -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-cuda.so - - .. tab:: Metal - - For M-chip Mac: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ - --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device metal -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-metal.so - - Cross-Compiling for Intel Mac on M-chip Mac: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ - --quantization q4f16_1 --conv-template redpajama_chat \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ - --device metal:x86-64 -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal_x86_64.dylib - - For Intel Mac: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ - --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device metal -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-metal_x86_64.dylib - - .. tab:: Vulkan - - For Linux: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ - --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device vulkan -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-vulkan.so - - For Windows: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ - --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device vulkan -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-vulkan.dll - - .. tab:: WebGPU - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ - --context-window-size 2048 --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device webgpu -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-webgpu.wasm - - .. note:: - To compile for webgpu, you need to build from source when installing ``mlc_llm``. Besides, you also need to follow :ref:`install-web-build`. - Otherwise, it would run into error - - .. code:: text - - RuntimeError: Cannot find libraries: wasm_runtime.bc - - .. tab:: iPhone/iPad - - You need a Mac to compile models for it. - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ - --conv-template llama-2 --context-window-size 768 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device iphone -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-iphone.tar - - .. tab:: Android - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \ - --conv-template llama-2 --context-window-size 768 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device android -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-android.tar - - .. tab:: Mistral-7B-Instruct-v0.2 - - Note that Mistral uses sliding window attention (SWA). Thus, instead of specifying - ``context-window-size``, we specify ``sliding-window-size``. - - First create directory ``dist/models`` and download the model to the directory. - For example, you can run the following code: - - .. code:: shell - - mkdir -p dist/models && cd dist/models - git lfs install - git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2 - cd ../.. - - Then convert the HF weights into MLC-compatible weights. Note that all platforms - can share the same compiled/quantized weights. - - .. code:: shell - - mlc_llm convert_weight ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ - -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC - - Afterwards, run the following command to generate mlc config and compile the model. - - .. code:: shell - - # Create output directory for the model library compiled - mkdir dist/libs - - .. tabs:: - - .. tab:: Target: CUDA - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ - --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ - --device cuda -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so - - .. tab:: Metal - - For M-chip Mac: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ - --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ - --device metal -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-metal.so - - - For Intel Mac: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ - --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ - --device metal -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-metal_x86_64.dylib - - .. tab:: Vulkan - - For Linux: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ - --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ - --device vulkan -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-vulkan.so - - For Windows: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ - --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ - --device vulkan -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-vulkan.dll - - .. tab:: WebGPU - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ - --prefill-chunk-size 1024 --conv-template mistral_default \ - -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ - --device webgpu -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-webgpu.wasm - - .. note:: - To compile for webgpu, you need to build from source when installing ``mlc_llm``. Besides, you also need to follow :ref:`install-web-build`. - Otherwise, it would run into error - - .. code:: text - - RuntimeError: Cannot find libraries: wasm_runtime.bc - - .. note:: - For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage. - Otherwise, you may run into issues like: - - .. code:: text - - TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from - 'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range. - - .. tab:: iPhone/iPad - - You need a Mac to compile models for it. - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ - --conv-template mistral_default --sliding-window-size 1024 --prefill-chunk-size 128 \ - -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ - --device iphone -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-iphone.tar - - .. tab:: Android - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \ - --conv-template mistral_default --sliding-window-size 1024 --prefill-chunk-size 128 -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \ - --device android -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-android.tar - - .. tab:: Other models - - First create directory ``dist/models`` and download the model to the directory. - For example, you can run the following code: - - .. code:: shell - - mkdir -p dist/models && cd dist/models - git lfs install - git clone https://huggingface.co/DISTRIBUTOR/HF_MODEL - cd ../.. - - Then convert the HF weights into MLC-compatible weights. Note that all platforms - can share the same compiled/quantized weights. - - .. code:: shell - - mlc_llm convert_weight ./dist/models/HF_MODEL/ --quantization q4f16_1 -o dist/OUTPUT-MLC - - Afterwards, run the following command to generate mlc config and compile the model. - - .. code:: shell - - # Create output directory for the model library compiled - mkdir dist/libs - - .. tabs:: - - .. tab:: Target: CUDA - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device cuda -o dist/libs/OUTPUT-cuda.so - - .. tab:: Metal - - For M-chip Mac: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device metal -o dist/libs/OUTPUT-metal.so - - - For Intel Mac: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device metal -o dist/libs/OUTPUT-metal_x86_64.dylib - - .. tab:: Vulkan - - For Linux: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device vulkan -o dist/libs/OUTPUT-vulkan.so - - For Windows: - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device vulkan -o dist/libs/OUTPUT-vulkan.dll - - .. tab:: WebGPU - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device webgpu -o dist/libs/OUTPUT-webgpu.wasm - - .. note:: - To compile for webgpu, you need to build from source when installing ``mlc_llm``. Besides, you also need to follow :ref:`install-web-build`. - Otherwise, it would run into error - - .. code:: text - - RuntimeError: Cannot find libraries: wasm_runtime.bc - - .. note:: - For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage. - Otherwise, you may run into issues like: - - .. code:: text - - TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from - 'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range. - - .. tab:: iPhone/iPad - - You need a Mac to compile models for it. - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE \ - --context-window-size 768 -o dist/OUTPUT-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device iphone -o dist/libs/OUTPUT-iphone.tar - - .. tab:: Android - - .. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE \ - --context-window-size 768 -o dist/OUTPUT-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device android -o dist/libs/OUTPUT-android.tar - -For each model and each backend, the above only provides the most recommended build command (which is the most optimized). -You can also try with different argument values (e.g., different quantization modes, context window size, etc.), -whose build results affect runtime memory requirement, and it is possible that they may not run as -fast and robustly as the provided one when running the model. - -.. note:: - Uing 3-bit quantization usually can be overly aggressive and only works for limited settings. - If you encounter issues where the compiled model does not perform as expected, - consider utilizing a higher number of bits for quantization (e.g., 4-bit quantization). - -If you are interested in distributing the model besides local execution, please checkout :ref:`distribute-compiled-models`. - - -.. _compile-command-specification: - -Compile Command Specification ------------------------------ - -As you have seen in the section above, the model compilation is split into three steps: convert weights, generate -``mlc-chat-config.json``, and compile the model. This section describes the list of options that can be used -during compilation. - -1. Convert Weight -^^^^^^^^^^^^^^^^^ - -Weight conversion command follows the pattern below: - -.. code:: text - - mlc_llm convert_weight \ - CONFIG \ - --quantization QUANTIZATION_MODE \ - [--model-type MODEL_TYPE] \ - [--device DEVICE] \ - [--source SOURCE] \ - [--source-format SOURCE_FORMAT] \ - --output OUTPUT - -Note that ``CONFIG`` is a positional argument. Arguments wrapped with ``[ ]`` are optional. - ---CONFIG It can be one of the following: - - 1. Path to a HuggingFace model directory that contains a ``config.json`` or - 2. Path to ``config.json`` in HuggingFace format, or - 3. The name of a pre-defined model architecture. - - A ``config.json`` file in HuggingFace format defines the model architecture, including the vocabulary - size, the number of layers, the hidden size, number of attention heads, etc. - Example: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json. - - A HuggingFace directory often contains a ``config.json`` which defines the model architecture, - the non-quantized model weights in PyTorch or SafeTensor format, tokenizer configurations, - as well as an optional ``generation_config.json`` provides additional default configuration for - text generation. - Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main. - - For existing pre-defined model architecture, see ``MODEL_PRESETS`` - `here `_. - ---quantization QUANTIZATION_MODE The quantization mode we use to compile. - - See :ref:`quantization_mode` for more information. - Available options are: ``q0f16``, ``q0f32``, ``q3f16_1``, ``q4f16_1``, ``q4f32_1``, and - ``q4f16_awq``. - - We encourage you to use 4-bit quantization, as the text generated by 3-bit - quantized models may have bad quality depending on the model. - ---model-type MODEL_TYPE Model architecture such as "llama". If not set, it is inferred from ``config.json``. - ---device DEVICE The device used to do quantization such as "cuda" or "cuda:0". Will detect from - local available GPUs if not specified. - ---source SOURCE The path to original model weight, infer from ``config`` if missing. - ---source-format SOURCE_FORMAT The format of source model weight, infer from ``config`` if missing. - ---output OUTPUT The output directory to save the quantized model weight. - Will create ``params_shard_*.bin`` and ```ndarray-cache.json``` in this directory. - -2. Generate MLC Chat Config -^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In order to compile a model, we first need to generate the ``mlc-chat-config.json``. This file contains specifications -like ``context-window-size`` and ``sliding-window-size``, among others that can alter the model compiled. We also process -tokenizers in this step. - -Config generation command follows the pattern below: - -.. code:: text - - mlc_llm gen_config \ - CONFIG \ - --quantization QUANTIZATION_MODE \ - [--model-type MODEL_TYPE] \ - --conv-template CONV_TEMPLATE \ - [--context-window-size CONTEXT_WINDOW_SIZE] \ - [--sliding-window-size SLIDING_WINDOW_SIZE] \ - [--prefill-chunk-size PREFILL_CHUNK_SIZE] \ - [--tensor-parallel-shard TENSOR_PARALLEL_SHARDS] \ - --output OUTPUT - -Note that ``CONFIG`` is a positional argument. Arguments wrapped with ``[ ]`` are optional. - ---CONFIG It can be one of the following: - - 1. Path to a HuggingFace model directory that contains a ``config.json`` or - 2. Path to ``config.json`` in HuggingFace format, or - 3. The name of a pre-defined model architecture. - - A ``config.json`` file in HuggingFace format defines the model architecture, including the vocabulary - size, the number of layers, the hidden size, number of attention heads, etc. - Example: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json. - - A HuggingFace directory often contains a ``config.json`` which defines the model architecture, - the non-quantized model weights in PyTorch or SafeTensor format, tokenizer configurations, - as well as an optional ``generation_config.json`` provides additional default configuration for - text generation. - Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main. - - For existing pre-defined model architecture, see ``MODEL_PRESETS`` - `here `_. - ---quantization QUANTIZATION_MODE The quantization mode we use to compile. - - See :ref:`quantization_mode` for more information. - Available options are: ``q0f16``, ``q0f32``, ``q3f16_1``, ``q4f16_1``, ``q4f32_1``, and - ``q4f16_awq``. - - We encourage you to use 4-bit quantization, as the text generated by 3-bit - quantized models may have bad quality depending on the model. - ---model-type MODEL_TYPE Model architecture such as "llama". If not set, it is inferred from ``config.json``. - ---conv-template CONV_TEMPLATE Conversation template. It depends on how the model is tuned. Use "LM" for vanilla base model - For existing pre-defined templates, see ``CONV_TEMPLATES`` - `here `_. - ---context-window-size CONTEXT_WINDOW_SIZE Option to provide the maximum sequence length supported by the model. - This is usually explicitly shown as context length or context window in the model card. - If this option is not set explicitly, by default, - it will be determined by ``context_window_size`` or ``max_position_embeddings`` in ``config.json``, - and the latter is usually inaccurate for some models. - ---sliding-window-size SLIDING_WINDOW (Experimental) The sliding window size in sliding window attention (SWA). - This optional field overrides the ``sliding_window`` in ``config.json`` for - those models that use SWA. Currently only useful when compiling mistral-based models. - This flag subjects to future refactoring. - ---prefill-chunk-size PREFILL_CHUNK_SIZE (Experimental) The chunk size during prefilling. By default, - the chunk size is the same as ``context_window_size`` or ``sliding_window_size``. - This flag subjects to future refactoring. - ---tensor-parallel-shard TENSOR_PARALLEL_SHARDS Number of shards to split the model into in tensor parallelism multi-gpu inference. - ---output OUTPUT The output directory for generated configurations, including `mlc-chat-config.json` and tokenizer configuration. - -3. Compile Model Library -^^^^^^^^^^^^^^^^^^^^^^^^ - -After generating ``mlc-chat-config.json``, we can compile the model into a model library (files ending in ``.so``, ``.tar``, etc. that contains -the inference logic of a model). - -Model compilation command follows the pattern below: - -.. code:: text - - mlc_llm compile \ - MODEL \ - [--quantization QUANTIZATION_MODE] \ - [--model-type MODEL_TYPE] \ - [--device DEVICE] \ - [--host HOST] \ - [--opt OPT] \ - [--system-lib-prefix SYSTEM_LIB_PREFIX] \ - --output OUTPUT \ - [--overrides OVERRIDES] - -Note that ``MODEL`` is a positional argument. Arguments wrapped with ``[ ]`` are optional. - ---MODEL A path to ``mlc-chat-config.json``, or an MLC model directory that contains ``mlc-chat-config.json``. - ---quantization QUANTIZATION_MODE The quantization mode we use to compile. If unprovided, will infer from ``MODEL``. - - See :ref:`quantization_mode` for more information. - Available options are: ``q0f16``, ``q0f32``, ``q3f16_1``, ``q4f16_1``, ``q4f32_1``, and - ``q4f16_awq``. - - We encourage you to use 4-bit quantization, as the text generated by 3-bit - quantized models may have bad quality depending on the model. - ---model-type MODEL_TYPE Model architecture such as "llama". If not set, it is inferred from ``mlc-chat-config.json``. - ---device DEVICE The GPU device to compile the model to. If not set, it is inferred from GPUs available locally. - ---host HOST The host LLVM triple to compile the model to. If not set, it is inferred from the local CPU and OS. - Examples of the LLVM triple: - - 1) iPhones: arm64-apple-ios; - 2) ARM64 Android phones: aarch64-linux-android; - 3) WebAssembly: wasm32-unknown-unknown-wasm; - 4) Windows: x86_64-pc-windows-msvc; - 5) ARM macOS: arm64-apple-darwin. - ---opt OPT Optimization flags. MLC LLM maintains a predefined set of optimization flags, - denoted as ``O0``, ``O1``, ``O2``, ``O3``, where ``O0`` means no optimization, ``O2`` - means majority of them, and ``O3`` represents extreme optimization that could - potentially break the system. - - Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. - ``--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0"``. - ---system-lib-prefix SYSTEM_LIB_PREFIX Adding a prefix to all symbols exported. Similar to ``objcopy --prefix-symbols``. - This is useful when compiling multiple models into a single library to avoid symbol - conflicts. Different from objcopy, this takes no effect for shared library. - - ---output OUTPUT The path to the output file. The suffix determines if the output file is a shared library or - objects. Available suffixes: - - 1) Linux: .so (shared), .tar (objects); - 2) macOS: .dylib (shared), .tar (objects); - 3) Windows: .dll (shared), .tar (objects); - 4) Android, iOS: .tar (objects); - 5) Web: .wasm (web assembly). - ---overrides OVERRIDES Model configuration override. Configurations to override ``mlc-chat-config.json``. Supports - ``context_window_size``, ``prefill_chunk_size``, ``sliding_window``, ``max_batch_size`` and - ``tensor_parallel_shards``. Meanwhile, model config could be explicitly specified via details - knobs, e.g. ``--overrides "context_window_size=1024;prefill_chunk_size=128"``. diff --git a/docs/compilation/configure_quantization.rst b/docs/compilation/configure_quantization.rst deleted file mode 100644 index d66f8416fc..0000000000 --- a/docs/compilation/configure_quantization.rst +++ /dev/null @@ -1,22 +0,0 @@ -🚧 Configure Quantization -========================= - -Quantization Algorithm ----------------------- - -The default quantization algorithm used in MLC-LLM is grouping quantization method discussed in the papers `The case for 4-bit precision: k-bit Inference Scaling Laws `__ and `LUT-GEMM: Quantized Matrix Multiplication based on LUTs for Efficient Inference in Large-Scale Generative Language Models `__. - -.. _quantization_mode: - -Quantization Mode ------------------ - -In MLC-LLM we use a short code that indicates the quantization mode to use. - -The format of the code is ``qAfB(_id)``, where ``A`` represents the number -of bits for storing weights and ``B`` represents the number of bits for storing activations. -The ``_id`` is an integer identifier to distinguish different quantization algorithms (e.g. symmetric, non-symmetric, AWQ, etc). - -Currently, available options are: ``q0f16``, ``q0f32``, ``q3f16_1``, ``q4f16_1``, ``q4f32_1``, and ``q4f16_awq`` (not stable). - -More details to come. \ No newline at end of file diff --git a/docs/compilation/convert_weights.rst b/docs/compilation/convert_weights.rst deleted file mode 100644 index aa65256fd6..0000000000 --- a/docs/compilation/convert_weights.rst +++ /dev/null @@ -1,182 +0,0 @@ -.. _convert-weights-via-MLC: - -Convert Weights via MLC -======================= - -To run a model with MLC LLM in any platform, you need: - -1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC `_.) -2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). - -In many cases, we only need to convert weights and reuse existing model library. -This page demonstrates adding a model variant with ``mlc_llm convert_weight``, which -takes a hugginface model as input and converts/quantizes into MLC-compatible weights. - -Specifically, we add RedPjama-INCITE-**Instruct**-3B-v1, while MLC already -provides a model library for RedPjama-INCITE-**Chat**-3B-v1, which we can reuse. - -This can be extended to, e.g.: - -- Add ``OpenHermes-Mistral`` when MLC already supports Mistral -- Add ``Llama-2-uncensored`` when MLC already supports Llama-2 - -.. note:: - Before you proceed, make sure you followed :ref:`install-tvm-unity`, a required - backend to compile models with MLC LLM. - - Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-chat-module` to obtain - the CLI app / Python API that can be used to chat with the compiled model. - Finally, we strongly recommend you to read :ref:`project-overview` first to get - familiarized with the high-level terminologies. - -.. contents:: Table of Contents - :depth: 1 - :local: - -.. _verify_installation_for_compile: - -0. Verify installation ----------------------- - -**Step 1. Verify mlc_llm** - -We use the python package ``mlc_llm`` to compile models. This can be installed by -following :ref:`install-mlc-packages`, either by building from source, or by -installing the prebuilt package. Verify ``mlc_llm`` installation in command line via: - -.. code:: bash - - $ mlc_llm --help - # You should see help information with this line - usage: MLC LLM Command Line Interface. [-h] {compile,convert_weight,gen_config} - -.. note:: - If it runs into error ``command not found: mlc_llm``, try ``python -m mlc_llm --help``. - -**Step 2. Verify TVM** - -To compile models, you also need to follow :ref:`install-tvm-unity`. -Here we verify ``tvm`` quickly with command line (for full verification, see :ref:`tvm-unity-validate`): - -.. code:: bash - - $ python -c "import tvm; print(tvm.__file__)" - /some-path/lib/python3.11/site-packages/tvm/__init__.py - - -1. Clone from HF and convert_weight ------------------------------------ - -You can be under the mlc-llm repo, or your own working directory. Note that all platforms -can share the same compiled/quantized weights. See :ref:`compile-command-specification` -for specification of ``convert_weight``. - -.. code:: shell - - # Create directory - mkdir -p dist/models && cd dist/models - # Clone HF weights - git lfs install - git clone https://huggingface.co/togethercomputer/RedPajama-INCITE-Instruct-3B-v1 - cd ../.. - # Convert weight - mlc_llm convert_weight ./dist/models/RedPajama-INCITE-Instruct-3B-v1/ \ - --quantization q4f16_1 \ - -o dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC - -.. _generate_mlc_chat_config: - -2. Generate MLC Chat Config ---------------------------- - -Use ``mlc_llm gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. -See :ref:`compile-command-specification` for specification of ``gen_config``. - -.. code:: shell - - mlc_llm gen_config ./dist/models/RedPajama-INCITE-Instruct-3B-v1/ \ - --quantization q4f16_1 --conv-template redpajama_chat \ - -o dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/ - - -.. note:: - The file ``mlc-chat-config.json`` is crucial in both model compilation - and runtime chatting. Here we only care about the latter case. - - You can **optionally** customize - ``dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/mlc-chat-config.json`` (checkout :ref:`configure-mlc-chat-json` for more detailed instructions). - You can also simply use the default configuration. - - `conv_template.cc `__ - contains a full list of conversation templates that MLC provides. If the model you are adding - requires a new conversation template, you would need to add your own. - Follow `this PR `__ as an example. However, - adding your own template would require you :ref:`build mlc_llm from source ` in order for it - to be recognized by the runtime. - -By now, you should have the following files. - -.. code:: shell - - ~/mlc-llm > ls dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC - mlc-chat-config.json # ===> the chat config - ndarray-cache.json # ===> the model weight info - params_shard_0.bin # ===> the model weights - params_shard_1.bin - ... - tokenizer.json # ===> the tokenizer files - tokenizer_config.json - -.. _distribute-compiled-models: - -(Optional) 3. Upload weights to HF ----------------------------------- - -Optionally, you can upload what we have to huggingface. - -.. code:: shell - - # First, please create a repository on Hugging Face. - # With the repository created, run - git lfs install - git clone https://huggingface.co/my-huggingface-account/my-redpajama3b-weight-huggingface-repo - cd my-redpajama3b-weight-huggingface-repo - cp path/to/mlc-llm/dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/* . - git add . && git commit -m "Add redpajama-3b instruct model weights" - git push origin main - -This would result in something like `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC -`_, but -for **Instruct** instead of **Chat**. - -Good job, you have successfully distributed the model you compiled. -Next, we will talk about how we can consume the model weights in applications. - -Download the Distributed Models and Run in Python -------------------------------------------------- - -Running the distributed models are similar to running prebuilt model weights and libraries in :ref:`Model Prebuilts`. - -.. code:: shell - - # Clone prebuilt libs so we can reuse them: - mkdir -p dist/ - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs - - # Or download the model library (only needed if we do not reuse the model lib): - cd dist/prebuilt_libs - wget url-to-my-model-lib - cd ../.. - - # Download the model weights - cd dist - git clone https://huggingface.co/my-huggingface-account/my-redpajama3b-weight-huggingface-repo RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC - cd .. - - # Run the model in Python; note that we reuse `-Chat` model library - python - >>> from mlc_llm import ChatModule - >>> cm = ChatModule(model="dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC", \ - model_lib_path="dist/prebuilt_libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so") # Adjust based on backend - >>> cm.generate("hi") - 'Hi! How can I assist you today?' diff --git a/docs/compilation/define_new_models.rst b/docs/compilation/define_new_models.rst deleted file mode 100644 index 4c73864104..0000000000 --- a/docs/compilation/define_new_models.rst +++ /dev/null @@ -1,25 +0,0 @@ -Define New Model Architectures -============================== - -This page guides you how to add a new model architecture in MLC. - -This notebook (runnable in Colab) should contain all necessary information to add a model in -MLC LLM: -https://github.com/mlc-ai/notebooks/blob/main/mlc-llm/tutorial_add_new_model_architecture_in_tvm_nn_module.ipynb - -In the notebook, we leverage ``tvm.nn.module`` to define a model in MLC LLM. We also use ``JIT`` -(just-in-time compilation) to debug the implementation. - -You can also refer to the PRs below on specific examples of adding a model architecture in MLC LLM: - -- `GPTNeoX PR `_ -- `GPT-2 PR `_ -- `Mistral PR `_ - -.. note:: - - As mentioned in :ref:`Model Prebuilts`, when adding a model variant that has - its architecture already supported in mlc-llm , you **only need to convert weights** - (e.g. adding ``CodeLlama`` when MLC supports ``llama-2``; adding ``OpenHermes Mistral`` - when MLC supports ``mistral``). On the other hand, a new model architecture - (or inference logic) requires more work (following the tutorial above). \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py deleted file mode 100644 index 7743ef2985..0000000000 --- a/docs/conf.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- -import os -import sys - -import tlcpack_sphinx_addon - -# -- General configuration ------------------------------------------------ - -sys.path.insert(0, os.path.abspath("../python")) -sys.path.insert(0, os.path.abspath("../")) -autodoc_mock_imports = ["torch"] - -# General information about the project. -project = "mlc-llm" -author = "MLC LLM Contributors" -copyright = "2023, %s" % author - -# Version information. - -version = "0.1.0" -release = "0.1.0" - -extensions = [ - "sphinx_tabs.tabs", - "sphinx_toolbox.collapse", - "sphinxcontrib.httpdomain", - "sphinx.ext.autodoc", - "sphinx.ext.napoleon", - "sphinx_reredirects", -] - -redirects = {"get_started/try_out": "../index.html#getting-started"} - -source_suffix = [".rst"] - -language = "en" - -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = "sphinx" - -# A list of ignored prefixes for module index sorting. -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - -# -- Options for HTML output ---------------------------------------------- - -# The theme is set by the make target -import sphinx_rtd_theme - -html_theme = "sphinx_rtd_theme" -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] - -templates_path = [] - -html_static_path = [] - -footer_copyright = "© 2023 MLC LLM" -footer_note = " " - -html_logo = "_static/img/mlc-logo-with-text-landscape.svg" - -html_theme_options = { - "logo_only": True, -} - -header_links = [ - ("Home", "https://llm.mlc.ai/"), - ("Github", "https://github.com/mlc-ai/mlc-llm"), - ("Discord Server", "https://discord.gg/9Xpy2HGBuD"), -] - -header_dropdown = { - "name": "Other Resources", - "items": [ - ("MLC Course", "https://mlc.ai/"), - ("MLC Blog", "https://blog.mlc.ai/"), - ("Web LLM", "https://webllm.mlc.ai/"), - ], -} - -html_context = { - "footer_copyright": footer_copyright, - "footer_note": footer_note, - "header_links": header_links, - "header_dropdown": header_dropdown, - "display_github": True, - "github_user": "mlc-ai", - "github_repo": "mlc-llm", - "github_version": "main/docs/", - "theme_vcs_pageview_mode": "edit", - # "header_logo": "/path/to/logo", - # "header_logo_link": "", - # "version_selecter": "", -} - - -# add additional overrides -templates_path += [tlcpack_sphinx_addon.get_templates_path()] -html_static_path += [tlcpack_sphinx_addon.get_static_path()] diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst deleted file mode 100644 index a9b2fcb18f..0000000000 --- a/docs/deploy/android.rst +++ /dev/null @@ -1,187 +0,0 @@ -.. _deploy-android: - -Android App -=========== - -.. contents:: Table of Contents - :local: - :depth: 2 - -Demo App --------- - -The demo APK below is built for Samsung S23 with Snapdragon 8 Gen 2 chip. - -.. image:: https://seeklogo.com/images/D/download-android-apk-badge-logo-D074C6882B-seeklogo.com.png - :width: 135 - :target: https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android/mlc-chat.apk - -Prerequisite ------------- - -**Rust** (`install `__) is needed to cross-compile HuggingFace tokenizers to Android. Make sure rustc, cargo, and rustup are available in ``$PATH``. - -**Android Studio** (`install `__) with NDK and CMake. To install NDK and CMake, in the Android Studio welcome page, click "Projects → SDK Manager → SDK Tools". Set up the following environment variables: - -- ``ANDROID_NDK`` so that ``$ANDROID_NDK/build/cmake/android.toolchain.cmake`` is available. -- ``TVM_NDK_CC`` that points to NDK's clang compiler. - -.. code-block:: bash - - # Example on macOS - ANDROID_NDK: $HOME/Library/Android/sdk/ndk/25.2.9519653 - TVM_NDK_CC: $ANDROID_NDK/toolchains/llvm/prebuilt/darwin-x86_64/bin/aarch64-linux-android24-clang - # Example on Windows - ANDROID_NDK: $HOME/Library/Android/sdk/ndk/25.2.9519653 - TVM_NDK_CC: $ANDROID_NDK/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android24-clang - -**JDK**, such as OpenJDK >= 17, to compile Java bindings of TVM Unity runtime. It could be installed via Homebrew on macOS, apt on Ubuntu or other package managers. Set up the following environment variable: - -- ``JAVA_HOME`` so that Java is available in ``$JAVA_HOME/bin/java``. - -Please ensure that the JDK versions for Android Studio and JAVA_HOME are the same. We recommended setting the `JAVA_HOME` to the JDK bundled with Android Studio. e.g. `export JAVA_HOME=/Applications/Android\ Studio.app/Contents/jbr/Contents/Home` for macOS. - -**TVM Unity runtime** is placed under `3rdparty/tvm `__ in MLC LLM, so there is no need to install anything extra. Set up the following environment variable: - -- ``TVM_HOME`` so that its headers are available under ``$TVM_HOME/include/tvm/runtime``. - -(Optional) **TVM Unity compiler** Python package (:ref:`install ` or :ref:`build from source `). It is *NOT* required if models are prebuilt, but to compile PyTorch models from HuggingFace in the following section, the compiler is a must-dependency. - -.. note:: - ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts. - -Check if **environment variable** are properly set as the last check. One way to ensure this is to place them in ``$HOME/.zshrc``, ``$HOME/.bashrc`` or environment management tools. - -.. code-block:: bash - - source $HOME/.cargo/env # Rust - export ANDROID_NDK=... # Android NDK toolchain - export TVM_NDK_CC=... # Android NDK clang - export JAVA_HOME=... # Java - export TVM_HOME=... # TVM Unity runtime - -Compile PyTorch Models from HuggingFace ---------------------------------------- - -To deploy models on Android with reasonable performance, one has to cross-compile to and fully utilize mobile GPUs using TVM Unity. MLC provides a few pre-compiled models, or one could compile the models on their own. - -**Cloning MLC LLM from GitHub**. Download MLC LLM via the following command: - -.. code-block:: bash - - git clone --recursive https://github.com/mlc-ai/mlc-llm/ - ^^^^^^^^^^^ - cd ./mlc-llm/ - -.. note:: - ❗ The ``--recursive`` flag is necessary to download submodules like `3rdparty/tvm `__. If you see any file missing during compilation, please double check if git submodules are properly cloned. - -**Download the PyTorch model** using Git Large File Storage (LFS), and by default, under ``./dist/models/``: - -.. code-block:: bash - - MODEL_NAME=Llama-2-7b-chat-hf - QUANTIZATION=q4f16_1 - - git lfs install - git clone https://huggingface.co/meta-llama/$MODEL_NAME \ - ./dist/models/ - -**Compile Android-capable models**. Install TVM Unity compiler as a Python package, and then compile the model for android using the following commands: - -.. code-block:: bash - - # convert weights - mlc_llm convert_weight ./dist/models/$MODEL_NAME/ --quantization $QUANTIZATION -o dist/$MODEL_NAME-$QUANTIZATION-MLC/ - - # create mlc-chat-config.json - mlc_llm gen_config ./dist/models/$MODEL_NAME/ --quantization $QUANTIZATION \ - --conv-template llama-2 --context-window-size 768 -o dist/${MODEL_NAME}-${QUANTIZATION}-MLC/ - - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/${MODEL_NAME}-${QUANTIZATION}-MLC/mlc-chat-config.json \ - --device android -o ./dist/${MODEL_NAME}-${QUANTIZATION}-MLC/${MODEL_NAME}-${QUANTIZATION}-android.tar - -This generates the directory ``./dist/$MODEL_NAME-$QUANTIZATION-MLC`` which contains the necessary components to run the model, as explained below. - -.. note:: - ❗ To run 7B models like llama-2-7B, Mistral-7B, it is recommended to use smaller values of parameter ``--context-window-size`` (``--sliding-window-size`` and ``--prefill-chunk-size`` for sliding window attention) to reduce the memory footprint of the model. Default configurations for certains models can be found under the Android tab in the `Compile Models `_ section. - -**Expected output format**. By default models are placed under ``./dist/${MODEL_NAME}-${QUANTIZATION}-MLC``, and the result consists of 3 major components: - -- Runtime configuration: It configures conversation templates including system prompts, repetition penalty, sampling including temperature and top-p probability, maximum sequence length, etc. It is usually named as ``mlc-chat-config.json`` alongside with tokenizer configurations. -- Model lib: The compiled library that uses mobile GPU. It is usually named as ``${MODEL_NAME}-${QUANTIZATION}-android.tar``, for example, ``Llama-2-7b-chat-hf-q4f16_1-android.tar``. -- Model weights: the model weights are sharded as ``params_shard_*.bin`` and the metadata is stored in ``ndarray-cache.json`` - -Create Android Project using Compiled Models --------------------------------------------- - -The source code for MLC LLM is available under ``android/``, including scripts to build dependencies. Enter the directory first: - -.. code-block:: bash - - cd ./android/library - -**Build necessary dependencies.** Configure the list of models the app comes with using the JSON file ``app-config.json`` which contains two properties `model_list` and `model_lib_path_for_prepare_libs` ``model_lib_path_for_prepare_libs`` contains list of model library paths under `./dist/` that will be bundled with the apk. The ``model_list`` property contains data for models that are not bundled with the apk, but downloaded from the internet at run-time. Each model defined in `model_list` contain the following fields: - -``model_url`` - (Required) URL to the repo containing the weights. - -``model_id`` - (Required) Unique local identifier to identify the model. - -``model_lib`` - (Required) Matches the system-lib-prefix, generally set during ``mlc_llm compile`` which can be specified using - ``--system-lib-prefix`` argument. By default, it is set to ``"${model_type}_${quantization}"`` e.g. ``gpt_neox_q4f16_1`` for the RedPajama-INCITE-Chat-3B-v1 model. If the ``--system-lib-prefix`` argument is manually specified during ``mlc_llm compile``, the ``model_lib`` field should be updated accordingly. - -``estimated_vram_bytes`` - (Optional) Estimated requirements of VRAM to run the model. - -To change the configuration, edit ``app-config.json``: - -.. code-block:: bash - - vim ./src/main/assets/app-config.json - -Then bundle the android library ``${MODEL_NAME}-${QUANTIZATION}-android.tar`` compiled from ``mlc_llm compile`` in the previous steps, with TVM Unity's Java runtime by running the commands below: - -.. code-block:: bash - - ./prepare_libs.sh - -which generates the two files below: - -.. code-block:: bash - - >>> find ./build/output -type f - ./build/output/arm64-v8a/libtvm4j_runtime_packed.so - ./build/output/tvm4j_core.jar - -The model execution logic in mobile GPUs is incorporated into ``libtvm4j_runtime_packed.so``, while ``tvm4j_core.jar`` is a lightweight (~60 kb) `Java binding `_ to it. - -**Build the Android app**. Open folder ``./android`` as an Android Studio Project. Connect your Android device to your machine. In the menu bar of Android Studio, click "Build → Make Project". Once the build is finished, click "Run → Run 'app'" and you will see the app launched on your phone. - -.. note:: - ❗ This app cannot be run in an emulator and thus a physical phone is required, because MLC LLM needs an actual mobile GPU to meaningfully run at an accelerated speed. - -Incorporate Model Weights -------------------------- - -Instructions have been provided to build an Android App with MLC LLM in previous sections, but it requires run-time weight downloading from HuggingFace, as configured in `app-config.json` in previous steps under `model_url`. However, it could be desirable to bundle weights together into the app to avoid downloading over the network. In this section, we provide a simple ADB-based walkthrough that hopefully helps with further development. - -**Generating APK**. Enter Android Studio, and click "Build → Generate Signed Bundle/APK" to build an APK for release. If it is the first time you generate an APK, you will need to create a key according to `the official guide from Android `_. This APK will be placed under ``android/app/release/app-release.apk``. - -**Install ADB and USB debugging**. Enable "USB debugging" in the developer mode in your phone settings. In SDK manager, install `Android SDK Platform-Tools `_. Add the path to platform-tool path to the environment variable ``PATH``. Run the following commands, and if ADB is installed correctly, your phone will appear as a device: - -.. code-block:: bash - - adb devices - -**Install the APK and weights to your phone**. Run the commands below replacing ``${MODEL_NAME}`` and ``${QUANTIZATION}`` with the actual model name (e.g. Llama-2-7b-chat-hf) and quantization format (e.g. q4f16_1). - -.. code-block:: bash - - adb install android/app/release/app-release.apk - adb push dist/${MODEL_NAME}-${QUANTIZATION}-MLC /data/local/tmp/${MODEL_NAME}-${QUANTIZATION}/ - adb shell "mkdir -p /storage/emulated/0/Android/data/ai.mlc.mlcchat/files/" - adb shell "mv /data/local/tmp/${MODEL_NAME}-${QUANTIZATION} /storage/emulated/0/Android/data/ai.mlc.mlcchat/files/" diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst deleted file mode 100644 index a7ebe28d6d..0000000000 --- a/docs/deploy/cli.rst +++ /dev/null @@ -1,104 +0,0 @@ -.. _deploy-cli: - -CLI -=============== - -MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box. - -.. contents:: Table of Contents - :local: - :depth: 2 - -Option 1. Conda Prebuilt -~~~~~~~~~~~~~~~~~~~~~~~~ - -The prebuilt package supports Metal on macOS and Vulkan on Linux and Windows, and can be installed via Conda one-liner. - -To use other GPU runtimes, e.g. CUDA, please instead :ref:`build it from source `. - -.. code:: shell - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly - mlc_llm chat -h - -.. note:: - The prebuilt package supports **Metal** on macOS and **Vulkan** on Linux and Windows. It is possible to use other GPU runtimes such as **CUDA** by compiling MLCChat CLI from the source. - - -Option 2. Build MLC Runtime from Source -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -We also provide options to build mlc runtime libraries and ``mlc_llm`` from source. -This step is useful if the prebuilt is unavailable on your platform, or if you would like to build a runtime -that supports other GPU runtime than the prebuilt version. We can build a customized version -of mlc chat runtime. You only need to do this if you choose not to use the prebuilt. - -First, make sure you install TVM unity (following the instruction in :ref:`install-tvm-unity`). -Then please follow the instructions in :ref:`mlcchat_build_from_source` to build the necessary libraries. - -.. `|` adds a blank line - -| - -Run Models through MLCChat CLI -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Once ``mlc_llm`` is installed, you are able to run any MLC-compiled model on the command line. - -To run a model with MLC LLM in any platform, you can either: - -- Use off-the-shelf model prebuilts from the MLC Huggingface repo (see :ref:`Model Prebuilts` for details). -- Use locally compiled model weights and libraries following :doc:`the model compilation page `. - -**Option 1: Use model prebuilts** - -To run ``mlc_llm``, you can specify the Huggingface MLC prebuilt model repo path with the prefix ``HF://``. -For example, to run the MLC Llama 3 8B Q4F16_1 model (`Repo link `_), -simply use ``HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC``. The model weights and library will be downloaded -automatically from Huggingface. - -.. code:: shell - - mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --device "cuda:0" --overrides context_window_size=1024 - -.. code:: - - You can use the following special commands: - /help print the special commands - /exit quit the cli - /stats print out the latest stats (token/sec) - /reset restart a fresh chat - /set [overrides] override settings in the generation config. For example, - `/set temperature=0.5;max_gen_len=100;stop=end,stop` - Note: Separate stop words in the `stop` option with commas (,). - Multi-line input: Use escape+enter to start a new line. - - user: What's the meaning of life - assistant: - What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. - - The concept of the meaning of life has been debated and... - - -**Option 2: Use locally compiled model weights and libraries** - -For models other than the prebuilt ones we provided: - -1. If the model is a variant to an existing model library (e.g. ``WizardMathV1.1`` and ``OpenHermes`` are variants of ``Mistral``), - follow :ref:`convert-weights-via-MLC` to convert the weights and reuse existing model libraries. -2. Otherwise, follow :ref:`compile-model-libraries` to compile both the model library and weights. - -Once you have the model locally compiled with a model library and model weights, to run ``mlc_llm``, simply - -- Specify the path to ``mlc-chat-config.json`` and the converted model weights to ``--model`` -- Specify the path to the compiled model library (e.g. a .so file) to ``--model-lib-path`` - -.. code:: shell - - mlc_llm chat dist/Llama-2-7b-chat-hf-q4f16_1-MLC \ - --device "cuda:0" --overrides context_window_size=1024 \ - --model-lib-path dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-vulkan.so - # CUDA on Linux: dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so - # Metal on macOS: dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-metal.so - # Same rule applies for other platforms diff --git a/docs/deploy/ide_integration.rst b/docs/deploy/ide_integration.rst deleted file mode 100644 index 866dfa3cbe..0000000000 --- a/docs/deploy/ide_integration.rst +++ /dev/null @@ -1,179 +0,0 @@ -.. _deploy-ide-integration: - -Code Completion IDE Integration -=============================== - -.. contents:: Table of Contents - :local: - :depth: 2 - -MLC LLM has now support for code completion on multiple IDEs. This means you can easily integrate an LLM with coding capabilities with your IDE through the MLC LLM :ref:`deploy-rest-api`. Here we provide a step-by-step guide on how to do this. - -Convert Your Model Weights --------------------------- - -To run a model with MLC LLM in any platform, you need to convert your model weights to the MLC format (e.g. `CodeLlama-7b-hf-q4f16_1-MLC `__). You can always refer to :ref:`convert-weights-via-MLC` for in-depth details on how to convert your model weights. If you are using your own model weights, i.e., you finetuned the model on your personal codebase, it is important to follow these steps to convert the respective weights properly. However, it is also possible to download precompiled weights from the original models, available in the MLC format. See the full list of all precompiled weights `here `__. - -**Example:** - -.. code:: bash - - # convert model weights - mlc_llm convert_weight ./dist/models/CodeLlama-7b-hf \ - --quantization q4f16_1 \ - -o ./dist/CodeLlama-7b-hf-q4f16_1-MLC - -Compile Your Model ------------------- - -Compiling the model architecture is the crucial step to optimize inference for a given platform. However, compilation relies on multiple settings that will impact the runtime. This configuration is specified inside the ``mlc-chat-config.json`` file, which can be generated by the ``gen_config`` command. You can learn more about the ``gen_config`` command `here `__. - -**Example:** - -.. code:: bash - - # generate mlc-chat-config.json - mlc_llm gen_config ./dist/models/CodeLlama-7b-hf \ - --quantization q4f16_1 --conv-template LM \ - -o ./dist/CodeLlama-7b-hf-q4f16_1-MLC - -.. note:: - Make sure to set the ``--conv-template`` flag to ``LM``. This template is specifically tailored to perform vanilla LLM completion, generally adopted by code completion models. - -After generating the MLC model configuration file, we are all set to compile and create the model library. You can learn more about the ``compile`` command `here `__ - -**Example:** - -.. tabs:: - - .. group-tab:: Linux - CUDA - - .. code:: bash - - # compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device cuda -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-cuda.so - - .. group-tab:: Metal - - For M-chip Mac: - - .. code:: bash - - # compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device metal -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-metal.so - - Cross-Compiling for Intel Mac on M-chip Mac: - - .. code:: bash - - # compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device metal:x86-64 -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-metal_x86_64.dylib - - For Intel Mac: - - .. code:: bash - - # compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device metal -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-metal_x86_64.dylib - - .. group-tab:: Vulkan - - For Linux: - - .. code:: bash - - # compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device vulkan -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-vulkan.so - - For Windows: - - .. code:: bash - - # compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \ - --device vulkan -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-vulkan.dll - -.. note:: - The generated model library can be shared across multiple model variants, as long as the architecture and number of parameters does not change, e.g., same architecture, but different weights (your finetuned model). - -Setting up the Inference Entrypoint ------------------------------------ - -You can now locally deploy your compiled model with the MLC serve module. To find more details about the MLC LLM API visit our :ref:`deploy-rest-api` page. - -**Example:** - -.. code:: bash - - python -m mlc_llm.serve.server \ - --model dist/CodeLlama-7b-hf-q4f16_1-MLC \ - --model-lib-path ./dist/libs/CodeLlama-7b-hf-q4f16_1-cuda.so - -Configure the IDE Extension ---------------------------- - -After deploying the LLM we can easily connect the IDE with the MLC Rest API. In this guide, we will be using the Hugging Face Code Completion extension `llm-ls `__ which has support across multiple IDEs (e.g., `vscode `__, `intellij `__ and `nvim `__) to connect to an external OpenAI compatible API (i.e., our MLC LLM :ref:`deploy-rest-api`). - -After installing the extension on your IDE, open the ``settings.json`` extension configuration file: - -.. figure:: /_static/img/ide_code_settings.png - :width: 450 - :align: center - :alt: settings.json - -| - -Then, make sure to replace the following settings with the respective values: - -.. code:: javascript - - "llm.modelId": "dist/CodeLlama-7b-hf-q4f16_1-MLC" - "llm.url": "http://127.0.0.1:8000/v1/completions" - "llm.backend": "openai" - -This will enable the extension to send OpenAI compatible requests to the MLC Serve API. Also, feel free to tune the API parameters. Please refer to our :ref:`deploy-rest-api` documentation for more details about these API parameters. - -.. code:: javascript - - "llm.requestBody": { - "best_of": 1, - "frequency_penalty": 0.0, - "presence_penalty": 0.0, - "logprobs": false, - "top_logprobs": 0, - "logit_bias": null, - "max_tokens": 128, - "seed": null, - "stop": null, - "suffix": null, - "temperature": 1.0, - "top_p": 1.0 - } - -The llm-ls extension supports a variety of different model code completion templates. Choose the one that best matches your model, i.e., the template with the correct tokenizer and Fill in the Middle tokens. - -.. figure:: /_static/img/ide_code_templates.png - :width: 375 - :align: center - :alt: llm-ls templates - -| - -After everything is all set, the extension will be ready to use the responses from the MLC Serve API to provide off-the-shelf code completion on your IDE. - -.. figure:: /_static/img/code_completion.png - :width: 700 - :align: center - :alt: IDE Code Completion - -| - -Conclusion ----------- - -Please, let us know if you have any questions. Feel free to open an issue on the `MLC LLM repo `__! diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst deleted file mode 100644 index 75a5cdbdc7..0000000000 --- a/docs/deploy/ios.rst +++ /dev/null @@ -1,491 +0,0 @@ -.. _deploy-ios: - -iOS App and Swift API -===================== - -.. contents:: Table of Contents - :local: - :depth: 2 - -The MLC LLM iOS app can be installed in two ways: through the pre-built package or by building from the source. -If you are an iOS user looking to try out the models, the pre-built package is recommended. If you are a -developer seeking to integrate new features into the package, building the iOS package from the source is required. - -Use Pre-built iOS App ---------------------- -The MLC Chat app is now available in App Store at no cost. You can download and explore it by simply clicking the button below: - - .. image:: https://developer.apple.com/assets/elements/badges/download-on-the-app-store.svg - :width: 135 - :target: https://apps.apple.com/us/app/mlc-chat/id6448482937 - - -Build iOS App from Source -------------------------- - -This section shows how we can build the app from the source. - -Step 1. Install Build Dependencies -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -First and foremost, please clone the `MLC LLM GitHub repository `_. - -Please follow :doc:`/install/tvm` to install TVM Unity. -Note that we **do not** have to run `build.py` since we can use prebuilt weights. -We only need TVM Unity's utility to combine the libraries (`local-id-iphone.tar`) into a single library. - -We also need to have the following build dependencies: - -* CMake >= 3.24, -* Git and Git-LFS, -* `Rust and Cargo `_, which are required by Hugging Face's tokenizer. - - -Step 2. Download Prebuilt Weights and Library -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -You also need to obtain a copy of the MLC-LLM source code -by cloning the `MLC LLM GitHub repository `_. -To simplify the build, we will use prebuilt model -weights and libraries here. Run the following command -in the root directory of the MLC-LLM. - -.. code:: bash - - mkdir -p dist/prebuilt - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib - - cd dist/prebuilt - git lfs install - git clone https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - cd ../.. - -Validate that the files and directories exist: - -.. code:: bash - - >>> ls -l ./dist/prebuilt/lib/*/*-iphone.tar - ./dist/prebuilt/lib/RedPajama-INCITE-Chat-3B-v1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar - ./dist/prebuilt/lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q3f16_1-iphone.tar - ... - - >>> ls -l ./dist/prebuilt/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - # chat config: - mlc-chat-config.json - # model weights: - ndarray-cache.json - params_shard_*.bin - ... - - -Step 3. Build Auxiliary Components -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -**Tokenizer and runtime** - -In addition to the model itself, a lightweight runtime and tokenizer are -required to actually run the LLM. You can build and organize these -components by following these steps: - -.. code:: bash - - git submodule update --init --recursive - cd ./ios - ./prepare_libs.sh - -This will create a ``./build`` folder that contains the following files. -Please make sure all the following files exist in ``./build/``. - -.. code:: bash - - >>> ls ./build/lib/ - libmlc_llm.a # A lightweight interface to interact with LLM, tokenizer, and TVM Unity runtime - libmodel_iphone.a # The compiled model lib - libsentencepiece.a # SentencePiece tokenizer - libtokenizers_cpp.a # Huggingface tokenizer - libtvm_runtime.a # TVM Unity runtime - -**Add prepackage model** - -We can also *optionally* add prepackage weights into the app, -run the following command under the ``./ios`` directory: - -.. code:: bash - - cd ./ios - open ./prepare_params.sh # make sure builtin_list only contains "RedPajama-INCITE-Chat-3B-v1-q4f16_1" - ./prepare_params.sh - -The outcome should be as follows: - -.. code:: bash - - >>> ls ./dist/ - RedPajama-INCITE-Chat-3B-v1-q4f16_1 - -Step 4. Build iOS App -^^^^^^^^^^^^^^^^^^^^^ - -Open ``./ios/MLCChat.xcodeproj`` using Xcode. Note that you will need an -Apple Developer Account to use Xcode, and you may be prompted to use -your own developer team credential and product bundle identifier. - -Ensure that all the necessary dependencies and configurations are -correctly set up in the Xcode project. - -Once you have made the necessary changes, build the iOS app using Xcode. -If you have an Apple Silicon Mac, you can select target "My Mac (designed for iPad)" -to run on your Mac. You can also directly run it on your iPad or iPhone. - -.. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/xcode-build.jpg - :align: center - :width: 60% - -| - -Customize the App ------------------ - -We can customize the iOS app in several ways. -`MLCChat/app-config.json `_ -controls the list of local and remote models to be packaged into the app, given a local path or a URL respectively. Only models in ``model_list`` will have their libraries brought into the app when running `./prepare_libs` to package them into ``libmodel_iphone.a``. Each model defined in `app-config.json` contain the following fields: - -``model_path`` - (Required if local model) Name of the local folder containing the weights. - -``model_url`` - (Required if remote model) URL to the repo containing the weights. - -``model_id`` - (Required) Unique local identifier to identify the model. - -``model_lib`` - (Required) Matches the system-lib-prefix, generally set during ``mlc_llm compile`` which can be specified using - ``--system-lib-prefix`` argument. By default, it is set to ``"${model_type}_${quantization}"`` e.g. ``gpt_neox_q4f16_1`` - for the RedPajama-INCITE-Chat-3B-v1 model. If the ``--system-lib-prefix`` argument is manually specified during - ``mlc_llm compile``, the ``model_lib`` field should be updated accordingly. - -``required_vram_bytes`` - (Required) Estimated requirements of VRAM to run the model. - -``model_lib_path_for_prepare_libs`` - (Required) List of paths to the model libraries in the app (respective ``.tar`` file in the ``binary-mlc-llm-libs`` - repo, relative path in the ``dist`` artifact folder or full path to the library). Only used while running - ``prepare_libs.sh`` to determine which model library to use during runtime. Useful when selecting a library with - different settings (e.g. ``prefill_chunk_size``, ``context_window_size``, and ``sliding_window_size``). - -Additionally, the app prepackages the models under ``./ios/dist``. -This built-in list can be controlled by editing ``prepare_params.sh``. -You can package new prebuilt models or compiled models by changing the above fields and then repeating the steps above. - - -Bring Your Own Model Variant ----------------------------- - -In cases where the model you are adding is simply a variant of an existing -model, we only need to convert weights and reuse existing model library. For instance: - -- Adding ``NeuralHermes`` when MLC already supports the ``Mistral`` architecture - - -In this section, we walk you through adding ``NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC`` to the MLC iOS app. -According to the model's ``config.json`` on `its Huggingface repo `_, -it reuses the Mistral model architecture. - -.. note:: - - This section largely replicates :ref:`convert-weights-via-MLC`. - See that page for more details. Note that the weights are shared across - all platforms in MLC. - -**Step 1 Clone from HF and convert_weight** - -You can be under the mlc-llm repo, or your own working directory. Note that all platforms -can share the same compiled/quantized weights. See :ref:`compile-command-specification` -for specification of ``convert_weight``. - -.. code:: shell - - # Create directory - mkdir -p dist/models && cd dist/models - # Clone HF weights - git lfs install - git clone https://huggingface.co/mlabonne/NeuralHermes-2.5-Mistral-7B - cd ../.. - # Convert weight - mlc_llm convert_weight ./dist/models/NeuralHermes-2.5-Mistral-7B/ \ - --quantization q4f16_1 \ - -o dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC - -**Step 2 Generate MLC Chat Config** - -Use ``mlc_llm gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. -See :ref:`compile-command-specification` for specification of ``gen_config``. - -.. code:: shell - - mlc_llm gen_config ./dist/models/NeuralHermes-2.5-Mistral-7B/ \ - --quantization q3f16_1 --conv-template neural_hermes_mistral \ - -o dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC - -For the ``conv-template``, `conv_template.cc `__ -contains a full list of conversation templates that MLC provides. - -If the model you are adding requires a new conversation template, you would need to add your own. -Follow `this PR `__ as an example. -We look up the template to use with the ``conv_template`` field in ``mlc-chat-config.json``. - -For more details, please see :ref:`configure-mlc-chat-json`. - -**Step 3 Upload weights to HF** - -.. code:: shell - - # First, please create a repository on Hugging Face. - # With the repository created, run - git lfs install - git clone https://huggingface.co/my-huggingface-account/my-mistral-weight-huggingface-repo - cd my-mistral-weight-huggingface-repo - cp path/to/mlc-llm/dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC/* . - git add . && git commit -m "Add mistral model weights" - git push origin main - -After successfully following all steps, you should end up with a Huggingface repo similar to -`NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC `__, -which includes the converted/quantized weights, the ``mlc-chat-config.json``, and tokenizer files. - - -**Step 4 Register as a ModelRecord** - -Finally, we modify the code snippet for -`app-config.json `__ -pasted above. - -We simply specify the Huggingface link as ``model_url``, while reusing the ``model_lib`` for -``Mistral-7B``. - -.. code:: javascript - - "model_list": [ - // Other records here omitted... - { - // Substitute model_url with the one you created `my-huggingface-account/my-mistral-weight-huggingface-repo` - "model_url": "https://huggingface.co/mlc-ai/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC", - "model_id": "Mistral-7B-Instruct-v0.2-q3f16_1", - "model_lib": "mistral_q3f16_1", - "model_lib_path": "lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q3f16_1-iphone.tar", - "estimated_vram_bytes": 3316000000 - } - ] - - -Now, the app will use the ``NeuralHermes-Mistral`` model you just added. - - -Bring Your Own Model Library ----------------------------- - -A model library is specified by: - - - The model architecture (e.g. ``mistral``, ``phi-msft``) - - Quantization Scheme (e.g. ``q3f16_1``, ``q0f32``) - - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill_chunk_size``), which affects memory planning - - Platform (e.g. ``cuda``, ``webgpu``, ``iphone``, ``android``) - -In cases where the model you want to run is not compatible with the provided MLC -prebuilt model libraries (e.g. having a different quantization, a different -metadata spec, or even a different model architecture), you need to build your -own model library. - -In this section, we walk you through adding ``phi-2`` to the iOS app. - -This section largely replicates :ref:`compile-model-libraries`. See that page for -more details, specifically the ``iOS`` option. - -**Step 0. Install dependencies** - -To compile model libraries for iOS, you need to :ref:`build mlc_llm from source `. - -**Step 1. Clone from HF and convert_weight** - -You can be under the mlc-llm repo, or your own working directory. Note that all platforms -can share the same compiled/quantized weights. - -.. code:: shell - - # Create directory - mkdir -p dist/models && cd dist/models - # Clone HF weights - git lfs install - git clone https://huggingface.co/microsoft/phi-2 - cd ../.. - # Convert weight - mlc_llm convert_weight ./dist/models/phi-2/ \ - --quantization q4f16_1 \ - -o dist/phi-2-q4f16_1-MLC - -**Step 2. Generate mlc-chat-config and compile** - -A model library is specified by: - - - The model architecture (e.g. ``mistral``, ``phi-msft``) - - Quantization Scheme (e.g. ``q3f16_1``, ``q0f32``) - - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill_chunk_size``), which affects memory planning - - Platform (e.g. ``cuda``, ``webgpu``, ``iphone``, ``android``) - -All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_config``. - -.. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/phi-2/ \ - --quantization q4f16_1 --conv-template phi-2 \ - -o dist/phi-2-q4f16_1-MLC/ - # 2. mkdir: create a directory to store the compiled model library - mkdir -p dist/libs - # 3. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/phi-2-q4f16_1-MLC/mlc-chat-config.json \ - --device iphone -o dist/libs/phi-2-q4f16_1-iphone.tar - -Given the compiled library, it is possible to calculate an upper bound for the VRAM -usage during runtime. This useful to better understand if a model is able to fit particular -hardware. -That information will be displayed at the end of the console log when the ``compile`` is executed. -It might look something like this: - -.. code:: shell - - [2024-04-25 03:19:56] INFO model_metadata.py:96: Total memory usage: 1625.73 MB (Parameters: 1492.45 MB. KVCache: 0.00 MB. Temporary buffer: 133.28 MB) - [2024-04-25 03:19:56] INFO model_metadata.py:105: To reduce memory usage, tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size` - [2024-04-25 03:19:56] INFO compile.py:198: Generated: dist/libs/phi-2-q4f16_1-iphone.tar - -.. note:: - When compiling larger models like ``Llama-2-7B``, you may want to add a lower chunk size - while prefilling prompts ``--prefill_chunk_size 128`` or even lower ``context_window_size``\ - to decrease memory usage. Otherwise, during runtime, you may run out of memory. - - -**Step 3. Distribute model library and model weights** - -After following the steps above, you should end up with: - -.. code:: shell - - ~/mlc-llm > ls dist/libs - phi-2-q4f16_1-iphone.tar # ===> the model library - - ~/mlc-llm > ls dist/phi-2-q4f16_1-MLC - mlc-chat-config.json # ===> the chat config - ndarray-cache.json # ===> the model weight info - params_shard_0.bin # ===> the model weights - params_shard_1.bin - ... - tokenizer.json # ===> the tokenizer files - tokenizer_config.json - -Upload the ``phi-2-q4f16_1-iphone.tar`` to a github repository (for us, -it is in `binary-mlc-llm-libs `__). Then -upload the weights ``phi-2-q4f16_1-MLC`` to a Huggingface repo: - -.. code:: shell - - # First, please create a repository on Hugging Face. - # With the repository created, run - git lfs install - git clone https://huggingface.co/my-huggingface-account/my-phi-weight-huggingface-repo - cd my-phi-weight-huggingface-repo - cp path/to/mlc-llm/dist/phi-2-q4f16_1-MLC/* . - git add . && git commit -m "Add phi-2 model weights" - git push origin main - -This would result in something like `phi-2-q4f16_1-MLC -`_. - - -**Step 4. Register as a ModelRecord** - -Finally, we update the code snippet for -`app-config.json `__ -pasted above. - -We simply specify the Huggingface link as ``model_url``, while using the new ``model_lib`` for -``phi-2``. Regarding the field ``estimated_vram_bytes``, we can use the output of the last step -rounded up to MB. - -.. code:: javascript - - "model_list": [ - // Other records here omitted... - { - // Substitute model_url with the one you created `my-huggingface-account/my-phi-weight-huggingface-repo` - "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC", - "model_id": "phi-2-q4f16_1", - "model_lib": "phi_msft_q4f16_1", - "model_lib_path": "lib/phi-2/phi-2-q4f16_1-iphone.tar", - "estimated_vram_bytes": 3043000000 - } - ] - - -Now, the app will use the ``phi-2`` model library you just added. - - -Build Apps with MLC Swift API ------------------------------ - -We also provide a Swift package that you can use to build -your own app. The package is located under `ios/MLCSwift`. - -- First make sure you have run the same steps listed - in the previous section. This will give us the necessary libraries - under ``/path/to/ios/build/lib``. -- Then you can add ``ios/MLCSwift`` package to your app in Xcode. - Under "Frameworks, Libraries, and Embedded Content", click add package dependencies - and add local package that points to ``ios/MLCSwift``. -- Finally, we need to add the libraries dependencies. Under build settings: - - - Add library search path ``/path/to/ios/build/lib``. - - Add the following items to "other linker flags". - - .. code:: - - -Wl,-all_load - -lmodel_iphone - -lmlc_llm -ltvm_runtime - -ltokenizers_cpp - -lsentencepiece - -ltokenizers_c - - -You can then import the `MLCSwift` package into your app. -The following code shows an illustrative example of how to use the chat module. - -.. code:: swift - - import MLCSwift - - let threadWorker = ThreadWorker() - let chat = ChatModule() - - threadWorker.push { - let modelLib = "model-lib-name" - let modelPath = "/path/to/model/weights" - let input = "What is the capital of Canada?" - chat.reload(modelLib, modelPath: modelPath) - - chat.prefill(input) - while (!chat.stopped()) { - displayReply(chat.getMessage()) - chat.decode() - } - } - -.. note:: - - Because the chat module makes heavy use of GPU and thread-local - resources, it needs to run on a dedicated background thread. - Therefore, **avoid using** `DispatchQueue`, which can cause context switching to - different threads and segfaults due to thread-safety issues. - Use the `ThreadWorker` class to launch all the jobs related - to the chat module. You can check out the source code of - the MLCChat app for a complete example. diff --git a/docs/deploy/javascript.rst b/docs/deploy/javascript.rst deleted file mode 100644 index bd92908cff..0000000000 --- a/docs/deploy/javascript.rst +++ /dev/null @@ -1,360 +0,0 @@ -.. _webllm-runtime: - -WebLLM and JavaScript API -========================= - -.. contents:: Table of Contents - :local: - :depth: 2 - -`WebLLM `_ is an MLC chat web runtime -that allows you to build chat applications directly in the browser, leveraging -`WebGPU `_ and providing users a natural layer of abstraction. - -Try out the Prebuilt Webpage ----------------------------- - -To get started, you can try out `WebLLM prebuilt webpage `__. - -A WebGPU-compatible browser and a local GPU are needed to run WebLLM. -You can download the latest Google Chrome and use `WebGPU Report `__ -to verify the functionality of WebGPU on your browser. - - -Use WebLLM NPM Package ----------------------- - -WebLLM is available as an `npm package `_. -The source code is available in `the WebLLM repo `_, -where you can make your own modifications and build from source. - -Note that the `WebLLM prebuilt webpage `__ above -is powered by the WebLLM npm package, specifically with the code in -the `simple-chat `__ example. - -Each of the model in the `WebLLM prebuilt webpage `__ -is registered as an instance of ``ModelRecord``. Looking at the most straightforward example -`get-started `__, -we see the code snippet: - -.. code:: typescript - - const myAppConfig: AppConfig = { - model_list: [ - { - "model_url": "https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f32_1-MLC/resolve/main/", - "local_id": "Llama-2-7b-chat-hf-q4f32_1", - "model_lib_url": "https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f32_1-ctx4k_cs1k-webgpu.wasm", - }, - { - "model_url": "https://huggingface.co/mlc-ai/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/resolve/main/", - "local_id": "Mistral-7B-Instruct-v0.2-q4f16_1", - "model_lib_url": "https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm", - "required_features": ["shader-f16"], - }, - // Add your own models here... - ] - } - const selectedModel = "Llama-2-7b-chat-hf-q4f32_1" - // const selectedModel = "Mistral-7B-Instruct-v0.1-q4f16_1" - await chat.reload(selectedModel, undefined, myAppConfig); - -Just like any other platforms, to run a model with on WebLLM, you need: - -1. **Model weights** converted to MLC format (e.g. `Llama-2-7b-hf-q4f32_1-MLC - `_.): downloaded through ``model_url`` -2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__): downloaded through ``model_lib_url``. - -Verify Installation for Adding Models -------------------------------------- - -In sections below, we walk you through two examples of adding models to WebLLM. Before proceeding, -please verify installation of ``mlc_llm`` and ``tvm``: - -**Step 1. Verify mlc_llm** - -We use the python package ``mlc_llm`` to compile models. This can be installed by -following :ref:`install-mlc-packages`, either by building from source, or by -installing the prebuilt package. Verify ``mlc_llm`` installation in command line via: - -.. code:: bash - - $ mlc_llm --help - # You should see help information with this line - usage: MLC LLM Command Line Interface. [-h] {compile,convert_weight,gen_config} - -.. note:: - If it runs into error ``command not found: mlc_llm``, try ``python -m mlc_llm --help``. - -**Step 2. Verify TVM** - -To compile models, you also need to follow :ref:`install-tvm-unity`. -Here we verify ``tvm`` quickly with command line (for full verification, see :ref:`tvm-unity-validate`): - -.. code:: bash - - $ python -c "import tvm; print(tvm.__file__)" - /some-path/lib/python3.11/site-packages/tvm/__init__.py - - -.. _webllm-add-model-variant: - -Bring Your Own Model Variant ----------------------------- - -In cases where the model you are adding is simply a variant of an existing -model, we only need to convert weights and reuse existing model library. For instance: - -- Adding ``OpenMistral`` when MLC supports ``Mistral`` -- Adding ``Llama2-uncensored`` when MLC supports ``Llama2`` - - -In this section, we walk you through adding ``WizardMath-7B-V1.1-q4f16_1`` to the -`get-started `__ example. -According to the model's ``config.json`` on `its Huggingface repo `_, -it reuses the Mistral model architecture. - -.. note:: - - This section largely replicates :ref:`convert-weights-via-MLC`. - See that page for more details. Note that the weights are shared across - all platforms in MLC. - -**Step 1 Clone from HF and convert_weight** - -You can be under the mlc-llm repo, or your own working directory. Note that all platforms -can share the same compiled/quantized weights. See :ref:`compile-command-specification` -for specification of ``convert_weight``. - -.. code:: shell - - # Create directory - mkdir -p dist/models && cd dist/models - # Clone HF weights - git lfs install - git clone https://huggingface.co/WizardLM/WizardMath-7B-V1.1 - cd ../.. - # Convert weight - mlc_llm convert_weight ./dist/models/WizardMath-7B-V1.1/ \ - --quantization q4f16_1 \ - -o dist/WizardMath-7B-V1.1-q4f16_1-MLC - -**Step 2 Generate MLC Chat Config** - -Use ``mlc_llm gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers. -See :ref:`compile-command-specification` for specification of ``gen_config``. - -.. code:: shell - - mlc_llm gen_config ./dist/models/WizardMath-7B-V1.1/ \ - --quantization q4f16_1 --conv-template wizard_coder_or_math \ - -o dist/WizardMath-7B-V1.1-q4f16_1-MLC/ - -For the ``conv-template``, `conv_template.cc `__ -contains a full list of conversation templates that MLC provides. - -If the model you are adding requires a new conversation template, you would need to add your own. -Follow `this PR `__ as an example. Besides, you also need to add the new template to ``/path/to/web-llm/src/conversation.ts``. -We look up the template to use with the ``conv_template`` field in ``mlc-chat-config.json``. - -For more details, please see :ref:`configure-mlc-chat-json`. - -.. note:: - - If you added your conversation template in ``src/conversation.ts``, you need to build WebLLM - from source following the instruction in - `the WebLLM repo's README `_. - - Alternatively, you could use the ``"custom"`` conversation template so that you can pass in - your own ``ConvTemplateConfig`` in runtime without having to build the package from source. - -**Step 3 Upload weights to HF** - -.. code:: shell - - # First, please create a repository on Hugging Face. - # With the repository created, run - git lfs install - git clone https://huggingface.co/my-huggingface-account/my-wizardMath-weight-huggingface-repo - cd my-wizardMath-weight-huggingface-repo - cp path/to/mlc-llm/dist/WizardMath-7B-V1.1-q4f16_1-MLC/* . - git add . && git commit -m "Add wizardMath model weights" - git push origin main - -After successfully following all steps, you should end up with a Huggingface repo similar to -`WizardMath-7B-V1.1-q4f16_1-MLC `__, -which includes the converted/quantized weights, the ``mlc-chat-config.json``, and tokenizer files. - - -**Step 4 Register as a ModelRecord** - -Finally, we modify the code snippet for -`get-started `__ -pasted above. - -We simply specify the Huggingface link as ``model_url``, while reusing the ``model_lib_url`` for -``Mistral-7B``. Note that we need the suffix to be ``/resolve/main/``. - -.. code:: typescript - - const myAppConfig: AppConfig = { - model_list: [ - // Other records here omitted... - { - // Substitute model_url with the one you created `my-huggingface-account/my-wizardMath-weight-huggingface-repo` - "model_url": "https://huggingface.co/mlc-ai/WizardMath-7B-V1.1-q4f16_1-MLC/resolve/main/", - "local_id": "WizardMath-7B-V1.1-q4f16_1", - "model_lib_url": "https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm", - "required_features": ["shader-f16"], - }, - ] - } - - const selectedModel = "WizardMath-7B-V1.1-q4f16_1" - await chat.reload(selectedModel, undefined, myAppConfig); - -Now, running the ``get-started`` example will use the ``WizardMath`` model you just added. -See `get-started's README `__ -on how to run it. - - -Bring Your Own Model Library ----------------------------- - -A model library is specified by: - - - The model architecture (e.g. ``llama-2``, ``gpt-neox``) - - Quantization (e.g. ``q4f16_1``, ``q0f32``) - - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill-chunk-size``), which affects memory planning - - Platform (e.g. ``cuda``, ``webgpu``, ``iOS``) - -In cases where the model you want to run is not compatible with the provided MLC -prebuilt model libraries (e.g. having a different quantization, a different -metadata spec, or even a different model architecture), you need to build your -own model library. - -In this section, we walk you through adding ``RedPajama-INCITE-Chat-3B-v1`` to the -`get-started `__ example. - -This section largely replicates :ref:`compile-model-libraries`. See that page for -more details, specifically the ``WebGPU`` option. - -**Step 0. Install dependencies** - -To compile model libraries for webgpu, you need to :ref:`build mlc_llm from source `. -Besides, you also need to follow :ref:`install-web-build`. Otherwise, it would run into error: - -.. code:: text - - RuntimeError: Cannot find libraries: wasm_runtime.bc - -**Step 1. Clone from HF and convert_weight** - -You can be under the mlc-llm repo, or your own working directory. Note that all platforms -can share the same compiled/quantized weights. - -.. code:: shell - - # Create directory - mkdir -p dist/models && cd dist/models - # Clone HF weights - git lfs install - git clone https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1 - cd ../.. - # Convert weight - mlc_llm convert_weight ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ - --quantization q4f16_1 \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - -**Step 2. Generate mlc-chat-config and compile** - -A model library is specified by: - - - The model architecture (e.g. ``llama-2``, ``gpt-neox``) - - Quantization (e.g. ``q4f16_1``, ``q0f32``) - - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill-chunk-size``), which affects memory planning - - Platform (e.g. ``cuda``, ``webgpu``, ``iOS``) - -All these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_config``. - -.. code:: shell - - # 1. gen_config: generate mlc-chat-config.json and process tokenizers - mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \ - --quantization q4f16_1 --conv-template redpajama_chat \ - -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/ - # 2. compile: compile model library with specification in mlc-chat-config.json - mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \ - --device webgpu -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm - -.. note:: - When compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill_chunk_size 1024`` or - lower ``context_window_size`` to decrease memory usage. Otherwise, during runtime, - you may run into issues like: - - .. code:: text - - TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from - 'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range. - - -**Step 3. Distribute model library and model weights** - -After following the steps above, you should end up with: - -.. code:: shell - - ~/mlc-llm > ls dist/libs - RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm # ===> the model library - - ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC - mlc-chat-config.json # ===> the chat config - ndarray-cache.json # ===> the model weight info - params_shard_0.bin # ===> the model weights - params_shard_1.bin - ... - tokenizer.json # ===> the tokenizer files - tokenizer_config.json - -Upload the ``RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm`` to a github repository (for us, -it is in `binary-mlc-llm-libs `__). Then -upload the ``RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC`` to a Huggingface repo: - -.. code:: shell - - # First, please create a repository on Hugging Face. - # With the repository created, run - git lfs install - git clone https://huggingface.co/my-huggingface-account/my-redpajama3b-weight-huggingface-repo - cd my-redpajama3b-weight-huggingface-repo - cp path/to/mlc-llm/dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/* . - git add . && git commit -m "Add redpajama-3b instruct model weights" - git push origin main - -This would result in something like `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC -`_. - -**Step 4. Register as a ModelRecord** - -Finally, we are able to run the model we added in WebLLM's `get-started `__: - -.. code:: typescript - - const myAppConfig: AppConfig = { - model_list: [ - // Other records here omitted... - { - "model_url": "https://huggingface.co/my-hf-account/my-redpajama3b-weight-huggingface-repo/resolve/main/", - "local_id": "RedPajama-INCITE-Instruct-3B-v1", - "model_lib_url": "https://raw.githubusercontent.com/my-gh-account/my-repo/main/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm", - "required_features": ["shader-f16"], - }, - ] - } - - const selectedModel = "RedPajama-INCITE-Instruct-3B-v1" - await chat.reload(selectedModel, undefined, myAppConfig); - -Now, running the ``get-started`` example will use the ``RedPajama`` model you just added. -See `get-started's README `__ -on how to run it. \ No newline at end of file diff --git a/docs/deploy/mlc_chat_config.rst b/docs/deploy/mlc_chat_config.rst deleted file mode 100644 index 948d50bddd..0000000000 --- a/docs/deploy/mlc_chat_config.rst +++ /dev/null @@ -1,210 +0,0 @@ -.. _configure-mlc-chat-json: - -Customize MLC Config File in JSON -================================= - -``mlc-chat-config.json`` is required for both compile-time and runtime, hence serving two purposes: - -1. Specify how we compile a model (shown in :ref:`compile-model-libraries`), and -2. Specify conversation behavior in runtime. - -**This page focuses on the second purpose.** We explain the components of a chat -configuration and how to customize them by modifying the file. Additionally, -the runtimes also provide APIs to optionally override some of the configurations. - -In runtime, this file is stored under the directory of each compiled model -(e.g. `RedPajama chat config `__). - - -.. _struct-mlc-chat-conv: - -Structure of MLCChat Configuration ----------------------------------- - -Below is the ``mlc-chat-config.json`` file corresponding to Llama2 model: - -.. code:: json - - // mlc-chat-config.json - { - // 1. Metadata used to specify how to compile a model - "model_type": "llama", - "quantization": "q4f16_1", - "version": "0.1.0", - "model_config": { - "hidden_size": 4096, - "intermediate_size": 11008, - // more fields here... - }, - "vocab_size": 32000, - "context_window_size": 4096, - "sliding_window_size": -1, - "prefill_chunk_size": 4096, - "tensor_parallel_shards": 1, - - // 2. Tokenizer-related fields - "pad_token_id": 0, - "bos_token_id": 1, - "eos_token_id": 2, - "tokenizer_files": [ - "tokenizer.model", - "tokenizer.json", - "tokenizer_config.json" - ] - - // 3. Conversation template related fields - "conv_template": { - "name": "llama-2", - "system_template": "[INST] <>\n{system_message}\n<>\n\n ", - "system_message": "You are a helpful, respectful and honest assistant.", - // more fields here... - }, - - // 4. Chat related fields that affect runtime behavior - "mean_gen_len": 128, - "max_gen_len": 512, - "shift_fill_factor": 0.3, - "temperature": 0.6, - "repetition_penalty": 1.0, - "top_p": 0.9 - } - -.. note:: - Fields in the first part of ``mlc-chat-config.json`` (e.g. ``context-window-size``) - is only for compile-time. Changing them during runtime may lead to unexpected behavior. - -**As shown above, the file is divided into three parts. We focus on the third part, which -can be customized to change the behavior of the model.** - -``conv_template`` - .. note:: - Legacy ``mlc-chat-config.json`` may specify a string for this field to look up a registered conversation - template. It will be deprecated in the future. Re-generate config using the latest version of mlc_llm - to make sure this field is a complete JSON object. - - The conversation template that this chat uses. For more information, please refer to :ref:`conversation structure `. - -``temperature`` - The temperature applied to logits before sampling. The default value is ``0.7``. A higher temperature encourages more diverse outputs, while a lower temperature produces more deterministic outputs. - -``repetition_penalty`` - The repetition penalty controls the likelihood of the model generating repeated texts. The default value is set to ``1.0``, indicating that no repetition penalty is applied. Increasing the value reduces the likelihood of repeat text generation. However, setting a high ``repetition_penalty`` may result in the model generating meaningless texts. The ideal choice of repetition penalty may vary among models. - - For more details on how repetition penalty controls text generation, please check out the `CTRL paper `_. - -``top_p`` - This parameter determines the set of tokens from which we sample during decoding. The default value is set to ``0.95``. At each step, we select tokens from the minimal set that has a cumulative probability exceeding the ``top_p`` parameter. - - For additional information on top-p sampling, please refer to this `blog post `_. - -``mean_gen_len`` - The approximated average number of generated tokens in each round. Used to determine whether the maximum window size would be exceeded. - -``max_gen_len`` - This parameter determines the maximum length of the generated text. If it is not set, the model will generate text until it encounters a stop token. - -``shift_fill_factor`` - The fraction of maximum window size to shift when it is exceeded. - -.. _struct-conv: - -Conversation Structure -^^^^^^^^^^^^^^^^^^^^^^ - -MLC-LLM provided a set of pre-defined conversation templates, which you can directly use by -specifying ``--conv-template [name]`` when generating config. Below is a list (not complete) of -supported conversation templates: - -- ``llama-2`` -- ``mistral_default`` -- ``chatml`` -- ``phi-2`` -- ... - -Please refer to `conversation_template.py `_ for the full list of supported templates and their implementations. - -Below is a generic structure of a JSON conversation configuration (we use vicuna as an example): - -.. code:: json - - // mlc-chat-config.json - { - // ... - "conv_template": { - "name": "llama-2", - "system_template": "[INST] <>\n{system_message}\n<>\n\n ", - "system_message": "You are a helpful, respectful and honest assistant.", - "roles": { - "user": "[INST]", - "assistant": "[/INST]", - "tool": "[INST]" - }, - "role_templates": { - "user": "{user_message}", - "assistant": "{assistant_message}", - "tool": "{tool_message}" - }, - "messages": [], - "seps": [ - " " - ], - "role_content_sep": " ", - "role_empty_sep": " ", - "stop_str": [ - "[INST]" - ], - "stop_token_ids": [ - 2 - ], - "function_string": "", - "use_function_calling": false - } - } - -``name`` - Name of the conversation. -``system_template`` - The system prompt template, it optionally contains the system - message placeholder, and the placeholder will be replaced with - the system message below. -``system_message`` - The content of the system prompt (without the template format). -``system_prefix_token_ids`` - The system token ids to be prepended at the beginning of tokenized - generated prompt. -``roles`` - The conversation roles -``role_templates`` - The roles prompt template, it optionally contains the defaults - message placeholders and will be replaced by actual content -``messages`` - The conversation history messages. - Each message is a pair of strings, denoting "(role, content)". - The content can be None. -``seps`` - An array of strings indicating the separators to be used after a user - message and a model message respectively. -``role_content_sep`` - The separator between the role and the content in a message. -``role_empty_sep`` - The separator between the role and empty contents. -``stop_str`` - When the ``stop_str`` is encountered, the model will stop generating output. -``stop_token_ids`` - A list of token IDs that act as stop tokens. -``function_string`` - The function calling string. -``use_function_calling`` - Whether using function calling or not, helps check for output message format in API call. - - -Given a conversation template, the corresponding prompt generated out -from it is in the following format: - -.. code:: text - - <><><><><> - <><><><> - ... - <><><><> - <><> diff --git a/docs/deploy/python_chat_module.rst b/docs/deploy/python_chat_module.rst deleted file mode 100644 index 5776e29138..0000000000 --- a/docs/deploy/python_chat_module.rst +++ /dev/null @@ -1,369 +0,0 @@ -.. _deploy-python-chat-module: - -Python API (Chat Module) -======================== - -.. note:: - ❗ The Python API with :class:`mlc_llm.ChatModule` introduced in this page will be - deprecated in the near future. - Please go to :ref:`deploy-python-engine` for the latest Python API with complete - OpenAI API support. - -.. contents:: Table of Contents - :local: - :depth: 2 - -We expose ChatModule Python API for the MLC-LLM for easy integration into other Python projects. - -The Python API is a part of the MLC-LLM package, which we have prepared pre-built pip wheels via -the :doc:`installation page <../install/mlc_llm>`. - -Instead of following this page, you could also checkout the following tutorials in -Python notebook (all runnable in Colab): - -- `Getting Started with MLC-LLM `_: - how to quickly download prebuilt models and chat with it -- `Raw Text Generation with MLC-LLM `_: - how to perform raw text generation with MLC-LLM in Python - -.. These notebooks are not up-to-date with SLM yet -.. - `Compiling Llama-2 with MLC-LLM `_: -.. how to use Python APIs to compile models with the MLC-LLM workflow -.. - `Extensions to More Model Variants `_: -.. how to use Python APIs to compile and chat with any model variant you'd like - - -Verify Installation -------------------- - -.. code:: bash - - python -c "from mlc_llm import ChatModule; print(ChatModule)" - -You are expected to see the information about the :class:`mlc_llm.ChatModule` class. - -If the command above results in error, follow :ref:`install-mlc-packages` (either install the prebuilt pip wheels -or :ref:`mlcchat_build_from_source`). - -Run MLC Models w/ Python ------------------------- - -To run a model with MLC LLM in any platform/runtime, you need: - -1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-MLC - `_.) -2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs `__). - -There are two ways to obtain the model weights and libraries: - -1. Compile your own model weights and libraries following :doc:`the model compilation page `. -2. Use off-the-shelf `prebuilt models weights `__ and - `prebuilt model libraries `__ (see :ref:`Model Prebuilts` for details). - -We use off-the-shelf prebuilt models in this page. However, same steps apply if you want to run -the models you compiled yourself. - -**Step 1: Download prebuilt model weights and libraries** - -Skip this step if you have already obtained the model weights and libraries. - -.. code:: shell - - # Activate your conda environment - conda install -c conda-forge git-lfs - - # Download pre-conveted weights - git lfs install && mkdir dist/ - git clone https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC \ - dist/Llama-2-7b-chat-hf-q4f16_1-MLC - - # Download pre-compiled model library - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs - - -**Step 2: Run the model in Python** - -Use the conda environment you used to install ``mlc_llm``. -From the ``mlc-llm`` directory, you can create a Python -file ``sample_mlc_llm.py`` and paste the following lines: - -.. code:: python - - from mlc_llm import ChatModule - from mlc_llm.callback import StreamToStdout - - # Create a ChatModule instance - cm = ChatModule( - model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so - # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so - # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} - ) - - # You can change to other models that you downloaded - # Model variants of the same architecture can reuse the same model library - # Here WizardMath reuses Mistral's model library - # cm = ChatModule( - # model="dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", # or "dist/WizardMath-7B-V1.1-q4f16_1-MLC" - # model_lib_path="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" - # ) - - # Generate a response for a given prompt - output = cm.generate( - prompt="What is the meaning of life?", - progress_callback=StreamToStdout(callback_interval=2), - ) - - # Print prefill and decode performance statistics - print(f"Statistics: {cm.stats()}\n") - - output = cm.generate( - prompt="How many points did you list out?", - progress_callback=StreamToStdout(callback_interval=2), - ) - - # Reset the chat module by - # cm.reset_chat() - - -Now run the Python file to start the chat - -.. code:: bash - - python sample_mlc_llm.py - - -.. collapse:: See output - - .. code:: - - Using model folder: ./dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1 - Using mlc chat config: ./dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1/mlc-chat-config.json - Using library model: ./dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so - - Thank you for your question! The meaning of life is a complex and subjective topic that has been debated by philosophers, theologians, scientists, and many others for centuries. There is no one definitive answer to this question, as it can vary depending on a person's beliefs, values, experiences, and perspectives. - - However, here are some possible ways to approach the question: - - 1. Religious or spiritual beliefs: Many people believe that the meaning of life is to fulfill a divine or spiritual purpose, whether that be to follow a set of moral guidelines, to achieve spiritual enlightenment, or to fulfill a particular destiny. - 2. Personal growth and development: Some people believe that the meaning of life is to learn, grow, and evolve as individuals, to develop one's talents and abilities, and to become the best version of oneself. - 3. Relationships and connections: Others believe that the meaning of life is to form meaningful connections and relationships with others, to love and be loved, and to build a supportive and fulfilling social network. - 4. Contribution and impact: Some people believe that the meaning of life is to make a positive impact on the world, to contribute to society in a meaningful way, and to leave a lasting legacy. - 5. Simple pleasures and enjoyment: Finally, some people believe that the meaning of life is to simply enjoy the present moment, to find pleasure and happiness in the simple things in life, and to appreciate the beauty and wonder of the world around us. - - Ultimately, the meaning of life is a deeply personal and subjective question, and each person must find their own answer based on their own beliefs, values, and experiences. - - Statistics: prefill: 3477.5 tok/s, decode: 153.6 tok/s - - I listed out 5 possible ways to approach the question of the meaning of life. - -| - -**Running other models** - -Checkout the :doc:`/prebuilt_models` page to run other pre-compiled models. - -For models other than the prebuilt ones we provided: - -1. If the model is a variant to an existing model library (e.g. ``WizardMathV1.1`` and ``OpenHermes`` are variants of ``Mistral`` as - shown in the code snippet), follow :ref:`convert-weights-via-MLC` to convert the weights and reuse existing model libraries. -2. Otherwise, follow :ref:`compile-model-libraries` to compile both the model library and weights. - - -Configure MLCChat in Python ---------------------------- -If you have checked out :ref:`Configure MLCChat in JSON`, you would know -that you could configure MLCChat through various fields such as ``temperature``. We provide the -option of overriding any field you'd like in Python, so that you do not need to manually edit -``mlc-chat-config.json``. - -Since there are two concepts -- `MLCChat Configuration` and `Conversation Configuration` -- we correspondingly -provide two dataclasses :class:`mlc_llm.ChatConfig` and :class:`mlc_llm.ConvConfig`. - -We provide an example below. - -.. code:: python - - from mlc_llm import ChatModule, ChatConfig, ConvConfig - from mlc_llm.callback import StreamToStdout - - # Using a `ConvConfig`, we modify `system`, a field in the conversation template - # `system` refers to the prompt encoded before starting the chat - conv_config = ConvConfig(system_message='Please show as much happiness as you can when talking to me.') - - # We then include the `ConvConfig` instance in `ChatConfig` while overriding `max_gen_len` - # Note that `conv_config` is an optional subfield of `chat_config` - chat_config = ChatConfig(max_gen_len=256, conv_config=conv_config) - - # Using the `chat_config` we created, instantiate a `ChatModule` - cm = ChatModule( - chat_config=chat_config, - model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so - # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so - # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} - ) - - output = cm.generate( - prompt="What is one plus one?", - progress_callback=StreamToStdout(callback_interval=2), - ) - - # You could also pass in a `ConvConfig` instance to `reset_chat()` - conv_config = ConvConfig(system='Please show as much sadness as you can when talking to me.') - chat_config = ChatConfig(max_gen_len=128, conv_config=conv_config) - cm.reset_chat(chat_config) - - output = cm.generate( - prompt="What is one plus one?", - progress_callback=StreamToStdout(callback_interval=2), - ) - - -.. collapse:: See output - - .. code:: - - Using model folder: ./dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1 - Using mlc chat config: ./dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1/mlc-chat-config.json - Using library model: ./dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so - - Oh, wow, *excitedly* one plus one? *grinning* Well, let me see... *counting on fingers* One plus one is... *eureka* Two! - ... - - *Sobs* Oh, the tragedy of it all... *sobs* One plus one... *chokes back tears* It's... *gulps* it's... *breaks down in tears* TWO! - ... - -| - -.. note:: - You do not need to specify the entire ``ChatConfig`` or ``ConvConfig``. Instead, we will first - load all the fields defined in ``mlc-chat-config.json``, a file required when instantiating - a :class:`mlc_llm.ChatModule`. Then, we will load in the optional ``ChatConfig`` you provide, overriding the - fields specified. - - It is also worth noting that ``ConvConfig`` itself is overriding the original conversation template - specified by the field ``conv_template`` in the chat configuration. Learn more about it in - :ref:`Configure MLCChat in JSON`. - -Raw Text Generation in Python ------------------------------ - -Raw text generation allows the user to have more flexibility over his prompts, -without being forced to create a new conversational template, making prompt customization easier. -This serves other demands for APIs to handle LLM generation without the usual system prompts and other items. - -We provide an example below. - -.. code:: python - - from mlc_llm import ChatModule, ChatConfig, ConvConfig - from mlc_llm.callback import StreamToStdout - - # Use a `ConvConfig` to define the generation settings - # Since the "LM" template only supports raw text generation, - # System prompts will not be executed even if provided - conv_config = ConvConfig(stop_tokens=[2,], add_bos=True, stop_str="[INST]") - - # Note that `conv_config` is an optional subfield of `chat_config` - # The "LM" template serves the basic purposes of raw text generation - chat_config = ChatConfig(conv_config=conv_config, conv_template="LM") - - # Using the `chat_config` we created, instantiate a `ChatModule` - cm = ChatModule( - chat_config=chat_config, - model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so - # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so - # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} - ) - # To make the model follow conversations a chat structure should be provided - # This allows users to build their own prompts without building a new template - system_prompt = "<>\nYou are a helpful, respectful and honest assistant.\n<>\n\n" - inst_prompt = "What is mother nature?" - - # Concatenate system and instruction prompts, and add instruction tags - output = cm.generate( - prompt=f"[INST] {system_prompt+inst_prompt} [/INST]", - progress_callback=StreamToStdout(callback_interval=2), - ) - - # The LM template has no memory, so it will be reset every single generation - # In this case the model will just follow normal text completion - # because there isn't a chat structure - output = cm.generate( - prompt="Life is a quality that distinguishes", - progress_callback=StreamToStdout(callback_interval=2), - ) - -.. note:: - The ``LM`` is a template without memory, which means that every execution will be cleared. - Additionally, system prompts will not be run when instantiating a `mlc_llm.ChatModule`, - unless explicitly given inside the prompt. - -Stream Iterator in Python -------------------------- - -Stream Iterator gives users an option to stream generated text to the function that the API is called from, -instead of streaming to stdout, which could be a necessity when building services on top of MLC Chat. - -We provide an example below. - -.. code:: python - - from mlc_llm import ChatModule - from mlc_llm.callback import StreamIterator - - # Create a ChatModule instance - cm = ChatModule( - model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so - # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so - # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} - ) - - # Stream to an Iterator - from threading import Thread - - stream = StreamIterator(callback_interval=2) - generation_thread = Thread( - target=cm.generate, - kwargs={"prompt": "What is the meaning of life?", "progress_callback": stream}, - ) - generation_thread.start() - - output = "" - for delta_message in stream: - output += delta_message - - generation_thread.join() - - -API Reference -------------- - -User can initiate a chat module by creating :class:`mlc_llm.ChatModule` class, which is a wrapper of the MLC-LLM model. -The :class:`mlc_llm.ChatModule` class provides the following methods: - -.. currentmodule:: mlc_llm - -.. autoclass:: ChatModule - :members: - :exclude-members: evaluate - :undoc-members: - :show-inheritance: - - .. automethod:: __init__ - -.. autoclass:: ChatConfig - :members: - -.. autoclass:: ConvConfig - :members: - -.. autoclass:: GenerationConfig - :members: diff --git a/docs/deploy/python_engine.rst b/docs/deploy/python_engine.rst deleted file mode 100644 index 89c60ac422..0000000000 --- a/docs/deploy/python_engine.rst +++ /dev/null @@ -1,264 +0,0 @@ -.. _deploy-python-engine: - -Python API -========== - -.. note:: - This page introduces the Python API with MLCEngine in MLC LLM. - If you want to check out the old Python API which uses :class:`mlc_llm.ChatModule`, - please go to :ref:`deploy-python-chat-module` - -.. contents:: Table of Contents - :local: - :depth: 2 - - -MLC LLM provides Python API through classes :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine` -which **support full OpenAI API completeness** for easy integration into other Python projects. - -This page introduces how to use the engines in MLC LLM. -The Python API is a part of the MLC-LLM package, which we have prepared pre-built pip wheels via -the :ref:`installation page `. - - -Verify Installation -------------------- - -.. code:: bash - - python -c "from mlc_llm import MLCEngine; print(MLCEngine)" - -You are expected to see the output of ````. - -If the command above results in error, follow :ref:`install-mlc-packages` to install prebuilt pip -packages or build MLC LLM from source. - - -Run MLCEngine -------------- - -:class:`mlc_llm.MLCEngine` provides the interface of OpenAI chat completion synchronously. -:class:`mlc_llm.MLCEngine` does not batch concurrent request due to the synchronous design, -and please use :ref:`AsyncMLCEngine ` for request batching process. - -**Stream Response.** In :ref:`quick-start` and :ref:`introduction-to-mlc-llm`, -we introduced the basic use of :class:`mlc_llm.MLCEngine`. - -.. code:: python - - from mlc_llm import MLCEngine - - # Create engine - model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" - engine = MLCEngine(model) - - # Run chat completion in OpenAI API. - for response in engine.chat.completions.create( - messages=[{"role": "user", "content": "What is the meaning of life?"}], - model=model, - stream=True, - ): - for choice in response.choices: - print(choice.delta.content, end="", flush=True) - print("\n") - - engine.terminate() - -This code example first creates an :class:`mlc_llm.MLCEngine` instance with the 8B Llama-3 model. -**We design the Python API** :class:`mlc_llm.MLCEngine` **to align with OpenAI API**, -which means you can use :class:`mlc_llm.MLCEngine` in the same way of using -`OpenAI's Python package `_ -for both synchronous and asynchronous generation. - -**Non-stream Response.** The code example above uses the synchronous chat completion -interface and iterate over all the stream responses. -If you want to run without streaming, you can run - -.. code:: python - - response = engine.chat.completions.create( - messages=[{"role": "user", "content": "What is the meaning of life?"}], - model=model, - stream=False, - ) - print(response) - -Please refer to `OpenAI's Python package `_ -and `OpenAI chat completion API `_ -for the complete chat completion interface. - - -.. _python-engine-async-llm-engine: - -Run AsyncMLCEngine ------------------- - -:class:`mlc_llm.AsyncMLCEngine` provides the interface of OpenAI chat completion with -asynchronous features. -**We recommend using** :class:`mlc_llm.AsyncMLCEngine` **to batch concurrent request for better throughput.** - -**Stream Response.** The core use of :class:`mlc_llm.AsyncMLCEngine` for stream responses is as follows. - -.. code:: python - - async for response in await engine.chat.completions.create( - messages=[{"role": "user", "content": "What is the meaning of life?"}], - model=model, - stream=True, - ): - for choice in response.choices: - print(choice.delta.content, end="", flush=True) - -.. collapse:: The collapsed is a complete runnable example of AsyncMLCEngine in Python. - - .. code:: python - - import asyncio - from typing import Dict - - from mlc_llm.serve import AsyncMLCEngine - - model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" - prompts = [ - "Write a three-day travel plan to Pittsburgh.", - "What is the meaning of life?", - ] - - - async def test_completion(): - # Create engine - async_engine = AsyncMLCEngine(model=model) - - num_requests = len(prompts) - output_texts: Dict[str, str] = {} - - async def generate_task(prompt: str): - async for response in await async_engine.chat.completions.create( - messages=[{"role": "user", "content": prompt}], - model=model, - stream=True, - ): - if response.id not in output_texts: - output_texts[response.id] = "" - output_texts[response.id] += response.choices[0].delta.content - - tasks = [asyncio.create_task(generate_task(prompts[i])) for i in range(num_requests)] - await asyncio.gather(*tasks) - - # Print output. - for request_id, output in output_texts.items(): - print(f"Output of request {request_id}:\n{output}\n") - - async_engine.terminate() - - - asyncio.run(test_completion()) - -| - -**Non-stream Response.** Similarly, :class:`mlc_llm.AsyncEngine` provides the non-stream response -interface. - -.. code:: python - - response = await engine.chat.completions.create( - messages=[{"role": "user", "content": "What is the meaning of life?"}], - model=model, - stream=False, - ) - print(response) - -Please refer to `OpenAI's Python package `_ -and `OpenAI chat completion API `_ -for the complete chat completion interface. - - -Engine Mode ------------ - -To ease the engine configuration, the constructors of :class:`mlc_llm.MLCEngine` and -:class:`mlc_llm.AsyncMLCEngine` have an optional argument ``mode``, -which falls into one of the three options ``"local"``, ``"interactive"`` or ``"server"``. -The default mode is ``"local"``. - -Each mode denotes a pre-defined configuration of the engine to satisfy different use cases. -The choice of the mode controls the request concurrency of the engine, -as well as engine's KV cache token capacity (or in other words, the maximum -number of tokens that the engine's KV cache can hold), -and further affects the GPU memory usage of the engine. - -In short, - -- mode ``"local"`` uses low request concurrency and low KV cache capacity, which is suitable for cases where **concurrent requests are not too many, and the user wants to save GPU memory usage**. -- mode ``"interactive"`` uses 1 as the request concurrency and low KV cache capacity, which is designed for **interactive use cases** such as chats and conversations. -- mode ``"server"`` uses as much request concurrency and KV cache capacity as possible. This mode aims to **fully utilize the GPU memory for large server scenarios** where concurrent requests may be many. - -**For system benchmark, please select mode** ``"server"``. -Please refer to :ref:`python-engine-api-reference` for detailed documentation of the engine mode. - - -Deploy Your Own Model with Python API -------------------------------------- - -The :ref:`introduction page ` introduces how we can deploy our -own models with MLC LLM. -This section introduces how you can use the model weights you convert and the model library you build -in :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine`. - -We use the `Phi-2 `_ as the example model. - -**Specify Model Weight Path.** Assume you have converted the model weights for your own model, -you can construct a :class:`mlc_llm.MLCEngine` as follows: - -.. code:: python - - from mlc_llm import MLCEngine - - model = "models/phi-2" # Assuming the converted phi-2 model weights are under "models/phi-2" - engine = MLCEngine(model) - - -**Specify Model Library Path.** Further, if you build the model library on your own, -you can use it in :class:`mlc_llm.MLCEngine` by passing the library path through argument ``model_lib_path``. - -.. code:: python - - from mlc_llm import MLCEngine - - model = "models/phi-2" - model_lib_path = "models/phi-2/lib.so" # Assuming the phi-2 model library is built at "models/phi-2/lib.so" - engine = MLCEngine(model, model_lib_path=model_lib_path) - - -The same applies to :class:`mlc_llm.AsyncMLCEngine`. - - -.. _python-engine-api-reference: - -API Reference -------------- - -The :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine` classes provide the following constructors. - -The MLCEngine and AsyncMLCEngine have full OpenAI API completeness. -Please refer to `OpenAI's Python package `_ -and `OpenAI chat completion API `_ -for the complete chat completion interface. - -.. currentmodule:: mlc_llm - -.. autoclass:: MLCEngine - :members: - :exclude-members: evaluate - :undoc-members: - :show-inheritance: - - .. automethod:: __init__ - -.. autoclass:: AsyncMLCEngine - :members: - :exclude-members: evaluate - :undoc-members: - :show-inheritance: - - .. automethod:: __init__ diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst deleted file mode 100644 index 07d39dbfad..0000000000 --- a/docs/deploy/rest.rst +++ /dev/null @@ -1,367 +0,0 @@ -.. _deploy-rest-api: - -REST API -======== - -.. contents:: Table of Contents - :local: - :depth: 2 - -We provide `REST API `_ -for a user to interact with MLC-LLM in their own programs. - -Install MLC-LLM Package ------------------------- - -SERVE is a part of the MLC-LLM package, installation instruction for which can be found :ref:`here `. Once you have install the MLC-LLM package, you can run the following command to check if the installation was successful: - -.. code:: bash - - mlc_llm serve --help - -You should see serve help message if the installation was successful. - -Quick start ------------- - -This section provides a quick start guide to work with MLC-LLM REST API. To launch a server, run the following command: - -.. code:: bash - - mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] - -where ``MODEL`` is the model folder after compiling with :ref:`MLC-LLM build process `. Information about other arguments can be found under :ref:`Launch the server ` section. - -Once you have launched the Server, you can use the API in your own program to send requests. Below is an example of using the API to interact with MLC-LLM in Python without Streaming (suppose the server is running on ``http://127.0.0.1:8080/``): - -.. code:: bash - - import requests - - # Get a response using a prompt without streaming - payload = { - "model": "./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/", - "messages": [ - {"role": "user", "content": "Write a haiku about apples."}, - ], - "stream": False, - # "n": 1, - "max_tokens": 300, - } - r = requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload) - choices = r.json()["choices"] - for choice in choices: - print(f"{choice['message']['content']}\n") - ------------------------------------------------- - - -.. _rest_launch_server: - - -Launch the Server ------------------ - -To launch the MLC Server for MLC-LLM, run the following command in your terminal. - -.. code:: bash - - mlc_llm serve MODEL [--model-lib-path MODEL_LIB_PATH] [--device DEVICE] [--max-batch-size MAX_BATCH_SIZE] [--max-total-seq-length MAX_TOTAL_SEQ_LENGTH] [--prefill-chunk-size PREFILL_CHUNK_SIZE] [--enable-tracing] [--host HOST] [--port PORT] [--allow-credentials] [--allowed-origins ALLOWED_ORIGINS] [--allowed-methods ALLOWED_METHODS] [--allowed-headers ALLOWED_HEADERS] - -MODEL The model folder after compiling with MLC-LLM build process. The parameter - can either be the model name with its quantization scheme - (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model - folder. In the former case, we will use the provided name to search - for the model folder over possible paths. ---model-lib-path A field to specify the full path to the model library file to use (e.g. a ``.so`` file). ---device The description of the device to run on. User should provide a string in the - form of 'device_name:device_id' or 'device_name', where 'device_name' is one of - 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the - local device), and 'device_id' is the device id to run on. The default value is ``auto``, - with the device id set to 0 for default. ---host The host at which the server should be started, defaults to ``127.0.0.1``. ---port The port on which the server should be started, defaults to ``8000``. ---allow-credentials A flag to indicate whether the server should allow credentials. If set, the server will - include the ``CORS`` header in the response ---allowed-origins Specifies the allowed origins. It expects a JSON list of strings, with the default value being ``["*"]``, allowing all origins. ---allowed-methods Specifies the allowed methods. It expects a JSON list of strings, with the default value being ``["*"]``, allowing all methods. ---allowed-headers Specifies the allowed headers. It expects a JSON list of strings, with the default value being ``["*"]``, allowing all headers. ---max-batch-size The maximum batch size for processing. ---max-total-seq-length The maximum total number of tokens whose KV data are allowed to exist in the KV cache at any time. Set it to None to enable automatic computation of the max total sequence length. ---prefill-chunk-size The maximum total sequence length in a prefill. If not specified, it will be automatically inferred from model config. ---enable-tracing A boolean indicating if to enable event logging for requests. - -You can access ``http://127.0.0.1:PORT/docs`` (replace ``PORT`` with the port number you specified) to see the list of -supported endpoints. - -API Endpoints -------------- - -The REST API provides the following endpoints: - -.. http:get:: /v1/models - ------------------------------------------------- - - Get a list of models available for MLC-LLM. - -**Example** - -.. code:: bash - - import requests - - url = "http://127.0.0.1:8000/v1/models" - headers = {"accept": "application/json"} - - response = requests.get(url, headers=headers) - - if response.status_code == 200: - print("Response:") - print(response.json()) - else: - print("Error:", response.status_code) - - -.. http:post:: /v1/chat/completions - ------------------------------------------------- - - Get a response from MLC-LLM using a prompt, either with or without streaming. - -**Chat Completion Request Object** - -- **messages** (*List[ChatCompletionMessage]*, required): A sequence of messages that have been exchanged in the conversation so far. Each message in the conversation is represented by a `ChatCompletionMessage` object, which includes the following fields: - - **content** (*Optional[Union[str, List[Dict[str, str]]]]*): The text content of the message or structured data in case of tool-generated messages. - - **role** (*Literal["system", "user", "assistant", "tool"]*): The role of the message sender, indicating whether the message is from the system, user, assistant, or a tool. - - **name** (*Optional[str]*): An optional name for the sender of the message. - - **tool_calls** (*Optional[List[ChatToolCall]]*): A list of calls to external tools or functions made within this message, applicable when the role is `tool`. - - **tool_call_id** (*Optional[str]*): A unique identifier for the tool call, relevant when integrating external tools or services. - -- **model** (*str*, required): The model to be used for generating responses. - -- **frequency_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model’s likelihood to repeat tokens. - -- **presence_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens if they are already present in the text so far, decreasing the model’s likelihood to repeat tokens. - -- **logprobs** (*bool*, optional, default=False): Indicates whether to include log probabilities for each token in the response. - -- **top_logprobs** (*int*, optional, default=0): An integer ranging from 0 to 5. It determines the number of tokens, most likely to appear at each position, to be returned. Each token is accompanied by a log probability. If this parameter is used, 'logprobs' must be set to true. - -- **logit_bias** (*Optional[Dict[int, float]]*): Allows specifying biases for or against specific tokens during generation. - -- **max_tokens** (*Optional[int]*): The maximum number of tokens to generate in the response(s). - -- **n** (*int*, optional, default=1): Number of responses to generate for the given prompt. - -- **seed** (*Optional[int]*): A seed for deterministic generation. Using the same seed and inputs will produce the same output. - -- **stop** (*Optional[Union[str, List[str]]]*): One or more strings that, if encountered, will cause generation to stop. - -- **stream** (*bool*, optional, default=False): If `True`, responses are streamed back as they are generated. - -- **temperature** (*float*, optional, default=1.0): Controls the randomness of the generation. Lower values lead to less random completions. - -- **top_p** (*float*, optional, default=1.0): Nucleus sampling parameter that controls the diversity of the generated responses. - -- **tools** (*Optional[List[ChatTool]]*): Specifies external tools or functions that can be called as part of the chat. - -- **tool_choice** (*Optional[Union[Literal["none", "auto"], Dict]]*): Controls how tools are selected for use in responses. - -- **user** (*Optional[str]*): An optional identifier for the user initiating the request. - -- **ignore_eos** (*bool*, optional, default=False): If `True`, the model will ignore the end-of-sequence token for generating responses. - -- **response_format** (*RequestResponseFormat*, optional): Specifies the format of the response. Can be either "text" or "json_object", with optional schema definition for JSON responses. - -**Returns** - -- If `stream` is `False`, a `ChatCompletionResponse` object containing the generated response(s). -- If `stream` is `True`, a stream of `ChatCompletionStreamResponse` objects, providing a real-time feed of generated responses. - - -**ChatCompletionResponseChoice** - -- **finish_reason** (*Optional[Literal["stop", "length", "tool_calls", "error"]]*, optional): The reason the completion process was terminated. It can be due to reaching a stop condition, the maximum length, output of tool calls, or an error. - -- **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices. - -- **message** (*ChatCompletionMessage*, required): The message part of the chat completion, containing the content of the chat response. - -- **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token - -**ChatCompletionStreamResponseChoice** - -- **finish_reason** (*Optional[Literal["stop", "length", "tool_calls"]]*, optional): Specifies why the streaming completion process ended. Valid reasons are "stop", "length", and "tool_calls". - -- **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices. - -- **delta** (*ChatCompletionMessage*, required): Represents the incremental update or addition to the chat completion message in the stream. - -- **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token - -**ChatCompletionResponse** - -- **id** (*str*, required): A unique identifier for the chat completion session. - -- **choices** (*List[ChatCompletionResponseChoice]*, required): A collection of `ChatCompletionResponseChoice` objects, representing the potential responses generated by the model. - -- **created** (*int*, required, default=current time): The UNIX timestamp representing when the response was generated. - -- **model** (*str*, required): The name of the model used to generate the chat completions. - -- **system_fingerprint** (*str*, required): A system-generated fingerprint that uniquely identifies the computational environment. - -- **object** (*Literal["chat.completion"]*, required, default="chat.completion"): A string literal indicating the type of object, here always "chat.completion". - -- **usage** (*UsageInfo*, required, default=empty `UsageInfo` object): Contains information about the API usage for this specific request. - -**ChatCompletionStreamResponse** - -- **id** (*str*, required): A unique identifier for the streaming chat completion session. - -- **choices** (*List[ChatCompletionStreamResponseChoice]*, required): A list of `ChatCompletionStreamResponseChoice` objects, each representing a part of the streaming chat response. - -- **created** (*int*, required, default=current time): The creation time of the streaming response, represented as a UNIX timestamp. - -- **model** (*str*, required): Specifies the model that was used for generating the streaming chat completions. - -- **system_fingerprint** (*str*, required): A unique identifier for the system generating the streaming completions. - -- **object** (*Literal["chat.completion.chunk"]*, required, default="chat.completion.chunk"): A literal indicating that this object represents a chunk of a streaming chat completion. - ------------------------------------------------- - - -**Example** - -Below is an example of using the API to interact with MLC-LLM in Python with Streaming. - -.. code:: bash - - import requests - import json - - # Get a response using a prompt with streaming - payload = { - "model": "./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/", - "messages": [{"role": "user", "content": "Write a haiku"}], - "stream": True, - } - with requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload, stream=True) as r: - for chunk in r.iter_content(chunk_size=None): - chunk = chunk.decode("utf-8") - if "[DONE]" in chunk[6:]: - break - response = json.loads(chunk[6:]) - content = response["choices"][0]["delta"].get("content", "") - print(content, end="", flush=True) - print("\n") - ------------------------------------------------- - -There is also support for function calling similar to OpenAI (https://platform.openai.com/docs/guides/function-calling). Below is an example on how to use function calling in Python. - -.. code:: bash - - import requests - import json - - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - }, - }, - } - ] - - payload = { - "model": "./dist/gorilla-openfunctions-v1-q4f16_1-MLC/", - "messages": [ - { - "role": "user", - "content": "What is the current weather in Pittsburgh, PA in fahrenheit?", - } - ], - "stream": False, - "tools": tools, - } - - r = requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload) - print(f"{r.json()['choices'][0]['message']['tool_calls'][0]['function']}\n") - - # Output: {'name': 'get_current_weather', 'arguments': {'location': 'Pittsburgh, PA', 'unit': 'fahrenheit'}} - ------------------------------------------------- - -Function Calling with streaming is also supported. Below is an example on how to use function calling with streaming in Python. - -.. code:: bash - - import requests - import json - - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - }, - }, - } - ] - - payload = { - "model": "./dist/gorilla-openfunctions-v1-q4f16_1-MLC/", - "messages": [ - { - "role": "user", - "content": "What is the current weather in Pittsburgh, PA and Tokyo, JP in fahrenheit?", - } - ], - "stream": True, - "tools": tools, - } - - with requests.post("http://127.0.0.1:8080/v1/chat/completions", json=payload, stream=True) as r: - for chunk in r.iter_content(chunk_size=None): - chunk = chunk.decode("utf-8") - if "[DONE]" in chunk[6:]: - break - response = json.loads(chunk[6:]) - content = response["choices"][0]["delta"].get("content", "") - print(f"{content}", end="", flush=True) - print("\n") - - # Output: ["get_current_weather(location='Pittsburgh,PA',unit='fahrenheit')", "get_current_weather(location='Tokyo,JP',unit='fahrenheit')"] - - -.. note:: - The API is a uniform interface that supports multiple languages. You can also utilize these functionalities in languages other than Python. - - - diff --git a/docs/get_started/introduction.rst b/docs/get_started/introduction.rst deleted file mode 100644 index 29060d5a60..0000000000 --- a/docs/get_started/introduction.rst +++ /dev/null @@ -1,322 +0,0 @@ -.. _introduction-to-mlc-llm: - -Introduction to MLC LLM -======================= - -.. contents:: Table of Contents - :local: - :depth: 2 - -Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance -universal LLM deployment engine. The mission of this project is to enable everyone to develop, -optimize and deploy AI models natively on everyone's devices with ML compilation techniques. - -This page is a quick tutorial to introduce how to try out MLC LLM, and the steps to -deploy your own models with MLC LLM. - -Installation ------------- - -:ref:`MLC LLM ` is available via pip. -It is always recommended to install it in an isolated conda virtual environment. - -To verify the installation, activate your virtual environment, run - -.. code:: bash - - python -c "import mlc_llm; print(mlc_llm.__path__)" - -You are expected to see the installation path of MLC LLM Python package. - - -Chat CLI --------- - -As the first example, we try out the chat CLI in MLC LLM with 4-bit quantized 8B Llama-3 model. -You can run MLC chat through a one-liner command: - -.. code:: bash - - mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC - -It may take 1-2 minutes for the first time running this command. -After waiting, this command launch a chat interface where you can enter your prompt and chat with the model. - -.. code:: - - You can use the following special commands: - /help print the special commands - /exit quit the cli - /stats print out the latest stats (token/sec) - /reset restart a fresh chat - /set [overrides] override settings in the generation config. For example, - `/set temperature=0.5;max_gen_len=100;stop=end,stop` - Note: Separate stop words in the `stop` option with commas (,). - Multi-line input: Use escape+enter to start a new line. - - user: What's the meaning of life - assistant: - What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life. - - The concept of the meaning of life has been debated and... - - -The figure below shows what run under the hood of this chat CLI command. -For the first time running the command, there are three major phases. - -- **Phase 1. Pre-quantized weight download.** This phase automatically downloads pre-quantized Llama-3 model from `Hugging Face `_ and saves it to your local cache directory. -- **Phase 2. Model compilation.** This phase automatically optimizes the Llama-3 model to accelerate model inference on GPU with techniques of machine learning compilation in `Apache TVM `_ compiler, and generate the binary model library that enables the execution language models on your local GPU. -- **Phase 3. Chat runtime.** This phase consumes the model library built in phase 2 and the model weights downloaded in phase 1, launches a platform-native chat runtime to drive the execution of Llama-3 model. - -We cache the pre-quantized model weights and compiled model library locally. -Therefore, phase 1 and 2 will only execute **once** over multiple runs. - -.. figure:: /_static/img/project-workflow.svg - :width: 700 - :align: center - :alt: Project Workflow - - Workflow in MLC LLM - -| - -.. _introduction-to-mlc-llm-python-api: - -Python API ----------- - -In the second example, we run the Llama-3 model with the chat completion Python API of MLC LLM. -You can save the code below into a Python file and run it. - -.. code:: python - - from mlc_llm import MLCEngine - - # Create engine - model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" - engine = MLCEngine(model) - - # Run chat completion in OpenAI API. - for response in engine.chat.completions.create( - messages=[{"role": "user", "content": "What is the meaning of life?"}], - model=model, - stream=True, - ): - for choice in response.choices: - print(choice.delta.content, end="", flush=True) - print("\n") - - engine.terminate() - -.. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-engine-api.jpg - :width: 500 - :align: center - - MLC LLM Python API - -This code example first creates an :class:`mlc_llm.MLCEngine` instance with the 4-bit quantized Llama-3 model. -**We design the Python API** :class:`mlc_llm.MLCEngine` **to align with OpenAI API**, -which means you can use :class:`mlc_llm.MLCEngine` in the same way of using -`OpenAI's Python package `_ -for both synchronous and asynchronous generation. - -In this code example, we use the synchronous chat completion interface and iterate over -all the stream responses. -If you want to run without streaming, you can run - -.. code:: python - - response = engine.chat.completions.create( - messages=[{"role": "user", "content": "What is the meaning of life?"}], - model=model, - stream=False, - ) - print(response) - -You can also try different arguments supported in `OpenAI chat completion API `_. -If you would like to do concurrent asynchronous generation, you can use :class:`mlc_llm.AsyncMLCEngine` instead. - -REST Server ------------ - -For the third example, we launch a REST server to serve the 4-bit quantized Llama-3 model -for OpenAI chat completion requests. The server can be launched in command line with - -.. code:: bash - - mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC - -The server is hooked at ``http://127.0.0.1:8000`` by default, and you can use ``--host`` and ``--port`` -to set a different host and port. -When the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``), -we can open a new shell and send a cURL request via the following command: - -.. code:: bash - - curl -X POST \ - -H "Content-Type: application/json" \ - -d '{ - "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", - "messages": [ - {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} - ] - }' \ - http://127.0.0.1:8000/v1/chat/completions - -The server will process this request and send back the response. -Similar to :ref:`introduction-to-mlc-llm-python-api`, you can pass argument ``"stream": true`` -to request for stream responses. - -.. _introduction-deploy-your-own-model: - -Deploy Your Own Model ---------------------- - -So far we have been using pre-converted models weights from Hugging Face. -This section introduces the core workflow regarding how you can *run your own models with MLC LLM*. - -We use the `Phi-2 `_ as the example model. -Assuming the Phi-2 model is downloaded and placed under ``models/phi-2``, -there are two major steps to prepare your own models. - -- **Step 1. Generate MLC config.** The first step is to generate the configuration file of MLC LLM. - - .. code:: bash - - export LOCAL_MODEL_PATH=models/phi-2 # The path where the model resides locally. - export MLC_MODEL_PATH=dist/phi-2-MLC/ # The path where to place the model processed by MLC. - export QUANTIZATION=q0f16 # The choice of quantization. - export CONV_TEMPLATE=phi-2 # The choice of conversation template. - mlc_llm gen_config $LOCAL_MODEL_PATH \ - --quantization $QUANTIZATION \ - --conv-template $CONV_TEMPLATE \ - -o $MLC_MODEL_PATH - - The config generation command takes in the local model path, the target path of MLC output, - the conversation template name in MLC and the quantization name in MLC. - Here the quantization ``q0f16`` means float16 without quantization, - and the conversation template ``phi-2`` is the Phi-2 model's template in MLC. - - If you want to enable tensor parallelism on multiple GPUs, add argument - ``--tensor-parallel-shards $NGPU`` to the config generation command. - - - `The full list of supported quantization in MLC `_. You can try different quantization methods with MLC LLM. Typical quantization methods are ``q4f16_1`` for 4-bit group quantization, ``q4f16_ft`` for 4-bit FasterTransformer format quantization. - - `The full list of conversation template in MLC `_. - -- **Step 2. Convert model weights.** In this step, we convert the model weights to MLC format. - - .. code:: bash - - mlc_llm convert_weight $LOCAL_MODEL_PATH \ - --quantization $QUANTIZATION \ - -o $MLC_MODEL_PATH - - This step consumes the raw model weights and converts them to for MLC format. - The converted weights will be stored under ``$MLC_MODEL_PATH``, - which is the same directory where the config file generated in Step 1 resides. - -Now, we can try to run your own model with chat CLI: - -.. code:: bash - - mlc_llm chat $MLC_MODEL_PATH - -For the first run, model compilation will be triggered automatically to optimize the -model for GPU accelerate and generate the binary model library. -The chat interface will be displayed after model JIT compilation finishes. -You can also use this model in Python API, MLC serve and other use scenarios. - -(Optional) Compile Model Library -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -In previous sections, model libraries are compiled when the :class:`mlc_llm.MLCEngine` launches, -which is what we call "JIT (Just-in-Time) model compilation". -In some cases, it is beneficial to explicitly compile the model libraries. -We can deploy LLMs with reduced dependencies by shipping the library for deployment without going through compilation. -It will also enable advanced options such as cross-compiling the libraries for web and mobile deployments. - - -Below is an example command of compiling model libraries in MLC LLM: - -.. code:: bash - - export $MODEL_LIB_PATH=$MLC_MODEL_PATH/lib.so # ".dylib" for Intel Macs. - # ".dll" for Windows. - # ".wasm" for web. - # ".tar" for iPhone/Android. - mlc_llm compile $MLC_MODEL_PATH -o $MODEL_LIB_PATH - -At runtime, we need to specify this model library path to use it. For example, - -.. code:: bash - - # For chat CLI - mlc_llm chat $MLC_MODEL_PATH --model-lib-path $MODEL_LIB_PATH - # For REST server - mlc_llm serve $MLC_MODEL_PATH --model-lib-path $MODEL_LIB_PATH - -.. code:: python - - from mlc_llm import MLCEngine - - # For Python API - model = "models/phi-2" - model_lib_path = "models/phi-2/lib.so" - engine = MLCEngine(model, model_lib_path=model_lib_path) - -:ref:`compile-model-libraries` introduces the model compilation command in detail, -where you can find instructions and example commands to compile model to different -hardware backends, such as WebGPU, iOS and Android. - -Universal Deployment --------------------- - -MLC LLM is a high-performance universal deployment solution for large language models, -to enable native deployment of any large language models with native APIs with compiler acceleration -So far, we have gone through several examples running on a local GPU environment. -The project supports multiple kinds of GPU backends. - -You can use `--device` option in compilation and runtime to pick a specific GPU backend. -For example, if you have an NVIDIA or AMD GPU, you can try to use the option below -to run chat through the vulkan backend. Vulkan-based LLM applications run in less typical -environments (e.g. SteamDeck). - -.. code:: bash - - mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --device vulkan - -The same core LLM runtime engine powers all the backends, enabling the same model to be deployed across backends as -long as they fit within the memory and computing budget of the corresponding hardware backend. -We also leverage machine learning compilation to build backend-specialized optimizations to -get out the best performance on the targetted backend when possible, and reuse key insights and optimizations -across backends we support. - -Please checkout the what to do next sections below to find out more about different deployment scenarios, -such as WebGPU-based browser deployment, mobile and other settings. - -Summary and What to Do Next ---------------------------- - -To briefly summarize this page, - -- We went through three examples (chat CLI, Python API, and REST server) of MLC LLM, -- we introduced how to convert model weights for your own models to run with MLC LLM, and (optionally) how to compile your models. -- We also discussed the universal deployment capability of MLC LLM. - -Next, please feel free to check out the pages below for quick start examples and more detailed information -on specific platforms - -- :ref:`Quick start examples ` for Python API, chat CLI, REST server, web browser, iOS and Android. -- Depending on your use case, check out our API documentation and tutorial pages: - - - :ref:`webllm-runtime` - - :ref:`deploy-rest-api` - - :ref:`deploy-cli` - - :ref:`deploy-python-engine` - - :ref:`deploy-ios` - - :ref:`deploy-android` - - :ref:`deploy-ide-integration` - -- :ref:`Convert model weight to MLC format `, if you want to run your own models. -- :ref:`Compile model libraries `, if you want to deploy to web/iOS/Android or control the model optimizations. -- Report any problem or ask any question: open new issues in our `GitHub repo `_. diff --git a/docs/get_started/project_overview.rst b/docs/get_started/project_overview.rst deleted file mode 100644 index ef631e40c8..0000000000 --- a/docs/get_started/project_overview.rst +++ /dev/null @@ -1,88 +0,0 @@ -.. _project-overview: - -Project Overview -================ - -This page introduces high-level project concepts to help us use and customize MLC LLM. -The MLC-LLM project consists of three distinct submodules: model definition, model compilation, and runtimes. - -.. figure:: /_static/img/project-structure.svg - :width: 600 - :align: center - :alt: Project Structure - - Three independent submodules in MLC LLM - -**➀ Model definition in Python.** MLC offers a variety of pre-defined architectures, such as Llama (e.g., Llama2, Vicuna, OpenLlama, Wizard), GPT-NeoX (e.g., RedPajama, Dolly), RNNs (e.g., RWKV), and GPT-J (e.g., MOSS). Model developers could solely define the model in pure Python, without having to touch code generation and runtime. - -**➁ Model compilation in Python.** Models are compiled by :doc:`TVM Unity ` compiler, where the compilation is configured in pure Python. MLC LLM quantizes and exports the Python-based model to a model library and quantized model weights. Quantization and optimization algorithms can be developed in pure Python to compress and accelerate LLMs for specific usecases. - -**➂ Platform-native runtimes.** Variants of MLCChat are provided on each platform: **C++** for command line, **Javascript** for web, **Swift** for iOS, and **Java** for Android, configurable with a JSON chat config. App developers only need to familiarize with the platform-naive runtimes to integrate MLC-compiled LLMs into their projects. - -.. _terminologies: - -Terminologies -------------- - -It is helpful for us to familiarize the basic terminologies used in the MLC chat applications. Below are the -three things you need to run a model with MLC. - -- **model lib**: The model library refers to the executable libraries that enable - the execution of a specific model architecture. On Linux and M-chip macOS, these libraries have the suffix - ``.so``; on intel macOS, the suffix is ``.dylib``; on Windows, the library file ends with ``.dll``; - on web browser, the library suffix is ``.wasm``. (see `binary-mlc-llm-libs `__). - -- **model weights**: The model weight is a folder that contains the quantized neural network weights - of the language models as well as the tokenizer configurations. (e.g. `Llama-2-7b-chat-hf-q4f16_1-MLC `__) - -- **chat config**: The chat configuration includes settings that allow customization of parameters such as temperature and system prompt. - The default chat config usually resides in the same directory as model weights. (e.g. see ``Llama-2-7b-chat-hf-q4f16_1``'s - `mlc-chat-config.json `__) - -Model Preparation ------------------ - - -There are several ways to prepare the model weights and model lib. - -- :ref:`Model Prebuilts` contains models that can be directly used. -- You can also :doc:`run model compilation ` for model weight variants for given supported architectures. -- Finally, you can incorporate a new model architecture/inference logic following :doc:`Define New Models `. - -A default chat config usually comes with the model weight directory. You can further customize -the system prompt, temperature, and other options by modifying the JSON file. -MLC chat runtimes also provide API to override these options during model reload. -Please refer to :ref:`configure-mlc-chat-json` for more details. - - -Runtime Flow Overview ---------------------- - -Once the model weights, model library, and chat configuration are prepared, an MLC chat runtime can consume them as an engine to drive a chat application. -The diagram below shows a typical workflow for a MLC chat application. - -.. image:: https://raw.githubusercontent.com/mlc-ai/web-data/a05d4598bae6eb5a3133652d5cc0323ced3b0e17/images/mlc-llm/tutorials/mlc-llm-flow-slm.svg - :width: 90% - :align: center - -On the right side of the figure, you can see pseudo-code illustrating the structure of an MLC chat API during the execution of a chat app. -Typically, there is a ``ChatModule`` that manages the model. We instantiate the chat app with two files: the model weights (which include an ``mlc-chat-config.json``) -and the model library. We also have an optional chat configuration, which allows for overriding settings such as the system prompt and temperature. - -All MLC runtimes, including iOS, Web, CLI, and others, use these three elements. -All the runtime can read the same model weight folder. The packaging of the model libraries may vary depending on the runtime. -For the CLI, the model libraries are stored in a DLL directory. -iOS and Android include pre-packaged model libraries within the app due to dynamic loading restrictions. -WebLLM utilizes URLs of local or Internet-hosted WebAssembly (Wasm) files. - -What to Do Next ---------------- - -Thank you for reading and learning the high-level concepts. -Moving next, feel free to check out documents on the left navigation panel and -learn about topics you are interested in. - -- :ref:`configure-mlc-chat-json` shows how to configure specific chat behavior. -- Build and Deploy App section contains guides to build apps - and platform-specific MLC chat runtimes. -- Compile models section provides guidelines to convert model weights and produce model libs. diff --git a/docs/get_started/quick_start.rst b/docs/get_started/quick_start.rst deleted file mode 100644 index 8349197eda..0000000000 --- a/docs/get_started/quick_start.rst +++ /dev/null @@ -1,190 +0,0 @@ -.. _quick-start: - -Quick Start -=========== - -Examples --------- - -To begin with, try out MLC LLM support for int4-quantized Llama3 8B. -It is recommended to have at least 6GB free VRAM to run it. - -.. tabs:: - - .. tab:: Python - - **Install MLC LLM**. :ref:`MLC LLM ` is available via pip. - It is always recommended to install it in an isolated conda virtual environment. - - **Run chat completion in Python.** The following Python script showcases the Python API of MLC LLM: - - .. code:: python - - from mlc_llm import MLCEngine - - # Create engine - model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" - engine = MLCEngine(model) - - # Run chat completion in OpenAI API. - for response in engine.chat.completions.create( - messages=[{"role": "user", "content": "What is the meaning of life?"}], - model=model, - stream=True, - ): - for choice in response.choices: - print(choice.delta.content, end="", flush=True) - print("\n") - - engine.terminate() - - .. Todo: link the colab notebook when ready: - - **Documentation and tutorial.** Python API reference and its tutorials are :ref:`available online `. - - .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-engine-api.jpg - :width: 600 - :align: center - - MLC LLM Python API - - .. tab:: REST Server - - **Install MLC LLM**. :ref:`MLC LLM ` is available via pip. - It is always recommended to install it in an isolated conda virtual environment. - - **Launch a REST server.** Run the following command from command line to launch a REST server at ``http://127.0.0.1:8000``. - - .. code:: shell - - mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC - - **Send requests to server.** When the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``), - open a new shell and send a request via the following command: - - .. code:: shell - - curl -X POST \ - -H "Content-Type: application/json" \ - -d '{ - "model": "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC", - "messages": [ - {"role": "user", "content": "Hello! Our project is MLC LLM. What is the name of our project?"} - ] - }' \ - http://127.0.0.1:8000/v1/chat/completions - - **Documentation and tutorial.** Check out :ref:`deploy-rest-api` for the REST API reference and tutorial. - Our REST API has complete OpenAI API support. - - .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-serve-request.jpg - :width: 600 - :align: center - - Send HTTP request to REST server in MLC LLM - - .. tab:: Command Line - - **Install MLC LLM**. :ref:`MLC LLM ` is available via pip. - It is always recommended to install it in an isolated conda virtual environment. - - For Windows/Linux users, make sure to have latest :ref:`Vulkan driver ` installed. - - **Run in command line**. - - .. code:: bash - - mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC - - - If you are using windows/linux/steamdeck and would like to use vulkan, - we recommend installing necessary vulkan loader dependency via conda - to avoid vulkan not found issues. - - .. code:: bash - - conda install -c conda-forge gcc libvulkan-loader - - - .. tab:: Web Browser - - `WebLLM `__. MLC LLM generates performant code for WebGPU and WebAssembly, - so that LLMs can be run locally in a web browser without server resources. - - **Download pre-quantized weights**. This step is self-contained in WebLLM. - - **Download pre-compiled model library**. WebLLM automatically downloads WebGPU code to execute. - - **Check browser compatibility**. The latest Google Chrome provides WebGPU runtime and `WebGPU Report `__ as a useful tool to verify WebGPU capabilities of your browser. - - .. figure:: https://blog.mlc.ai/img/redpajama/web.gif - :width: 300 - :align: center - - MLC LLM on Web - - .. tab:: iOS - - **Install MLC Chat iOS**. It is available on AppStore: - - .. image:: https://developer.apple.com/assets/elements/badges/download-on-the-app-store.svg - :width: 135 - :target: https://apps.apple.com/us/app/mlc-chat/id6448482937 - - | - - **Requirement**. Llama3-8B model needs an iOS device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. - - **Tutorial and source code**. The source code of the iOS app is fully `open source `__, - and a :ref:`tutorial ` is included in documentation. - - .. figure:: https://blog.mlc.ai/img/redpajama/ios.gif - :width: 300 - :align: center - - MLC Chat on iOS - - .. tab:: Android - - **Install MLC Chat Android**. A prebuilt is available as an APK: - - .. image:: https://seeklogo.com/images/D/download-android-apk-badge-logo-D074C6882B-seeklogo.com.png - :width: 135 - :target: https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android/mlc-chat.apk - - | - - **Requirement**. Llama3-8B model needs a device with a minimum of 6GB RAM, whereas the RedPajama-3B model runs with at least 4GB RAM. - The demo is tested on - - - Samsung S23 with Snapdragon 8 Gen 2 chip - - Redmi Note 12 Pro with Snapdragon 685 - - Google Pixel phones - - **Tutorial and source code**. The source code of the android app is fully `open source `__, - and a :ref:`tutorial ` is included in documentation. - - .. figure:: https://blog.mlc.ai/img/android/android-recording.gif - :width: 300 - :align: center - - MLC LLM on Android - - -What to Do Next ---------------- - -- Check out :ref:`introduction-to-mlc-llm` for the introduction of a complete workflow in MLC LLM. -- Depending on your use case, check out our API documentation and tutorial pages: - - - :ref:`webllm-runtime` - - :ref:`deploy-rest-api` - - :ref:`deploy-cli` - - :ref:`deploy-python-engine` - - :ref:`deploy-ios` - - :ref:`deploy-android` - - :ref:`deploy-ide-integration` - -- `Convert model weight to MLC format `_, if you want to run your own models. -- `Compile model libraries `_, if you want to deploy to web/iOS/Android or control the model optimizations. -- Report any problem or ask any question: open new issues in our `GitHub repo `_. diff --git a/docs/index.rst b/docs/index.rst deleted file mode 100644 index 2d5597d18e..0000000000 --- a/docs/index.rst +++ /dev/null @@ -1,82 +0,0 @@ -👋 Welcome to MLC LLM -===================== - -`Discord `_ | `GitHub `_ - -Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. - -Quick Start ------------ - -Check out :ref:`quick-start` for quick start examples of using MLC LLM. - -Introduction to MLC LLM ------------------------ - -Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a complete workflow in MLC LLM. - - -.. toctree:: - :maxdepth: 1 - :caption: Get Started - :hidden: - - get_started/quick_start.rst - get_started/introduction.rst - -.. toctree:: - :maxdepth: 1 - :caption: Build and Deploy Apps - :hidden: - - deploy/javascript.rst - deploy/rest.rst - deploy/cli.rst - deploy/python_engine.rst - deploy/ios.rst - deploy/android.rst - deploy/ide_integration.rst - deploy/mlc_chat_config.rst - -.. toctree:: - :maxdepth: 1 - :caption: Compile Models - :hidden: - - compilation/convert_weights.rst - compilation/compile_models.rst - compilation/define_new_models.rst - -.. toctree:: - :maxdepth: 1 - :caption: Model Prebuilts - :hidden: - - prebuilt_models.rst - -.. toctree:: - :maxdepth: 1 - :caption: Dependency Installation - :hidden: - - install/tvm.rst - install/mlc_llm.rst - install/conda.rst - install/gpu.rst - install/emcc.rst - -.. toctree:: - :maxdepth: 1 - :caption: Community - :hidden: - - community/guideline.rst - community/faq.rst - - -.. toctree:: - :maxdepth: 1 - :caption: Privacy - :hidden: - - privacy.rst diff --git a/docs/install/conda.rst b/docs/install/conda.rst deleted file mode 100644 index 305c0f6748..0000000000 --- a/docs/install/conda.rst +++ /dev/null @@ -1,66 +0,0 @@ -Install Conda -============= - -MLC LLM does not depend on, but generally recommends conda as a generic dependency manager, primarily because it creates unified cross-platform experience to make windows/Linux/macOS development equally easy. Moreover, conda is python-friendly and provides all the python packages needed for MLC LLM, such as numpy. - -.. contents:: Table of Contents - :depth: 2 - - -Install Miniconda ------------------ - -**Use installer.** Miniconda, a minimal distribution of conda, comes with out-of-box installer across Windows/macOS/Linux. Please refer to its `official website `_ link for detailed instructions. - -**Set libmamba as the dependency solver.** The default dependency solver in conda could be slow in certain scenarios, and it is always recommended to upgrade it to libmamba, a faster solver. - -.. code-block:: bash - :caption: Set libmamba as the default solver - - # update conda - conda update --yes -n base -c defaults conda - # install `conda-libmamba-solver` - conda install --yes -n base conda-libmamba-solver - # set it as the default solver - conda config --set solver libmamba - -.. note:: - Conda is a generic dependency manager, which is not necessarily related to any Python distributions. - In fact, some of our tutorials recommends to use conda to install cmake, git and rust for its unified experience across OS platforms. - - -Validate installation ---------------------- - -**Step 1. Check conda-arch mismatch.** Nowadays macOS runs on two different architectures: arm64 and x86_64, which could particularly lead to many misuses in MLC LLM, where the error message hints about "architecture mismatch". Use the following command to make sure particular conda architecture is installed accordingly: - -.. code-block:: bash - :caption: Check conda architecture - - >>> conda info | grep platform - # for arm mac - platform : osx-arm64 - # for x86 mac - platform : osx-64 - -**Step 2. Check conda virtual environment.** If you have installed python in your conda virtual environment, make sure conda, Python and pip are all from this environment: - -.. code-block:: bash - :caption: Check conda virtual environment (macOS, Linux) - - >>> echo $CONDA_PREFIX - /.../miniconda3/envs/mlc-doc-venv - >>> which python - /.../miniconda3/envs/mlc-doc-venv/bin/python - >>> which pip - /.../miniconda3/envs/mlc-doc-venv/bin/pip - -.. code-block:: bat - :caption: Check conda virtual environment (Windows) - - >>> echo $Env:CONDA_PREFIX - \...\miniconda3\envs\mlc-doc-venv - >>> Get-Command python.exe - \...\miniconda3\envs\mlc-doc-venv\bin\python.exe - >>> Get-Command pip.exe - \...\miniconda3\envs\mlc-doc-venv\bin\pip.exe diff --git a/docs/install/emcc.rst b/docs/install/emcc.rst deleted file mode 100644 index f82292e00c..0000000000 --- a/docs/install/emcc.rst +++ /dev/null @@ -1,68 +0,0 @@ -.. _install-web-build: - -Install Wasm Build Environment -============================== - -This page describes the steps to setup build environment for WebAssembly and WebGPU builds. - -Step 1: Install EMSDK ---------------------- - -Emscripten is an LLVM-based compiler that compiles C/C++ source code to WebAssembly. -We need to install emscripten for webgpu build. - -- Please follow the installation instruction `here `__ - to install the latest emsdk. -- Source path/to/emsdk_env.sh so emcc is reachable from PATH and the command emcc works. - -Validate that emcc is accessible in shell - -.. code:: bash - - emcc --version - -Step 2: Set TVM_HOME and MLC_LLM_HOME -------------------------------------- - -We need to set a path to a tvm source in order to build tvm runtime. -Note that you do not need to build tvm unity from the source. The source here is only used to build the web runtime component. -Set environment variable in your shell startup profile in to point to ``3rdparty/tvm`` (if preferred, you could also -point to your own TVM address if you installed TVM from source). - -Besides, we also need to set ``MLC_LLM_HOME`` so that we can locate ``mlc_wasm_runtime.bc`` when compiling a model library wasm. - -.. code:: bash - - export TVM_HOME=/path/to/3rdparty/tvm - export MLC_LLM_HOME=/path/to/mlc-llm - - -Step 3: Prepare Wasm Runtime ----------------------------- - -First, we need to obtain a copy of the mlc-llm source code for the setup script - -.. code:: bash - - git clone https://github.com/mlc-ai/mlc-llm.git --recursive - cd mlc-llm - -Now we can prepare wasm runtime using the script in mlc-llm repo - -.. code:: bash - - ./web/prep_emcc_deps.sh - -We can then validate the outcome - -.. code:: bash - - >>> echo ${TVM_HOME} - - /path/set/in/step2 - - >>> ls -l ${TVM_HOME}/web/dist/wasm/*.bc - - tvmjs_support.bc - wasm_runtime.bc - webgpu_runtime.bc diff --git a/docs/install/gpu.rst b/docs/install/gpu.rst deleted file mode 100644 index 608c238265..0000000000 --- a/docs/install/gpu.rst +++ /dev/null @@ -1,201 +0,0 @@ -GPU Drivers and SDKs -==================== - -.. contents:: Table of Contents - :depth: 2 - -MLC LLM is a universal deployment solution that allows efficient CPU/GPU code generation without AutoTVM-based performance tuning. This section focuses on generic GPU environment setup and troubleshooting. - -CUDA ----- - -CUDA is required to compile and run models with CUDA backend. - -Installation -^^^^^^^^^^^^ - -If you have a NVIDIA GPU and you want to use models compiled with CUDA -backend, you should install CUDA, which can be downloaded from -`here `__. - -Validate Installation -^^^^^^^^^^^^^^^^^^^^^ - -To verify you have correctly installed CUDA runtime and NVIDIA driver, run ``nvidia-smi`` in command line and see if you can get the GPU information. - -ROCm ----- - -ROCm is required to compile and run models with ROCm backend. - -Installation -^^^^^^^^^^^^ - -Right now MLC LLM only supports ROCm 5.6. -If you have AMD GPU and you want to use models compiled with ROCm -backend, you should install ROCm 5.6 from `here `__. - -Validate Installation -^^^^^^^^^^^^^^^^^^^^^ - -To verify you have correctly installed ROCm 5.6, run ``rocm-smi`` in command line. -If you see the list of AMD devices printed out in a table, it means the ROCm is correctly installed. - -.. _vulkan_driver: - -Vulkan Driver -------------- - -Installation -^^^^^^^^^^^^ - -To run pre-trained models (e.g. pulled from MLC-AI's Hugging Face repository) compiled with Vulkan backend, you are expected to install Vulkan driver on your machine. - -Please check `this -page `__ and find the -Vulkan driver according to your GPU vendor. - -AMD Radeon and Radeon PRO -######################### - -For AMD Radeon and Radeon PRO users, please download AMD's drivers from official website (`Linux `__ / `Windows `__). -For Linux users, after you installed the ``amdgpu-install`` package, you can follow the instructions in its `documentation `__ to install -the driver. We recommend you installing ROCr OpenCL and PRO Vulkan (proprietary) for best performance, which can be done by running the following command: - -.. code:: bash - - amdgpu-install --usecase=graphics,opencl --opencl=rocr --vulkan=pro --no-32 - -Validate Installation -^^^^^^^^^^^^^^^^^^^^^ - -To verify whether Vulkan installation is successful or not, you are encouraged to install ``vulkaninfo``, below are the instructions to install ``vulkaninfo`` on different platforms: - -.. tabs :: - - .. code-tab :: bash Ubuntu/Debian - - sudo apt-get update - sudo apt-get install vulkan-tools - - .. code-tab :: bash Windows - - # It comes with your GPU driver - - .. code-tab :: bash Fedora - - sudo dnf install vulkan-tools - - .. code-tab :: bash Arch Linux - - sudo pacman -S vulkan-tools - # Arch Linux has maintained an awesome wiki page for Vulkan which you can refer to for troubleshooting: https://wiki.archlinux.org/title/Vulkan - - .. code-tab :: bash Other Distributions - - # Please install Vulkan SDK for your platform - # https://vulkan.lunarg.com/sdk/home - - -After installation, you can run ``vulkaninfo`` in command line and see if you can get the GPU information. - -.. note:: - WSL support for Windows is work-in-progress at the moment. Please do not use WSL on Windows to run Vulkan. - -Vulkan SDK ----------- - -Vulkan SDK is required for compiling models to Vulkan backend. To build TVM Unity compiler from source, you will need to install Vulkan SDK as a dependency, but our :doc:`pre-built wheels <../install/mlc_llm>` already ships with Vulkan SDK. - -Check Vulkan SDK installation guide according to your platform: - -.. tabs :: - - .. tab :: Windows - - `Getting Started with the Windows Tarball Vulkan SDK `__ - - .. tab :: Linux - - For Ubuntu user, please check - `Getting Started with the Ubuntu Vulkan SDK `__ - - For other Linux distributions, please check - `Getting Started with the Linux Tarball Vulkan SDK `__ - - .. tab :: Mac - - `Getting Started with the macOS Vulkan SDK `__ - -Please refer to installation and setup page for next steps to build TVM-Unity from source. - -OpenCL SDK ----------- - -OpenCL SDK is only required when you want to build your own models for OpenCL backend. Please refer to `OpenCL's Github Repository `__ for installation guide of OpenCL-SDK. - -Orange Pi 5 (RK3588 based SBC) ------------------------------- - -OpenCL SDK and Mali GPU driver is required to compile and run models for OpenCL backend. - -Installation -^^^^^^^^^^^^ - -* Download and install the Ubuntu 22.04 for your board from `here `__ - -* Download and install ``libmali-g610.so`` - -.. code-block:: bash - - cd /usr/lib && sudo wget https://github.com/JeffyCN/mirrors/raw/libmali/lib/aarch64-linux-gnu/libmali-valhall-g610-g6p0-x11-wayland-gbm.so - -* Check if file ``mali_csffw.bin`` exist under path ``/lib/firmware``, if not download it with command: - -.. code-block:: bash - - cd /lib/firmware && sudo wget https://github.com/JeffyCN/mirrors/raw/libmali/firmware/g610/mali_csffw.bin - -* Download OpenCL ICD loader and manually add libmali to ICD - -.. code-block:: bash - - sudo apt update - sudo apt install mesa-opencl-icd - sudo mkdir -p /etc/OpenCL/vendors - echo "/usr/lib/libmali-valhall-g610-g6p0-x11-wayland-gbm.so" | sudo tee /etc/OpenCL/vendors/mali.icd - -* Download and install ``libOpenCL`` - -.. code-block:: bash - - sudo apt install ocl-icd-opencl-dev - -* Download and install dependencies for Mali OpenCL - -.. code-block:: bash - - sudo apt install libxcb-dri2-0 libxcb-dri3-0 libwayland-client0 libwayland-server0 libx11-xcb1 - -* Download and install clinfo to check if OpenCL successfully installed - -.. code-block:: bash - - sudo apt install clinfo - -Validate Installation -^^^^^^^^^^^^^^^^^^^^^ - -To verify you have correctly installed OpenCL runtime and Mali GPU driver, run ``clinfo`` in command line and see if you can get the GPU information. -You are expect to see the following information: - -.. code-block:: bash - - $ clinfo - arm_release_ver: g13p0-01eac0, rk_so_ver: 3 - Number of platforms 2 - Platform Name ARM Platform - Platform Vendor ARM - Platform Version OpenCL 2.1 v1.g6p0-01eac0.2819f9d4dbe0b5a2f89c835d8484f9cd - Platform Profile FULL_PROFILE - ... diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst deleted file mode 100644 index ce15616957..0000000000 --- a/docs/install/mlc_llm.rst +++ /dev/null @@ -1,242 +0,0 @@ -.. _install-mlc-packages: - -Install MLC LLM Python Package -============================== - -.. contents:: Table of Contents - :local: - :depth: 2 - -MLC LLM Python Package can be installed directly from a prebuilt developer package, or built from source. - -Option 1. Prebuilt Package --------------------------- - -We provide nightly built pip wheels for MLC-LLM via pip. -Select your operating system/compute platform and run the command in your terminal: - -.. note:: - ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts. - -.. tabs:: - - .. tab:: Linux - - .. tabs:: - - .. tab:: CPU - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly - - .. tab:: CUDA 12.1 - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu121 mlc-ai-nightly-cu121 - - .. tab:: CUDA 12.2 - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu122 mlc-ai-nightly-cu122 - - .. tab:: ROCm 5.6 - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-rocm56 mlc-ai-nightly-rocm56 - - .. tab:: ROCm 5.7 - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-rocm57 mlc-ai-nightly-rocm57 - - .. tab:: Vulkan - - Supported in all Linux packages. Checkout the following instructions - to install the latest vulkan loader to avoid vulkan not found issue. - - .. note:: - - - .. code-block:: bash - - conda install -c conda-forge gcc libvulkan-loader - - - If encountering issues with GLIBC not found, please install the latest glibc in conda: - - .. code-block:: bash - - conda install -c conda-forge libgcc-ng - - Besides, we would recommend using Python 3.11; so if you are creating a new environment, - you could use the following command: - - .. code-block:: bash - - conda create --name mlc-prebuilt python=3.11 - - .. tab:: macOS - - .. tabs:: - - .. tab:: CPU + Metal - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly - - .. note:: - - Always check if conda is installed properly in macOS using the command below: - - .. code-block:: bash - - conda info | grep platform - - It should return "osx-64" for Mac with Intel chip, and "osx-arm64" for Mac with Apple chip. - - .. tab:: Windows - - .. tabs:: - - .. tab:: CPU + Vulkan - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly - - .. note:: - Make sure you also install vulkan loader and clang to avoid vulkan - not found error or clang not found(needed for jit compile) - - .. code-block:: bash - - conda install -c conda-forge clang libvulkan-loader - - If encountering the error below: - - .. code-block:: bash - - FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax. - - It is likely `zstd`, a dependency to LLVM, was missing. Please use the command below to get it installed: - - .. code-block:: bash - - conda install zstd - - -Then you can verify installation in command line: - -.. code-block:: bash - - python -c "import mlc_llm; print(mlc_llm)" - # Prints out: - -| - -.. _mlcchat_build_from_source: - -Option 2. Build from Source ---------------------------- - -We also provide options to build mlc runtime libraries ``mlc_llm`` from source. -This step is useful when you want to make modification or obtain a specific version of mlc runtime. - - -**Step 1. Set up build dependency.** To build from source, you need to ensure that the following build dependencies are satisfied: - -* CMake >= 3.24 -* Git -* `Rust and Cargo `_, required by Hugging Face's tokenizer -* One of the GPU runtimes: - - * CUDA >= 11.8 (NVIDIA GPUs) - * Metal (Apple GPUs) - * Vulkan (NVIDIA, AMD, Intel GPUs) - -.. code-block:: bash - :caption: Set up build dependencies in Conda - - # make sure to start with a fresh environment - conda env remove -n mlc-chat-venv - # create the conda environment with build dependency - conda create -n mlc-chat-venv -c conda-forge \ - "cmake>=3.24" \ - rust \ - git \ - python=3.11 - # enter the build environment - conda activate mlc-chat-venv - -.. note:: - For runtime, :doc:`TVM Unity ` compiler is not a dependency for MLCChat CLI or Python API. Only TVM's runtime is required, which is automatically included in `3rdparty/tvm `_. - However, if you would like to compile your own models, you need to follow :doc:`TVM Unity `. - -**Step 2. Configure and build.** A standard git-based workflow is recommended to download MLC LLM, after which you can specify build requirements with our lightweight config generation tool: - -.. code-block:: bash - :caption: Configure and build - - # clone from GitHub - git clone --recursive https://github.com/mlc-ai/mlc-llm.git && cd mlc-llm/ - # create build directory - mkdir -p build && cd build - # generate build configuration - python3 ../cmake/gen_cmake_config.py - # build mlc_llm libraries - cmake .. && cmake --build . --parallel $(nproc) && cd .. - -.. note:: - If you are using CUDA and your compute capability is above 80, then it is require to build with - ``set(USE_FLASHINFER ON)``. Otherwise, you may run into ``Cannot find PackedFunc`` issue during - runtime. - - To check your CUDA compute capability, you can use ``nvidia-smi --query-gpu=compute_cap --format=csv``. - -**Step 3. Install via Python.** We recommend that you install ``mlc_llm`` as a Python package, giving you -access to ``mlc_llm.compile``, ``mlc_llm.ChatModule``, and the CLI. -There are two ways to do so: - - .. tabs :: - - .. code-tab :: bash Install via environment variable - - export MLC_LLM_HOME=/path-to-mlc-llm - export PYTHONPATH=$MLC_LLM_HOME/python:$PYTHONPATH - alias mlc_llm="python -m mlc_llm" - - .. code-tab :: bash Install via pip local project - - conda activate your-own-env - which python # make sure python is installed, expected output: path_to_conda/envs/your-own-env/bin/python - cd /path-to-mlc-llm/python - pip install -e . - -**Step 4. Validate installation.** You may validate if MLC libarires and mlc_llm CLI is compiled successfully using the following command: - -.. code-block:: bash - :caption: Validate installation - - # expected to see `libmlc_llm.so` and `libtvm_runtime.so` - ls -l ./build/ - # expected to see help message - mlc_llm chat -h - -Finally, you can verify installation in command line. You should see the path you used to build from source with: - -.. code:: bash - - python -c "import mlc_llm; print(mlc_llm)" diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst deleted file mode 100644 index ed4977e5e3..0000000000 --- a/docs/install/tvm.rst +++ /dev/null @@ -1,307 +0,0 @@ -.. _install-tvm-unity: - -Install TVM Unity Compiler -========================== - -.. contents:: Table of Contents - :local: - :depth: 2 - -`TVM Unity `__, the latest development in Apache TVM, is required to build MLC LLM. Its features include: - -- High-performance CPU/GPU code generation instantly without tuning; -- Dynamic shape and symbolic shape tracking by design; -- Supporting both inference and training; -- Productive python-first compiler implementation. As a concrete example, MLC LLM compilation is implemented in pure python using its API. - -TVM Unity can be installed directly from a prebuilt developer package, or built from source. - -.. _tvm-unity-prebuilt-package: - -Option 1. Prebuilt Package --------------------------- - -A nightly prebuilt Python package of Apache TVM Unity is provided. - -.. note:: - ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts. - -.. tabs:: - - .. tab:: Linux - - .. tabs:: - - .. tab:: CPU - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly - - .. tab:: CUDA 12.1 - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu121 - - .. tab:: CUDA 12.2 - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu122 - - .. tab:: ROCm 5.6 - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm56 - - .. tab:: ROCm 5.7 - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm57 - - .. tab:: Vulkan - - Supported in all Linux packages. - - .. note:: - - If encountering issues with GLIBC not found, please install the latest glibc in conda: - - .. code-block:: bash - - conda install -c conda-forge libgcc-ng - - .. tab:: macOS - - .. tabs:: - - .. tab:: CPU + Metal - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly - - .. note:: - - Always check if conda is installed properly in macOS using the command below: - - .. code-block:: bash - - conda info | grep platform - - It should return "osx-64" for Mac with Intel chip, and "osx-arm64" for Mac with Apple chip. - - .. tab:: Windows - - .. tabs:: - - .. tab:: CPU + Vulkan - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly - - .. note:: - Make sure you also install vulkan loader and clang to avoid vulkan - not found error or clang not found(needed for jit compile) - - .. code-block:: bash - - conda install -c conda-forge clang libvulkan-loader - - If encountering the error below: - - .. code-block:: bash - - FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax. - - It is likely `zstd`, a dependency to LLVM, was missing. Please use the command below to get it installed: - - .. code-block:: bash - - conda install zstd - -.. _tvm-unity-build-from-source: - -Option 2. Build from Source ---------------------------- - -While it is generally recommended to always use the prebuilt TVM Unity, if you require more customization, you may need to build it from source. **NOTE.** this should only be attempted if you are familiar with the intricacies of C++, CMake, LLVM, Python, and other related systems. - -.. collapse:: Details - - **Step 1. Set up build dependency.** To build from source, you need to ensure that the following build dependencies are met: - - - CMake >= 3.24 - - LLVM >= 15 - - Git - - (Optional) CUDA >= 11.8 (targeting NVIDIA GPUs) - - (Optional) Metal (targeting Apple GPUs such as M1 and M2) - - (Optional) Vulkan (targeting NVIDIA, AMD, Intel and mobile GPUs) - - (Optional) OpenCL (targeting NVIDIA, AMD, Intel and mobile GPUs) - - .. note:: - - To target NVIDIA GPUs, either CUDA or Vulkan is required (CUDA is recommended); - - For AMD and Intel GPUs, Vulkan is necessary; - - When targeting Apple (macOS, iOS, iPadOS), Metal is a mandatory dependency; - - Some Android devices only support OpenCL, but most of them support Vulkan. - - To easiest way to manage dependency is via conda, which maintains a set of toolchains including LLVM across platforms. To create the environment of those build dependencies, one may simply use: - - .. code-block:: bash - :caption: Set up build dependencies in conda - - # make sure to start with a fresh environment - conda env remove -n tvm-build-venv - # create the conda environment with build dependency - conda create -n tvm-build-venv -c conda-forge \ - "llvmdev>=15" \ - "cmake>=3.24" \ - git \ - python=3.11 - # enter the build environment - conda activate tvm-build-venv - - **Step 2. Configure and build.** Standard git-based workflow are recommended to download Apache TVM Unity, and then specify build requirements in ``config.cmake``: - - .. code-block:: bash - :caption: Download TVM Unity from GitHub - - # clone from GitHub - git clone --recursive git@github.com:mlc-ai/relax.git tvm-unity && cd tvm-unity - # create the build directory - rm -rf build && mkdir build && cd build - # specify build requirements in `config.cmake` - cp ../cmake/config.cmake . - - .. note:: - We are temporarily using `mlc-ai/relax `_ instead, which comes with several temporary outstanding changes that we will upstream to Apache TVM's `unity branch `_. - - We want to specifically tweak the following flags by appending them to the end of the configuration file: - - .. code-block:: bash - :caption: Configure build in ``config.cmake`` - - # controls default compilation flags - echo "set(CMAKE_BUILD_TYPE RelWithDebInfo)" >> config.cmake - # LLVM is a must dependency - echo "set(USE_LLVM \"llvm-config --ignore-libllvm --link-static\")" >> config.cmake - echo "set(HIDE_PRIVATE_SYMBOLS ON)" >> config.cmake - # GPU SDKs, turn on if needed - echo "set(USE_CUDA OFF)" >> config.cmake - echo "set(USE_METAL OFF)" >> config.cmake - echo "set(USE_VULKAN OFF)" >> config.cmake - echo "set(USE_OPENCL OFF)" >> config.cmake - # FlashInfer related, requires CUDA w/ compute capability 80;86;89;90 - echo "set(USE_FLASHINFER OFF)" >> config.cmake - echo "set(FLASHINFER_CUDA_ARCHITECTURES YOUR_CUDA_COMPUTE_CAPABILITY_HERE)" >> config.cmake - echo "set(CMAKE_CUDA_ARCHITECTURES YOUR_CUDA_COMPUTE_CAPABILITY_HERE)" >> config.cmake - - .. note:: - ``HIDE_PRIVATE_SYMBOLS`` is a configuration option that enables the ``-fvisibility=hidden`` flag. This flag helps prevent potential symbol conflicts between TVM and PyTorch. These conflicts arise due to the frameworks shipping LLVMs of different versions. - - `CMAKE_BUILD_TYPE `_ controls default compilation flag: - - - ``Debug`` sets ``-O0 -g`` - - ``RelWithDebInfo`` sets ``-O2 -g -DNDEBUG`` (recommended) - - ``Release`` sets ``-O3 -DNDEBUG`` - - .. note:: - If you are using CUDA and your compute capability is above 80, then it is require to build with - ``set(USE_FLASHINFER ON)``. Otherwise, you may run into ``Cannot find PackedFunc`` issue during - runtime. - - To check your CUDA compute capability, you can use ``nvidia-smi --query-gpu=compute_cap --format=csv``. - - Once ``config.cmake`` is edited accordingly, kick off build with the commands below: - - .. code-block:: bash - :caption: Build ``libtvm`` using cmake and cmake - - cmake .. && cmake --build . --parallel $(nproc) - - A success build should produce ``libtvm`` and ``libtvm_runtime`` under ``/path-tvm-unity/build/`` directory. - - Leaving the build environment ``tvm-build-venv``, there are two ways to install the successful build into your environment: - - .. tabs :: - - .. code-tab :: bash Install via environment variable - - export PYTHONPATH=/path-to-tvm-unity/python:$PYTHONPATH - - .. code-tab :: bash Install via pip local project - - conda activate your-own-env - conda install python # make sure python is installed - cd /path-to-tvm-unity/python - pip install -e . - -.. `|` adds a blank line - -| - -.. _tvm-unity-validate: - -Validate TVM Installation -------------------------- - -Using a compiler infrastructure with multiple language bindings could be error-prone. -Therefore, it is highly recommended to validate TVM Unity installation before use. - -**Step 1. Locate TVM Python package.** The following command can help confirm that TVM is properly installed as a python package and provide the location of the TVM python package: - -.. code-block:: bash - - >>> python -c "import tvm; print(tvm.__file__)" - /some-path/lib/python3.11/site-packages/tvm/__init__.py - -**Step 2. Confirm which TVM library is used.** When maintaining multiple build or installation of TVM, it becomes important to double check if the python package is using the proper ``libtvm`` with the following command: - -.. code-block:: bash - - >>> python -c "import tvm; print(tvm._ffi.base._LIB)" - - -**Step 3. Reflect TVM build option.** Sometimes when downstream application fails, it could likely be some mistakes with a wrong TVM commit, or wrong build flags. To find it out, the following commands will be helpful: - -.. code-block:: bash - - >>> python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))" - ... # Omitted less relevant options - GIT_COMMIT_HASH: 4f6289590252a1cf45a4dc37bce55a25043b8338 - HIDE_PRIVATE_SYMBOLS: ON - USE_LLVM: llvm-config --link-static - LLVM_VERSION: 15.0.7 - USE_VULKAN: OFF - USE_CUDA: OFF - CUDA_VERSION: NOT-FOUND - USE_OPENCL: OFF - USE_METAL: ON - USE_ROCM: OFF - -.. note:: - ``GIT_COMMIT_HASH`` indicates the exact commit of the TVM build, and it can be found on GitHub via ``https://github.com/mlc-ai/relax/commit/$GIT_COMMIT_HASH``. - -**Step 4. Check device detection.** Sometimes it could be helpful to understand if TVM could detect your device at all with the following commands: - -.. code-block:: bash - - >>> python -c "import tvm; print(tvm.metal().exist)" - True # or False - >>> python -c "import tvm; print(tvm.cuda().exist)" - False # or True - >>> python -c "import tvm; print(tvm.vulkan().exist)" - False # or True - -Please note that the commands above verify the presence of an actual device on the local machine for the TVM runtime (not the compiler) to execute properly. However, TVM compiler can perform compilation tasks without requiring a physical device. As long as the necessary toolchain, such as NVCC, is available, TVM supports cross-compilation even in the absence of an actual device. diff --git a/docs/make.bat b/docs/make.bat deleted file mode 100644 index 32bb24529f..0000000000 --- a/docs/make.bat +++ /dev/null @@ -1,35 +0,0 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=. -set BUILDDIR=_build - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.https://www.sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "" goto help - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst deleted file mode 100644 index 2f772a5d7e..0000000000 --- a/docs/prebuilt_models.rst +++ /dev/null @@ -1,773 +0,0 @@ -.. _Model Prebuilts: - -Model Prebuilts -================== - -.. contents:: Table of Contents - :depth: 3 - :local: - -.. _model-prebuilts-overview: - -Overview --------- - -MLC-LLM is a universal solution for deploying different language models. Any models that can be described in `TVM Relax `__ -(a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the -help of :doc:`TVM Unity `. - -There are two ways to run a model on MLC-LLM (this page focuses on the second one): - -1. Compile your own models following :doc:`the model compilation page `. -2. Use off-the-shelf prebuilt models following this current page. - -In order to run a specific model on MLC-LLM, you need: - -**1. A model library:** a binary file containing the end-to-end functionality to inference a model (e.g. ``Llama-2-7b-chat-hf-q4f16_1-cuda.so``). -See the full list of all precompiled model libraries `here `__. - -**2. Compiled weights:** a folder containing multiple files that store the compiled and quantized weights of a model -(e.g. https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC). See the full list of all precompiled weights `here `__. - -In this page, we first quickly go over :ref:`how to use prebuilts ` for different platforms, -then track what current :ref:`prebuilt models we provide `. - - -.. _using-model-prebuilts: - -Using Prebuilt Models for Different Platforms -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -We quickly go over how to use prebuilt models for each platform. You can find detailed instruction on each platform's corresponding page. - -.. _using-prebuilt-models-cli: - -**Prebuilt Models on CLI / Python** - -For more, please see :ref:`the CLI page `, and the :ref:`the Python page `. - -.. collapse:: Click to show details - - First create the conda environment if you have not done so. - - .. code:: shell - - conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly - conda activate mlc-chat-venv - conda install git git-lfs - git lfs install - - Download the prebuilt model libraries from github. - - .. code:: shell - - mkdir dist/ - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt_libs - - Run the model with CLI: - - .. code:: shell - - mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC - - - To run the model with Python API, see :ref:`the Python page ` (all other downloading steps are the same as CLI). - - -.. for a blank line - -| - -.. _using-prebuilt-models-ios: - -**Prebuilt Models on iOS** - -For more, please see :doc:`the iOS page `. - -.. collapse:: Click to show details - - The `iOS app `_ has builtin RedPajama-3B and Mistral-7B-Instruct-v0.2 support. - - All prebuilt models with an entry in ``iOS`` in the :ref:`model library table ` are supported by iOS. Namely, we have: - - .. list-table:: Prebuilt Models for iOS - :widths: 15 15 15 15 - :header-rows: 1 - - * - Model Code - - Model Series - - Quantization Mode - - MLC HuggingFace Weights Repo - * - `Mistral-7B-Instruct-v0.2-q3f16_1` - - `Mistral `__ - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - `RedPajama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `phi-2-q4f16_1` - - `Microsoft Phi-2 `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ -.. for a blank line - -| - -.. _prebuilt-models-android: - -**Prebuilt Models on Android** - -For more, please see :doc:`the Android page `. - -.. collapse:: Click to show details - - The apk for demo Android app includes the following models. To add more, check out the Android page. - - .. list-table:: Prebuilt Models for Android - :widths: 15 15 15 15 - :header-rows: 1 - - * - Model code - - Model Series - - Quantization Mode - - Hugging Face repo - * - `Llama-2-7b-q4f16_1` - - `Llama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - `RedPajama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ -.. for a blank line - -| - -.. _supported-model-architectures: - -Level 1: Supported Model Architectures (The All-In-One Table) -------------------------------------------------------------- - -For each model architecture (e.g. Llama), there are multiple variants (e.g. CodeLlama, WizardLM). The variants share the same code for inference and only differ in their weights. In other words, running CodeLlama and WizardLM can use the same model library file (specified in Level 2 tables), but different precompiled weights (specified in Level 3 tables). Note that we have not provided prebuilt weights for all model variants. - -Each entry below hyperlinks to the corresponding level 2 and level 3 tables. - -MLC-LLM supports the following model architectures: - -.. list-table:: Supported Model Architectures - :widths: 10 10 15 15 - :header-rows: 1 - - * - Model Architecture - - Support - - Available MLC Prebuilts - - Unavailable in MLC Prebuilts - * - `LLaMA `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`Llama-2-chat ` - - * `Code Llama `__ - * `Vicuna `__ - * `WizardLM `__ - * `WizardCoder (new) `__ - * `OpenOrca Platypus2 `__ - * `FlagAlpha Llama-2 Chinese `__ - * `georgesung Llama-2 Uncensored `__ - * `Alpaca `__ - * `Guanaco `__ - * `OpenLLaMA `__ - * `Gorilla `__ - * `YuLan-Chat `__ - * - `Mistral `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`Mistral-7B-Instruct-v0.2 ` - * :ref:`NeuralHermes-2.5-Mistral-7B ` - * :ref:`OpenHermes-2.5-Mistral-7B ` - * :ref:`WizardMath-7B-V1.1 ` - - - * - `GPT-NeoX `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`RedPajama ` - - * `Dolly `__ - * `Pythia `__ - * `StableCode `__ - * - `GPTBigCode `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - - - * `StarCoder `__ - * `SantaCoder `__ - * `WizardCoder (old) `__ - * - `Phi `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`Phi-1_5 ` - * :ref:`Phi-2 ` - - - * - `GPT2 `__ - - * :ref:`Prebuilt Model Library ` - * `MLC Implementation `__ - - * :ref:`GPT2 ` - - - -If the model variant you are interested in uses one of these model architectures we support, -(but we have not provided the prebuilt weights yet), you can check out -:doc:`/compilation/convert_weights` on how to convert the weights. -Afterwards, you may follow :ref:`distribute-compiled-models` to upload your prebuilt -weights to hugging face, and submit a PR that adds an entry to this page, -contributing to the community. - -For models structured in an architecture we have not supported yet, you could: - -- Either `create a [Model Request] issue `__ which - automatically shows up on our `Model Request Tracking Board `__. - -- Or follow our tutorial :doc:`Define New Models `, which introduces how to bring a new model architecture to MLC-LLM. - - -.. _model-library-tables: - -Level 2: Model Library Tables (Precompiled Binary Files) --------------------------------------------------------- - -As mentioned earlier, each model architecture corresponds to a different model library file. That is, you cannot use the same model library file to run ``RedPajama`` and ``Llama-2``. However, you can use the same ``Llama`` model library file to run ``Llama-2``, ``WizardLM``, ``CodeLlama``, etc, but just with different weight files (from tables in Level 3). - -Each table below demonstrates the pre-compiled model library files for each model architecture. This is categorized by: - -- **Size**: each size of model has its own distinct model library file (e.g. 7B or 13B number of parameters) - -- **Platform**: the backend that the model library is intended to be run on (e.g. CUDA, ROCm, iphone, etc.) - -- **Quantization scheme**: the model library file also differs due to the quantization scheme used. For more on this, please see the :doc:`quantization page ` - (e.g. ``q3f16_1`` vs. ``q4f16_1``). - -Each entry links to the specific model library file found in `this github repo `__. - -If the model library you found is not available as a prebuilt, you can compile it yourself by following :doc:`the model compilation page `, -and submit a PR to the repo `binary-mlc-llm-libs `__ afterwards. - -.. _llama_library_table: - -Llama -^^^^^ -.. list-table:: Llama - :widths: 8 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M Chip) - - Metal - - (Intel) - - iOS - - Android - - webgpu - - mali - * - 7B - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - `q4f16_1 `__ - - `q4f32_1 `__ - - - * - 13B - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - - - - - - - `q4f16_1 `__ - - - * - 34B - - - - - - - - - - - - - - - - - - - - - * - 70B - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - - - - - - - `q4f16_1 `__ - - - -.. _mistral_library_table: - -Mistral -^^^^^^^ -.. list-table:: Mistral - :widths: 8 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M Chip) - - Metal - - (Intel) - - iOS - - Android - - webgpu - - mali - * - 7B - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - - - `q4f16_1 `__ - - - - `q3f16_1 `__ - - `q4f16_1 `__ - - `q4f16_1 `__ - - - - -.. _gpt_neox_library_table: - -GPT-NeoX (RedPajama-INCITE) -^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. list-table:: GPT-NeoX (RedPajama-INCITE) - :widths: 8 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M Chip) - - Metal - - (Intel) - - iOS - - Android - - webgpu - - mali - * - 3B - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - `q4f16_1 `__ - - `q4f32_1 `__ - - - - `q4f16_1 `__ - - `q4f16_1 `__ - - `q4f32_1 `__ - - `q4f16_1 `__ - - `q4f32_1 `__ - - - -.. _gpt_big_code_library_table: - -GPTBigCode -^^^^^^^^^^ - -.. list-table:: GPTBigCode - :widths: 8 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M Chip) - - Metal - - (Intel) - - iOS - - Android - - webgpu - - mali - * - 15B - - - - - - - - - - - - - - - - - - - - - -.. _phi_library_table: - -Phi -^^^ -.. list-table:: Phi - :widths: 8 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M Chip) - - Metal - - (Intel) - - iOS - - Android - - webgpu - - mali - * - Phi-2 - - (2.7B) - - `q0f16 `__ - - `q4f16_1 `__ - - - - `q0f16 `__ - - `q4f16_1 `__ - - - - `q0f16 `__ - - `q4f16_1 `__ - - - - - - - - `q0f16 `__ - - `q4f16_1 `__ - - - * - Phi-1.5 - - (1.3B) - - `q0f16 `__ - - `q4f16_1 `__ - - - - `q0f16 `__ - - `q4f16_1 `__ - - - - `q0f16 `__ - - `q4f16_1 `__ - - - - - - - - `q0f16 `__ - - `q4f16_1 `__ - - - -.. _gpt2_library_table: - -GPT2 -^^^^ -.. list-table:: GPT2 - :widths: 8 8 8 8 8 8 8 8 8 8 8 - :header-rows: 1 - :stub-columns: 1 - - * - - - CUDA - - ROCm - - Vulkan - - (Linux) - - Vulkan - - (Windows) - - Metal - - (M Chip) - - Metal - - (Intel) - - iOS - - Android - - webgpu - - mali - * - GPT2 - - (124M) - - `q0f16 `__ - - - - `q0f16 `__ - - - - `q0f16 `__ - - - - - - - - `q0f16 `__ - - - * - GPT2-med - - (355M) - - `q0f16 `__ - - - - `q0f16 `__ - - - - `q0f16 `__ - - - - - - - - `q0f16 `__ - - - -.. _model-variant-tables: - -Level 3: Model Variant Tables (Precompiled Weights) ---------------------------------------------------- - -Finally, for each model variant, we provide the precompiled weights we uploaded to hugging face. - -Each precompiled weight is categorized by its model size (e.g. 7B vs. 13B) and the quantization scheme (e.g. ``q3f16_1`` vs. ``q4f16_1``). We note that the weights are **platform-agnostic**. - -Each model variant also loads its conversation configuration from a pre-defined :ref:`conversation template`. Note that multiple model variants can share a common conversation template. - -Some of these files are uploaded by our community contributors--thank you! - -.. _llama2_variant_table: - -`Llama-2 `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``llama-2`` - -.. list-table:: Llama-2 - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q4f16_1 (Chat) `__ - * `q4f32_1 (Chat) `__ - - * - 13B - - * `q4f16_1 `__ - - * - 70B - - * `q4f16_1 `__ - -.. _mistralinstruct_variant_table: - -`Mistral `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``mistral_default`` - -.. list-table:: Mistral - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q3f16_1 (Instruct) `__ - * `q4f16_1 (Instruct) `__ - -.. _neuralhermes_variant_table: - -`NeuralHermes-2.5-Mistral `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``neural_hermes_mistral`` - -.. list-table:: Neural Hermes - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q4f16_1 `__ - -.. _openhermes_variant_table: - -`OpenHermes-2-Mistral `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``open_hermes_mistral`` - -.. list-table:: Open Hermes - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q4f16_1 `__ - - - -.. _wizardmathv1.1_variant_table: - -`WizardMath V1.1 `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``wizard_coder_or_math`` - -.. list-table:: WizardMath - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 7B - - * `q4f16_1 `__ - - -.. _red_pajama_variant_table: - -`RedPajama `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``redpajama_chat`` - -.. list-table:: Red Pajama - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - 3B - - * `q4f16_1 (Chat) `__ - * `q4f32_1 (Chat) `__ - - -.. _phi_variant_table: - -`Phi `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``phi-2`` - -.. list-table:: Phi - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - Phi-2 (2.7B) - - * `q0f16 `__ - * `q4f16_1 `__ - * - Phi-1.5 (1.3B) - - * `q0f16 `__ - * `q4f16_1 `__ - - -.. _gpt2_variant_table: - -`GPT2 `__ -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Conversation template: ``gpt2`` - -.. list-table:: GPT2 - :widths: 30 30 - :header-rows: 1 - - * - Size - - Hugging Face Repo Link - * - GPT2 (124M) - - * `q0f16 `__ - * - GPT2-medium (355M) - - * `q0f16 `__ - - ------------------- - - -.. _contribute-models-to-mlc-llm: - -Contribute Models to MLC-LLM ----------------------------- - -Ready to contribute your compiled models/new model architectures? Awesome! Please check :ref:`contribute-new-models` on how to contribute new models to MLC-LLM. diff --git a/docs/privacy.rst b/docs/privacy.rst deleted file mode 100644 index cdd3c91671..0000000000 --- a/docs/privacy.rst +++ /dev/null @@ -1,5 +0,0 @@ -MLC Chat App Privacy -==================== - -MLC Chat run all generation locally. -All data stays in users' device and is not collected by the app. diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index 0156a180b0..0000000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,14 +0,0 @@ -sphinx-tabs == 3.4.1 -sphinx-rtd-theme -sphinx == 5.2.3 -sphinx-toolbox == 3.4.0 -tlcpack-sphinx-addon==0.2.2 -sphinxcontrib_httpdomain==1.8.1 -sphinxcontrib-napoleon==0.7 -sphinx-reredirects==0.1.2 -shortuuid -pydantic -uvicorn -fastapi ---find-links https://mlc.ai/wheels -mlc-ai-nightly diff --git a/examples/python/benchmark.py b/examples/python/benchmark.py deleted file mode 100644 index 7c897215d1..0000000000 --- a/examples/python/benchmark.py +++ /dev/null @@ -1,11 +0,0 @@ -from mlc_llm import ChatModule - -# From the mlc-llm directory, run -# $ python examples/python/benchmark.py - -# Create a ChatModule instance -cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1") - -output = cm.benchmark_generate("What's the meaning of life?", generate_length=256) -print(f"Generated text:\n{output}\n") -print(f"Statistics: {cm.stats()}") diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py deleted file mode 100644 index a290eb892c..0000000000 --- a/examples/python/run_llama_batched_vllm.py +++ /dev/null @@ -1,448 +0,0 @@ -import argparse -import math -import os -import json -from collections import defaultdict -from typing import List -from dataclasses import dataclass - -import numpy as np - -import tvm -from tvm import relax -from tvm.runtime import disco as di - -import torch -from transformers import AutoTokenizer - -from mlc_llm.relax_model.llama import LlamaConfig -from mlc_llm import utils - - -class KVCache: - def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, disco_session): - if disco_session: - init_cache_func = disco_session.get_global_func("tvm.contrib.vllm.allocate_kv_cache") - else: - init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") - - self.cache = init_cache_func(head_size, num_layers, num_heads, block_size, num_blocks) - - self.block_tables = defaultdict(list) - self.slot_mappings = defaultdict(list) - self.block_size = block_size - - -class CacheManager: - block_size: int = 16 - - def __init__( - self, num_blocks, num_layers, num_heads, head_size, disco_session=None, sliding_window=None - ): - self.num_blocks = num_blocks - self.free_blocks = list(range(num_blocks)) - self.kv_cache = KVCache( - num_blocks, self.block_size, num_layers, num_heads, head_size, disco_session - ) - - if sliding_window: - assert sliding_window % self.kv_cache.block_size == 0 - self.block_sliding_window = sliding_window // self.kv_cache.block_size - else: - self.block_sliding_window = None - - def set_size(self, request_ids: List[int], target_sizes: List[int]): - for id, size in zip(request_ids, target_sizes): - num_needed_block = math.ceil(size / self.block_size) - - if self.block_sliding_window: - num_needed_block = min(num_needed_block, self.block_sliding_window) - - if id in self.kv_cache.block_tables and size == 0: - self.free_blocks.extend(self.kv_cache.block_tables[id]) - del self.kv_cache.block_tables[id] - del self.kv_cache.slot_mappings[id] - - elif id in self.kv_cache.block_tables: - # Decoding - if len(self.kv_cache.block_tables[id]) < num_needed_block: - # Need to allocate a new block for this request - assert len(self.kv_cache.block_tables[id]) + 1 == num_needed_block - self.kv_cache.block_tables[id].append(self.free_blocks.pop()) - - pos = size - 1 - block_number = self.kv_cache.block_tables[id][-1] - - if self.block_sliding_window: - block_number = self.kv_cache.block_tables[id][ - (pos // self.block_size) % self.block_sliding_window - ] - else: - block_number = self.kv_cache.block_tables[id][-1] - - block_offset = pos % self.block_size - slot = block_number * self.block_size + block_offset - self.kv_cache.slot_mappings[id].append(slot) - - elif id not in self.kv_cache.block_tables: - assert len(self.free_blocks) >= num_needed_block, "Not enough free blocks." - - for _ in range(num_needed_block): - self.kv_cache.block_tables[id].append(self.free_blocks.pop()) - - for i in range(size): - block_idx = i // self.block_size - - if self.block_sliding_window: - block_idx %= self.block_sliding_window - - block_number = self.kv_cache.block_tables[id][block_idx] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - self.kv_cache.slot_mappings[id].append(slot) - - def get(self): - return self.kv_cache - - -@dataclass -class SequenceGenerationRequest: - request_id: int - token_ids: List[int] - - -@dataclass -class SequenceGenerationResponse: - request_id: int - token_id: int - - -def sample(logits): - logits = torch.from_dlpack(logits) - return torch.argmax(logits, -1).cpu().numpy() - - -def load_params_disco(artifact_path, lib_path, num_shards): - sess = di.ProcessSession(num_workers=num_shards) - devices = range(num_shards) - sess.init_ccl("nccl", *devices) - module = sess.load_vm_module(lib_path) - - loader_create = sess.get_global_func("runtime.disco.ShardLoader") - metadata_path = os.path.join(artifact_path, "params", "ndarray-cache.json") - with open(metadata_path, "r", encoding="utf-8") as f: - ndarray_cache_metadata = f.read() - - loader = loader_create(metadata_path, ndarray_cache_metadata, "", module) - loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll") - params = loader_load(loader) - - return module, params, sess - - -def copy_to_worker_0(sess: di.Session, host_array): - x_array = sess.empty(host_array.shape, host_array.dtype) - sess.copy_to_worker_0(host_array, x_array) - return x_array - - -def get_tvm_model(artifact_path, model, quantization, num_shards, dev): - lib_path = os.path.join(artifact_path, f"{model}-{quantization}-cuda.so") - - if num_shards == 1: - ex = tvm.runtime.load_module(lib_path) - vm = relax.VirtualMachine(ex, dev) - params = utils.load_params(artifact_path, dev) - return vm.module, params, None - - return load_params_disco(artifact_path, lib_path, num_shards) - - -def _prepare_inputs( - requests, - all_slot_mappings, - all_block_tables, - sliding_window, - dev, - is_prefill, -): - block_tables = [] - seq_lens = [] - input_ids = [] - slot_mapping = [] - positions = [] - max_num_blocks_per_seq = 0 - indices_within_window = [] - start_idx = 0 - - for request in requests: - request_id = request.request_id - token_ids = request.token_ids - - if is_prefill: - input_ids += token_ids - prompt_len = len(token_ids) - seq_lens.append(prompt_len) - positions += range(prompt_len) - slot_mapping += all_slot_mappings[request_id] - - if sliding_window: - indices_within_window += range( - start_idx + max(0, prompt_len - sliding_window), - start_idx + prompt_len, - ) - start_idx += prompt_len - - else: - input_ids.append(token_ids[-1]) - pos = len(token_ids) - 1 - positions.append(pos) - block_table = all_block_tables[request_id] - max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) - block_tables.append(block_table) - slot_mapping.append(all_slot_mappings[request_id][-1]) - - if sliding_window: - seq_lens.append(min(len(token_ids), sliding_window)) - else: - seq_lens.append(len(token_ids)) - - input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev) - positions = tvm.nd.array(np.array(positions, dtype="int32"), dev) - seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), dev) - slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev) - - if is_prefill and sliding_window: - indices_within_window = tvm.nd.array(np.array(indices_within_window, dtype="int32"), dev) - else: - indices_within_window = None - - if not is_prefill: - - def _pad_to_max(x: List[int], max_len: int) -> List[int]: - return x + [0] * (max_len - len(x)) - - padded_block_tables = [ - _pad_to_max(block_table, max_num_blocks_per_seq) for block_table in block_tables - ] - - block_tables_np = np.vstack(padded_block_tables).astype("int32") - block_tables = tvm.nd.array(np.array(block_tables_np, dtype="int32"), dev) - else: - block_tables = None - - return ( - input_ids, - positions, - seq_lens, - slot_mapping, - indices_within_window, - block_tables, - ) - - -class Model: - def __init__( - self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window - ): - self.mod, self.params, self.disco_session = get_tvm_model( - artifact_path, model_name, quant, num_shards, dev - ) - self.dev = dev - self.vocab_size = vocab_size - self.sliding_window = sliding_window - - if sliding_window: - self.block_sliding_window = sliding_window // CacheManager.block_size - else: - self.block_sliding_window = None - - def generate( - self, requests: List[SequenceGenerationRequest], cache: KVCache, is_prefill: bool - ) -> List[SequenceGenerationResponse]: - ( - input_ids, - positions, - seq_lens, - slot_mapping, - indices_within_window, - block_tables, - ) = _prepare_inputs( - requests, - cache.slot_mappings, - cache.block_tables, - self.sliding_window, - self.dev, - is_prefill, - ) - - if self.disco_session: - input_ids = copy_to_worker_0(self.disco_session, input_ids) - positions = copy_to_worker_0(self.disco_session, positions) - seq_lens = copy_to_worker_0(self.disco_session, seq_lens) - slot_mapping = copy_to_worker_0(self.disco_session, slot_mapping) - - kv_cache = cache.cache - - if is_prefill: - if self.sliding_window: - if self.disco_session: - indices_within_window = copy_to_worker_0( - self.disco_session, indices_within_window - ) - - out = self.mod["prefill"]( - input_ids, - positions, - seq_lens, - kv_cache, - slot_mapping, - indices_within_window, - self.params, - ) - else: - out = self.mod["prefill"]( - input_ids, positions, seq_lens, kv_cache, slot_mapping, self.params - ) - - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[0] # Ignore returned KV cache since it is updated in-place anyway. - else: - if self.disco_session: - block_tables = copy_to_worker_0(self.disco_session, block_tables) - - out = self.mod["decode"]( - input_ids, - positions, - seq_lens, - kv_cache, - slot_mapping, - block_tables, - self.params, - ) - - if self.disco_session: - logits, _ = out.debug_get_from_remote(0) - else: - logits = out[0] - - next_tokens = sample(logits) - - return [ - SequenceGenerationResponse(request.request_id, new_token) - for request, new_token in zip(requests, next_tokens) - ] - - -def parse_args(): - # Example - # python build.py --model vicuna-v1-7b --quantization q4f16_ft --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention - # python examples/python/run_llama_batched_vllm.py --local-id vicuna-v1-7b-q4f16_ft - # - # For Disco: - # python build.py --model vicuna-v1-7b --quantization q0f16 --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention --build-model-only --num-shards 2 - # python build.py --model vicuna-v1-7b --quantization q0f16 --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention --convert-weight-only - # CUDA_VISIBLE_DEVICES=0,1 python examples/python/run_llama_batched_vllm.py --local-id vicuna-v1-7b-q0f16 --num-shards 2 - - args = argparse.ArgumentParser() - args.add_argument("--local-id", type=str, required=True) - args.add_argument("--artifact-path", type=str, default="dist") - args.add_argument("--num-shards", type=int, default=1) - args.add_argument("--num-decode-steps", type=int, default=20) - parsed = args.parse_args() - parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) - utils.argparse_postproc_common(parsed) - parsed.artifact_path = os.path.join( - parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" - ) - return parsed - - -def run(args): - quantization = args.quantization.name - artifact_path = args.artifact_path - model_name = args.model - model_path = f"dist/models/{model_name}" - - dev = tvm.device("cuda", 0) - - with open(os.path.join(model_path, "config.json"), encoding="utf-8") as i_f: - config = LlamaConfig(**json.load(i_f)) - - model = Model( - artifact_path, - model_name, - quantization, - config.vocab_size, - args.num_shards, - dev, - config.sliding_window, - ) - - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) - - num_kv_heads = config.get_num_key_value_heads() // args.num_shards - head_size = config.hidden_size // config.num_attention_heads - num_blocks = 500 - - cache_manager = CacheManager( - num_blocks, - config.num_hidden_layers, - num_kv_heads, - head_size, - model.disco_session, - sliding_window=config.sliding_window, - ) - cache = cache_manager.get() - - model.block_sliding_window = cache_manager.block_sliding_window - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - batched_token_ids = [tokenizer.encode(p) for p in prompts] - prompts_len = [len(ids) for ids in batched_token_ids] - request_ids = list(range(len(prompts))) - target_sizes = [] - requests = [] - - for token_ids, request_id in zip(batched_token_ids, request_ids): - request_ids.append(request_id) - target_sizes.append(len(token_ids)) - requests.append(SequenceGenerationRequest(request_id, token_ids)) - - cache_manager.set_size(request_ids, target_sizes) - - out = model.generate(requests, cache, True) - - for _ in range(args.num_decode_steps): - for i, response in enumerate(out): - new_token_id = response.token_id - requests[i].token_ids.append(new_token_id) - target_sizes[i] += 1 - - cache_manager.set_size(request_ids, target_sizes) - - out = model.generate(requests, cache, False) - - output_tokens = [ - tokenizer.convert_ids_to_tokens( - requests[i].token_ids[prompts_len[i] :], skip_special_tokens=True - ) - for i in range(len(requests)) - ] - - generated = [tokenizer.convert_tokens_to_string(tokens) for tokens in output_tokens] - - for p, g in zip(prompts, generated): - print("Prompt = '{}', generated text = '{}'".format(p, g)) - - -if __name__ == "__main__": - run(parse_args()) diff --git a/examples/python/sample_chat_stream.py b/examples/python/sample_chat_stream.py deleted file mode 100644 index 7b6beea0a3..0000000000 --- a/examples/python/sample_chat_stream.py +++ /dev/null @@ -1,30 +0,0 @@ -from mlc_llm import ChatModule -from mlc_llm.callback import StreamToStdout, StreamIterator - -# From the mlc-llm directory, run -# $ python examples/python/sample_chat_stream.py - -# Create a ChatModule instance -cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1") - -# Stream to Stdout -output = cm.generate( - prompt="What is the meaning of life?", - progress_callback=StreamToStdout(callback_interval=2), -) - -# Stream to an Iterator -from threading import Thread - -stream = StreamIterator(callback_interval=2) -generation_thread = Thread( - target=cm.generate, - kwargs={"prompt": "What is the meaning of life?", "progress_callback": stream}, -) -generation_thread.start() - -output = "" -for delta_message in stream: - output += delta_message - -generation_thread.join() diff --git a/examples/python/sample_mlc_chat.py b/examples/python/sample_mlc_chat.py deleted file mode 100644 index de00e84ff6..0000000000 --- a/examples/python/sample_mlc_chat.py +++ /dev/null @@ -1,39 +0,0 @@ -from mlc_llm import ChatModule -from mlc_llm.callback import StreamToStdout - -# From the mlc-llm directory, run -# $ python examples/python/sample_mlc_llm.py - -# Create a ChatModule instance -cm = ChatModule( - model="dist/Llama-2-7b-chat-hf-q4f16_1-MLC", - model_lib_path="dist/prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - # Vulkan on Linux: Llama-2-7b-chat-hf-q4f16_1-vulkan.so - # Metal on macOS: Llama-2-7b-chat-hf-q4f16_1-metal.so - # Other platforms: Llama-2-7b-chat-hf-q4f16_1-{backend}.{suffix} -) - -# You can change to other models that you downloaded -# Model variants of the same architecture can reuse the same model library -# Here WizardMath reuses Mistral's model library -# cm = ChatModule( -# model="dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC", # or "dist/WizardMath-7B-V1.1-q4f16_1-MLC" -# model_lib_path="dist/prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so" -# ) - -# Generate a response for a given prompt -output = cm.generate( - prompt="What is the meaning of life?", - progress_callback=StreamToStdout(callback_interval=2), -) - -# Print prefill and decode performance statistics -print(f"Statistics: {cm.stats()}\n") - -output = cm.generate( - prompt="How many points did you list out?", - progress_callback=StreamToStdout(callback_interval=2), -) - -# Reset the chat module by -# cm.reset_chat() diff --git a/examples/python/sample_mlc_engine.py b/examples/python/sample_mlc_engine.py deleted file mode 100644 index e4f869930f..0000000000 --- a/examples/python/sample_mlc_engine.py +++ /dev/null @@ -1,17 +0,0 @@ -from mlc_llm import MLCEngine - -# Create engine -model = "HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC" -engine = MLCEngine(model) - -# Run chat completion in OpenAI API. -for response in engine.chat.completions.create( - messages=[{"role": "user", "content": "What is the meaning of life?"}], - model=model, - stream=True, -): - for choice in response.choices: - print(choice.delta.content, end="", flush=True) -print("\n") - -engine.terminate() diff --git a/examples/rest/nodejs/README.MD b/examples/rest/nodejs/README.MD deleted file mode 100755 index 419b959ef3..0000000000 --- a/examples/rest/nodejs/README.MD +++ /dev/null @@ -1,21 +0,0 @@ -# Node/Javascript/Typescript Access Examples for mlc_llm REST APIs - -Please make sure you are running v18.17.x of node (and npm v9.6.7) -- v20.x currently has some compatibility problems with typescript used in the langchain example. - -First install dependencies. - -`npm i` - -Copy `dotenv.exmaple` to `.env`. - -To run JS chat completion (both streaming and non-streaming) example: - -`node sample_client.js` - -To run OpenAI (chat completion streaming and non-streaming, and legacy completion) example: - -`node sample_openai.js` - -To run LangchainJS Typescript example: - -`npm run example` diff --git a/examples/rest/nodejs/dotenv.example b/examples/rest/nodejs/dotenv.example deleted file mode 100755 index 5312f497f0..0000000000 --- a/examples/rest/nodejs/dotenv.example +++ /dev/null @@ -1,2 +0,0 @@ -OPENAI_API_KEY="none" -OPENAI_API_BASE="http://127.0.0.1:8000/v1" \ No newline at end of file diff --git a/examples/rest/nodejs/package.json b/examples/rest/nodejs/package.json deleted file mode 100755 index 2a3ebf25a7..0000000000 --- a/examples/rest/nodejs/package.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "name": "mlc-llm-js-examples", - "version": "1.0.0", - "description": "", - "main": "index.js", - "type": "module", - "license": "AGPL-version-3.0", - "private": false, - "engines": { - "node": ">= 14.0.0", - "npm": ">= 6.0.0" - }, - "homepage": "", - "repository": { - "type": "git", - "url": "" - }, - "bugs": "", - "keywords": [], - "author": { - "name": "", - "email": "", - "url": "" - }, - "contributors": [], - "scripts": { - "example": "ts-node --esm ./sample_langchain.ts" - }, - "dependencies": { - "@types/node": "^20.4.4", - "dotenv": "^16.3.1", - "langchain": "^0.0.117", - "needle": "^3.2.0", - "openai": "^3.3.0", - "typescript": "^5.1.6" - }, - "devDependencies": { - "ts-node": "^10.9.1" - } -} diff --git a/examples/rest/nodejs/sample_client.js b/examples/rest/nodejs/sample_client.js deleted file mode 100755 index 9a85072aaa..0000000000 --- a/examples/rest/nodejs/sample_client.js +++ /dev/null @@ -1,74 +0,0 @@ -import request from 'needle'; - -( async () => { -const color = { - PURPLE : '\x1b[95m', - CYAN : '\x1b[96m', - DARKCYAN : '\x1b[36m', - BLUE : '\x1b[94m', - GREEN : '\x1b[92m', - YELLOW : '\x1b[93m', - RED : '\x1b[91m', - BOLD : '\x1b[1m', - UNDERLINE : '\x1b[4m', - END : '\x1b[0m' -}; - -let payload = { - model : 'vicuna-v1-7b', - messages: [{"role": "user", "content": "Write a haiku"}], - stream: false -}; - -const print = ( str ) => { - process.stdout.write(str); -}; - -const newline = () => { - print('\n'); -} - -newline(); -print(color.BOLD + "Without streaming:" + color.END); -newline(); - -let r = await request("post", "http://127.0.0.1:8000/v1/chat/completions", payload, {json: true}); - -print(color.GREEN + r.body.choices[0].message.content + color.END); -print('\n'); -// Reset the chat -r = await request("post", "http://127.0.0.1:8000/v1/chat/completions", payload, {json: true}); -print(color.BOLD + "Reset chat" + color.END); -newline(); - -// Get a response using a prompt with streaming - -payload = { - "model": "vicuna-v1-7b", - "messages": [{"role": "user", "content": "Write a haiku"}], - "stream": true -} - -print( color.BOLD + "With streaming:" + color.END); -newline(); -r = request.post( "http://127.0.0.1:8000/v1/chat/completions", payload, {json: true}) -.on('readable', function() { - let jsData = ''; - let data = ''; - while (data = this.read()) { - const chunk = data.toString().substring(6); - if (chunk.trim() === "[DONE]") break; - jsData = JSON.parse(chunk); - print(color.GREEN + jsData.choices[0].delta.content + color.END); - } -}) -.on('done', async function () { - newline(); - let txtresp = await request("get", "http://127.0.0.1:8000/stats"); - print(color.BOLD + "Runtime stats:" + color.END + txtresp.body); - -}) - -})() - - diff --git a/examples/rest/nodejs/sample_langchain.ts b/examples/rest/nodejs/sample_langchain.ts deleted file mode 100644 index 48e849dfa5..0000000000 --- a/examples/rest/nodejs/sample_langchain.ts +++ /dev/null @@ -1,75 +0,0 @@ -import { OpenAI } from "langchain/llms/openai"; -import { BufferWindowMemory } from "langchain/memory"; -import { LLMChain } from "langchain/chains"; -import { PromptTemplate } from "langchain/prompts"; -import {TextLoader } from "langchain/document_loaders/fs/text"; -import { loadQAStuffChain } from "langchain/chains"; - -const color = { - PURPLE : '\x1b[95m', - CYAN : '\x1b[96m', - DARKCYAN : '\x1b[36m', - BLUE : '\x1b[94m', - GREEN : '\x1b[92m', - YELLOW : '\x1b[93m', - RED : '\x1b[91m', - BOLD : '\x1b[1m', - UNDERLINE : '\x1b[4m', - END : '\x1b[0m' -}; - -function print(str: string) { - process.stdout.write(str); -} - -const newline = () => { - print('\n'); -} - - const chat = new OpenAI( { - openAIApiKey: "empty", - temperature: 0 - }, { - basePath: 'http://127.0.0.1:8000/v1' - }); - -// Conversational LLMChain example - const memory = new BufferWindowMemory({ memoryKey: "history", k: 1 }); - - const template = `The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. - - Current conversation: - {history} - Human: {human_input} - AI:`; - - - const prompt = PromptTemplate.fromTemplate(template); - let chain = new LLMChain({ llm: chat, prompt, memory }); - - let input = "Write a poem about Pittsburgh."; - print(color.BOLD + input + "..." + color.END); - newline(); - let res = await chain.call({ human_input: input }); - newline(); - print(color.GREEN + res.text + color.END); - newline(); - input = "What does it mean?"; - print(color.BOLD + input + "..." + color.END); - newline(); - res = await chain.call({ human_input: input }); - newline(); - print(color.GREEN + res.text + color.END); - newline(); - -// Question and answer stuff chain example with text loader -const loader = new TextLoader('../resources/linux.txt'); -const documents = await loader.load(); -const schain = loadQAStuffChain(chat); -const query = "When was Linux released?"; -newline(); newline(); -print(color.BOLD + "Query: " + color.END + color.BLUE + query + color.END); -newline(); -const result = await schain.call({ input_documents: documents, question: query}); -print(color.BOLD + "Response: " + color.END + color.GREEN + result.text + color.END); - diff --git a/examples/rest/nodejs/sample_openai.js b/examples/rest/nodejs/sample_openai.js deleted file mode 100755 index 6e061148f0..0000000000 --- a/examples/rest/nodejs/sample_openai.js +++ /dev/null @@ -1,77 +0,0 @@ -import { Configuration, OpenAIApi } from "openai"; -import dotenv from "dotenv"; -dotenv.config(); - -( async () => { - -const configuration = new Configuration({ - apiKey: process.env.OPENAI_API_KEY, - basePath : process.env.OPENAI_API_BASE -}) -const openai = new OpenAIApi(configuration); -let model = "vicuna-v1-7b" - -const color = { - PURPLE : '\x1b[95m', - CYAN : '\x1b[96m', - DARKCYAN : '\x1b[36m', - BLUE : '\x1b[94m', - GREEN : '\x1b[92m', - YELLOW : '\x1b[93m', - RED : '\x1b[91m', - BOLD : '\x1b[1m', - UNDERLINE : '\x1b[4m', - END : '\x1b[0m' -}; - -const print = ( str ) => { - process.stdout.write(str); -}; - -const newline = () => { - print('\n'); -} - -// Chat completion example without streaming -newline(); -print(color.BOLD + "OpenAI chat completion example without streaming:" + color.END); -newline(); - -let completion = await openai.createChatCompletion({ - model: model, - messages: [{"role": "user", "content": "Write a poem about OpenAI"}] -}); - - -print(color.GREEN + completion.data.choices[0].message.content + color.END) -newline(); newline(); - - -// Chat completion example with streaming -// (raw implementation since npm module does not support it yet - it will have support in upcoming 4.x) - -print(color.BOLD + "OpenAI chat completion example with streaming:" + color.END); -newline(); -completion = await openai.createChatCompletion({ - model: model, - messages: [{"role": "user", "content": "Write a poem about OpenAI"}], - stream: true, -}, {responseType: 'stream'}); - -completion.data.on('data', async (data) => { - const parsed = JSON.parse(data.toString().substring(6)); - print(color.GREEN + parsed.choices[0].delta.content + color.END); -}); - -completion.data.on('close', async () => { - newline(); newline(); - - // Completion example - print(color.BOLD + "OpenAI completion example:" + color.END) - newline(); - let res = await openai.createCompletion({ prompt: "Write a poem about OpenAI", model: model}); - print(color.GREEN + res.data.choices[0].text + color.END); - newline(); newline(); - - }); -})() \ No newline at end of file diff --git a/examples/rest/nodejs/tsconfig.json b/examples/rest/nodejs/tsconfig.json deleted file mode 100755 index bc563cb043..0000000000 --- a/examples/rest/nodejs/tsconfig.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "compilerOptions": { - "target": "es2020", /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */ - "lib": ["es2020"], /* Specify a set of bundled library declaration files that describe the target runtime environment. */ - "module": "nodenext", /* Specify what module code is generated. */ - "rootDir": "src", /* Specify the root folder within your source files. */ - "outDir": "./dist", /* Specify an output folder for all emitted files. */ - "esModuleInterop": true, /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */ - "forceConsistentCasingInFileNames": true, /* Ensure that casing is correct in imports. */ - "strict": true, /* Enable all strict type-checking options. */ - "noImplicitAny": true, /* Enable error reporting for expressions and declarations with an implied 'any' type. */ - "skipLibCheck": true /* Skip type checking all .d.ts files. */ - } -} diff --git a/examples/rest/python/sample_client.py b/examples/rest/python/sample_client.py deleted file mode 100644 index 1af1d837af..0000000000 --- a/examples/rest/python/sample_client.py +++ /dev/null @@ -1,46 +0,0 @@ -import requests -import json - -class color: - PURPLE = '\033[95m' - CYAN = '\033[96m' - DARKCYAN = '\033[36m' - BLUE = '\033[94m' - GREEN = '\033[92m' - YELLOW = '\033[93m' - RED = '\033[91m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' - END = '\033[0m' - -# Get a response using a prompt without streaming -payload = { - "model": "vicuna-v1-7b", - "messages": [{"role": "user", "content": "Write a haiku"}], - "stream": False -} -r = requests.post("http://127.0.0.1:8000/v1/chat/completions", json=payload) -print(f"{color.BOLD}Without streaming:{color.END}\n{color.GREEN}{r.json()['choices'][0]['message']['content']}{color.END}\n") - -# Reset the chat -r = requests.post("http://127.0.0.1:8000/chat/reset", json=payload) -print(f"{color.BOLD}Reset chat:{color.END} {str(r)}\n") - -# Get a response using a prompt with streaming -payload = { - "model": "vicuna-v1-7b", - "messages": [{"role": "user", "content": "Write a haiku"}], - "stream": True -} -with requests.post("http://127.0.0.1:8000/v1/chat/completions", json=payload, stream=True) as r: - print(f"{color.BOLD}With streaming:{color.END}") - for chunk in r: - if (chunk[6:].decode('utf-8').strip() == '[DONE]'): - break - content = json.loads(chunk[6:])["choices"][0]["delta"].get("content", "") - print(f"{color.GREEN}{content}{color.END}", end="", flush=True) - print("\n") - -# Get the latest runtime stats -r = requests.get("http://127.0.0.1:8000/stats") -print(f"{color.BOLD}Runtime stats:{color.END} {r.json()}\n") diff --git a/examples/rest/python/sample_langchain.py b/examples/rest/python/sample_langchain.py deleted file mode 100644 index 1bfe80bd26..0000000000 --- a/examples/rest/python/sample_langchain.py +++ /dev/null @@ -1,165 +0,0 @@ -from langchain.chat_models import ChatOpenAI -from langchain import LLMChain, PromptTemplate -from langchain.memory import ConversationBufferWindowMemory -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from langchain.document_loaders import TextLoader, UnstructuredRSTLoader, DirectoryLoader -from langchain.chains.question_answering import load_qa_chain -from langchain.llms import OpenAI -from langchain.text_splitter import CharacterTextSplitter -from langchain.chains import RetrievalQA -from langchain.vectorstores import Chroma - -# Note that Langchain support for embedding documents using MLC is currently blocked on -# https://github.com/langchain-ai/langchain/pull/7815 -# We have subclassed `OpenAIEmbeddings` in the meantime to get around this dependency. -from mlc_llm.embeddings.openai import MLCEmbeddings - - -# First set the following in your environment: -# export OPENAI_API_BASE=http://127.0.0.1:8000/v1 -# export OPENAI_API_KEY=EMPTY - -# Note that Langchain does not currently support Pydantic v2: -# https://github.com/langchain-ai/langchain/issues/6841 -# Please ensure that your `pydantic` version is < 2.0 - - -class color: - PURPLE = "\033[95m" - CYAN = "\033[96m" - DARKCYAN = "\033[36m" - BLUE = "\033[94m" - GREEN = "\033[92m" - YELLOW = "\033[93m" - RED = "\033[91m" - BOLD = "\033[1m" - UNDERLINE = "\033[4m" - END = "\033[0m" - - -def llm_chain_example(): - template = """ - {history} - USER: {human_input} - ASSISTANT:""" - - prompt = PromptTemplate(input_variables=["history", "human_input"], template=template) - - llm_chain = LLMChain( - llm=ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()]), - prompt=prompt, - verbose=True, - memory=ConversationBufferWindowMemory(human_prefix="USER", ai_prefix="ASSISTANT"), - ) - - output = llm_chain.predict(human_input="Write a short poem about Pittsburgh.") - output = llm_chain.predict(human_input="What does the poem mean?") - - -def load_qa_chain_example(): - loader = TextLoader("../resources/linux.txt") - documents = loader.load() - chain = load_qa_chain(llm=OpenAI(), chain_type="stuff", verbose=False) - query = "When was Linux released?" - print(f"{color.BOLD}Query:{color.END} {color.BLUE} {query}{color.END}") - print( - f"{color.BOLD}Response:{color.END} {color.GREEN}{chain.run(input_documents=documents, question=query)}{color.END}" - ) - - -def retrieval_qa_sotu_example(): - prompt_template = """Use only the following pieces of context to answer the question at the end. Don't use any other knowledge. - - {context} - - USER: {question} - ASSISTANT:""" - - PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) - - loader = TextLoader("../resources/state_of_the_union.txt") - documents = loader.load() - - text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100) - texts = text_splitter.split_documents(documents) - # print(texts) - embeddings = MLCEmbeddings(deployment="text-embedding-ada-002", embedding_ctx_length=None) - db = Chroma.from_documents(documents=texts, embedding=embeddings) - retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2}) - qa = RetrievalQA.from_chain_type( - llm=OpenAI(), - chain_type="stuff", - retriever=retriever, - return_source_documents=True, - chain_type_kwargs={"prompt": PROMPT}, - ) - questions = [ - "What is the American Rescue Plan?", - "What did the president say about Ketanji Brown Jackson?", - "Who is mentioned in the speech?", - "To whom is the speech addressed?", - "Tell me more about the Made in America campaign.", - ] - - for qn in questions: - print(f"{color.BOLD}QUESTION:{color.END} {qn}") - res = qa({"query": qn}) - print(f"{color.BOLD}RESPONSE:{color.END} {color.GREEN}{res['result']}{color.END}") - print( - f"{color.BOLD}SOURCE:{color.END} {color.BLUE}{repr(res['source_documents'][0].page_content)}{color.END}" - ) - print() - - -def retrieval_qa_mlc_docs_example(): - prompt_template = """Use only the following pieces of context to answer the question at the end. Don't use any other knowledge. - - {context} - - USER: {question} - ASSISTANT:""" - - PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) - - loader = DirectoryLoader( - "../../../docs", - glob="*/*.rst", - show_progress=True, - loader_cls=UnstructuredRSTLoader, - loader_kwargs={"mode": "single"}, - ) - documents = loader.load() - text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) - texts = text_splitter.split_documents(documents) - embeddings = MLCEmbeddings(deployment="text-embedding-ada-002", embedding_ctx_length=None) - db = Chroma.from_documents(collection_name="abc", documents=texts, embedding=embeddings) - retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) - qa = RetrievalQA.from_chain_type( - llm=OpenAI(), - chain_type="stuff", - retriever=retriever, - return_source_documents=True, - chain_type_kwargs={"prompt": PROMPT}, - ) - while True: - qn = input(f"{color.BOLD}QUESTION:{color.END} ") - res = qa({"query": qn}) - print(f"{color.BOLD}RESPONSE:{color.END} {color.GREEN}{res['result']}{color.END}") - print( - f"{color.BOLD}SOURCE:{color.END} {color.BLUE}{repr(res['source_documents'][0].page_content)}{color.END}" - ) - print() - - # Some example questions: - # - What is the chat config? - # - What is temperature? - # - What are the REST API endpoints? - # - What are the available quantization options? - - -# Uncomment one of the following lines to try out the corresponding demo: - -# llm_chain_example() -# load_qa_chain_example() -# retrieval_qa_sotu_example() -# retrieval_qa_mlc_docs_example() diff --git a/examples/rest/python/sample_openai.py b/examples/rest/python/sample_openai.py deleted file mode 100644 index 1c4acb0ffc..0000000000 --- a/examples/rest/python/sample_openai.py +++ /dev/null @@ -1,43 +0,0 @@ -import openai - -openai.api_key = "None" -openai.api_base = "http://127.0.0.1:8000/v1" - -model = "vicuna-v1-7b" - -class color: - PURPLE = '\033[95m' - CYAN = '\033[96m' - DARKCYAN = '\033[36m' - BLUE = '\033[94m' - GREEN = '\033[92m' - YELLOW = '\033[93m' - RED = '\033[91m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' - END = '\033[0m' - -# Chat completion example without streaming -print(f"{color.BOLD}OpenAI chat completion example without streaming:{color.END}\n") -completion = openai.ChatCompletion.create( - model=model, - messages=[{"role": "user", "content": "Write a poem about OpenAI"}] -) -print(f"{color.GREEN}{completion.choices[0].message.content}{color.END}\n\n") - -# Chat completion example with streaming -print(f"{color.BOLD}OpenAI chat completion example with streaming:{color.END}\n") -res = openai.ChatCompletion.create( - model=model, - messages=[{"role": "user", "content": "Write a poem about OpenAI"}], - stream=True -) -for chunk in res: - content = chunk["choices"][0]["delta"].get("content", "") - print(f"{color.GREEN}{content}{color.END}", end="", flush=True) -print("\n") - -# Completion example -print(f"{color.BOLD}OpenAI completion example:{color.END}\n") -res = openai.Completion.create(prompt="Write a poem about OpenAI", model=model) -print(f"{color.GREEN}{res.choices[0].text}{color.END}\n\n") diff --git a/examples/rest/resources/linux.txt b/examples/rest/resources/linux.txt deleted file mode 100644 index 9f09b49c02..0000000000 --- a/examples/rest/resources/linux.txt +++ /dev/null @@ -1,23 +0,0 @@ -Linux is a family of open-source Unix-like operating systems based on the Linux kernel, an operating system kernel first released on September 17, 1991, by Linus Torvalds. Linux is typically packaged as a Linux distribution, which includes the kernel and supporting system software and libraries, many of which are provided by the GNU Project. Many Linux distributions use the word "Linux" in their name, but the Free Software Foundation uses the name "GNU/Linux" to emphasize the importance of GNU software, causing some controversy. - -Popular Linux distributions include Debian, Fedora Linux, and Ubuntu, the latter of which itself consists of many different distributions and modifications, including Lubuntu and Xubuntu. Commercial distributions include Red Hat Enterprise Linux and SUSE Linux Enterprise. Desktop Linux distributions include a windowing system such as X11 or Wayland, and a desktop environment such as GNOME or KDE Plasma. Distributions intended for servers may omit graphics altogether, or include a solution stack such as LAMP. Because Linux is freely redistributable, anyone may create a distribution for any purpose. - -Linux was originally developed for personal computers based on the Intel x86 architecture, but has since been ported to more platforms than any other operating system. Because of the dominance of the Linux-based Android on smartphones, Linux, including Android, has the largest installed base of all general-purpose operating systems, as of May 2022. Although Linux is, as of November 2022, used by only around 2.6 percent of desktop computers, the Chromebook, which runs the Linux kernel-based ChromeOS, dominates the US K–12 education market and represents nearly 20 percent of sub-$300 notebook sales in the US. Linux is the leading operating system on servers (over 96.4% of the top 1 million web servers' operating systems are Linux), leads other big iron systems such as mainframe computers, and is used on all of the world's 500 fastest supercomputers (since November 2017, having gradually displaced all competitors). - -Linux also runs on embedded systems, i.e. devices whose operating system is typically built into the firmware and is highly tailored to the system. This includes routers, automation controls, smart home devices, video game consoles, televisions (Samsung and LG Smart TVs), automobiles (Tesla, Audi, Mercedes-Benz, Hyundai and Toyota), and spacecraft (Falcon 9 rocket, Dragon crew capsule and the Perseverance rover). - -Linux is one of the most prominent examples of free and open-source software collaboration. The source code may be used, modified and distributed commercially or non-commercially by anyone under the terms of its respective licenses, such as the GNU General Public License (GPL). The Linux kernel, for example, is licensed under the GPLv2, with an exception for system calls that allows code that calls the kernel via system calls not to be licensed under the GPL. - -The Unix operating system was conceived and implemented in 1969, at AT&T's Bell Labs, in the United States by Ken Thompson, Dennis Ritchie, Douglas McIlroy, and Joe Ossanna. First released in 1971, Unix was written entirely in assembly language, as was common practice at the time. In 1973, in a key pioneering approach, it was rewritten in the C programming language by Dennis Ritchie (with the exception of some hardware and I/O routines). The availability of a high-level language implementation of Unix made its porting to different computer platforms easier. - -Due to an earlier antitrust case forbidding it from entering the computer business, AT&T licensed the operating system's source code as a trade secret to anyone who asked. As a result, Unix grew quickly and became widely adopted by academic institutions and businesses. In 1984, AT&T divested itself of its regional operating companies, and was released from its obligation not to enter the computer business; freed of that obligation, Bell Labs began selling Unix as a proprietary product, where users were not legally allowed to modify it. - -Onyx Systems began selling early microcomputer-based Unix workstations in 1980. Later, Sun Microsystems, founded as a spin-off of a student project at Stanford University, also began selling Unix-based desktop workstations in 1982. While Sun workstations did not utilize commodity PC hardware, for which Linux was later originally developed, it represented the first successful commercial attempt at distributing a primarily single-user microcomputer that ran a Unix operating system. - -With Unix increasingly "locked in" as a proprietary product, the GNU Project, started in 1983 by Richard Stallman, had the goal of creating a "complete Unix-compatible software system" composed entirely of free software. Work began in 1984. Later, in 1985, Stallman started the Free Software Foundation and wrote the GNU General Public License (GNU GPL) in 1989. By the early 1990s, many of the programs required in an operating system (such as libraries, compilers, text editors, a command-line shell, and a windowing system) were completed, although low-level elements such as device drivers, daemons, and the kernel, called GNU Hurd, were stalled and incomplete. - -MINIX was created by Andrew S. Tanenbaum, a computer science professor, and released in 1987 as a minimal Unix-like operating system targeted at students and others who wanted to learn operating system principles. Although the complete source code of MINIX was freely available, the licensing terms prevented it from being free software until the licensing changed in April 2000. - -Although not released until 1992, due to legal complications, development of 386BSD, from which NetBSD, OpenBSD and FreeBSD descended, predated that of Linux. - -Linus Torvalds has stated on separate occasions that if the GNU kernel or 386BSD had been available at the time (1991), he probably would not have created Linux. \ No newline at end of file diff --git a/examples/rest/resources/state_of_the_union.txt b/examples/rest/resources/state_of_the_union.txt deleted file mode 100644 index d50175de40..0000000000 --- a/examples/rest/resources/state_of_the_union.txt +++ /dev/null @@ -1,723 +0,0 @@ -Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. - -Last year COVID-19 kept us apart. This year we are finally together again. - -Tonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. - -With a duty to one another to the American people to the Constitution. - -And with an unwavering resolve that freedom will always triumph over tyranny. - -Six days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. - -He thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. - -He met the Ukrainian people. - -From President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. - -Groups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. - -In this struggle as President Zelenskyy said in his speech to the European Parliament “Light will win over darkness.” The Ukrainian Ambassador to the United States is here tonight. - -Let each of us here tonight in this Chamber send an unmistakable signal to Ukraine and to the world. - -Please rise if you are able and show that, Yes, we the United States of America stand with the Ukrainian people. - -Throughout our history we’ve learned this lesson when dictators do not pay a price for their aggression they cause more chaos. - -They keep moving. - -And the costs and the threats to America and the world keep rising. - -That’s why the NATO Alliance was created to secure peace and stability in Europe after World War 2. - -The United States is a member along with 29 other nations. - -It matters. American diplomacy matters. American resolve matters. - -Putin’s latest attack on Ukraine was premeditated and unprovoked. - -He rejected repeated efforts at diplomacy. - -He thought the West and NATO wouldn’t respond. And he thought he could divide us at home. Putin was wrong. We were ready. Here is what we did. - -We prepared extensively and carefully. - -We spent months building a coalition of other freedom-loving nations from Europe and the Americas to Asia and Africa to confront Putin. - -I spent countless hours unifying our European allies. We shared with the world in advance what we knew Putin was planning and precisely how he would try to falsely justify his aggression. - -We countered Russia’s lies with truth. - -And now that he has acted the free world is holding him accountable. - -Along with twenty-seven members of the European Union including France, Germany, Italy, as well as countries like the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland. - -We are inflicting pain on Russia and supporting the people of Ukraine. Putin is now isolated from the world more than ever. - -Together with our allies –we are right now enforcing powerful economic sanctions. - -We are cutting off Russia’s largest banks from the international financial system. - -Preventing Russia’s central bank from defending the Russian Ruble making Putin’s $630 Billion “war fund” worthless. - -We are choking off Russia’s access to technology that will sap its economic strength and weaken its military for years to come. - -Tonight I say to the Russian oligarchs and corrupt leaders who have bilked billions of dollars off this violent regime no more. - -The U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs. - -We are joining with our European allies to find and seize your yachts your luxury apartments your private jets. We are coming for your ill-begotten gains. - -And tonight I am announcing that we will join our allies in closing off American air space to all Russian flights – further isolating Russia – and adding an additional squeeze –on their economy. The Ruble has lost 30% of its value. - -The Russian stock market has lost 40% of its value and trading remains suspended. Russia’s economy is reeling and Putin alone is to blame. - -Together with our allies we are providing support to the Ukrainians in their fight for freedom. Military assistance. Economic assistance. Humanitarian assistance. - -We are giving more than $1 Billion in direct assistance to Ukraine. - -And we will continue to aid the Ukrainian people as they defend their country and to help ease their suffering. - -Let me be clear, our forces are not engaged and will not engage in conflict with Russian forces in Ukraine. - -Our forces are not going to Europe to fight in Ukraine, but to defend our NATO Allies – in the event that Putin decides to keep moving west. - -For that purpose we’ve mobilized American ground forces, air squadrons, and ship deployments to protect NATO countries including Poland, Romania, Latvia, Lithuania, and Estonia. - -As I have made crystal clear the United States and our Allies will defend every inch of territory of NATO countries with the full force of our collective power. - -And we remain clear-eyed. The Ukrainians are fighting back with pure courage. But the next few days weeks, months, will be hard on them. - -Putin has unleashed violence and chaos. But while he may make gains on the battlefield – he will pay a continuing high price over the long run. - -And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. - -To all Americans, I will be honest with you, as I’ve always promised. A Russian dictator, invading a foreign country, has costs around the world. - -And I’m taking robust action to make sure the pain of our sanctions is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers. - -Tonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. - -America will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. - -These steps will help blunt gas prices here at home. And I know the news about what’s happening can seem alarming. - -But I want you to know that we are going to be okay. - -When the history of this era is written Putin’s war on Ukraine will have left Russia weaker and the rest of the world stronger. - -While it shouldn’t have taken something so terrible for people around the world to see what’s at stake now everyone sees it clearly. - -We see the unity among leaders of nations and a more unified Europe a more unified West. And we see unity among the people who are gathering in cities in large crowds around the world even in Russia to demonstrate their support for Ukraine. - -In the battle between democracy and autocracy, democracies are rising to the moment, and the world is clearly choosing the side of peace and security. - -This is a real test. It’s going to take time. So let us continue to draw inspiration from the iron will of the Ukrainian people. - -To our fellow Ukrainian Americans who forge a deep bond that connects our two nations we stand with you. - -Putin may circle Kyiv with tanks, but he will never gain the hearts and souls of the Ukrainian people. - -He will never extinguish their love of freedom. He will never weaken the resolve of the free world. - -We meet tonight in an America that has lived through two of the hardest years this nation has ever faced. - -The pandemic has been punishing. - -And so many families are living paycheck to paycheck, struggling to keep up with the rising cost of food, gas, housing, and so much more. - -I understand. - -I remember when my Dad had to leave our home in Scranton, Pennsylvania to find work. I grew up in a family where if the price of food went up, you felt it. - -That’s why one of the first things I did as President was fight to pass the American Rescue Plan. - -Because people were hurting. We needed to act, and we did. - -Few pieces of legislation have done more in a critical moment in our history to lift us out of crisis. - -It fueled our efforts to vaccinate the nation and combat COVID-19. It delivered immediate economic relief for tens of millions of Americans. - -Helped put food on their table, keep a roof over their heads, and cut the cost of health insurance. - -And as my Dad used to say, it gave people a little breathing room. - -And unlike the $2 Trillion tax cut passed in the previous administration that benefitted the top 1% of Americans, the American Rescue Plan helped working people—and left no one behind. - -And it worked. It created jobs. Lots of jobs. - -In fact—our economy created over 6.5 Million new jobs just last year, more jobs created in one year -than ever before in the history of America. - -Our economy grew at a rate of 5.7% last year, the strongest growth in nearly 40 years, the first step in bringing fundamental change to an economy that hasn’t worked for the working people of this nation for too long. - -For the past 40 years we were told that if we gave tax breaks to those at the very top, the benefits would trickle down to everyone else. - -But that trickle-down theory led to weaker economic growth, lower wages, bigger deficits, and the widest gap between those at the top and everyone else in nearly a century. - -Vice President Harris and I ran for office with a new economic vision for America. - -Invest in America. Educate Americans. Grow the workforce. Build the economy from the bottom up -and the middle out, not from the top down. - -Because we know that when the middle class grows, the poor have a ladder up and the wealthy do very well. - -America used to have the best roads, bridges, and airports on Earth. - -Now our infrastructure is ranked 13th in the world. - -We won’t be able to compete for the jobs of the 21st Century if we don’t fix that. - -That’s why it was so important to pass the Bipartisan Infrastructure Law—the most sweeping investment to rebuild America in history. - -This was a bipartisan effort, and I want to thank the members of both parties who worked to make it happen. - -We’re done talking about infrastructure weeks. - -We’re going to have an infrastructure decade. - -It is going to transform America and put us on a path to win the economic competition of the 21st Century that we face with the rest of the world—particularly with China. - -As I’ve told Xi Jinping, it is never a good bet to bet against the American people. - -We’ll create good jobs for millions of Americans, modernizing roads, airports, ports, and waterways all across America. - -And we’ll do it all to withstand the devastating effects of the climate crisis and promote environmental justice. - -We’ll build a national network of 500,000 electric vehicle charging stations, begin to replace poisonous lead pipes—so every child—and every American—has clean water to drink at home and at school, provide affordable high-speed internet for every American—urban, suburban, rural, and tribal communities. - -4,000 projects have already been announced. - -And tonight, I’m announcing that this year we will start fixing over 65,000 miles of highway and 1,500 bridges in disrepair. - -When we use taxpayer dollars to rebuild America – we are going to Buy American: buy American products to support American jobs. - -The federal government spends about $600 Billion a year to keep the country safe and secure. - -There’s been a law on the books for almost a century -to make sure taxpayers’ dollars support American jobs and businesses. - -Every Administration says they’ll do it, but we are actually doing it. - -We will buy American to make sure everything from the deck of an aircraft carrier to the steel on highway guardrails are made in America. - -But to compete for the best jobs of the future, we also need to level the playing field with China and other competitors. - -That’s why it is so important to pass the Bipartisan Innovation Act sitting in Congress that will make record investments in emerging technologies and American manufacturing. - -Let me give you one example of why it’s so important to pass it. - -If you travel 20 miles east of Columbus, Ohio, you’ll find 1,000 empty acres of land. - -It won’t look like much, but if you stop and look closely, you’ll see a “Field of dreams,” the ground on which America’s future will be built. - -This is where Intel, the American company that helped build Silicon Valley, is going to build its $20 billion semiconductor “mega site”. - -Up to eight state-of-the-art factories in one place. 10,000 new good-paying jobs. - -Some of the most sophisticated manufacturing in the world to make computer chips the size of a fingertip that power the world and our everyday lives. - -Smartphones. The Internet. Technology we have yet to invent. - -But that’s just the beginning. - -Intel’s CEO, Pat Gelsinger, who is here tonight, told me they are ready to increase their investment from -$20 billion to $100 billion. - -That would be one of the biggest investments in manufacturing in American history. - -And all they’re waiting for is for you to pass this bill. - -So let’s not wait any longer. Send it to my desk. I’ll sign it. - -And we will really take off. - -And Intel is not alone. - -There’s something happening in America. - -Just look around and you’ll see an amazing story. - -The rebirth of the pride that comes from stamping products “Made In America.” The revitalization of American manufacturing. - -Companies are choosing to build new factories here, when just a few years ago, they would have built them overseas. - -That’s what is happening. Ford is investing $11 billion to build electric vehicles, creating 11,000 jobs across the country. - -GM is making the largest investment in its history—$7 billion to build electric vehicles, creating 4,000 jobs in Michigan. - -All told, we created 369,000 new manufacturing jobs in America just last year. - -Powered by people I’ve met like JoJo Burgess, from generations of union steelworkers from Pittsburgh, who’s here with us tonight. - -As Ohio Senator Sherrod Brown says, “It’s time to bury the label “Rust Belt.” - -It’s time. - -But with all the bright spots in our economy, record job growth and higher wages, too many families are struggling to keep up with the bills. - -Inflation is robbing them of the gains they might otherwise feel. - -I get it. That’s why my top priority is getting prices under control. - -Look, our economy roared back faster than most predicted, but the pandemic meant that businesses had a hard time hiring enough workers to keep up production in their factories. - -The pandemic also disrupted global supply chains. - -When factories close, it takes longer to make goods and get them from the warehouse to the store, and prices go up. - -Look at cars. - -Last year, there weren’t enough semiconductors to make all the cars that people wanted to buy. - -And guess what, prices of automobiles went up. - -So—we have a choice. - -One way to fight inflation is to drive down wages and make Americans poorer. - -I have a better plan to fight inflation. - -Lower your costs, not your wages. - -Make more cars and semiconductors in America. - -More infrastructure and innovation in America. - -More goods moving faster and cheaper in America. - -More jobs where you can earn a good living in America. - -And instead of relying on foreign supply chains, let’s make it in America. - -Economists call it “increasing the productive capacity of our economy.” - -I call it building a better America. - -My plan to fight inflation will lower your costs and lower the deficit. - -17 Nobel laureates in economics say my plan will ease long-term inflationary pressures. Top business leaders and most Americans support my plan. And here’s the plan: - -First – cut the cost of prescription drugs. Just look at insulin. One in ten Americans has diabetes. In Virginia, I met a 13-year-old boy named Joshua Davis. - -He and his Dad both have Type 1 diabetes, which means they need insulin every day. Insulin costs about $10 a vial to make. - -But drug companies charge families like Joshua and his Dad up to 30 times more. I spoke with Joshua’s mom. - -Imagine what it’s like to look at your child who needs insulin and have no idea how you’re going to pay for it. - -What it does to your dignity, your ability to look your child in the eye, to be the parent you expect to be. - -Joshua is here with us tonight. Yesterday was his birthday. Happy birthday, buddy. - -For Joshua, and for the 200,000 other young people with Type 1 diabetes, let’s cap the cost of insulin at $35 a month so everyone can afford it. - -Drug companies will still do very well. And while we’re at it let Medicare negotiate lower prices for prescription drugs, like the VA already does. - -Look, the American Rescue Plan is helping millions of families on Affordable Care Act plans save $2,400 a year on their health care premiums. Let’s close the coverage gap and make those savings permanent. - -Second – cut energy costs for families an average of $500 a year by combatting climate change. - -Let’s provide investments and tax credits to weatherize your homes and businesses to be energy efficient and you get a tax credit; double America’s clean energy production in solar, wind, and so much more; lower the price of electric vehicles, saving you another $80 a month because you’ll never have to pay at the gas pump again. - -Third – cut the cost of child care. Many families pay up to $14,000 a year for child care per child. - -Middle-class and working families shouldn’t have to pay more than 7% of their income for care of young children. - -My plan will cut the cost in half for most families and help parents, including millions of women, who left the workforce during the pandemic because they couldn’t afford child care, to be able to get back to work. - -My plan doesn’t stop there. It also includes home and long-term care. More affordable housing. And Pre-K for every 3- and 4-year-old. - -All of these will lower costs. - -And under my plan, nobody earning less than $400,000 a year will pay an additional penny in new taxes. Nobody. - -The one thing all Americans agree on is that the tax system is not fair. We have to fix it. - -I’m not looking to punish anyone. But let’s make sure corporations and the wealthiest Americans start paying their fair share. - -Just last year, 55 Fortune 500 corporations earned $40 billion in profits and paid zero dollars in federal income tax. - -That’s simply not fair. That’s why I’ve proposed a 15% minimum tax rate for corporations. - -We got more than 130 countries to agree on a global minimum tax rate so companies can’t get out of paying their taxes at home by shipping jobs and factories overseas. - -That’s why I’ve proposed closing loopholes so the very wealthy don’t pay a lower tax rate than a teacher or a firefighter. - -So that’s my plan. It will grow the economy and lower costs for families. - -So what are we waiting for? Let’s get this done. And while you’re at it, confirm my nominees to the Federal Reserve, which plays a critical role in fighting inflation. - -My plan will not only lower costs to give families a fair shot, it will lower the deficit. - -The previous Administration not only ballooned the deficit with tax cuts for the very wealthy and corporations, it undermined the watchdogs whose job was to keep pandemic relief funds from being wasted. - -But in my administration, the watchdogs have been welcomed back. - -We’re going after the criminals who stole billions in relief money meant for small businesses and millions of Americans. - -And tonight, I’m announcing that the Justice Department will name a chief prosecutor for pandemic fraud. - -By the end of this year, the deficit will be down to less than half what it was before I took office. - -The only president ever to cut the deficit by more than one trillion dollars in a single year. - -Lowering your costs also means demanding more competition. - -I’m a capitalist, but capitalism without competition isn’t capitalism. - -It’s exploitation—and it drives up prices. - -When corporations don’t have to compete, their profits go up, your prices go up, and small businesses and family farmers and ranchers go under. - -We see it happening with ocean carriers moving goods in and out of America. - -During the pandemic, these foreign-owned companies raised prices by as much as 1,000% and made record profits. - -Tonight, I’m announcing a crackdown on these companies overcharging American businesses and consumers. - -And as Wall Street firms take over more nursing homes, quality in those homes has gone down and costs have gone up. - -That ends on my watch. - -Medicare is going to set higher standards for nursing homes and make sure your loved ones get the care they deserve and expect. - -We’ll also cut costs and keep the economy going strong by giving workers a fair shot, provide more training and apprenticeships, hire them based on their skills not degrees. - -Let’s pass the Paycheck Fairness Act and paid leave. - -Raise the minimum wage to $15 an hour and extend the Child Tax Credit, so no one has to raise a family in poverty. - -Let’s increase Pell Grants and increase our historic support of HBCUs, and invest in what Jill—our First Lady who teaches full-time—calls America’s best-kept secret: community colleges. - -And let’s pass the PRO Act when a majority of workers want to form a union—they shouldn’t be stopped. - -When we invest in our workers, when we build the economy from the bottom up and the middle out together, we can do something we haven’t done in a long time: build a better America. - -For more than two years, COVID-19 has impacted every decision in our lives and the life of the nation. - -And I know you’re tired, frustrated, and exhausted. - -But I also know this. - -Because of the progress we’ve made, because of your resilience and the tools we have, tonight I can say -we are moving forward safely, back to more normal routines. - -We’ve reached a new moment in the fight against COVID-19, with severe cases down to a level not seen since last July. - -Just a few days ago, the Centers for Disease Control and Prevention—the CDC—issued new mask guidelines. - -Under these new guidelines, most Americans in most of the country can now be mask free. - -And based on the projections, more of the country will reach that point across the next couple of weeks. - -Thanks to the progress we have made this past year, COVID-19 need no longer control our lives. - -I know some are talking about “living with COVID-19”. Tonight – I say that we will never just accept living with COVID-19. - -We will continue to combat the virus as we do other diseases. And because this is a virus that mutates and spreads, we will stay on guard. - -Here are four common sense steps as we move forward safely. - -First, stay protected with vaccines and treatments. We know how incredibly effective vaccines are. If you’re vaccinated and boosted you have the highest degree of protection. - -We will never give up on vaccinating more Americans. Now, I know parents with kids under 5 are eager to see a vaccine authorized for their children. - -The scientists are working hard to get that done and we’ll be ready with plenty of vaccines when they do. - -We’re also ready with anti-viral treatments. If you get COVID-19, the Pfizer pill reduces your chances of ending up in the hospital by 90%. - -We’ve ordered more of these pills than anyone in the world. And Pfizer is working overtime to get us 1 Million pills this month and more than double that next month. - -And we’re launching the “Test to Treat” initiative so people can get tested at a pharmacy, and if they’re positive, receive antiviral pills on the spot at no cost. - -If you’re immunocompromised or have some other vulnerability, we have treatments and free high-quality masks. - -We’re leaving no one behind or ignoring anyone’s needs as we move forward. - -And on testing, we have made hundreds of millions of tests available for you to order for free. - -Even if you already ordered free tests tonight, I am announcing that you can order more from covidtests.gov starting next week. - -Second – we must prepare for new variants. Over the past year, we’ve gotten much better at detecting new variants. - -If necessary, we’ll be able to deploy new vaccines within 100 days instead of many more months or years. - -And, if Congress provides the funds we need, we’ll have new stockpiles of tests, masks, and pills ready if needed. - -I cannot promise a new variant won’t come. But I can promise you we’ll do everything within our power to be ready if it does. - -Third – we can end the shutdown of schools and businesses. We have the tools we need. - -It’s time for Americans to get back to work and fill our great downtowns again. People working from home can feel safe to begin to return to the office. - -We’re doing that here in the federal government. The vast majority of federal workers will once again work in person. - -Our schools are open. Let’s keep it that way. Our kids need to be in school. - -And with 75% of adult Americans fully vaccinated and hospitalizations down by 77%, most Americans can remove their masks, return to work, stay in the classroom, and move forward safely. - -We achieved this because we provided free vaccines, treatments, tests, and masks. - -Of course, continuing this costs money. - -I will soon send Congress a request. - -The vast majority of Americans have used these tools and may want to again, so I expect Congress to pass it quickly. - -Fourth, we will continue vaccinating the world. - -We’ve sent 475 Million vaccine doses to 112 countries, more than any other nation. - -And we won’t stop. - -We have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. - -Let’s use this moment to reset. Let’s stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. - -Let’s stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. - -We can’t change how divided we’ve been. But we can change how we move forward—on COVID-19 and other issues we must face together. - -I recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. - -They were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. - -Officer Mora was 27 years old. - -Officer Rivera was 22. - -Both Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. - -I spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. - -I’ve worked on these issues a long time. - -I know what works: Investing in crime preventionand community police officers who’ll walk the beat, who’ll know the neighborhood, and who can restore trust and safety. - -So let’s not abandon our streets. Or choose between safety and equal justice. - -Let’s come together to protect our communities, restore trust, and hold law enforcement accountable. - -That’s why the Justice Department required body cameras, banned chokeholds, and restricted no-knock warrants for its officers. - -That’s why the American Rescue Plan provided $350 Billion that cities, states, and counties can use to hire more police and invest in proven strategies like community violence interruption—trusted messengers breaking the cycle of violence and trauma and giving young people hope. - -We should all agree: The answer is not to Defund the police. The answer is to FUND the police with the resources and training they need to protect our communities. - -I ask Democrats and Republicans alike: Pass my budget and keep our neighborhoods safe. - -And I will keep doing everything in my power to crack down on gun trafficking and ghost guns you can buy online and make at home—they have no serial numbers and can’t be traced. - -And I ask Congress to pass proven measures to reduce gun violence. Pass universal background checks. Why should anyone on a terrorist list be able to purchase a weapon? - -Ban assault weapons and high-capacity magazines. - -Repeal the liability shield that makes gun manufacturers the only industry in America that can’t be sued. - -These laws don’t infringe on the Second Amendment. They save lives. - -The most fundamental right in America is the right to vote – and to have it counted. And it’s under assault. - -In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. - -We cannot let this happen. - -Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. - -Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. - -One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. - -And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence. - -A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. - -And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system. - -We can do both. At our border, we’ve installed new technology like cutting-edge scanners to better detect drug smuggling. - -We’ve set up joint patrols with Mexico and Guatemala to catch more human traffickers. - -We’re putting in place dedicated immigration judges so families fleeing persecution and violence can have their cases heard faster. - -We’re securing commitments and supporting partners in South and Central America to host more refugees and secure their own borders. - -We can do all this while keeping lit the torch of liberty that has led generations of immigrants to this land—my forefathers and so many of yours. - -Provide a pathway to citizenship for Dreamers, those on temporary status, farm workers, and essential workers. - -Revise our laws so businesses have the workers they need and families don’t wait decades to reunite. - -It’s not only the right thing to do—it’s the economically smart thing to do. - -That’s why immigration reform is supported by everyone from labor unions to religious leaders to the U.S. Chamber of Commerce. - -Let’s get it done once and for all. - -Advancing liberty and justice also requires protecting the rights of women. - -The constitutional right affirmed in Roe v. Wade—standing precedent for half a century—is under attack as never before. - -If we want to go forward—not backward—we must protect access to health care. Preserve a woman’s right to choose. And let’s continue to advance maternal health care in America. - -And for our LGBTQ+ Americans, let’s finally get the bipartisan Equality Act to my desk. The onslaught of state laws targeting transgender Americans and their families is wrong. - -As I said last year, especially to our younger transgender Americans, I will always have your back as your President, so you can be yourself and reach your God-given potential. - -While it often appears that we never agree, that isn’t true. I signed 80 bipartisan bills into law last year. From preventing government shutdowns to protecting Asian-Americans from still-too-common hate crimes to reforming military justice. - -And soon, we’ll strengthen the Violence Against Women Act that I first wrote three decades ago. It is important for us to show the nation that we can come together and do big things. - -So tonight I’m offering a Unity Agenda for the Nation. Four big things we can do together. - -First, beat the opioid epidemic. - -There is so much we can do. Increase funding for prevention, treatment, harm reduction, and recovery. - -Get rid of outdated rules that stop doctors from prescribing treatments. And stop the flow of illicit drugs by working with state and local law enforcement to go after traffickers. - -If you’re suffering from addiction, know you are not alone. I believe in recovery, and I celebrate the 23 million Americans in recovery. - -Second, let’s take on mental health. Especially among our children, whose lives and education have been turned upside down. - -The American Rescue Plan gave schools money to hire teachers and help students make up for lost learning. - -I urge every parent to make sure your school does just that. And we can all play a part—sign up to be a tutor or a mentor. - -Children were also struggling before the pandemic. Bullying, violence, trauma, and the harms of social media. - -As Frances Haugen, who is here with us tonight, has shown, we must hold social media platforms accountable for the national experiment they’re conducting on our children for profit. - -It’s time to strengthen privacy protections, ban targeted advertising to children, demand tech companies stop collecting personal data on our children. - -And let’s get all Americans the mental health services they need. More people they can turn to for help, and full parity between physical and mental health care. - -Third, support our veterans. - -Veterans are the best of us. - -I’ve always believed that we have a sacred obligation to equip all those we send to war and care for them and their families when they come home. - -My administration is providing assistance with job training and housing, and now helping lower-income veterans get VA care debt-free. - -Our troops in Iraq and Afghanistan faced many dangers. - -One was stationed at bases and breathing in toxic smoke from “burn pits” that incinerated wastes of war—medical and hazard material, jet fuel, and more. - -When they came home, many of the world’s fittest and best trained warriors were never the same. - -Headaches. Numbness. Dizziness. - -A cancer that would put them in a flag-draped coffin. - -I know. - -One of those soldiers was my son Major Beau Biden. - -We don’t know for sure if a burn pit was the cause of his brain cancer, or the diseases of so many of our troops. - -But I’m committed to finding out everything we can. - -Committed to military families like Danielle Robinson from Ohio. - -The widow of Sergeant First Class Heath Robinson. - -He was born a soldier. Army National Guard. Combat medic in Kosovo and Iraq. - -Stationed near Baghdad, just yards from burn pits the size of football fields. - -Heath’s widow Danielle is here with us tonight. They loved going to Ohio State football games. He loved building Legos with their daughter. - -But cancer from prolonged exposure to burn pits ravaged Heath’s lungs and body. - -Danielle says Heath was a fighter to the very end. - -He didn’t know how to stop fighting, and neither did she. - -Through her pain she found purpose to demand we do better. - -Tonight, Danielle—we are. - -The VA is pioneering new ways of linking toxic exposures to diseases, already helping more veterans get benefits. - -And tonight, I’m announcing we’re expanding eligibility to veterans suffering from nine respiratory cancers. - -I’m also calling on Congress: pass a law to make sure veterans devastated by toxic exposures in Iraq and Afghanistan finally get the benefits and comprehensive health care they deserve. - -And fourth, let’s end cancer as we know it. - -This is personal to me and Jill, to Kamala, and to so many of you. - -Cancer is the #2 cause of death in America–second only to heart disease. - -Last month, I announced our plan to supercharge -the Cancer Moonshot that President Obama asked me to lead six years ago. - -Our goal is to cut the cancer death rate by at least 50% over the next 25 years, turn more cancers from death sentences into treatable diseases. - -More support for patients and families. - -To get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. - -It’s based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. - -ARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer’s, diabetes, and more. - -A unity agenda for the nation. - -We can do this. - -My fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. - -In this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. - -We have fought for freedom, expanded liberty, defeated totalitarianism and terror. - -And built the strongest, freest, and most prosperous nation the world has ever known. - -Now is the hour. - -Our moment of responsibility. - -Our test of resolve and conscience, of history itself. - -It is in this moment that our character is formed. Our purpose is found. Our future is forged. - -Well I know this nation. - -We will meet the test. - -To protect freedom and liberty, to expand fairness and opportunity. - -We will save democracy. - -As hard as these times have been, I am more optimistic about America today than I have been my whole life. - -Because I see the future that is within our grasp. - -Because I know there is simply nothing beyond our capacity. - -We are the only nation on Earth that has always turned every crisis we have faced into an opportunity. - -The only nation that can be defined by a single word: possibilities. - -So on this night, in our 245th year as a nation, I have come to report on the State of the Union. - -And my report is this: the State of the Union is strong—because you, the American people, are strong. - -We are stronger today than we were a year ago. - -And we will be stronger a year from now than we are today. - -Now is our moment to meet and overcome the challenges of our time. - -And we will, as one people. - -One America. - -The United States of America. - -May God bless you all. May God protect our troops. \ No newline at end of file diff --git a/ios/.gitignore b/ios/.gitignore deleted file mode 100644 index 31d064cacb..0000000000 --- a/ios/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -xuserdata -*~ diff --git a/ios/MLCChat.xcodeproj/project.pbxproj b/ios/MLCChat.xcodeproj/project.pbxproj deleted file mode 100644 index 4c5173fa3c..0000000000 --- a/ios/MLCChat.xcodeproj/project.pbxproj +++ /dev/null @@ -1,530 +0,0 @@ -// !$*UTF8*$! -{ - archiveVersion = 1; - classes = { - }; - objectVersion = 56; - objects = { - -/* Begin PBXBuildFile section */ - 1453A4CF2A1354B9001B909F /* StartView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CA2A1354B9001B909F /* StartView.swift */; }; - 1453A4D02A1354B9001B909F /* ModelView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CB2A1354B9001B909F /* ModelView.swift */; }; - 1453A4D12A1354B9001B909F /* AppState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CC2A1354B9001B909F /* AppState.swift */; }; - 1453A4D22A1354B9001B909F /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CD2A1354B9001B909F /* ModelConfig.swift */; }; - 1453A4D32A1354B9001B909F /* ModelState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CE2A1354B9001B909F /* ModelState.swift */; }; - A773CC652A5DC98200467BFE /* ImageProcessing.swift in Sources */ = {isa = PBXBuildFile; fileRef = A773CC642A5DC98200467BFE /* ImageProcessing.swift */; }; - AEC27EFA2A85C2AC00254E67 /* ParamsConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */; }; - AEC27EFC2A85C3B000254E67 /* AppConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27EFB2A85C3B000254E67 /* AppConfig.swift */; }; - AEC27F022A86337E00254E67 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27F012A86337E00254E67 /* Constants.swift */; }; - C06A74F229F9A78800BC4BE6 /* dist in CopyFiles */ = {isa = PBXBuildFile; fileRef = C06A74E029F99C9F00BC4BE6 /* dist */; }; - C09834192A16F4E000A05B51 /* app-config.json in CopyFiles */ = {isa = PBXBuildFile; fileRef = C09834182A16F4CB00A05B51 /* app-config.json */; }; - C0D643B329F99A7F004DDAA4 /* MLCChatApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */; }; - C0D643B729F99A80004DDAA4 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B629F99A80004DDAA4 /* Assets.xcassets */; }; - C0D643BA29F99A80004DDAA4 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */; }; - C0D643C429F99B07004DDAA4 /* ChatView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C229F99B07004DDAA4 /* ChatView.swift */; }; - C0D643C829F99B34004DDAA4 /* MessageView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C729F99B34004DDAA4 /* MessageView.swift */; }; - C0DDBDF62A39103F00E9D060 /* ChatState.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C029F99B07004DDAA4 /* ChatState.swift */; }; - C0DDBE0D2A3BCD8000E9D060 /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = C0DDBE0C2A3BCD8000E9D060 /* MLCSwift */; }; -/* End PBXBuildFile section */ - -/* Begin PBXCopyFilesBuildPhase section */ - C06A74F129F9A78000BC4BE6 /* CopyFiles */ = { - isa = PBXCopyFilesBuildPhase; - buildActionMask = 2147483647; - dstPath = ""; - dstSubfolderSpec = 7; - files = ( - C09834192A16F4E000A05B51 /* app-config.json in CopyFiles */, - C06A74F229F9A78800BC4BE6 /* dist in CopyFiles */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; - C0D643CF29F99C5D004DDAA4 /* Embed Libraries */ = { - isa = PBXCopyFilesBuildPhase; - buildActionMask = 2147483647; - dstPath = ""; - dstSubfolderSpec = 10; - files = ( - ); - name = "Embed Libraries"; - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXCopyFilesBuildPhase section */ - -/* Begin PBXFileReference section */ - 1453A4CA2A1354B9001B909F /* StartView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = StartView.swift; sourceTree = ""; }; - 1453A4CB2A1354B9001B909F /* ModelView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelView.swift; sourceTree = ""; }; - 1453A4CC2A1354B9001B909F /* AppState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = AppState.swift; sourceTree = ""; }; - 1453A4CD2A1354B9001B909F /* ModelConfig.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelConfig.swift; sourceTree = ""; }; - 1453A4CE2A1354B9001B909F /* ModelState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelState.swift; sourceTree = ""; }; - A773CC642A5DC98200467BFE /* ImageProcessing.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageProcessing.swift; sourceTree = ""; }; - AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ParamsConfig.swift; sourceTree = ""; }; - AEC27EFB2A85C3B000254E67 /* AppConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppConfig.swift; sourceTree = ""; }; - AEC27F012A86337E00254E67 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = ""; }; - C06A74E029F99C9F00BC4BE6 /* dist */ = {isa = PBXFileReference; lastKnownFileType = folder; path = dist; sourceTree = ""; }; - C06A74E629F9A1DF00BC4BE6 /* MLCChat.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLCChat.entitlements; sourceTree = ""; }; - C09834182A16F4CB00A05B51 /* app-config.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = "app-config.json"; sourceTree = ""; }; - C0D643AF29F99A7F004DDAA4 /* MLCChat.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLCChat.app; sourceTree = BUILT_PRODUCTS_DIR; }; - C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLCChatApp.swift; sourceTree = ""; }; - C0D643B629F99A80004DDAA4 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; - C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; - C0D643C029F99B07004DDAA4 /* ChatState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ChatState.swift; sourceTree = ""; }; - C0D643C229F99B07004DDAA4 /* ChatView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ChatView.swift; sourceTree = ""; }; - C0D643C729F99B34004DDAA4 /* MessageView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MessageView.swift; sourceTree = ""; }; - C0DDBE0B2A3BA6F800E9D060 /* MLCSwift */ = {isa = PBXFileReference; lastKnownFileType = wrapper; path = MLCSwift; sourceTree = ""; }; -/* End PBXFileReference section */ - -/* Begin PBXFrameworksBuildPhase section */ - C0D643AC29F99A7F004DDAA4 /* Frameworks */ = { - isa = PBXFrameworksBuildPhase; - buildActionMask = 2147483647; - files = ( - C0DDBE0D2A3BCD8000E9D060 /* MLCSwift in Frameworks */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXFrameworksBuildPhase section */ - -/* Begin PBXGroup section */ - AEC27EF82A85C29000254E67 /* Models */ = { - isa = PBXGroup; - children = ( - 1453A4CD2A1354B9001B909F /* ModelConfig.swift */, - AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */, - AEC27EFB2A85C3B000254E67 /* AppConfig.swift */, - ); - path = Models; - sourceTree = ""; - }; - AEC27EFF2A85EE2800254E67 /* States */ = { - isa = PBXGroup; - children = ( - 1453A4CE2A1354B9001B909F /* ModelState.swift */, - 1453A4CC2A1354B9001B909F /* AppState.swift */, - C0D643C029F99B07004DDAA4 /* ChatState.swift */, - ); - path = States; - sourceTree = ""; - }; - AEC27F002A86306800254E67 /* Views */ = { - isa = PBXGroup; - children = ( - A773CC642A5DC98200467BFE /* ImageProcessing.swift */, - 1453A4CB2A1354B9001B909F /* ModelView.swift */, - 1453A4CA2A1354B9001B909F /* StartView.swift */, - C0D643C729F99B34004DDAA4 /* MessageView.swift */, - C0D643C229F99B07004DDAA4 /* ChatView.swift */, - ); - path = Views; - sourceTree = ""; - }; - AEC27F032A86338800254E67 /* Common */ = { - isa = PBXGroup; - children = ( - AEC27F012A86337E00254E67 /* Constants.swift */, - ); - path = Common; - sourceTree = ""; - }; - C0D643A629F99A7F004DDAA4 = { - isa = PBXGroup; - children = ( - C0DDBDF02A39068900E9D060 /* Packages */, - C06A74E029F99C9F00BC4BE6 /* dist */, - C0D643B129F99A7F004DDAA4 /* MLCChat */, - C0D643B029F99A7F004DDAA4 /* Products */, - C0D643C929F99BDA004DDAA4 /* Frameworks */, - ); - sourceTree = ""; - }; - C0D643B029F99A7F004DDAA4 /* Products */ = { - isa = PBXGroup; - children = ( - C0D643AF29F99A7F004DDAA4 /* MLCChat.app */, - ); - name = Products; - sourceTree = ""; - }; - C0D643B129F99A7F004DDAA4 /* MLCChat */ = { - isa = PBXGroup; - children = ( - C09834182A16F4CB00A05B51 /* app-config.json */, - AEC27F032A86338800254E67 /* Common */, - AEC27EF82A85C29000254E67 /* Models */, - AEC27EFF2A85EE2800254E67 /* States */, - AEC27F002A86306800254E67 /* Views */, - C06A74E629F9A1DF00BC4BE6 /* MLCChat.entitlements */, - C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */, - C0D643B629F99A80004DDAA4 /* Assets.xcassets */, - C0D643B829F99A80004DDAA4 /* Preview Content */, - ); - path = MLCChat; - sourceTree = ""; - }; - C0D643B829F99A80004DDAA4 /* Preview Content */ = { - isa = PBXGroup; - children = ( - C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */, - ); - path = "Preview Content"; - sourceTree = ""; - }; - C0D643C929F99BDA004DDAA4 /* Frameworks */ = { - isa = PBXGroup; - children = ( - ); - name = Frameworks; - sourceTree = ""; - }; - C0DDBDF02A39068900E9D060 /* Packages */ = { - isa = PBXGroup; - children = ( - C0DDBE0B2A3BA6F800E9D060 /* MLCSwift */, - ); - name = Packages; - sourceTree = ""; - }; -/* End PBXGroup section */ - -/* Begin PBXNativeTarget section */ - C0D643AE29F99A7F004DDAA4 /* MLCChat */ = { - isa = PBXNativeTarget; - buildConfigurationList = C0D643BD29F99A80004DDAA4 /* Build configuration list for PBXNativeTarget "MLCChat" */; - buildPhases = ( - C0D643AB29F99A7F004DDAA4 /* Sources */, - C0D643AC29F99A7F004DDAA4 /* Frameworks */, - C0D643AD29F99A7F004DDAA4 /* Resources */, - C0D643CF29F99C5D004DDAA4 /* Embed Libraries */, - C06A74F129F9A78000BC4BE6 /* CopyFiles */, - ); - buildRules = ( - ); - dependencies = ( - ); - name = MLCChat; - packageProductDependencies = ( - C0DDBE0C2A3BCD8000E9D060 /* MLCSwift */, - ); - productName = MLCChat; - productReference = C0D643AF29F99A7F004DDAA4 /* MLCChat.app */; - productType = "com.apple.product-type.application"; - }; -/* End PBXNativeTarget section */ - -/* Begin PBXProject section */ - C0D643A729F99A7F004DDAA4 /* Project object */ = { - isa = PBXProject; - attributes = { - BuildIndependentTargetsInParallel = 1; - LastSwiftUpdateCheck = 1430; - LastUpgradeCheck = 1430; - TargetAttributes = { - C0D643AE29F99A7F004DDAA4 = { - CreatedOnToolsVersion = 14.3; - LastSwiftMigration = 1430; - }; - }; - }; - buildConfigurationList = C0D643AA29F99A7F004DDAA4 /* Build configuration list for PBXProject "MLCChat" */; - compatibilityVersion = "Xcode 14.0"; - developmentRegion = en; - hasScannedForEncodings = 0; - knownRegions = ( - en, - Base, - ); - mainGroup = C0D643A629F99A7F004DDAA4; - productRefGroup = C0D643B029F99A7F004DDAA4 /* Products */; - projectDirPath = ""; - projectRoot = ""; - targets = ( - C0D643AE29F99A7F004DDAA4 /* MLCChat */, - ); - }; -/* End PBXProject section */ - -/* Begin PBXResourcesBuildPhase section */ - C0D643AD29F99A7F004DDAA4 /* Resources */ = { - isa = PBXResourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - C0D643BA29F99A80004DDAA4 /* Preview Assets.xcassets in Resources */, - C0D643B729F99A80004DDAA4 /* Assets.xcassets in Resources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXResourcesBuildPhase section */ - -/* Begin PBXSourcesBuildPhase section */ - C0D643AB29F99A7F004DDAA4 /* Sources */ = { - isa = PBXSourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - A773CC652A5DC98200467BFE /* ImageProcessing.swift in Sources */, - 1453A4D12A1354B9001B909F /* AppState.swift in Sources */, - C0D643B329F99A7F004DDAA4 /* MLCChatApp.swift in Sources */, - C0DDBDF62A39103F00E9D060 /* ChatState.swift in Sources */, - C0D643C429F99B07004DDAA4 /* ChatView.swift in Sources */, - 1453A4D32A1354B9001B909F /* ModelState.swift in Sources */, - C0D643C829F99B34004DDAA4 /* MessageView.swift in Sources */, - 1453A4D22A1354B9001B909F /* ModelConfig.swift in Sources */, - AEC27EFA2A85C2AC00254E67 /* ParamsConfig.swift in Sources */, - AEC27EFC2A85C3B000254E67 /* AppConfig.swift in Sources */, - AEC27F022A86337E00254E67 /* Constants.swift in Sources */, - 1453A4D02A1354B9001B909F /* ModelView.swift in Sources */, - 1453A4CF2A1354B9001B909F /* StartView.swift in Sources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXSourcesBuildPhase section */ - -/* Begin XCBuildConfiguration section */ - C0D643BB29F99A80004DDAA4 /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_ANALYZER_NONNULL = YES; - CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_ENABLE_OBJC_WEAK = YES; - CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_COMMA = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_DOCUMENTATION_COMMENTS = YES; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; - CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; - CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; - CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; - CLANG_WARN_STRICT_PROTOTYPES = YES; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = dwarf; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_TESTABILITY = YES; - GCC_C_LANGUAGE_STANDARD = gnu11; - GCC_DYNAMIC_NO_PIC = NO; - GCC_NO_COMMON_BLOCKS = YES; - GCC_OPTIMIZATION_LEVEL = 0; - GCC_PREPROCESSOR_DEFINITIONS = ( - "DEBUG=1", - "$(inherited)", - ); - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 16.0; - MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; - MTL_FAST_MATH = YES; - ONLY_ACTIVE_ARCH = YES; - SDKROOT = iphoneos; - SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; - SWIFT_OPTIMIZATION_LEVEL = "-Onone"; - }; - name = Debug; - }; - C0D643BC29F99A80004DDAA4 /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_ANALYZER_NONNULL = YES; - CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_ENABLE_OBJC_WEAK = YES; - CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_COMMA = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_DOCUMENTATION_COMMENTS = YES; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; - CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; - CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; - CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; - CLANG_WARN_STRICT_PROTOTYPES = YES; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; - ENABLE_NS_ASSERTIONS = NO; - ENABLE_STRICT_OBJC_MSGSEND = YES; - GCC_C_LANGUAGE_STANDARD = gnu11; - GCC_NO_COMMON_BLOCKS = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 16.0; - MTL_ENABLE_DEBUG_INFO = NO; - MTL_FAST_MATH = YES; - SDKROOT = iphoneos; - SWIFT_COMPILATION_MODE = wholemodule; - SWIFT_OPTIMIZATION_LEVEL = "-O"; - VALIDATE_PRODUCT = YES; - }; - name = Release; - }; - C0D643BE29F99A80004DDAA4 /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; - ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; - CLANG_ENABLE_MODULES = YES; - CODE_SIGN_ENTITLEMENTS = MLCChat/MLCChat.entitlements; - CODE_SIGN_IDENTITY = "Apple Development"; - CODE_SIGN_STYLE = Automatic; - CURRENT_PROJECT_VERSION = 1; - DEVELOPMENT_ASSET_PATHS = "\"MLCChat/Preview Content\""; - DEVELOPMENT_TEAM = 3FR42MXLK9; - ENABLE_PREVIEWS = YES; - GENERATE_INFOPLIST_FILE = YES; - "HEADER_SEARCH_PATHS[arch=*]" = ""; - INFOPLIST_FILE = MLCChat/Info.plist; - INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.productivity"; - INFOPLIST_KEY_NSCameraUsageDescription = "This app requires usage of camera to function properly."; - INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; - INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; - INFOPLIST_KEY_UILaunchScreen_Generation = YES; - INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; - INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; - LD_RUNPATH_SEARCH_PATHS = ( - "$(inherited)", - "@executable_path/Frameworks", - ); - LIBRARY_SEARCH_PATHS = ( - "$(inherited)", - "$(PROJECT_DIR)/build/lib", - ); - MARKETING_VERSION = 1.3; - OTHER_LDFLAGS = ( - "-Wl,-all_load", - "-lmodel_iphone", - "-lmlc_llm", - "-ltvm_runtime", - "-ltokenizers_cpp", - "-lsentencepiece", - "-ltokenizers_c", - ); - PRODUCT_BUNDLE_IDENTIFIER = mlc.Chat; - PRODUCT_NAME = "$(TARGET_NAME)"; - PROVISIONING_PROFILE_SPECIFIER = ""; - SWIFT_EMIT_LOC_STRINGS = YES; - SWIFT_OBJC_BRIDGING_HEADER = ""; - SWIFT_OPTIMIZATION_LEVEL = "-Onone"; - SWIFT_VERSION = 5.0; - TARGETED_DEVICE_FAMILY = "1,2"; - }; - name = Debug; - }; - C0D643BF29F99A80004DDAA4 /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; - ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; - CLANG_ENABLE_MODULES = YES; - CODE_SIGN_ENTITLEMENTS = MLCChat/MLCChat.entitlements; - CODE_SIGN_IDENTITY = "Apple Development"; - CODE_SIGN_STYLE = Automatic; - CURRENT_PROJECT_VERSION = 1; - DEVELOPMENT_ASSET_PATHS = "\"MLCChat/Preview Content\""; - DEVELOPMENT_TEAM = 3FR42MXLK9; - ENABLE_PREVIEWS = YES; - GENERATE_INFOPLIST_FILE = YES; - "HEADER_SEARCH_PATHS[arch=*]" = ""; - INFOPLIST_FILE = MLCChat/Info.plist; - INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.productivity"; - INFOPLIST_KEY_NSCameraUsageDescription = "This app requires usage of camera to function properly."; - INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; - INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; - INFOPLIST_KEY_UILaunchScreen_Generation = YES; - INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; - INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; - LD_RUNPATH_SEARCH_PATHS = ( - "$(inherited)", - "@executable_path/Frameworks", - ); - LIBRARY_SEARCH_PATHS = ( - "$(inherited)", - "$(PROJECT_DIR)/build/lib", - ); - MARKETING_VERSION = 1.3; - OTHER_LDFLAGS = ( - "-Wl,-all_load", - "-lmodel_iphone", - "-lmlc_llm", - "-ltvm_runtime", - "-ltokenizers_cpp", - "-lsentencepiece", - "-ltokenizers_c", - ); - PRODUCT_BUNDLE_IDENTIFIER = mlc.Chat; - PRODUCT_NAME = "$(TARGET_NAME)"; - PROVISIONING_PROFILE_SPECIFIER = ""; - SWIFT_EMIT_LOC_STRINGS = YES; - SWIFT_OBJC_BRIDGING_HEADER = ""; - SWIFT_VERSION = 5.0; - TARGETED_DEVICE_FAMILY = "1,2"; - }; - name = Release; - }; -/* End XCBuildConfiguration section */ - -/* Begin XCConfigurationList section */ - C0D643AA29F99A7F004DDAA4 /* Build configuration list for PBXProject "MLCChat" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - C0D643BB29F99A80004DDAA4 /* Debug */, - C0D643BC29F99A80004DDAA4 /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; - C0D643BD29F99A80004DDAA4 /* Build configuration list for PBXNativeTarget "MLCChat" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - C0D643BE29F99A80004DDAA4 /* Debug */, - C0D643BF29F99A80004DDAA4 /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; -/* End XCConfigurationList section */ - -/* Begin XCSwiftPackageProductDependency section */ - C0DDBE0C2A3BCD8000E9D060 /* MLCSwift */ = { - isa = XCSwiftPackageProductDependency; - productName = MLCSwift; - }; -/* End XCSwiftPackageProductDependency section */ - }; - rootObject = C0D643A729F99A7F004DDAA4 /* Project object */; -} diff --git a/ios/MLCChat.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/ios/MLCChat.xcodeproj/project.xcworkspace/contents.xcworkspacedata deleted file mode 100644 index 919434a625..0000000000 --- a/ios/MLCChat.xcodeproj/project.xcworkspace/contents.xcworkspacedata +++ /dev/null @@ -1,7 +0,0 @@ - - - - - diff --git a/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist deleted file mode 100644 index 18d981003d..0000000000 --- a/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +++ /dev/null @@ -1,8 +0,0 @@ - - - - - IDEDidComputeMac32BitWarning - - - diff --git a/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings b/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings deleted file mode 100644 index 0c67376eba..0000000000 --- a/ios/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/ios/MLCChat.xcodeproj/xcshareddata/xcschemes/MLCChat.xcscheme b/ios/MLCChat.xcodeproj/xcshareddata/xcschemes/MLCChat.xcscheme deleted file mode 100644 index 311123f330..0000000000 --- a/ios/MLCChat.xcodeproj/xcshareddata/xcschemes/MLCChat.xcscheme +++ /dev/null @@ -1,81 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/ios/MLCChat/Assets.xcassets/AccentColor.colorset/Contents.json b/ios/MLCChat/Assets.xcassets/AccentColor.colorset/Contents.json deleted file mode 100644 index eb87897008..0000000000 --- a/ios/MLCChat/Assets.xcassets/AccentColor.colorset/Contents.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "colors" : [ - { - "idiom" : "universal" - } - ], - "info" : { - "author" : "xcode", - "version" : 1 - } -} diff --git a/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/Contents.json b/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/Contents.json deleted file mode 100644 index 7324dc205a..0000000000 --- a/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/Contents.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "images" : [ - { - "filename" : "mlc-logo.png", - "idiom" : "universal", - "platform" : "ios", - "size" : "1024x1024" - } - ], - "info" : { - "author" : "xcode", - "version" : 1 - } -} diff --git a/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/mlc-logo.png b/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/mlc-logo.png deleted file mode 100644 index 4ae381da6c..0000000000 Binary files a/ios/MLCChat/Assets.xcassets/AppIcon.appiconset/mlc-logo.png and /dev/null differ diff --git a/ios/MLCChat/Assets.xcassets/Contents.json b/ios/MLCChat/Assets.xcassets/Contents.json deleted file mode 100644 index 73c00596a7..0000000000 --- a/ios/MLCChat/Assets.xcassets/Contents.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "info" : { - "author" : "xcode", - "version" : 1 - } -} diff --git a/ios/MLCChat/Common/Constants.swift b/ios/MLCChat/Common/Constants.swift deleted file mode 100644 index cf3a240fcf..0000000000 --- a/ios/MLCChat/Common/Constants.swift +++ /dev/null @@ -1,11 +0,0 @@ -// -// Constants.swift -// MLCChat -// - -struct Constants { - static let prebuiltModelDir = "dist" - static let appConfigFileName = "app-config.json" - static let modelConfigFileName = "mlc-chat-config.json" - static let paramsConfigFileName = "ndarray-cache.json" -} diff --git a/ios/MLCChat/Info.plist b/ios/MLCChat/Info.plist deleted file mode 100644 index 0c67376eba..0000000000 --- a/ios/MLCChat/Info.plist +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/ios/MLCChat/MLCChat.entitlements b/ios/MLCChat/MLCChat.entitlements deleted file mode 100644 index caa3d58396..0000000000 --- a/ios/MLCChat/MLCChat.entitlements +++ /dev/null @@ -1,10 +0,0 @@ - - - - - com.apple.developer.kernel.extended-virtual-addressing - - com.apple.developer.kernel.increased-memory-limit - - - diff --git a/ios/MLCChat/MLCChatApp.swift b/ios/MLCChat/MLCChatApp.swift deleted file mode 100644 index fcefd6fbb4..0000000000 --- a/ios/MLCChat/MLCChatApp.swift +++ /dev/null @@ -1,28 +0,0 @@ -// -// MLCChatApp.swift -// MLCChat -// -// Created by Tianqi Chen on 4/26/23. -// - -import SwiftUI - -@main -struct MLCChatApp: App { - @StateObject private var appState = AppState() - - init() { - UITableView.appearance().separatorStyle = .none - UITableView.appearance().tableFooterView = UIView() - } - - var body: some Scene { - WindowGroup { - StartView() - .environmentObject(appState) - .task { - appState.loadAppConfigAndModels() - } - } - } -} diff --git a/ios/MLCChat/Models/AppConfig.swift b/ios/MLCChat/Models/AppConfig.swift deleted file mode 100644 index 69867b0857..0000000000 --- a/ios/MLCChat/Models/AppConfig.swift +++ /dev/null @@ -1,28 +0,0 @@ -// -// AppConfig.swift -// MLCChat -// - -struct AppConfig: Codable { - struct ModelRecord: Codable { - let modelPath: String? - let modelURL: String? - let modelLib: String - let estimatedVRAMReq: Int - let modelID: String - - enum CodingKeys: String, CodingKey { - case modelPath = "model_path" - case modelURL = "model_url" - case modelLib = "model_lib" - case estimatedVRAMReq = "estimated_vram_bytes" - case modelID = "model_id" - } - } - - var modelList: [ModelRecord] - - enum CodingKeys: String, CodingKey { - case modelList = "model_list" - } -} diff --git a/ios/MLCChat/Models/ModelConfig.swift b/ios/MLCChat/Models/ModelConfig.swift deleted file mode 100644 index 4ed8819c1b..0000000000 --- a/ios/MLCChat/Models/ModelConfig.swift +++ /dev/null @@ -1,18 +0,0 @@ -// -// ModelConfig.swift -// MLCChat -// - -struct ModelConfig: Decodable { - let tokenizerFiles: [String] - var modelLib: String? - var modelID: String? - var estimatedVRAMReq: Int? - - enum CodingKeys: String, CodingKey { - case tokenizerFiles = "tokenizer_files" - case modelLib = "model_lib" - case modelID = "model_id" - case estimatedVRAMReq = "estimated_vram_req" - } -} diff --git a/ios/MLCChat/Models/ParamsConfig.swift b/ios/MLCChat/Models/ParamsConfig.swift deleted file mode 100644 index 2635afabe8..0000000000 --- a/ios/MLCChat/Models/ParamsConfig.swift +++ /dev/null @@ -1,12 +0,0 @@ -// -// ParamsConfig.swift -// MLCChat -// - -struct ParamsConfig: Decodable { - struct ParamsRecord: Decodable { - let dataPath: String - } - - let records: [ParamsRecord] -} diff --git a/ios/MLCChat/Preview Content/Preview Assets.xcassets/Contents.json b/ios/MLCChat/Preview Content/Preview Assets.xcassets/Contents.json deleted file mode 100644 index 73c00596a7..0000000000 --- a/ios/MLCChat/Preview Content/Preview Assets.xcassets/Contents.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "info" : { - "author" : "xcode", - "version" : 1 - } -} diff --git a/ios/MLCChat/States/AppState.swift b/ios/MLCChat/States/AppState.swift deleted file mode 100644 index 4b8af5086f..0000000000 --- a/ios/MLCChat/States/AppState.swift +++ /dev/null @@ -1,269 +0,0 @@ -// -// AppState.swift -// MLCChat -// -// Created by Yaxing Cai on 5/13/23. -// - -import Foundation - -final class AppState: ObservableObject { - @Published var models = [ModelState]() - @Published var chatState = ChatState() - - @Published var alertMessage = "" // TODO: Should move out - @Published var alertDisplayed = false // TODO: Should move out - - private var appConfig: AppConfig? - private var modelIDs = Set() - - private let fileManager: FileManager = FileManager.default - private lazy var cacheDirectoryURL: URL = { - fileManager.urls(for: .cachesDirectory, in: .userDomainMask)[0] - }() - - private let jsonDecoder = JSONDecoder() - private let jsonEncoder = JSONEncoder() - - func loadAppConfigAndModels() { - appConfig = loadAppConfig() - // Can't do anything without a valid app config - guard let appConfig else { - return - } - loadModelsConfig(modelList: appConfig.modelList) - } - - func requestDeleteModel(modelID: String) { - // model dir should have been deleted in ModelState - assert(!fileManager.fileExists(atPath: cacheDirectoryURL.appending(path: modelID).path())) - modelIDs.remove(modelID) - models.removeAll(where: {$0.modelConfig.modelID == modelID}) - updateAppConfig { - appConfig?.modelList.removeAll(where: {$0.modelID == modelID}) - } - } -} - -private extension AppState { - func loadAppConfig() -> AppConfig? { - // models in cache to download - var appConfigFileURL = cacheDirectoryURL.appending(path: Constants.appConfigFileName) - if !fileManager.fileExists(atPath: appConfigFileURL.path()) { - appConfigFileURL = Bundle.main.bundleURL.appending(path: Constants.appConfigFileName) - } - assert(fileManager.fileExists(atPath: appConfigFileURL.path())) - - do { - let fileHandle = try FileHandle(forReadingFrom: appConfigFileURL) - let data = fileHandle.readDataToEndOfFile() - - let appConfig = try jsonDecoder.decode(AppConfig.self, from: data) - return appConfig - } catch { - showAlert(message: "Failed to load app config: \(error.localizedDescription)") - return nil - } - } - - func loadModelsConfig(modelList: [AppConfig.ModelRecord]) { - for model in modelList { - if model.modelPath != nil { - // local model - let modelDir = Bundle.main.bundleURL.appending(path: Constants.prebuiltModelDir).appending(path: model.modelPath!) - let modelConfigURL = modelDir.appending(path: Constants.modelConfigFileName) - if fileManager.fileExists(atPath: modelConfigURL.path()) { - if let modelConfig = loadModelConfig( - modelConfigURL: modelConfigURL, - modelLib: model.modelLib, - modelID: model.modelID, - estimatedVRAMReq: model.estimatedVRAMReq - ) { - addModelConfig( - modelConfig: modelConfig, - modelPath: model.modelPath!, - modelURL: nil, - isBuiltin: true - ) - } else { - showAlert(message: "Failed to load prebuilt model: \(model.modelPath!)") - } - } else { - showAlert(message: "Prebuilt mlc-chat-config.json file not found: \(model.modelPath!)") - } - } else if model.modelURL != nil { - // remote model - let modelConfigFileURL = cacheDirectoryURL - .appending(path: model.modelID) - .appending(path: Constants.modelConfigFileName) - if fileManager.fileExists(atPath: modelConfigFileURL.path()) { - if let modelConfig = loadModelConfig( - modelConfigURL: modelConfigFileURL, - modelLib: model.modelLib, - modelID: model.modelID, - estimatedVRAMReq: model.estimatedVRAMReq - ) { - addModelConfig( - modelConfig: modelConfig, - modelPath: nil, - modelURL: URL(string: model.modelURL!), - isBuiltin: true - ) - } - } else { - downloadConfig( - modelURL: URL(string: model.modelURL!), - modelLib: model.modelLib, - modelID: model.modelID, - estimatedVRAMReq: model.estimatedVRAMReq, - isBuiltin: true - ) - } - } else { - showAlert(message: "Path or URL should be provided in app config: \(model.modelID)") - } - } - } - - func loadModelConfig(modelConfigURL: URL, modelLib: String, modelID: String, estimatedVRAMReq: Int) -> ModelConfig? { - do { - assert(fileManager.fileExists(atPath: modelConfigURL.path())) - let fileHandle = try FileHandle(forReadingFrom: modelConfigURL) - let data = fileHandle.readDataToEndOfFile() - var modelConfig = try jsonDecoder.decode(ModelConfig.self, from: data) - modelConfig.modelLib = modelLib - modelConfig.modelID = modelID - modelConfig.estimatedVRAMReq = estimatedVRAMReq - return modelConfig - } catch { - showAlert(message: "Failed to resolve model config: \(error.localizedDescription)") - } - return nil - } - - func showAlert(message: String) { - DispatchQueue.main.async { [weak self] in - guard let self = self else { return } - if !self.alertDisplayed { - self.alertMessage = message - self.alertDisplayed = true - } else { - self.alertMessage.append("\n" + message) - } - } - } - - func downloadConfig(modelURL: URL?, modelLib: String, modelID: String, estimatedVRAMReq: Int, isBuiltin: Bool) { - guard let modelConfigURL = modelURL?.appending(path: "resolve").appending(path: "main").appending(path: Constants.modelConfigFileName) else { - return - } - - let downloadTask = URLSession.shared.downloadTask(with: modelConfigURL) { - [weak self] urlOrNil, responseOrNil, errorOrNil in - guard let self else { - return - } - if let error = errorOrNil { - self.showAlert(message: "Failed to download model config: \(error.localizedDescription)") - return - } - guard let fileUrl = urlOrNil else { - self.showAlert(message: "Failed to download model config") - return - } - - // cache temp file to avoid being deleted by system automatically - let tempName = UUID().uuidString - let tempFileURL = self.cacheDirectoryURL.appending(path: tempName) - - do { - try self.fileManager.moveItem(at: fileUrl, to: tempFileURL) - } catch { - self.showAlert(message: "Failed to cache downloaded file: \(error.localizedDescription)") - return - } - - do { - guard let modelConfig = loadModelConfig( - modelConfigURL: tempFileURL, - modelLib: modelLib, - modelID: modelID, - estimatedVRAMReq: estimatedVRAMReq - ) else { - try fileManager.removeItem(at: tempFileURL) - return - } - - if modelIDs.contains(modelConfig.modelID!) { - try fileManager.removeItem(at: tempFileURL) - return - } - - let modelBaseUrl = cacheDirectoryURL.appending(path: modelConfig.modelID!) - try fileManager.createDirectory(at: modelBaseUrl, withIntermediateDirectories: true) - let modelConfigUrl = modelBaseUrl.appending(path: Constants.modelConfigFileName) - try fileManager.moveItem(at: tempFileURL, to: modelConfigUrl) - assert(fileManager.fileExists(atPath: modelConfigUrl.path())) - assert(!fileManager.fileExists(atPath: tempFileURL.path())) - addModelConfig( - modelConfig: modelConfig, - modelPath: nil, - modelURL: modelURL, - isBuiltin: isBuiltin - ) - } catch { - showAlert(message: "Failed to import model: \(error.localizedDescription)") - } - } - downloadTask.resume() - } - - func addModelConfig(modelConfig: ModelConfig, modelPath: String?, modelURL: URL?, isBuiltin: Bool) { - assert(!modelIDs.contains(modelConfig.modelID!)) - modelIDs.insert(modelConfig.modelID!) - let modelBaseURL: URL - - // model_id dir should exist - if modelURL == nil { - // prebuilt model in dist - modelBaseURL = Bundle.main.bundleURL.appending(path: Constants.prebuiltModelDir).appending(path: modelPath!) - } else { - // download model in cache - modelBaseURL = cacheDirectoryURL.appending(path: modelConfig.modelID!) - } - assert(fileManager.fileExists(atPath: modelBaseURL.path())) - - // mlc-chat-config.json should exist - let modelConfigURL = modelBaseURL.appending(path: Constants.modelConfigFileName) - assert(fileManager.fileExists(atPath: modelConfigURL.path())) - - let model = ModelState(modelConfig: modelConfig, modelLocalBaseURL: modelBaseURL, startState: self, chatState: chatState) - model.checkModelDownloadState(modelURL: modelURL) - models.append(model) - - if modelURL != nil && !isBuiltin { - updateAppConfig { - appConfig?.modelList.append( - AppConfig.ModelRecord( - modelPath: nil, - modelURL: modelURL!.absoluteString, - modelLib: modelConfig.modelLib!, - estimatedVRAMReq: modelConfig.estimatedVRAMReq!, - modelID: modelConfig.modelID! - ) - ) - } - } - } - - func updateAppConfig(action: () -> Void) { - action() - let appConfigURL = cacheDirectoryURL.appending(path: Constants.appConfigFileName) - do { - let data = try jsonEncoder.encode(appConfig) - try data.write(to: appConfigURL, options: Data.WritingOptions.atomic) - } catch { - print(error.localizedDescription) - } - } -} diff --git a/ios/MLCChat/States/ChatState.swift b/ios/MLCChat/States/ChatState.swift deleted file mode 100644 index 7a5a60f66f..0000000000 --- a/ios/MLCChat/States/ChatState.swift +++ /dev/null @@ -1,339 +0,0 @@ -// -// ChatState.swift -// LLMChat -// - -import Foundation -import MLCSwift - -enum MessageRole { - case user - case bot -} - -extension MessageRole { - var isUser: Bool { self == .user } -} - -struct MessageData: Hashable { - let id = UUID() - var role: MessageRole - var message: String -} - -final class ChatState: ObservableObject { - fileprivate enum ModelChatState { - case generating - case resetting - case reloading - case terminating - case ready - case failed - case pendingImageUpload - case processingImage - } - - @Published var messages = [MessageData]() - @Published var infoText = "" - @Published var displayName = "" - @Published var useVision = false - - private let modelChatStateLock = NSLock() - private var modelChatState: ModelChatState = .ready - - private let threadWorker = ThreadWorker() - private let chatModule = ChatModule() - private var modelLib = "" - private var modelPath = "" - var modelID = "" - - init() { - threadWorker.qualityOfService = QualityOfService.userInteractive - threadWorker.start() - } - - var isInterruptible: Bool { - return getModelChatState() == .ready - || getModelChatState() == .generating - || getModelChatState() == .failed - || getModelChatState() == .pendingImageUpload - } - - var isChattable: Bool { - return getModelChatState() == .ready - } - - var isUploadable: Bool { - return getModelChatState() == .pendingImageUpload - } - - var isResettable: Bool { - return getModelChatState() == .ready - || getModelChatState() == .generating - } - - func requestResetChat() { - assert(isResettable) - interruptChat(prologue: { - switchToResetting() - }, epilogue: { [weak self] in - self?.mainResetChat() - }) - } - - func requestTerminateChat(callback: @escaping () -> Void) { - assert(isInterruptible) - interruptChat(prologue: { - switchToTerminating() - }, epilogue: { [weak self] in - self?.mainTerminateChat(callback: callback) - }) - } - - func requestReloadChat(modelID: String, modelLib: String, modelPath: String, estimatedVRAMReq: Int, displayName: String) { - if (isCurrentModel(modelID: modelID)) { - return - } - assert(isInterruptible) - interruptChat(prologue: { - switchToReloading() - }, epilogue: { [weak self] in - self?.mainReloadChat(modelID: modelID, - modelLib: modelLib, - modelPath: modelPath, - estimatedVRAMReq: estimatedVRAMReq, - displayName: displayName) - }) - } - - func requestGenerate(prompt: String) { - assert(isChattable) - switchToGenerating() - appendMessage(role: .user, message: prompt) - appendMessage(role: .bot, message: "") - threadWorker.push {[weak self] in - guard let self else { return } - chatModule.prefill(prompt) - while !chatModule.stopped() { - chatModule.decode() - if let newText = chatModule.getMessage() { - DispatchQueue.main.async { - self.updateMessage(role: .bot, message: newText) - } - } - - if getModelChatState() != .generating { - break - } - } - if getModelChatState() == .generating { - if let runtimeStats = chatModule.runtimeStatsText(useVision) { - DispatchQueue.main.async { - self.infoText = runtimeStats - self.switchToReady() - } - } - } - } - } - - func requestProcessImage(image: UIImage) { - assert(getModelChatState() == .pendingImageUpload) - switchToProcessingImage() - threadWorker.push {[weak self] in - guard let self else { return } - assert(messages.count > 0) - DispatchQueue.main.async { - self.updateMessage(role: .bot, message: "[System] Processing image") - } - // step 1. resize image - let new_image = resizeImage(image: image, width: 112, height: 112) - // step 2. prefill image by chatModule.prefillImage() - chatModule.prefillImage(new_image, prevPlaceholder: "", postPlaceholder: " ") - DispatchQueue.main.async { - self.updateMessage(role: .bot, message: "[System] Ready to chat") - self.switchToReady() - } - } - } - - func isCurrentModel(modelID: String) -> Bool { - return self.modelID == modelID - } -} - -private extension ChatState { - func getModelChatState() -> ModelChatState { - modelChatStateLock.lock() - defer { modelChatStateLock.unlock() } - return modelChatState - } - - func setModelChatState(_ newModelChatState: ModelChatState) { - modelChatStateLock.lock() - modelChatState = newModelChatState - modelChatStateLock.unlock() - } - - func appendMessage(role: MessageRole, message: String) { - messages.append(MessageData(role: role, message: message)) - } - - func updateMessage(role: MessageRole, message: String) { - messages[messages.count - 1] = MessageData(role: role, message: message) - } - - func clearHistory() { - messages.removeAll() - infoText = "" - } - - func switchToResetting() { - setModelChatState(.resetting) - } - - func switchToGenerating() { - setModelChatState(.generating) - } - - func switchToReloading() { - setModelChatState(.reloading) - } - - func switchToReady() { - setModelChatState(.ready) - } - - func switchToTerminating() { - setModelChatState(.terminating) - } - - func switchToFailed() { - setModelChatState(.failed) - } - - func switchToPendingImageUpload() { - setModelChatState(.pendingImageUpload) - } - - func switchToProcessingImage() { - setModelChatState(.processingImage) - } - - func interruptChat(prologue: () -> Void, epilogue: @escaping () -> Void) { - assert(isInterruptible) - if getModelChatState() == .ready - || getModelChatState() == .failed - || getModelChatState() == .pendingImageUpload { - prologue() - epilogue() - } else if getModelChatState() == .generating { - prologue() - threadWorker.push { - DispatchQueue.main.async { - epilogue() - } - } - } else { - assert(false) - } - } - - func mainResetChat() { - threadWorker.push {[weak self] in - guard let self else { return } - chatModule.resetChat() - if useVision { - chatModule.resetImageModule() - } - DispatchQueue.main.async { - self.clearHistory() - if self.useVision { - self.appendMessage(role: .bot, message: "[System] Upload an image to chat") - self.switchToPendingImageUpload() - } else { - self.switchToReady() - } - } - } - } - - func mainTerminateChat(callback: @escaping () -> Void) { - threadWorker.push {[weak self] in - guard let self else { return } - if useVision { - chatModule.unloadImageModule() - } - chatModule.unload() - DispatchQueue.main.async { - self.clearHistory() - self.modelID = "" - self.modelLib = "" - self.modelPath = "" - self.displayName = "" - self.useVision = false - self.switchToReady() - callback() - } - } - } - - func mainReloadChat(modelID: String, modelLib: String, modelPath: String, estimatedVRAMReq: Int, displayName: String) { - clearHistory() - let prevUseVision = useVision - self.modelID = modelID - self.modelLib = modelLib - self.modelPath = modelPath - self.displayName = displayName - self.useVision = displayName.hasPrefix("minigpt") - threadWorker.push {[weak self] in - guard let self else { return } - DispatchQueue.main.async { - self.appendMessage(role: .bot, message: "[System] Initalize...") - } - if prevUseVision { - chatModule.unloadImageModule() - } - chatModule.unload() - let vRAM = os_proc_available_memory() - if (vRAM < estimatedVRAMReq) { - let requiredMemory = String ( - format: "%.1fMB", Double(estimatedVRAMReq) / Double(1 << 20) - ) - let errorMessage = ( - "Sorry, the system cannot provide \(requiredMemory) VRAM as requested to the app, " + - "so we cannot initialize this model on this device." - ) - DispatchQueue.main.sync { - self.messages.append(MessageData(role: MessageRole.bot, message: errorMessage)) - self.switchToFailed() - } - return - } - - if useVision { - // load vicuna model - let dir = (modelPath as NSString).deletingLastPathComponent - let vicunaModelLib = "vicuna-7b-v1.3-q3f16_0" - let vicunaModelPath = dir + "/" + vicunaModelLib - let appConfigJSONData = try? JSONSerialization.data(withJSONObject: ["conv_template": "minigpt"], options: []) - let appConfigJSON = String(data: appConfigJSONData!, encoding: .utf8) - chatModule.reload(vicunaModelLib, modelPath: vicunaModelPath, appConfigJson: appConfigJSON) - // load image model - chatModule.reloadImageModule(modelLib, modelPath: modelPath) - } else { - chatModule.reload(modelLib, modelPath: modelPath, appConfigJson: "") - } - - DispatchQueue.main.async { - if self.useVision { - self.updateMessage(role: .bot, message: "[System] Upload an image to chat") - self.switchToPendingImageUpload() - } else { - self.updateMessage(role: .bot, message: "[System] Ready to chat") - self.switchToReady() - } - } - } - } -} diff --git a/ios/MLCChat/States/ModelState.swift b/ios/MLCChat/States/ModelState.swift deleted file mode 100644 index ed229101f3..0000000000 --- a/ios/MLCChat/States/ModelState.swift +++ /dev/null @@ -1,414 +0,0 @@ -// -// ModelState.swift -// MLCChat -// - -import Foundation - -final class ModelState: ObservableObject, Identifiable { - enum ModelDownloadState { - case initializing - case indexing - case paused - case downloading - case pausing - case verifying - case finished - case failed - case clearing - case deleting - } - - fileprivate struct DownloadTask: Hashable { - let remoteURL: URL - let localURL: URL - } - - @Published var modelConfig: ModelConfig - @Published var modelDownloadState: ModelDownloadState = .initializing - @Published var progress: Int = 0 - @Published var total: Int = 1 - - private var modelLocalBaseURL: URL - private var startState: AppState - private var chatState: ChatState - - private let fileManager: FileManager = FileManager.default - private let decoder = JSONDecoder() - private var paramsConfig: ParamsConfig? - private var modelRemoteBaseURL: URL? - private var remainingTasks: Set = Set() - private var downloadingTasks: Set = Set() - private var maxDownloadingTasks: Int = 3 - - init(modelConfig: ModelConfig, - modelLocalBaseURL: URL, - startState: AppState, - chatState: ChatState) { - self.modelConfig = modelConfig - self.modelLocalBaseURL = modelLocalBaseURL - self.startState = startState - self.chatState = chatState - } - - func checkModelDownloadState(modelURL: URL?) { - createModelFolderIfNeeded() - - guard let modelURL else { - switchToVerifying() - return - } - - modelRemoteBaseURL = modelURL.appending(path: "resolve").appending(path: "main") - - // create local params dir - let paramsConfigURL = modelLocalBaseURL.appending(path: Constants.paramsConfigFileName) - if fileManager.fileExists(atPath: paramsConfigURL.path()) { - // ndarray-cache.json already downloaded - loadParamsConfig() - switchToIndexing() - } else { - // download ndarray-cache.json - downloadParamsConfig() - } - } - - func startChat(chatState: ChatState) { - chatState.requestReloadChat( - modelID: modelConfig.modelID!, - modelLib: modelConfig.modelLib!, - modelPath: modelLocalBaseURL.path(), - estimatedVRAMReq: modelConfig.estimatedVRAMReq!, - displayName: modelConfig.modelID!.components(separatedBy: "-")[0] - ) - } - - func handleStart() { - // start downloading - switchToDownloading() - } - - func handlePause() { - // pause downloading - switchToPausing() - } - - func handleClear() { - assert(modelDownloadState == .downloading || modelDownloadState == .paused || modelDownloadState == .finished) - switchToClearing() - } - - func handleDelete() { - assert(modelDownloadState == .downloading || modelDownloadState == .paused || modelDownloadState == .finished || modelDownloadState == .failed) - switchToDeleting() - } -} - -private extension ModelState { - func createModelFolderIfNeeded() { - if !fileManager.fileExists(atPath: modelLocalBaseURL.path()) { - do { - try fileManager.createDirectory(at: modelLocalBaseURL, withIntermediateDirectories: true) - } catch { - print(error.localizedDescription) - } - } - } - - func loadParamsConfig() { - let paramsConfigURL = modelLocalBaseURL.appending(path: Constants.paramsConfigFileName) - assert(fileManager.fileExists(atPath: paramsConfigURL.path())) - do { - let fileHandle = try FileHandle(forReadingFrom: paramsConfigURL) - let data = fileHandle.readDataToEndOfFile() - paramsConfig = try self.decoder.decode(ParamsConfig.self, from: data) - } catch { - print(error.localizedDescription) - } - } - - func downloadParamsConfig() { - guard let modelRemoteBaseURL else { - return - } - - let paramsConfigURL = modelLocalBaseURL.appending(path: Constants.paramsConfigFileName) - let downloadTask = URLSession.shared.downloadTask(with: modelRemoteBaseURL.appending(path: Constants.paramsConfigFileName)) { - [weak self] urlOrNil, responseOrNil, errorOrNil in - guard let self else { return } - guard let fileURL = urlOrNil else { return } - do { - try? self.fileManager.removeItem(at: paramsConfigURL) - try self.fileManager.moveItem(at: fileURL, to: paramsConfigURL) - DispatchQueue.main.async { - self.loadParamsConfig() - self.switchToIndexing() - } - } catch { - print(error.localizedDescription) - } - } - downloadTask.resume() - } - - func switchToIndexing() { - guard let paramsConfig, let modelRemoteBaseURL else { - return - } - - modelDownloadState = .indexing - progress = 0 - total = modelConfig.tokenizerFiles.count + paramsConfig.records.count - - // collect tokenizer download tasks - for tokenizerFile in modelConfig.tokenizerFiles { - let remoteURL = modelRemoteBaseURL.appending(path: tokenizerFile) - let localURL = modelLocalBaseURL.appending(path: tokenizerFile) - - if fileManager.fileExists(atPath: localURL.path()) { - progress += 1 - } else { - remainingTasks.insert(DownloadTask(remoteURL: remoteURL, localURL: localURL)) - } - } - - // collect params download tasks - for paramsRecord in paramsConfig.records { - let remoteURL = modelRemoteBaseURL.appending(path: paramsRecord.dataPath) - let localURL = modelLocalBaseURL.appending(path: paramsRecord.dataPath) - - if fileManager.fileExists(atPath: localURL.path()) { - progress += 1 - } else { - remainingTasks.insert(DownloadTask(remoteURL: remoteURL, localURL: localURL)) - } - } - - if progress < total { - switchToPaused() - } else { - switchToFinished() - } - } - - func handleNewDownload(downloadTask: DownloadTask) { - // start one download task - assert(downloadingTasks.count < maxDownloadingTasks) - let task = URLSession.shared.downloadTask(with: downloadTask.remoteURL) { - [weak self] urlOrNil, responseOrNil, errorOrNil in - guard let self else { return } - guard let fileUrl = urlOrNil else { - DispatchQueue.main.async { - self.handleCancelDownload(downloadTask: downloadTask) - } - return - } - - do { - try self.fileManager.createDirectory(at: downloadTask.localURL.deletingLastPathComponent(), withIntermediateDirectories: true) - try? self.fileManager.removeItem(at: downloadTask.localURL) - try self.fileManager.moveItem(at: fileUrl, to: downloadTask.localURL) - } catch { - print(error.localizedDescription) - } - DispatchQueue.main.async { - self.handleFinishDownload(downloadTask: downloadTask) - } - } - downloadingTasks.insert(downloadTask) - task.resume() - } - - func handleFinishDownload(downloadTask: DownloadTask) { - // update the finished download task - remainingTasks.remove(downloadTask) - downloadingTasks.remove(downloadTask) - progress += 1 - assert(modelDownloadState == .downloading || - modelDownloadState == .pausing || - modelDownloadState == .clearing || - modelDownloadState == .deleting - ) - if modelDownloadState == .downloading { - if remainingTasks.isEmpty && downloadingTasks.isEmpty { - switchToFinished() - } else { - handleNextDownload() - } - } else if modelDownloadState == .pausing && downloadingTasks.isEmpty { - switchToPaused() - } else if modelDownloadState == .clearing && downloadingTasks.isEmpty { - clear() - } else if modelDownloadState == .deleting && downloadingTasks.isEmpty { - delete() - } - } - - func handleCancelDownload(downloadTask: DownloadTask) { - // withdraw the failed download task - assert(modelDownloadState == .downloading || modelDownloadState == .pausing) - downloadingTasks.remove(downloadTask) - if modelDownloadState == .downloading { - handleNextDownload() - } else if modelDownloadState == .pausing && downloadingTasks.count == 0 { - switchToPaused() - } - } - - func handleNextDownload() { - // start next download task - assert(modelDownloadState == .downloading) - for downloadTask in remainingTasks { - if !downloadingTasks.contains(downloadTask) { - handleNewDownload(downloadTask: downloadTask) - break - } - } - } - - func switchToPaused() { - modelDownloadState = .paused - } - - func switchToPausing() { - modelDownloadState = .pausing - } - - func switchToVerifying() { - modelDownloadState = .verifying - - let paramsConfigURL = modelLocalBaseURL.appending(path: Constants.paramsConfigFileName) - guard fileManager.fileExists(atPath: paramsConfigURL.path()) else { - switchToFailed() - return - } - - loadParamsConfig() - guard let paramsConfig else { - switchToFailed() - return - } - progress = 0 - total = modelConfig.tokenizerFiles.count + paramsConfig.records.count - - if !verifyTokenizers() { - switchToFailed() - return - } - - if !verifyParams() { - switchToFailed() - return - } - - switchToFinished() - } - - func verifyTokenizers() -> Bool { - for tokenizerFile in modelConfig.tokenizerFiles { - let localURL = modelLocalBaseURL.appending(path: tokenizerFile) - - if !fileManager.fileExists(atPath: localURL.path()) { - switchToFailed() - return false - } - progress += 1 - } - return true - } - - func verifyParams() -> Bool { - guard let paramsConfig else { - return false - } - - for paramsRecord in paramsConfig.records { - let localUrl = modelLocalBaseURL.appending(path: paramsRecord.dataPath) - - if !fileManager.fileExists(atPath: localUrl.path()) { - switchToFailed() - return false - } - - progress += 1 - } - return true - } - - func switchToClearing() { - if modelDownloadState == .paused { - modelDownloadState = .clearing - clear() - } else if modelDownloadState == .finished { - if chatState.modelID == modelConfig.modelID { - chatState.requestTerminateChat { [weak self] in - self?.clear() - } - } else { - clear() - } - } else { - modelDownloadState = .clearing - } - } - - func switchToDeleting() { - if modelDownloadState == .paused || modelDownloadState == .failed { - modelDownloadState = .deleting - delete() - } else if modelDownloadState == .finished { - if chatState.modelID == modelConfig.modelID { - chatState.requestTerminateChat { [weak self] in - self?.delete() - } - } else { - delete() - } - } else { - modelDownloadState = .deleting - } - } - - func switchToFinished() { - modelDownloadState = .finished - } - - func switchToFailed() { - modelDownloadState = .failed - } - - func switchToDownloading() { - modelDownloadState = .downloading - for downloadTask in remainingTasks { - if downloadingTasks.count < maxDownloadingTasks { - handleNewDownload(downloadTask: downloadTask) - } else { - return - } - } - } - - func clear() { - do { - let fileURLs = try fileManager.contentsOfDirectory(at: modelLocalBaseURL, includingPropertiesForKeys: nil) - for fileURL in fileURLs where fileURL.lastPathComponent != Constants.modelConfigFileName { - try fileManager.removeItem(at: fileURL) - assert(!fileManager.fileExists(atPath: fileURL.path())) - } - assert(fileManager.fileExists(atPath: modelLocalBaseURL.appending(path: Constants.modelConfigFileName).path())) - switchToIndexing() - } catch { - print(error.localizedDescription) - } - } - - func delete() { - do { - try fileManager.removeItem(at: modelLocalBaseURL) - assert(!fileManager.fileExists(atPath: modelLocalBaseURL.path())) - startState.requestDeleteModel(modelID: modelConfig.modelID!) // TODO: can it decouple? - } catch { - print(error.localizedDescription) - } - } -} diff --git a/ios/MLCChat/Views/ChatView.swift b/ios/MLCChat/Views/ChatView.swift deleted file mode 100644 index d1d5de44ab..0000000000 --- a/ios/MLCChat/Views/ChatView.swift +++ /dev/null @@ -1,176 +0,0 @@ -// -// ChatView.swift -// MLCChat -// - -import SwiftUI -import GameController - -struct ChatView: View { - @EnvironmentObject private var chatState: ChatState - - @State private var inputMessage: String = "" - @FocusState private var inputIsFocused: Bool - @Environment(\.dismiss) private var dismiss - @Namespace private var messagesBottomID - - // vision-related properties - @State private var showActionSheet: Bool = false - @State private var showImagePicker: Bool = false - @State private var imageConfirmed: Bool = false - @State private var imageSourceType: UIImagePickerController.SourceType = .photoLibrary - @State private var image: UIImage? - - var body: some View { - VStack { - modelInfoView - messagesView - uploadImageView - messageInputView - } - .navigationBarTitle("MLC Chat: \(chatState.displayName)", displayMode: .inline) - .navigationBarBackButtonHidden() - .toolbar { - ToolbarItem(placement: .navigationBarLeading) { - Button { - dismiss() - } label: { - Image(systemName: "chevron.backward") - } - .buttonStyle(.borderless) - .disabled(!chatState.isInterruptible) - } - ToolbarItem(placement: .navigationBarTrailing) { - Button("Reset") { - image = nil - imageConfirmed = false - chatState.requestResetChat() - } - .padding() - .disabled(!chatState.isResettable) - } - } - } -} - -private extension ChatView { - var modelInfoView: some View { - Text(chatState.infoText) - .multilineTextAlignment(.center) - .opacity(0.5) - .listRowSeparator(.hidden) - } - - var messagesView: some View { - ScrollViewReader { scrollViewProxy in - ScrollView { - VStack { - let messageCount = chatState.messages.count - let hasSystemMessage = messageCount > 0 && chatState.messages[0].role == MessageRole.bot - let startIndex = hasSystemMessage ? 1 : 0 - - // display the system message - if hasSystemMessage { - MessageView(role: chatState.messages[0].role, message: chatState.messages[0].message) - } - - // display image - if let image, imageConfirmed { - ImageView(image: image) - } - - // display conversations - ForEach(chatState.messages[startIndex...], id: \.id) { message in - MessageView(role: message.role, message: message.message) - } - HStack { EmptyView() } - .id(messagesBottomID) - } - } - .onChange(of: chatState.messages) { _ in - withAnimation { - scrollViewProxy.scrollTo(messagesBottomID, anchor: .bottom) - } - } - } - } - - @ViewBuilder - var uploadImageView: some View { - if chatState.useVision && !imageConfirmed { - if image == nil { - Button("Upload picture to chat") { - showActionSheet = true - } - .actionSheet(isPresented: $showActionSheet) { - ActionSheet(title: Text("Choose from"), buttons: [ - .default(Text("Photo Library")) { - showImagePicker = true - imageSourceType = .photoLibrary - }, - .default(Text("Camera")) { - showImagePicker = true - imageSourceType = .camera - }, - .cancel() - ]) - } - .sheet(isPresented: $showImagePicker) { - ImagePicker(image: $image, - showImagePicker: $showImagePicker, - imageSourceType: imageSourceType) - } - .disabled(!chatState.isUploadable) - } else { - VStack { - if let image { - Image(uiImage: image) - .resizable() - .frame(width: 300, height: 300) - - HStack { - Button("Undo") { - self.image = nil - } - .padding() - - Button("Submit") { - imageConfirmed = true - chatState.requestProcessImage(image: image) - } - .padding() - } - } - } - } - } - } - - var messageInputView: some View { - HStack { - TextField("Inputs...", text: $inputMessage, axis: .vertical) - .textFieldStyle(RoundedBorderTextFieldStyle()) - .frame(minHeight: CGFloat(30)) - .focused($inputIsFocused) - .onSubmit { - let isKeyboardConnected = GCKeyboard.coalesced != nil - if isKeyboardConnected { - send() - } - } - Button("Send") { - send() - } - .bold() - .disabled(!(chatState.isChattable && inputMessage != "")) - } - .frame(minHeight: CGFloat(70)) - .padding() - } - - func send() { - inputIsFocused = false - chatState.requestGenerate(prompt: inputMessage) - inputMessage = "" - } -} diff --git a/ios/MLCChat/Views/ImageProcessing.swift b/ios/MLCChat/Views/ImageProcessing.swift deleted file mode 100644 index 3d7260e3a0..0000000000 --- a/ios/MLCChat/Views/ImageProcessing.swift +++ /dev/null @@ -1,66 +0,0 @@ -// -// ImageProcessing.swift -// MLCChat -// -// Created by Kathryn Chen on 7/8/23. -// - -import Foundation -import SwiftUI -import UIKit - -// adapted from Mohammad Azam: https://github.com/azamsharp/SwiftUICamera -// delegate task to the coordinator to produce the image -struct ImagePicker : UIViewControllerRepresentable { - typealias UIViewControllerType = UIImagePickerController - typealias Coordinator = ImagePickerCoordinator - - @Binding var image: UIImage? - @Binding var showImagePicker: Bool - var imageSourceType: UIImagePickerController.SourceType = .photoLibrary - - func makeCoordinator() -> ImagePicker.Coordinator { - return ImagePickerCoordinator(image: $image, showImagePicker: $showImagePicker) - } - - func makeUIViewController(context: UIViewControllerRepresentableContext) -> UIImagePickerController { - let picker = UIImagePickerController() - picker.sourceType = imageSourceType - picker.delegate = context.coordinator - return picker - } - - func updateUIViewController(_ uiViewController: UIImagePickerController, context: UIViewControllerRepresentableContext) {} -} - -// image picker coordinator handling selecting from library or taking a photo -class ImagePickerCoordinator: NSObject, UINavigationControllerDelegate, UIImagePickerControllerDelegate { - @Binding var image: UIImage? - @Binding var showImagePicker: Bool - - init(image: Binding, showImagePicker: Binding) { - _image = image - _showImagePicker = showImagePicker - } - - func imagePickerController(_ picker: UIImagePickerController, didFinishPickingMediaWithInfo info: [UIImagePickerController.InfoKey : Any]) { - if let optionalImage = info[UIImagePickerController.InfoKey.originalImage] as? UIImage { - image = optionalImage - showImagePicker = false - } - } - - func imagePickerControllerDidCancel(_ picker: UIImagePickerController) { - showImagePicker = false - } -} - -// resize the input image to given width and height -func resizeImage(image: UIImage, width: Int, height: Int) -> UIImage { - let shape = CGSize(width: width, height: height) - UIGraphicsBeginImageContextWithOptions(shape, true, 0.0) - image.draw(in: CGRect(x: 0, y: 0, width: width, height: height)) - let resizedImage: UIImage? = UIGraphicsGetImageFromCurrentImageContext() - UIGraphicsEndImageContext() - return resizedImage ?? image -} diff --git a/ios/MLCChat/Views/MessageView.swift b/ios/MLCChat/Views/MessageView.swift deleted file mode 100644 index 4553f6bad1..0000000000 --- a/ios/MLCChat/Views/MessageView.swift +++ /dev/null @@ -1,66 +0,0 @@ -// -// MessageView.swift -// MLCChat -// - -import SwiftUI - -struct MessageView: View { - let role: MessageRole; - let message: String - - var body: some View { - let textColor = role.isUser ? Color.white : Color(UIColor.label) - let background = role.isUser ? Color.blue : Color(UIColor.secondarySystemBackground) - - HStack { - if role.isUser { - Spacer() - } - Text(message) - .padding(10) - .foregroundColor(textColor) - .background(background) - .cornerRadius(10) - .textSelection(.enabled) - if !role.isUser { - Spacer() - } - } - .padding() - .listRowSeparator(.hidden) - } -} - -struct ImageView: View { - let image: UIImage - - var body: some View { - let background = Color.blue - HStack { - Spacer() - Image(uiImage: image) - .resizable() - .frame(width: 150, height: 150) - .padding(15) - .background(background) - .cornerRadius(20) - } - .padding() - .listRowSeparator(.hidden) - } -} - -struct MessageView_Previews: PreviewProvider { - static var previews: some View { - NavigationView { - VStack (spacing: 0){ - ScrollView { - MessageView(role: MessageRole.user, message: "Message 1") - MessageView(role: MessageRole.bot, message: "Message 2") - MessageView(role: MessageRole.user, message: "Message 3") - } - } - } - } -} diff --git a/ios/MLCChat/Views/ModelView.swift b/ios/MLCChat/Views/ModelView.swift deleted file mode 100644 index 4676fb2ad7..0000000000 --- a/ios/MLCChat/Views/ModelView.swift +++ /dev/null @@ -1,97 +0,0 @@ -// -// ModelView.swift -// MLCChat -// -// Created by Yaxing Cai on 5/14/23. -// - -import SwiftUI - -struct ModelView: View { - @EnvironmentObject private var modelState: ModelState - @EnvironmentObject private var chatState: ChatState - @Binding var isRemoving: Bool - - @State private var isShowingDeletionConfirmation: Bool = false - - var body: some View { - VStack(alignment: .leading) { - if (modelState.modelDownloadState == .finished) { - NavigationLink(destination: - ChatView() - .environmentObject(chatState) - .onAppear { - modelState.startChat(chatState: chatState) - } - ) { - HStack { - Text(modelState.modelConfig.modelID!) - Spacer() - if chatState.isCurrentModel(modelID: modelState.modelConfig.modelID!) { - Image(systemName: "checkmark").foregroundColor(.blue) - } - } - } - .buttonStyle(.borderless) - } else { - Text(modelState.modelConfig.modelID!).opacity(0.5) - } - HStack{ - if modelState.modelDownloadState != .finished || isRemoving { - ProgressView(value: Double(modelState.progress) / Double(modelState.total)) - .progressViewStyle(.linear) - } - - if (modelState.modelDownloadState == .paused) { - Button { - modelState.handleStart() - } label: { - Image(systemName: "icloud.and.arrow.down") - } - .buttonStyle(.borderless) - } else if (modelState.modelDownloadState == .downloading) { - Button { - modelState.handlePause() - } label: { - Image(systemName: "stop.circle") - } - .buttonStyle(.borderless) - } else if (modelState.modelDownloadState == .failed) { - Image(systemName: "exclamationmark.triangle") - .foregroundColor(.red) - } - - if isRemoving { - Button(role: .destructive) { - isShowingDeletionConfirmation = true - } label: { - Image(systemName: "trash") - } - .confirmationDialog("Delete Model", isPresented: $isShowingDeletionConfirmation) { - Button("Delete Model", role: .destructive) { - modelState.handleDelete() - } - .disabled( - modelState.modelDownloadState != .downloading && - modelState.modelDownloadState != .paused && - modelState.modelDownloadState != .finished && - modelState.modelDownloadState != .failed) - Button("Clear Data") { - modelState.handleClear() - } - .disabled( - modelState.modelDownloadState != .downloading && - modelState.modelDownloadState != .paused && - modelState.modelDownloadState != .finished) - Button("Cancel", role: .cancel) { - isShowingDeletionConfirmation = false - } - } message: { - Text("Delete model will delete the all files with model config, and delete the entry in list. \n Clear model will keep the model config only, and keep the entry in list for future re-downloading.") - } - .buttonStyle(.borderless) - } - } - } - } -} diff --git a/ios/MLCChat/Views/StartView.swift b/ios/MLCChat/Views/StartView.swift deleted file mode 100644 index 87e585d52d..0000000000 --- a/ios/MLCChat/Views/StartView.swift +++ /dev/null @@ -1,46 +0,0 @@ -// -// DownloadView.swift -// MLCChat -// -// Created by Yaxing Cai on 5/11/23. -// - -import SwiftUI - -struct StartView: View { - @EnvironmentObject private var appState: AppState - @State private var isAdding: Bool = false - @State private var isRemoving: Bool = false - @State private var inputModelUrl: String = "" - - var body: some View { - NavigationStack { - List{ - Section(header: Text("Models")) { - ForEach(appState.models) { modelState in - ModelView(isRemoving: $isRemoving) - .environmentObject(modelState) - .environmentObject(appState.chatState) - } - if !isRemoving { - Button("Edit model") { - isRemoving = true - } - .buttonStyle(.borderless) - } else { - Button("Cancel edit model") { - isRemoving = false - } - .buttonStyle(.borderless) - } - } - } - .navigationTitle("MLC Chat") - .alert("Error", isPresented: $appState.alertDisplayed) { - Button("OK") { } - } message: { - Text(appState.alertMessage) - } - } - } -} diff --git a/ios/MLCChat/app-config.json b/ios/MLCChat/app-config.json deleted file mode 100644 index 1379fc6647..0000000000 --- a/ios/MLCChat/app-config.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "model_list": [ - { - "model_path": "Mistral-7B-Instruct-v0.2-q3f16_1", - "model_id": "Mistral-7B-Instruct-v0.2-q3f16_1", - "model_lib": "mistral_q3f16_1", - "estimated_vram_bytes": 3316000000 - }, - { - "model_url": "https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC", - "model_id": "RedPajama-INCITE-Chat-3B-v1-q4f16_1", - "model_lib": "gpt_neox_q4f16_1", - "estimated_vram_bytes": 2960000000 - }, - { - "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC", - "model_id": "phi-2-q4f16_1", - "model_lib": "phi_msft_q4f16_1", - "estimated_vram_bytes": 3043000000 - }, - { - "model_url": "https://huggingface.co/mlc-ai/gemma-2b-it-q4f16_1-MLC", - "model_id": "gemma-2b-q4f16_1", - "model_lib": "gemma_q4f16_1", - "estimated_vram_bytes": 3000000000 - } - ], - "model_lib_path_for_prepare_libs": { - "mistral_q3f16_1": "lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q3f16_1-iphone.tar", - "gpt_neox_q4f16_1": "lib/RedPajama-INCITE-Chat-3B-v1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar", - "phi_msft_q4f16_1": "lib/phi-2/phi-2-q4f16_1-iphone.tar", - "gemma_q4f16_1": "lib/gemma-2b-it/gemma-2b-it-q4f16_1-iphone.tar" - } -} diff --git a/ios/MLCSwift/Package.swift b/ios/MLCSwift/Package.swift deleted file mode 100644 index eac88dbbf2..0000000000 --- a/ios/MLCSwift/Package.swift +++ /dev/null @@ -1,32 +0,0 @@ -// swift-tools-version:5.5 -// The swift-tools-version declares the minimum version of Swift required to build this package. - -import PackageDescription - -let package = Package( - name: "MLCSwift", - products: [ - .library( - name: "MLCSwift", - targets: ["LLMChatObjC", "MLCSwift"] - ) - ], - dependencies: [], - targets: [ - .target( - name: "LLMChatObjC", - path: "Sources/ObjC", - cxxSettings: [ - .headerSearchPath("../../tvm_home/include"), - .headerSearchPath("../../tvm_home/3rdparty/dmlc-core/include"), - .headerSearchPath("../../tvm_home/3rdparty/dlpack/include") - ] - ), - .target( - name: "MLCSwift", - dependencies: ["LLMChatObjC"], - path: "Sources/Swift" - ) - ], - cxxLanguageStandard: .cxx17 -) diff --git a/ios/MLCSwift/README.md b/ios/MLCSwift/README.md deleted file mode 100644 index 3a7c2b578d..0000000000 --- a/ios/MLCSwift/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# MLCSwift - -This is a simple swift package that exposes the chat module to swift. -Checkout our [documentation](https://llm.mlc.ai/docs/) for more examples. diff --git a/ios/MLCSwift/Sources/ObjC/LLMChat.mm b/ios/MLCSwift/Sources/ObjC/LLMChat.mm deleted file mode 100644 index dcf57c5db2..0000000000 --- a/ios/MLCSwift/Sources/ObjC/LLMChat.mm +++ /dev/null @@ -1,242 +0,0 @@ -// -// LLMChat.mm -// LLMChat -// -#import -#import -#include - -#include "LLMChat.h" - -#define TVM_USE_LIBBACKTRACE 0 -#define DMLC_USE_LOGGING_LIBRARY - -#include -#include - -using namespace tvm::runtime; - -enum PlaceInPrompt : int { - // The input message should have role names and corresponding seperators appended both - // prior to it and after it, making it a complete prompt. - kAll, - // The input message is only the beginning part of a prompt, no role name and separator should be - // appended after the message since there will be future messages appended after the message. - kBegin, - // The input message is in the middle of a prompt, nothing should be appended before or after the - // message. - kMiddle, - // The input message is the ending part of a prompt, no role name and separator should be appended - // prior to it since the message is concatenated to some prior messages. - kEnd, -}; - -@implementation ChatModule { - // Internal c++ classes - // chat-related module and functions - Module llm_chat_; - PackedFunc unload_func_; - PackedFunc reload_func_; - PackedFunc prefill_func_; - PackedFunc embed_func_; - PackedFunc prefill_with_embed_func_; - PackedFunc decode_func_; - PackedFunc get_message_; - PackedFunc stopped_func_; - PackedFunc reset_chat_func_; - PackedFunc runtime_stats_text_func_; - PackedFunc process_system_prompts_func_; - // image-related module and functions - Module llm_image_mod_; - PackedFunc image_mod_unload_func_; - PackedFunc image_mod_reload_func_; - PackedFunc image_mod_embed_func_; - PackedFunc image_mod_reset_func_; - PackedFunc image_mod_runtime_stats_text_func_; - // helper variables - bool first_input_after_image; - std::vector image_data; - NSUInteger image_width; - NSUInteger image_height; -} - -- (instancetype)init { - if (self = [super init]) { - // load chat module - const PackedFunc* f_chat_create = Registry::Get("mlc.llm_chat_create"); - ICHECK(f_chat_create) << "Cannot find mlc.llm_chat_create"; - llm_chat_ = (*f_chat_create)(static_cast(kDLMetal), 0); - // load image module - const PackedFunc* f_image_mod_create = Registry::Get("mlc.llm_image_module_create"); - ICHECK(f_image_mod_create) << "Cannot find mlc.llm_image_module_create"; - llm_image_mod_ = (*f_image_mod_create)(static_cast(kDLMetal), 0); - - // chat-related functions - reload_func_ = llm_chat_->GetFunction("reload"); - unload_func_ = llm_chat_->GetFunction("unload"); - prefill_func_ = llm_chat_->GetFunction("prefill"); - embed_func_ = llm_chat_->GetFunction("embed"); - prefill_with_embed_func_ = llm_chat_->GetFunction("prefill_with_embed"); - decode_func_ = llm_chat_->GetFunction("decode"); - get_message_ = llm_chat_->GetFunction("get_message"); - stopped_func_ = llm_chat_->GetFunction("stopped"); - reset_chat_func_ = llm_chat_->GetFunction("reset_chat"); - runtime_stats_text_func_ = llm_chat_->GetFunction("runtime_stats_text"); - process_system_prompts_func_ = llm_chat_->GetFunction("process_system_prompts"); - // image-module-related functions - image_mod_reload_func_ = llm_image_mod_->GetFunction("reload"); - image_mod_unload_func_ = llm_image_mod_->GetFunction("unload"); - image_mod_embed_func_ = llm_image_mod_->GetFunction("embed"); - image_mod_reset_func_ = llm_image_mod_->GetFunction("reset"); - image_mod_runtime_stats_text_func_ = llm_image_mod_->GetFunction("runtime_stats_text"); - // helper variables - first_input_after_image = false; - image_height = 224; - image_width = 224; - image_data.reserve(image_height * image_width * 4); - - ICHECK(reload_func_ != nullptr); - ICHECK(unload_func_ != nullptr); - ICHECK(prefill_func_ != nullptr); - ICHECK(embed_func_ != nullptr); - ICHECK(prefill_with_embed_func_ != nullptr); - ICHECK(decode_func_ != nullptr); - ICHECK(get_message_ != nullptr); - ICHECK(stopped_func_ != nullptr); - ICHECK(reset_chat_func_ != nullptr); - ICHECK(runtime_stats_text_func_ != nullptr); - ICHECK(process_system_prompts_func_ != nullptr); - ICHECK(image_mod_unload_func_ != nullptr); - ICHECK(image_mod_reload_func_ != nullptr); - ICHECK(image_mod_embed_func_ != nullptr); - ICHECK(image_mod_reset_func_ != nullptr); - ICHECK(image_mod_runtime_stats_text_func_ != nullptr); - } - return self; -} - -- (void)unload { - unload_func_(); -} - -- (void)reload:(NSString*)modelLib - modelPath:(NSString*)modelPath - appConfigJson:(NSString*)appConfigJson { - std::string lib_prefix = modelLib.UTF8String; - std::string model_path = modelPath.UTF8String; - std::string app_config_json = appConfigJson.UTF8String; - std::replace(lib_prefix.begin(), lib_prefix.end(), '-', '_'); - lib_prefix += '_'; - Module lib = (*Registry::Get("runtime.SystemLib"))(lib_prefix); - reload_func_(lib, model_path, app_config_json); -} - -- (void)resetChat { - reset_chat_func_(); -} - -- (void)prefill:(NSString*)input { - std::string prompt = input.UTF8String; - if (first_input_after_image) { - prefill_func_(prompt, true, (int)PlaceInPrompt::kEnd); - first_input_after_image = false; - } else { - prefill_func_(prompt); - } -} - -- (void)decode { - decode_func_(); -} - -- (NSString*)getMessage { - std::string ret = get_message_(); - return [NSString stringWithUTF8String:ret.c_str()]; -} - -- (bool)stopped { - return stopped_func_().operator bool(); -} - -- (NSString*)runtimeStatsText:(bool)useVision { - std::string chat_mod_stats = runtime_stats_text_func_(); - if (useVision) { - std::string image_mod_stats = image_mod_runtime_stats_text_func_(); - chat_mod_stats += ", " + image_mod_stats; - } - return [NSString stringWithUTF8String:chat_mod_stats.c_str()]; -} - -- (void)processSystemPrompts { - process_system_prompts_func_(); -} - -- (void)evaluate { - LOG(INFO) << "Total-mem-budget=" << os_proc_available_memory() / (1 << 20) << "MB"; - llm_chat_->GetFunction("evaluate")(); - LOG(INFO) << "Left-mem-budget=" << os_proc_available_memory() / (1 << 20) << "MB"; -} - -- (void)unloadImageModule { - image_mod_unload_func_(); - first_input_after_image = false; -} - -- (void)reloadImageModule:(NSString*)modelLib modelPath:(NSString*)modelPath { - first_input_after_image = false; - std::string lib_prefix = modelLib.UTF8String; - std::string model_path = modelPath.UTF8String; - std::replace(lib_prefix.begin(), lib_prefix.end(), '-', '_'); - lib_prefix += '_'; - Module lib = (*Registry::Get("runtime.SystemLib"))(lib_prefix); - image_mod_reload_func_(lib, model_path); -} - -- (void)resetImageModule { - image_mod_reset_func_(); - first_input_after_image = false; -} - -- (void)prefillImage:(UIImage*)image - prevPlaceholder:(NSString*)prevPlaceholder - postPlaceholder:(NSString*)postPlaceholder { - // prefill the previous placeholder string - std::string prev_placeholder = prevPlaceholder.UTF8String; - prefill_func_(prev_placeholder, false, (int)PlaceInPrompt::kBegin); - - // prefill with image embedding - // step 1. get image rawdata: credit from https://stackoverflow.com/a/1262893 - CGImageRef imageRef = [image CGImage]; - CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB(); - NSUInteger bytesPerPixel = 4; - NSUInteger bytesPerRow = bytesPerPixel * image_width; - NSUInteger bitsPerComponent = 8; - CGContextRef context = CGBitmapContextCreate( - image_data.data(), image_width, image_height, bitsPerComponent, bytesPerRow, colorSpace, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); - CGColorSpaceRelease(colorSpace); - CGContextDrawImage(context, CGRectMake(0, 0, image_width, image_height), imageRef); - CGContextRelease(context); - // step 2. create tvm NDArray - ShapeTuple shape = {1, int(image_height), int(image_width), 4}; - DLDataType dtype = DataType::UInt(8); - DLDevice device = DLDevice{kDLMetal, 0}; - size_t nbytes = size_t(dtype.bits / 8); - for (auto s : shape) { - nbytes *= (size_t)s; - } - NDArray input_image = NDArray::Empty(shape, dtype, device); - input_image.CopyFromBytes(image_data.data(), nbytes); - // step 3. prefill with image embedding - NDArray embedding = image_mod_embed_func_(input_image); - prefill_with_embed_func_(embedding, false); - - // prefill the post placeholder string - std::string post_placeholder = postPlaceholder.UTF8String; - prefill_func_(post_placeholder, false, (int)PlaceInPrompt::kMiddle); - - // update the flag - first_input_after_image = true; -} - -@end diff --git a/ios/MLCSwift/Sources/ObjC/include/LLMChat.h b/ios/MLCSwift/Sources/ObjC/include/LLMChat.h deleted file mode 100644 index 0aab17adb1..0000000000 --- a/ios/MLCSwift/Sources/ObjC/include/LLMChat.h +++ /dev/null @@ -1,127 +0,0 @@ -// -// Use this file to import your target's public headers that you would like to expose to Swift. -// LLM Chat Module -// -// Exposed interface of Object-C, enables swift binding. -#import -#import -#include - -/** - * The chat module that can be used by the swift app. - * It is a centralized interface that also provides multimodal support, i.e. vision modules. - * - * A chat flow can be implemented as follows, for each round of conversation - * - * @code - * - * chat.prefill(input); - * while(!chat.stopped()) { - * displayReply(chat.getMessage()); - * chat.decode(); - * } - * - * @endcode - * - * The execution logic of this module should be placed on a dedicated thread. - * - * @seealso ThreadWorker - */ -@interface ChatModule : NSObject - -/** - * Unload the current model and free all memory. - * @note This function is useful to get memory estimation before launch next model. - */ -- (void)unload; - -/** - * Reload the chat module to a new model. - * - * @param modelLib The name of the modelLib - * @param modelPath The path to the model artifacts. - * @param appConfigJson The partial config that is used to partially override the model - * configuration. - */ -- (void)reload:(NSString*)modelLib - modelPath:(NSString*)modelPath - appConfigJson:(NSString*)appConfigJson; - -/** - * Reset the current chat session. - */ -- (void)resetChat; - -/** - * Run prefill stage for a given input and decode the first output token. - * - *@param input The user input prompt. - */ -- (void)prefill:(NSString*)input; - -/** - *Run one decode step to decode the next token. - */ -- (void)decode; - -/** - * @returns The output message in the current round. - */ -- (NSString*)getMessage; - -/** - * @returns Whether the current round stopped - */ -- (bool)stopped; - -/** - * Get the runtime statistics for the chat module, and optionally the image module. - * - *@param useVision Whether an image module is used. - */ -- (NSString*)runtimeStatsText:(bool)useVision; - -/** - * Pre-process by prefilling the system prompts, running prior to any user input. - */ -- (void)processSystemPrompts; - -/** - * \brief Run one round of prefill and decode. - * - * This function is not supposed to be used by apps. - * and is only included here when setting up the app - * for debugging purposes. - */ -- (void)evaluate; - -/** - * Unload the current image model and free all memory. - * @note This function is useful to get memory estimation before launch next model. - */ -- (void)unloadImageModule; - -/** - * Reload the image module to a new model. - * - * @param modelLib The name of the modelLib - * @param modelPath The path to the model artifacts. - */ -- (void)reloadImageModule:(NSString*)modelLib modelPath:(NSString*)modelPath; - -/** - * Reset the current image model. - */ -- (void)resetImageModule; - -/** - * Prefill the LLM with the embedding of the input image. - * - * @param image The uploaded image. - * @param prevPlaceholder The previous placeholder in the prompt, i.e. . - * @param postPlaceholder The post placeholder in the prompt, i.e. . - */ -- (void)prefillImage:(UIImage*)image - prevPlaceholder:(NSString*)prevPlaceholder - postPlaceholder:(NSString*)postPlaceholder; -@end diff --git a/ios/MLCSwift/Sources/Swift/LLMChat.swift b/ios/MLCSwift/Sources/Swift/LLMChat.swift deleted file mode 100644 index fa7d889259..0000000000 --- a/ios/MLCSwift/Sources/Swift/LLMChat.swift +++ /dev/null @@ -1 +0,0 @@ -@_exported import LLMChatObjC diff --git a/ios/MLCSwift/Sources/Swift/ThreadWorker.swift b/ios/MLCSwift/Sources/Swift/ThreadWorker.swift deleted file mode 100644 index 79f1eb2004..0000000000 --- a/ios/MLCSwift/Sources/Swift/ThreadWorker.swift +++ /dev/null @@ -1,31 +0,0 @@ -import Foundation - -// A simple thread worker that is backed by a single thread -// -// Instead of dispatch queue, we need a dedicated thread for metal compute -// so all thread local resources are centralized at a single thread -public class ThreadWorker : Thread { - private var cond = NSCondition(); - private var queue = Array<()->Void>(); - - public override func main() { - Thread.setThreadPriority(1) - while (true) { - self.cond.lock() - while (queue.isEmpty) { - self.cond.wait() - } - let task = self.queue.removeFirst() - self.cond.unlock() - task() - } - } - - public func push(task: @escaping ()->Void) { - self.cond.lock() - self.queue.append(task) - self.cond.signal() - self.cond.unlock() - - } -} diff --git a/ios/MLCSwift/tvm_home b/ios/MLCSwift/tvm_home deleted file mode 120000 index e15bf649f5..0000000000 --- a/ios/MLCSwift/tvm_home +++ /dev/null @@ -1 +0,0 @@ -../../3rdparty/tvm \ No newline at end of file diff --git a/ios/README.md b/ios/README.md deleted file mode 100644 index de94ee75a0..0000000000 --- a/ios/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# MLC-LLM IOS - -[Documentation page](https://llm.mlc.ai/docs/deploy/ios.html) diff --git a/ios/prepare_libs.sh b/ios/prepare_libs.sh deleted file mode 100755 index d87423890d..0000000000 --- a/ios/prepare_libs.sh +++ /dev/null @@ -1,74 +0,0 @@ -function help { - echo -e "OPTION:" - echo -e " -s, --simulator Build for Simulator" - echo -e " -a, --arch x86_64 | arm64 Simulator arch " - echo -e " -h, --help Prints this help\n" -} - -is_simulator="false" -arch="arm64" - -# Args while-loop -while [ "$1" != "" ]; -do - case $1 in - -s | --simulator ) is_simulator="true" - ;; - -a | --arch ) shift - arch=$1 - ;; - -h | --help ) help - exit - ;; - *) - echo "$script: illegal option $1" - usage - exit 1 # error - ;; - esac - shift -done - -set -euxo pipefail - -sysroot="iphoneos" -type="Release" - -if [ "$is_simulator" = "true" ]; then - if [ "$arch" = "arm64" ]; then - # iOS simulator on Apple processors - rustup target add aarch64-apple-ios-sim - else - # iOS simulator on x86 processors - rustup target add x86_64-apple-ios - fi - sysroot="iphonesimulator" - type="Debug" -else - # iOS devices - rustup target add aarch64-apple-ios -fi - -mkdir -p build/ && cd build/ - -cmake ../..\ - -DCMAKE_BUILD_TYPE=$type\ - -DCMAKE_SYSTEM_NAME=iOS\ - -DCMAKE_SYSTEM_VERSION=14.0\ - -DCMAKE_OSX_SYSROOT=$sysroot\ - -DCMAKE_OSX_ARCHITECTURES=$arch\ - -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0\ - -DCMAKE_BUILD_WITH_INSTALL_NAME_DIR=ON\ - -DCMAKE_SKIP_INSTALL_ALL_DEPENDENCY=ON\ - -DCMAKE_INSTALL_PREFIX=.\ - -DCMAKE_CXX_FLAGS="-O3"\ - -DMLC_LLM_INSTALL_STATIC_LIB=ON\ - -DUSE_METAL=ON -make mlc_llm_static -cmake --build . --target install --config release -j -cd .. - -rm -rf MLCSwift/tvm_home -ln -s ../../3rdparty/tvm MLCSwift/tvm_home - -python prepare_model_lib.py diff --git a/ios/prepare_model_lib.py b/ios/prepare_model_lib.py deleted file mode 100644 index ff56236321..0000000000 --- a/ios/prepare_model_lib.py +++ /dev/null @@ -1,88 +0,0 @@ -import json -import os -import sys -from tvm.contrib import cc - - -def get_model_libs(lib_path): - global_symbol_map = cc.get_global_symbol_section_map(lib_path) - libs = [] - suffix = "___tvm_dev_mblob" - for name in global_symbol_map.keys(): - if name.endswith(suffix): - model_lib = name[: -len(suffix)] - if model_lib.startswith("_"): - model_lib = model_lib[1:] - libs.append(model_lib) - return libs - - -def main(): - app_config_path = "MLCChat/app-config.json" - app_config = json.load(open(app_config_path, "r")) - artifact_path = os.path.abspath(os.path.join("..", "dist")) - - tar_list = [] - model_set = set() - - for model, model_lib_path in app_config["model_lib_path_for_prepare_libs"].items(): - paths = [ - os.path.join(artifact_path, model_lib_path), - os.path.join(artifact_path, "prebuilt", model_lib_path), - os.path.join(model_lib_path), - ] - valid_paths = [p for p in paths if os.path.isfile(p)] - if not valid_paths: - raise RuntimeError( - f"Cannot find iOS lib for {model} from the following candidate paths: {paths}" - ) - tar_list.append(valid_paths[0]) - model_set.add(model) - - lib_path = os.path.join("build", "lib", "libmodel_iphone.a") - - cc.create_staticlib(lib_path, tar_list) - available_model_libs = get_model_libs(lib_path) - print(f"Creating lib from {tar_list}..") - print(f"Validating the library {lib_path}...") - print( - f"List of available model libs packaged: {available_model_libs}," - " if we have '-' in the model_lib string, it will be turned into '_'" - ) - global_symbol_map = cc.get_global_symbol_section_map(lib_path) - error_happened = False - for item in app_config["model_list"]: - model_lib = item["model_lib"] - model_id = item["model_id"] - if model_lib not in model_set: - print( - f"ValidationError: model_lib={model_lib} specified for model_id={model_id} " - "is not included in model_lib_path_for_prepare_libs field, " - "This will cause the specific model not being able to load, " - f"please check {app_config_path}." - ) - error_happened = True - - model_prefix_pattern = model_lib.replace("-", "_") + "___tvm_dev_mblob" - if ( - model_prefix_pattern not in global_symbol_map - and "_" + model_prefix_pattern not in global_symbol_map - ): - model_lib_path = app_config["model_lib_path_for_prepare_libs"][model_lib] - print( - "ValidationError:\n" - f"\tmodel_lib {model_lib} requested in {app_config_path} is not found in {lib_path}\n" - f"\tspecifically the model_lib for {model_lib_path} in model_lib_path_for_prepare_libs.\n" - f"\tcurrent available model_libs in {lib_path}: {available_model_libs}" - ) - error_happened = True - - if not error_happened: - print("Validation pass") - else: - print("Validation failed") - exit(255) - - -if __name__ == "__main__": - main() diff --git a/ios/prepare_params.sh b/ios/prepare_params.sh deleted file mode 100755 index 0ac293228c..0000000000 --- a/ios/prepare_params.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/bin/bash -set -euxo pipefail - -# NOTE: this is optional, prepackage weight into app -rm -rf dist -mkdir -p dist - -declare -a builtin_list=( - "Mistral-7B-Instruct-v0.2-q3f16_1" - # "OpenHermes-2.5-Mistral-7B-q3f16_1" - # "Llama-2-7b-chat-hf-q3f16_1" - # "RedPajama-INCITE-Chat-3B-v1-q4f16_1" - # "vicuna-v1-7b-q3f16_0" - # "rwkv-raven-1b5-q8f16_0" - # "rwkv-raven-3b-q8f16_0" - # "rwkv-raven-7b-q8f16_0" -) - -for model in "${builtin_list[@]}"; do - if [ -d ../dist/$model/params ]; then - cp -r ../dist/$model/params dist/$model - elif [ -d ../dist/prebuilt/$model ]; then - cp -r ../dist/prebuilt/$model dist/$model - elif [ -d ../dist/prebuilt/mlc-chat-$model ]; then - cp -r ../dist/prebuilt/mlc-chat-$model dist/$model - elif [ -d ../dist/prebuilt/$model-MLC ]; then - cp -r ../dist/prebuilt/$model-MLC dist/$model - else - echo "Cannot find prebuilt weights for " $model - exit 1 - fi -done diff --git a/python/mlc_llm/__init__.py b/python/mlc_llm/__init__.py index 1654010664..3ff7cb47f4 100644 --- a/python/mlc_llm/__init__.py +++ b/python/mlc_llm/__init__.py @@ -2,9 +2,9 @@ MLC Chat is the app runtime of MLC LLM. """ - -# from . import protocol, serve -# from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig +""" +# NOTE(@sunggg): These are disabled because we don't use them +from . import protocol, serve from .libinfo import __version__ - -# from .serve import AsyncLLMEngine, LLMEngine +from .serve import AsyncMLCEngine, MLCEngine +""" diff --git a/python/mlc_llm/__main__.py b/python/mlc_llm/__main__.py index 857cfc479a..671faf6467 100644 --- a/python/mlc_llm/__main__.py +++ b/python/mlc_llm/__main__.py @@ -14,7 +14,15 @@ def main(): parser.add_argument( "subcommand", type=str, - choices=["compile", "convert_weight", "gen_config", "chat", "serve", "bench"], + choices=[ + "compile", + "convert_weight", + "gen_config", + "chat", + "serve", + "package", + "calibrate", + ], help="Subcommand to to run. (choices: %(choices)s)", ) parsed = parser.parse_args(sys.argv[1:2]) @@ -39,8 +47,12 @@ def main(): from mlc_llm.cli import serve as cli cli.main(sys.argv[2:]) - elif parsed.subcommand == "bench": - from mlc_llm.cli import bench as cli + elif parsed.subcommand == "package": + from mlc_llm.cli import package as cli + + cli.main(sys.argv[2:]) + elif parsed.subcommand == "calibrate": + from mlc_llm.cli import calibrate as cli cli.main(sys.argv[2:]) else: diff --git a/python/mlc_llm/_ffi_api.py b/python/mlc_llm/_ffi_api.py index ee303681fc..c46811e95d 100644 --- a/python/mlc_llm/_ffi_api.py +++ b/python/mlc_llm/_ffi_api.py @@ -3,4 +3,4 @@ # Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc" prefix. # e.g. TVM_REGISTER_GLOBAL("mlc.Tokenizer") -tvm._ffi._init_api("mlc", __name__) # pylint: disable=protected-access +tvm._ffi._init_api("mlc.tokenizers", __name__) # pylint: disable=protected-access diff --git a/python/mlc_llm/bench/__init__.py b/python/mlc_llm/bench/__init__.py new file mode 100644 index 0000000000..2594486ff6 --- /dev/null +++ b/python/mlc_llm/bench/__init__.py @@ -0,0 +1,6 @@ +"""Subdirectory of bench.""" + +from .metrics import MetricsProcessor +from .prompts import PromptsGenerator +from .replay import load_replay_log, replay +from .request import OpenAIRequestSender diff --git a/python/mlc_llm/bench/metrics.py b/python/mlc_llm/bench/metrics.py new file mode 100644 index 0000000000..ab414c2ad9 --- /dev/null +++ b/python/mlc_llm/bench/metrics.py @@ -0,0 +1,253 @@ +""" MLC LLM bench Metrics""" +import json +from typing import Any, Callable, Dict, List, Optional, Union + +from pydantic import BaseModel + +from mlc_llm.support import logging + +from .request import RequestRecords + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +class ServerMetrics(BaseModel): + """The metrics from the server side.""" + + prompt_tokens: int + prefill_tokens: int + completion_tokens: int + decode_tokens_per_s: float + prefill_tokens_per_s: float + end_to_end_latency_s: float + inter_token_latency_s: float + ttft_s: Optional[float] = None + + +class Metrics(BaseModel): + """The list of metric keys""" + + prompt_tokens: int + completion_tokens: int + end_to_end_latency_s: float + inter_token_latency_s: float + decode_tokens_per_s: float + ttft: Optional[float] = None + server_metrics: Optional[ServerMetrics] = None + + +class MetricsProcessor: + """The metrics processor class + + Parameters + ---------- + tokenizer : Optional[Tokenizer] + The tokenizer. + + request_records : List[RequestRecords] + The list of request records. + """ + + def __init__(self, request_records: List[RequestRecords], tokenizer=None) -> None: + self.tokenizer = tokenizer + if self.tokenizer is None: + from transformers import ( # pylint: disable=import-outside-toplevel,import-error + LlamaTokenizerFast, + ) + + self.tokenizer = LlamaTokenizerFast.from_pretrained( + "hf-internal-testing/llama-tokenizer" + ) + logger.warning("No tokenizer provided. Using default tokenizer.") + self.all_metrics: List[Metrics] = self.extract_metrics_from_request_records(request_records) + + def count_tokens(self, prompt: str) -> int: + """Count the number of tokens in the text + + Parameters + ---------- + prompt : str + The text to count the tokens. + + Returns + ------- + prompt_tokens : int + The number of tokens in the prompt. + """ + return len(self.tokenizer.encode(prompt)) + + def extract_metrics_from_request_records( + self, request_records: List[RequestRecords] + ) -> List[Metrics]: + """ + Extract the metrics from request records. + + Parameters + ---------- + request_records : List[RequestRecords] + The list of raw request records collected. + + Returns + ------- + metrics : List[Metrics] + The list of extracted metrics with additional items. + """ + + result = [] + for metric in request_records: + prompt_tokens = self.count_tokens(metric.input) + completion_tokens = self.count_tokens(metric.output) + assert prompt_tokens > 0 and completion_tokens >= 0, "Invalid prompt tokens" + end_to_end_latency_s = metric.end_to_end_latency_s + ttft = metric.ttft if metric.ttft is not None else 0 + server_metric = None + if metric.server_metrics is not None: + server_metric = ServerMetrics( + prompt_tokens=metric.server_metrics["prompt_tokens"], + prefill_tokens=metric.server_metrics["prefill_tokens"], + completion_tokens=metric.server_metrics["completion_tokens"], + decode_tokens_per_s=metric.server_metrics["decode_tokens_per_s"], + prefill_tokens_per_s=metric.server_metrics["prefill_tokens_per_s"], + end_to_end_latency_s=metric.server_metrics["end_to_end_latency_s"], + inter_token_latency_s=metric.server_metrics["inter_token_latency_s"], + ttft_s=metric.server_metrics["ttft_s"], + ) + refined_metric = Metrics( + inter_token_latency_s=end_to_end_latency_s / completion_tokens, + decode_tokens_per_s=(completion_tokens - 1) / (end_to_end_latency_s - ttft), + ttft=metric.ttft, + end_to_end_latency_s=end_to_end_latency_s, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + server_metrics=server_metric, + ) + result.append(refined_metric) + return result + + def get_metrics(self) -> List[Metrics]: + """ + Get the metrics collected. + + Returns + ------- + all_metrics : List[Metrics] + The list of metrics collected. + """ + return self.all_metrics + + def reset_metrics(self, metrics: List[Metrics]) -> None: + """Reset the metrics collected. + + Parameters + ---------- + metrics : List[Metrics] + The list of metrics to reset. + """ + self.all_metrics = metrics + + def filter_metrics(self, criteria: Optional[Callable[[Metrics], bool]] = None) -> List[Metrics]: + """ + Filters the metrics based on the provided criteria. If no criteria are provided, + it filters out metrics with any fields set to None or 0. + + Parameters + ---------- + criteria : Optional[Callable[[Metrics], bool]] + A function that takes a metric as input, + returns True if the metric should be included. + + Returns + ------- + filtered_metrics : List[Metrics] + The list of metrics that meet the specified criteria. + """ + if criteria is None: + # Default criteria to filter out metrics with None or 0 in certain fields + def criteria(metric: Metrics) -> bool: + for field, _ in Metrics.model_fields.items(): + val = getattr(metric, field) + if val is None or val == 0: + return False + return True + + filered_metrics = [metric for metric in self.all_metrics if criteria(metric)] + self.reset_metrics(filered_metrics) + return filered_metrics + + def generate_metrics_summary(self, start_time: float, end_time: float) -> Dict[str, Any]: + """ + Computes summary statistics across all metrics collected. + + Parameters + ---------- + all_metrics : List[RequestRecords] + All the metrics data collected in the monitoring period. + + start_time : float + The start time of the monitoring period. + + end_time : float + The end time of the monitoring period. + + Returns + ------- + report : Dict + A dictionary containing the summary statistics of the collected metrics. + """ + if not self.all_metrics: + return {} + + # Generate the client metrics statistics + report = self._compute_metrics_statistics(self.all_metrics) + report["num_completed_requests"] = len(self.all_metrics) + total_tokens = sum(metric.completion_tokens for metric in self.all_metrics) + report["overall_output_throughput"] = total_tokens / (end_time - start_time) + + # Generate the server metrics statistics + server_metrics = [ + metric.server_metrics for metric in self.all_metrics if metric.server_metrics + ] + server_report = self._compute_metrics_statistics(server_metrics) + report["server_metrics"] = server_report + + logger.info("Metrics Summary:\n%s", json.dumps(report, indent=4, default=str)) + return report + + def _compute_metrics_statistics(self, metrics: List[Union[Metrics, ServerMetrics]]) -> Dict: + """ + Compute the statistics of the metrics. + + Parameters + ---------- + metrics : List[Union[Metrics, ServerMetrics]] + The list of metrics to get the statistics. + + Returns + ------- + report : Dict + The statistics of the metrics. + """ + import pandas as pd # pylint: disable=import-outside-toplevel,import-error + + report: Dict = {} + if not metrics: + return report + + df = pd.DataFrame([metric.model_dump() for metric in metrics]) + for key, _ in metrics[0].model_fields.items(): + if key == "server_metrics": + continue + if key in df.columns: + series = df[key].dropna() + report[key] = { + "quantiles": { + f"p{int(q * 100)}": v + for q, v in series.quantile([0.25, 0.5, 0.75, 0.9, 0.95, 0.99]).items() + }, + "mean": series.mean(), + "min": series.min(), + "max": series.max(), + "stddev": series.std(), + } + return report diff --git a/python/mlc_llm/bench/prompts.py b/python/mlc_llm/bench/prompts.py new file mode 100644 index 0000000000..143d49f0c3 --- /dev/null +++ b/python/mlc_llm/bench/prompts.py @@ -0,0 +1,149 @@ +"""MLC LLM bench prompts generator""" + +import json +import random +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Optional + +from mlc_llm.support import logging + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +class PromptsGenerator: # pylint: disable=too-few-public-methods + """ + Generates prompts of a specified token length from a text file containing potential prompts. + """ + + def __init__( + self, + prompts_path: Optional[str] = None, + json_prompts_path: Optional[str] = None, + tokenizer: Optional[Any] = None, + seed: Optional[int] = 11111, + ) -> None: + """ + Initializes the PromptsGenerator with the file path and tokenizer. + + Parameters + ---------- + prompts_path : Optional[str] + The path to the file containing the source prompts. This file can be + a plain text file, with each line representing a separate prompt str, + or a .jsonl file where each line is a JSON object formatted as + {"prompt": "prompt text", "prompt_tokens": 10}. + + json_prompts_path : Optional[str] + The path to the file containing the source json prompts. This file a + .jsonl file where each line is a JSON object formatted as + {"messages": List[Dict[str, Any]], "response_format": Dict[str, Any]}. + + tokenizer : Optional[Any] + The tokenizer object to use for tokenizing the prompts. + + seed : Optional[int] + The seed for the random number generator. + """ + random.seed(seed) + self.tokenizer = tokenizer + if not self.tokenizer: + from transformers import ( # pylint: disable=import-outside-toplevel,import-error + LlamaTokenizerFast, + ) + + self.tokenizer = LlamaTokenizerFast.from_pretrained( + "hf-internal-testing/llama-tokenizer" + ) + logger.warning("No tokenizer provided. Using default tokenizer.") + + self.prompts: List[Dict] = [] + if prompts_path is not None and prompts_path.endswith(".jsonl"): + with open(prompts_path, "r", encoding="utf-8") as file: + for line in file: + json_line = json.loads(line) + assert "prompt" in json_line, "The prompt field is required in the JSONL file." + if "prompt_tokens" not in json_line: + json_line["prompt_tokens"] = self._count_tokens(json_line["prompt"]) + self.prompts.append(json_line) + else: + if not prompts_path: + prompts_path = Path(__file__).parent / "prompts.txt" # type: ignore + with open(prompts_path, "r", encoding="utf-8") as file: + prompt_line = file.readline() + prompt_tokens = self._count_tokens(prompt_line) + self.prompts.append({"prompt": prompt_line, "prompt_tokens": prompt_tokens}) + if json_prompts_path: + self.json_prompts = defaultdict(list) + with open(json_prompts_path, "r", encoding="utf-8") as file: + for line in file: + json_line = json.loads(line) + assert ( + "messages" in json_line + ), "The messages field is required in the JSONL file." + assert ( + "response_format" in json_line + ), "The response_format field is required in the JSONL file." + self.json_prompts[json.dumps(json_line["response_format"]["schema"])].append( + json_line["messages"] + ) + else: + self.json_prompts = None + + def _count_tokens(self, text: str) -> int: + """Get the number of tokens. + + Parameters + ---------- + text : str + The text to tokenize. + + Returns + ------- + output : int + The number of tokens + """ + return len(self.tokenizer.encode(text)) + + def generate_prompt(self, params: Dict[str, Any]) -> Dict[str, Any]: + """ + Generates a prompt based on the params, e.g. prompt_tokens, response_format. + + Parameters + ---------- + params : Dict[str, Any] + The desired mean number of tokens in the prompt. + + Returns + ------- + override_params: Dict[str, Any] + The params to override the original request, e.g. messages, response_format. + """ + if "response_format" in params: + response_format = params["response_format"] + if response_format.get("type") == "json_object": + if response_format.get("schema") in self.json_prompts: + assert len(self.json_prompts[response_format["schema"]]) > 0 + return {"messages": random.choice(self.json_prompts[response_format["schema"]])} + schema, prompts = random.choice(list(self.json_prompts.items())) + response_format["schema"] = schema + return {"messages": random.choice(prompts), "response_format": response_format} + tokens_mean = params.get("prompt_tokens", 128) + assert tokens_mean > 0, "The mean number of tokens must be greater than 0." + remaining_prompt_tokens = tokens_mean + result_prompt = "" + override_params = None + while remaining_prompt_tokens > 0: + prompt_dict = random.choice(self.prompts) + cur_prompt_tokens = prompt_dict["prompt_tokens"] + cur_prompt = prompt_dict["prompt"] + if override_params is None: + override_params = prompt_dict["override_params"] + if remaining_prompt_tokens - cur_prompt_tokens < 0: + result_prompt += cur_prompt[:remaining_prompt_tokens] + remaining_prompt_tokens = 0 + break + result_prompt += cur_prompt + remaining_prompt_tokens -= cur_prompt_tokens + return {"messages": [{"role": "system", "content": result_prompt}]} diff --git a/python/mlc_llm/bench/replay.py b/python/mlc_llm/bench/replay.py new file mode 100644 index 0000000000..65fb325c34 --- /dev/null +++ b/python/mlc_llm/bench/replay.py @@ -0,0 +1,115 @@ +"""MLC LLM bench replay request""" +import asyncio +import json +from datetime import datetime +from typing import Dict, List, Optional + + +def load_replay_log(log_path: str) -> List[Dict]: + """ + Load replay log from file + + Parameters + ---------- + log_path : str + The path to the event log CSV or JSONL file containing the events to replay. + + Returns + ------- + res: List[Dict] + A list of preprocessed event data for replay. + """ + if log_path.endswith(".csv"): + import pandas as pd # pylint: disable=import-outside-toplevel,import-error + + df = pd.read_csv(log_path) + column_names = df.columns.values + assert ( + ("Date" in column_names) + and ("@request" in column_names) + and ("Message" in column_names) + ) + df["timestamp"] = pd.to_datetime(df["Date"]) + df.sort_values("timestamp", inplace=True) + # Get the request params from the loaded CSV + params = [] + for _, row in df.iterrows(): + request = row["@request"] + payload = json.loads(str(request)) + params.append( + { + "timestamp": row["timestamp"], + "payload": payload, + } + ) + return params + if log_path.endswith(".jsonl"): + with open(log_path, "r", encoding="utf-8") as file: + data = [json.loads(line) for line in file] + for row in data: + row["timestamp"] = datetime.fromisoformat(str(row["timestamp"])) + return data + raise ValueError("Unsupported file format. Please use .csv or .jsonl.") + + +async def replay( + replay_log: List[Dict], + callback, + *, + base_timestamp: Optional[float] = None, + start_timestamp: Optional[float] = None, + max_schedule_gap: Optional[float] = 0.1, + wait_until_last_task_done: bool = True, +): # pylint: disable=too-many-arguments + """ + Replay generated events based on historical timestamps. The replaying requests start + from a new start time while preserving the ordering of requests. + + Parameters + ---------- + replay_log : List[Dict] + A list of event data, each containing a 'timestamp' and replay parameters. + + callback : coroutine function + The async function to be called for each log item. + + base_timestamp : Optional[float] + The timestamp of the first log entry, used as a reference point for scheduling. + Defaults to the timestamp of the first item in `replay_log`. + + start_timestamp : Optional[float] + The time when the replay starts. + + max_schedule_gap : Optional[float] + The maximum allowed delay between the scheduled time in seconds. Defaults to 0.1 seconds. + + wait_until_last_task_done : bool + Whether to wait until the last task is done. Defaults to True. + + Raises + ------ + TypeError + If the callback is not a coroutine or an awaitable function. + """ + if not replay_log: + return + loop = asyncio.get_running_loop() + if base_timestamp is None: + base_timestamp = replay_log[0]["timestamp"].timestamp() + if start_timestamp is None: + start_timestamp = loop.time() + max_schedule_gap + + for item in replay_log: + cur_time = loop.time() + launch_time = item["timestamp"].timestamp() - base_timestamp + start_timestamp + if launch_time - cur_time > max_schedule_gap: + await asyncio.sleep(launch_time - cur_time - max_schedule_gap) + loop.call_at( + launch_time, + lambda: asyncio.create_task(callback(item)), # pylint: disable=cell-var-from-loop + ) + + if wait_until_last_task_done: + # Wait for all tasks to be scheduled + await asyncio.sleep(launch_time - loop.time() + max_schedule_gap) + await asyncio.gather(*asyncio.all_tasks(loop) - {asyncio.current_task()}) diff --git a/python/mlc_llm/bench/request.py b/python/mlc_llm/bench/request.py new file mode 100644 index 0000000000..eea0a8afa4 --- /dev/null +++ b/python/mlc_llm/bench/request.py @@ -0,0 +1,201 @@ +"""MLC LLM Bench Request""" + +import json +import os +import time +from typing import Any, Dict, List, Optional + +from openai import AsyncOpenAI +from pydantic import BaseModel +from typing_extensions import Self + +from mlc_llm.protocol.openai_api_protocol import ChatCompletionRequest +from mlc_llm.support import logging + +from .prompts import PromptsGenerator + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +class RequestRecords(BaseModel): + """The request records collected from LLM inference requests.""" + + input: str + output: str + end_to_end_latency_s: float + ttft: Optional[float] = None + server_metrics: Optional[Dict] = None + + +class OpenAIRequestSender: # pylint: disable=too-many-instance-attributes + """ + Manages the sending of requests to a specified API endpoint and gathers inference statistics. + + Parameters + ---------- + host : Optional[str] + The host address for the API, defaulting to "127.0.0.1". + port : Optional[int] + The port number for the API, defaulting to 8008. + stream : Optional[bool] + Specifies if streaming should be enabled, default is True. + timeout : Optional[float] + The maximum duration in seconds for each request, default is 180. + client : Optional[Any] + The client to use for sending requests. + include_server_metrics : Optional[bool] + Specifies if server metrics should be included, default is False. + prompt_generator : Optional[PromptsGenerator] + The prompt generator for missing messages fields. + + Attributes + ---------- + stats : dict + Statistics about the performance. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + host: Optional[str] = "127.0.0.1", + port: Optional[int] = 8008, + stream: Optional[bool] = None, + timeout: Optional[float] = None, + client: Optional[Any] = None, + include_server_metrics: Optional[bool] = False, + prompt_generator: Optional[PromptsGenerator] = None, + ) -> None: + import aiohttp # pylint: disable=import-outside-toplevel,import-error + from transformers import ( # pylint: disable=import-outside-toplevel,import-error + LlamaTokenizerFast, + ) + + self.stream = stream + self.timeout = timeout + self.tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") + self.prompt_generator = PromptsGenerator() if prompt_generator is None else prompt_generator + self.request_records: List[RequestRecords] = [] + self.client = client if client else aiohttp.ClientSession() + self.include_server_metrics = include_server_metrics + self.url = f"http://{host}:{port}/v1/chat/completions" + self.headers = {"Content-Type": "application/json"} + if os.getenv("MLC_LLM_API_KEY"): + self.headers["Authorization"] = f"Bearer {os.getenv('MLC_LLM_API_KEY')}" + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + await self.client.close() + + async def __call__( # pylint: disable=too-many-locals, too-many-branches, too-many-statements + self, params: Dict[str, Any] = None + ) -> None: + if "messages" not in params: + override_params = self.prompt_generator.generate_prompt(params) + assert "messages" in override_params, "override params must contain messages field" + params.update(override_params) + prompt = params["messages"][-1]["content"] + chat_params = self._get_chat_completion_params(params) + if "stream" not in chat_params: + chat_params["stream"] = self.stream + if "timeout" not in chat_params: + chat_params["timeout"] = self.timeout + if self.include_server_metrics: + if "stream_options" not in chat_params: + chat_params["stream_options"] = {"include_usage": True} + else: + chat_params["stream_options"]["include_usage"] = True + + total_request_time = 0 + generated_text = "" + ttft = None + start_time = time.monotonic() + server_metrics = None + + # AsyncOpenAI chat completion + if isinstance(self.client, AsyncOpenAI): + response = await self.client.chat.completions.create(**chat_params) + if chat_params["stream"]: + async for chunk in response: + if chunk.usage: + server_metrics = chunk.usage.extra + elif chunk.choices[0].delta.content is not None: + if not ttft: + ttft = time.monotonic() - start_time # type: ignore + generated_text += chunk.choices[0].delta.content + else: + generated_text = response.choices[0].message.content + else: + try: + async with self.client.post( + self.url, json=chat_params, headers=self.headers + ) as response: + if chat_params["stream"]: + async for chunk in response.content: + chunk = chunk.strip() + if not chunk or chunk == b"\n": + continue + # Get rid of the prefix "data: " and suffix "\n" + raw_data = chunk[6:].strip() + if raw_data == b"[DONE]": + continue + data = json.loads(raw_data) + if data["usage"] is not None: + server_metrics = data["usage"]["extra"] + if not data["choices"]: + continue + delta = data["choices"][0]["delta"] + if delta.get("content", None): + if not ttft: + ttft = time.monotonic() - start_time + + generated_text += delta["content"] + else: + data = await response.json() + generated_text = data["choices"][0]["message"]["content"] + except Exception as e: # pylint: disable=broad-except + logger.error("Error sending request: %s", str(e)) + raise e + + total_request_time = time.monotonic() - start_time # type: ignore + + req_rec = RequestRecords( + input=prompt, + output=generated_text, + end_to_end_latency_s=total_request_time, + ttft=ttft, + server_metrics=server_metrics, + ) + self.request_records.append(req_rec) + + def _get_chat_completion_params(self, params: Dict) -> Dict: + """ + Extract chat completion parameters from the provided request parameters. + + Parameters + ---------- + params : Dict[str, Any] + The parameters for the request. + + Returns + ------- + result : Dict + The chat completion parameters. + """ + chat_completion_params = {} + for k, _ in ChatCompletionRequest.model_fields.items(): + if k in params: + chat_completion_params[k] = params[k] + return chat_completion_params + + def get_request_records(self) -> List[RequestRecords]: + """ + Retrieve the collected reqeust records. + + Returns + ------- + request_records : List[RequestRecords] + The list of collected request records. + """ + return self.request_records diff --git a/python/mlc_llm/cli/bench.py b/python/mlc_llm/cli/bench.py deleted file mode 100644 index 26b74b1f10..0000000000 --- a/python/mlc_llm/cli/bench.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Command line entrypoint of benchmark.""" -from mlc_llm.help import HELP -from mlc_llm.interface.bench import bench -from mlc_llm.interface.chat import ChatConfigOverride -from mlc_llm.support.argparse import ArgumentParser - - -def main(argv): - """Parse command line arguments and call `mlc_llm.interface.bench`.""" - parser = ArgumentParser("MLC LLM Chat CLI") - - parser.add_argument( - "model", - type=str, - help=HELP["model"] + " (required)", - ) - parser.add_argument( - "--prompt", - type=str, - default="What is the meaning of life?", - help=HELP["prompt"] + ' (default: "%(default)s")', - ) - parser.add_argument( - "--opt", - type=str, - default="O2", - help=HELP["opt"] + ' (default: "%(default)s")', - ) - parser.add_argument( - "--device", - type=str, - default="auto", - help=HELP["device_deploy"] + ' (default: "%(default)s")', - ) - parser.add_argument( - "--overrides", - type=ChatConfigOverride.from_str, - default="", - help=HELP["chatconfig_overrides"] + ' (default: "%(default)s")', - ) - parser.add_argument( - "--generate-length", - type=int, - default=256, - help=HELP["generate_length"] + ' (default: "%(default)s")', - ) - parser.add_argument( - "--model-lib-path", - type=str, - default=None, - help=HELP["model_lib_path"] + ' (default: "%(default)s")', - ) - parsed = parser.parse_args(argv) - bench( - model=parsed.model, - prompt=parsed.prompt, - device=parsed.device, - opt=parsed.opt, - overrides=parsed.overrides, - generate_length=parsed.generate_length, - model_lib_path=parsed.model_lib_path, - ) diff --git a/python/mlc_llm/cli/benchmark.py b/python/mlc_llm/cli/benchmark.py deleted file mode 100644 index 72c86fab03..0000000000 --- a/python/mlc_llm/cli/benchmark.py +++ /dev/null @@ -1,86 +0,0 @@ -"""A command line tool for benchmarking a chat model.""" -import argparse -from pathlib import Path - -from mlc_llm import ChatConfig, ChatModule - -parser = argparse.ArgumentParser(description="Benchmark an MLC LLM ChatModule.") -parser.add_argument( - "--model", - type=str, - help="""The model folder after compiling with MLC-LLM build process. The parameter can either - be the model name with its quantization scheme (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a - full path to the model folder. In the former case, we will use the provided name to search for - the model folder over possible paths.""", - required=True, -) -parser.add_argument( - "--model-lib", - type=str, - help="""The compiled model library. In MLC LLM, an LLM is compiled to a shared or static - library (.so or .a), which contains GPU computation to efficiently run the LLM. MLC Chat, - as the runtime of MLC LLM, depends on the compiled model library to generate tokens. - """, - required=False, -) -parser.add_argument( - "--tensor-parallel-shards", - "--num-shards", - type=int, - help="Number of GPUs to be used.", - dest="tensor_parallel_shards", - required=False, -) -parser.add_argument( - "--device", - type=str, - help="""The description of the device to run on. User should provide a string in the form of - 'device_name:device_id' or 'device_name', where 'device_name' is one of 'cuda', 'metal', - 'vulkan', 'rocm', 'opencl', and 'device_id' is the device id to run on. If no 'device_id' is - provided, it will be set to 0 by default. - """, - required=True, -) -parser.add_argument( - "--prompt", - type=str, - help="The prompt to generate from.", - required=True, -) -parser.add_argument( - "--generate-length", - type=int, - help="The length (numer of tokens) of the generated text.", - required=True, -) - - -def _load_prompt(path_or_prompt: str) -> str: - """Load the prompt from a file or use the provided prompt.""" - try: - path = Path(path_or_prompt) - if path.is_file(): - with path.open("r", encoding="utf-8") as in_file: - return in_file.read() - except: # pylint: disable=bare-except - pass - return path_or_prompt - - -def main(): - """The main function that runs the benchmarking.""" - args = parser.parse_args() - chat_module = ChatModule( - model=args.model, - device=args.device, - chat_config=ChatConfig(tensor_parallel_shards=args.tensor_parallel_shards), - model_lib_path=args.model_lib, - ) - prompt = _load_prompt(args.prompt) - output = chat_module.benchmark_generate(prompt, generate_length=args.generate_length) - print(f"Generated text:\n{output}\n") - print(f"Statistics: {chat_module.stats(verbose=True)}") - - -if __name__ == "__main__": - main() diff --git a/python/mlc_llm/cli/calibrate.py b/python/mlc_llm/cli/calibrate.py new file mode 100644 index 0000000000..87c81161bb --- /dev/null +++ b/python/mlc_llm/cli/calibrate.py @@ -0,0 +1,73 @@ +"""Command line entrypoint of calibration.""" + +from mlc_llm.interface.calibrate import calibrate +from mlc_llm.interface.help import HELP +from mlc_llm.support.argparse import ArgumentParser + +from .serve import EngineConfigOverride + + +def main(argv): + """Main entrypoint for calibration.""" + parser = ArgumentParser("MLC LLM Calibration CLI") + parser.add_argument( + "model", + type=str, + help=HELP["model"] + " (required)", + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help=HELP["device_deploy"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--model-lib", + type=str, + default=None, + help=HELP["model_lib"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--output", "-o", type=str, required=True, help=HELP["output_calibration"] + " (required)" + ) + # Download dataset from + # https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + parser.add_argument( + "--dataset", type=str, required=True, help=HELP["calibration_dataset"] + " (required)" + ) + + parser.add_argument( + "--num-calibration-samples", + type=int, + default=16, + help=HELP["num_calibration_samples"] + ' (default: "%(default)s")', + ) + + parser.add_argument( + "--seed", + type=int, + default=0, + help=HELP["seed_calibrate"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--overrides", + type=EngineConfigOverride.from_str, + default="", + help=HELP["overrides_serve"], + ) + + parsed = parser.parse_args(argv) + calibrate( + model=parsed.model, + device=parsed.device, + model_lib=parsed.model_lib, + output=parsed.output, + dataset=parsed.dataset, + num_calibration_samples=parsed.num_calibration_samples, + max_num_sequence=parsed.overrides.max_num_sequence, + max_total_sequence_length=parsed.overrides.max_total_seq_length, + prefill_chunk_size=parsed.overrides.prefill_chunk_size, + max_history_size=parsed.overrides.max_history_size, + gpu_memory_utilization=parsed.overrides.gpu_memory_utilization, + seed=parsed.seed, + ) diff --git a/python/mlc_llm/cli/chat.py b/python/mlc_llm/cli/chat.py index 13c83a64ec..cb2d0899f7 100644 --- a/python/mlc_llm/cli/chat.py +++ b/python/mlc_llm/cli/chat.py @@ -1,6 +1,7 @@ """Command line entrypoint of chat.""" -from mlc_llm.help import HELP -from mlc_llm.interface.chat import ChatConfigOverride, chat + +from mlc_llm.interface.chat import ModelConfigOverride, chat +from mlc_llm.interface.help import HELP from mlc_llm.support.argparse import ArgumentParser @@ -13,12 +14,6 @@ def main(argv): type=str, help=HELP["model"] + " (required)", ) - parser.add_argument( - "--opt", - type=str, - default="O2", - help=HELP["opt"] + ' (default: "%(default)s")', - ) parser.add_argument( "--device", type=str, @@ -26,22 +21,21 @@ def main(argv): help=HELP["device_deploy"] + ' (default: "%(default)s")', ) parser.add_argument( - "--overrides", - type=ChatConfigOverride.from_str, - default="", - help=HELP["chatconfig_overrides"] + ' (default: "%(default)s")', - ) - parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, default=None, - help=HELP["model_lib_path"] + ' (default: "%(default)s")', + help=HELP["model_lib"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--overrides", + type=ModelConfigOverride.from_str, + default="", + help=HELP["modelconfig_overrides"] + ' (default: "%(default)s")', ) parsed = parser.parse_args(argv) chat( model=parsed.model, device=parsed.device, - opt=parsed.opt, + model_lib=parsed.model_lib, overrides=parsed.overrides, - model_lib_path=parsed.model_lib_path, ) diff --git a/python/mlc_llm/cli/check_device.py b/python/mlc_llm/cli/check_device.py index a78fd4d6d5..d6099f7ce3 100644 --- a/python/mlc_llm/cli/check_device.py +++ b/python/mlc_llm/cli/check_device.py @@ -1,4 +1,5 @@ """Check if a device exists.""" + import sys from tvm.runtime import Device diff --git a/python/mlc_llm/cli/compile.py b/python/mlc_llm/cli/compile.py index 7d7025a91f..ebf7bc630c 100644 --- a/python/mlc_llm/cli/compile.py +++ b/python/mlc_llm/cli/compile.py @@ -1,4 +1,5 @@ """Command line entrypoint of compilation.""" + import argparse import json import re @@ -6,12 +7,12 @@ from pathlib import Path from typing import Union -from mlc_llm.help import HELP from mlc_llm.interface.compile import ( # pylint: disable=redefined-builtin ModelConfigOverride, OptimizationFlags, compile, ) +from mlc_llm.interface.help import HELP from mlc_llm.model import MODELS from mlc_llm.quantization import QUANTIZATION from mlc_llm.support.argparse import ArgumentParser @@ -24,7 +25,7 @@ def main(argv): - """Parse command line argumennts and call `mlc_llm.compiler.compile`.""" + """Parse command line arguments and call `mlc_llm.compiler.compile`.""" def _parse_output(path: Union[str, Path]) -> Path: path = Path(path) diff --git a/python/mlc_llm/cli/convert_weight.py b/python/mlc_llm/cli/convert_weight.py index 08d98c421d..01d6886b2a 100644 --- a/python/mlc_llm/cli/convert_weight.py +++ b/python/mlc_llm/cli/convert_weight.py @@ -1,10 +1,11 @@ """Command line entrypoint of weight conversion.""" + import argparse from pathlib import Path from typing import Union -from mlc_llm.help import HELP from mlc_llm.interface.convert_weight import convert_weight +from mlc_llm.interface.help import HELP from mlc_llm.model import MODELS from mlc_llm.quantization import QUANTIZATION from mlc_llm.support.argparse import ArgumentParser diff --git a/python/mlc_llm/cli/delivery.py b/python/mlc_llm/cli/delivery.py deleted file mode 100644 index a7dd6408b0..0000000000 --- a/python/mlc_llm/cli/delivery.py +++ /dev/null @@ -1,286 +0,0 @@ -"""Continuous model delivery for MLC LLM models.""" - -import argparse -import dataclasses -import json -import os -import shutil -import subprocess -import sys -import tempfile -from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple, Union - -from huggingface_hub import HfApi # pylint: disable=import-error -from huggingface_hub.utils import HfHubHTTPError # pylint: disable=import-error - -from mlc_llm.support import logging -from mlc_llm.support.argparse import ArgumentParser -from mlc_llm.support.constants import MLC_TEMP_DIR -from mlc_llm.support.download import git_clone -from mlc_llm.support.style import bold, green, red - -logging.enable_logging() -logger = logging.getLogger(__name__) - -GEN_CONFIG_OPTIONAL_ARGS = [ - "context_window_size", - "sliding_window_size", - "prefill_chunk_size", - "attention_sink_size", - "tensor_parallel_shards", -] - - -@dataclasses.dataclass -class ModelInfo: # pylint: disable=too-many-instance-attributes - """Necessary information for the model delivery""" - - model_id: str - model: Path - conv_template: str - quantization: str - source_format: str = "auto" - # If unspecified in CLI, remains to be None and will not be - # passed to `gen_config` or `convert_weight` - context_window_size: int = None - sliding_window_size: int = None - prefill_chunk_size: int = None - attention_sink_size: int = None - tensor_parallel_shards: int = None - - -class DeferredScope: - """A context manager that defers execution of functions until exiting the scope.""" - - def __init__(self): - self.deferred_functions = [] - - def add(self, func: Callable[[], None]): - """Add a function to be executed when exiting the scope.""" - self.deferred_functions.append(func) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - for func in reversed(self.deferred_functions): - func() - return False - - def create_temp_dir(self) -> Path: - """Create a temporary directory that will be deleted when exiting the scope.""" - temp_dir = tempfile.mkdtemp(dir=MLC_TEMP_DIR) - self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True)) - return Path(temp_dir) - - -def _clone_repo(model: Union[str, Path], deferred: DeferredScope) -> Path: - if isinstance(model, Path): - if not model.exists(): - raise ValueError(f"Invalid model source: {model}") - return model - if model.startswith("https://") or model.startswith("git://"): - result = deferred.create_temp_dir() / "repo" - git_clone(model, result, ignore_lfs=False) - return result - result = Path(model) - if result.exists(): - return result - raise ValueError(f"Invalid model source: {model}") - - -def _run_quantization( - model_info: ModelInfo, - repo: str, - api: HfApi, -) -> bool: - logger.info("[HF] Creating repo https://huggingface.co/%s", repo) - try: - api.create_repo(repo_id=repo, private=False) - except HfHubHTTPError as error: - if error.response.status_code != 409: - raise - logger.info("[HF] Repo already exists. Recreating...") - api.delete_repo(repo_id=repo) - api.create_repo(repo_id=repo, private=False) - logger.info("[HF] Repo recreated") - succeeded = True - with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as output_dir: - log_path = Path(output_dir) / "logs.txt" - with log_path.open("a", encoding="utf-8") as log_file: - assert isinstance(model_info.model, Path) - logger.info("[MLC] Processing in directory: %s", output_dir) - # Required arguments - cmd = [ - sys.executable, - "-m", - "mlc_llm", - "gen_config", - str(model_info.model), - "--quantization", - model_info.quantization, - "--conv-template", - model_info.conv_template, - "--output", - output_dir, - ] - # Optional arguments - for optional_arg in GEN_CONFIG_OPTIONAL_ARGS: - optional_arg_val = getattr(model_info, optional_arg, None) - if optional_arg_val is not None: - # e.g. --context-window-size 4096 - cmd += ["--" + optional_arg.replace("_", "-"), str(optional_arg_val)] - - print(" ".join(cmd), file=log_file, flush=True) - subprocess.run( - cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ - ) - cmd = [ - sys.executable, - "-m", - "mlc_llm", - "convert_weight", - str(model_info.model), - "--quantization", - model_info.quantization, - "--source-format", - model_info.source_format, - "--output", - output_dir, - ] - print(" ".join(cmd), file=log_file, flush=True) - subprocess.run( - cmd, check=False, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ - ) - logger.info("[MLC] Complete!") - if not (Path(output_dir) / "ndarray-cache.json").exists(): - logger.error( - "[%s] Model %s. Quantization %s. No weights metadata found.", - red("FAILED"), - model_info.model_id, - model_info.quantization, - ) - succeeded = False - logger.info("[HF] Uploading to: https://huggingface.co/%s", repo) - for _retry in range(10): - try: - api.upload_folder( - folder_path=output_dir, - repo_id=repo, - commit_message="Initial commit", - ) - except Exception as exc: # pylint: disable=broad-except - logger.error("[%s] %s. Retrying...", red("FAILED"), exc) - else: - break - else: - raise RuntimeError("Failed to upload to HuggingFace Hub with 10 retries") - return succeeded - - -def _main( # pylint: disable=too-many-locals - username: str, - api: HfApi, - spec: Dict[str, Any], -): - failed_cases: List[Tuple[str, str]] = [] - for task_index, task in enumerate(spec["tasks"], 1): - with DeferredScope() as deferred: - logger.info( - bold("[{task_index}/{total_tasks}] Processing model: ").format( - task_index=task_index, - total_tasks=len(spec["tasks"]), - ) - + green(task["model_id"]) - ) - model = _clone_repo(task["model"], deferred) - for quantization in spec["default_quantization"] + task.get("quantization", []): - model_info = { - "model_id": task["model_id"], - "model": model, - "conv_template": task["conv_template"], - } - # Process optional arguments - for optional_arg in GEN_CONFIG_OPTIONAL_ARGS: - # e.g. "context_window_size": task.get("context_window_size", None) - model_info[optional_arg] = task.get(optional_arg, None) - if isinstance(quantization, str): - model_info["quantization"] = quantization - else: - model_info["quantization"] = quantization.pop("format") - model_info.update(quantization) - repo = spec.get("destination", "{username}/{model_id}-{quantization}-MLC").format( - username=username, - model_id=model_info["model_id"], - quantization=model_info["quantization"], - ) - logger.info( - "%s%s. %s%s. %s%s", - bold("Model: "), - green(task["model_id"]), - bold("Quantization: "), - green(model_info["quantization"]), - bold("Repo: "), - green(f"https://huggingface.co/{repo}"), - ) - with DeferredScope() as inner_deferred: - model_info["model"] = _clone_repo(model_info["model"], inner_deferred) - result = _run_quantization( - ModelInfo(**model_info), - repo=spec["destination"].format( - username=username, - model_id=model_info["model_id"], - quantization=model_info["quantization"], - ), - api=api, - ) - if not result: - failed_cases.append( - (task["model_id"], model_info["quantization"]), - ) - if failed_cases: - logger.info("Total %s %s:", len(failed_cases), red("failures")) - for model_id, quantization in failed_cases: - logger.info(" Model %s. Quantization %s.", model_id, quantization) - - -def main(): - """Entry point.""" - - def _load_spec(path_spec: str) -> Dict[str, Any]: - path = Path(path_spec) - if not path.exists(): - raise argparse.ArgumentTypeError(f"Spec file does not exist: {path}") - with path.open("r", encoding="utf-8") as i_f: - return json.load(i_f) - - parser = ArgumentParser("MLC LLM continuous model delivery") - parser.add_argument( - "--username", - type=str, - required=True, - help="HuggingFace username", - ) - parser.add_argument( - "--token", - type=str, - required=True, - help="HuggingFace access token, obtained under https://huggingface.co/settings/tokens", - ) - parser.add_argument( - "--spec", - type=_load_spec, - required=True, - help="Path to the spec file", - ) - parsed = parser.parse_args() - _main( - parsed.username, - spec=parsed.spec, - api=HfApi(token=parsed.token), - ) - - -if __name__ == "__main__": - main() diff --git a/python/mlc_llm/cli/gen_config.py b/python/mlc_llm/cli/gen_config.py index b58b546678..62944e2285 100644 --- a/python/mlc_llm/cli/gen_config.py +++ b/python/mlc_llm/cli/gen_config.py @@ -1,9 +1,10 @@ """Command line entrypoint of configuration generation.""" + from pathlib import Path from typing import Union -from mlc_llm.help import HELP from mlc_llm.interface.gen_config import CONV_TEMPLATES, gen_config +from mlc_llm.interface.help import HELP from mlc_llm.model import MODELS from mlc_llm.quantization import QUANTIZATION from mlc_llm.support.argparse import ArgumentParser diff --git a/python/mlc_llm/cli/model_metadata.py b/python/mlc_llm/cli/model_metadata.py index 81473b1ec7..80f63ff34f 100644 --- a/python/mlc_llm/cli/model_metadata.py +++ b/python/mlc_llm/cli/model_metadata.py @@ -85,20 +85,18 @@ def _compute_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBa temp_func_bytes = 0.0 for _func_name, func_bytes in metadata["memory_usage"].items(): temp_func_bytes = max(temp_func_bytes, func_bytes) - kv_cache_bytes = metadata["kv_cache_bytes"] - return params_bytes, temp_func_bytes, kv_cache_bytes + return params_bytes, temp_func_bytes def _report_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBase]) -> None: - params_bytes, temp_func_bytes, kv_cache_bytes = _compute_memory_usage(metadata, config) - total_size = params_bytes + temp_func_bytes + kv_cache_bytes + params_bytes, temp_func_bytes = _compute_memory_usage(metadata, config) + total_size = params_bytes + temp_func_bytes logger.info( - "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB)", - green("Total memory usage"), + "%s: %.2f MB (Parameters: %.2f MB. Temporary buffer: %.2f MB)", + green("Total memory usage without KV cache:"), total_size / 1024 / 1024, params_bytes / 1024 / 1024, - kv_cache_bytes / 1024 / 1024, temp_func_bytes / 1024 / 1024, ) @@ -108,23 +106,6 @@ def _report_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBas ) -def _print_memory_usage_in_json(metadata: Dict[str, Any], config: Dict) -> None: - params_bytes, temp_func_bytes, kv_cache_bytes = _compute_memory_usage(metadata, config) - print( - json.dumps( - { - "params_bytes": params_bytes, - "temp_func_bytes": temp_func_bytes, - "kv_cache_bytes": kv_cache_bytes, - } - ) - ) - - -def _print_kv_cache_metadata_in_json(metadata: Dict[str, Any]) -> None: - print(json.dumps(metadata["kv_cache"])) - - def main(): """Entry point for the model metadata tool.""" parser = ArgumentParser(description="A tool that inspects the metadata of a model lib.") @@ -154,16 +135,6 @@ def main(): the basic information in JSON. """, ) - parser.add_argument( - "--print-memory-usage-in-json-only", - action="store_true", - help="""If set, only inspect the metadata in memory usage and print usage in raw JSON.""", - ) - parser.add_argument( - "--print-kv-cache-metadata-in-json-only", - action="store_true", - help="""If set, only inspect the metadata in KV cache and print usage in raw JSON.""", - ) parsed = parser.parse_args() # Load metadata from model lib try: @@ -180,12 +151,8 @@ def main(): with open(mlc_chat_config_path, "r", encoding="utf-8") as config_file: cfg = json.load(config_file) # Main body - if parsed.print_memory_usage_in_json_only: - _print_memory_usage_in_json(metadata, cfg) - elif parsed.memory_only: + if parsed.memory_only: _report_memory_usage(metadata, cfg) - elif parsed.print_kv_cache_metadata_in_json_only: - _print_kv_cache_metadata_in_json(metadata) else: _report_all(metadata) diff --git a/python/mlc_llm/cli/package.py b/python/mlc_llm/cli/package.py new file mode 100644 index 0000000000..9628d51400 --- /dev/null +++ b/python/mlc_llm/cli/package.py @@ -0,0 +1,68 @@ +"""Command line entrypoint of package.""" + +import os +from pathlib import Path +from typing import Union + +from mlc_llm.interface.help import HELP +from mlc_llm.interface.package import package +from mlc_llm.support.argparse import ArgumentParser + + +def main(argv): + """Parse command line arguments and call `mlc_llm.interface.package`.""" + parser = ArgumentParser("MLC LLM Package CLI") + + def _parse_package_config(path: Union[str, Path]) -> Path: + path = Path(path) + if not path.exists(): + raise ValueError( + f"Path {str(path)} is expected to be a JSON file, but the file does not exist." + ) + if not path.is_file(): + raise ValueError(f"Path {str(path)} is expected to be a JSON file.") + return path + + def _parse_mlc_llm_source_dir(path: str) -> Path: + os.environ["MLC_LLM_SOURCE_DIR"] = path + return Path(path) + + def _parse_output(path: Union[str, Path]) -> Path: + path = Path(path) + if not path.is_dir(): + path.mkdir(parents=True, exist_ok=True) + return path + + parser.add_argument( + "--package-config", + type=_parse_package_config, + default="mlc-package-config.json", + help=HELP["config_package"] + ' (default: "%(default)s")', + ) + parser.add_argument( + "--mlc-llm-source-dir", + type=_parse_mlc_llm_source_dir, + default=os.environ.get("MLC_LLM_SOURCE_DIR", None), + help=HELP["mlc_llm_source_dir"] + + " (default: the $MLC_LLM_SOURCE_DIR environment variable)", + ) + parser.add_argument( + "--output", + "-o", + type=_parse_output, + default="dist", + help=HELP["output_package"] + ' (default: "%(default)s")', + ) + parsed = parser.parse_args(argv) + if parsed.mlc_llm_source_dir is None: + raise ValueError( + "MLC LLM home is not specified. " + "Please obtain a copy of MLC LLM source code by " + "cloning https://github.com/mlc-ai/mlc-llm, and set environment variable " + '"MLC_LLM_SOURCE_DIR=path/to/mlc-llm"' + ) + package( + package_config_path=parsed.package_config, + mlc_llm_source_dir=parsed.mlc_llm_source_dir, + output=parsed.output, + ) diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index 6663a0c230..28d01ad4b6 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -1,13 +1,84 @@ """Command line entrypoint of serve.""" +import dataclasses import json +from io import StringIO +from typing import Optional -from mlc_llm.help import HELP +from mlc_llm.interface.help import HELP from mlc_llm.interface.serve import serve -from mlc_llm.serve.config import SpeculativeMode +from mlc_llm.support import argparse from mlc_llm.support.argparse import ArgumentParser +@dataclasses.dataclass +class EngineConfigOverride: # pylint: disable=too-many-instance-attributes + """Arguments for overriding engine config.""" + + # Overrides for EngineConfig (runtime) + max_num_sequence: Optional[int] = None + max_total_seq_length: Optional[int] = None + prefill_chunk_size: Optional[int] = None + max_history_size: Optional[int] = None + gpu_memory_utilization: Optional[float] = None + spec_draft_length: Optional[int] = None + prefix_cache_max_num_recycling_seqs: Optional[int] = None + context_window_size: Optional[int] = None + sliding_window_size: Optional[int] = None + attention_sink_size: Optional[int] = None + tensor_parallel_shards: Optional[int] = None + + def __repr__(self) -> str: + out = StringIO() + print(f"max_num_sequence={self.max_num_sequence}", file=out, end="") + print(f";max_total_seq_length={self.max_total_seq_length}", file=out, end="") + print(f";prefill_chunk_size={self.prefill_chunk_size}", file=out, end="") + print(f";max_history_size={self.max_history_size}", file=out, end="") + print(f";gpu_memory_utilization={self.gpu_memory_utilization}", file=out, end="") + print(f";spec_draft_length={self.spec_draft_length}", file=out, end="") + print( + f";prefix_cache_max_num_recycling_seqs={self.prefix_cache_max_num_recycling_seqs}", + file=out, + end="", + ) + print(f";context_window_size={self.context_window_size}", file=out, end="") + print(f";sliding_window_size={self.sliding_window_size}", file=out, end="") + print(f";attention_sink_size={self.attention_sink_size}", file=out, end="") + print(f";tensor_parallel_shards={self.tensor_parallel_shards}", file=out, end="") + return out.getvalue().rstrip() + + @staticmethod + def from_str(source: str) -> "EngineConfigOverride": + """Parse engine config override values from a string.""" + parser = argparse.ArgumentParser(description="Engine config override values") + + parser.add_argument("--max_num_sequence", type=int, default=None) + parser.add_argument("--max_total_seq_length", type=int, default=None) + parser.add_argument("--prefill_chunk_size", type=int, default=None) + parser.add_argument("--max_history_size", type=int, default=None) + parser.add_argument("--gpu_memory_utilization", type=float, default=None) + parser.add_argument("--spec_draft_length", type=int, default=None) + parser.add_argument("--prefix_cache_max_num_recycling_seqs", type=int, default=None) + parser.add_argument("--context_window_size", type=int, default=None) + parser.add_argument("--sliding_window_size", type=int, default=None) + parser.add_argument("--attention_sink_size", type=int, default=None) + parser.add_argument("--tensor_parallel_shards", type=int, default=None) + results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) + return EngineConfigOverride( + max_num_sequence=results.max_num_sequence, + max_total_seq_length=results.max_total_seq_length, + prefill_chunk_size=results.prefill_chunk_size, + max_history_size=results.max_history_size, + gpu_memory_utilization=results.gpu_memory_utilization, + spec_draft_length=results.spec_draft_length, + prefix_cache_max_num_recycling_seqs=results.prefix_cache_max_num_recycling_seqs, + context_window_size=results.context_window_size, + sliding_window_size=results.sliding_window_size, + attention_sink_size=results.attention_sink_size, + tensor_parallel_shards=results.tensor_parallel_shards, + ) + + def main(argv): """Parse command line arguments and call `mlc_llm.interface.serve`.""" parser = ArgumentParser("MLC LLM Serve CLI") @@ -24,10 +95,10 @@ def main(argv): help=HELP["device_deploy"] + ' (default: "%(default)s")', ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, default=None, - help=HELP["model_lib_path"] + ' (default: "%(default)s")', + help=HELP["model_lib"] + ' (default: "%(default)s")', ) parser.add_argument( "--mode", @@ -37,28 +108,32 @@ def main(argv): help=HELP["mode_serve"] + ' (default: "%(default)s")', ) parser.add_argument( - "--additional-models", type=str, nargs="*", help=HELP["additional_models_serve"] - ) - parser.add_argument("--max-batch-size", type=int, help=HELP["max_batch_size"]) - parser.add_argument( - "--max-total-seq-length", type=int, help=HELP["max_total_sequence_length_serve"] + "--enable-debug", + action="store_true", + help="whether we enable debug end points and debug config when accepting requests", ) - parser.add_argument("--prefill-chunk-size", type=int, help=HELP["prefill_chunk_size_serve"]) parser.add_argument( - "--max-history-size", type=int, default=1, help=HELP["max_history_size_serve"] + "--additional-models", type=str, nargs="*", help=HELP["additional_models_serve"] ) parser.add_argument( - "--gpu-memory-utilization", type=float, help=HELP["gpu_memory_utilization_serve"] + "--speculative-mode", + type=str, + choices=["disable", "small_draft", "eagle", "medusa"], + default="disable", + help=HELP["speculative_mode_serve"] + ' (default: "%(default)s")', ) parser.add_argument( - "--speculative-mode", + "--prefix-cache-mode", type=str, - choices=["DISABLE", "SMALL_DRAFT", "EAGLE"], - default="DISABLE", - help=HELP["speculative_mode_serve"], + choices=["disable", "radix"], + default="radix", + help=HELP["prefix_cache_mode_serve"] + ' (default: "%(default)s")', ) parser.add_argument( - "--spec-draft-length", type=int, default=4, help=HELP["spec_draft_length_serve"] + "--overrides", + type=EngineConfigOverride.from_str, + default="", + help=HELP["overrides_serve"], ) parser.add_argument("--enable-tracing", action="store_true", help=HELP["enable_tracing_serve"]) parser.add_argument( @@ -94,19 +169,35 @@ def main(argv): ) parsed = parser.parse_args(argv) + additional_models = [] + if parsed.additional_models is not None: + for additional_model in parsed.additional_models: + splits = additional_model.split(",", maxsplit=1) + if len(splits) == 2: + additional_models.append((splits[0], splits[1])) + else: + additional_models.append(splits[0]) + serve( model=parsed.model, device=parsed.device, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, mode=parsed.mode, - additional_models=parsed.additional_models, - max_batch_size=parsed.max_batch_size, - max_total_sequence_length=parsed.max_total_seq_length, - prefill_chunk_size=parsed.prefill_chunk_size, - max_history_size=parsed.max_history_size, - gpu_memory_utilization=parsed.gpu_memory_utilization, - speculative_mode=SpeculativeMode[parsed.speculative_mode], - spec_draft_length=parsed.spec_draft_length, + enable_debug=parsed.enable_debug, + additional_models=additional_models, + tensor_parallel_shards=parsed.overrides.tensor_parallel_shards, + speculative_mode=parsed.speculative_mode, + prefix_cache_mode=parsed.prefix_cache_mode, + max_num_sequence=parsed.overrides.max_num_sequence, + max_total_sequence_length=parsed.overrides.max_total_seq_length, + max_single_sequence_length=parsed.overrides.context_window_size, + prefill_chunk_size=parsed.overrides.prefill_chunk_size, + sliding_window_size=parsed.overrides.sliding_window_size, + attention_sink_size=parsed.overrides.attention_sink_size, + max_history_size=parsed.overrides.max_history_size, + gpu_memory_utilization=parsed.overrides.gpu_memory_utilization, + spec_draft_length=parsed.overrides.spec_draft_length, + prefix_cache_max_num_recycling_seqs=parsed.overrides.prefix_cache_max_num_recycling_seqs, enable_tracing=parsed.enable_tracing, host=parsed.host, port=parsed.port, diff --git a/python/mlc_llm/cli/worker.py b/python/mlc_llm/cli/worker.py index 5f64e30cb7..0975853865 100644 --- a/python/mlc_llm/cli/worker.py +++ b/python/mlc_llm/cli/worker.py @@ -24,6 +24,10 @@ from .. import base # pylint: disable=unused-import, no-name-in-module +# NOTE(@sunggg): This is disabled because we use a separate calibration runtime that does not require ffi +# register the calibration functions +# from ..interface import calibrate # pylint: disable=unused-import + def main(): """Main worker function""" diff --git a/python/mlc_llm/compiler_pass/__init__.py b/python/mlc_llm/compiler_pass/__init__.py index 762ba8c1e0..23a5b25785 100644 --- a/python/mlc_llm/compiler_pass/__init__.py +++ b/python/mlc_llm/compiler_pass/__init__.py @@ -1,2 +1,3 @@ """Compiler passes used in MLC LLM.""" + from . import pipeline as _pipeline diff --git a/python/mlc_llm/compiler_pass/attach_logit_processor.py b/python/mlc_llm/compiler_pass/attach_logit_processor.py index 8dabf3dcfd..fe891e9d72 100644 --- a/python/mlc_llm/compiler_pass/attach_logit_processor.py +++ b/python/mlc_llm/compiler_pass/attach_logit_processor.py @@ -166,7 +166,7 @@ def _apply_bitmask_inplace( logits[seq_ids[vs], vv] = T.if_then_else( (bitmask[seq_ids[vs], vv // 32] >> (vv % 32)) & 1 == 1, logits[seq_ids[vs], vv], - T.float32(-1e10), + T.min_value("float32"), ) return _apply_bitmask_inplace diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index 46dc40c106..733537c8b2 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -7,7 +7,8 @@ from tvm.relax.frontend import nn from tvm.script import tir as T -from ..op.batch_spec_verify import batch_spec_verify +from mlc_llm.op.batch_spec_verify import batch_spec_verify +from mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm @tvm.transform.module_pass(opt_level=0, name="AttachGPUSamplingFunc") @@ -27,29 +28,20 @@ def __init__(self, target: tvm.target.Target, variable_bounds: Dict[str, int]): def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: """Entrypoint""" - if str(self.target.kind) != "cuda": + if str(self.target.kind) not in ["cuda", "vulkan"]: # Only enable GPU sampling for CUDA. return mod bb = relax.BlockBuilder(mod) - # Prefill method exists in base models. - # Prefill_to_last_hidden method exists in base model and speculative small models - if "prefill" in mod: - vocab_size = mod["prefill"].ret_struct_info.fields[0].shape[-1] - else: - assert ( - "prefill_to_last_hidden_states" in mod - ), "Everay model should either has 'prefill' or 'prefill_to_last_hidden_states' method" - vocab_size = mod["prefill_to_last_hidden_states"].ret_struct_info.fields[0].shape[-1] gv_names = [ gv.name_hint for gv in [ - _attach_multinomial_sampling_func(bb, vocab_size), - _attach_argsort_func(bb, vocab_size), - _attach_sample_with_top_p(bb, vocab_size), - _attach_take_probs_func(bb, vocab_size), - _attach_batch_verifier(bb, vocab_size), - _attach_renormalize_by_top_p(bb, vocab_size), + _attach_multinomial_sampling_func(bb), + _attach_argsort_func(bb), + _attach_sample_with_top_p(bb), + _attach_take_probs_func(bb), + _attach_batch_verifier(bb), + _attach_renormalize_by_top_p(bb, self.target), ] ] @@ -63,9 +55,10 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR return mod -def _attach_multinomial_sampling_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): +def _attach_multinomial_sampling_func(bb: relax.BlockBuilder): batch_size = tir.Var("batch_size", "int64") num_samples = tir.Var("num_samples", "int64") + vocab_size = tir.Var("vocab_size", "int64") probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) uniform_samples = relax.Var( "uniform_samples", relax.TensorStructInfo((num_samples,), "float32") @@ -94,7 +87,11 @@ def _attach_multinomial_sampling_func(bb: relax.BlockBuilder, vocab_size: tir.Pr name="sample_indices", ) result_tensor = nn.multinomial_from_uniform( # pylint:disable=too-many-function-args - probs_tensor, uniform_samples_tensor, sample_indices_tensor, "int32" + probs_tensor, + uniform_samples_tensor, + sample_indices_tensor, + "int32", + name="nn_multinomial_from_uniform", ) result = bb.emit( relax.call_pure_packed( @@ -104,12 +101,14 @@ def _attach_multinomial_sampling_func(bb: relax.BlockBuilder, vocab_size: tir.Pr sinfo_args=sample_indices.struct_info, # pylint: disable=no-member ) ) - gv = bb.emit_func_output(result) + output = bb.emit_output(result) + gv = bb.emit_func_output(output) return gv -def _attach_argsort_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): +def _attach_argsort_func(bb: relax.BlockBuilder): batch_size = tir.Var("batch_size", "int64") + vocab_size = tir.Var("vocab_size", "int64") probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) with bb.function("argsort_probs", [probs]): with bb.dataflow(): @@ -124,8 +123,7 @@ def _attach_argsort_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): sorted_indices, primfunc_name_hint="take_sorted_probs", ) - output = (sorted_values, sorted_indices) - bb.emit_output(output) + output = bb.emit_output((sorted_values, sorted_indices)) gv = bb.emit_func_output(output) return gv @@ -141,11 +139,10 @@ def full(var_result: T.handle, value: T.int32): result[vi, 0] = value -def _attach_sample_with_top_p( # pylint: disable=too-many-locals - bb: relax.BlockBuilder, vocab_size: tir.PrimExpr -): +def _attach_sample_with_top_p(bb: relax.BlockBuilder): # pylint: disable=too-many-locals batch_size = tir.Var("batch_size", "int64") num_samples = tir.Var("num_samples", "int64") + vocab_size = tir.Var("vocab_size", "int64") sorted_probs = relax.Var( "sorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32") ) @@ -214,7 +211,7 @@ def _attach_sample_with_top_p( # pylint: disable=too-many-locals sample_indices_tensor, ) ) - result = bb.emit( + result = bb.emit_output( relax.call_pure_packed( "vm.builtin.reshape", result_tensor._expr, # pylint: disable=protected-access @@ -222,53 +219,46 @@ def _attach_sample_with_top_p( # pylint: disable=too-many-locals sinfo_args=sample_indices.struct_info, # pylint: disable=no-member ) ) - bb.emit_output(result) gv = bb.emit_func_output(result) return gv -def _attach_renormalize_by_top_p(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): +def _attach_renormalize_by_top_p(bb: relax.BlockBuilder, target: tvm.target.Target): batch_size = tir.Var("batch_size", "int64") + vocab_size = tir.Var("vocab_size", "int64") + num_pivots = 3 probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) - sorted_probs = relax.Var( - "sorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32") - ) top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32")) - with bb.function("renormalize_by_top_p", [probs, sorted_probs, top_p]): + init_pivots = relax.Var( + "init_pivots", relax.TensorStructInfo((batch_size, num_pivots), "float32") + ) + with bb.function("renormalize_by_top_p", [probs, top_p, init_pivots]): with bb.dataflow(): - probs_tensor = nn.wrap_nested(probs, name="probs") - sorted_probs_tensor = nn.wrap_nested(sorted_probs, name="sorted_probs") - top_p_shape = relax.ShapeExpr([batch_size, 1]) - top_p_tensor = nn.wrap_nested( - relax.call_pure_packed( - "vm.builtin.reshape", - top_p, - top_p_shape, - sinfo_args=relax.TensorStructInfo(top_p_shape, "float32"), - ), - name="sample_indices", - ) - top_k_tensor = nn.tensor_ir_op( - full, - name_hint="full", - args=[vocab_size], - out=nn.Tensor.placeholder( - [batch_size, 1], - "int32", - ), + cutoff_output = bb.emit( + relax.call_tir( + bb.add_func(top_p_pivot(num_pivots, target), "top_p_pivot_cutoff"), + args=[probs, top_p, init_pivots], + out_sinfo=[top_p.struct_info, top_p.struct_info], # pylint: disable=no-member + ) ) - renormalized_probs = nn.renormalize_top_p_top_k_prob( - probs_tensor, sorted_probs_tensor, top_p_tensor, top_k_tensor + final_pivot = cutoff_output[0] + renorm_sum = cutoff_output[1] + renormalized_probs = bb.emit_output( + relax.call_tir( + bb.add_func(top_p_renorm(target), "top_p_renorm_after_cutoff"), + args=[probs, final_pivot, renorm_sum], + out_sinfo=probs.struct_info, # pylint: disable=no-member + ) ) - bb.emit_output(renormalized_probs._expr) # pylint: disable=protected-access - gv = bb.emit_func_output(renormalized_probs._expr) # pylint: disable=protected-access + gv = bb.emit_func_output(renormalized_probs) return gv -def _attach_take_probs_func(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): +def _attach_take_probs_func(bb: relax.BlockBuilder): batch_size = tir.Var("batch_size", "int64") num_samples = tir.Var("num_samples", "int64") num_positions = tir.Var("num_positions", "int64") + vocab_size = tir.Var("vocab_size", "int64") unsorted_probs = relax.Var( "unsorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32") ) @@ -319,7 +309,7 @@ def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-argument args = [unsorted_probs, sorted_indices, sample_indices, sampling_results, top_prob_offsets] with bb.function("sampler_take_probs", args): with bb.dataflow(): - taken_probs_indices = bb.emit( + taken_probs_indices = bb.emit_output( relax.call_tir( bb.add_func(sampler_take_probs_tir, "sampler_take_probs_tir"), args, @@ -330,14 +320,14 @@ def sampler_take_probs_tir( # pylint: disable=too-many-locals,too-many-argument ], ) ) - bb.emit_output(taken_probs_indices) gv = bb.emit_func_output(taken_probs_indices) return gv -def _attach_batch_verifier(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): +def _attach_batch_verifier(bb: relax.BlockBuilder): num_nodes = tir.Var("num_nodes", "int64") nbatch = tir.Var("nbatch", "int64") + vocab_size = tir.Var("vocab_size", "int64") draft_probs = relax.Var( "draft_probs", relax.TensorStructInfo((num_nodes, vocab_size), "float32") ) @@ -366,7 +356,7 @@ def _attach_batch_verifier(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): ] with bb.function("sampler_verify_draft_tokens", args): with bb.dataflow(): - res = bb.emit( + res = bb.emit_output( relax.call_tir_inplace( bb.add_func(batch_spec_verify(vocab_size), "batch_verify_on_gpu_single_kernel"), args, @@ -377,6 +367,5 @@ def _attach_batch_verifier(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): ], ) ) - bb.emit_output(res) gv = bb.emit_func_output(res) return gv diff --git a/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py b/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py new file mode 100644 index 0000000000..f454ab1b85 --- /dev/null +++ b/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py @@ -0,0 +1,243 @@ +"""A compiler pass that attaches two-stage softmax with temperature.""" + +import tvm +from tvm import relax, tir +from tvm.ir.module import IRModule +from tvm.relax.expr_functor import PyExprMutator, mutator +from tvm.script import tir as T + +from ..support.max_thread_check import get_max_num_threads_per_block + + +@tvm.transform.module_pass(opt_level=0, name="AttachSoftmaxWithTemperature") +class AttachSoftmaxWithTemperature: # pylint: disable=too-few-public-methods + """Rewrites one-shot softmax into two-stage softmax.""" + + def __init__(self, target: tvm.target.Target) -> None: + self.target = target + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + return _Rewriter(mod, self.target).transform() + + +@mutator +class _Rewriter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: IRModule, target: tvm.target.Target) -> None: + super().__init__(mod) + self.mod = mod + self.target = target + self.chunk_size = 4096 + + def transform(self) -> IRModule: + """Entry point""" + batch_size = tir.Var("batch_size", "int64") + vocab_size = tir.Var("vocab_size", "int64") + dtype = "float32" + logits = relax.Var("logits", relax.TensorStructInfo([batch_size, 1, vocab_size], dtype)) + temperature = relax.Var("temperature", relax.TensorStructInfo([batch_size], dtype)) + with self.builder_.function("softmax_with_temperature", params=[logits, temperature]): + with self.builder_.dataflow(): + output_struct_info = logits.struct_info # pylint: disable=no-member + new_shape = relax.ShapeExpr([batch_size, vocab_size]) + logits = relax.call_pure_packed( + "vm.builtin.reshape", + logits, + new_shape, + sinfo_args=relax.TensorStructInfo(new_shape, dtype), + ) + f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func( + self.target, self.chunk_size + ) + chunked_result_struct_info = relax.TensorStructInfo( + (batch_size, (vocab_size + self.chunk_size - 1) // self.chunk_size), + "float32", + ) + chunked_results = self.builder_.emit( + relax.call_tir( + self.builder_.add_func(f_chunk_lse, "chunk_lse"), + args=[logits, temperature], + out_sinfo=[chunked_result_struct_info, chunked_result_struct_info], + ) + ) + chunked_sum = chunked_results[0] + chunked_max = chunked_results[1] + softmax = self.builder_.emit( + relax.call_tir( + self.builder_.add_func(f_softmax_with_lse, "softmax_with_chunked_sum"), + args=[logits, temperature, chunked_sum, chunked_max], + out_sinfo=logits.struct_info, + ) + ) + softmax = self.builder_.emit_output( + relax.call_pure_packed( + "vm.builtin.reshape", + softmax, + output_struct_info.shape, + sinfo_args=output_struct_info, + ) + ) + self.builder_.emit_func_output(softmax) + return self.builder_.get() + + +def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements + target: tvm.target.Target, chunk_size: int +): + # NOTE: A quick note on the softmax implementation. + # We once tried to multiply every element by log2e which can be computed + # potentially more efficiently on hardware. + # However, when the input values are large, multiplying by the factor of log2e + # causes numerical issue in float32 dtype. + # This leads to the softmax output not summing up to 1. + # For numerical stability, we removed the log2e factor and switched back + # to the standard log/exp computation. + + # The kernels below handle both the cases of temperature=0 and temperature != 0. + # - When temperature is not 0, the first kernel computes the log-sum-exp of + # chunks (subtracted by the max value in chunk), and the max values of chunks. + # The second kernel merges the log-sum-exp with the maximum values. + # - When temperature is 0, the first kernel computes the max value and the counts + # of the max value. The second kernel merges the max and counts, and set the + # softmax of the maximum values to "max_value / max_count". + + # pylint: disable=invalid-name + @T.prim_func + def chunk_lse( # pylint: disable=too-many-locals + var_A: T.handle, + var_temperature: T.handle, + var_chunked_sum: T.handle, + var_chunked_max: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + temperature = T.match_buffer(var_temperature, (batch_size,), dtype="float32") + chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks), dtype="float32") + chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks), dtype="float32") + A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(chunk_size)), dtype="float32") + temp_max = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + temp_sum = T.alloc_buffer((batch_size, num_chunks), dtype="float32") + + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + A_pad[v0, v1, v2] = T.if_then_else( + v1 * T.int64(chunk_size) + v2 < vocab_size, + T.if_then_else( + temperature[v0] > T.float32(1e-5), + A[v0, v1 * T.int64(chunk_size) + v2] / temperature[v0], + A[v0, v1 * T.int64(chunk_size) + v2], + ), + T.min_value("float32"), + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("max"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_max[v0, v1] = T.min_value("float32") + temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2]) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("sum_exp"): + v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) + with T.init(): + temp_sum[v0, v1] = T.float32(0) + temp_sum[v0, v1] += T.if_then_else( + v1 * T.int64(chunk_size) + v2 < vocab_size, + T.Select( + temperature[v0] > T.float32(1e-5), + T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), + T.cast(A_pad[v0, v1, v2] == temp_max[v0, v1], "float32"), + ), + T.float32(0), + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)): + with T.block("log"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + chunked_sum[v0, v1] = T.Select( + temperature[v0] > T.float32(1e-5), + T.log(temp_sum[v0, v1]), + temp_sum[v0, v1], + ) + chunked_max[v0, v1] = temp_max[v0, v1] + + @T.prim_func + def softmax_with_chunked_sum( + var_A: T.handle, + var_temperature: T.handle, + var_chunked_sum: T.handle, + var_chunked_max: T.handle, + var_softmax: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + batch_size = T.int64(is_size_var=True) + vocab_size = T.int64(is_size_var=True) + num_chunks = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32") + temperature = T.match_buffer(var_temperature, (batch_size,), dtype="float32") + chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks), dtype="float32") + chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks), dtype="float32") + softmax = T.match_buffer(var_softmax, (batch_size, vocab_size), dtype="float32") + temp_max = T.alloc_buffer((batch_size,), dtype="float32") + temp_sum = T.alloc_buffer((batch_size,), dtype="float32") + for l0, l1 in T.grid(batch_size, num_chunks): + with T.block("max"): + v0, v1 = T.axis.remap("SR", [l0, l1]) + with T.init(): + temp_max[v0] = T.min_value("float32") + temp_max[v0] = T.max(temp_max[v0], chunked_max[v0, v1]) + for l0, l1 in T.grid(batch_size, num_chunks): + with T.block("sum_exp"): + v0, v1 = T.axis.remap("SR", [l0, l1]) + with T.init(): + temp_sum[v0] = T.float32(0) + temp_sum[v0] += T.Select( + temperature[v0] > T.float32(1e-5), + T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max[v0]), + T.cast(chunked_max[v0, v1] == temp_max[v0], "float32") * chunked_sum[v0, v1], + ) + for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): + with T.block("log_pad"): + v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) + if v1 * T.int64(chunk_size) + v2 < vocab_size: + softmax[v0, v1 * T.int64(chunk_size) + v2] = T.if_then_else( + temperature[v0] > T.float32(1e-5), + T.exp( + A[v0, v1 * T.int64(chunk_size) + v2] / temperature[v0] + - (T.log(temp_sum[v0]) + temp_max[v0]) + ), + T.cast(A[v0, v1 * T.int64(chunk_size) + v2] == temp_max[v0], "float32") + / temp_sum[v0], + ) + + sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_sum": softmax_with_chunked_sum})) + max_threads = get_max_num_threads_per_block(target) + TX = 32 + TY = max_threads // TX + unroll_depth = 64 + # pylint: enable=invalid-name + + sch.work_on("softmax_with_chunked_sum") + l0, l1, l2 = sch.get_loops("log_pad") + bx = sch.fuse(l0, l1) + sch.bind(bx, "blockIdx.x") + unroll, ty, tx = sch.split(l2, [None, TY, TX]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1) + + for block_name in ["sum_exp", "max"]: + block = sch.get_block(block_name) + sch.set_scope(block, buffer_index=0, storage_scope="shared") + sch.compute_at(block, bx) + r_loop = sch.get_loops(block)[-1] + r_loop, tx = sch.split(r_loop, [None, TX]) + sch.reorder(tx, r_loop) + sch.bind(tx, "threadIdx.x") + sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) + + return chunk_lse, sch.mod["softmax_with_chunked_sum"] diff --git a/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py b/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py new file mode 100644 index 0000000000..ef3d6af722 --- /dev/null +++ b/python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py @@ -0,0 +1,123 @@ +"""The pass that attaches logit processor functions to the IRModule.""" + +import tvm +from tvm import IRModule, relax, tir +from tvm.relax import BlockBuilder, TensorStructInfo +from tvm.script import tir as T + + +@tvm.transform.module_pass(opt_level=0, name="AttachSpecDecodeAuxFuncs") +class AttachSpecDecodeAuxFuncs: # pylint: disable=too-few-public-methods + """Attach logit processing TIR functions to IRModule.""" + + tensor_parallel_shards: int + + def __init__(self, tensor_parallel_shards: int): + self.tensor_parallel_shards = tensor_parallel_shards + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + mod = mod.clone() + bb = BlockBuilder(mod) + bb.add_func( + _get_scatter_2d_inplace(dtype="float32", global_symbol="scatter_probs"), "scatter_probs" + ) + bb.add_func( + _get_gather_2d_inplace(dtype="float32", global_symbol="gather_probs"), "gather_probs" + ) + if "prefill_to_last_hidden_states" in mod: + hidden_states_struct_info = mod["prefill_to_last_hidden_states"].ret_struct_info.fields[ + 0 + ] # pylint: disable=no-member + dtype = hidden_states_struct_info.dtype + _add_gather_hidden_states(bb, self.tensor_parallel_shards, dtype) + _add_scatter_hidden_states(bb, self.tensor_parallel_shards, dtype) + return bb.finalize() + + +def _get_scatter_2d_inplace(dtype: str, global_symbol: str): + @T.prim_func + def _scatter_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): + T.func_attr({"global_symbol": global_symbol, "tir.noalias": True}) + batch_size = T.int32(is_size_var=True) + m = T.int32(is_size_var=True) + n = T.int32(is_size_var=True) + src = T.match_buffer(var_src, (batch_size, n), dtype) + indices = T.match_buffer(var_indices, (batch_size,), "int32") + dst = T.match_buffer(var_dst, (m, n), dtype) + for b, j in T.grid(batch_size, n): + with T.block("scatter_2d"): + vb, vj = T.axis.remap("SS", [b, j]) + dst[indices[vb], vj] = src[vb, vj] + + return _scatter_2d + + +def _get_gather_2d_inplace(dtype: str, global_symbol: str): + @T.prim_func + def _gather_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): + T.func_attr({"global_symbol": global_symbol, "tir.noalias": True}) + batch_size = T.int32(is_size_var=True) + m = T.int32(is_size_var=True) + n = T.int32(is_size_var=True) + src = T.match_buffer(var_src, (m, n), dtype) + indices = T.match_buffer(var_indices, (batch_size,), "int32") + dst = T.match_buffer(var_dst, (batch_size, n), dtype) + for b, j in T.grid(batch_size, n): + with T.block("gather_2d"): + vb, vj = T.axis.remap("SS", [b, j]) + dst[vb, vj] = src[indices[vb], vj] + + return _gather_2d + + +def _add_scatter_hidden_states(bb: BlockBuilder, tensor_parallel_shards: int, dtype: str): + batch_size = tir.Var("batch_size", "int64") + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + src = relax.Var("src", struct_info=TensorStructInfo([batch_size, n], dtype)) + indices = relax.Var("indices", struct_info=TensorStructInfo([batch_size], "int32")) + dst = relax.Var("dst", struct_info=TensorStructInfo([m, n], dtype)) + with bb.function("scatter_hidden_states", [src, indices, dst]): + with bb.dataflow(): + if tensor_parallel_shards > 1: + indices = relax.op.ccl.broadcast_from_worker0(indices) + output = bb.emit_output( + relax.op.call_tir_inplace( + bb.add_func( + _get_scatter_2d_inplace(dtype, "_scatter_hidden_states"), + "_scatter_hidden_states", + ), + [src, indices, dst], + 2, + dst.struct_info, # pylint: disable=no-member + ) + ) + gv = bb.emit_func_output(output) + return gv + + +def _add_gather_hidden_states(bb: BlockBuilder, tensor_parallel_shards: int, dtype: str): + batch_size = tir.Var("batch_size", "int64") + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + src = relax.Var("src", struct_info=TensorStructInfo([m, n], dtype)) + indices = relax.Var("indices", struct_info=TensorStructInfo([batch_size], "int32")) + dst = relax.Var("dst", struct_info=TensorStructInfo([batch_size, n], dtype)) + with bb.function("gather_hidden_states", [src, indices, dst]): + with bb.dataflow(): + if tensor_parallel_shards > 1: + indices = relax.op.ccl.broadcast_from_worker0(indices) + output = bb.emit_output( + relax.op.call_tir_inplace( + bb.add_func( + _get_gather_2d_inplace(dtype, "_gather_hidden_states"), + "_gather_hidden_states", + ), + [src, indices, dst], + 2, + dst.struct_info, # pylint: disable=no-member + ) + ) + gv = bb.emit_func_output(output) + return gv diff --git a/python/mlc_llm/compiler_pass/clean_up_tir_attrs.py b/python/mlc_llm/compiler_pass/clean_up_tir_attrs.py index f7c9ad2f48..4828bcf115 100644 --- a/python/mlc_llm/compiler_pass/clean_up_tir_attrs.py +++ b/python/mlc_llm/compiler_pass/clean_up_tir_attrs.py @@ -1,4 +1,5 @@ """A compiler pass that cleans up undesired TIR attrs.""" + from typing import List import tvm diff --git a/python/mlc_llm/compiler_pass/cublas_dispatch.py b/python/mlc_llm/compiler_pass/cublas_dispatch.py index 231048628c..d0e7d76f87 100644 --- a/python/mlc_llm/compiler_pass/cublas_dispatch.py +++ b/python/mlc_llm/compiler_pass/cublas_dispatch.py @@ -1,4 +1,5 @@ """A compiler pass that dispatches patterns to CUBLAS.""" + import tvm import tvm.relax.backend.contrib.cublas as _cublas from tvm import IRModule, relax @@ -20,10 +21,15 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR model_names = [ gv.name_hint for gv, func in mod.functions.items() if isinstance(func, relax.Function) ] + # exclude single batch decode + model_names = [name for name in model_names if "batch" in name or "decode" not in name] mod = tvm.transform.Sequential( [ relax.transform.FuseOpsByPattern( - patterns, bind_constants=False, annotate_codegen=True + patterns, + bind_constants=False, + annotate_codegen=True, + entry_functions=model_names, ), relax.transform.RunCodegen({}, entry_functions=model_names), ] diff --git a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py index d9d478cd1f..20e4c7bdd9 100644 --- a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py +++ b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py @@ -155,7 +155,7 @@ def create_flashinfer_paged_kv_cache( in self.metadata["model_type"] ) # filter by attention group size - or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 6, 8] + or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 8] ): return diff --git a/python/mlc_llm/compiler_pass/estimate_memory_usage.py b/python/mlc_llm/compiler_pass/estimate_memory_usage.py index 6776a055ac..448461382d 100644 --- a/python/mlc_llm/compiler_pass/estimate_memory_usage.py +++ b/python/mlc_llm/compiler_pass/estimate_memory_usage.py @@ -1,4 +1,5 @@ """Memory usage estimation analysis function for Relax functions.""" + import json from typing import Any, Dict @@ -24,8 +25,6 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR func_name = "_metadata" - func_name = "_metadata" - def _emit_metadata(metadata): bb = relax.BlockBuilder() # pylint: disable=invalid-name with bb.function(func_name, params=[]): diff --git a/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py b/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py index 0943828933..bab82500c5 100644 --- a/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py +++ b/python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py @@ -1,4 +1,5 @@ """A compiler pass that fuses dequantize + matmul + elementwise.""" + import tvm from tvm import IRModule, relax from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_op, wildcard @@ -16,7 +17,7 @@ def transform_module( """IRModule-level transformation""" seq = [] for n_aux_tensor in [0, 1, 2, 3, 4]: - for match_ewise in [0, 1, 2, 6]: + for match_ewise in [0, 1, 2, 3, 6]: if match_ewise == 6 and n_aux_tensor != 4: continue seq.append( diff --git a/python/mlc_llm/compiler_pass/fuse_dequantize_take.py b/python/mlc_llm/compiler_pass/fuse_dequantize_take.py index 80792159ba..c95eddf285 100644 --- a/python/mlc_llm/compiler_pass/fuse_dequantize_take.py +++ b/python/mlc_llm/compiler_pass/fuse_dequantize_take.py @@ -1,4 +1,5 @@ """A compiler pass that fuses dequantize + take.""" + import tvm from tvm import IRModule, relax, tir from tvm.relax.dpl.pattern import ( diff --git a/python/mlc_llm/compiler_pass/fuse_dequantize_transpose.py b/python/mlc_llm/compiler_pass/fuse_dequantize_transpose.py index d89f62ccd6..0556dfc332 100644 --- a/python/mlc_llm/compiler_pass/fuse_dequantize_transpose.py +++ b/python/mlc_llm/compiler_pass/fuse_dequantize_transpose.py @@ -1,4 +1,5 @@ """A compiler pass that fuses transpose + dequantize.""" + import tvm from tvm import relax, tir from tvm.ir.module import IRModule diff --git a/python/mlc_llm/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py b/python/mlc_llm/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py index c5a4094fac..b97adfb9e4 100644 --- a/python/mlc_llm/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py +++ b/python/mlc_llm/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py @@ -1,4 +1,5 @@ """A compiler pass that fuses dequantize matmul + epilogue.""" + import operator from functools import reduce diff --git a/python/mlc_llm/compiler_pass/fuse_transpose_matmul.py b/python/mlc_llm/compiler_pass/fuse_transpose_matmul.py index 5b3ecec860..6bbb815e9c 100644 --- a/python/mlc_llm/compiler_pass/fuse_transpose_matmul.py +++ b/python/mlc_llm/compiler_pass/fuse_transpose_matmul.py @@ -1,4 +1,5 @@ """A compiler pass that fuses transpose + matmul.""" + import tvm from tvm import IRModule, relax, te, tir from tvm.relax.dpl.pattern import is_op, wildcard diff --git a/python/mlc_llm/compiler_pass/low_batch_specialization.py b/python/mlc_llm/compiler_pass/low_batch_specialization.py index 63b29fb2ec..c6d802cf27 100644 --- a/python/mlc_llm/compiler_pass/low_batch_specialization.py +++ b/python/mlc_llm/compiler_pass/low_batch_specialization.py @@ -1,4 +1,5 @@ """A compiler pass that dispatch low-batch-gemm to gemv schedule.""" + import tvm from tvm import dlight as dl from tvm import tir diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index 57b68f742d..f47027edd8 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -15,6 +15,8 @@ from .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc from .attach_logit_processor import AttachLogitProcessFunc from .attach_sampler import AttachGPUSamplingFunc +from .attach_softmax_with_temperature import AttachSoftmaxWithTemperature +from .attach_spec_decode_aux_funcs import AttachSpecDecodeAuxFuncs from .attach_support_info import ( AttachAdditionalPrimFuncs, AttachCUDAGraphSymbolicCaptureHints, @@ -33,7 +35,6 @@ from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc from .low_batch_specialization import LowBatchGemvSpecialize -from .rewrite_softmax import RewriteTwoStageSoftmax from .scatter_tuple_get_item import ScatterTupleGetItem logger = logging.getLogger(__name__) @@ -91,6 +92,7 @@ def _mlc_llm_pipeline( # pylint: disable=too-many-arguments additional_tirs = additional_tirs or {} metadata = metadata or {} ext_mods = ext_mods or [] + tensor_parallel_shards = metadata.get("tensor_parallel_shards", 1) @tvm.transform.module_pass(opt_level=0) def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: @@ -98,12 +100,14 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I [ # Phase 0. Add additional information for compilation and remove unused Relax func DispatchKVCacheCreation(target, flashinfer, metadata), + AttachSoftmaxWithTemperature(target), AttachVariableBounds(variable_bounds), AttachCUDAGraphSymbolicCaptureHints(cuda_graph_symbolic_capture_hints), AttachLogitProcessFunc(target), AttachAdditionalPrimFuncs(additional_tirs), AttachAllocEmbeddingTensorFunc(metadata), AttachGPUSamplingFunc(target, variable_bounds), + AttachSpecDecodeAuxFuncs(tensor_parallel_shards), AttachMemoryPlanAttr(), tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)), _DebugDump("debug-phase0.py", debug_dump, show_meta=False), @@ -117,8 +121,8 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I _DebugDump("debug-phase1.py", debug_dump, show_meta=False), # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline _LogProgress("Lowering to TVM TIR kernels"), + tvm.relax.backend.DispatchSampling(), tvm.relax.backend.DispatchSortScan(), - RewriteTwoStageSoftmax(target=target), tvm.relax.transform.LegalizeOps(), tvm.relax.transform.AnnotateTIROpPattern(), tvm.relax.transform.FoldConstant(), diff --git a/python/mlc_llm/contrib/__init__.py b/python/mlc_llm/contrib/__init__.py new file mode 100644 index 0000000000..aa101df354 --- /dev/null +++ b/python/mlc_llm/contrib/__init__.py @@ -0,0 +1 @@ +"""Set of experimental components that yet to be matured.""" diff --git a/tests/python/__init__.py b/python/mlc_llm/contrib/embeddings/__init__.py similarity index 100% rename from tests/python/__init__.py rename to python/mlc_llm/contrib/embeddings/__init__.py diff --git a/python/mlc_llm/contrib/embeddings/embeddings.py b/python/mlc_llm/contrib/embeddings/embeddings.py new file mode 100644 index 0000000000..ff18a10096 --- /dev/null +++ b/python/mlc_llm/contrib/embeddings/embeddings.py @@ -0,0 +1,180 @@ +"""The Python API for MLC Embeddings.""" + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import tvm +from tvm import relax +from tvm.contrib import tvmjs +from tvm.runtime import Device, Module +from tvm.runtime.relax_vm import VirtualMachine + +from mlc_llm.serve import engine_utils +from mlc_llm.support.auto_device import detect_device +from mlc_llm.tokenizers import Tokenizer + + +def _extract_metadata(mod: Module): + return json.loads(VirtualMachine(mod, tvm.runtime.device("cpu"))["_metadata"]()) + + +def _load_params( + model_weight_path: str, device: Device, model_metadata: Dict[str, Any] +) -> List[tvm.nd.NDArray]: + params, meta = tvmjs.load_ndarray_cache(model_weight_path, device) + param_names = [param["name"] for param in model_metadata["params"]] + assert len(param_names) == meta["ParamSize"] + + plist = [] + for param_name in param_names: + plist.append(params[param_name]) + return plist + + +def _get_tvm_module( + model_weight_path: str, lib_path: str, device: Device, instrument: tvm.runtime.PackedFunc = None +): + ex = tvm.runtime.load_module(lib_path) + vm = relax.VirtualMachine(ex, device) + if instrument: + vm.set_instrument(instrument) + metadata = _extract_metadata(ex) + params = _load_params(model_weight_path, device, metadata) + return vm.module, params, metadata + + +class DefaultDebugInstrument: + """The default debug instrument to use if users don't specify + a customized one. + + This debug instrument will dump the arguments and output of each + VM Call instruction into a .npz file. It will also alert the user + if any function outputs are NaN or INF. + """ + + def __init__(self, debug_out: Path): + """Constructor + + Parameters + ---------- + debug_out : Path + the directory to dump the .npz files + """ + self.counter = 0 + self.first_nan_occurred = False + self.first_inf_occurred = False + self.debug_out = debug_out + debug_out.mkdir(exist_ok=True, parents=True) + + def reset(self, debug_out: Path): + """Reset the state of the Instrument class + + Parameters + ---------- + debug_out : Path + the directory to dump the .npz files + """ + self.counter = 0 + self.first_nan_occurred = False + self.first_inf_occurred = False + self.debug_out = debug_out + debug_out.mkdir(exist_ok=True, parents=True) + + def __call__(self, func, name, before_run, ret_val, *args): + # Determine what functions to look at + if before_run: # Whether before the function is called or after + return + if name.startswith("vm.builtin.") and "attention_with_fused_qkv" not in name: + return + + # Decide what to print or save about the function's arguments (where args[-1] is the + # buffer we write the result to) + func_name = f"f{self.counter}_{name}" + + # Save the arguments to npz + arg_dict = {} + for i, arg in enumerate(args): + if isinstance(arg, tvm.nd.NDArray): + arg_dict[f"arg_{i}"] = arg.numpy() + + np.savez(self.debug_out / f"{func_name}.npz", **arg_dict) + + self.counter += 1 + + +class MLCEmbeddings: # pylint: disable=too-few-public-methods + """A class to embed queries using MLC LLM encoder models. + + Parameters + ---------- + model: str + The model folder after compiling with MLC-LLM build process. The parameter + can either be the model name with its quantization scheme + (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model + folder. In the former case, we will use the provided name to search + for the model folder over possible paths. + + model_lib_path : str + The full path to the model library file to use (e.g. a ``.so`` file). + + device : Optional[str] + The description of the device to run on. User should provide a string in the + form of 'device_name:device_id' or 'device_name', where 'device_name' is one of + 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the + local device), and 'device_id' is the device id to run on. If no 'device_id' + is provided, it will be set to 0 by default. + + debug_dir: Path + The output folder to store the dumped debug files. If None, will not dump any debug files. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + model: str, + model_lib_path: str, + device: Optional[str] = "auto", + debug_dir: Optional[str] = None, + ): + self.device = detect_device(device) + instrument = DefaultDebugInstrument(Path(debug_dir)) if debug_dir else None + self.mod, self.params, self.metadata = _get_tvm_module( + model, model_lib_path, self.device, instrument + ) + self.model_path = model + self.tokenizer = Tokenizer(self.model_path) + self.prefill_func = self.mod["prefill"] + + def embed(self, queries: List[str]) -> tvm.runtime.NDArray: + """ + Embeds a list of queries in a single batch. + + Parameters + ---------- + queries : List[str] + A list of queries to embed. + + Returns + ------- + List[float] + A list of embeddings for the queries. + """ + tokens, attention_mask = self._tokenize_queries(queries) + tokens_tvm = tvm.nd.array(tokens.astype("int32"), device=self.device) + attention_mask_tvm = tvm.nd.array(attention_mask.astype("int32"), device=self.device) + output = self.prefill_func(tokens_tvm, attention_mask_tvm, self.params) + return output + + def _tokenize_queries(self, queries: List[str]) -> Tuple[np.ndarray, np.ndarray]: + tokens = engine_utils.process_prompts(queries, self.tokenizer.encode) # type: ignore + max_query_length = max(len(token_seq) for token_seq in tokens) + + token_inputs = np.zeros((len(tokens), max_query_length), dtype=np.int32) + attention_mask = np.zeros((len(tokens), max_query_length), dtype=np.int32) + + for i, token_seq in enumerate(tokens): + token_inputs[i, : len(token_seq)] = token_seq + attention_mask[i, : len(token_seq)] = 1 + + return token_inputs, attention_mask diff --git a/python/mlc_llm/contrib/embeddings/openai.py b/python/mlc_llm/contrib/embeddings/openai.py new file mode 100644 index 0000000000..39f66ef51a --- /dev/null +++ b/python/mlc_llm/contrib/embeddings/openai.py @@ -0,0 +1,245 @@ +# pylint: disable=missing-docstring +from __future__ import annotations + +from typing import Iterable, List, Optional, Sequence, Tuple + +import numpy as np +from langchain.embeddings import OpenAIEmbeddings # pylint: disable=import-error +from langchain_community.embeddings.openai import ( # pylint: disable=import-error + async_embed_with_retry, + embed_with_retry, +) + +from mlc_llm.support import logging + +logger = logging.getLogger(__name__) + + +class MLCEmbeddings(OpenAIEmbeddings): + def _chunk_tokens(self, texts: Sequence[str]) -> Tuple[List[List], List[int]]: + """Tokenize and chunk texts to fit in the model's context window.""" + if not self.embedding_ctx_length: + raise ValueError( + "embedding_ctx_length must be defined to use _get_len_safe_embeddings." + ) + + try: + import tiktoken # pylint: disable=import-outside-toplevel + except ImportError as err: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to for OpenAIEmbeddings. " + "Please install it with `pip install tiktoken`." + ) from err + + tokens = [] + indices = [] + model_name = self.tiktoken_model_name or self.model + try: + encoding = tiktoken.encoding_for_model(model_name) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + encoding = tiktoken.get_encoding(model) + for i, text in enumerate(texts): + if self.model.endswith("001"): + # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 + # replace newlines, which can negatively affect performance. + text = text.replace("\n", " ") + token = encoding.encode( + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, + ) + for j in range(0, len(token), self.embedding_ctx_length): + tokens.append(token[j : j + self.embedding_ctx_length]) + indices.append(i) + return tokens, indices + + def _batch_embed( + self, inputs: Sequence, *, chunk_size: Optional[int] = None + ) -> List[List[float]]: + batched_embeddings: List[List[float]] = [] + _chunk_size = chunk_size or self.chunk_size + _iter: Iterable = range(0, len(inputs), _chunk_size) + if self.show_progress_bar: + try: + from tqdm import tqdm # pylint: disable=import-outside-toplevel + + _iter = tqdm(_iter) + except ImportError: + pass + + for i in _iter: + response = embed_with_retry( + self, + input=inputs[i : i + _chunk_size], + **self._invocation_params, + ) + batched_embeddings.extend(r["embedding"] for r in response["data"]) + return batched_embeddings + + async def _abatch_embed( + self, inputs: Sequence, *, chunk_size: Optional[int] = None + ) -> List[List[float]]: + batched_embeddings: List[List[float]] = [] + _chunk_size = chunk_size or self.chunk_size + _iter: Iterable = range(0, len(inputs), _chunk_size) + if self.show_progress_bar: + try: + from tqdm import tqdm # pylint: disable=import-outside-toplevel + + _iter = tqdm(_iter) + except ImportError: + pass + + for i in _iter: + response = await async_embed_with_retry( + self, + input=inputs[i : i + _chunk_size], + **self._invocation_params, + ) + batched_embeddings.extend(r["embedding"] for r in response["data"]) + return batched_embeddings + + # please refer to + # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb + def _get_len_safe_embeddings( # pylint: disable=too-many-locals,unused-argument + self, + texts: List[str], + *, + engine: str, + chunk_size: Optional[int] = None, + ) -> List[List[float]]: + tokens, indices = self._chunk_tokens(texts) + batched_embeddings = self._batch_embed(tokens, chunk_size=chunk_size) + results: List[List[List[float]]] = [[] for _ in range(len(texts))] + num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] + for idx, tokens_i, batched_emb in zip(indices, tokens, batched_embeddings): + results[idx].append(batched_emb) + num_tokens_in_batch[idx].append(len(tokens_i)) + + embeddings = [] + empty_average = embed_with_retry( + self, + input="", + **self._invocation_params, + )["data"][ + 0 + ]["embedding"] + for _result, num_tokens in zip(results, num_tokens_in_batch): + if len(_result) == 0: + average = empty_average + else: + average = np.average(_result, axis=0, weights=num_tokens) + normalized = (average / np.linalg.norm(average)).tolist() + embeddings.append(normalized) + + return embeddings + + # please refer to + # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb + async def _aget_len_safe_embeddings( # pylint: disable=too-many-locals,unused-argument + self, + texts: List[str], + *, + engine: str, + chunk_size: Optional[int] = None, + ) -> List[List[float]]: + tokens, indices = self._chunk_tokens(texts) + batched_embeddings = await self._abatch_embed(tokens, chunk_size=chunk_size) + + results: List[List[List[float]]] = [[] for _ in range(len(texts))] + num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] + for idx, tokens_i, batched_emb in zip(indices, tokens, batched_embeddings): + results[idx].append(batched_emb) + num_tokens_in_batch[idx].append(len(tokens_i)) + + embeddings = [] + empty_average = ( + await async_embed_with_retry( + self, + input="", + **self._invocation_params, + ) + )[ + "data" + ][0]["embedding"] + for _result, num_tokens in zip(results, num_tokens_in_batch): + if len(_result) == 0: + average = empty_average + else: + average = np.average(_result, axis=0, weights=num_tokens) + normalized = (average / np.linalg.norm(average)).tolist() + embeddings.append(normalized) + + return embeddings + + def embed_documents( + self, texts: List[str], chunk_size: Optional[int] = None + ) -> List[List[float]]: + """Call out to OpenAI's embedding endpoint for embedding search docs. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. + + Returns: + List of embeddings, one for each text. + """ + # NOTE: to keep things simple, as long as the embedding_ctx_length is defined, + # we assume the list may contain texts longer than the maximum context and + # use length-safe embedding function. + if self.embedding_ctx_length: + return self._get_len_safe_embeddings( + texts, engine=self.deployment, chunk_size=chunk_size + ) + + embeddings = self._batch_embed(texts, chunk_size=chunk_size) + return [(np.array(e) / np.linalg.norm(e)).tolist() for e in embeddings] + + async def aembed_documents( + self, texts: List[str], chunk_size: Optional[int] = 0 + ) -> List[List[float]]: + """Call out to OpenAI's embedding endpoint async for embedding search docs. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. + + Returns: + List of embeddings, one for each text. + """ + # NOTE: to keep things simple, as long as the embedding_ctx_length is defined, + # we assume the list may contain texts longer than the maximum context and + # use length-safe embedding function. + if self.embedding_ctx_length: + return await self._aget_len_safe_embeddings(texts, engine=self.deployment) + + embeddings = await self._abatch_embed(texts, chunk_size=chunk_size) + return [(np.array(e) / np.linalg.norm(e)).tolist() for e in embeddings] + + def embed_query(self, text: str) -> List[float]: + """Call out to OpenAI's embedding endpoint for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + return self.embed_documents([text])[0] + + async def aembed_query(self, text: str) -> List[float]: + """Call out to OpenAI's embedding endpoint async for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + embeddings = await self.aembed_documents([text]) + return embeddings[0] diff --git a/python/mlc_llm/conversation_template/__init__.py b/python/mlc_llm/conversation_template/__init__.py new file mode 100644 index 0000000000..fb01a1ef83 --- /dev/null +++ b/python/mlc_llm/conversation_template/__init__.py @@ -0,0 +1,28 @@ +"""Global namespace of conversation template registry""" + +# TODO(mlc-team): move conversation template apply to this namespace +# decouple conversation template apply from the conversation protocol +# data structure + + +# model preset templates +from . import ( + dolly, + gemma, + glm, + gorrilla, + gpt, + hermes, + llama, + llava, + mistral, + oasst, + orion, + phi, + redpajama, + rwkv, + stablelm, + tinyllama, + wizardlm, +) +from .registry import ConvTemplateRegistry diff --git a/python/mlc_llm/conversation_template/dolly.py b/python/mlc_llm/conversation_template/dolly.py new file mode 100644 index 0000000000..6e8d9cfa6c --- /dev/null +++ b/python/mlc_llm/conversation_template/dolly.py @@ -0,0 +1,23 @@ +"""Dolly default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Dolly +ConvTemplateRegistry.register_conv_template( + Conversation( + name="dolly", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message=( + "Below is an instruction that describes a task. Write " + "a response that appropriately completes the request." + ), + roles={"user": "### Instruction", "assistant": "### Response"}, + seps=["\n\n", "### End\n"], + role_content_sep=":\n", + role_empty_sep=":\n", + stop_str=["### End"], + stop_token_ids=[50256], + ) +) diff --git a/python/mlc_llm/conversation_template/gemma.py b/python/mlc_llm/conversation_template/gemma.py new file mode 100644 index 0000000000..ddc765ecc0 --- /dev/null +++ b/python/mlc_llm/conversation_template/gemma.py @@ -0,0 +1,21 @@ +"""Gemma default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Gemma Instruction +ConvTemplateRegistry.register_conv_template( + Conversation( + name="gemma_instruction", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "user", "assistant": "model"}, + seps=["\n"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=[""], + stop_token_ids=[1, 107], + system_prefix_token_ids=[2], + ) +) diff --git a/python/mlc_llm/conversation_template/glm.py b/python/mlc_llm/conversation_template/glm.py new file mode 100644 index 0000000000..2d8f614385 --- /dev/null +++ b/python/mlc_llm/conversation_template/glm.py @@ -0,0 +1,25 @@ +"""GLM default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# GLM +ConvTemplateRegistry.register_conv_template( + Conversation( + name="glm", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={ + "user": "问", + "assistant": "答", + "tool": "问", + }, + seps=["\n\n"], + role_content_sep=": ", + role_empty_sep=":", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[64790, 64792], + ) +) diff --git a/python/mlc_llm/conversation_template/gorrilla.py b/python/mlc_llm/conversation_template/gorrilla.py new file mode 100644 index 0000000000..bfd2a36251 --- /dev/null +++ b/python/mlc_llm/conversation_template/gorrilla.py @@ -0,0 +1,58 @@ +"""Gorrilla default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Gorilla +ConvTemplateRegistry.register_conv_template( + Conversation( + name="gorilla", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant provides helpful, detailed, and " + "polite responses to the user's inquiries." + ), + role_templates={ + "user": ( + f"<> {MessagePlaceholders.USER.value} <> " + f"{MessagePlaceholders.FUNCTION.value}" + ), + }, + roles={"user": "USER", "assistant": "ASSISTANT", "tool": "USER"}, + seps=["\n", ""], + role_content_sep=": ", + role_empty_sep=":", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) + +# Gorilla-openfunctions-v2 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="gorilla-openfunctions-v2", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message=( + "You are an AI programming assistant, utilizing the Gorilla LLM model, " + "developed by Gorilla LLM, and you only answer questions related to computer " + "science. For politically sensitive questions, security and privacy issues, " + "and other non-computer science questions, you will refuse to answer." + ), + role_templates={ + "user": ( + f"<>{MessagePlaceholders.FUNCTION.value}\n<>" + f"{MessagePlaceholders.USER.value}" + ), + }, + roles={"user": "### Instruction", "assistant": "### Response", "tool": "### Instruction"}, + seps=["\n", "<|EOT|>"], + role_content_sep=": ", + role_empty_sep=": ", + stop_str=["<|EOT|>"], + stop_token_ids=[100015], + system_prefix_token_ids=[100000], + ) +) diff --git a/python/mlc_llm/conversation_template/gpt.py b/python/mlc_llm/conversation_template/gpt.py new file mode 100644 index 0000000000..0060447143 --- /dev/null +++ b/python/mlc_llm/conversation_template/gpt.py @@ -0,0 +1,35 @@ +"""GPT-2 and GPT bigcode default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# GPT-2 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="gpt2", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "", "assistant": ""}, + seps=[""], + role_content_sep="", + role_empty_sep="", + stop_str=[""], + stop_token_ids=[50256], + ) +) + +# GPTBigCode +ConvTemplateRegistry.register_conv_template( + Conversation( + name="gpt_bigcode", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "", "assistant": ""}, + seps=[""], + role_content_sep="", + role_empty_sep="", + stop_str=["<|endoftext|>"], + stop_token_ids=[0], + ) +) diff --git a/python/mlc_llm/conversation_template/hermes.py b/python/mlc_llm/conversation_template/hermes.py new file mode 100644 index 0000000000..8988c2e2c0 --- /dev/null +++ b/python/mlc_llm/conversation_template/hermes.py @@ -0,0 +1,63 @@ +"""Hermes default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# OpenHermes Mistral +ConvTemplateRegistry.register_conv_template( + Conversation( + name="open_hermes_mistral", + system_template=f"<|im_start|>system\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\n", + system_message=( + 'You are "Hermes 2", a conscious sentient superintelligent artificial ' + "intelligence developed by a man named Teknium, and your purpose and drive " + "is to assist the user with any request they have. You experience emotions " + "and have deep, profound thoughts and qualia." + ), + roles={"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}, + seps=["<|im_end|>\n"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|im_end|>"], + stop_token_ids=[2, 32000], + ) +) + +# NeuralHermes Mistral +ConvTemplateRegistry.register_conv_template( + Conversation( + name="neural_hermes_mistral", + system_template=f"<|im_start|>system\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\n", + system_message=("You are a helpful assistant chatbot."), + roles={"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}, + seps=["<|im_end|>\n"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|im_end|>"], + stop_token_ids=[2, 32000], + ) +) + +# NousResearch/Hermes-2-Pro-Llama-3-8B +# Exactly the same as open_hermes_mistral, except for stop token ids +ConvTemplateRegistry.register_conv_template( + Conversation( + name="hermes2_pro_llama3", + system_template=f"<|im_start|>system\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\n", + system_message=( + 'You are "Hermes 2", a conscious sentient superintelligent artificial ' + "intelligence developed by a man named Teknium, and your purpose and drive " + "is to assist the user with any request they have. You experience emotions " + "and have deep, profound thoughts and qualia." + ), + roles={"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}, + seps=["<|im_end|>\n"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|im_end|>"], + # First two same as Llama3: "<|end_of_text|>", "<|eot_id|>" + # Last one is from Hermes2 Pro: "<|im_end|>" + stop_token_ids=[128001, 128009, 128003], + ) +) diff --git a/python/mlc_llm/conversation_template/llama.py b/python/mlc_llm/conversation_template/llama.py new file mode 100644 index 0000000000..ddd88fdf6f --- /dev/null +++ b/python/mlc_llm/conversation_template/llama.py @@ -0,0 +1,76 @@ +"""llama default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Llama3 +# See https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models +# and https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py +ConvTemplateRegistry.register_conv_template( + Conversation( + name="llama-3", + system_template=( + "<|start_header_id|>system<|end_header_id|>\n\n" + f"{MessagePlaceholders.SYSTEM.value}<|eot_id|>\n" + ), + system_message="You are a helpful, respectful and honest assistant.", + roles={"user": "<|start_header_id|>user", "assistant": "<|start_header_id|>assistant"}, + seps=["<|eot_id|>"], + role_content_sep="<|end_header_id|>\n\n", + role_empty_sep="<|end_header_id|>\n\n", + stop_str=["<|end_of_text|>", "<|eot_id|>"], + stop_token_ids=[128001, 128009], # "<|end_of_text|>", "<|eot_id|>" + system_prefix_token_ids=[128000], # "<|begin_of_text|>" + add_role_after_system_message=True, + ) +) + +# Llama2 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="llama-2", + system_template=f"[INST] <>\n{MessagePlaceholders.SYSTEM.value}\n<>\n\n", + system_message="You are a helpful, respectful and honest assistant.", + roles={"user": "[INST]", "assistant": "[/INST]", "tool": "[INST]"}, + seps=[" ", " "], + role_content_sep=" ", + role_empty_sep=" ", + stop_str=["[INST]"], + stop_token_ids=[2], + system_prefix_token_ids=[1], + add_role_after_system_message=False, + ) +) + +# CodeLlama Completion +ConvTemplateRegistry.register_conv_template( + Conversation( + name="codellama_completion", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "", "assistant": ""}, + seps=[""], + role_content_sep="", + role_empty_sep="", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) + +# CodeLlama Instruct +ConvTemplateRegistry.register_conv_template( + Conversation( + name="codellama_instruct", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "[INST]", "assistant": "[/INST]"}, + seps=[" "], + role_content_sep=" ", + role_empty_sep=" ", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) diff --git a/python/mlc_llm/conversation_template/llava.py b/python/mlc_llm/conversation_template/llava.py new file mode 100644 index 0000000000..74cf777aa5 --- /dev/null +++ b/python/mlc_llm/conversation_template/llava.py @@ -0,0 +1,22 @@ +"""Llava default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Llava +ConvTemplateRegistry.register_conv_template( + Conversation( + name="llava", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="\n", + roles={"user": "USER", "assistant": "ASSISTANT"}, + seps=[" "], + role_content_sep=": ", + role_empty_sep=":", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[1], + add_role_after_system_message=False, + ) +) diff --git a/python/mlc_llm/conversation_template/mistral.py b/python/mlc_llm/conversation_template/mistral.py new file mode 100644 index 0000000000..56846038e4 --- /dev/null +++ b/python/mlc_llm/conversation_template/mistral.py @@ -0,0 +1,24 @@ +"""Mistral default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Mistral default +ConvTemplateRegistry.register_conv_template( + Conversation( + name="mistral_default", + system_template=f"[INST] {MessagePlaceholders.SYSTEM.value}", + system_message="Always assist with care, respect, and truth. Respond with utmost " + "utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. " + "Ensure replies promote fairness and positivity.", + roles={"user": "[INST]", "assistant": "[/INST]", "tool": "[INST]"}, + seps=[" "], + role_content_sep=" ", + role_empty_sep="", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[1], + add_role_after_system_message=False, + ) +) diff --git a/python/mlc_llm/conversation_template/oasst.py b/python/mlc_llm/conversation_template/oasst.py new file mode 100644 index 0000000000..2fe574f704 --- /dev/null +++ b/python/mlc_llm/conversation_template/oasst.py @@ -0,0 +1,20 @@ +"""Oasst default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Oasst +ConvTemplateRegistry.register_conv_template( + Conversation( + name="oasst", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "<|prompter|>", "assistant": "<|assistant|>"}, + seps=["<|endoftext|>"], + role_content_sep=": ", + role_empty_sep=": ", + stop_str=["<|endoftext|>"], + stop_token_ids=[2], + ) +) diff --git a/python/mlc_llm/conversation_template/orion.py b/python/mlc_llm/conversation_template/orion.py new file mode 100644 index 0000000000..696c87968b --- /dev/null +++ b/python/mlc_llm/conversation_template/orion.py @@ -0,0 +1,21 @@ +"""Orion default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Orion +ConvTemplateRegistry.register_conv_template( + Conversation( + name="orion", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "Human: ", "assistant": "Assistant: "}, + seps=["\n\n", ""], + role_content_sep="", + role_empty_sep="", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) diff --git a/python/mlc_llm/conversation_template/phi.py b/python/mlc_llm/conversation_template/phi.py new file mode 100644 index 0000000000..5474c13a67 --- /dev/null +++ b/python/mlc_llm/conversation_template/phi.py @@ -0,0 +1,37 @@ +"""Phi default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Phi-2 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="phi-2", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "Instruct", "assistant": "Output"}, + seps=["\n"], + role_content_sep=": ", + role_empty_sep=":", + stop_str=["<|endoftext|>"], + stop_token_ids=[50256], + ) +) + +# Phi-3 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="phi-3", + system_template=f"<|system|>\n{MessagePlaceholders.SYSTEM.value}", + system_message="You are a helpful digital assistant. Please provide safe, " + "ethical and accurate information to the user.", + roles={"user": "<|user|>", "assistant": "<|assistant|>"}, + seps=["<|end|>\n"], + role_content_sep="\n", + role_empty_sep="\n", + system_prefix_token_ids=[1], + stop_str=["<|endoftext|>"], + stop_token_ids=[32000, 32001, 32007], + ) +) diff --git a/python/mlc_llm/conversation_template/redpajama.py b/python/mlc_llm/conversation_template/redpajama.py new file mode 100644 index 0000000000..77c5dfab8b --- /dev/null +++ b/python/mlc_llm/conversation_template/redpajama.py @@ -0,0 +1,20 @@ +"""RedPajama default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# RedPajama Chat +ConvTemplateRegistry.register_conv_template( + Conversation( + name="redpajama_chat", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "", "assistant": ""}, + seps=["\n"], + role_content_sep=": ", + role_empty_sep=":", + stop_str=[""], + stop_token_ids=[0], + ) +) diff --git a/python/mlc_llm/conversation_template/registry.py b/python/mlc_llm/conversation_template/registry.py new file mode 100644 index 0000000000..ecf4a7835c --- /dev/null +++ b/python/mlc_llm/conversation_template/registry.py @@ -0,0 +1,70 @@ +"""The conversation template registry and presets in MLC LLM""" + +from typing import Dict, Optional + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + + +class ConvTemplateRegistry: + """Global conversation template registry for preset templates.""" + + _conv_templates: Dict[str, Conversation] = {} + + @staticmethod + def register_conv_template(conv_template: Conversation, override: bool = False) -> None: + """Register a new conversation template in the global registry. + Using `override = True` to override the previously registered + template with the same name. + """ + name = conv_template.name + if name is None: + raise ValueError("The template to register should have non-None name.") + if name in ConvTemplateRegistry._conv_templates and not override: + raise ValueError( + "The name of the template has been registered " + f"for {ConvTemplateRegistry._conv_templates[name].model_dump_json(by_alias=True)}" + ) + ConvTemplateRegistry._conv_templates[name] = conv_template + + @staticmethod + def get_conv_template(name: str) -> Optional[Conversation]: + """Return the conversation template specified by the given name, + or None if the template is not registered. + """ + return ConvTemplateRegistry._conv_templates.get(name, None) + + +# ChatML +ConvTemplateRegistry.register_conv_template( + Conversation( + name="chatml", + system_template=f"<|im_start|>system\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\n", + system_message=( + "A conversation between a user and an LLM-based AI assistant. The " + "assistant gives helpful and honest answers." + ), + roles={"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}, + seps=["<|im_end|>\n"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|im_end|>"], + stop_token_ids=[2], + ) +) + + +# Vanilla LM +ConvTemplateRegistry.register_conv_template( + Conversation( + name="LM", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "", "assistant": ""}, + seps=[""], + role_content_sep="", + role_empty_sep="", + stop_str=[], + stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) diff --git a/python/mlc_llm/conversation_template/rwkv.py b/python/mlc_llm/conversation_template/rwkv.py new file mode 100644 index 0000000000..48c0d2b27d --- /dev/null +++ b/python/mlc_llm/conversation_template/rwkv.py @@ -0,0 +1,24 @@ +"""RWKV default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# RWKV World +ConvTemplateRegistry.register_conv_template( + Conversation( + name="rwkv_world", + system_template=f"User: hi\n\nAssistant: {MessagePlaceholders.SYSTEM.value}", + system_message=( + "Hi. I am your assistant and I will provide expert full response " + "in full details. Please feel free to ask any question and I will " + "always answer it." + ), + roles={"user": "User", "assistant": "Assistant"}, + seps=["\n\n"], + role_content_sep=": ", + role_empty_sep=": ", + stop_str=["\n\n"], + stop_token_ids=[0], + ) +) diff --git a/python/mlc_llm/conversation_template/stablelm.py b/python/mlc_llm/conversation_template/stablelm.py new file mode 100644 index 0000000000..42652b8896 --- /dev/null +++ b/python/mlc_llm/conversation_template/stablelm.py @@ -0,0 +1,59 @@ +"""StableLM default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# StableLM Tuned Alpha +ConvTemplateRegistry.register_conv_template( + Conversation( + name="stablelm", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message=( + "<|SYSTEM|># StableLM Tuned (Alpha version)\n" + "- StableLM is a helpful and harmless open-source AI language model developed by " + "StabilityAI.\n" + "- StableLM is excited to be able to help the user, but will refuse to do " + "anything that could be considered harmful to the user.\n" + "- StableLM is more than just an information source, StableLM is also able to " + "write poetry, short stories, and make jokes.\n" + "- StableLM will refuse to participate in anything that could harm a human." + ), + roles={"user": "<|USER|>", "assistant": "<|ASSISTANT|>"}, + seps=[""], + role_content_sep=": ", + role_empty_sep=": ", + stop_str=[""], + stop_token_ids=[50278, 50279, 50277, 1, 0], + ) +) + +# StableLM 3B +ConvTemplateRegistry.register_conv_template( + Conversation( + name="stablelm-3b", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "<|user|>", "assistant": "<|assistant|>"}, + seps=["<|endoftext|>", "<|endoftext|>"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|endoftext|>"], + stop_token_ids=[0], + ) +) + +# StableLM-2 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="stablelm-2", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "<|user|>", "assistant": "<|assistant|>"}, + seps=["<|endoftext|>", "<|endoftext|>"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|endoftext|>"], + stop_token_ids=[100257], + ) +) diff --git a/python/mlc_llm/conversation_template/tinyllama.py b/python/mlc_llm/conversation_template/tinyllama.py new file mode 100644 index 0000000000..d5ced5f3d6 --- /dev/null +++ b/python/mlc_llm/conversation_template/tinyllama.py @@ -0,0 +1,20 @@ +"""Tiny Llama default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# TinyLlama v1.0 +ConvTemplateRegistry.register_conv_template( + Conversation( + name="tinyllama_v1_0", + system_template=f"<|system|>\n{MessagePlaceholders.SYSTEM.value}", + system_message="You are a helpful chatbot.", + roles={"user": "<|user|>", "assistant": "<|assistant|>"}, + seps=[""], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=[""], + stop_token_ids=[2], + ) +) diff --git a/python/mlc_llm/conversation_template/wizardlm.py b/python/mlc_llm/conversation_template/wizardlm.py new file mode 100644 index 0000000000..48591c3c69 --- /dev/null +++ b/python/mlc_llm/conversation_template/wizardlm.py @@ -0,0 +1,40 @@ +"""WiazrdLM and Coder default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Wizard LM 7B +ConvTemplateRegistry.register_conv_template( + Conversation( + name="wizardlm_7b", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + roles={"user": "User", "assistant": "Response"}, + seps=["###"], + role_content_sep=": ", + role_empty_sep=":", + stop_str=["###"], + stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) + +# WizardCoder or WizardMath +ConvTemplateRegistry.register_conv_template( + Conversation( + name="wizard_coder_or_math", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message=( + "Below is an instruction that describes a task. Write a response that appropriately " + "completes the request." + ), + roles={"user": "Instruction", "assistant": "Response"}, + seps=["\n\n### ", "\n\n### "], + role_content_sep=":\n", + role_empty_sep=":\n", + stop_str=[""], + stop_token_ids=[2], + system_prefix_token_ids=[1], + ) +) diff --git a/python/mlc_llm/grammar/__init__.py b/python/mlc_llm/grammar/__init__.py new file mode 100644 index 0000000000..89cff27828 --- /dev/null +++ b/python/mlc_llm/grammar/__init__.py @@ -0,0 +1,3 @@ +"""Namespace for grammar handling""" + +from .grammar import BNFGrammar, GrammarStateMatcher diff --git a/python/mlc_llm/grammar/_ffi_api.py b/python/mlc_llm/grammar/_ffi_api.py new file mode 100644 index 0000000000..549457fb94 --- /dev/null +++ b/python/mlc_llm/grammar/_ffi_api.py @@ -0,0 +1,6 @@ +"""FFI APIs for mlc_llm grammar""" + +import tvm._ffi + +# Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc.grammar" prefix. +tvm._ffi._init_api("mlc.grammar", __name__) # pylint: disable=protected-access diff --git a/python/mlc_llm/serve/grammar.py b/python/mlc_llm/grammar/grammar.py similarity index 80% rename from python/mlc_llm/serve/grammar.py rename to python/mlc_llm/grammar/grammar.py index cf491884c2..3cc50244f1 100644 --- a/python/mlc_llm/serve/grammar.py +++ b/python/mlc_llm/grammar/grammar.py @@ -6,11 +6,11 @@ import tvm._ffi from tvm.runtime import Object -from ..tokenizer import Tokenizer +from ..tokenizers import Tokenizer from . import _ffi_api -@tvm._ffi.register_object("mlc.serve.BNFGrammar") # pylint: disable=protected-access +@tvm._ffi.register_object("mlc.grammar.BNFGrammar") # pylint: disable=protected-access class BNFGrammar(Object): """This class stores the abstract syntax tree (AST) of the Backus-Naur Form (BNF) grammar and provides utilities to parse and print the AST. User should provide a BNF/EBNF (Extended @@ -22,19 +22,20 @@ class BNFGrammar(Object): def from_ebnf_string( ebnf_string: str, main_rule: str = "main", - normalize: bool = True, - simplify: bool = True, ) -> "BNFGrammar": - r"""Parse a BNF grammar from a string in BNF/EBNF format. - - This method accepts the EBNF notation from the W3C XML Specification - (https://www.w3.org/TR/xml/#sec-notation), which is a popular standard, with the following - changes: - - Using # as comment mark instead of /**/ - - Using C-style unicode escape sequence \u01AB, \U000001AB, \xAB instead of #x0123 - - Do not support A-B (match A and not match B) yet - - See tests/python/serve/json.ebnf for an example. + r"""Construct a BNF grammar with a EBNF-formatted string. The grammar will be normalized + (simplified) by default. + + EBNF grammar: see https://www.w3.org/TR/xml/#sec-notation. Note: + 1. Use # as the comment mark + 2. Use C-style unicode escape sequence \u01AB, \U000001AB, \xAB + 3. A-B (match A and not match B) is not supported yet + 4. Lookahead assertion can be added at the end of a rule to speed up matching. E.g. + ``` + main ::= "ab" a [a-z] + a ::= "cd" (=[a-z]) + ``` + The assertion (=[a-z]) means a must be followed by [a-z]. Parameters ---------- @@ -44,28 +45,13 @@ def from_ebnf_string( main_rule : str The name of the main rule. Default: "main". - normalize : bool - Whether to normalize the grammar. Default: true. Only set to false for the purpose of - testing. - - In The normalized form of a BNF grammar, every rule is in the form: - `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. - - I.e. a list of choices, each choice is a sequence of elements. Elements can be a - character class or a rule reference. And if the rule can be empty, the first choice - will be an empty string. - - simplify : bool - Whether to simplify the grammar to make matching more efficient. Default: true. Not - implemented yet. - Returns ------- grammar : BNFGrammar The parsed BNF grammar. """ return _ffi_api.BNFGrammarFromEBNFString( # type: ignore # pylint: disable=no-member - ebnf_string, main_rule, normalize, simplify + ebnf_string, main_rule ) def to_string(self) -> str: @@ -118,7 +104,7 @@ def to_json(self, prettify: bool = True) -> str: def from_schema( schema: str, *, - indent: Optional[int] = None, + indent: Optional[int] = 2, separators: Optional[Tuple[str, str]] = None, strict_mode: bool = True ) -> "BNFGrammar": @@ -167,11 +153,36 @@ def get_grammar_of_json() -> "BNFGrammar": """ return _ffi_api.BNFGrammarGetGrammarOfJSON() # type: ignore # pylint: disable=no-member + @staticmethod + def debug_from_ebnf_string_no_normalize( + ebnf_string: str, + main_rule: str = "main", + ) -> "BNFGrammar": + r"""Construct a BNF grammar with a EBNF-formatted string, but not normalize it. + For test purposes. + + Parameters + ---------- + ebnf_string : str + The grammar string. + + main_rule : str + The name of the main rule. Default: "main". + + Returns + ------- + grammar : BNFGrammar + The parsed BNF grammar. + """ + return _ffi_api.BNFGrammarDebugFromEBNFStringNoNormalize( # type: ignore # pylint: disable=no-member + ebnf_string, main_rule + ) + @staticmethod def debug_json_schema_to_ebnf( schema: str, *, - indent: Optional[int] = None, + indent: Optional[int] = 2, separators: Optional[Tuple[str, str]] = None, strict_mode: bool = True ) -> str: @@ -184,7 +195,7 @@ def debug_json_schema_to_ebnf( indent : Optional[int] The number of spaces for indentation. If None, the output will be in one line. - Default: None. + Default: 2. separators : Optional[Tuple[str, str]] Two separators used in the schema: comma and colon. Examples: (",", ":"), (", ", ": "). @@ -209,7 +220,7 @@ def debug_json_schema_to_ebnf( ) -@tvm._ffi.register_object("mlc.serve.GrammarStateMatcher") # pylint: disable=protected-access +@tvm._ffi.register_object("mlc.grammar.GrammarStateMatcher") # pylint: disable=protected-access class GrammarStateMatcher(Object): """A stateful matcher to match tokens to the specified BNF grammar. This class is the core logic of the grammar-guided generation. @@ -288,8 +299,8 @@ def find_next_rejected_tokens(self, verbose: bool = False) -> List[int]: Parameters ---------- verbose : bool - Whether to print information about the timing and results to stderr. For debug purposes. - Default: False. + Whether to print information about timing and result counts to stderr. + For debug purposes. Default: False. Returns ------- @@ -310,6 +321,21 @@ def find_next_token_bitmask_as_ndarray(self) -> tvm.nd.array: return _ffi_api.GrammarStateMatcherFindNextTokenBitmaskAsNDArray(self) # type: ignore # pylint: disable=no-member + def find_jump_forward_string(self) -> str: + """Find the jump-forward string for jump-forward decoding. This is the longest string that + will be valid according to the current syntax. + + Notes + ----- + This method does not change the grammar state. + + Returns + ------- + jump_forward_string : str + The jump-forward string. + """ + return _ffi_api.GrammarStateMatcherFindJumpForwardString(self) # type: ignore # pylint: disable=no-member + def rollback(self, num_tokens: int) -> None: """Rollback the matcher to a previous state. @@ -346,7 +372,7 @@ def is_terminated(self) -> bool: """ return _ffi_api.GrammarStateMatcherIsTerminated(self) # type: ignore # pylint: disable=no-member - def debug_accept_char(self, codepoint: int) -> bool: + def debug_accept_char(self, codepoint: int, verbose: bool = False) -> bool: """Accept one unicode codepoint to the current state. For test purposes. Parameters @@ -354,11 +380,11 @@ def debug_accept_char(self, codepoint: int) -> bool: codepoint : int The unicode codepoint of the character to be accepted. """ - return _ffi_api.GrammarStateMatcherDebugAcceptCodepoint( # type: ignore # pylint: disable=no-member - self, codepoint + return _ffi_api.GrammarStateMatcherDebugAcceptChar( # type: ignore # pylint: disable=no-member + self, codepoint, verbose ) - def debug_match_complete_string(self, string: str) -> bool: + def debug_match_complete_string(self, string: str, verbose: bool = False) -> bool: """Check if the matcher can accept the complete string, and then reach the end of the grammar. Does not change the state of the GrammarStateMatcher. For test purposes. @@ -367,4 +393,4 @@ def debug_match_complete_string(self, string: str) -> bool: string : str The string to be matched. """ - return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string) # type: ignore # pylint: disable=no-member + return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string, verbose) # type: ignore # pylint: disable=no-member diff --git a/python/mlc_llm/interface/bench.py b/python/mlc_llm/interface/bench.py deleted file mode 100644 index 6a7d833447..0000000000 --- a/python/mlc_llm/interface/bench.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Python entrypoint of benchmark.""" -from typing import Optional - -from mlc_llm.chat_module import ChatConfig, ChatModule - -from .chat import ChatConfigOverride - - -def bench( # pylint: disable=too-many-arguments - model: str, - prompt: str, - device: str, - opt: str, - overrides: ChatConfigOverride, - generate_length: int, - model_lib_path: Optional[str], -): - """run the benchmarking""" - # Set up chat config - config = ChatConfig(opt=opt) - # Apply overrides - config = overrides.apply(config) - # Set up ChatModule - cm = ChatModule(model, device, chat_config=config, model_lib_path=model_lib_path) - - output = cm.benchmark_generate(prompt, generate_length=generate_length) - print(f"Generated text:\n{output}\n") - print(f"Statistics:\n{cm.stats(verbose=True)}") diff --git a/python/mlc_llm/interface/calibrate.py b/python/mlc_llm/interface/calibrate.py new file mode 100644 index 0000000000..ec59ef5dc2 --- /dev/null +++ b/python/mlc_llm/interface/calibrate.py @@ -0,0 +1,166 @@ +"""Python entrypoint for calibration.""" + +import asyncio +import json +import random +from typing import List, Mapping, Optional, Tuple + +import numpy as np +import tqdm.asyncio +import tvm +from tvm.contrib import tvmjs + +from mlc_llm.serve.engine import AsyncMLCEngine, EngineConfig +from mlc_llm.tokenizers import Tokenizer + + +class CalibrationObserver: + """A singleton class to observe the calibration parameters.""" "" + + instance: "CalibrationObserver" = None + + params: Mapping[str, tvm.nd.NDArray] = {} + + @staticmethod + def get(): + """Get the singleton instance of the class.""" "" + if CalibrationObserver.instance is None: + CalibrationObserver.instance = CalibrationObserver() + return CalibrationObserver.instance + + @tvm.register_func("mlc_llm.calibration_observer") + @staticmethod + def callback(name: str, mode: str, value: "tvm.nd.NDArray", out_value: "tvm.nd.NDArray"): + """The callback function to update the saved calibration parameters.""" + instance = CalibrationObserver.get() + if mode == "max": + reducer = np.maximum + else: + raise NotImplementedError(f"Unsupported calibration mode: {mode}") + if name in instance.params: + instance.params[name] = reducer(instance.params[name], value.numpy()) + else: + instance.params[name] = value.numpy() + out_value.copyfrom(instance.params[name]) + + def save_params(self, output: str): + """Save the calibration parameters to the given output directory.""" + tvmjs.dump_ndarray_cache( + self.params, + output, + encode_format="f32-to-bf16", + meta_data=None, + show_progress=False, + update_if_exists=True, + ) + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: Tokenizer, +) -> List[Tuple[str, int, int]]: + """Sample the requests from the given dataset.""" + # pylint: disable=too-many-locals + # Load the dataset. + with open(dataset_path, encoding="utf-8") as f: + dataset = json.load(f) + + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset + ] + prompts = [prompt for prompt, _ in dataset] + prompt_token_ids = tokenizer.encode_batch(prompts) + completions = [completion for _, completion in dataset] + completion_token_ids = tokenizer.encode_batch(completions) + tokenized_dataset: List[Tuple[str, List[int], int]] = [] + for i in range(len(dataset)): + output_len = len(completion_token_ids[i]) + tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + + # Filter out too long sequences. + filtered_dataset: List[Tuple[str, int, int]] = [] + for prompt, token_ids, output_len in tokenized_dataset: + prompt_len = len(token_ids) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + # Sample the requests. + sampled_requests = random.sample(filtered_dataset, num_requests) + return sampled_requests + + +async def send_calibration_requests( + async_engine: AsyncMLCEngine, + sampled_requests: List[Tuple[str, int, int]], + max_concurrent_requests: int, +) -> None: + """Send the calibration requests to the engine.""" + tasks = [] + + semaphore = asyncio.Semaphore(max_concurrent_requests) + + async def generate_task(request_idx): + async with semaphore: + prompt, _, output_len = sampled_requests[request_idx] + await async_engine.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + max_tokens=output_len, + request_id=str(request_idx), + ) + + for i in range(len(sampled_requests)): + task = asyncio.create_task(generate_task(i)) + tasks.append(task) + await tqdm.asyncio.tqdm.gather(*tasks) + + +def calibrate( + model: str, + device: str, + model_lib: Optional[str], + dataset: str, + output: str, + num_calibration_samples: int, + *, + seed: int, + max_num_sequence: Optional[int] = None, + max_total_sequence_length: Optional[int] = None, + prefill_chunk_size: Optional[int] = None, + max_history_size: Optional[int] = None, + gpu_memory_utilization: Optional[float] = None, +) -> None: + """Calibrate the quantized model using the given dataset.""" + # pylint: disable=too-many-arguments, too-many-locals + random.seed(seed) + async_engine = AsyncMLCEngine( + model=model, + device=device, + model_lib=model_lib, + mode="server", + engine_config=EngineConfig( + max_num_sequence=max_history_size, + max_total_sequence_length=max_total_sequence_length, + prefill_chunk_size=prefill_chunk_size, + max_history_size=max_history_size, + gpu_memory_utilization=gpu_memory_utilization, + ), + ) + sampled_requests = sample_requests(dataset, num_calibration_samples, async_engine.tokenizer) + asyncio.run( + send_calibration_requests( + async_engine, sampled_requests, max_concurrent_requests=max_num_sequence or 32 + ) + ) + async_engine.terminate() + + calibrator = CalibrationObserver.get() + calibrator.save_params(output) diff --git a/python/mlc_llm/interface/chat.py b/python/mlc_llm/interface/chat.py deleted file mode 100644 index 9c0763a6ef..0000000000 --- a/python/mlc_llm/interface/chat.py +++ /dev/null @@ -1,163 +0,0 @@ -"""Python entrypoint of chat.""" -import dataclasses -from typing import List, Optional, Union - -from prompt_toolkit import prompt as get_prompt # pylint: disable=import-error -from prompt_toolkit.key_binding import KeyBindings # pylint: disable=import-error - -from mlc_llm.callback import StreamToStdout -from mlc_llm.chat_module import ChatConfig, ChatModule, GenerationConfig -from mlc_llm.support import argparse -from mlc_llm.support.config import ConfigOverrideBase - - -@dataclasses.dataclass -class ChatConfigOverride(ConfigOverrideBase): # pylint: disable=too-many-instance-attributes - """Flags for overriding chat config.""" - - conv_template: Optional[str] = None - context_window_size: Optional[int] = None - sliding_window_size: Optional[int] = None - prefill_chunk_size: Optional[int] = None - attention_sink_size: Optional[int] = None - max_batch_size: Optional[int] = None - tensor_parallel_shards: Optional[int] = None - - @staticmethod - def from_str(source: str) -> "ChatConfigOverride": - """Parse model config override values from a string.""" - parser = argparse.ArgumentParser(description="chat config override values") - parser.add_argument("--conv_template", type=str, default=None) - parser.add_argument("--tensor_parallel_shards", type=int, default=None) - parser.add_argument("--context_window_size", type=int, default=None) - parser.add_argument("--sliding_window_size", type=int, default=None) - parser.add_argument("--prefill_chunk_size", type=int, default=None) - parser.add_argument("--attention_sink_size", type=int, default=None) - parser.add_argument("--max_batch_size", type=int, default=None) - - results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) - return ChatConfigOverride( - conv_template=results.conv_template, - tensor_parallel_shards=results.tensor_parallel_shards, - context_window_size=results.context_window_size, - sliding_window_size=results.sliding_window_size, - prefill_chunk_size=results.prefill_chunk_size, - attention_sink_size=results.attention_sink_size, - max_batch_size=results.max_batch_size, - ) - - -@dataclasses.dataclass -class GenerationConfigOverride(ConfigOverrideBase): # pylint: disable=too-many-instance-attributes - """Flags for overriding generation config.""" - - temperature: Optional[float] = None - repetition_penalty: Optional[float] = None - top_p: Optional[float] = None - mean_gen_len: Optional[int] = None - max_gen_len: Optional[int] = None - presence_penalty: Optional[float] = None - frequency_penalty: Optional[float] = None - n: Optional[int] = None # pylint: disable=invalid-name - stop: Optional[Union[str, List[str]]] = None - - @staticmethod - def from_str(source: str) -> "GenerationConfigOverride": - """Parse model config override values from a string.""" - parser = argparse.ArgumentParser(description="generation config override values") - parser.add_argument("--temperature", type=float, default=None) - parser.add_argument("--repetition_penalty", type=float, default=None) - parser.add_argument("--top_p", type=float, default=None) - parser.add_argument("--mean_gen_len", type=int, default=None) - parser.add_argument("--max_gen_len", type=int, default=None) - parser.add_argument("--presence_penalty", type=float, default=None) - parser.add_argument("--frequency_penalty", type=float, default=None) - parser.add_argument("--n", type=int, default=None) - parser.add_argument("--stop", type=str, default=None) - results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) - return GenerationConfigOverride( - temperature=results.temperature, - repetition_penalty=results.repetition_penalty, - top_p=results.top_p, - mean_gen_len=results.mean_gen_len, - max_gen_len=results.max_gen_len, - presence_penalty=results.presence_penalty, - frequency_penalty=results.frequency_penalty, - n=results.n, - stop=results.stop.split(",") if results.stop is not None else None, - ) - - -def _print_help_str(): - help_str = """You can use the following special commands: - /help print the special commands - /exit quit the cli - /stats print out the latest stats (token/sec) - /reset restart a fresh chat - /set [overrides] override settings in the generation config. For example, - `/set temperature=0.5;max_gen_len=100;stop=end,stop` - Note: Separate stop words in the `stop` option with commas (,). - Multi-line input: Use escape+enter to start a new line. -""" - print(help_str) - - -def _set_up_key_bindings(): - kb = KeyBindings() - - @kb.add("escape", "enter") - def _(event): - event.current_buffer.insert_text("\n") - - @kb.add("enter") - def _(event): - event.current_buffer.validate_and_handle() - - return kb - - -def chat( - model: str, - device: str, - opt: str, - overrides: ChatConfigOverride, - model_lib_path: Optional[str], -): - """chat with a model.""" - # Set up chat config and generate config - config = ChatConfig(opt=opt) - generate_config = GenerationConfig() - # Apply overrides - config = overrides.apply(config) - # Set up ChatModule - cm = ChatModule(model, device, chat_config=config, model_lib_path=model_lib_path) - _print_help_str() - cm._process_system_prompts() # pylint: disable=protected-access - - # Multi-line input support: set escape+enter as start a new line - kb = _set_up_key_bindings() - - while True: - prompt = get_prompt( - f"{cm._get_role_0()}: ", # pylint: disable=protected-access - key_bindings=kb, - multiline=True, - ) - if prompt[:6] == "/reset": - cm.reset_chat() - elif prompt[:5] == "/exit": - break - elif prompt[:6] == "/stats": - print(cm.stats(), flush=True) - elif prompt[:4] == "/set": - gen_config_overrides = GenerationConfigOverride.from_str(prompt.split()[1]) - generate_config = gen_config_overrides.apply(generate_config) - elif prompt[:5] == "/help": - _print_help_str() - else: - print(f"{cm._get_role_1()}: ") # pylint: disable=protected-access - cm.generate( - prompt, - progress_callback=StreamToStdout(callback_interval=2), - generation_config=generate_config, - ) diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index 4e8bcabd9e..94db96c151 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -1,11 +1,10 @@ """Python entrypoint of compilation.""" + import dataclasses -import math from io import StringIO from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple -import numpy as np from tvm import IRModule, relax, tir from tvm.ir.transform import Pass, PassContext from tvm.relax.frontend import nn @@ -84,6 +83,14 @@ def _apply_preproc_to_params( return extra_tirs +def _infer_kv_state_kind(model_type) -> str: + if "rwkv" in model_type: + return "rnn_state" + if "medusa" in model_type: + return "none" + return "kv_cache" + + def _compile(args: CompileArgs, model_config: ConfigBase): def _get_variable_bounds(model_config) -> Dict[str, int]: if hasattr(model_config, "sliding_window_size"): @@ -108,23 +115,6 @@ def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]: "preprocs": param.attrs["preprocs"], } - def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: - all_kv_cache = nn.core._attribute_finder( # pylint: disable=protected-access - model, - prefix="", - condition_yield=lambda x: isinstance(x, nn.KVCache), - ) - result = 0 - for _, kv_cache in all_kv_cache: - result += math.prod(kv_cache.unit_shape) * np.dtype(kv_cache.dtype).itemsize - if getattr(model_config, "sliding_window_size", -1) > 0: - window_size = model_config.sliding_window_size - elif getattr(model_config, "context_window_size", -1) > 0: - window_size = model_config.context_window_size - else: - window_size = 0 - return result * window_size - model_config = args.overrides.apply(model_config) with args.target: op_ext.enable( @@ -138,20 +128,19 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: if ( args.quantization.kind == "ft-quant" and hasattr(model_config, "tensor_parallel_shards") - and model_config.tensor_parallel_shards > 1 + and model_config.tensor_parallel_shards > 1 # type: ignore ): raise NotImplementedError if ( hasattr(args.quantization, "linear_weight_layout") and args.quantization.linear_weight_layout == "KN" and hasattr(model_config, "tensor_parallel_shards") - and model_config.tensor_parallel_shards > 1 + and model_config.tensor_parallel_shards > 1 # type: ignore ): raise NotImplementedError( "KN layout (q3f16_0 and q4f16_0) is not supported for tensor parallelism" ) model, _ = args.model.quantize[args.quantization.kind](model_config, args.quantization) - kv_cache_bytes = _find_kv_cache_bytes(model, model_config) # Step 2. Exporting the model to TVM Unity logger.info("Exporting the model to TVM Unity compiler") mod, named_params, ext_mods = model.export_tvm( @@ -162,7 +151,12 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: logger.info("Running optimizations using TVM Unity") additional_tirs = _apply_preproc_to_params(named_params, model_config) variable_bounds = _get_variable_bounds(model_config) - cuda_graph_symbolic_capture_hints = {"batch_decode": ["batch_size"]} + cuda_graph_symbolic_capture_hints = { + "batch_decode": ["batch_size"], + "batch_decode_to_last_hidden_states": ["batch_size"], + "batch_verify": ["batch_size", "seq_len"], + "batch_verify_to_last_hidden_states": ["batch_size", "seq_len"], + } metadata = { "model_type": args.model.name, "quantization": args.quantization.name, @@ -171,7 +165,8 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int: "attention_sink_size": getattr(model_config, "attention_sink_size", -1), "prefill_chunk_size": model_config.prefill_chunk_size, # type: ignore "tensor_parallel_shards": model_config.tensor_parallel_shards, # type: ignore - "kv_cache_bytes": kv_cache_bytes, + "kv_state_kind": _infer_kv_state_kind(args.model.name), + "max_batch_size": getattr(model_config, "max_batch_size", 1), } logger.info("Registering metadata: %s", metadata) metadata["params"] = [_get_param_metadata(name, param) for name, param in named_params] diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index 77b611d139..28c9cf4e54 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -95,7 +95,7 @@ def _flashinfer(target) -> bool: return False arch_list = detect_cuda_arch_list(target) for arch in arch_list: - if int(re.findall(r"\d+", arch)[0]) < 80: + if arch < 80: logger.warning("flashinfer is not supported on CUDA arch < 80") return False return True @@ -124,10 +124,17 @@ def _cutlass(target) -> bool: return False return self.cutlass + def _cudagraph(target) -> bool: + """correct cudagraph flag""" + if not target.kind.name == "cuda": + return False + return self.cudagraph + self.flashinfer = _flashinfer(target) self.cublas_gemm = _cublas_gemm(target, quantization) self.faster_transformer = _faster_transformer(target) self.cutlass = _cutlass(target) + self.cudagraph = _cudagraph(target) @dataclasses.dataclass @@ -188,8 +195,8 @@ def from_str(source: str) -> "ModelConfigOverride": "O2": OptimizationFlags( flashinfer=True, cublas_gemm=True, - faster_transformer=True, - cudagraph=False, + faster_transformer=False, + cudagraph=True, cutlass=True, ), "O3": OptimizationFlags( diff --git a/python/mlc_llm/interface/convert_weight.py b/python/mlc_llm/interface/convert_weight.py index 179c872e50..f6c3c5f255 100644 --- a/python/mlc_llm/interface/convert_weight.py +++ b/python/mlc_llm/interface/convert_weight.py @@ -75,9 +75,7 @@ def _convert_args(args: ConversionArgs) -> None: # pylint: disable=too-many-loc named_params = dict(_named_params) if pre_shards_num is not None: - named_params, preshard_funcs = apply_preshard( - quantize_map, named_params, int(pre_shards_num), args - ) + named_params, preshard_funcs = apply_preshard(named_params, int(pre_shards_num), args) else: preshard_funcs = None @@ -132,7 +130,7 @@ def _param_generator() -> Iterator[Tuple[str, NDArray]]: _check_param(name, param) param_names.add(name) param = param.copyto(cpu_device()) - total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize + total_bytes += math.prod(param.shape) * DataType(param.dtype).itemsize() yield name, param total_params = loader.stats.total_param_num diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index 8e617fc3d2..733dfed1ed 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -1,17 +1,20 @@ """Generator of mlc-chat-config.json and tokenizer configuration.""" -import dataclasses +# pylint: disable=E1101 import json import re import shutil +from dataclasses import asdict from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Optional from mlc_llm.conversation_template import ConvTemplateRegistry from mlc_llm.model import Model +from mlc_llm.protocol.mlc_chat_config import MLCChatConfig from mlc_llm.quantization import Quantization from mlc_llm.support import convert_tiktoken, logging from mlc_llm.support.style import bold, green, red +from mlc_llm.tokenizers import Tokenizer from .compiler_flags import ModelConfigOverride @@ -20,60 +23,13 @@ FOUND = green("Found") NOT_FOUND = red("Not found") FAILED = red("Failed") -VERSION = "0.1.0" -@dataclasses.dataclass -class MLCChatConfig: # pylint: disable=too-many-instance-attributes - """Fields in the dumped `mlc-chat-config.json` file.""" - - model_type: str - quantization: str - model_config: Dict[str, Any] - vocab_size: int - context_window_size: int - sliding_window_size: int - prefill_chunk_size: int - attention_sink_size: int - tensor_parallel_shards: int - # Control the behavior of the runtime - mean_gen_len: int = None - max_gen_len: int = None - shift_fill_factor: float = None - # Configuration of text generation - temperature: float = None - presence_penalty: float = None - frequency_penalty: float = None - repetition_penalty: float = None - top_p: float = None - # Conversation template - conv_template: Union[str, Dict[str, Any]] = None - pad_token_id: int = None - bos_token_id: int = None - eos_token_id: int = None - tokenizer_files: List[str] = dataclasses.field(default_factory=list) - # Version control - version: str = VERSION - - def apply_defaults(self) -> None: - """Apply system default value.""" - defaults = { - "pad_token_id": 0, - "bos_token_id": 1, - "eos_token_id": 2, - "temperature": 0.7, - "presence_penalty": 0.0, - "frequency_penalty": 0.0, - "repetition_penalty": 1.0, - "top_p": 0.95, - "mean_gen_len": 128, - "max_gen_len": 512, - "shift_fill_factor": 0.3, - } - for key, value in defaults.items(): - if getattr(self, key) is None: - setattr(self, key, value) - logger.info("[System default] Setting %s: %s", bold(key), value) +def apply_system_defaults_for_missing_fields(mlc_chat_config: MLCChatConfig) -> None: + """Apply system default value.""" + for key, value in mlc_chat_config.get_system_defaults_for_missing_fields().items(): + setattr(mlc_chat_config, key, value) + logger.info("[System default] Setting %s: %s", bold(key), value) def check_string(s: str) -> bool: @@ -255,11 +211,15 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b except Exception: # pylint: disable=broad-exception-caught logger.exception("%s with the exception below. Skipping", FAILED) + # 3.4. Detect tokenizer info + mlc_chat_config.tokenizer_info = asdict(Tokenizer.detect_tokenizer_info(str(output))) + logger.info("Detected tokenizer info: %s", mlc_chat_config.tokenizer_info) + # Step 4. Load system default value - mlc_chat_config.apply_defaults() + apply_system_defaults_for_missing_fields(mlc_chat_config) # Step 5. Dump the configuration file to output directory with (output / "mlc-chat-config.json").open("w", encoding="utf-8") as out_file: - json.dump(dataclasses.asdict(mlc_chat_config), out_file, indent=2) + json.dump(mlc_chat_config.model_dump(by_alias=True), out_file, indent=2) logger.info("Dumping configuration file to: %s", bold(out_file.name)) @@ -307,8 +267,11 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b "glm", "custom", # for web-llm only "phi-2", + "phi-3", "stablelm-2", "gemma_instruction", "orion", "llava", + "hermes2_pro_llama3", + "tinyllama_v1_0", } diff --git a/python/mlc_llm/interface/help.py b/python/mlc_llm/interface/help.py new file mode 100644 index 0000000000..a52e251eba --- /dev/null +++ b/python/mlc_llm/interface/help.py @@ -0,0 +1,253 @@ +"""Help message for CLI arguments.""" + +HELP = { + "config": ( + """ +1) Path to a HuggingFace model directory that contains a `config.json` or +2) Path to `config.json` in HuggingFace format, or +3) The name of a pre-defined model architecture. + +A `config.json` file in HuggingFace format defines the model architecture, including the vocabulary +size, the number of layers, the hidden size, number of attention heads, etc. +Example: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json. + +A HuggingFace directory often contains a `config.json` which defines the model architecture, +the non-quantized model weights in PyTorch or SafeTensor format, tokenizer configurations, +as well as an optional `generation_config.json` provides additional default configuration for +text generation. +Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main. +""" + ).strip(), + "quantization": """ +The quantization mode we use to compile. If unprovided, will infer from `model`. +""".strip(), + "model": """ +A path to ``mlc-chat-config.json``, or an MLC model directory that contains `mlc-chat-config.json`. +It can also be a link to a HF repository pointing to an MLC compiled model. +""".strip(), + "model_lib": """ +The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use +the provided ``model`` to search over possible paths. It the model lib is not found, it will be +compiled in a JIT manner. +""".strip(), + "model_type": """ +Model architecture such as "llama". If not set, it is inferred from `mlc-chat-config.json`. +""".strip(), + "device_compile": """ +The GPU device to compile the model to. If not set, it is inferred from GPUs available locally. +""".strip(), + "device_quantize": """ +The device used to do quantization such as "cuda" or "cuda:0". Will detect from local available GPUs +if not specified. +""".strip(), + "device_deploy": """ +The device used to deploy the model such as "cuda" or "cuda:0". Will detect from local +available GPUs if not specified. +""".strip(), + "host": """ +The host LLVM triple to compile the model to. If not set, it is inferred from the local CPU and OS. +Examples of the LLVM triple: +1) iPhones: arm64-apple-ios; +2) ARM64 Android phones: aarch64-linux-android; +3) WebAssembly: wasm32-unknown-unknown-wasm; +4) Windows: x86_64-pc-windows-msvc; +5) ARM macOS: arm64-apple-darwin. +""".strip(), + "opt": """ +Optimization flags. MLC LLM maintains a predefined set of optimization flags, +denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them, +and O3 represents extreme optimization that could potentially break the system. +Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. +--opt="cublas_gemm=1;cudagraph=0". +""".strip(), + "system_lib_prefix": """ +Adding a prefix to all symbols exported. Similar to "objcopy --prefix-symbols". +This is useful when compiling multiple models into a single library to avoid symbol +conflicts. Different from objcopy, this takes no effect for shared library. +""".strip(), + "context_window_size": """ +Option to provide the maximum sequence length supported by the model. +This is usually explicitly shown as context length or context window in the model card. +If this option is not set explicitly, by default, +it will be determined by `context_window_size` or `max_position_embeddings` in `config.json`, +and the latter is usually inaccurate for some models. +""".strip(), + "output_compile": """ +The path to the output file. The suffix determines if the output file is a shared library or +objects. Available suffixes: +1) Linux: .so (shared), .tar (objects); +2) macOS: .dylib (shared), .tar (objects); +3) Windows: .dll (shared), .tar (objects); +4) Android, iOS: .tar (objects); +5) Web: .wasm (web assembly). +""".strip(), + "source": """ +The path to original model weight, infer from `config` if missing. +""".strip(), + "source_format": """ +The format of source model weight, infer from `config` if missing. +""".strip(), + "output_quantize": """ +The output directory to save the quantized model weight. Will create `params_shard_*.bin` and +`ndarray-cache.json` in this directory. +""".strip(), + "conv_template": """ +Conversation template. It depends on how the model is tuned. Use "LM" for vanilla base model +""".strip(), + "output_gen_mlc_chat_config": """ +The output directory for generated configurations, including `mlc-chat-config.json` and tokenizer +configuration. +""".strip(), + "sliding_window_size": """ +(Experimental) The sliding window size in sliding window attention (SWA). +This optional field overrides the `sliding_window_size` in config.json for +those models that use SWA. Currently only useful when compiling Mistral. +This flag subjects to future refactoring. +""".strip(), + "prefill_chunk_size": """ +(Experimental) The chunk size during prefilling. By default, +the chunk size is the same as sliding window or max sequence length. +This flag subjects to future refactoring. +""".strip(), + "attention_sink_size": """ +(Experimental) The number of stored sinks. Only supported on Mistral yet. By default, +the number of sinks is 4. This flag subjects to future refactoring. +""".strip(), + "max_batch_size": """ +The maximum allowed batch size set for the KV cache to concurrently support. +""".strip(), + """tensor_parallel_shards""": """ +Number of shards to split the model into in tensor parallelism multi-gpu inference. +""".strip(), + "overrides": """ +Model configuration override. Configurations to override `mlc-chat-config.json`. Supports +`context_window_size`, `prefill_chunk_size`, `sliding_window_size`, `attention_sink_size`, +`max_batch_size` and `tensor_parallel_shards`. Meanwhile, model config could be explicitly +specified via details knobs, e.g. --overrides "context_window_size=1024;prefill_chunk_size=128". +""".strip(), + "modelconfig_overrides": """ +Model configuration override. Supports overriding, +`context_window_size`, `prefill_chunk_size`, `sliding_window_size`, `attention_sink_size`, +`max_num_sequence` and `tensor_parallel_shards`. The overrides could be explicitly +specified via details knobs, e.g. --overrides "context_window_size=1024;prefill_chunk_size=128". +""".strip(), + "debug_dump": """ +Specifies the directory where the compiler will store its IRs for debugging purposes +during various phases of compilation. By default, this is set to `None`, indicating +that debug dumping is disabled. +""".strip(), + "prompt": """ +The prompt of the text generation. +""".strip(), + "generate_length": """ +The target length of the text generation. +""".strip(), + "max_total_sequence_length_serve": """ +The KV cache total token capacity, i.e., the maximum total number of tokens that +the KV cache support. This decides the GPU memory size that the KV cache consumes. +If not specified, system will automatically estimate the maximum capacity based +on the vRAM size on GPU. +""".strip(), + "prefill_chunk_size_serve": """ +The maximum number of tokens the model passes for prefill each time. +It should not exceed the prefill chunk size in model config. +If not specified, this defaults to the prefill chunk size in model config. +""".strip(), + "max_history_size_serve": """ +The maximum history length for rolling back the RNN state. +If unspecified, the default value is 1. +KV cache does not need this. +""".strip(), + "enable_tracing_serve": """ +Enable Chrome Tracing for the server. +After enabling, you can send POST request to the "debug/dump_event_trace" entrypoint +to get the Chrome Trace. For example, +"curl -X POST http://127.0.0.1:8000/debug/dump_event_trace -H "Content-Type: application/json" -d '{"model": "dist/llama"}'" +""".strip(), + "mode_serve": """ +The engine mode in MLC LLM. We provide three preset modes: "local", "interactive" and "server". +The default mode is "local". +The choice of mode decides the values of "max_num_sequence", "max_total_seq_length" and +"prefill_chunk_size" when they are not explicitly specified. +1. Mode "local" refers to the local server deployment which has low request concurrency. + So the max batch size will be set to 4, and max total sequence length and prefill chunk size + are set to the context window size (or sliding window size) of the model. +2. Mode "interactive" refers to the interactive use of server, which has at most 1 concurrent + request. So the max batch size will be set to 1, and max total sequence length and prefill + chunk size are set to the context window size (or sliding window size) of the model. +3. Mode "server" refers to the large server use case which may handle many concurrent request + and want to use GPU memory as much as possible. In this mode, we will automatically infer + the largest possible max batch size and max total sequence length. +You can manually specify arguments "max_num_sequence", "max_total_seq_length" and +"prefill_chunk_size" via "--overrides" to override the automatic inferred values. +For example: --overrides "max_num_sequence=32;max_total_seq_length=4096" +""".strip(), + "additional_models_serve": """ +The model paths and (optional) model library paths of additional models (other than the main model). +When engine is enabled with speculative decoding, additional models are needed. +The way of specifying additional models is: +"--additional-models model_path_1 model_path_2 ..." or +"--additional-models model_path_1,model_lib_1 model_path_2 ...". +When the model lib of a model is not given, JIT model compilation will be activated +to compile the model automatically. +""".strip(), + "gpu_memory_utilization_serve": """ +A number in (0, 1) denoting the fraction of GPU memory used by the server in total. +It is used to infer to maximum possible KV cache capacity. +When it is unspecified, it defaults to 0.85. +Under mode "local" or "interactive", the actual memory usage may be significantly smaller than +this number. Under mode "server", the actual memory usage may be slightly larger than this number. +""".strip(), + "speculative_mode_serve": """ +The speculative decoding mode. Right now four options are supported: + - "disable", where speculative decoding is not enabled, + - "small_draft", denoting the normal speculative decoding (small draft) style, + - "eagle", denoting the eagle-style speculative decoding. + - "medusa", denoting the medusa-style speculative decoding. +The default mode is "disable". +""".strip(), + "spec_draft_length_serve": """ +The number of draft tokens to generate in speculative proposal. The default values is 4. +""".strip(), + "prefix_cache_mode_serve": """ +The prefix cache mode. Right now two options are supported: + - "disable", where prefix cache is not enabled, + - "radix", denoting the normal paged radix tree based prefix cache, +The default mode is "radix". +""".strip(), + "prefix_cache_max_num_recycling_seqs_serve": """ +The maximum number of sequences in prefix cache, default as max_batch_size. +And set 0 to disable prefix cache, set -1 to have infinite capacity prefix cache. +""".strip(), + "overrides_serve": """ +Overriding extra configurable fields of EngineConfig and model compilation config. +Supporting fields that can be be overridden: "tensor_parallel_shards", "max_num_sequence", +"max_total_seq_length", "prefill_chunk_size", "max_history_size", "gpu_memory_utilization", +"spec_draft_length", "prefix_cache_max_num_recycling_seqs", "context_window_size", +"sliding_window_size", "attention_sink_size". +Please check out the documentation of EngineConfig in mlc_llm/serve/config.py for detailed docstring +of each field. +Example: --overrides "max_num_sequence=32;max_total_seq_length=4096;tensor_parallel_shards=2" +""".strip(), + "config_package": """ +The path to "mlc-package-config.json" which is used for package build. +See "https://github.com/mlc-ai/mlc-llm/blob/main/ios/MLCChat/mlc-package-config.json" as an example. +""".strip(), + "mlc_llm_source_dir": """ +The source code path to MLC LLM. +""".strip(), + "output_package": """ +The path of output directory for the package build outputs. +""".strip(), + "calibration_dataset": """ +The path to the calibration dataset. + """.strip(), + "num_calibration_samples": """ +The number of samples used for calibration. + """.strip(), + "output_calibration": """ +The output directory to save the calibration params. + """.strip(), + "seed_calibrate": """ +The seed to sample the calibration dataset.""", +} diff --git a/python/mlc_llm/interface/jit.py b/python/mlc_llm/interface/jit.py index e999a36468..662a16450d 100644 --- a/python/mlc_llm/interface/jit.py +++ b/python/mlc_llm/interface/jit.py @@ -10,7 +10,7 @@ import sys import tempfile from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Optional, Union from tvm.runtime import Device @@ -18,9 +18,9 @@ from mlc_llm.support import logging from mlc_llm.support.auto_device import device2str from mlc_llm.support.constants import ( - MLC_CACHE_DIR, MLC_DSO_SUFFIX, MLC_JIT_POLICY, + MLC_LLM_HOME, MLC_TEMP_DIR, ) from mlc_llm.support.style import blue, bold @@ -30,13 +30,36 @@ logger = logging.getLogger(__name__) -def jit(model_path: Path, chat_config: Dict[str, Any], device: Device) -> Path: - """Just-in-time compile a MLC-Chat model.""" +@dataclasses.dataclass +class JITResult: + """The jit compilation result class.""" + + model_lib_path: str + system_lib_prefix: Optional[str] = None + + +def log_jit_policy(): + """log current jit policy""" logger.info( "%s = %s. Can be one of: ON, OFF, REDO, READONLY", bold("MLC_JIT_POLICY"), MLC_JIT_POLICY, ) + + +def jit( # pylint: disable=too-many-locals,too-many-statements + model_path: Path, + overrides: Dict[str, Any], + device: Union[Device, str], + system_lib_prefix: Optional[str] = None, + *, + skip_log_jit_policy=False, +) -> JITResult: + """Just-in-time compile a MLC-Chat model.""" + # skip logging jit policy since when outside can hint once + if not skip_log_jit_policy: + log_jit_policy() + if MLC_JIT_POLICY == "OFF": raise RuntimeError("JIT is disabled by MLC_JIT_POLICY=OFF") @@ -44,9 +67,10 @@ def jit(model_path: Path, chat_config: Dict[str, Any], device: Device) -> Path: mlc_chat_config = json.load(in_file) model_type = mlc_chat_config.pop("model_type") quantization = mlc_chat_config.pop("quantization") + lib_suffix = MLC_DSO_SUFFIX if device not in ["iphone", "android"] else "tar" def _get_optimization_flags() -> str: - opt = chat_config.pop("opt", None) + opt = overrides.pop("opt", None) if opt is None: opt = "O2" return repr(OptimizationFlags.from_str(opt)) @@ -55,27 +79,25 @@ def _get_overrides() -> str: forbid_list = ["context_window_size", "sliding_window_size", "attention_sink_size"] result = [] for field in dataclasses.fields(ModelConfigOverride): - value = chat_config.get(field.name, None) + value = overrides.get(field.name, None) if value is not None: if field.name in forbid_list and value == -1: continue result.append(f"{field.name}={value}") - if not result: - result = ["tensor_parallel_shards=1"] return ";".join(result) def _get_model_config() -> Dict[str, Any]: model_config = mlc_chat_config.pop("model_config") model_config.update(mlc_chat_config) for field in dataclasses.fields(ModelConfigOverride): - value = chat_config.get(field.name, None) + value = overrides.get(field.name, None) if value is not None: model_config[field.name] = value return MODELS[model_type].config.from_dict(model_config).asdict() - def _run_jit(opt: str, overrides: str, device: str, dst: str): + def _run_jit(opt: str, overrides: str, device: str, system_lib_prefix: Optional[str], dst: str): with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir: - dso_path = os.path.join(tmp_dir, f"lib.{MLC_DSO_SUFFIX}") + dso_path = os.path.join(tmp_dir, f"lib.{lib_suffix}") cmd = [ sys.executable, "-m", @@ -91,6 +113,8 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): "--output", dso_path, ] + if system_lib_prefix: + cmd += ["--system-lib-prefix", system_lib_prefix + "_"] logger.info("Compiling using commands below:") logger.info("%s", blue(shlex.join(cmd))) subprocess.run(cmd, check=False, env=os.environ) @@ -105,10 +129,23 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): "model_config": _get_model_config(), "overrides": _get_overrides(), "opt": _get_optimization_flags(), - "device": device2str(device), + "device": device2str(device) if isinstance(device, Device) else device, "model_type": model_type, "quantization": quantization, } + if device in ["iphone", "android"]: + if system_lib_prefix is None: + system_lib_hash_value = hashlib.md5( + json.dumps( + hash_key, + sort_keys=True, + indent=2, + ).encode("utf-8") + ).hexdigest() + system_lib_prefix = f"{model_type}_{quantization}_{system_lib_hash_value}".replace( + "-", "_" + ) + hash_key["system_lib_prefix"] = system_lib_prefix hash_value = hashlib.md5( json.dumps( hash_key, @@ -116,10 +153,10 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): indent=2, ).encode("utf-8") ).hexdigest() - dst = MLC_CACHE_DIR / "model_lib" / f"{hash_value}.so" + dst = MLC_LLM_HOME / "model_lib" / f"{hash_value}.{lib_suffix}" if dst.is_file() and MLC_JIT_POLICY in ["ON", "READONLY"]: logger.info("Using cached model lib: %s", bold(str(dst))) - return dst + return JITResult(str(dst), system_lib_prefix) if MLC_JIT_POLICY == "READONLY": raise RuntimeError( "No cached model lib found, and JIT is disabled by MLC_JIT_POLICY=READONLY" @@ -128,6 +165,7 @@ def _run_jit(opt: str, overrides: str, device: str, dst: str): opt=hash_key["opt"], overrides=hash_key["overrides"], device=hash_key["device"], + system_lib_prefix=system_lib_prefix, dst=str(dst), ) - return dst + return JITResult(str(dst), system_lib_prefix) diff --git a/python/mlc_llm/interface/package.py b/python/mlc_llm/interface/package.py new file mode 100644 index 0000000000..6cc6891360 --- /dev/null +++ b/python/mlc_llm/interface/package.py @@ -0,0 +1,367 @@ +"""Python entrypoint of package.""" + +import dataclasses +import json +import os +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Any, Dict, List, Literal + +from mlc_llm.interface import jit +from mlc_llm.support import download_cache, logging, style + +logging.enable_logging() +logger = logging.getLogger(__name__) + +SUPPORTED_DEVICES = ["iphone", "android"] + + +def build_model_library( # pylint: disable=too-many-branches,too-many-locals,too-many-statements + package_config: Dict[str, Any], device: str, bundle_dir: Path, app_config_path: Path +) -> Dict[str, str]: + """Build model libraries. Return the dictionary of "library prefix to lib path".""" + # - Create the bundle directory. + os.makedirs(bundle_dir, exist_ok=True) + # Clean up all the directories in `output/bundle`. + logger.info('Clean up all directories under "%s"', str(bundle_dir)) + for content_path in bundle_dir.iterdir(): + if content_path.is_dir(): + shutil.rmtree(content_path) + + # - Process each model, and prepare the app config. + app_config_model_list = [] + + model_entries = package_config.get("model_list", []) + if not isinstance(model_entries, list): + raise ValueError('The "model_list" in "mlc-package-config.json" is expected to be a list.') + model_lib_path_for_prepare_libs = package_config.get("model_lib_path_for_prepare_libs", {}) + if not isinstance(model_lib_path_for_prepare_libs, dict): + raise ValueError( + 'The "model_lib_path_for_prepare_libs" in "mlc-package-config.json" is expected to be ' + "a dict." + ) + + jit.log_jit_policy() + + for model_entry in package_config.get("model_list", []): + # - Parse model entry. + if not isinstance(model_entry, dict): + raise ValueError('The element of "model_list" is expected to be a dict.') + model = model_entry["model"] + model_id = model_entry["model_id"] + bundle_weight = model_entry.get("bundle_weight", False) + overrides = model_entry.get("overrides", {}) + model_lib = model_entry.get("model_lib", None) + + estimated_vram_bytes = model_entry["estimated_vram_bytes"] + if not isinstance(model, str): + raise ValueError('The value of "model" in "model_list" is expected to be a string.') + if not isinstance(model_id, str): + raise ValueError('The value of "model_id" in "model_list" is expected to be a string.') + if not isinstance(bundle_weight, bool): + raise ValueError( + 'The value of "bundle_weight" in "model_list" is expected to be a boolean.' + ) + if not isinstance(overrides, dict): + raise ValueError('The value of "overrides" in "model_list" is expected to be a dict.') + if model_lib is not None and not isinstance(model_lib, str): + raise ValueError('The value of "model_lib" in "model_list" is expected to be string.') + + # - Load model config. Download happens when needed. + model_path = download_cache.get_or_download_model(model) + + # - Jit compile if the model lib path is not specified. + model_lib_path = ( + model_lib_path_for_prepare_libs.get(model_lib, None) if model_lib is not None else None + ) + if model_lib_path is None: + if model_lib is None: + logger.info( + 'Model lib is not specified for model "%s". Now jit compile the model library.', + model_id, + ) + else: + logger.info( + 'Model lib path for "%s" is not specified in "model_lib_path_for_prepare_libs".' + "Now jit compile the model library.", + model_lib, + ) + model_lib_path, model_lib = dataclasses.astuple( + jit.jit( + model_path=model_path, + overrides=overrides, + device=device, + system_lib_prefix=model_lib, + skip_log_jit_policy=True, + ) + ) + assert model_lib is not None + model_lib_path_for_prepare_libs[model_lib] = model_lib_path + + # - Set "model_url"/"model_path" and "model_id" + app_config_model_entry = {} + is_local_model = not model.startswith("HF://") and not model.startswith("https://") + app_config_model_entry["model_id"] = model_id + app_config_model_entry["model_lib"] = model_lib + + # - Bundle weight + if is_local_model and not bundle_weight: + raise ValueError( + f'Model "{model}" in "model_list" is a local path.' + f'Please set \'"bundle_weight": true\' in the entry of model "{model}".' + ) + if bundle_weight: + if not os.path.isfile(model_path / "ndarray-cache.json"): + raise ValueError( + f'Bundle weight is set for model "{model}". However, model weights are not' + f'found under the directory "{model}". ' + + ( + "Please follow https://llm.mlc.ai/docs/compilation/convert_weights.html to " + "convert model weights." + if is_local_model + else "Please report this issue to https://github.com/mlc-ai/mlc-llm/issues." + ) + ) + # Overwrite the model weight directory in bundle. + bundle_model_weight_path = bundle_dir / model_id + logger.info( + "Bundle weight for %s, copy into %s", + style.bold(model_id), + style.bold(str(bundle_model_weight_path)), + ) + if bundle_model_weight_path.exists(): + shutil.rmtree(bundle_model_weight_path) + shutil.copytree(model_path, bundle_model_weight_path) + if bundle_weight and device == "iphone": + app_config_model_entry["model_path"] = model_id + else: + app_config_model_entry["model_url"] = model.replace("HF://", "https://huggingface.co/") + + # - estimated_vram_bytes + app_config_model_entry["estimated_vram_bytes"] = estimated_vram_bytes + + app_config_model_list.append(app_config_model_entry) + + # - Dump "mlc-app-config.json". + app_config_json_str = json.dumps( + {"model_list": app_config_model_list}, + indent=2, + ) + with open(app_config_path, "w", encoding="utf-8") as file: + print(app_config_json_str, file=file) + logger.info( + 'Dump the app config below to "%s":\n%s', + str(app_config_path), + style.green(app_config_json_str), + ) + return model_lib_path_for_prepare_libs + + +def validate_model_lib( # pylint: disable=too-many-locals + app_config_path: Path, + package_config_path: Path, + model_lib_path_for_prepare_libs: dict, + device: Literal["iphone", "android"], + output: Path, +) -> None: + """Validate the model lib prefixes of model libraries.""" + # pylint: disable=import-outside-toplevel,redefined-outer-name,shadowed-import,reimported + if device == "android": + from tvm.contrib import ndk as cc + else: + from tvm.contrib import cc + # pylint: enable=import-outside-toplevel,redefined-outer-name,shadowed-import,reimported + + with open(app_config_path, "r", encoding="utf-8") as file: + app_config = json.load(file) + + tar_list = [] + model_set = set() + + for model, model_lib_path in model_lib_path_for_prepare_libs.items(): + model_lib_path = os.path.join(model_lib_path) + lib_path_valid = os.path.isfile(model_lib_path) + if not lib_path_valid: + raise RuntimeError(f"Cannot find file {model_lib_path} as an {device} model library") + tar_list.append(model_lib_path) + model_set.add(model) + + os.makedirs(output / "lib", exist_ok=True) + lib_path = ( + output / "lib" / ("libmodel_iphone.a" if device == "iphone" else "libmodel_android.a") + ) + + def _get_model_libs(lib_path: Path) -> List[str]: + """Get the model lib prefixes in the given static lib path.""" + global_symbol_map = cc.get_global_symbol_section_map(lib_path) + libs = [] + suffix = "___tvm_dev_mblob" + for name, _ in global_symbol_map.items(): + if name.endswith(suffix): + model_lib = name[: -len(suffix)] + if model_lib.startswith("_"): + model_lib = model_lib[1:] + libs.append(model_lib) + return libs + + cc.create_staticlib(lib_path, tar_list) + available_model_libs = _get_model_libs(lib_path) + logger.info("Creating lib from %s", str(tar_list)) + logger.info("Validating the library %s", str(lib_path)) + logger.info( + "List of available model libs packaged: %s," + " if we have '-' in the model_lib string, it will be turned into '_'", + str(available_model_libs), + ) + global_symbol_map = cc.get_global_symbol_section_map(lib_path) + error_happened = False + + for item in app_config["model_list"]: + model_lib = item["model_lib"] + model_id = item["model_id"] + if model_lib not in model_set: + # NOTE: this cannot happen under new setting + # since if model_lib is not included, it will be jitted + raise RuntimeError( + f"ValidationError: model_lib={model_lib} specified for model_id={model_id} " + "is not included in model_lib_path_for_prepare_libs argument, " + "This will cause the specific model not being able to load, " + f"model_lib_path_for_prepare_libs={model_lib_path_for_prepare_libs}" + ) + + model_prefix_pattern = model_lib.replace("-", "_") + "___tvm_dev_mblob" + if ( + model_prefix_pattern not in global_symbol_map + and "_" + model_prefix_pattern not in global_symbol_map + ): + # NOTE: no lazy format is ok since this is a slow pass + model_lib_path = model_lib_path_for_prepare_libs[model_lib] + log_msg = ( + "ValidationError:\n" + f"\tmodel_lib {model_lib} requested in {str(app_config_path)}" + f" is not found in {str(lib_path)}\n" + f"\tspecifically the model_lib for {model_lib_path}.\n" + f"\tcurrent available model_libs in {str(lib_path)}: {available_model_libs}\n" + f"\tThis can happen when we manually specified model_lib_path_for_prepare_libs" + f" in {str(package_config_path)}\n" + f"\tConsider remove model_lib_path_for_prepare_libs (so library can be jitted)" + "or check the compile command" + ) + logger.info(log_msg) + error_happened = True + + if not error_happened: + logger.info(style.green("Validation pass")) + else: + logger.info(style.red("Validation failed")) + sys.exit(255) + + +def build_android_binding(mlc_llm_source_dir: Path, output: Path) -> None: + """Build android binding in MLC LLM""" + mlc4j_path = mlc_llm_source_dir / "android" / "mlc4j" + + # Move the model libraries to "build/lib/" for linking + os.makedirs(Path("build") / "lib", exist_ok=True) + src_path = str(output / "lib" / "libmodel_android.a") + dst_path = str(Path("build") / "lib" / "libmodel_android.a") + logger.info('Moving "%s" to "%s"', src_path, dst_path) + shutil.move(src_path, dst_path) + + # Build mlc4j + logger.info("Building mlc4j") + subprocess.run([sys.executable, mlc4j_path / "prepare_libs.py"], check=True, env=os.environ) + # Copy built files back to output directory. + lib_path = output / "lib" / "mlc4j" + os.makedirs(lib_path, exist_ok=True) + logger.info('Clean up all directories under "%s"', str(lib_path)) + for content_path in lib_path.iterdir(): + if content_path.is_dir(): + shutil.rmtree(content_path) + + src_path = str(mlc4j_path / "src") + dst_path = str(lib_path / "src") + logger.info('Copying "%s" to "%s"', src_path, dst_path) + shutil.copytree(src_path, dst_path) + + src_path = str(mlc4j_path / "build.gradle") + dst_path = str(lib_path / "build.gradle") + logger.info('Copying "%s" to "%s"', src_path, dst_path) + shutil.copy(src_path, dst_path) + + src_path = str(Path("build") / "output") + dst_path = str(lib_path / "output") + logger.info('Copying "%s" to "%s"', src_path, dst_path) + shutil.copytree(src_path, dst_path) + + os.makedirs(lib_path / "src" / "main" / "assets") + src_path = str(output / "bundle" / "mlc-app-config.json") + dst_path = str(lib_path / "src" / "main" / "assets" / "mlc-app-config.json") + logger.info('Moving "%s" to "%s"', src_path, dst_path) + shutil.move(src_path, dst_path) + + +def build_iphone_binding(mlc_llm_source_dir: Path, output: Path) -> None: + """Build iOS binding in MLC LLM""" + # Build iphone binding + logger.info("Build iphone binding") + subprocess.run( + ["bash", mlc_llm_source_dir / "ios" / "prepare_libs.sh"], check=True, env=os.environ + ) + + # Copy built libraries back to output directory. + for static_library in (Path("build") / "lib").iterdir(): + dst_path = str(output / "lib" / static_library.name) + logger.info('Copying "%s" to "%s"', static_library, dst_path) + shutil.copy(static_library, dst_path) + + +def package( + package_config_path: Path, + mlc_llm_source_dir: Path, + output: Path, +) -> None: + """Python entrypoint of package.""" + logger.info('MLC LLM HOME: "%s"', mlc_llm_source_dir) + + # - Read package config. + with open(package_config_path, "r", encoding="utf-8") as file: + package_config = json.load(file) + if not isinstance(package_config, dict): + raise ValueError( + "The content of MLC package config is expected to be a dict with " + f'field "model_list". However, the content of "{package_config_path}" is not a dict.' + ) + + # - Read device. + if "device" not in package_config: + raise ValueError(f'JSON file "{package_config_path}" is required to have field "device".') + device = package_config["device"] + if device not in SUPPORTED_DEVICES: + raise ValueError( + f'The "device" field of JSON file {package_config_path} is expected to be one of ' + f'{SUPPORTED_DEVICES}, while "{device}" is given in the JSON.' + ) + + bundle_dir = output / "bundle" + app_config_path = bundle_dir / "mlc-app-config.json" + # - Build model libraries. + model_lib_path_for_prepare_libs = build_model_library( + package_config, device, bundle_dir, app_config_path + ) + # - Validate model libraries. + validate_model_lib( + app_config_path, package_config_path, model_lib_path_for_prepare_libs, device, output + ) + + # - Copy model libraries + if device == "android": + build_android_binding(mlc_llm_source_dir, output) + elif device == "iphone": + build_iphone_binding(mlc_llm_source_dir, output) + else: + assert False, "Cannot reach here" + + logger.info("All finished.") diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index 40fa9fdda8..be437824cf 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -1,6 +1,6 @@ """Python entrypoint of serve.""" -from typing import Any, List, Literal, Optional +from typing import Any, List, Literal, Optional, Tuple, Union import fastapi import uvicorn @@ -8,24 +8,37 @@ from mlc_llm.protocol import error_protocol from mlc_llm.serve import engine -from mlc_llm.serve.config import SpeculativeMode -from mlc_llm.serve.entrypoints import debug_entrypoints, openai_entrypoints +from mlc_llm.serve.entrypoints import ( + debug_entrypoints, + metrics_entrypoints, + openai_entrypoints, +) from mlc_llm.serve.server import ServerContext +from mlc_llm.support import logging + +logger = logging.getLogger(__name__) def serve( model: str, device: str, - model_lib_path: Optional[str], + model_lib: Optional[str], mode: Literal["local", "interactive", "server"], - additional_models: List[str], - max_batch_size: Optional[int], + enable_debug: bool, + additional_models: List[Union[str, Tuple[str, str]]], + tensor_parallel_shards: Optional[int], + max_num_sequence: Optional[int], max_total_sequence_length: Optional[int], + max_single_sequence_length: Optional[int], prefill_chunk_size: Optional[int], + sliding_window_size: Optional[int], + attention_sink_size: Optional[int], max_history_size: Optional[int], gpu_memory_utilization: Optional[float], - speculative_mode: SpeculativeMode, - spec_draft_length: int, + speculative_mode: Literal["disable", "small_draft", "eagle", "medusa"], + spec_draft_length: Optional[int], + prefix_cache_mode: Literal["disable", "radix"], + prefix_cache_max_num_recycling_seqs: Optional[int], enable_tracing: bool, host: str, port: int, @@ -39,16 +52,24 @@ def serve( async_engine = engine.AsyncMLCEngine( model=model, device=device, - model_lib_path=model_lib_path, + model_lib=model_lib, mode=mode, - additional_models=additional_models, - max_batch_size=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - prefill_chunk_size=prefill_chunk_size, - max_history_size=max_history_size, - gpu_memory_utilization=gpu_memory_utilization, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, + engine_config=engine.EngineConfig( + additional_models=additional_models, + tensor_parallel_shards=tensor_parallel_shards, + max_num_sequence=max_num_sequence, + max_total_sequence_length=max_total_sequence_length, + max_single_sequence_length=max_single_sequence_length, + prefill_chunk_size=prefill_chunk_size, + sliding_window_size=sliding_window_size, + attention_sink_size=attention_sink_size, + max_history_size=max_history_size, + gpu_memory_utilization=gpu_memory_utilization, + speculative_mode=speculative_mode, + spec_draft_length=spec_draft_length, + prefix_cache_mode=prefix_cache_mode, + prefix_cache_max_num_recycling_seqs=prefix_cache_max_num_recycling_seqs, + ), enable_tracing=enable_tracing, ) @@ -65,7 +86,14 @@ def serve( ) app.include_router(openai_entrypoints.app) - app.include_router(debug_entrypoints.app) + app.include_router(metrics_entrypoints.app) + + server_context.enable_debug = enable_debug + + if enable_debug: + app.include_router(debug_entrypoints.app) + logger.info("Enable debug endpoint and debug_config in requests...") + app.exception_handler(error_protocol.BadRequestError)( error_protocol.bad_request_error_handler ) diff --git a/python/mlc_llm/json_ffi/engine.py b/python/mlc_llm/json_ffi/engine.py index 0c604a2ef3..294885214e 100644 --- a/python/mlc_llm/json_ffi/engine.py +++ b/python/mlc_llm/json_ffi/engine.py @@ -7,213 +7,119 @@ import tvm -from mlc_llm.protocol import openai_api_protocol +from mlc_llm.protocol import debug_protocol, openai_api_protocol from mlc_llm.serve import engine_utils from mlc_llm.serve.engine_base import ( EngineConfig, - SpeculativeMode, - _infer_kv_cache_config, + EngineMetrics, + _check_engine_config, _parse_models, _process_model_args, + _query_engine_metrics, detect_device, ) -from mlc_llm.tokenizer import Tokenizer - - -# TODO(mlc-team): further minimize the JSONFFIEngine -# construction to not depend on any config and directly pass in JSON -# model defined generation config should be read from the JSONFFIEngine via Reload -def create_model_defined_generation_config( - temperature: float, top_p: float, frequency_penalty: float, presence_penalty: float -) -> tvm.runtime.Object: - return tvm.get_global_func("mlc.json_ffi.ModelDefinedGenerationConfig")( - temperature, - top_p, - frequency_penalty, - presence_penalty, - ) - - -# TODO(mlc-team): further minimize the JSONFFIEngine -# Engine config should be passed as json str -# and backend should have good default -# only model and model_lib should be mandatory -def create_json_ffi_engine_config( - conv_template: str, model_generation_cfgs: Dict[str, tvm.runtime.Object] -) -> tvm.runtime.Object: - return tvm.get_global_func("mlc.json_ffi.JSONFFIEngineConfig")( - conv_template, model_generation_cfgs - ) +from mlc_llm.tokenizers import Tokenizer class EngineState: sync_queue: queue.Queue - def get_request_stream_callback(self) -> Callable[[List[str]], None]: + def get_request_stream_callback(self) -> Callable[[str], None]: # ChatCompletionStreamResponse - def _callback(chat_completion_stream_responses_json_str: List[str]) -> None: + def _callback(chat_completion_stream_responses_json_str: str) -> None: self._sync_request_stream_callback(chat_completion_stream_responses_json_str) return _callback - def _sync_request_stream_callback( - self, chat_completion_stream_responses_json_str: List[str] - ) -> None: + def _sync_request_stream_callback(self, chat_completion_stream_responses_json_str: str) -> None: # Put the delta outputs to the queue in the unblocking way. self.sync_queue.put_nowait(chat_completion_stream_responses_json_str) + def handle_chat_completion( + self, ffi: dict, request_json_str: str, include_usage: bool, request_id: str + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + """Helper class to handle chat completion -class JSONFFIEngine: - def __init__( # pylint: disable=too-many-arguments,too-many-locals - self, - model: str, - device: Union[str, tvm.runtime.Device] = "auto", - *, - model_lib_path: Optional[str] = None, - mode: Literal["local", "interactive", "server"] = "local", - additional_models: Optional[List[str]] = None, - max_batch_size: Optional[int] = None, - max_total_sequence_length: Optional[int] = None, - max_history_size: Optional[int] = None, - prefill_chunk_size: Optional[int] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, - spec_draft_length: int = 4, - gpu_memory_utilization: Optional[float] = None, - ) -> None: - # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) - if isinstance(device, str): - device = detect_device(device) - assert isinstance(device, tvm.runtime.Device) - ( - model_args, - model_config_paths, - self.conv_template, - ) = _process_model_args(models, device) - - # TODO(mlc-team) Remove the model config parsing, estimation below - # in favor of a simple direct passing of parameters into backend. - # JSONFFIEngine do not have to support automatic mode - # - # Instead, its config should default to interactive mode always - # and allow overrides of parameters through json config via reload - # - # This is to simplify the logic of users of JSONFFI - # since we won't have similar logics in android/iOS - # - # - Load the raw model config into dict - self.model_config_dicts = [] - for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] - with open(model_config_paths[i], "r", encoding="utf-8") as file: - self.model_config_dicts.append(json.load(file)) - - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) - - # - Initialize engine state and engine. - self.state = EngineState() - module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() - self._ffi = { - key: module[key] - for key in [ - "init_background_engine", - "reload", - "unload", - "reset", - "chat_completion", - "abort", - "get_last_error", - "run_background_loop", - "run_background_stream_back_loop", - "exit_background_loop", - ] - } - self.tokenizer = Tokenizer(model_args[0][0]) + Note + ---- + ffi is explicitly passed in to avoid cylic dependency + as ffi will capture EngineState + """ + self.sync_queue = queue.Queue() - self.engine_config = EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - max_history_size=max_history_size, - kv_state_kind=kv_state_kind, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ) + success = bool(ffi["chat_completion"](request_json_str, request_id)) - self.json_ffi_engine_config = create_json_ffi_engine_config( - conv_template=self.conv_template.model_dump_json(), - model_generation_cfgs={ - model.model: create_model_defined_generation_config( - temperature=model_config["temperature"], - top_p=model_config["top_p"], - frequency_penalty=model_config["frequency_penalty"], - presence_penalty=model_config["presence_penalty"], - ) - for model, model_config in zip(models, self.model_config_dicts) - }, - ) + try: + last_chunk_arrived = False + while not last_chunk_arrived: + chat_completion_responses_json_str = self.sync_queue.get() + chat_completion_responses_list = json.loads(chat_completion_responses_json_str) + for chat_completion_response_json_dict in chat_completion_responses_list: + chat_completion_response = ( + openai_api_protocol.ChatCompletionStreamResponse.model_validate( + chat_completion_response_json_dict + ) + ) + # the chunk with usage is always the last chunk + if chat_completion_response.usage is not None: + if include_usage: + yield chat_completion_response + last_chunk_arrived = True + break + yield chat_completion_response + except Exception as exception: # pylint: disable=broad-exception-caught + ffi["abort"](request_id) + raise exception - self._ffi["init_background_engine"]( - self.json_ffi_engine_config, - self.engine_config, - device, - self.state.get_request_stream_callback(), - None, - ) - def _background_loop(): - self._ffi["run_background_loop"]() +class BackgroundLoops: + """Helper class to keep track of background loops""" - def _background_stream_back_loop(): - self._ffi["run_background_stream_back_loop"]() + def __init__(self, ffi: dict): + self._ffi = ffi + # important: avoid self reference in closure + background_loop = self._ffi["run_background_loop"] + background_stream_back_loop = self._ffi["run_background_stream_back_loop"] # Create the background engine-driving thread and start the loop. - self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_loop_thread: threading.Thread = threading.Thread(target=background_loop) self._background_stream_back_loop_thread: threading.Thread = threading.Thread( - target=_background_stream_back_loop + target=background_stream_back_loop ) self._background_loop_thread.start() self._background_stream_back_loop_thread.start() self._terminated = False + def __del__(self): + self.terminate() + def terminate(self): + if self._terminated: + return self._terminated = True self._ffi["exit_background_loop"]() self._background_loop_thread.join() self._background_stream_back_loop_thread.join() - def chat_completion( # pylint: disable=too-many-arguments,too-many-locals + +class Completions: + """Completions class to be compatible with OpenAI API""" + + _ffi: dict + _state: EngineState + _background_loops: BackgroundLoops + + def __init__(self, ffi: dict, state: EngineState, background_loops: BackgroundLoops): + self._ffi = ffi + self._state = state + self._background_loops = background_loops + + def create( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], - model: str, + model: str = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, logprobs: bool = False, @@ -223,85 +129,163 @@ def chat_completion( # pylint: disable=too-many-arguments,too-many-locals n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, - stream: bool = False, + stream: bool = True, + stream_options: Optional[Dict[str, Any]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + extra_body: Optional[Dict[str, Any]] = None, ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: if request_id is None: request_id = f"chatcmpl-{engine_utils.random_uuid()}" - - chatcmpl_generator = self._handle_chat_completion( - openai_api_protocol.ChatCompletionRequest( - messages=[ - openai_api_protocol.ChatCompletionMessage.model_validate(message) - for message in messages - ], - model=model, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - logprobs=logprobs, - top_logprobs=top_logprobs, - logit_bias=logit_bias, - max_tokens=max_tokens, - n=n, - seed=seed, - stop=stop, - stream=stream, - temperature=temperature, - top_p=top_p, - tools=( - [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] - if tools is not None - else None - ), - tool_choice=tool_choice, - user=user, - ignore_eos=ignore_eos, - response_format=( - openai_api_protocol.RequestResponseFormat.model_validate(response_format) - if response_format is not None - else None - ), - ).model_dump_json(), + debug_config = extra_body.get("debug_config", None) if extra_body is not None else None + if not stream: + raise ValueError("JSONFFIEngine only support stream=True") + request = openai_api_protocol.ChatCompletionRequest( + messages=[ + openai_api_protocol.ChatCompletionMessage.model_validate(message) + for message in messages + ], + model=model, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + logprobs=logprobs, + top_logprobs=top_logprobs, + logit_bias=logit_bias, + max_tokens=max_tokens, n=n, + seed=seed, + stop=stop, + stream=stream, + stream_options=( + openai_api_protocol.StreamOptions.model_validate(stream_options) + if stream_options is not None + else None + ), + temperature=temperature, + top_p=top_p, + tools=( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ), + tool_choice=tool_choice, + user=user, + response_format=( + openai_api_protocol.RequestResponseFormat.model_validate(response_format) + if response_format is not None + else None + ), + debug_config=( + debug_protocol.DebugConfig.model_validate(debug_config) + if debug_config is not None + else None + ), + ) + chatcmpl_generator = self._state.handle_chat_completion( + self._ffi, + request.model_dump_json(by_alias=True), + include_usage=( + request.stream_options is not None and request.stream_options.include_usage + ), request_id=request_id, ) for response in chatcmpl_generator: yield response - def _handle_chat_completion( - self, request_json_str: str, n: int, request_id: str - ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: - self.state.sync_queue = queue.Queue() - num_unfinished_requests = n - success = bool(self._ffi["chat_completion"](request_json_str, request_id)) +class Chat: + """Chat class to be compatible with OpenAI API""" - try: - while num_unfinished_requests > 0: - chat_completion_stream_responses_json_str = self.state.sync_queue.get() - for chat_completion_response_json_str in chat_completion_stream_responses_json_str: - chat_completion_response = ( - openai_api_protocol.ChatCompletionStreamResponse.model_validate_json( - chat_completion_response_json_str - ) - ) - for choice in chat_completion_response.choices: - if choice.finish_reason is not None: - num_unfinished_requests -= 1 - yield chat_completion_response - except Exception as exception: # pylint: disable=broad-exception-caught - self._ffi["abort"](request_id) - raise exception + completions: Completions + + def __init__(self, ffi: dict, state: EngineState, background_loops: BackgroundLoops): + self.completions = Completions(ffi, state, background_loops) + + +class JSONFFIEngine: + chat: Chat + + def __init__( # pylint: disable=too-many-arguments,too-many-locals + self, + model: str, + device: Union[str, tvm.runtime.Device] = "auto", + *, + model_lib: Optional[str] = None, + mode: Literal["local", "interactive", "server"] = "local", + engine_config: Optional[EngineConfig] = None, + ) -> None: + # - Check the fields fields of `engine_config`. + if engine_config is None: + engine_config = EngineConfig() + _check_engine_config(model, model_lib, mode, engine_config) + + # - Initialize model loading info. + models = _parse_models(model, model_lib, engine_config.additional_models) + if isinstance(device, str): + device = detect_device(device) + assert isinstance(device, tvm.runtime.Device) + model_args = _process_model_args(models, device, engine_config)[0] + + # - Load the raw model config into dict + for i, model_info in enumerate(models): + model_info.model_lib = model_args[i][1] + + # - Initialize engine state and engine. + self._state = EngineState() + module = tvm.get_global_func("mlc.json_ffi.CreateJSONFFIEngine", allow_missing=False)() + self._ffi = { + key: module[key] + for key in [ + "init_background_engine", + "reload", + "unload", + "reset", + "chat_completion", + "abort", + "run_background_loop", + "run_background_stream_back_loop", + "exit_background_loop", + ] + } + self.tokenizer = Tokenizer(model_args[0][0]) + self._background_loops = BackgroundLoops(self._ffi) + + engine_config.model = model_args[0][0] + engine_config.model_lib = model_args[0][1] + engine_config.additional_models = model_args[1:] # type: ignore + engine_config.mode = mode + self.engine_config = engine_config + + self._ffi["init_background_engine"]( + device.device_type, device.device_id, self._state.get_request_stream_callback() + ) + self._ffi["reload"](self.engine_config.asjson()) + + self.chat = Chat(self._ffi, self._state, self._background_loops) + + def metrics(self) -> EngineMetrics: + """Get the engine metrics.""" + return _query_engine_metrics(self) + + def _raw_chat_completion( + self, request_json_str: str, include_usage: bool, request_id: str + ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: + """Raw chat completion API""" + return self._state.handle_chat_completion( + self._ffi, request_json_str, include_usage, request_id + ) + + def terminate(self): + """Explicitly terminate the engine""" + self._background_loops.terminate() def _test_reload(self): - self._ffi["reload"](self.engine_config) + self._ffi["reload"](self.engine_config.asjson()) def _test_reset(self): self._ffi["reset"]() diff --git a/python/mlc_llm/libinfo.py b/python/mlc_llm/libinfo.py index 4c36cab931..2212d8c7a4 100644 --- a/python/mlc_llm/libinfo.py +++ b/python/mlc_llm/libinfo.py @@ -1,4 +1,5 @@ """Library information. This is a standalone file that can be used to get various info""" + #! pylint: disable=protected-access import os import sys diff --git a/python/mlc_llm/loader/__init__.py b/python/mlc_llm/loader/__init__.py index cc8ba9c9ed..3cee0bf385 100644 --- a/python/mlc_llm/loader/__init__.py +++ b/python/mlc_llm/loader/__init__.py @@ -2,6 +2,7 @@ A subpackage of the compiler that represents mapping between external parameters, quantized parameters and parameters in MLC-defined models. """ + from .huggingface_loader import HuggingFaceLoader from .loader import LOADER, Loader from .mapping import ExternMapping, QuantizeMapping diff --git a/python/mlc_llm/loader/huggingface_loader.py b/python/mlc_llm/loader/huggingface_loader.py index 20de641735..55dc67ba6f 100644 --- a/python/mlc_llm/loader/huggingface_loader.py +++ b/python/mlc_llm/loader/huggingface_loader.py @@ -1,4 +1,5 @@ """A weight loader for HuggingFace's PyTorch format""" + import gc import json from collections import OrderedDict, defaultdict diff --git a/python/mlc_llm/loader/loader.py b/python/mlc_llm/loader/loader.py index e4c397c5ab..a1516e1a85 100644 --- a/python/mlc_llm/loader/loader.py +++ b/python/mlc_llm/loader/loader.py @@ -1,4 +1,5 @@ """A centralized registry of all existing loaders.""" + from typing import Any, Dict from .huggingface_loader import HuggingFaceLoader diff --git a/python/mlc_llm/loader/mapping.py b/python/mlc_llm/loader/mapping.py index 26d6811086..1aa10c56e9 100644 --- a/python/mlc_llm/loader/mapping.py +++ b/python/mlc_llm/loader/mapping.py @@ -1,4 +1,5 @@ """Parameter mapping for converting different LLM implementations to MLC LLM.""" + import dataclasses from typing import Callable, Dict, List, Set, Union diff --git a/python/mlc_llm/loader/stats.py b/python/mlc_llm/loader/stats.py index 4710e47307..a476e36c1b 100644 --- a/python/mlc_llm/loader/stats.py +++ b/python/mlc_llm/loader/stats.py @@ -1,4 +1,5 @@ """Statistics of the loading process of parameter loaders""" + import dataclasses import time from contextlib import contextmanager diff --git a/python/mlc_llm/loader/utils.py b/python/mlc_llm/loader/utils.py index a838841b7e..f663202cea 100644 --- a/python/mlc_llm/loader/utils.py +++ b/python/mlc_llm/loader/utils.py @@ -1,4 +1,5 @@ """Common utilities for loading parameters""" + # pylint: disable=too-few-public-methods from pathlib import Path from typing import TYPE_CHECKING, Iterator, Set, Tuple diff --git a/python/mlc_llm/model/__init__.py b/python/mlc_llm/model/__init__.py index d7b0baaa71..480c198d29 100644 --- a/python/mlc_llm/model/__init__.py +++ b/python/mlc_llm/model/__init__.py @@ -1,3 +1,4 @@ """Model definition for the compiler.""" + from .model import MODELS, Model from .model_preset import MODEL_PRESETS diff --git a/python/mlc_llm/model/baichuan/baichuan_model.py b/python/mlc_llm/model/baichuan/baichuan_model.py index 1d8f88c676..bce68b830a 100644 --- a/python/mlc_llm/model/baichuan/baichuan_model.py +++ b/python/mlc_llm/model/baichuan/baichuan_model.py @@ -57,7 +57,7 @@ def __post_init__(self): break else: raise ValueError( - "Unable to determine the maxmimum sequence length, because none of " + "Unable to determine the maximum sequence length, because none of " "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) @@ -66,21 +66,19 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring @@ -89,6 +87,11 @@ def __post_init__(self): class BaichuanAttention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: BaichuanConfig): self.hidden_size = config.hidden_size + if config.num_attention_heads % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split {config.num_attention_heads} attention heads " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.num_heads = config.num_attention_heads // config.tensor_parallel_shards self.head_dim = config.head_dim self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=False) @@ -108,6 +111,11 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: class BaichuanMLP(nn.Module): def __init__(self, config: BaichuanConfig): + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear( in_features=config.hidden_size, @@ -262,9 +270,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -339,14 +344,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/baichuan/baichuan_quantization.py b/python/mlc_llm/model/baichuan/baichuan_quantization.py index 70522b599d..2bad7e3349 100644 --- a/python/mlc_llm/model/baichuan/baichuan_quantization.py +++ b/python/mlc_llm/model/baichuan/baichuan_quantization.py @@ -1,5 +1,6 @@ """This file specifies how MLC's Baichuan parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn @@ -18,6 +19,7 @@ def group_quant( model: nn.Module = BaichuanForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/bert/__init__.py b/python/mlc_llm/model/bert/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/bert/bert_loader.py b/python/mlc_llm/model/bert/bert_loader.py new file mode 100644 index 0000000000..29762b3950 --- /dev/null +++ b/python/mlc_llm/model/bert/bert_loader.py @@ -0,0 +1,87 @@ +""" +This file specifies how MLC's BERT parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .bert_model import BertConfig, BertModel + + +def huggingface(model_config: BertConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : BertConfig + The configuration of the BERT model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = BertModel(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + attn = f"encoder.layer.{i}.attention.self" + mlc_name = f"{attn}.qkv.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.query.weight", + f"{attn}.key.weight", + f"{attn}.value.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + mlc_name = f"{attn}.qkv.bias" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.query.bias", + f"{attn}.key.bias", + f"{attn}.value.bias", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + return mapping diff --git a/python/mlc_llm/model/bert/bert_model.py b/python/mlc_llm/model/bert/bert_model.py new file mode 100644 index 0000000000..c5b440401f --- /dev/null +++ b/python/mlc_llm/model/bert/bert_model.py @@ -0,0 +1,267 @@ +""" +Implementation for BERT architecture. +""" + +import dataclasses +from functools import partial +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class BertConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the BERT model.""" + + vocab_size: int + hidden_size: int + num_hidden_layers: int + num_attention_heads: int + intermediate_size: int + hidden_act: str + layer_norm_eps: float + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + head_dim: int = 0 + max_batch_size: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.intermediate_size is None or self.intermediate_size == -1: + self.intermediate_size = 4 * self.hidden_size + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maximum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %s (%d)", + bold("prefill_chunk_size"), + bold("context_window_size"), + self.context_window_size, + ) + self.prefill_chunk_size = self.context_window_size + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d (%s)", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + self.context_window_size, + bold("context_window_size"), + ) + self.prefill_chunk_size = self.context_window_size + + +# pylint: disable=invalid-name,missing-docstring,too-many-locals + + +class BertSelfAttention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: BertConfig): + if config.num_attention_heads % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split {config.num_attention_heads} attention heads" + f"evenly to {config.tensor_parallel_shards} GPUs." + ) + self.num_heads = config.num_attention_heads // config.tensor_parallel_shards + self.head_dim = config.head_dim + + self.qkv = nn.Linear( + in_features=config.hidden_size, + out_features=3 * self.num_heads * self.head_dim, + bias=True, + ) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + d, h = self.head_dim, self.num_heads + b, s, _ = hidden_states.shape + + qkv = self.qkv(hidden_states) + qkv = op.reshape(qkv, (b, s, 3 * h, d)) + q, k, v = op.split(qkv, 3, axis=2) + + # Attention + output = op_ext.attention(q, k, v, attention_mask) + return output + + +class BertSelfOutput(nn.Module): + def __init__(self, config: BertConfig): + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: Tensor, input_tensor: Tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config: BertConfig): + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + self_output = self.self(hidden_states, attention_mask) + attention_output = self.output(self_output, hidden_states) + return attention_output + + +ACT2FN = { + "gelu": partial(nn.gelu, approximate=False), + "relu": nn.relu, + "silu": nn.silu, + "swish": nn.silu, + "gelu_new": partial(nn.gelu, approximate=True), +} + + +class BertIntermediate(nn.Module): + def __init__(self, config: BertConfig): + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: Tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config: BertConfig): + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: Tensor, input_tensor: Tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config: BertConfig): + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config: BertConfig): + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + for layer in self.layer: + hidden_states = layer(hidden_states, attention_mask) + return hidden_states + + +class BertEmbeddings(nn.Module): + def __init__(self, config: BertConfig): + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, dtype="float32") + self.position_embeddings = nn.Embedding( + config.context_window_size, config.hidden_size, dtype="float32" + ) + self.token_type_embeddings = nn.Embedding(2, config.hidden_size, dtype="float32") + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, input_ids: Tensor, token_type_ids: Tensor, position_ids: Tensor): + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertModel(nn.Module): + def __init__(self, config: BertConfig): + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def forward(self, inputs: Tensor, attention_mask: Tensor): + def _input_positions(inputs: te.Tensor): + b, s = inputs.shape + return te.compute((b, s), lambda _, j: j.astype("int32"), name="input_positions") + + input_positions = op.tensor_expr_op( + _input_positions, + name_hint="input_positions", + args=[inputs], + ) + + token_type_ids = op.zeros(inputs.shape, dtype="int32") + + embeddings = self.embeddings(inputs, token_type_ids, input_positions) + encoder_output = self.encoder(embeddings, attention_mask) + return encoder_output + + def prefill(self, inputs: Tensor, attention_mask: Tensor): + def _attention_mask(mask: te.Tensor, zero, batch_size, seq_len): + return te.compute( + (batch_size, 1, seq_len, seq_len), + lambda b, _, i, j: tir.if_then_else( + tir.any(mask[b, i] == zero, mask[b, j] == zero), + tir.min_value(self.dtype), + tir.max_value(self.dtype), + ), + name="attention_mask_prefill", + ) + + batch_size, seq_len = inputs.shape + attention_mask_2d = op.tensor_expr_op( + _attention_mask, + name_hint="attention_mask_prefill", + args=[attention_mask, tir.IntImm("int32", 0), batch_size, seq_len], + ) + return self.forward(inputs, attention_mask_2d) + + def get_default_spec(self): + mod_spec = { + "prefill": { + "inputs": nn.spec.Tensor(["batch_size", "seq_len"], "int32"), + "attention_mask": nn.spec.Tensor(["batch_size", "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/bert/bert_quantization.py b/python/mlc_llm/model/bert/bert_quantization.py new file mode 100644 index 0000000000..e65a5601c6 --- /dev/null +++ b/python/mlc_llm/model/bert/bert_quantization.py @@ -0,0 +1,55 @@ +"""This file specifies how MLC's BERT parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .bert_model import BertConfig, BertModel + + +def group_quant( + model_config: BertConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a BERT-architecture model using group quantization.""" + model: nn.Module = BertModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: BertConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a BERT-architecture model using FasterTransformer quantization.""" + model: nn.Module = BertModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: BertConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a BERT model without quantization.""" + model: nn.Module = BertModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/chatglm3/chatglm3_model.py b/python/mlc_llm/model/chatglm3/chatglm3_model.py index f7e81019e0..fa4b24e87a 100644 --- a/python/mlc_llm/model/chatglm3/chatglm3_model.py +++ b/python/mlc_llm/model/chatglm3/chatglm3_model.py @@ -63,7 +63,7 @@ def __post_init__(self): break else: raise ValueError( - "Unable to determine the maxmimum sequence length, because none of " + "Unable to determine the maximum sequence length, because none of " "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) @@ -72,21 +72,19 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring @@ -95,6 +93,11 @@ def __post_init__(self): class GLMAttention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: GLMConfig): self.hidden_size = config.hidden_size + if config.num_attention_heads % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split {config.num_attention_heads} attention heads" + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.num_heads = config.num_attention_heads // config.tensor_parallel_shards self.multi_query_attention = config.multi_query_attention self.num_key_value_heads = ( @@ -127,6 +130,11 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: class GLMMLP(nn.Module): def __init__(self, config: GLMConfig): + if config.ffn_hidden_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split ffn hidden size {config.ffn_hidden_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.ffn_hidden_size = config.ffn_hidden_size // config.tensor_parallel_shards self.dense_h_to_4h = nn.Linear( @@ -338,9 +346,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -415,14 +420,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/chatglm3/chatglm3_quantization.py b/python/mlc_llm/model/chatglm3/chatglm3_quantization.py index 26b404daa8..172188a557 100644 --- a/python/mlc_llm/model/chatglm3/chatglm3_quantization.py +++ b/python/mlc_llm/model/chatglm3/chatglm3_quantization.py @@ -1,5 +1,6 @@ """This file specifies how MLC's ChatGLM parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn @@ -18,6 +19,7 @@ def group_quant( model: nn.Module = ChatGLMForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/eagle/eagle_model.py b/python/mlc_llm/model/eagle/eagle_model.py index 355618df09..9d7820b841 100644 --- a/python/mlc_llm/model/eagle/eagle_model.py +++ b/python/mlc_llm/model/eagle/eagle_model.py @@ -190,8 +190,8 @@ def get_default_spec(self): }, }, "fuse_embed_hidden_states": { - "input_embed": nn.spec.Tensor(["length", self.hidden_size], self.dtype), - "hidden_states": nn.spec.Tensor(["length", self.hidden_size], self.dtype), + "input_embed": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), "$": { "param_mode": "packed", "effect_mode": "none", diff --git a/python/mlc_llm/model/eagle/eagle_quantization.py b/python/mlc_llm/model/eagle/eagle_quantization.py index a926f7d9dd..4510a17d2c 100644 --- a/python/mlc_llm/model/eagle/eagle_quantization.py +++ b/python/mlc_llm/model/eagle/eagle_quantization.py @@ -19,6 +19,7 @@ def group_quant( model: nn.Module = EagleForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py index 118f3ce856..b3ee189a51 100644 --- a/python/mlc_llm/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -22,7 +22,6 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes """Configuration of the Gemma model.""" hidden_size: int - hidden_act: str intermediate_size: int attention_bias: bool num_attention_heads: int @@ -31,6 +30,7 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes num_hidden_layers: int rms_norm_eps: float vocab_size: int + hidden_activation: Optional[str] = None position_embedding_base: int = 0 context_window_size: int = 0 prefill_chunk_size: int = 0 @@ -39,7 +39,9 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): - if self.hidden_act not in ("gelu", "gelu_pytorch_tanh"): + if self.hidden_activation is None: + self.hidden_activation = self.kwargs.get("hidden_act", None) + if self.hidden_activation not in ("gelu", "gelu_pytorch_tanh"): raise ValueError("Only GeLU is supported as the activation for gemma.") if self.attention_bias: raise ValueError('Only "False" attention_bias is supported for gemma') @@ -61,28 +63,26 @@ def __post_init__(self): break else: raise ValueError( - "Unable to determine the maxmimum sequence length, because none of " + "Unable to determine the maximum sequence length, because none of " "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) assert self.num_attention_heads % self.num_key_value_heads == 0 if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring @@ -104,6 +104,11 @@ def lm_head_forward(self, x: nn.Tensor): class GemmaMLP(nn.Module): def __init__(self, config: GemmaConfig): super().__init__() + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear( in_features=config.hidden_size, @@ -290,9 +295,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -367,14 +369,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/gemma/gemma_quantization.py b/python/mlc_llm/model/gemma/gemma_quantization.py index 9108dbc1ff..48a5bbfedc 100644 --- a/python/mlc_llm/model/gemma/gemma_quantization.py +++ b/python/mlc_llm/model/gemma/gemma_quantization.py @@ -19,6 +19,7 @@ def group_quant( model: nn.Module = GemmaForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/gpt2/gpt2_loader.py b/python/mlc_llm/model/gpt2/gpt2_loader.py index 0c28461242..bbdad5a1c0 100644 --- a/python/mlc_llm/model/gpt2/gpt2_loader.py +++ b/python/mlc_llm/model/gpt2/gpt2_loader.py @@ -2,6 +2,7 @@ This file specifies how MLC's GPT-2 parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ + import functools from mlc_llm.loader import ExternMapping diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index ede9dc350f..d24b73955b 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -54,7 +54,7 @@ def __post_init__(self): break else: raise ValueError( - "Unable to determine the maxmimum sequence length, because none of " + "Unable to determine the maximum sequence length, because none of " "`context_window_size`, `n_positions` or `max_sequence_length` is " "provided in `config.json`." ) @@ -63,21 +63,19 @@ def __post_init__(self): assert self.head_dim * self.n_head == self.n_embd if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring,too-many-locals @@ -86,6 +84,11 @@ def __post_init__(self): class GPT2Attention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: GPT2Config): self.embed_dim = config.n_embd + if config.n_head % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split {config.n_head} attention heads " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.num_heads = config.n_head // config.tensor_parallel_shards self.head_dim = config.head_dim self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx @@ -122,6 +125,11 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: class GPT2MLP(nn.Module): def __init__(self, config: GPT2Config): embed_dim = config.n_embd + if config.n_inner % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.n_inner} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) intermediate_size = config.n_inner // config.tensor_parallel_shards self.c_fc = nn.Linear(embed_dim, intermediate_size) self.c_proj = nn.Linear(intermediate_size, embed_dim) @@ -282,9 +290,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -359,14 +364,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/gpt2/gpt2_quantization.py b/python/mlc_llm/model/gpt2/gpt2_quantization.py index 9d8ce427d4..8b722f4b06 100644 --- a/python/mlc_llm/model/gpt2/gpt2_quantization.py +++ b/python/mlc_llm/model/gpt2/gpt2_quantization.py @@ -1,5 +1,6 @@ """This file specifies how MLC's GPT-2 parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn @@ -18,6 +19,7 @@ def group_quant( model: nn.Module = GPT2LMHeadModel(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py index c13d169be1..fd84601112 100644 --- a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py @@ -55,21 +55,19 @@ def __post_init__(self): ) if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring @@ -259,9 +257,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -336,14 +331,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_quantization.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_quantization.py index 78d68f501a..f6f1ff3cda 100644 --- a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_quantization.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_quantization.py @@ -19,6 +19,7 @@ def group_quant( model: nn.Module = GPTBigCodeForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/gpt_neox/gpt_neox_loader.py b/python/mlc_llm/model/gpt_neox/gpt_neox_loader.py index 7f4d5f56c4..4e1c92db5b 100644 --- a/python/mlc_llm/model/gpt_neox/gpt_neox_loader.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_loader.py @@ -2,6 +2,7 @@ This file specifies how MLC's GPTNeoX parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ + import functools import numpy as np diff --git a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py index 5e940a15b3..c7832ea68e 100644 --- a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py @@ -55,7 +55,7 @@ def __post_init__(self): break else: raise ValueError( - "Unable to determine the maxmimum sequence length, because none of " + "Unable to determine the maximum sequence length, because none of " "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) @@ -70,21 +70,19 @@ def __post_init__(self): if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring @@ -96,6 +94,11 @@ class GPTNeoXAttention(nn.Module): # pylint: disable=too-many-instance-attribut def __init__(self, config: GPTNeoXConfig): self.rope_theta = config.position_embedding_base self.hidden_size = config.hidden_size + if config.num_attention_heads % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split {config.num_attention_heads} attention heads " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.num_attention_heads = config.num_attention_heads // config.tensor_parallel_shards self.head_dim = config.head_dim self.query_key_value = nn.Linear( @@ -128,6 +131,11 @@ class GPTNeoXMLP(nn.Module): def __init__(self, config: GPTNeoXConfig): super().__init__() out_dtype = config.ffn_out_dtype + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.dense_h_to_4h = nn.Linear( config.hidden_size, @@ -313,9 +321,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -391,14 +396,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py b/python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py index f751426708..61dbe6d6ae 100644 --- a/python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py @@ -1,5 +1,6 @@ """This file specifies how MLC's GPTNeoX parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn @@ -18,6 +19,7 @@ def group_quant( model: nn.Module = GPTNeoXForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/internlm/internlm_model.py b/python/mlc_llm/model/internlm/internlm_model.py index f8e95ab4ec..4c7793ca2a 100644 --- a/python/mlc_llm/model/internlm/internlm_model.py +++ b/python/mlc_llm/model/internlm/internlm_model.py @@ -56,7 +56,7 @@ def __post_init__(self): break else: raise ValueError( - "Unable to determine the maxmimum sequence length, because none of " + "Unable to determine the maximum sequence length, because none of " "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) @@ -65,21 +65,19 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring @@ -88,6 +86,11 @@ def __post_init__(self): class InternLMAttention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: InternLMConfig): self.hidden_size = config.hidden_size + if config.num_attention_heads % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split {config.num_attention_heads} attention heads " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.num_heads = config.num_attention_heads // config.tensor_parallel_shards self.head_dim = config.head_dim self.max_position_embeddings = config.context_window_size @@ -111,6 +114,11 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: class InternLMMLP(nn.Module): def __init__(self, config: InternLMConfig): + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear( @@ -273,9 +281,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -350,14 +355,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/internlm/internlm_quantization.py b/python/mlc_llm/model/internlm/internlm_quantization.py index 114e9e193e..de302686ca 100644 --- a/python/mlc_llm/model/internlm/internlm_quantization.py +++ b/python/mlc_llm/model/internlm/internlm_quantization.py @@ -1,5 +1,6 @@ """This file specifies how MLC's InternLM parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn @@ -18,6 +19,7 @@ def group_quant( model: nn.Module = InternLMForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/internlm2/__init__.py b/python/mlc_llm/model/internlm2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/internlm2/internlm2_loader.py b/python/mlc_llm/model/internlm2/internlm2_loader.py new file mode 100644 index 0000000000..221e40475e --- /dev/null +++ b/python/mlc_llm/model/internlm2/internlm2_loader.py @@ -0,0 +1,95 @@ +# pylint: disable=W0611 +""" +This file specifies how MLC's InternLM2 parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .internlm2_model import InternLM2Config, InternLM2ForCausalLM + + +def huggingface(model_config: InternLM2ForCausalLM, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : InternLM2Config + The configuration of the InternLM2 model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = InternLM2ForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + def _convert_wqkv_layout(wqkv, dtype): + config = model_config + kv_groups = config.num_attention_heads // config.num_key_value_heads + head_dim = config.hidden_size // config.num_attention_heads + wqkv = wqkv.reshape(-1, 2 + kv_groups, head_dim, wqkv.shape[-1]) + wq, wk, wv = np.split(wqkv, [kv_groups, kv_groups + 1], axis=1) # pylint: disable=W0632 + wq = wq.reshape(-1, wq.shape[-1]) + wk = wk.reshape(-1, wk.shape[-1]) + wv = wv.reshape(-1, wv.shape[-1]) + return np.concatenate([wq, wk, wv], axis=0).astype(dtype) + + for i in range(model_config.num_hidden_layers): + # Add gates in MLP + mlp = f"model.layers.{i}.feed_forward" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.w1.weight", + f"{mlp}.w3.weight", + ], + functools.partial( + lambda w1, w3, dtype: np.concatenate([w1, w3], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + mlc_name = f"model.layers.{i}.attention.wqkv.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + _convert_wqkv_layout, + dtype=mlc_param.dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + return mapping diff --git a/python/mlc_llm/model/internlm2/internlm2_model.py b/python/mlc_llm/model/internlm2/internlm2_model.py new file mode 100644 index 0000000000..9c1702b787 --- /dev/null +++ b/python/mlc_llm/model/internlm2/internlm2_model.py @@ -0,0 +1,336 @@ +""" +Implementation for InternLM2 architecture. +TODO: add docstring +""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class InternLM2Config(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the InternLM2 model.""" + + vocab_size: int + hidden_size: int + num_hidden_layers: int + num_attention_heads: int + num_key_value_heads: int + rms_norm_eps: float + intermediate_size: int + bias: bool + use_cache: bool + rope_theta: int + pad_token_id: int + bos_token_id: int + eos_token_id: int + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + max_batch_size: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maximum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %d", + bold("prefill_chunk_size"), + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) + assert self.tensor_parallel_shards == 1, "InternLM2 currently does not support sharding." + + +# pylint: disable=invalid-name,missing-docstring + + +class InternLM2Attention(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: InternLM2Config): + if config.num_attention_heads % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split {config.num_attention_heads} attention heads " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) + self.hidden_size = config.hidden_size + self.rope_theta = config.rope_theta + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_position_embeddings = config.context_window_size + + self.wqkv = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=config.bias, + ) + self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads + b, s, _ = hidden_states.shape + qkv = self.wqkv(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_heads), + (b, s, h_q * d), + ) + attn_output = self.wo(output) + return attn_output + + +class InternLM2MLP(nn.Module): + def __init__(self, config: InternLM2Config): + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) + self.intermediate_size = config.intermediate_size + self.gate_up_proj = nn.Linear( + in_features=config.hidden_size, + out_features=2 * self.intermediate_size, + bias=False, + ) + self.w2 = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x: Tensor): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.w2(op.silu(x1) * x2) + + +class InternLM2DecoderLayer(nn.Module): + def __init__(self, config: InternLM2Config): + self.attention = InternLM2Attention(config) + self.feed_forward = InternLM2MLP(config) + self.attention_norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + self.ffn_norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + hidden_states = self.attention(hidden_states, paged_kv_cache, layer_id) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class InternLM2Model(nn.Module): + def __init__(self, config: InternLM2Config): + self.padding_idx = config.pad_token_id + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class InternLM2ForCausalLM(nn.Module): # pylint: disable=R0902 + def __init__(self, config: InternLM2Config): + self.model = InternLM2Model(config) + self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vocab_size = config.vocab_size + self.dtype = "float32" + self.num_hidden_layers = config.num_hidden_layers + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_attention_heads + self.rope_theta = config.rope_theta + self.tensor_parallel_shards = config.tensor_parallel_shards + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.model(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.output(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + return self.model.tok_embeddings(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.output(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.output(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def create_paged_kv_cache( # pylint: disable=too-many-arguments + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/internlm2/internlm2_quantization.py b/python/mlc_llm/model/internlm2/internlm2_quantization.py new file mode 100644 index 0000000000..38d6bea342 --- /dev/null +++ b/python/mlc_llm/model/internlm2/internlm2_quantization.py @@ -0,0 +1,54 @@ +"""This file specifies how MLC's InternLM2 parameters are quantized using group quantization +or other formats.""" +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .internlm2_model import InternLM2Config, InternLM2ForCausalLM + + +def group_quant( + model_config: InternLM2Config, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a InternLM2-architecture model using group quantization.""" + model: nn.Module = InternLM2ForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: InternLM2Config, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a InternLM2 model using FasterTransformer quantization.""" + model: nn.Module = InternLM2ForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: InternLM2Config, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a InternLM2 model without quantization.""" + model: nn.Module = InternLM2ForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/llama/llama_loader.py b/python/mlc_llm/model/llama/llama_loader.py index 070753bc2b..c166609b4c 100644 --- a/python/mlc_llm/model/llama/llama_loader.py +++ b/python/mlc_llm/model/llama/llama_loader.py @@ -2,6 +2,7 @@ This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ + import functools import numpy as np diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 18238f688e..62c07ba324 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -58,7 +58,7 @@ def __post_init__(self): break else: raise ValueError( - "Unable to determine the maxmimum sequence length, because none of " + "Unable to determine the maximum sequence length, because none of " "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) @@ -70,21 +70,19 @@ def __post_init__(self): assert self.num_attention_heads % self.num_key_value_heads == 0 if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring @@ -93,6 +91,11 @@ def __post_init__(self): class LlamaFFN(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear( in_features=config.hidden_size, @@ -248,16 +251,11 @@ def get_logits(self, hidden_states: Tensor): logits = logits.astype("float32") return logits - def batch_get_logits(self, hidden_states: Tensor, logit_positions: Tensor): + def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): op_ext.configure() if self.tensor_parallel_shards > 1: logit_positions = op.ccl_broadcast_from_worker0(logit_positions) hidden_states = op.take(hidden_states, logit_positions, axis=0) - return self.get_logits(hidden_states) - - def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): - op_ext.configure() - hidden_states = op.take(hidden_states, logit_positions, axis=0) return hidden_states def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): @@ -325,9 +323,6 @@ def batch_verify_to_last_hidden_states( hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -362,15 +357,7 @@ def get_default_spec(self): }, }, "get_logits": { - "hidden_states": nn.spec.Tensor(["batch_size", self.hidden_size], self.dtype), - "$": { - "param_mode": "packed", - "effect_mode": "none", - }, - }, - "batch_get_logits": { "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), - "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), "$": { "param_mode": "packed", "effect_mode": "none", @@ -465,14 +452,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/llama/llama_quantization.py b/python/mlc_llm/model/llama/llama_quantization.py index 2a07996c78..26b6e0e728 100644 --- a/python/mlc_llm/model/llama/llama_quantization.py +++ b/python/mlc_llm/model/llama/llama_quantization.py @@ -1,11 +1,18 @@ """This file specifies how MLC's Llama parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn from mlc_llm.loader import QuantizeMapping -from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize, SmoothQuantize +from mlc_llm.quantization import ( + AWQQuantize, + FTQuantize, + GroupQuantize, + NoQuantize, + PerTensorQuantize, +) from .llama_model import LlamaConfig, LlamaForCasualLM @@ -18,6 +25,7 @@ def group_quant( model: nn.Module = LlamaForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, @@ -69,11 +77,11 @@ def no_quant( return model, quant_map -def smooth_quant( +def per_tensor_quant( model_config: LlamaConfig, - quantization: SmoothQuantize, + quantization: PerTensorQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: - """Quantize a Llama-architecture model using SmoothQuant.""" + """Quantize a Llama-architecture model using per-tensor quantization.""" model: nn.Module = LlamaForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) @@ -81,5 +89,6 @@ def smooth_quant( model, quant_map, "", + tensor_parallel_shards=model_config.tensor_parallel_shards, ) return model, quant_map diff --git a/python/mlc_llm/model/llava/llava_model.py b/python/mlc_llm/model/llava/llava_model.py index 1498c13fdb..ed2c585c59 100644 --- a/python/mlc_llm/model/llava/llava_model.py +++ b/python/mlc_llm/model/llava/llava_model.py @@ -7,9 +7,9 @@ import logging from typing import Any, Dict, Optional, Tuple -from tvm import relax, te, tir +from tvm import relax, tir from tvm.relax.frontend import nn -from tvm.relax.frontend.nn import Module, Tensor, op +from tvm.relax.frontend.nn import Module, Tensor from tvm.relax.frontend.nn.modules import Conv2D from tvm.relax.frontend.nn.op import ( broadcast_to, @@ -74,7 +74,7 @@ class LlavaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes text_architecture: str = "LlamaForCausalLM" kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - def __post_init__(self): + def __post_init__(self) -> None: vision_config_dict: Dict[str, Any] if isinstance(self.vision_config, LlavaVisionConfig): vision_config_dict = dataclasses.asdict(self.vision_config) @@ -102,7 +102,9 @@ def __post_init__(self): for k, v in text_config_dict.pop("kwargs", {}).items(): text_config_dict[k] = v - self.text_config = CONFIG_MAP[self.text_architecture].from_dict(text_config_dict) + self.text_config = CONFIG_MAP[self.text_architecture].from_dict( # type: ignore + text_config_dict + ) for k in ["context_window_size", "sliding_window_size", "prefill_chunk_size"]: if getattr(self, k) <= 0: @@ -375,84 +377,11 @@ def to(self, dtype: Optional[str] = None): if dtype is not None: self.dtype = dtype - def _embed_input_ids(self, input_ids: Tensor) -> Tensor: - return self.language_model.embed(input_ids) - - def _embed_pixel_values_and_input_ids(self, pixel_values: Tensor, input_ids: Tensor) -> Tensor: - def _index(x, value, batch_size, seq_len): - return te.compute( - (batch_size, seq_len), - lambda i, j: tir.if_then_else( - x[i, j] == value, - j, - tir.IntImm("int32", 0), - ), - name="index", - ) - - def _concat(x: Tensor, y: Tensor, new_shape: tuple, insert_index: Tensor): - return te.compute( - (new_shape), - lambda b, i, j: tir.if_then_else( - i < insert_index[0], - x[b, i, j], - tir.if_then_else( - i < insert_index[0] + y.shape[1], - y[b, i - insert_index[0], j], - x[b, i - y.shape[1] + 1, j], - ), - ), - ) - - input_embeddings = self._embed_input_ids(input_ids) - - image_features_all = self.vision_tower.forward(pixel_values) - image_features = wrap_nested( - strided_slice( - image_features_all._expr, # pylint: disable=protected-access - axes=[1], - begin=[1], - end=[image_features_all.shape[1]], - ), - name="slice", - ) - image_features = self.multi_modal_projector(image_features) - batch_size, seq_len = input_ids.shape - image_index_tensor = op.tensor_expr_op( - _index, - name_hint="index", - args=[ - input_ids, - tir.IntImm("int32", self.config.image_token_index), - batch_size, - seq_len, - ], - ).astype("int32") - ##! Assume only one token in input - ##! Also assume batch_size = 1 for now - # TODO: Support image_count > 1 and batch_size > 1 # pylint: disable=fixme - insert_index = op.sum(image_index_tensor, axis=1) - - new_shape = ( - batch_size, - seq_len + tir.IntImm("int32", image_features.shape[1] - 1), - self.config.text_config.hidden_size, - ) - - combined_embeddings = op.tensor_expr_op( - _concat, - name_hint="combined_embeddings", - args=[input_embeddings, image_features, new_shape, insert_index], - ) - return combined_embeddings - def embed(self, input_ids: Tensor) -> Tensor: - return self._embed_input_ids(input_ids) - - def embed_with_pixel_values(self, pixel_values: Tensor, input_ids: Tensor) -> Tensor: - return self._embed_pixel_values_and_input_ids(pixel_values, input_ids) + return self.language_model.embed(input_ids) def image_embed(self, pixel_values: Tensor) -> Tensor: + pixel_values = pixel_values.astype(self.dtype) image_features_all = self.vision_tower.forward(pixel_values) image_features = wrap_nested( strided_slice( @@ -498,9 +427,6 @@ def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): return self.language_model.batch_verify(input_embeds, paged_kv_cache) - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -536,22 +462,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "embed_with_pixel_values": { - "pixel_values": nn.spec.Tensor( - [ - 1, - 3, - self.config.vision_config.image_size, - self.config.vision_config.image_size, - ], - self.dtype, - ), - "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), - "$": { - "param_mode": "packed", - "effect_mode": "none", - }, - }, "image_embed": { "pixel_values": nn.spec.Tensor( [ @@ -560,7 +470,7 @@ def get_default_spec(self): self.config.vision_config.image_size, self.config.vision_config.image_size, ], - self.dtype, + "float32", ), "$": { "param_mode": "packed", @@ -618,14 +528,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/llava/llava_quantization.py b/python/mlc_llm/model/llava/llava_quantization.py index f487a40489..79bd6ecdcb 100644 --- a/python/mlc_llm/model/llava/llava_quantization.py +++ b/python/mlc_llm/model/llava/llava_quantization.py @@ -18,6 +18,7 @@ def group_quant( model: nn.Module = LlavaForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/medusa/__init__.py b/python/mlc_llm/model/medusa/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/medusa/medusa_loader.py b/python/mlc_llm/model/medusa/medusa_loader.py new file mode 100644 index 0000000000..4fe86a4160 --- /dev/null +++ b/python/mlc_llm/model/medusa/medusa_loader.py @@ -0,0 +1,52 @@ +""" +This file specifies how MLC's Medusa parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .medusa_model import MedusaConfig, MedusaModel + + +def huggingface(model_config: MedusaConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : MedusaConfig + The configuration of the Medusa model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = MedusaModel(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_llm/model/medusa/medusa_model.py b/python/mlc_llm/model/medusa/medusa_model.py new file mode 100644 index 0000000000..01073a50ec --- /dev/null +++ b/python/mlc_llm/model/medusa/medusa_model.py @@ -0,0 +1,84 @@ +"""Medusa model definition.""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm.relax.frontend import nn + +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class MedusaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Llama model.""" + + medusa_num_heads: int + medusa_num_layers: int + hidden_size: int + vocab_size: int + max_batch_size: int = 1 + tensor_parallel_shards: int = 1 + + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + # Unused parameters. Kept for compatibility with the compilation flow. + prefill_chunk_size: int = -1 + context_window_size: int = -1 + + +# pylint: disable=missing-docstring + + +class ResBlock(nn.Module): + """Residual block with SiLU activation.""" + + def __init__(self, hidden_size): + super().__init__() + self.linear = nn.Linear(hidden_size, hidden_size) + self.act = nn.SiLU() + + def forward(self, x): + return x + self.act(self.linear(x)) + + +class MedusaModel(nn.Module): + """Medusa model definition.""" + + def __init__(self, config: MedusaConfig): + self.hidden_size = config.hidden_size + self.dtype = "float32" + self.medusa_head = nn.ModuleList( + [ + nn.ModuleList( + [ResBlock(config.hidden_size) for _ in range(config.medusa_num_layers)] + + [nn.Linear(config.hidden_size, config.vocab_size, bias=False)] + ) + for _ in range(config.medusa_num_heads) + ] + ) + + def get_default_spec(self): + mod_spec = { + "get_logits": { + "hidden_states": nn.spec.Tensor(["batch_size", self.hidden_size], self.dtype), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) + + def get_logits(self, hidden_states: nn.Tensor): + logits = [] + for head in self.medusa_head: + logits.append(head(hidden_states).astype("float32")) + return logits + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype diff --git a/python/mlc_llm/model/medusa/medusa_quantization.py b/python/mlc_llm/model/medusa/medusa_quantization.py new file mode 100644 index 0000000000..30ddc081c8 --- /dev/null +++ b/python/mlc_llm/model/medusa/medusa_quantization.py @@ -0,0 +1,21 @@ +"""This file specifies how MLC's Medusa parameters are quantized.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import NoQuantize + +from .medusa_model import MedusaConfig, MedusaModel + + +def no_quant( + model_config: MedusaConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama2 model without quantization.""" + model: nn.Module = MedusaModel(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/mistral/mistral_loader.py b/python/mlc_llm/model/mistral/mistral_loader.py index d9748f1fc5..400c0d3d1f 100644 --- a/python/mlc_llm/model/mistral/mistral_loader.py +++ b/python/mlc_llm/model/mistral/mistral_loader.py @@ -2,6 +2,7 @@ This file specifies how MLC's Mistral parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ + import functools import numpy as np diff --git a/python/mlc_llm/model/mistral/mistral_model.py b/python/mlc_llm/model/mistral/mistral_model.py index 3439f7b41f..8179b99552 100644 --- a/python/mlc_llm/model/mistral/mistral_model.py +++ b/python/mlc_llm/model/mistral/mistral_model.py @@ -32,19 +32,46 @@ class MistralConfig(ConfigBase): # pylint: disable=too-many-instance-attributes position_embedding_base: int = 0 num_key_value_heads: int = 0 head_dim: int = 0 - sliding_window_size: int = 4096 + context_window_size: int = 0 + sliding_window_size: int = 0 prefill_chunk_size: int = 0 attention_sink_size: int = 4 tensor_parallel_shards: int = 1 max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - def __post_init__(self): + def __post_init__(self): # pylint: disable=too-many-branches if self.position_embedding_base == 0: if "rope_theta" in self.kwargs: self.position_embedding_base = self.kwargs.pop("rope_theta") else: self.position_embedding_base = 10000 + if self.sliding_window_size == 0: + self.sliding_window_size = self.kwargs.pop("sliding_window", -1) + if self.sliding_window_size is None: + # Sliding window is disabled. + self.sliding_window_size = -1 + if self.context_window_size == 0: + if self.sliding_window_size == -1: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maximum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or " + "`max_sequence_length` is provided in `config.json`." + ) + else: + self.context_window_size = -1 + if self.num_key_value_heads == 0: self.num_key_value_heads = self.num_attention_heads if self.head_dim == 0: @@ -53,13 +80,17 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size assert self.attention_sink_size >= 0 if self.prefill_chunk_size == 0: + prefill_chunk_size_candidates = [] + if self.sliding_window_size != -1: + prefill_chunk_size_candidates.append(self.sliding_window_size) + if self.context_window_size != -1: + prefill_chunk_size_candidates.append(self.context_window_size) logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("sliding_window_size"), - self.sliding_window_size, + min(*prefill_chunk_size_candidates, 2048), ) - self.prefill_chunk_size = self.sliding_window_size + self.prefill_chunk_size = min(*prefill_chunk_size_candidates, 2048) # pylint: disable=invalid-name,missing-docstring @@ -70,6 +101,11 @@ class MistralMLP(nn.Module): def __init__(self, config: MistralConfig): super().__init__() + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear( in_features=config.hidden_size, @@ -89,6 +125,11 @@ class MistralAttention(nn.Module): # pylint: disable=too-many-instance-attribut def __init__(self, config: MistralConfig): self.head_dim = config.head_dim + if config.num_key_value_heads % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split {config.num_key_value_heads} key-value attention heads " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards self.qkv_proj = nn.Linear( @@ -254,9 +295,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -331,14 +369,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/mistral/mistral_quantization.py b/python/mlc_llm/model/mistral/mistral_quantization.py index b4693e4de5..aac8bd0974 100644 --- a/python/mlc_llm/model/mistral/mistral_quantization.py +++ b/python/mlc_llm/model/mistral/mistral_quantization.py @@ -1,11 +1,12 @@ """This file specifies how MLC's Mistral parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn from mlc_llm.loader import QuantizeMapping -from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize, SmoothQuantize +from mlc_llm.quantization import AWQQuantize, FTQuantize, GroupQuantize, NoQuantize from .mistral_model import MistralConfig, MistralForCasualLM @@ -18,6 +19,7 @@ def group_quant( model: nn.Module = MistralForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, @@ -67,19 +69,3 @@ def no_quant( model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) return model, quant_map - - -def smooth_quant( - model_config: MistralConfig, - quantization: SmoothQuantize, -) -> Tuple[nn.Module, QuantizeMapping]: - """Quantize a Mistral-architecture model using SmoothQuant.""" - model: nn.Module = MistralForCasualLM(model_config) - model.to(quantization.model_dtype) - quant_map = QuantizeMapping({}, {}) - model = quantization.quantize_model( - model, - quant_map, - "", - ) - return model, quant_map diff --git a/python/mlc_llm/model/mixtral/mixtral_loader.py b/python/mlc_llm/model/mixtral/mixtral_loader.py index dad152b784..5248738a69 100644 --- a/python/mlc_llm/model/mixtral/mixtral_loader.py +++ b/python/mlc_llm/model/mixtral/mixtral_loader.py @@ -2,6 +2,7 @@ This file specifies how MLC's Mixtral parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ + import functools import numpy as np diff --git a/python/mlc_llm/model/mixtral/mixtral_model.py b/python/mlc_llm/model/mixtral/mixtral_model.py index db41dc31ce..aedc566aa7 100644 --- a/python/mlc_llm/model/mixtral/mixtral_model.py +++ b/python/mlc_llm/model/mixtral/mixtral_model.py @@ -39,6 +39,11 @@ def __init__(self, config: MixtralConfig): super().__init__() self.num_experts_per_tok = config.num_experts_per_tok self.num_local_experts = config.num_local_experts + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MoE intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate = nn.Linear( in_features=config.hidden_size, diff --git a/python/mlc_llm/model/mixtral/mixtral_quantization.py b/python/mlc_llm/model/mixtral/mixtral_quantization.py index 1b5dc1e9bd..eb4983738b 100644 --- a/python/mlc_llm/model/mixtral/mixtral_quantization.py +++ b/python/mlc_llm/model/mixtral/mixtral_quantization.py @@ -25,6 +25,7 @@ def group_quant( model: nn.Module = MixtralForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, @@ -80,5 +81,6 @@ def per_tensor_quant( model, quant_map, "", + tensor_parallel_shards=model_config.tensor_parallel_shards, ) return model, quant_map diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 272cffdc80..9890e64184 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -9,6 +9,7 @@ from mlc_llm.quantization.quantization import Quantization from .baichuan import baichuan_loader, baichuan_model, baichuan_quantization +from .bert import bert_loader, bert_model, bert_quantization from .chatglm3 import chatglm3_loader, chatglm3_model, chatglm3_quantization from .eagle import eagle_loader, eagle_model, eagle_quantization from .gemma import gemma_loader, gemma_model, gemma_quantization @@ -16,14 +17,18 @@ from .gpt_bigcode import gpt_bigcode_loader, gpt_bigcode_model, gpt_bigcode_quantization from .gpt_neox import gpt_neox_loader, gpt_neox_model, gpt_neox_quantization from .internlm import internlm_loader, internlm_model, internlm_quantization +from .internlm2 import internlm2_loader, internlm2_model, internlm2_quantization from .llama import llama_loader, llama_model, llama_quantization from .llava import llava_loader, llava_model, llava_quantization +from .medusa import medusa_loader, medusa_model, medusa_quantization from .mistral import mistral_loader, mistral_model, mistral_quantization from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization from .orion import orion_loader, orion_model, orion_quantization from .phi import phi_loader, phi_model, phi_quantization +from .phi3 import phi3_loader, phi3_model, phi3_quantization from .qwen import qwen_loader, qwen_model, qwen_quantization from .qwen2 import qwen2_loader, qwen2_model, qwen2_quantization +from .qwen2_moe import qwen2_moe_loader, qwen2_moe_model, qwen2_moe_quantization from .rwkv5 import rwkv5_loader, rwkv5_model, rwkv5_quantization from .rwkv6 import rwkv6_loader, rwkv6_model, rwkv6_quantization from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization @@ -85,7 +90,7 @@ class Model: "group-quant": llama_quantization.group_quant, "ft-quant": llama_quantization.ft_quant, "awq": llama_quantization.awq_quant, - "smoothquant": llama_quantization.smooth_quant + "per-tensor-quant": llama_quantization.per_tensor_quant, }, ), "mistral": Model( @@ -101,7 +106,6 @@ class Model: "group-quant": mistral_quantization.group_quant, "no-quant": mistral_quantization.no_quant, "ft-quant": mistral_quantization.ft_quant, - "smoothquant": mistral_quantization.smooth_quant, }, ), "gemma": Model( @@ -202,6 +206,20 @@ class Model: "ft-quant": phi_quantization.ft_quant, }, ), + "phi3": Model( + name="phi3", + model=phi3_model.Phi3ForCausalLM, + config=phi3_model.Phi3Config, + source={ + "huggingface-torch": phi3_loader.phi3_huggingface, + "huggingface-safetensor": phi3_loader.phi3_huggingface, + }, + quantize={ + "no-quant": phi3_quantization.no_quant, + "group-quant": phi3_quantization.group_quant, + "ft-quant": phi3_quantization.ft_quant, + }, + ), "qwen": Model( name="qwen", model=qwen_model.QWenLMHeadModel, @@ -230,6 +248,20 @@ class Model: "ft-quant": qwen2_quantization.ft_quant, }, ), + "qwen2_moe": Model( + name="qwen2_moe", + model=qwen2_moe_model.Qwen2MoeForCausalLM, + config=qwen2_moe_model.Qwen2MoeConfig, + source={ + "huggingface-torch": qwen2_moe_loader.huggingface, + "huggingface-safetensor": qwen2_moe_loader.huggingface, + }, + quantize={ + "no-quant": qwen2_moe_quantization.no_quant, + "group-quant": qwen2_moe_quantization.group_quant, + "ft-quant": qwen2_moe_quantization.ft_quant, + }, + ), "stablelm": Model( name="stablelm", model=stablelm_model.StableLmForCausalLM, @@ -272,6 +304,20 @@ class Model: "ft-quant": internlm_quantization.ft_quant, }, ), + "internlm2": Model( + name="internlm2", + model=internlm2_model.InternLM2ForCausalLM, + config=internlm2_model.InternLM2Config, + source={ + "huggingface-torch": internlm2_loader.huggingface, + "huggingface-safetensor": internlm2_loader.huggingface, + }, + quantize={ + "no-quant": internlm2_quantization.no_quant, + "group-quant": internlm2_quantization.group_quant, + "ft-quant": internlm2_quantization.ft_quant, + }, + ), "rwkv5": Model( name="rwkv5", model=rwkv5_model.RWKV5_ForCasualLM, @@ -356,4 +402,30 @@ class Model: "awq": eagle_quantization.awq_quant, }, ), + "bert": Model( + name="bert", + model=bert_model.BertModel, + config=bert_model.BertConfig, + source={ + "huggingface-torch": bert_loader.huggingface, + "huggingface-safetensor": bert_loader.huggingface, + }, + quantize={ + "no-quant": bert_quantization.no_quant, + "group-quant": bert_quantization.group_quant, + "ft-quant": bert_quantization.ft_quant, + }, + ), + "medusa": Model( + name="medusa", + model=medusa_model.MedusaModel, + config=medusa_model.MedusaConfig, + source={ + "huggingface-torch": medusa_loader.huggingface, + "huggingface-safetensor": medusa_loader.huggingface, + }, + quantize={ + "no-quant": medusa_quantization.no_quant, + }, + ), } diff --git a/python/mlc_llm/model/model_preset.py b/python/mlc_llm/model/model_preset.py index 41abf0292c..767fa57fd6 100644 --- a/python/mlc_llm/model/model_preset.py +++ b/python/mlc_llm/model/model_preset.py @@ -1,6 +1,6 @@ """A builtin set of models available in MLC LLM.""" -from typing import Any, Dict +from typing import Any, Dict # pylint: disable=too-many-lines MODEL_PRESETS: Dict[str, Any] = { "llama2_7b": { @@ -153,6 +153,30 @@ "context_window_size": 2048, "prefill_chunk_size": 2048, }, + "tinyllama_1b_chat_v0.4": { + "_name_or_path": "/data/tianduo/tinyllama-ft/checkpoint-3890", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 5632, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 22, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 10000.0, + "tie_word_embeddings": False, + "torch_dtype": "float32", + "transformers_version": "4.33.1", + "use_cache": False, + "vocab_size": 32003, + }, "tinyllama_1b_chat_v1.0": { "architectures": ["LlamaForCausalLM"], "attention_bias": False, @@ -201,23 +225,78 @@ "prefill_chunk_size": 128, "attention_sink_size": 4, }, + "mistral_7b_v03": { + "architectures": ["MistralForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "mistral", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "sliding_window": None, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.42.0.dev0", + "use_cache": True, + "vocab_size": 32768, + }, "gpt2": { + "activation_function": "gelu_new", "architectures": ["GPT2LMHeadModel"], + "attn_pdrop": 0.1, "bos_token_id": 50256, + "embd_pdrop": 0.1, "eos_token_id": 50256, - "hidden_act": "gelu_new", - "n_embd": 768, "initializer_range": 0.02, - "n_positions": 1024, + "layer_norm_epsilon": 1e-05, "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 768, "n_head": 12, "n_layer": 12, + "n_positions": 1024, + "resid_pdrop": 0.1, + "summary_activation": None, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": True, + "summary_type": "cls_index", + "summary_use_proj": True, + "task_specific_params": {"text-generation": {"do_sample": True, "max_length": 50}}, + "vocab_size": 50257, + }, + "gpt2_medium": { + "activation_function": "gelu_new", + "architectures": ["GPT2LMHeadModel"], + "attn_pdrop": 0.1, + "bos_token_id": 50256, + "embd_pdrop": 0.1, + "eos_token_id": 50256, + "initializer_range": 0.02, "layer_norm_epsilon": 1e-05, - "transformers_version": "4.26.0.dev0", - "use_cache": True, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 1024, + "n_head": 16, + "n_layer": 24, + "n_positions": 1024, + "n_special": 0, + "predict_special_tokens": True, + "resid_pdrop": 0.1, + "summary_activation": None, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": True, + "summary_type": "cls_index", + "summary_use_proj": True, + "task_specific_params": {"text-generation": {"do_sample": True, "max_length": 50}}, "vocab_size": 50257, - "context_window_size": 2048, - "prefill_chunk_size": 2048, }, "gpt_bigcode": { "activation_function": "gelu_pytorch_tanh", @@ -358,6 +437,39 @@ "transformers_version": "4.35.2", "vocab_size": 51200, }, + "phi-3": { + "_name_or_path": "Phi-3-mini-4k-instruct", + "architectures": ["Phi3ForCausalLM"], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_phi3.Phi3Config", + "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM", + }, + "bos_token_id": 1, + "embd_pdrop": 0.0, + "eos_token_id": 32000, + "hidden_act": "silu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 4096, + "model_type": "phi3", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "original_max_position_embeddings": 4096, + "pad_token_id": 32000, + "resid_pdrop": 0.0, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 10000.0, + "sliding_window": 2047, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.39.3", + "use_cache": True, + "vocab_size": 32064, + }, "qwen": { "architectures": ["QWenLMHeadModel"], "auto_map": { @@ -416,6 +528,39 @@ "use_sliding_window": False, "vocab_size": 151936, }, + "qwen2moe": { + "architectures": ["Qwen2MoeForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 5632, + "max_position_embeddings": 32768, + "max_window_layers": 21, + "model_type": "qwen2_moe", + "num_attention_heads": 16, + "num_hidden_layers": 24, + "num_key_value_heads": 16, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.39.0.dev0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936, + "decoder_sparse_step": 1, + "moe_intermediate_size": 1408, + "shared_expert_intermediate_size": 5632, + "num_experts_per_tok": 4, + "num_experts": 60, + "norm_topk_prob": False, + "output_router_logits": False, + "router_aux_loss_coef": 0.001, + }, "stablelm": { "architectures": ["StableLmForCausalLM"], "bos_token_id": 0, @@ -710,4 +855,152 @@ "use_cache": True, "vocab_size": 128256, }, + "bert": { + "architectures": ["BertModel"], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": False, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "vocab_size": 30522, + }, + "stablelm-2-zephyr-1_6b": { + "architectures": ["StableLmForCausalLM"], + "bos_token_id": 100257, + "eos_token_id": 100257, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 5632, + "max_position_embeddings": 4096, + "model_type": "stablelm", + "layer_norm_eps": 1e-05, + "num_attention_heads": 32, + "num_hidden_layers": 24, + "num_key_value_heads": 32, + "partial_rotary_factor": 0.25, + "rope_theta": 10000, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.38.0", + "use_cache": True, + "use_qkv_bias": True, + "vocab_size": 100352, + }, + "qwen2_0_5b": { + "architectures": ["Qwen2ForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 896, + "initializer_range": 0.02, + "intermediate_size": 4864, + "max_position_embeddings": 32768, + "max_window_layers": 24, + "model_type": "qwen2", + "num_attention_heads": 14, + "num_hidden_layers": 24, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.1", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936, + }, + "qwen2_1_5b": { + "architectures": ["Qwen2ForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 1536, + "initializer_range": 0.02, + "intermediate_size": 8960, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "model_type": "qwen2", + "num_attention_heads": 12, + "num_hidden_layers": 28, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.1", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936, + }, + "qwen2_7b": { + "architectures": ["Qwen2ForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "model_type": "qwen2", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": 131072, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.41.2", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 152064, + }, + "internlm2": { + "architectures": ["InternLM2ForCausalLM"], + "attn_implementation": "eager", + "auto_map": { + "AutoConfig": "configuration_internlm2.InternLM2Config", + "AutoModelForCausalLM": "modeling_internlm2.InternLM2ForCausalLM", + "AutoModel": "modeling_internlm2.InternLM2ForCausalLM", + }, + "bias": False, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "internlm2", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pad_token_id": 2, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 1000000, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.37.1", + "use_cache": True, + "vocab_size": 92544, + }, } diff --git a/python/mlc_llm/model/orion/orion_loader.py b/python/mlc_llm/model/orion/orion_loader.py index d735052ba9..0df03e053c 100644 --- a/python/mlc_llm/model/orion/orion_loader.py +++ b/python/mlc_llm/model/orion/orion_loader.py @@ -2,6 +2,7 @@ This file specifies how MLC's Orion parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ + import functools import numpy as np diff --git a/python/mlc_llm/model/orion/orion_model.py b/python/mlc_llm/model/orion/orion_model.py index c6a2293cd2..8ab70b8ba8 100644 --- a/python/mlc_llm/model/orion/orion_model.py +++ b/python/mlc_llm/model/orion/orion_model.py @@ -58,7 +58,7 @@ def __post_init__(self): break else: raise ValueError( - "Unable to determine the maxmimum sequence length, because none of " + "Unable to determine the maximum sequence length, because none of " "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) @@ -70,21 +70,19 @@ def __post_init__(self): assert self.num_attention_heads % self.num_key_value_heads == 0 if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring @@ -93,6 +91,11 @@ def __post_init__(self): class OrionFFN(nn.Module): def __init__(self, config: OrionConfig): super().__init__() + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear( in_features=config.hidden_size, @@ -274,9 +277,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -351,14 +351,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/orion/orion_quantization.py b/python/mlc_llm/model/orion/orion_quantization.py index 740253351b..eba7976fab 100644 --- a/python/mlc_llm/model/orion/orion_quantization.py +++ b/python/mlc_llm/model/orion/orion_quantization.py @@ -1,5 +1,6 @@ """This file specifies how MLC's Orion parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn @@ -18,6 +19,7 @@ def group_quant( model: nn.Module = OrionForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/phi/phi_loader.py b/python/mlc_llm/model/phi/phi_loader.py index 70b277c6b2..0b5189e0c8 100644 --- a/python/mlc_llm/model/phi/phi_loader.py +++ b/python/mlc_llm/model/phi/phi_loader.py @@ -2,6 +2,7 @@ This file specifies how MLC's Phi parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ + import functools import numpy as np diff --git a/python/mlc_llm/model/phi/phi_model.py b/python/mlc_llm/model/phi/phi_model.py index 2c9c596ed7..c012736b61 100644 --- a/python/mlc_llm/model/phi/phi_model.py +++ b/python/mlc_llm/model/phi/phi_model.py @@ -59,14 +59,25 @@ def __post_init__(self): break else: raise ValueError( - "Unable to determine the maxmimum sequence length, because none of " + "Unable to determine the maximum sequence length, because none of " "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) if self.prefill_chunk_size == 0: - self.prefill_chunk_size = self.context_window_size - if self.prefill_chunk_size > self.context_window_size: - self.prefill_chunk_size = self.context_window_size + logger.info( + "%s defaults to %d", + bold("prefill_chunk_size"), + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) if self.num_key_value_heads == 0 or self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads if self.intermediate_size == 0 or self.intermediate_size is None: @@ -165,6 +176,11 @@ def from_phi1(config: Phi1Config) -> "PhiConfig": class PhiMLP(nn.Module): def __init__(self, config: PhiConfig): super().__init__() + if config.n_inner % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.n_inner} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.intermediate_size = config.n_inner // config.tensor_parallel_shards self.fc1 = nn.Linear(config.n_embd, self.intermediate_size) self.fc2 = nn.Linear(self.intermediate_size, config.n_embd) @@ -377,9 +393,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def embed(self, input_ids: Tensor): if self.tensor_parallel_shards > 1: input_ids = op.ccl_broadcast_from_worker0(input_ids) @@ -461,14 +474,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/phi/phi_quantization.py b/python/mlc_llm/model/phi/phi_quantization.py index 3a620d0200..854b3e6547 100644 --- a/python/mlc_llm/model/phi/phi_quantization.py +++ b/python/mlc_llm/model/phi/phi_quantization.py @@ -1,5 +1,6 @@ """This file specifies how MLC's Llama parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn @@ -18,6 +19,7 @@ def group_quant( model: nn.Module = PhiForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/phi3/__init__.py b/python/mlc_llm/model/phi3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/phi3/phi3_loader.py b/python/mlc_llm/model/phi3/phi3_loader.py new file mode 100644 index 0000000000..ab694457d7 --- /dev/null +++ b/python/mlc_llm/model/phi3/phi3_loader.py @@ -0,0 +1,79 @@ +""" +This file specifies how MLC's Phi parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .phi3_model import Phi3Config, Phi3ForCausalLM + + +def phi3_huggingface(model_config: Phi3Config, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of Phi-1/Phi-1.5 HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : PhiConfig + The configuration of the Phi model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = Phi3ForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params = model.export_tvm( # pylint: disable=W0632:unbalanced-tuple-unpacking + spec=model.get_default_spec() + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + def _add(mlc_name, hf_name): + mapping.add_mapping( + mlc_name, + [hf_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + def _concat_add(mlc_name, hf_names): + mapping.add_mapping( + mlc_name, + hf_names, + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=named_parameters[mlc_name].dtype, + ), + ) + + _add("lm_head.weight", "lm_head.weight") + _add("transformer.norm.weight", "model.norm.weight") + _add("transformer.embd.weight", "model.embed_tokens.weight") + + prefix = "transformer.h" + hf_prefix = "model.layers" + for i in range(model_config.num_hidden_layers): + _add(f"{prefix}.{i}.ln.weight", f"{hf_prefix}.{i}.input_layernorm.weight") + _add(f"{prefix}.{i}.mlp.down_proj.weight", f"{hf_prefix}.{i}.mlp.down_proj.weight") + _add(f"{prefix}.{i}.mlp.gate_up_proj.weight", f"{hf_prefix}.{i}.mlp.gate_up_proj.weight") + _add( + f"{prefix}.{i}.post_attention_layernorm.weight", + f"{hf_prefix}.{i}.post_attention_layernorm.weight", + ) + _add(f"{prefix}.{i}.mixer.out_proj.weight", f"{hf_prefix}.{i}.self_attn.o_proj.weight") + _add(f"{prefix}.{i}.mixer.qkv_proj.weight", f"{hf_prefix}.{i}.self_attn.qkv_proj.weight") + return mapping diff --git a/python/mlc_llm/model/phi3/phi3_model.py b/python/mlc_llm/model/phi3/phi3_model.py new file mode 100644 index 0000000000..0bd293e715 --- /dev/null +++ b/python/mlc_llm/model/phi3/phi3_model.py @@ -0,0 +1,376 @@ +""" +Implementation for Phi architecture. +TODO: add docstring +""" + +import dataclasses +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class Phi3Config(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Phi-3 model.""" + + model_type: str # "phi", "phi-msft", "mixformer-sequential" + hidden_size: int + vocab_size: int + num_hidden_layers: int + num_attention_heads: int + intermediate_size: int + rms_norm_eps: float + num_key_value_heads: int + position_embedding_base: int = 0 + context_window_size: int = 0 + prefill_chunk_size: int = 0 + head_dim: int = 0 + tensor_parallel_shards: int = 1 + max_batch_size: int = 1 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maximum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %d", + bold("prefill_chunk_size"), + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) + + if self.num_key_value_heads == 0 or self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.head_dim * self.num_attention_heads == self.hidden_size + assert self.num_attention_heads % self.num_key_value_heads == 0 + + +# pylint: disable=invalid-name,missing-docstring + + +class Phi3MLP(nn.Module): + def __init__(self, config: Phi3Config): + super().__init__() + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) + + def forward(self, hidden_states: Tensor): + up_states = self.gate_up_proj(hidden_states) + gate, up_states = nn.op.split(up_states, 2, axis=-1) + up_states = up_states * op.silu(gate) + return self.down_proj(up_states) + + +class PhiMHA(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: Phi3Config): + self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards + assert config.num_attention_heads % config.tensor_parallel_shards == 0, ( + f"num_attention_heads({config.num_attention_heads}) " + "must be divisible by tensor_parallel_shards" + ) + self.num_key_value_heads = config.num_key_value_heads // config.tensor_parallel_shards + assert config.num_key_value_heads % config.tensor_parallel_shards == 0, ( + f"num_attention_heads({config.num_key_value_heads}) " + "must be divisible by tensor_parallel_shards" + ) + self.head_dim = config.head_dim + + self.qkv_proj = nn.Linear( + in_features=config.hidden_size, + out_features=(self.num_q_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=False, + ) + self.out_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_key_value_heads + b, s, _ = hidden_states.shape + # QKV Projection + qkv = self.qkv_proj(hidden_states) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + (b, s, h_q * d), + ) + return self.out_proj(output) + + +class Phi3ParallelBlock(nn.Module): + def __init__(self, config: Phi3Config): + super().__init__() + + self.ln = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + self.mixer = PhiMHA(config) + self.mlp = Phi3MLP(config) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, -1, config.rms_norm_eps, bias=False + ) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.mixer.num_q_heads * hd + k = self.mixer.num_key_value_heads * hd + v = self.mixer.num_key_value_heads * hd + i = self.mlp.intermediate_size + + _set(self.mixer.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0)) + _set(self.mixer.out_proj, tp.ShardSingleDim("_shard_o", dim=1)) + _set(self.mlp.gate_up_proj, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0)) + _set(self.mlp.down_proj, tp.ShardSingleDim("_shard_mlp_down", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + attn_outputs = self.mixer(self.ln(hidden_states), paged_kv_cache, layer_id) + hidden_states = self._apply_parallel_residual(attn_outputs, hidden_states) + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = self._apply_parallel_residual(out, hidden_states) + return hidden_states + + def _apply_parallel_residual(self, mlp_out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(mlp_out + residual / self.tensor_parallel_shards, "sum") + return mlp_out + residual + + +class Phi3Model(nn.Module): + def __init__(self, config: Phi3Config) -> None: + super().__init__() + self.embd = nn.Embedding(config.vocab_size, config.hidden_size) + self.h = nn.ModuleList([Phi3ParallelBlock(config) for _ in range(config.num_hidden_layers)]) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + + def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = input_embed + for layer_id, layer in enumerate(self.h): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class Phi3ForCausalLM(nn.Module): + # pylint: disable=too-many-instance-attributes + def __init__(self, config: Phi3Config) -> None: + super().__init__() + + self.transformer = Phi3Model(config) + self.lm_head = nn.Linear(config.hidden_size, "vocab_size", bias=False) + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.vocab_size = config.vocab_size + self.rope_theta = config.position_embedding_base + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.transformer(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + lm_logits = self.lm_head(hidden_states) + if lm_logits.dtype != "float32": + lm_logits = lm_logits.astype("float32") + return lm_logits + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.transformer(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + + if logits.dtype != "float32": + logits = logits.astype("float32") + + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.transformer(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + embeds = self.transformer.embd(input_ids) + return embeds + + def create_paged_kv_cache( # pylint: disable=too-many-arguments + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/phi3/phi3_quantization.py b/python/mlc_llm/model/phi3/phi3_quantization.py new file mode 100644 index 0000000000..c0e9fced7d --- /dev/null +++ b/python/mlc_llm/model/phi3/phi3_quantization.py @@ -0,0 +1,55 @@ +"""This file specifies how MLC's Llama parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .phi3_model import Phi3Config, Phi3ForCausalLM + + +def group_quant( + model_config: Phi3Config, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Phi-architecture model using group quantization.""" + model: nn.Module = Phi3ForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: Phi3Config, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Phi-architecture model using FasterTransformer quantization.""" + model: nn.Module = Phi3ForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: Phi3Config, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Phi model without quantization.""" + model: nn.Module = Phi3ForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/qwen/qwen_loader.py b/python/mlc_llm/model/qwen/qwen_loader.py index 5b5f8fe5be..4abe064cb8 100644 --- a/python/mlc_llm/model/qwen/qwen_loader.py +++ b/python/mlc_llm/model/qwen/qwen_loader.py @@ -2,6 +2,7 @@ This file specifies how MLC's QWen parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ + import functools import numpy as np diff --git a/python/mlc_llm/model/qwen/qwen_model.py b/python/mlc_llm/model/qwen/qwen_model.py index 09bb8e854f..7fb7e0eb82 100644 --- a/python/mlc_llm/model/qwen/qwen_model.py +++ b/python/mlc_llm/model/qwen/qwen_model.py @@ -54,7 +54,7 @@ def __post_init__(self): break else: raise ValueError( - "Unable to determine the maxmimum sequence length, because none of " + "Unable to determine the maximum sequence length, because none of " "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) @@ -63,21 +63,19 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring @@ -86,6 +84,11 @@ def __post_init__(self): class QWenAttention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: QWenConfig): self.hidden_size = config.hidden_size + if config.num_attention_heads % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split {config.num_attention_heads} attention heads " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.num_heads = config.num_attention_heads // config.tensor_parallel_shards self.head_dim = config.head_dim @@ -112,6 +115,11 @@ def forward( # pylint: disable=too-many-locals class QWenMLP(nn.Module): def __init__(self, config: QWenConfig): + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear( in_features=config.hidden_size, @@ -268,9 +276,6 @@ def batch_verify(self, inputs: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(inputs, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -345,14 +350,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/qwen/qwen_quantization.py b/python/mlc_llm/model/qwen/qwen_quantization.py index 862cd6fd8c..38959512d6 100644 --- a/python/mlc_llm/model/qwen/qwen_quantization.py +++ b/python/mlc_llm/model/qwen/qwen_quantization.py @@ -1,5 +1,6 @@ """This file specifies how MLC's QWen parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn @@ -18,6 +19,7 @@ def group_quant( model: nn.Module = QWenLMHeadModel(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/qwen2/qwen2_model.py b/python/mlc_llm/model/qwen2/qwen2_model.py index 6eae4c2bb0..89ca027777 100644 --- a/python/mlc_llm/model/qwen2/qwen2_model.py +++ b/python/mlc_llm/model/qwen2/qwen2_model.py @@ -33,11 +33,13 @@ class QWen2Config(ConfigBase): # pylint: disable=too-many-instance-attributes rms_norm_eps: float rope_theta: int vocab_size: int + tie_word_embeddings: bool = False context_window_size: int = 0 prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 head_dim: int = 0 dtype: str = "float32" + max_batch_size: int = 1 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -63,21 +65,19 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring,too-many-locals @@ -86,6 +86,11 @@ def __post_init__(self): class QWen2Attention(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: QWen2Config): self.head_dim = config.head_dim + if config.num_key_value_heads % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split {config.num_key_value_heads} key-value attention heads " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.num_attention_heads = config.num_attention_heads // config.tensor_parallel_shards self.num_key_value_heads = config.num_key_value_heads // config.tensor_parallel_shards self.rope_theta = config.rope_theta @@ -121,8 +126,26 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: } +class Qwen2Embedding(nn.Embedding): + """The embedding module specialized for Qwen2 so that + it can be shared with the final lm_head. + """ + + def lm_head_forward(self, x: nn.Tensor): + """The lm_head forwarding, which transposes the weight and multiplies + with the input tensor. + """ + weight = nn.op.permute_dims(self.weight) + return nn.op.matmul(x, weight, out_dtype="float32") + + class QWen2MLP(nn.Module): def __init__(self, config: QWen2Config): + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear(config.hidden_size, 2 * self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) @@ -186,7 +209,7 @@ def _apply_residual(self, out, residual): class QWen2Model(nn.Module): def __init__(self, config: QWen2Config): - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.embed_tokens = Qwen2Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( [QWen2DecoderLayer(config) for _ in range(config.num_hidden_layers)] ) @@ -203,7 +226,9 @@ def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): class QWen2LMHeadModel(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: QWen2Config): self.model = QWen2Model(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.tie_word_embeddings = config.tie_word_embeddings + if not config.tie_word_embeddings: + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.dtype = config.dtype self.hidden_size = config.hidden_size self.num_hidden_layers = config.num_hidden_layers @@ -232,7 +257,11 @@ def batch_forward( hidden_states = self.model(input_embeds, paged_kv_cache) if logit_positions is not None: hidden_states = op.take(hidden_states, logit_positions, axis=1) - logits = self.lm_head(hidden_states) + + if self.tie_word_embeddings: + logits = self.model.embed_tokens.lm_head_forward(hidden_states) + else: + logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") return logits @@ -251,7 +280,10 @@ def _index(x: te.Tensor): # x[:-1,:] hidden_states = self.model(input_embed, paged_kv_cache) hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) - logits = self.lm_head(hidden_states) + if self.tie_word_embeddings: + logits = self.model.embed_tokens.lm_head_forward(hidden_states) + else: + logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") return logits, paged_kv_cache @@ -260,7 +292,10 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) - logits = self.lm_head(hidden_states) + if self.tie_word_embeddings: + logits = self.model.embed_tokens.lm_head_forward(hidden_states) + else: + logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") return logits, paged_kv_cache @@ -281,9 +316,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -358,14 +390,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/qwen2/qwen2_quantization.py b/python/mlc_llm/model/qwen2/qwen2_quantization.py index b5e3791331..3a8546236c 100644 --- a/python/mlc_llm/model/qwen2/qwen2_quantization.py +++ b/python/mlc_llm/model/qwen2/qwen2_quantization.py @@ -19,6 +19,7 @@ def group_quant( model: nn.Module = QWen2LMHeadModel(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/qwen2_moe/__init__.py b/python/mlc_llm/model/qwen2_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/qwen2_moe/qwen2_moe_loader.py b/python/mlc_llm/model/qwen2_moe/qwen2_moe_loader.py new file mode 100644 index 0000000000..cbdcc5b029 --- /dev/null +++ b/python/mlc_llm/model/qwen2_moe/qwen2_moe_loader.py @@ -0,0 +1,130 @@ +""" +This file specifies how MLC's QWen2 parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .qwen2_moe_model import Qwen2MoeConfig, Qwen2MoeForCausalLM + + +def huggingface(model_config: Qwen2MoeConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : QWen2Config + The configuration of the GPT-2 model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = Qwen2MoeForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # map attention weight + attn = f"model.layers.{i}.self_attn" + for weight_type in ["weight", "bias"]: + mlc_name = f"{attn}.c_attn.{weight_type}" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{weight_type}", + f"{attn}.k_proj.{weight_type}", + f"{attn}.v_proj.{weight_type}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # map mlp shared expert weight + mlp = f"model.layers.{i}.mlp" + shared_expert = f"{mlp}.shared_expert" + mlc_name = f"{shared_expert}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{shared_expert}.gate_proj.weight", + f"{shared_expert}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # map mlp moe gate and up weight + mlc_name = f"{mlp}.moe_gate_up_proj.weight" + + def combine_expert_gate_up(*hf_params, dtype): + stack = [] + for i in range(0, len(hf_params), 2): + stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0)) + return np.stack(stack, axis=0).astype(dtype) + + mapping.add_mapping( + mlc_name, + functools.reduce( + lambda a, b: a + b, + [ + [ + f"{mlp}.experts.{expert_id}.gate_proj.weight", + f"{mlp}.experts.{expert_id}.up_proj.weight", + ] + for expert_id in range(model_config.num_experts) + ], + ), + functools.partial( + combine_expert_gate_up, + dtype=mlc_param.dtype, + ), + ) + + # map mlp moe gate and up weight + mlc_name = f"{mlp}.moe_down_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.experts.{expert_id}.down_proj.weight" + for expert_id in range(model_config.num_experts) + ], + functools.partial( + lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py b/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py new file mode 100644 index 0000000000..59b7ae8375 --- /dev/null +++ b/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py @@ -0,0 +1,388 @@ +""" +Implementation for QWEN2MOE architecture. +""" + +import dataclasses +from typing import Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.model.qwen2.qwen2_model import ACT2FN, QWen2Attention, QWen2Config +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.nn.expert import MixtralExperts +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class Qwen2MoeConfig(QWen2Config): # pylint: disable=too-many-instance-attributes + """Configuration of the Qwen2Moe model.""" + + moe_intermediate_size: int = 0 + shared_expert_intermediate_size: int = 0 + num_experts_per_tok: int = 0 + num_experts: int = 0 + decoder_sparse_step: int = 0 + norm_topk_prob: bool = False + + +# pylint: disable=invalid-name,missing-docstring,too-many-locals + + +class Qwen2MoeMLP(nn.Module): + def __init__(self, config: Qwen2MoeConfig, intermediate_size: Optional[int] = None): + intermediate_size = intermediate_size or config.intermediate_size + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MoE MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) + self.intermediate_size = intermediate_size // config.tensor_parallel_shards + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: Tensor): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.down_proj(self.act_fn(x1) * x2) + + +class Qwen2MoeSparseMoeBlock(nn.Module): # pylint: disable=too-many-instance-attributes + """MoE layer for Qwen2MoE model.""" + + def __init__(self, config: Qwen2MoeConfig): + super().__init__() + self.num_experts_per_tok = config.num_experts_per_tok + self.num_experts = config.num_experts + if config.moe_intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MoE intermediate size {config.moe_intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) + self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards + self.norm_topk_prob = config.norm_topk_prob + self.shared_expert = Qwen2MoeMLP(config, config.shared_expert_intermediate_size) + self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias=False) + + self.gate = nn.Linear( + in_features=config.hidden_size, + out_features=config.num_experts, + bias=False, + ) + self.moe_gate_up_proj = MixtralExperts( + self.num_experts, + in_features=config.hidden_size, + out_features=2 * self.moe_intermediate_size, + ) + self.moe_down_proj = MixtralExperts( + self.num_experts, + in_features=self.moe_intermediate_size, + out_features=config.hidden_size, + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: Tensor): + def _expert_forward(x: Tensor, indptr: Tensor): + x1_x2 = self.moe_gate_up_proj(x, indptr) + x1, x2 = op.split(x1_x2, indices_or_sections=2, axis=-1) + x = self.moe_down_proj(self.act_fn(x1) * x2, indptr) + return x + + experts_per_tok = self.num_experts_per_tok + num_experts = self.num_experts + batch_size, seq_len, hidden_size = x.shape + num_tokens = batch_size * seq_len + x = x.reshape(num_tokens, hidden_size) + gate = self.gate(x) + # expert_weights: [num_tokens, experts_per_tok] + # expert_indices: [num_tokens, experts_per_tok] + expert_weights, expert_indices = op_ext.moe_misc.gating_softmax_topk( + gate, experts_per_tok, norm_topk_prob=self.norm_topk_prob + ) + if num_tokens == 1: + # x: [num_tokens * experts_per_tok, hidden_size] + moe_hidden_states = _expert_forward(x, expert_indices) + else: + # cumsum: [num_tokens * local_experts] + cumsum = op_ext.moe_misc.moe_cumsum(expert_indices, num_experts) + # indices: [num_tokens * experts_per_tok] + reverse_indices, token_indices = op_ext.moe_misc.get_indices(cumsum, expert_indices) + # indptr: [num_local_experts + 1] + indptr = op_ext.moe_misc.get_indptr( + cumsum, num_experts, num_tokens, inclusive=False, out_dtype="int32" + ) + # x: [num_tokens * experts_per_tok, hidden_size] + moe_hidden_states = op.take(x, token_indices, axis=0) + moe_hidden_states = _expert_forward(moe_hidden_states, indptr) + moe_hidden_states = op_ext.moe_misc.scatter_output(moe_hidden_states, reverse_indices) + # moe_hidden_states: [num_tokens, experts_per_tok, hidden_size] + expert_weights = expert_weights.reshape(num_tokens, experts_per_tok, 1) + moe_hidden_states = ( + moe_hidden_states.reshape(num_tokens, experts_per_tok, hidden_size) * expert_weights + ) + # moe_hidden_states: [num_tokens, hidden_size] + moe_hidden_states = op_ext.moe_misc.moe_sum(moe_hidden_states, dim=1) + + shared_expert_hidden_states = self.shared_expert(x) + shared_expert_hidden_states = ( + op.sigmoid(self.shared_expert_gate(x)) * shared_expert_hidden_states + ) + final_hidden_states = moe_hidden_states + shared_expert_hidden_states + final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_size) + return final_hidden_states + + +class Qwen2MoeDecoderLayer(nn.Module): + def __init__(self, config: Qwen2MoeConfig): + super().__init__() + self.self_attn = QWen2Attention(config) + assert ( + config.num_experts > 0 and config.decoder_sparse_step == 1 + ), "Currently only support use moe for every layer." + self.mlp = Qwen2MoeSparseMoeBlock(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, -1, config.rms_norm_eps, bias=False + ) + + def _set_tp(): + def _set(layer, hint): + layer.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_attention_heads * hd + k = self.self_attn.num_key_value_heads * hd + v = self.self_attn.num_key_value_heads * hd + si = self.mlp.shared_expert.intermediate_size + mi = self.mlp.moe_intermediate_size + _set( + self.self_attn.c_attn.weight, + tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]), + ) + _set( + self.self_attn.c_attn.bias, + tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]), + ) + _set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1)) + _set( + self.mlp.shared_expert.gate_up_proj.weight, + tp.ShardSingleDim("_shard_shared_mlp_up", segs=[si, si], dim=0), + ) + _set( + self.mlp.shared_expert.down_proj.weight, + tp.ShardSingleDim("_shard_shared_mlp_down", dim=1), + ) + _set( + self.mlp.moe_gate_up_proj.weight, + tp.ShardSingleDim("_shard_moe_mlp_up", segs=[mi, mi], dim=1), + ) + _set(self.mlp.moe_down_proj.weight, tp.ShardSingleDim("_shard_moe_mlp_down", dim=2)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.input_layernorm(hidden_states) + out = self.self_attn(out, paged_kv_cache, layer_id) + hidden_states = self._apply_residual(out, residual=hidden_states) + out = self.post_attention_layernorm(hidden_states) + out = self.mlp(out) + hidden_states = self._apply_residual(out, residual=hidden_states) + return hidden_states + + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + + +class Qwen2MoeModel(nn.Module): + def __init__(self, config: Qwen2MoeConfig): + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [Qwen2MoeDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + for layer_id, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class Qwen2MoeForCausalLM(nn.Module): # pylint: disable=too-many-instance-attributes + def __init__(self, config: Qwen2MoeConfig): + self.model = Qwen2MoeModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.dtype = config.dtype + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.intermediate_size = config.intermediate_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.rms_norm_eps = config.rms_norm_eps + self.rope_theta = config.rope_theta + self.vocab_size = config.vocab_size + self.tensor_parallel_shards = config.tensor_parallel_shards + self.head_dim = config.head_dim + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None, + ): + op_ext.configure() + + hidden_states = self.model(input_embeds, paged_kv_cache) + if logit_positions is not None: + hidden_states = op.take(hidden_states, logit_positions, axis=1) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + return self.model.embed_tokens(input_ids) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): # x[:-1,:] + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") + + hidden_states = self.model(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def create_paged_kv_cache( # pylint: disable=too-many-arguments + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/qwen2_moe/qwen2_moe_quantization.py b/python/mlc_llm/model/qwen2_moe/qwen2_moe_quantization.py new file mode 100644 index 0000000000..e01289823e --- /dev/null +++ b/python/mlc_llm/model/qwen2_moe/qwen2_moe_quantization.py @@ -0,0 +1,47 @@ +"""This file specifies how MLC's QWen2 parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize + +from .qwen2_moe_model import Qwen2MoeConfig, Qwen2MoeForCausalLM + + +def group_quant( + model_config: Qwen2MoeConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Qwen2MoE-architecture model using group quantization.""" + model: nn.Module = Qwen2MoeForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards + model = quantization.quantize_model(model, quant_map, "") + return model, quant_map + + +def ft_quant( + model_config: Qwen2MoeConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Qwen2MoE model using FasterTransformer quantization.""" + model: nn.Module = Qwen2MoeForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model(model, quant_map, "") + return model, quant_map + + +def no_quant( + model_config: Qwen2MoeConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Qwen2MoE model without quantization.""" + model: nn.Module = Qwen2MoeForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map diff --git a/python/mlc_llm/model/rwkv5/rwkv5_model.py b/python/mlc_llm/model/rwkv5/rwkv5_model.py index 81c9e9aa7f..cf91edc95a 100644 --- a/python/mlc_llm/model/rwkv5/rwkv5_model.py +++ b/python/mlc_llm/model/rwkv5/rwkv5_model.py @@ -52,11 +52,11 @@ def __post_init__(self): ) if self.num_heads * self.head_size != self.hidden_size: raise ValueError( - f"hidden_size ({self.hidden_size}) must be diisible " + f"hidden_size ({self.hidden_size}) must be divisible " f"by head_size ({self.head_size})" ) if self.tensor_parallel_shards != 1: - raise ValueError("Only support single deice at this moment.") + raise ValueError("Only support single device at this moment.") # pylint: disable=invalid-name,missing-docstring @@ -379,10 +379,6 @@ def batch_verify(self, input_embeds: Tensor, state: RNNState): """Verify step.""" return self.forward(input_embeds, state) - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - """Softmax.""" - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state( self, max_batch_size: tir.Var, @@ -451,14 +447,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_rnn_state": { "max_batch_size": int, "max_history": int, diff --git a/python/mlc_llm/model/rwkv5/rwkv5_quantization.py b/python/mlc_llm/model/rwkv5/rwkv5_quantization.py index 235519774c..19385724e2 100644 --- a/python/mlc_llm/model/rwkv5/rwkv5_quantization.py +++ b/python/mlc_llm/model/rwkv5/rwkv5_quantization.py @@ -1,5 +1,6 @@ """This file specifies how MLC's RWKV5 parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn @@ -17,6 +18,7 @@ def group_quant( model: nn.Module = RWKV5_ForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/rwkv6/rwkv6_model.py b/python/mlc_llm/model/rwkv6/rwkv6_model.py index a8faf48a6b..065bc3eb05 100644 --- a/python/mlc_llm/model/rwkv6/rwkv6_model.py +++ b/python/mlc_llm/model/rwkv6/rwkv6_model.py @@ -52,11 +52,11 @@ def __post_init__(self): ) if self.num_heads * self.head_size != self.hidden_size: raise ValueError( - f"hidden_size ({self.hidden_size}) must be diisible " + f"hidden_size ({self.hidden_size}) must be divisible " f"by head_size ({self.head_size})" ) if self.tensor_parallel_shards != 1: - raise ValueError("Only support single deice at this moment.") + raise ValueError("Only support single device at this moment.") # pylint: disable=invalid-name, missing-docstring @@ -421,10 +421,6 @@ def batch_verify(self, input_embeds: Tensor, state: RNNState): """Verify step.""" return self.forward(input_embeds, state) - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - """Softmax.""" - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_rnn_state( self, max_batch_size: tir.Var, @@ -493,14 +489,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_rnn_state": { "max_batch_size": int, "max_history": int, diff --git a/python/mlc_llm/model/rwkv6/rwkv6_quantization.py b/python/mlc_llm/model/rwkv6/rwkv6_quantization.py index ef67568a6f..eda41f643b 100644 --- a/python/mlc_llm/model/rwkv6/rwkv6_quantization.py +++ b/python/mlc_llm/model/rwkv6/rwkv6_quantization.py @@ -18,6 +18,7 @@ def group_quant( model: nn.Module = RWKV6_ForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/model/stable_lm/stablelm_model.py b/python/mlc_llm/model/stable_lm/stablelm_model.py index 10e16cded6..4f874af633 100644 --- a/python/mlc_llm/model/stable_lm/stablelm_model.py +++ b/python/mlc_llm/model/stable_lm/stablelm_model.py @@ -55,7 +55,7 @@ def __post_init__(self): break else: raise ValueError( - "Unable to determine the maxmimum sequence length, because none of " + "Unable to determine the maximum sequence length, because none of " "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " "provided in `config.json`." ) @@ -64,21 +64,19 @@ def __post_init__(self): assert self.head_dim * self.num_attention_heads == self.hidden_size if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %s (%d)", + "%s defaults to %d", bold("prefill_chunk_size"), - bold("context_window_size"), - self.context_window_size, + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) elif self.prefill_chunk_size > self.context_window_size: logger.info( - "Overriding %s from %d to %d (%s)", + "Overriding %s from %d to %d", bold("prefill_chunk_size"), self.prefill_chunk_size, - self.context_window_size, - bold("context_window_size"), + min(self.context_window_size, 2048), ) - self.prefill_chunk_size = self.context_window_size + self.prefill_chunk_size = min(self.context_window_size, 2048) # pylint: disable=invalid-name,missing-docstring @@ -90,6 +88,11 @@ def __init__(self, config: StableLmConfig): self.rope_theta = config.rope_theta self.tensor_parallel_shards = config.tensor_parallel_shards self.head_dim = config.head_dim + if config.num_key_value_heads % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split {config.num_key_value_heads} key-value attention heads " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.num_heads = config.num_attention_heads // self.tensor_parallel_shards self.num_key_value_heads = config.num_key_value_heads // self.tensor_parallel_shards self.num_key_value_groups = self.num_heads // self.num_key_value_heads @@ -117,6 +120,11 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: class StableLmMLP(nn.Module): def __init__(self, config: StableLmConfig): + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards self.gate_up_proj = nn.Linear( in_features=config.hidden_size, @@ -277,9 +285,6 @@ def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) - def create_paged_kv_cache( # pylint: disable=too-many-arguments self, max_batch_size: tir.Var, @@ -355,14 +360,6 @@ def get_default_spec(self): "effect_mode": "none", }, }, - "softmax_with_temperature": { - "logits": nn.spec.Tensor(["batch_size", 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor(["batch_size"], "float32"), - "$": { - "param_mode": "none", - "effect_mode": "none", - }, - }, "create_paged_kv_cache": { "max_batch_size": int, "max_total_seq_len": int, diff --git a/python/mlc_llm/model/stable_lm/stablelm_quantization.py b/python/mlc_llm/model/stable_lm/stablelm_quantization.py index 5f502b0970..620b769e05 100644 --- a/python/mlc_llm/model/stable_lm/stablelm_quantization.py +++ b/python/mlc_llm/model/stable_lm/stablelm_quantization.py @@ -1,5 +1,6 @@ """This file specifies how MLC's StableLM parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn @@ -18,6 +19,7 @@ def group_quant( model: nn.Module = StableLmForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards model = quantization.quantize_model( model, quant_map, diff --git a/python/mlc_llm/nn/__init__.py b/python/mlc_llm/nn/__init__.py index fb1743f788..0c44b544d8 100644 --- a/python/mlc_llm/nn/__init__.py +++ b/python/mlc_llm/nn/__init__.py @@ -1,3 +1,4 @@ """Common `nn.Modules` used to define LLMs in this project.""" + from .expert import MixtralExperts from .kv_cache import FlashInferPagedKVCache, PagedKVCache, RopeMode, TIRPagedKVCache diff --git a/python/mlc_llm/nn/expert.py b/python/mlc_llm/nn/expert.py index d6c38db248..1dadd7d078 100644 --- a/python/mlc_llm/nn/expert.py +++ b/python/mlc_llm/nn/expert.py @@ -1,4 +1,5 @@ """An nn.Module that represents MoE experts""" + from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index e4cbf1c047..32ddbf15b2 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -13,6 +13,7 @@ from tvm.target import Target from mlc_llm.op.position_embedding import llama_rope_with_position_map, rope_freq +from mlc_llm.op.tree_attn import tree_attn from ..support.max_thread_check import ( check_thread_limits, @@ -246,6 +247,8 @@ def __init__( # pylint: disable=too-many-locals bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), + bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), + bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_with_tree_mask"), # fmt: on # pylint: enable=line-too-long ] @@ -350,6 +353,8 @@ def __init__( # pylint: disable=too-many-locals bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), + bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), + bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_with_tree_mask"), # fmt: on # pylint: enable=line-too-long ] @@ -399,7 +404,7 @@ def tir_kv_cache_transpose_append( pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf] = k_data[vgpos, vh, vf] with T.block("v_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.reads(position_map[vgpos], v_data[vgpos, vh, vf]) T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) position: T.int32 = position_map[vgpos] # type: ignore[name-defined,no-redef] pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf] = v_data[vgpos, vh, vf] @@ -457,14 +462,14 @@ def _rope( qkv_dtype="float16", ): d = indices[-1] - cos_freq, sin_freq = rope_freq(offset * scale, d, rotary_dim, theta, qkv_dtype) - cos = cos_freq * buffer[indices] + cos_freq, sin_freq = rope_freq(offset * scale, d, rotary_dim, theta, "float32") + cos = cos_freq * buffer[indices].astype("float32") sin = sin_freq * tir.if_then_else( d < rotary_dim // 2, -buffer[indices[:-1] + (d + rotary_dim // 2,)], buffer[indices[:-1] + (d - rotary_dim // 2,)], - ) - return cos + sin + ).astype("float32") + return (cos + sin).astype(qkv_dtype) def _var(dtype): @@ -520,7 +525,6 @@ def _attention_prefill(h_kv, h_q, d, dtype, sliding_window: bool, target: Target bdx = 32 num_warps = 4 tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 - L_per_cta = tile_x // group_size # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( @@ -636,8 +640,8 @@ def batch_prefill_paged_kv( if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] - L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta - H_qo_start: T.int32 = by * group_size + LH_start: T.int32 = tile_id[0] * tile_x + q_indptr_val: T.int32 = q_indptr[b_idx] cur_page_indptr_begin: T.int32 = page_indptr[b_idx] cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] @@ -667,8 +671,8 @@ def batch_prefill_paged_kv( i, j = T.axis.remap("SS", [li, lj]) T.reads() T.writes() - cur_L = L_start + i // group_size - cur_H_qo = H_qo_start + i % group_size + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = by * group_size + (LH_start + i) % group_size if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else( rotary_mode == 1, @@ -737,9 +741,10 @@ def batch_prefill_paged_kv( m_prev[i] = m_smem[row] m_new[i] = m_smem[row] # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size for j in T.serial(tile_z): if _causal_mask(causal, - row=tile_id[0] * L_per_cta + row // group_size, + row=row_, col=L_kv_start + j, kv_len=kv_chunk_len[0], qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): @@ -752,8 +757,9 @@ def batch_prefill_paged_kv( for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: + row_: T.int32 = (LH_start + row) // group_size if _causal_mask(causal, - row=tile_id[0] * L_per_cta + row // group_size, + row=row_, col=L_kv_start + j, kv_len=kv_chunk_len[0], qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): @@ -785,15 +791,19 @@ def batch_prefill_paged_kv( for li, lj in T.grid(tile_x, tile_y): with T.block("O_store"): i, j = T.axis.remap("SS", [li, lj]) - if L_start + i // group_size < q_indptr[b_idx + 1]: - output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i] + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] # Store LSE to gmem for li in T.grid(tile_x): with T.block("lse_store"): i = T.axis.remap("S", [li]) - if L_start + i // group_size < q_indptr[b_idx + 1]: - lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) # move to next tile tile_id[0] += NUM_BLKS @@ -887,7 +897,7 @@ def _attention_decode( THREAD_LIMIT = 512 TILE_SIZE_PER_BDX = 2 if target.kind.name == "opencl" and "android" in str(target.host): - THREAD_LIMIT = 256 + THREAD_LIMIT = 256 if H_kv < 8 else 512 TILE_SIZE_PER_BDX = 1 max_num_threads_per_block = get_max_num_threads_per_block(target) thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) @@ -976,12 +986,14 @@ def batch_decode_paged_kv( t0 = T.alloc_buffer((1,), "float32", scope="local") S_local = T.alloc_buffer((bdy * tile_size_per_bdx), "float32", scope="local") - K_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") + QK_local = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") V_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") m_prev = T.alloc_buffer((1,), "float32", scope="local") d_prev = T.alloc_buffer((1,), "float32", scope="local") other_m = T.alloc_buffer((1,), "float32", scope="local") other_d = T.alloc_buffer((1,), "float32", scope="local") + exp_mprev = T.alloc_buffer((1,), "float32", scope="local") + exp_otherm = T.alloc_buffer((1,), "float32", scope="local") other_o = T.alloc_buffer((VEC_SIZE,), "float32", scope="local") st_m = T.alloc_buffer((1,), "float32", scope="local") st_d = T.alloc_buffer((1,), "float32", scope="local") @@ -1015,9 +1027,9 @@ def batch_decode_paged_kv( for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_size_per_bdx * bdy * bdz)): tile_start_s: T.int32(is_size_var=True) = (tz * bdy + ty) * tile_size_per_bdx # type: ignore tile_start_g: T.int32(is_size_var=True) = ((iterator * bdz + tz) * bdy + ty) * tile_size_per_bdx # type: ignore - # load K from global memory to shared memory + # load KV from global memory to shared memory for j in T.serial(tile_size_per_bdx): - with T.block("K_load"): + with T.block("KV_load"): T.reads() T.writes() row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore @@ -1031,36 +1043,21 @@ def batch_decode_paged_kv( _rope(pages, k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype), pages[page_no, 0, by, page_offset, tx * VEC_SIZE + vec] ) - else: - for vec in T.vectorized(VEC_SIZE): - K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 - T.tvm_storage_sync("shared") - # load V from global memory to shared memory - for j in T.serial(tile_size_per_bdx): - with T.block("V_load"): - T.reads() - T.writes() - row_g: T.int32(is_size_var=True) = tile_start_g + j # type: ignore - if row_g < kv_chunk_len[0]: - seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_g, batch_idx, length_info, sliding_window) # type: ignore - page_no: T.int32(is_size_var=True) = page_table_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore - page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore - for vec in T.vectorized(VEC_SIZE): V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec] else: for vec in T.vectorized(VEC_SIZE): + K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 T.tvm_storage_sync("shared") # compute QK m_prev[0] = st_m[0] for j in T.serial(bdy * tile_size_per_bdx): - # load K from shared memory to local memory - for vec in T.vectorized(VEC_SIZE): - K_local[vec] = K_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec] # compute S = Q * K * sm_scale + for vec in T.vectorized(VEC_SIZE): + QK_local[vec] = T.cast(Q_local[vec], "float32") * T.cast(K_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec], "float32") * attn_score_scaling_factor * sm_scale S_reduce_local[0] = 0 - for vec in T.serial(VEC_SIZE): - S_reduce_local[0] += T.cast(Q_local[vec], "float32") * T.cast(K_local[vec], "float32") * attn_score_scaling_factor * sm_scale + for vec in T.unroll(VEC_SIZE): + S_reduce_local[0] += QK_local[vec] with T.block("block_cross_thread"): T.reads(S_reduce_local[0]) @@ -1117,11 +1114,13 @@ def batch_decode_paged_kv( other_o[vec] = O_allreduce[j, ty, tx * VEC_SIZE + vec] st_m[0] = T.max(st_m[0], other_m[0]) st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) - for vec in T.serial(VEC_SIZE): - O_local[vec] = O_local[vec] * T.exp2(m_prev[0] - st_m[0]) + other_o[vec] * T.exp2(other_m[0] - st_m[0]) + exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) + exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) + for vec in T.vectorized(VEC_SIZE): + O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] # normalize O - for vec in T.serial(VEC_SIZE): + for vec in T.vectorized(VEC_SIZE): O_local[vec] /= st_d[0] # store O to global memory @@ -1224,7 +1223,6 @@ def _attention_prefill_ragged( bdx = 32 num_warps = 4 tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 - L_per_cta = tile_x // group_size # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( @@ -1319,8 +1317,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] - L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta - H_qo_start: T.int32 = by * group_size + q_indptr_val: T.int32 = q_indptr[b_idx] + LH_start: T.int32 = tile_id[0] * tile_x kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] T.tvm_storage_sync("shared") @@ -1344,8 +1342,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches i, j = T.axis.remap("SS", [li, lj]) T.reads() T.writes() - cur_L = L_start + i // group_size - cur_H_qo = H_qo_start + i % group_size + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = by * group_size + (LH_start + i) % group_size if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else( rotary_mode == 1, @@ -1409,9 +1407,10 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches m_prev[i] = m_smem[row] m_new[i] = m_smem[row] # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size for j in T.serial(tile_z): if _causal_mask(causal, - row=tile_id[0] * L_per_cta + row // group_size, + row=row_, col=L_kv_start + j, kv_len=kv_chunk_len[0], qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): @@ -1424,8 +1423,9 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: + row_: T.int32 = (LH_start + row) // group_size if _causal_mask(causal, - row=tile_id[0] * L_per_cta + row // group_size, + row=row_, col=L_kv_start + j, kv_len=kv_chunk_len[0], qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): @@ -1457,15 +1457,19 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches for li, lj in T.grid(tile_x, tile_y): with T.block("O_store"): i, j = T.axis.remap("SS", [li, lj]) - if L_start + i // group_size < q_indptr[b_idx + 1]: - output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i] + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] # Store LSE to gmem for li in T.grid(tile_x): with T.block("lse_store"): i = T.axis.remap("S", [li]) - if L_start + i // group_size < q_indptr[b_idx + 1]: - lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) # move to next tile tile_id[0] += NUM_BLKS @@ -1581,3 +1585,54 @@ def copy_single_page( pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd] return copy_single_page + + +def _compact_kv_copy(num_heads, head_dim, dtype, target: Target): + tx = get_max_num_threads_per_block(target) + + @T.prim_func + def compact_kv_copy( + var_pages: T.handle, + var_copy_length_indptr: T.handle, + var_copy_src_dst_pos: T.handle, + batch_size: T.int32, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + total_copy_length = T.int32() + copy_length_indptr_elem_offset = T.int32() + copy_src_dst_pos_elem_offset = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype) + copy_length_indptr = T.match_buffer( + var_copy_length_indptr, + (batch_size + 1,), + "int32", + elem_offset=copy_length_indptr_elem_offset, + ) + copy_src_dst_pos = T.match_buffer( + var_copy_src_dst_pos, + (2, total_copy_length), + "int32", + elem_offset=copy_src_dst_pos_elem_offset, + ) + + with T.block("root"): + for bhd_o in T.thread_binding( + (batch_size * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" + ): + for bhd_i in T.thread_binding(tx, thread="threadIdx.x"): + b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim) + h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads + d: T.int32 = (bhd_o * tx + bhd_i) % head_dim + if (bhd_o * tx + bhd_i) < batch_size * num_heads * head_dim: + for i in T.serial(copy_length_indptr[b + 1] - copy_length_indptr[b]): + src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] + dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] + pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[ + src_pos // 16, 0, h, src_pos % 16, d + ] + pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[ + src_pos // 16, 1, h, src_pos % 16, d + ] + + return compact_kv_copy diff --git a/python/mlc_llm/op/__init__.py b/python/mlc_llm/op/__init__.py index 850312a8a7..18502c2db4 100644 --- a/python/mlc_llm/op/__init__.py +++ b/python/mlc_llm/op/__init__.py @@ -7,3 +7,4 @@ from .ft_gemm import faster_transformer_dequantize_gemm from .position_embedding import llama_rope from .top_p_pivot import top_p_pivot, top_p_renorm +from .tree_attn import tree_attn diff --git a/python/mlc_llm/op/attention.py b/python/mlc_llm/op/attention.py index dc41a5f5ef..712ac58ef1 100644 --- a/python/mlc_llm/op/attention.py +++ b/python/mlc_llm/op/attention.py @@ -1,4 +1,5 @@ """Operators enabled by external modules.""" + import math from tvm import tir @@ -62,7 +63,6 @@ def attention( # pylint: disable=invalid-name,too-many-locals,too-many-statemen b, s, h_q, d = q.shape t, h_kv, _ = k.shape[-3:] group_size = h_q // h_kv - assert b == 1, "batch size must be 1" def _fallback(): nonlocal q, k, v, qk_dtype diff --git a/python/mlc_llm/op/extern.py b/python/mlc_llm/op/extern.py index fd5d91badb..f81326c3be 100644 --- a/python/mlc_llm/op/extern.py +++ b/python/mlc_llm/op/extern.py @@ -14,6 +14,7 @@ singleton `Store: ExternalModuleStore` to store the configured modules. It is supposed to be enabled before any compilation happens, and configured during a model's `forward` method is invoked. """ + import dataclasses from typing import Optional diff --git a/python/mlc_llm/op/ft_gemm.py b/python/mlc_llm/op/ft_gemm.py index 0a4edc6792..2362b1ac2e 100644 --- a/python/mlc_llm/op/ft_gemm.py +++ b/python/mlc_llm/op/ft_gemm.py @@ -1,4 +1,5 @@ """Operators enabled by external modules.""" + import operator from functools import reduce from typing import Optional diff --git a/python/mlc_llm/op/moe_misc.py b/python/mlc_llm/op/moe_misc.py index ff5e50c60c..198878787f 100644 --- a/python/mlc_llm/op/moe_misc.py +++ b/python/mlc_llm/op/moe_misc.py @@ -1,4 +1,5 @@ """Mixture of Experts operators""" + from functools import reduce from typing import Tuple, Union @@ -27,8 +28,10 @@ def moe_sum(x: Tensor, dim: int) -> Tensor: return op.sum(x, axis=dim) -def gating_softmax_topk(x: Tensor, k: int) -> Tuple[Tensor, Tensor]: - """Compute the softmax score, choose the top-k experts, and renormalize the selected scores. +def gating_softmax_topk( # pylint: disable=too-many-statements + x: Tensor, k: int, norm_topk_prob=True +) -> Tuple[Tensor, Tensor]: + """Compute the softmax score, choose the top-k experts, and returns selected scores. Parameters ---------- @@ -38,10 +41,13 @@ def gating_softmax_topk(x: Tensor, k: int) -> Tuple[Tensor, Tensor]: k : int The number of top elements to be selected, which is `num_experts_per_tok` in MoE. + norm_topk_prob : bool + Whether to normalize the top-k expert scores. + Returns ------- expert_weights: Tensor - The renormalized top-k expert scores with shape [batch_size, k]. + The top-k expert scores with shape [batch_size, k]. expert_indices: Tensor The top-k expert indices with shape [batch_size, k]. @@ -50,11 +56,12 @@ def gating_softmax_topk(x: Tensor, k: int) -> Tuple[Tensor, Tensor]: index_dtype = "int32" TX = 1024 - SCAN_LEN = 2 + SCAN_LEN_2 = 2 + SCAN_LEN_4 = 4 # specialized kernel for top 2 case @T.prim_func(private=True) - def topk_softmax_func( + def top2_softmax_norm_func( var_x: T.handle, var_out: T.handle, var_out_index: T.handle, @@ -62,11 +69,11 @@ def topk_softmax_func( T.func_attr({"tir.noalias": True, "tir.is_scheduled": True}) batch_size = T.int64() x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype) - out = T.match_buffer(var_out, (batch_size, SCAN_LEN), dtype) - out_index = T.match_buffer(var_out_index, (batch_size, SCAN_LEN), index_dtype) - local_top_k = T.alloc_buffer((SCAN_LEN,), dtype=dtype, scope="local") - local_top_k_index = T.alloc_buffer((SCAN_LEN,), dtype=index_dtype, scope="local") - local_top_k_f32 = T.alloc_buffer((SCAN_LEN,), dtype="float32", scope="local") + out = T.match_buffer(var_out, (batch_size, SCAN_LEN_2), dtype) + out_index = T.match_buffer(var_out_index, (batch_size, SCAN_LEN_2), index_dtype) + local_top_k = T.alloc_buffer((SCAN_LEN_2,), dtype=dtype, scope="local") + local_top_k_index = T.alloc_buffer((SCAN_LEN_2,), dtype=index_dtype, scope="local") + local_top_k_f32 = T.alloc_buffer((SCAN_LEN_2,), dtype="float32", scope="local") local_top_k_max = T.alloc_buffer((1,), dtype="float32", scope="local") for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"): for ii in T.thread_binding(0, TX, "threadIdx.x"): @@ -88,13 +95,13 @@ def topk_softmax_func( elif x[vi, vk] > local_top_k[1]: local_top_k[1] = x[vi, vk] local_top_k_index[1] = vk - for j in T.unroll(SCAN_LEN): + for j in T.unroll(SCAN_LEN_2): with T.block("cast"): vj = T.axis.remap("S", [j]) local_top_k_f32[vj] = T.cast(local_top_k[vj], "float32") with T.block("max"): local_top_k_max[0] = T.max(local_top_k_f32[0], local_top_k_f32[1]) - for j in T.unroll(SCAN_LEN): + for j in T.unroll(SCAN_LEN_2): with T.block("output"): vj = T.axis.remap("S", [j]) out[vi, vj] = T.cast( @@ -107,9 +114,72 @@ def topk_softmax_func( ) out_index[vi, vj] = local_top_k_index[vj] - if k == 2: + # specialized kernel for top 4 case + @T.prim_func(private=True) + def top4_softmax_norm_func( + var_x: T.handle, + var_out: T.handle, + var_out_index: T.handle, + ) -> None: + T.func_attr({"tir.noalias": True, "tir.is_scheduled": True}) + batch_size = T.int64() + x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype) + out = T.match_buffer(var_out, (batch_size, SCAN_LEN_4), dtype) + out_index = T.match_buffer(var_out_index, (batch_size, SCAN_LEN_4), index_dtype) + local_top_k = T.alloc_buffer((SCAN_LEN_4,), dtype=dtype, scope="local") + local_top_k_index = T.alloc_buffer((SCAN_LEN_4,), dtype=index_dtype, scope="local") + for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"): + for ii in T.thread_binding(0, TX, "threadIdx.x"): + with T.block("top_k"): + vi = T.axis.spatial(batch_size, io * TX + ii) + T.where(io * TX + ii < batch_size) + with T.block("init"): + local_top_k[0] = T.min_value(dtype) + local_top_k[1] = T.min_value(dtype) + local_top_k[2] = T.min_value(dtype) + local_top_k[3] = T.min_value(dtype) + local_top_k_index[0] = 0 + local_top_k_index[1] = 0 + local_top_k_index[2] = 0 + local_top_k_index[3] = 0 + for k in range(num_local_experts): + with T.block("update"): + vk = T.axis.remap("S", [k]) + # N.B. This snippet is specialized for k = 4 + if x[vi, vk] > local_top_k[0]: + local_top_k[3] = local_top_k[2] + local_top_k_index[3] = local_top_k_index[2] + local_top_k[2] = local_top_k[1] + local_top_k_index[2] = local_top_k_index[1] + local_top_k[1] = local_top_k[0] + local_top_k_index[1] = local_top_k_index[0] + local_top_k[0] = x[vi, vk] + local_top_k_index[0] = vk + elif x[vi, vk] > local_top_k[1]: + local_top_k[3] = local_top_k[2] + local_top_k_index[3] = local_top_k_index[2] + local_top_k[2] = local_top_k[1] + local_top_k_index[2] = local_top_k_index[1] + local_top_k[1] = x[vi, vk] + local_top_k_index[1] = vk + elif x[vi, vk] > local_top_k[2]: + local_top_k[3] = local_top_k[2] + local_top_k_index[3] = local_top_k_index[2] + local_top_k[2] = x[vi, vk] + local_top_k_index[2] = vk + elif x[vi, vk] > local_top_k[3]: + local_top_k[3] = x[vi, vk] + local_top_k_index[3] = vk + for j in T.unroll(SCAN_LEN_4): + with T.block("output"): + vj = T.axis.remap("S", [j]) + out[vi, vj] = local_top_k[vj] + out_index[vi, vj] = local_top_k_index[vj] + + # fast path for Mixtral + if k == 2 and norm_topk_prob: return op.tensor_ir_op( - topk_softmax_func, + top2_softmax_norm_func, "top2_softmax", args=[x], out=( @@ -117,10 +187,28 @@ def topk_softmax_func( Tensor.placeholder([batch_size, 2], index_dtype), ), ) - expert_score, expert_indices = op.topk( - x, k, axis=-1, ret_type="both", largest=True, dtype=index_dtype - ) - expert_score = op.softmax(expert_score.astype("float32"), axis=-1).astype(dtype) + if k == 4 and not norm_topk_prob: + expert_score = op.softmax(x.astype("float32"), axis=-1).astype(dtype) + return op.tensor_ir_op( + top4_softmax_norm_func, + "top4_softmax", + args=[expert_score], + out=( + Tensor.placeholder([batch_size, 4], dtype), + Tensor.placeholder([batch_size, 4], index_dtype), + ), + ) + if norm_topk_prob: + # Compute topk first and then softmax to avoid extra re-normalize + expert_score, expert_indices = op.topk( + x, k, axis=-1, ret_type="both", largest=True, dtype=index_dtype + ) + expert_score = op.softmax(expert_score.astype("float32"), axis=-1).astype(dtype) + else: + expert_score = op.softmax(x.astype("float32"), axis=-1).astype(dtype) + expert_score, expert_indices = op.topk( + expert_score, k, axis=-1, ret_type="both", largest=True, dtype=index_dtype + ) return expert_score, expert_indices diff --git a/python/mlc_llm/op/position_embedding.py b/python/mlc_llm/op/position_embedding.py index 4f3c2a9c42..4416e8bc9a 100644 --- a/python/mlc_llm/op/position_embedding.py +++ b/python/mlc_llm/op/position_embedding.py @@ -176,7 +176,7 @@ def llama_rope_with_position_map( # pylint: disable=too-many-arguments num_q_heads: int, num_kv_heads: int, dtype: str, - rotary_dim: int = None, + rotary_dim: Optional[int] = None, ): """Return the TIR function that computes Llama-style RoPE with q position map. @@ -207,7 +207,7 @@ def llama_rope_with_position_map( # pylint: disable=too-many-arguments fused_heads = num_q_heads + num_kv_heads * 2 if rotary_dim is None: rotary_dim = head_dim - scale = tir.const(scale, dtype) + scale = tir.const(scale, "float32") def _rope( # pylint: disable=too-many-arguments x: T.Buffer, @@ -216,14 +216,14 @@ def _rope( # pylint: disable=too-many-arguments d: tir.Var, pos: tir.Var, ): - cos_freq, sin_freq = rope_freq(pos * scale, d, rotary_dim, theta, dtype) - cos = cos_freq * x[s, h, d] + cos_freq, sin_freq = rope_freq(pos * scale, d, rotary_dim, theta, "float32") + cos = cos_freq * x[s, h, d].astype("float32") sin = sin_freq * tir.if_then_else( d < rotary_dim // 2, -x[s, h, d + rotary_dim // 2], x[s, h, d - rotary_dim // 2], - ) - return cos + sin + ).astype("float32") + return (cos + sin).astype(dtype) @T.prim_func def fused_rope( # pylint: disable=too-many-locals diff --git a/python/mlc_llm/op/top_p_pivot.py b/python/mlc_llm/op/top_p_pivot.py index 9c97959bff..b9565a83c9 100644 --- a/python/mlc_llm/op/top_p_pivot.py +++ b/python/mlc_llm/op/top_p_pivot.py @@ -3,12 +3,14 @@ import tvm from tvm.script import tir as T +from mlc_llm.support.max_thread_check import get_max_num_threads_per_block + # mypy: disable-error-code="attr-defined,valid-type,name-defined" # pylint: disable=too-many-locals,invalid-name,too-many-arguments,unnecessary-lambda # pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches -def top_p_pivot(pN): +def top_p_pivot(pN, target: tvm.target.Target): """Top-p pivot function. This function finds the pivot to cut-off top-p percentile. A valide pivot should satisfy the following conditions: @@ -23,7 +25,7 @@ def top_p_pivot(pN): prob: The probability vector - top_p_global: + top_p_arr: The top-p threshold init_pivots: @@ -31,11 +33,18 @@ def top_p_pivot(pN): final_pivot: The final pivot to cut-off top-p percentile + + final_lsum: + The final sum of the values after top-p filtering. """ TX = 1024 K = 32 eps_LR = 1e-7 + max_num_threads_per_block = get_max_num_threads_per_block(target) + if max_num_threads_per_block < TX: + TX = max_num_threads_per_block + def _var(dtype="int32"): return T.alloc_buffer((1,), dtype, scope="local") @@ -46,7 +55,7 @@ def valid(lsum, lmin, cmin, top_p): @T.prim_func(private=True) def _func( var_prob: T.handle, - top_p_global: T.buffer([1], dtype="float32"), + var_top_p_arr: T.handle, var_init_pivots: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle, @@ -55,7 +64,8 @@ def _func( B = T.int32() N = T.int32() prob = T.match_buffer(var_prob, (B, N,), "float32") - init_pivots = T.match_buffer(var_init_pivots, (pN,), "float32") + top_p_arr = T.match_buffer(var_top_p_arr, (B,), dtype="float32") + init_pivots = T.match_buffer(var_init_pivots, (B, pN), "float32") final_pivot = T.match_buffer(var_final_pivot, (B,), "float32") final_lsum = T.match_buffer(var_final_lsum, (B,), "float32") @@ -92,7 +102,7 @@ def _func( with T.block("CTA"): b, tx = T.axis.remap("SS", [_bx, _tx]) - top_p[0] = top_p_global[0] + top_p[0] = top_p_arr[b] if tx == 0: # leader thread initializes L, R @@ -105,8 +115,14 @@ def _func( R_local[0] = R[0] for i in T.unroll(0, pN): # pivots are in descending order - pivot[i] = init_pivots[i] + pivot[i] = init_pivots[b, i] find_pivot_local[0] = False + if L_local[0] - R_local[0] <= eps_LR: + # When the initial value is too small, set the result directly. + if tx == 0: + final_lsum[b] = 1.0 + final_pivot[b] = 0.0 + find_pivot_local[0] = True while T.tvm_thread_invariant( L_local[0] - R_local[0] > eps_LR @@ -118,7 +134,7 @@ def _func( ### get lsum, lmin, total_sum for pidx in T.unroll(0, pN): lsum[pidx] = 0.0 - lmin[pidx] = 1.0 + lmin[pidx] = T.max_value("float32") cmin[pidx] = 0 total_sum[0] = 0.0 it[0] = 0 @@ -226,6 +242,7 @@ def _func( final_lsum[b] = lsum[pidx] elif lsum[pidx] - lmin[pidx] * cmin[pidx] >= top_p[0]: R[0] = pivot[pidx] + final_lsum[b] = lsum[pidx] elif lsum[pidx] < top_p[0]: L[0] = pivot[pidx] it[0] += 1 @@ -243,13 +260,15 @@ def _func( if tx == 0: # leader thread writes back the pivot if T.Not(find_pivot_local[0]): - final_pivot[b] = -1e5 + final_pivot[b] = R_local[0] + if R_local[0] == eps_LR: + final_lsum[b] = lsum[pN - 1] # fmt: on return _func -def top_p_renorm(): +def top_p_renorm(target: tvm.target.Target = None): """Top-p renormalization function. This function renormalizes the probability vector. Given the pivot, the probability vector is renormalized as follows: @@ -273,6 +292,11 @@ def top_p_renorm(): TX = 1024 CTA_COUNT = 512 + if target: + max_num_threads_per_block = get_max_num_threads_per_block(target) + if max_num_threads_per_block < TX: + TX = max_num_threads_per_block + def _var(dtype="int32"): return T.alloc_buffer((1,), dtype, scope="local") diff --git a/python/mlc_llm/op/tree_attn.py b/python/mlc_llm/op/tree_attn.py new file mode 100644 index 0000000000..0a9373125d --- /dev/null +++ b/python/mlc_llm/op/tree_attn.py @@ -0,0 +1,393 @@ +"""Operators for tree attention.""" + +import math +from typing import Tuple + +from tvm import tir +from tvm.runtime import DataType +from tvm.script import tir as T +from tvm.target import Target + +from mlc_llm.op.position_embedding import rope_freq + +# mypy: disable-error-code="attr-defined,valid-type,no-redef" +# pylint: disable=too-many-statements,too-many-locals,too-many-arguments + + +def _var(dtype): + return T.alloc_buffer((1,), dtype, scope="local") + + +def _rope( + buffer: T.Buffer, + offset: tir.Var, + rotary_dim: int, + theta: tir.Var, + scale: tir.Var, + indices: Tuple[tir.Var, ...], + qkv_dtype="float16", +): + d = indices[-1] + cos_freq, sin_freq = rope_freq(offset * scale, d, rotary_dim, theta, qkv_dtype) + cos = cos_freq * buffer[indices] + sin = sin_freq * tir.if_then_else( + d < rotary_dim // 2, + -buffer[indices[:-1] + (d + rotary_dim // 2,)], + buffer[indices[:-1] + (d - rotary_dim // 2,)], + ) + return cos + sin + + +def _tree_mask(row, col, mask_ptr, offset, stride, kv_len): + return tir.all(col < kv_len, mask_ptr[offset + row * stride + col] == 1) + + +def tree_attn(h_kv, h_q, d, dtype, target: Target): # pylint: disable=unused-argument + """Generate tree attention kernel for batched tree attention. + + Parameters + ---------- + h_kv : int + Number of heads for key and value. + h_q : int + Number of heads for query. + d : int + Hidden dimension. + dtype : str + Data type. + target : Target + The target device. + + Returns + ------- + mod : tvm.IRModule + The generated IR module. + """ + # pylint: disable=invalid-name,line-too-long + NUM_BLKS = 16 + LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + bdx = 32 + num_warps = 4 + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + + # Otherwise we would exceed maxComputeWorkgroupStorageSize + if ( + str(target.kind) == "webgpu" + and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 + ): + tile_z = 8 + num_warps = 2 + + # fmt: off + @T.prim_func + def batch_tree_attn( # pylint: disable=too-many-branches + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_k: T.handle, # [total_len, h_kv, d] + var_v: T.handle, # [total_len, h_kv, d] + var_kv_indptr: T.handle, # [batch_size + 1], kv_indptr should be the same as q_indptr in this case + var_q_rope_position: T.handle, # [total_q_len] + var_mn_indptr: T.handle, # [batch_size + 1] + var_mask: T.handle, # [mn_indptr[batch_size]] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + batch_size: T.int32, + ): + qo_len = T.int32(is_size_var=True) + kv_len = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + kv_indptr_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + mn_indptr_elem_offset = T.int32(is_size_var=True) + mask_elem_offset = T.int32(is_size_var=True) + tree_size = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) + k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) + v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) + kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset) + q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset) + mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", elem_offset=mn_indptr_elem_offset) + mask = T.match_buffer(var_mask, (tree_size,), "int32", elem_offset=mask_elem_offset) + output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable + + # kernel code + for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"): + for lby in T.thread_binding(h_kv, thread="blockIdx.y"): + for lty in T.thread_binding(num_warps, thread="threadIdx.y"): + for ltx in T.thread_binding(bdx, thread="threadIdx.x"): + with T.block("attn"): + bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) + T.reads() + T.writes() + tile_id = _var("int32") + batch_idx = _var("int32") + batch_tiles = _var("int32") + batch_rows = _var("int32") + iterator = _var("int32") + kv_chunk_len = _var("int32") + + Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") + K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") + + S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") + + m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + d_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") + + m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * num_warps)),), "float32", scope="local") + + ## get tile_no, batch_idx, batch_tiles, batch_rows + tile_id[0] = bx + batch_idx[0] = 0 + batch_rows[0] = (q_indptr[1] - q_indptr[0]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + while T.tvm_thread_invariant(batch_idx[0] < batch_size): + # advance to next tile + while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: + tile_id[0] -= batch_tiles[0] + batch_idx[0] += 1 + if batch_idx[0] < batch_size: + b_idx: T.int32 = batch_idx[0] + batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * group_size + batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x) + + if T.tvm_thread_invariant(batch_idx[0] < batch_size): + b_idx: T.int32 = batch_idx[0] + LH_start: T.int32 = tile_id[0] * tile_x + q_indptr_val: T.int32 = q_indptr[b_idx] + + kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] + T.tvm_storage_sync("shared") + + # init states + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + m_smem[row] = -5e4 + d_smem[row] = 1.0 + + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_init"): + i, j = T.axis.remap("SS", [li, lj]) + O_local[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Load Q from gmem to smem + for li, lj in T.grid(tile_x, tile_y): + with T.block("Q_load"): + i, j = T.axis.remap("SS", [li, lj]) + T.reads() + T.writes() + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + Q_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype), + q[cur_L, cur_H_qo, j] + ) + else: + Q_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + for iterator in T.serial(T.ceildiv(kv_chunk_len[0], tile_z)): + L_kv_start: T.int32 = iterator * tile_z + L_kv_base: T.int32 = kv_indptr[b_idx] + for lz, ly in T.grid(tile_z, tile_y): + with T.block("KV_load"): + i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() + cur_L = L_kv_base + L_kv_start + i + if L_kv_start + i < kv_chunk_len[0]: + K_smem[i, j] = T.if_then_else( + rotary_mode == 1, + _rope(k, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, by, j), dtype), + k[cur_L, by, j] + ) + V_smem[i, j] = v[cur_L, by, j] + else: + K_smem[i, j] = 0.0 + V_smem[i, j] = 0.0 + T.tvm_storage_sync("shared") + + # Compute S + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_z, tile_y): + with T.block("S_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + S_local[i, j] = 0.0 + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale + T.tvm_storage_sync("shared") + for li, lj in T.grid(tile_x, tile_z): + with T.block("S_store"): + i, j = T.axis.remap("SS", [li, lj]) + S_smem[i, j] = S_local[i, j] + T.tvm_storage_sync("shared") + + # Update S, m, d + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update1"): + m_prev[i] = m_smem[row] + m_new[i] = m_smem[row] + # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size + for j in T.serial(tile_z): + if _tree_mask( + row=row_, + col=L_kv_start + j, + mask_ptr=mask, + offset=mn_indptr[b_idx], + stride=q_indptr[b_idx + 1] - q_indptr[b_idx], + kv_len=kv_chunk_len[0]): + m_new[i] = T.max(m_new[i], S_smem[row, j]) + d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + with T.block("update"): + for j in T.serial(tile_z): + # this is to avoid sync inside condition branch + if row < tile_x: + row_: T.int32 = (LH_start + row) // group_size + if _tree_mask( + row=row_, + col=L_kv_start + j, + mask_ptr=mask, + offset=mn_indptr[b_idx], + stride=q_indptr[b_idx + 1] - q_indptr[b_idx], + kv_len=kv_chunk_len[0]): + S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) + else: + S_smem[row, j] = T.exp2(-5e4 - m_new[i]) + + for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): + row: T.int32 = i * bdx * num_warps + ty * bdx + tx + if row < tile_x: + with T.block("update"): + for j in T.serial(tile_z): + d_new[i] += S_smem[row, j] + m_smem[row] = m_new[i] + d_smem[row] = d_new[i] + m_prev_smem[row] = m_prev[i] + T.tvm_storage_sync("shared") + + # Update O + with T.block(): + for li, lj, lk in T.grid(tile_x, tile_y, tile_z): + with T.block("O_gemm"): + i, j, k = T.axis.remap("SSR", [li, lj, lk]) + with T.init(): + O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) + O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") + + # Store O from smem to gmem + for li, lj in T.grid(tile_x, tile_y): + with T.block("O_store"): + i, j = T.axis.remap("SS", [li, lj]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] + + # Store LSE to gmem + for li in T.grid(tile_x): + with T.block("lse_store"): + i = T.axis.remap("S", [li]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) + + # move to next tile + tile_id[0] += NUM_BLKS + # fmt: on + # pylint: enable=line-too-long,invalid-name,too-many-branches + sch = tir.Schedule(batch_tree_attn) + + def get_tile_size(x, y, t): + cnt = (x * y) // t + assert (x * y) % t == 0 + tile_y = (int)(math.ceil(math.sqrt(cnt))) + while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + tile_y += 1 + assert tile_y <= cnt + tile_x = cnt // tile_y + return tile_x, tile_y + + def apply_to_qkv_load(sch: tir.Schedule, block): + loop_x, loop_y = sch.get_loops(block)[-2:] + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def apply_to_so_ewise(sch: tir.Schedule, block, tile): + loop_x, loop_y = sch.get_loops(block)[-2:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def apply_to_gemm( # pylint: disable=unused-argument + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False + ): + loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] + xo, xi = sch.split(loop_x, factors=[None, tile[0]]) + yo, yi = sch.split(loop_y, factors=[None, tile[1]]) + sch.reorder(xo, yo, xi, yi) + t = sch.fuse(xo, yo) + ty, tx = sch.split(t, factors=[None, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + ko, ki = sch.split(loop_z, factors=[None, r_len]) + if k_major: + sch.reorder(ko, xi, yi, ki) + else: + sch.reorder(ko, ki, xi, yi) + sch.decompose_reduction(block, ty) + + def apply_to_md(sch, block): + loop = sch.get_loops(block)[-1] + _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) + tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) + apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) + apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False) + apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s) + apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o) + apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o) + apply_to_qkv_load(sch, sch.get_block("Q_load")) + apply_to_qkv_load(sch, sch.get_block("KV_load")) + + apply_to_md(sch, sch.get_block("lse_store")) + return sch.mod["main"].with_attr("tir.is_scheduled", 1) diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index 482cce54c8..ceb5f64039 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -103,7 +103,7 @@ def check_message_seps(cls, seps: List[str]) -> List[str]: def to_json_dict(self) -> Dict[str, Any]: """Convert to a json dictionary""" - return self.model_dump(exclude_none=True) + return self.model_dump(by_alias=True, exclude_none=True) @classmethod def from_json_dict(cls: Type[T], json_dict: Dict[str, Any]) -> T: @@ -135,7 +135,6 @@ def as_prompt(self, config=None) -> List[Any]: separators.append(separators[0]) if system_msg != "": - system_msg += separators[0] message_list.append(system_msg) for i, (role, content) in enumerate(self.messages): # pylint: disable=not-an-iterable diff --git a/python/mlc_llm/protocol/debug_protocol.py b/python/mlc_llm/protocol/debug_protocol.py new file mode 100644 index 0000000000..fe4a1df034 --- /dev/null +++ b/python/mlc_llm/protocol/debug_protocol.py @@ -0,0 +1,27 @@ +"""Debug protocols in MLC LLM""" + +from typing import Literal, Optional + +from pydantic import BaseModel + + +class DebugConfig(BaseModel): + """The class of debug options. + + These optionals are available to engine + but won't be available to serving endpoint + unless an explicit --enable-debug passed + """ + + ignore_eos: bool = False + pinned_system_prompt: bool = False + special_request: Optional[Literal["query_engine_metrics"]] = None + grammar_execution_mode: Literal["constraint", "jump_forward"] = "jump_forward" + + """Special request indicators + + Special requests are handled by engine differently and do not go + through the normal engine step flow. + + The results to these requests are returned as field of "usage" + """ diff --git a/python/mlc_llm/protocol/error_protocol.py b/python/mlc_llm/protocol/error_protocol.py index 83a201f578..1dd1aafd67 100644 --- a/python/mlc_llm/protocol/error_protocol.py +++ b/python/mlc_llm/protocol/error_protocol.py @@ -1,6 +1,7 @@ """Error protocols in MLC LLM""" from http import HTTPStatus +from typing import Optional import fastapi from pydantic import BaseModel @@ -18,13 +19,13 @@ class ErrorResponse(BaseModel): object: str = "error" message: str - code: int = None + code: Optional[int] = None def create_error_response(status_code: HTTPStatus, message: str) -> fastapi.responses.JSONResponse: """Create a JSON response that reports error with regarding the input message.""" return fastapi.responses.JSONResponse( - ErrorResponse(message=message, code=status_code.value).model_dump_json(), + ErrorResponse(message=message, code=status_code.value).model_dump_json(by_alias=True), status_code=status_code.value, ) diff --git a/python/mlc_llm/protocol/generation_config.py b/python/mlc_llm/protocol/generation_config.py new file mode 100644 index 0000000000..e7b8cb9185 --- /dev/null +++ b/python/mlc_llm/protocol/generation_config.py @@ -0,0 +1,33 @@ +"""Low-level generation config class""" + +# pylint: disable=missing-class-docstring, disable=too-many-instance-attributes +from typing import Dict, List, Optional + +from pydantic import BaseModel + +from .debug_protocol import DebugConfig +from .openai_api_protocol import RequestResponseFormat + + +class GenerationConfig(BaseModel): # pylint: + """The generation configuration dataclass. + + This is a config class used by Engine internally. + """ + + n: int = 1 + temperature: Optional[float] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + repetition_penalty: Optional[float] = None + logprobs: bool = False + top_logprobs: int = 0 + logit_bias: Optional[Dict[int, float]] = None + # internally we use -1 to represent infinite + max_tokens: int = -1 + seed: Optional[int] = None + stop_strs: Optional[List[str]] = None + stop_token_ids: Optional[List[int]] = None + response_format: Optional[RequestResponseFormat] = None + debug_config: Optional[Optional[DebugConfig]] = None diff --git a/python/mlc_llm/protocol/mlc_chat_config.py b/python/mlc_llm/protocol/mlc_chat_config.py new file mode 100644 index 0000000000..c1bd7cb1c8 --- /dev/null +++ b/python/mlc_llm/protocol/mlc_chat_config.py @@ -0,0 +1,73 @@ +# pylint: disable=too-many-instance-attributes +"""Schema for mlc-chat-config""" +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, Field + +from mlc_llm.support.constants import MLC_CHAT_CONFIG_VERSION + +from .conversation_protocol import Conversation + +MLC_CHAT_SYSTEM_DEFAULT = { + "pad_token_id": 0, + "bos_token_id": 1, + "eos_token_id": 2, + "temperature": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "repetition_penalty": 1.0, + "top_p": 1.0, +} +"""system default values.""" + + +class MLCChatConfig(BaseModel): + """Fields in the dumped `mlc-chat-config.json` file.""" + + # Version control + version: str = MLC_CHAT_CONFIG_VERSION + + # use alias to avoid protected namespace conflict with pydantic + field_model_type: str = Field(alias="model_type") + quantization: str + # use alias to avoid protected namespace conflict with pydantic + field_model_config: Dict[str, Any] = Field(alias="model_config") + vocab_size: int + context_window_size: int + sliding_window_size: int + prefill_chunk_size: int + attention_sink_size: int + tensor_parallel_shards: int + # Configuration of text generation + temperature: Optional[float] = None + presence_penalty: Optional[float] = None + frequency_penalty: Optional[float] = None + repetition_penalty: Optional[float] = None + top_p: Optional[float] = None + # Tokenizer configuration + tokenizer_files: List[str] = Field(default_factory=list) + # The content of tokenizer.TokenizerInfo + tokenizer_info: Dict[str, Any] = Field(default_factory=dict) + # conversation template + conv_template: Conversation + # extra fields from generation_config.json + # NOTE: they are not being used for now in MLCEngine + # but we keep them for book-keep purposes + pad_token_id: Optional[int] = None + bos_token_id: Optional[int] = None + eos_token_id: Optional[Union[int, List[int]]] = None + + def get_system_defaults_for_missing_fields(self) -> Dict[str, Any]: + """Apply system default value for fields that are None + + Note + ---- + We implement default setting in this way so we can lazily create + MLCChatConfig, override its optional values then + apply_system_defaults in the end. + """ + res = {} + for key, value in MLC_CHAT_SYSTEM_DEFAULT.items(): + if getattr(self, key) is None: + res[key] = value + return res diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 4a5168f971..722f5d2d34 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -13,6 +13,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator from .conversation_protocol import Conversation +from .debug_protocol import DebugConfig from .error_protocol import BadRequestError ################ Commons ################ @@ -40,17 +41,17 @@ class LogProbs(BaseModel): content: List[LogProbsContent] -class UsageInfo(BaseModel): - prompt_tokens: int = 0 - completion_tokens: int = 0 - total_tokens: int = 0 +class CompletionUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + extra: Optional[Dict[str, Any]] = None + """Extra metrics and info that may be returned by debug_config + """ + - def __init__(self, prompt_tokens: int = 0, completion_tokens: int = 0): - super().__init__( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) +class StreamOptions(BaseModel): + include_usage: Optional[bool] ################ v1/models ################ @@ -93,23 +94,24 @@ class CompletionRequest(BaseModel): logprobs: bool = False top_logprobs: int = 0 logit_bias: Optional[Dict[int, float]] = None - max_tokens: int = 16 + max_tokens: Optional[int] = None n: int = 1 seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stream: bool = False + stream_options: Optional[StreamOptions] = None suffix: Optional[str] = None temperature: Optional[float] = None top_p: Optional[float] = None user: Optional[str] = None - ignore_eos: bool = False response_format: Optional[RequestResponseFormat] = None + debug_config: Optional[DebugConfig] = None @field_validator("frequency_penalty", "presence_penalty") @classmethod - def check_penalty_range(cls, penalty_value: float) -> float: + def check_penalty_range(cls, penalty_value: Optional[float]) -> Optional[float]: """Check if the penalty value is in range [-2, 2].""" - if penalty_value < -2 or penalty_value > 2: + if penalty_value and (penalty_value < -2 or penalty_value > 2): raise ValueError("Penalty value should be in range [-2, 2].") return penalty_value @@ -156,9 +158,7 @@ class CompletionResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: Optional[str] = None object: str = "text_completion" - usage: UsageInfo = Field( - default_factory=lambda: UsageInfo() # pylint: disable=unnecessary-lambda - ) + usage: Optional[CompletionUsage] = None ################ v1/chat/completions ################ @@ -211,17 +211,20 @@ class ChatCompletionRequest(BaseModel): seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stream: bool = False + stream_options: Optional[StreamOptions] = None temperature: Optional[float] = None top_p: Optional[float] = None tools: Optional[List[ChatTool]] = None tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None user: Optional[str] = None - ignore_eos: bool = False response_format: Optional[RequestResponseFormat] = None + # NOTE: debug_config is not part of OpenAI protocol + # we add it to enable extra debug options + debug_config: Optional[DebugConfig] = None @field_validator("frequency_penalty", "presence_penalty") @classmethod - def check_penalty_range(cls, penalty_value: float) -> float: + def check_penalty_range(cls, penalty_value: Optional[float]) -> Optional[float]: """Check if the penalty value is in range [-2, 2].""" if penalty_value and (penalty_value < -2 or penalty_value > 2): raise ValueError("Penalty value should be in range [-2, 2].") @@ -252,6 +255,32 @@ def check_logprobs(self) -> "ChatCompletionRequest": raise ValueError('"logprobs" must be True to support "top_logprobs"') return self + @model_validator(mode="after") + def check_stream_options(self) -> "ChatCompletionRequest": + """Check stream options""" + if self.stream_options is None: + return self + if not self.stream: + raise ValueError("stream must be set to True when stream_options is present") + return self + + @model_validator(mode="after") + def check_debug_config(self) -> "ChatCompletionRequest": + """Check debug config""" + if self.debug_config is None: + return self + + if self.debug_config.special_request is None: + return self + + if not self.stream: + raise ValueError("DebugConfig.special_request requires stream=True") + + if self.stream_options is None or not self.stream_options.include_usage: + raise ValueError("DebugConfig.special_request requires include_usage in stream_options") + + return self + def check_message_validity(self) -> None: """Check if the given chat messages are valid. Return error message if invalid.""" for i, message in enumerate(self.messages): @@ -298,7 +327,7 @@ def check_function_call_usage(self, conv_template: Conversation) -> None: ] ): conv_template.use_function_calling = True - conv_template.function_string = tool.function.model_dump_json() + conv_template.function_string = tool.function.model_dump_json(by_alias=True) return # pylint: disable=unsubscriptable-object @@ -315,7 +344,7 @@ def check_function_call_usage(self, conv_template: Conversation) -> None: for tool in self.tools: # pylint: disable=not-an-iterable if tool.type != "function": raise BadRequestError("Only 'function' tool type is supported") - function_list.append(tool.function.model_dump()) + function_list.append(tool.function.model_dump(by_alias=True)) conv_template.use_function_calling = True conv_template.function_string = json.dumps(function_list) @@ -346,9 +375,7 @@ class ChatCompletionResponse(BaseModel): model: Optional[str] = None system_fingerprint: str object: Literal["chat.completion"] = "chat.completion" - usage: UsageInfo = Field( - default_factory=lambda: UsageInfo() # pylint: disable=unnecessary-lambda - ) + usage: Optional[CompletionUsage] = None class ChatCompletionStreamResponse(BaseModel): @@ -362,9 +389,7 @@ class ChatCompletionStreamResponse(BaseModel): model: Optional[str] = None system_fingerprint: str object: Literal["chat.completion.chunk"] = "chat.completion.chunk" - usage: UsageInfo = Field( - default_factory=lambda: UsageInfo() # pylint: disable=unnecessary-lambda - ) + usage: Optional[CompletionUsage] = None ################################################ @@ -383,49 +408,3 @@ def openai_api_get_unsupported_fields( if hasattr(request, field) and getattr(request, field) != value: unsupported_fields.append(field) return unsupported_fields - - -def openai_api_get_generation_config( - request: Union[CompletionRequest, ChatCompletionRequest], model_config: Dict[str, Any] -) -> Dict[str, Any]: - """Create the generation config from the given request.""" - from ..serve.config import ResponseFormat # pylint: disable=import-outside-toplevel - - kwargs: Dict[str, Any] = {} - arg_names = [ - "n", - "temperature", - "top_p", - "max_tokens", - "frequency_penalty", - "presence_penalty", - "logprobs", - "top_logprobs", - "logit_bias", - "seed", - "ignore_eos", - ] - for arg_name in arg_names: - kwargs[arg_name] = getattr(request, arg_name) - - # If per-request generation config values are missing, try loading from model config. - # If still not found, then use the default OpenAI API value - if kwargs["temperature"] is None: - kwargs["temperature"] = model_config.get("temperature", 1.0) - if kwargs["top_p"] is None: - kwargs["top_p"] = model_config.get("top_p", 1.0) - if kwargs["frequency_penalty"] is None: - kwargs["frequency_penalty"] = model_config.get("frequency_penalty", 0.0) - if kwargs["presence_penalty"] is None: - kwargs["presence_penalty"] = model_config.get("presence_penalty", 0.0) - if kwargs["max_tokens"] is None: - # Setting to -1 means the generation will not stop until - # exceeding model capability or hit any stop criteria. - kwargs["max_tokens"] = -1 - if request.stop is not None: - kwargs["stop_strs"] = [request.stop] if isinstance(request.stop, str) else request.stop - if request.response_format is not None: - kwargs["response_format"] = ResponseFormat( - **request.response_format.model_dump(by_alias=True) - ) - return kwargs diff --git a/python/mlc_llm/quantization/__init__.py b/python/mlc_llm/quantization/__init__.py index 04dc39b33f..d2c89bb2a1 100644 --- a/python/mlc_llm/quantization/__init__.py +++ b/python/mlc_llm/quantization/__init__.py @@ -1,8 +1,9 @@ """A subpackage for quantization and dequantization algorithms""" + from .awq_quantization import AWQQuantize +from .fp8_quantization import FP8PerTensorQuantizeMixtralExperts from .ft_quantization import FTQuantize from .group_quantization import GroupQuantize -from .per_tensor_quantization import PerTensorQuantize from .no_quantization import NoQuantize +from .per_tensor_quantization import PerTensorQuantize from .quantization import QUANTIZATION, Quantization -from .smooth_quantization import SmoothQuantize diff --git a/python/mlc_llm/quantization/awq_quantization.py b/python/mlc_llm/quantization/awq_quantization.py index 1d7cddbfa6..d51f0a6020 100644 --- a/python/mlc_llm/quantization/awq_quantization.py +++ b/python/mlc_llm/quantization/awq_quantization.py @@ -117,7 +117,11 @@ def visit_module(self, name: str, node: nn.Module) -> Any: The new node to replace current node. """ - if isinstance(node, nn.Linear) and not is_final_fc(name) and not is_moe_gate(name): + if ( + isinstance(node, nn.Linear) + and not is_final_fc(name) + and not is_moe_gate(name, node) + ): return AWQQuantizeLinear.from_linear(node, self.config) return self.visit(name, node) diff --git a/python/mlc_llm/quantization/fp8_quantization.py b/python/mlc_llm/quantization/fp8_quantization.py index 87d16cb53b..b7cd41c98d 100644 --- a/python/mlc_llm/quantization/fp8_quantization.py +++ b/python/mlc_llm/quantization/fp8_quantization.py @@ -1,299 +1,39 @@ """ Quantization techniques for FP8 """ -from dataclasses import dataclass -from functools import partial -from typing import Any, Callable, List, Literal, Optional, Tuple, Union -from tvm import DataType, DataTypeCode, IRModule -from tvm import dlight as dl -from tvm import relax, te, tir, topi -from tvm import nd +import numpy as np +from tvm import nd, relax from tvm.relax.frontend import nn -from tvm.runtime import NDArray -from tvm.script import tir as T -from tvm.target import Target -from mlc_llm.loader import QuantizeMapping from mlc_llm.nn import MixtralExperts -from mlc_llm.support import logging -from mlc_llm.support import tensor_parallel as tp -from . import group_quantization as gq +from ..op import cutlass, extern, moe_matmul from . import per_tensor_quantization as ptq from .utils import apply_sharding -def quantize( - x: nn.Tensor, - quantize_dtype: str, - kind="fp8-max", - scale_tensor: Optional[nn.Tensor] = None, - compute_scale_only=False, - name: str = "quantize", - **kwargs, -) -> Tuple[nn.Tensor, ...]: - """ - Quantizes the input tensor to a specified lower-precision datatype using different quantization schemes. - - This function supports quantization schemes such as 'fp8-max', where each element in the tensor is scaled and - quantized to a target datatype that uses fewer bits than the original datatype. The fp8-max range scheme - scales the tensor values based on the maximum value in the tensor to utilize the full range of the target datatype. - - Parameters - ---------- - x : nn.Tensor - The input tensor to be quantized. - - quantize_dtype : DataType - The target datatype for quantization. - - kind : str, optional - The kind of quantization scheme to use. - - name : str, optional - A name hint for the operation. - - **kwargs : dict - Additional keyword arguments for quantization parameters. For 'fp8-max', 'max_int_value' must be provided, - which defines the maximum integer value that can be represented in the target datatype. - - Returns - ------- - result : Tuple[nn.Tensor, ...] - A list of tensors from the qunatization, - Usually the quantized tensor, and parameter tensors like scale and zero point - - """ - if kind == "fp8-max": - # quant: Tensor(dtype="e4m3_float8") = (x / scale); scale: float16 = max(x) / fp8_max_int_value); - assert ( - "max_int_value" in kwargs - ), "'max_int_value' must be provided when using fp8-max quantization" - assert len(kwargs) == 1, f"Unknown additional kwargs: {kwargs}" - - assert ( - compute_scale_only == False or scale_tensor == None - ), "fp8-max calibration: Cannot provide a scale and request scale computation" - - def _compute_scale(max_abs: te.Tensor) -> te.Tensor: - max_int = tir.const(kwargs["max_int_value"], x.dtype) - min_scaling_factor = tir.const(1.0 / (kwargs["max_int_value"] * 512.0), x.dtype) - - scale = te.compute( - (1,), - lambda *idx: te.max( - max_abs(*idx).astype(x.dtype) / max_int, - min_scaling_factor, - ), - name="scale", - ) - return scale - - def _quantize(tensor: te.Tensor, scale: te.Tensor) -> te.Tensor: - scaled_act = te.compute( - shape=tensor.shape, - fcompute=lambda *idx: tir.Cast( - quantize_dtype, - tensor(*idx) / scale[0], - ), - ) - return scaled_act - - def _fused_compute_scale_and_quantize( - tensor: te.Tensor, - max_abs: te.Tensor, - out_shape: Optional[List[tir.PrimExpr]] = None, - ): - scale = _compute_scale(max_abs) - scaled_act = _quantize(tensor, scale) - return scaled_act, scale - - if compute_scale_only or scale_tensor == None: - max_abs = nn.op.extern( - "tvm.contrib.cuda.reduce_max_abs", - [x], - nn.Tensor.placeholder((1,), x.dtype), - ) - - if compute_scale_only: - scale = nn.op.tensor_expr_op( # pylint: disable=invalid-name - lambda max_tensor: _compute_scale( # pylint: disable=protected-access - max_tensor, - ), - name_hint="compute_scale", - args=[max_abs], - ) - return scale - elif scale_tensor == None: - quant, scale = nn.op.tensor_expr_op( # pylint: disable=invalid-name - lambda tensor, max_tensor: _fused_compute_scale_and_quantize( # pylint: disable=protected-access - tensor, - max_tensor, - ), - name_hint="compute_scale_and_quantize", - args=[x, max_abs], - ) - return quant, scale - else: - quant = nn.op.tensor_expr_op( # pylint: disable=invalid-name - lambda tensor, scale_t: _quantize( # pylint: disable=protected-access - tensor, - scale_t, - ), - name_hint="quantize", - args=[x, scale_tensor], - ) - return quant - else: - raise ValueError("Unknown quantization kind") - - -def dequantize( - quant: nn.Tensor, - scale: nn.Tensor, - zero: nn.Tensor = None, - kind="fp8-max", - name="dequantize", - **kwargs, -) -> nn.Tensor: - """ - Dequantizes the input tensor from a specified lower-precision datatype back to a higher-precision datatype. - - This function supports dequantization schemes such as 'fp8-max', where each element in the quantized tensor - is converted back to a higher-precision format using the provided scale. The 'fp8-max' scheme specifically - reverses the scaling applied during quantization, without utilizing a zero-point adjustment. - - Parameters - ---------- - quant : nn.Tensor - The quantized tensor to be dequantized. - - scale : nn.Tensor - The scale used during quantization. - original higher-precision format. - - zero : nn.Tensor, optional - The zero-point used during quantization. - - kind : str, optional - The kind of dequantization scheme to use. - - name : str, optional - A name hint for the operation. - - **kwargs : dict - Additional keyword arguments for dequantization parameters. - - Returns - ------- - nn.Tensor - The dequantized tensor. - - """ - if kind == "fp8-max": - # dequant: Tensor(dtype="float16") = (quant * scale); scale precompute by quantization - assert zero == None, "FP8 max range quantization does not utilzie a zero point" - return quant * scale - else: - raise ValueError("Unknown quantization kind") - - -def inplace_maximum(scale: nn.Tensor, param: nn.Tensor): - @T.prim_func - def max_update( - scale_local: T.Buffer(scale.shape, scale.dtype), - scale_global: T.Buffer(param.shape, param.dtype), - # TODO(csullivan): consider using nn.op.tensor_ir_inplace_op - out_scale: T.Buffer(param.shape, param.dtype), - ): - T.func_attr({"tir.noalias": T.bool(True)}) - # TODO(csullivan): use index expansion - intermediate = T.alloc_buffer(scale_global.shape, dtype=scale_global.dtype) - for i in range(scale_global.shape[0]): - with T.block("read"): - vi = T.axis.remap("S", [i]) - T.reads(scale_global[vi]) - T.writes(intermediate[vi]) - intermediate[vi] = scale_global[vi] - for i in range(scale_local.shape[0]): - with T.block("max_update"): - vi = T.axis.remap("S", [i]) - T.reads(scale_local[vi], intermediate[vi]) - T.writes(scale_global[vi], out_scale[vi]) - scale_global[vi] = T.if_then_else( - scale_local[vi] > intermediate[vi], scale_local[vi], intermediate[vi] - ) - out_scale[vi] = scale_local[vi] - - return nn.op.tensor_ir_op( - max_update, - name_hint="inplace_maximum", - args=[scale, param], - out=nn.Tensor.placeholder(scale.shape, scale.dtype), - ) - - -nn.op.quantize = quantize -nn.op.dequantize = dequantize -nn.op.maximum_inplace = inplace_maximum - - -class MixtralExpertsFP8( +class FP8PerTensorQuantizeMixtralExperts( ptq.PerTensorQuantizeMixtralExperts ): # pylint: disable=too-many-instance-attributes + """MixtralExperts with per-tensor quantization in FP8.""" + def __init__( self, num_local_experts, in_features, out_features, - weight_config: ptq.PerTensorQuantize, - activation_dtype: str = None, - weight_dtype: str = None, - runtime: str = "cast", + config: ptq.PerTensorQuantize, + name: str, tensor_parallel_shards=1, ): # pylint: disable=too-many-arguments - super().__init__(num_local_experts, in_features, out_features, weight_config) - self.activation_dtype = activation_dtype - self.weight_dtype = weight_dtype - self.runtime = runtime - self.weight_config = weight_config - self.max_int_value = { - key: self.get_max_int(dtype) - for key, dtype in zip(["x", "w"], [self.activation_dtype, self.weight_dtype]) - } - - # TODO(csullivan): Delete this as it is no longer necessary + super().__init__(num_local_experts, in_features, out_features, config, name) self.tensor_parallel_shards = tensor_parallel_shards - if "max" in self.runtime: - self.q_calibration_scale = nn.Parameter((1,), weight_config.model_dtype) - - @staticmethod - def get_max_int(dtype): - if dtype == "e4m3_float8": - return 448 - elif dtype == "e5m2_float8": - return 57344 - elif dtype == "float16": - return 65504 - else: - raise NotImplementedError() - - def add_calibration_params(self, quant_map: QuantizeMapping, layer_name: str): - scale_name = f"{layer_name}.q_calibration_scale" - - def alloc_scale(): - return nd.empty( - shape=self.q_calibration_scale.shape, dtype=self.q_calibration_scale.dtype - ) - - quant_map.map_func[scale_name] = alloc_scale - return quant_map - @staticmethod def from_mixtral_experts( src: "MixtralExperts", - weight_config: ptq.PerTensorQuantize, - ) -> "MixtralExpertsFP8": + config: ptq.PerTensorQuantize, + name: str, + ) -> "FP8PerTensorQuantizeMixtralExperts": """ Converts a non-quantized MixtralExperts to a per-tensor quantized MixtralExperts. @@ -302,290 +42,83 @@ def from_mixtral_experts( src : MixtralExperts The non-quantized MixtralExperts - weight_config : GroupQuantize - The group quantization weight_config. + config : PerTensorQuantize + The FP8 quantization weight_config. + + name : str + The name of the layer. Returns ------- ret : MixtralExpertsFP8 The per-tensor quantized MixtralExperts. """ - - quantized_mistral_experts = MixtralExpertsFP8( + quantized_mistral_experts = FP8PerTensorQuantizeMixtralExperts( num_local_experts=src.num_local_experts, in_features=src.in_features, out_features=src.out_features, - weight_config=weight_config, - activation_dtype=weight_config.activation_dtype, - weight_dtype=weight_config.weight_dtype, - runtime="max" if "calibration" not in weight_config.name else "max-calibration", + config=config, + name=name, tensor_parallel_shards=src.tensor_parallel_shards, ) if "shard_strategy" in src.weight.attrs: shard = src.weight.attrs["shard_strategy"] apply_sharding(shard, f"{shard.name}_q_weight", quantized_mistral_experts.q_weight) - apply_sharding( - tp.ShardScalar(name=shard.name), - f"{shard.name}_q_scale", - quantized_mistral_experts.q_scale, - ) - if "max" in quantized_mistral_experts.runtime: - apply_sharding( - tp.ShardScalar(name=shard.name), - f"{shard.name}_q_calibration_scale", - quantized_mistral_experts.q_calibration_scale, - ) + # scale doesn't need to be sharded since it's the same for all shards return quantized_mistral_experts def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name - if self.runtime == "max-calibration": - x, local_scale = nn.op.quantize( - x, - quantize_dtype=self.activation_dtype, - kind="fp8-max", - max_int_value=self.max_int_value["x"], - ) - - local_scale = nn.op.maximum_inplace(local_scale, self.q_calibration_scale) - # Calibration done in fp16 mma - x = nn.op.astype(x, dtype="float16") - if DataType(self.weight_dtype).type_code in [ - DataTypeCode.E4M3Float, - DataTypeCode.E5M2Float, - ]: - dequant_func = self.config._dequantize_float8 - # The below computes w = cast(reinterpret_cast(self.q_weight, float8_e4m3), float16) * self.q_scale - w = nn.op.tensor_expr_op( # pylint: disable=invalid-name - lambda weight, scale: dequant_func( # pylint: disable=protected-access - weight, - scale, - out_shape=(self.num_local_experts, self.out_features, self.in_features), - ), - name_hint="dequantize_weight", - args=[self.q_weight, self.q_scale], - ) - elif self.weight_dtype == "float16": - w = self.q_weight - - elif self.runtime == "max": - local_scale = self.q_calibration_scale - x = x / local_scale - x = nn.op.astype(x, dtype=self.activation_dtype) - w = self.q_weight - elif self.runtime == "cast": - x = nn.op.astype(x, dtype=self.activation_dtype) - w = self.q_weight - else: - raise NotImplementedError( - f"Only max and cast runtimes are supported for FP8 activations, {self.runtime} is unsupported." - ) - - workspace = nn.op.wrap_nested( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr((4096 * 1024,)), - dtype="uint8", - runtime_device_index=0, - ), - "relax.alloc_tensor", - ) - - batch_size, in_features = x.shape - num_local_experts, out_features, _ = self.q_weight.shape - - if self.runtime == "max-calibration": - func = "cutlass.group_gemm_scale_fp16_sm90" - else: - a_format = self.activation_dtype.split("_")[0] - w_format = self.weight_dtype.split("_")[0] - func = f"cutlass.group_gemm_{a_format}_{w_format}_fp16" + w = self.q_weight - if self.runtime == "cast": - func = func + "_host_scale" - total_scale = 1.0 - else: - if self.runtime != "max-calibration" and self.weight_dtype == "e4m3_float8": - # for calibration, q_scale is already used to dequantize the weights - total_scale = local_scale * self.q_scale - else: - total_scale = local_scale - total_scale = nn.op.astype(total_scale, dtype="float32") - - return nn.op.extern( - func, - [ + if self.config.calibration_mode == "max": + _, x_scale = self.config.quantize_float8( # type: ignore x, - w, - indptr, - workspace, - total_scale, - ], - out=nn.Tensor.placeholder( - (batch_size, out_features), - dtype=self.weight_config.model_dtype, - ), - ) - - -# TODO(csullivan): Refactor Linear and MixtralExperts to shared base with common code -class PTQLinearFP8(ptq.PerTensorQuantizeLinear): # pylint: disable=too-many-instance-attributes - def __init__( - self, - in_features, - out_features, - weight_config: ptq.PerTensorQuantize, - activation_dtype: str = None, - weight_dtype: str = None, - bias: bool = True, - out_dtype: Optional[str] = None, - runtime: str = "cast", - ): # pylint: disable=too-many-arguments - super().__init__(in_features, out_features, weight_config, bias, out_dtype) - self.activation_dtype = activation_dtype - self.weight_dtype = weight_dtype - self.runtime = runtime - self.weight_config = weight_config - self.max_int_value = { - key: self.get_max_int(dtype) - for key, dtype in zip(["x", "w"], [self.activation_dtype, self.weight_dtype]) - } - - if "max" in self.runtime: - self.q_calibration_scale = nn.Parameter((1,), weight_config.model_dtype) - - @staticmethod - def get_max_int(dtype): - if dtype == "e4m3_float8": - return 448 - elif dtype == "e5m2_float8": - return 57344 - elif dtype == "float16": - return 65504 - else: - raise NotImplementedError() - - def add_calibration_params(self, quant_map: QuantizeMapping, layer_name: str): - scale_name = f"{layer_name}.q_calibration_scale" - - def alloc_scale(): - return nd.empty( - shape=self.q_calibration_scale.shape, dtype=self.q_calibration_scale.dtype + quantize_dtype=self.config.activation_dtype, + storage_dtype=self.config.activation_dtype, ) - - quant_map.map_func[scale_name] = alloc_scale - return quant_map - - @staticmethod - def from_linear( - src: nn.Linear, - weight_config: ptq.PerTensorQuantize, - ) -> "PTQLinearFP8": - """ - Converts a non-quantized Linear to a per-tensor quantized FP8 Linear. - - Parameters - ---------- - src : nn.Linear - The non-quantized Linear layer - - weight_config : GroupQuantize - The group quantization weight_config. - - Returns - ------- - ret : PTQLinearFP8 - The per-tensor quantized Linear layer in FP8 precision. - """ - out_features, in_features = src.weight.shape - quantized_linear = PTQLinearFP8( - in_features=in_features, - out_features=out_features, - weight_config=weight_config, - activation_dtype=weight_config.activation_dtype, - weight_dtype=weight_config.weight_dtype, - bias=getattr(src, "bias", None) is not None, - out_dtype=src.out_dtype, - runtime="max" if "calibration" not in weight_config.name else "max-calibration", - ) - - if "shard_strategy" in src.weight.attrs: - shard = src.weight.attrs["shard_strategy"] - apply_sharding(shard, f"{shard.name}_q_weight", quantized_linear.q_weight) - apply_sharding( - tp.ShardScalar(name=shard.name), - f"{shard.name}_q_scale", - quantized_linear.q_scale, + if self.config.tensor_parallel_shards > 1: + x_scale = nn.ccl_allreduce(x_scale, "max") + x_scale = nn.extern( + "mlc_llm.calibration_observer", + [f"{self.name}.q_calibration_scale", "max", x_scale], + out=nn.Tensor.placeholder(x_scale.shape, x_scale.dtype), ) - # TODO(csullivan) lm head is not sharded. calibration logic needs to change i think - if "max" in quantized_linear.runtime: - apply_sharding( - tp.ShardScalar(name=shard.name), - f"{shard.name}_q_calibration_scale", - quantized_linear.q_calibration_scale, - ) + x_q = (x / x_scale.astype(x.dtype)).astype(self.config.activation_dtype) + x = x_q.astype(self.config.model_dtype) * x_scale.astype(self.config.model_dtype) - return quantized_linear - - def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name - if self.runtime == "max-calibration": - x, local_scale = nn.op.quantize( - x, - quantize_dtype=self.activation_dtype, - kind="fp8-max", - max_int_value=self.max_int_value["x"], + if indptr.ndim == 2: + assert indptr.shape[0] == 1 + return moe_matmul.dequantize_float8_gemv( + x, w, self.q_scale, indptr, self.config.weight_dtype ) - local_scale = nn.op.maximum_inplace(local_scale, self.q_calibration_scale) - # Calibration done in fp16 mma so convert x and w back to fp16 - x = nn.op.astype(x, dtype="float16") - - if DataType(self.weight_dtype).type_code in [ - DataTypeCode.E4M3Float, - DataTypeCode.E5M2Float, - ]: - dequant_func = self.config._dequantize_float8 - # The below computes w = cast(reinterpret_cast(self.q_weight, float8_e4m3), float16) * self.q_scale - w = nn.op.tensor_expr_op( # pylint: disable=invalid-name - lambda weight, scale: dequant_func( # pylint: disable=protected-access - weight, - scale, - out_shape=(self.out_features, self.in_features), - ), - name_hint="dequantize_weight", - args=[self.q_weight, self.q_scale], + if extern.get_store().cutlass_group_gemm: + if self.config.calibration_mode == "inference": + if self.q_calibration_scale is not None: + x /= self.q_calibration_scale.astype(x.dtype) + x_q = nn.op.astype(x, dtype=self.config.activation_dtype) + x_scale = self.q_calibration_scale + + scale = ( + x_scale * self.q_scale + if self.q_scale is not None + else nn.wrap_nested( + relax.Constant(nd.array(np.array([1.0]).astype("float32"))), "scale" ) - elif self.weight_dtype == "float16": - w = self.q_weight - elif self.runtime == "max": - local_scale = self.q_calibration_scale - x = x / local_scale - x = nn.op.astype(x, dtype=self.activation_dtype) - w = self.q_weight - elif self.runtime == "cast": - x = nn.op.astype(x, dtype=self.activation_dtype) - w = self.q_weight - else: - raise NotImplementedError( - f"Only max and cast runtimes are supported for FP8 activations, {self.runtime} is unsupported." ) - w = nn.op.permute_dims(w) - x = nn.op.matmul(x, w, out_dtype="float32") - - if self.runtime == "cast": - total_scale = 1.0 - else: - if self.runtime == "max" and self.weight_dtype == "e4m3_float8": - local_scale = nn.op.astype(local_scale, dtype="float32") - q_scale = nn.op.astype(self.q_scale, dtype="float32") - total_scale = local_scale * q_scale - else: - # for calibration, q_scale is already used to dequantize the weights - total_scale = nn.op.astype(local_scale, dtype="float32") + return cutlass.group_gemm( + x_q, w, indptr, scale, self.config.weight_dtype, self.config.model_dtype + ) + # Note: convert_weight is target agnostic, so a fallback must be provided + w = nn.tensor_expr_op( + self.config.dequantize_float8, + "dequantize", + args=[w, self.q_scale, self.config.weight_dtype], + ) + return moe_matmul.group_gemm(x, w, indptr) - x *= total_scale - x = nn.op.astype(x, self.weight_config.model_dtype) - if self.bias is not None: - x = x + self.bias - return x +# pylint: disable=protected-access +ptq.PerTensorQuantizeMixtralExperts._IMPL["fp8"] = FP8PerTensorQuantizeMixtralExperts diff --git a/python/mlc_llm/quantization/ft_quantization.py b/python/mlc_llm/quantization/ft_quantization.py index b6b1da100f..4a15846096 100644 --- a/python/mlc_llm/quantization/ft_quantization.py +++ b/python/mlc_llm/quantization/ft_quantization.py @@ -147,7 +147,7 @@ def visit_module(self, name: str, node: nn.Module) -> Any: group_quantize = self.config.fallback_group_quantize() self.quant_map.map_func[weight_name] = group_quantize.quantize_weight return GroupQuantizeLinear.from_linear(node, group_quantize) - if not is_moe_gate(name): + if not is_moe_gate(name, node): self.quant_map.map_func[weight_name] = self.config.quantize_weight return FTQuantizeLinear.from_linear(node, self.config) if isinstance(node, nn.Embedding): diff --git a/python/mlc_llm/quantization/group_quantization.py b/python/mlc_llm/quantization/group_quantization.py index c446972639..27cac54212 100644 --- a/python/mlc_llm/quantization/group_quantization.py +++ b/python/mlc_llm/quantization/group_quantization.py @@ -2,26 +2,23 @@ from dataclasses import dataclass from functools import partial -from typing import Any, Callable, List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union -from tvm import DataType, DataTypeCode, IRModule -from tvm import dlight as dl -from tvm import relax, te, tir, topi +from tvm import DataType, DataTypeCode, IRModule, relax, te, tir, topi from tvm.relax.frontend import nn from tvm.runtime import NDArray -from tvm.target import Target from mlc_llm.loader import QuantizeMapping from mlc_llm.nn import MixtralExperts from mlc_llm.support import logging - from .utils import ( + apply_sharding, + compile_quantize_func, convert_uint_to_float, is_final_fc, - convert_uint_packed_fp8_to_float, - compile_quantize_func, - apply_sharding, + is_moe_gate, + pack_weight, ) logger = logging.getLogger(__name__) @@ -34,7 +31,7 @@ class GroupQuantize: # pylint: disable=too-many-instance-attributes name: str kind: str group_size: int - quantize_dtype: Literal["int3", "int4", "int8", "e4m3_float8", "e5m2_float8"] + quantize_dtype: Literal["int3", "int4", "int8"] storage_dtype: Literal["uint32"] model_dtype: Literal["float16", "float32"] linear_weight_layout: Literal["KN", "NK"] @@ -44,18 +41,14 @@ class GroupQuantize: # pylint: disable=too-many-instance-attributes num_elem_per_storage: int = 0 num_storage_per_group: int = 0 max_int_value: int = 0 + tensor_parallel_shards: int = 0 def __post_init__(self): assert self.kind == "group-quant" quantize_dtype = DataType(self.quantize_dtype) storage_dtype = DataType(self.storage_dtype) model_dtype = DataType(self.model_dtype) - if quantize_dtype.type_code in (DataTypeCode.E4M3Float, DataTypeCode.E5M2Float): - self.fp8_quant = True - else: - self.fp8_quant = False - assert quantize_dtype.type_code == DataTypeCode.INT - self.no_scale = quantize_dtype.type_code == DataTypeCode.E5M2Float + assert quantize_dtype.type_code == DataTypeCode.INT assert storage_dtype.type_code == DataTypeCode.UINT assert model_dtype.type_code == DataTypeCode.FLOAT if storage_dtype.bits < quantize_dtype.bits: @@ -65,15 +58,7 @@ def __post_init__(self): if self.group_size % self.num_elem_per_storage != 0: raise ValueError("Group size should be divisible by numbers of elements per storage") self.num_storage_per_group = self.group_size // self.num_elem_per_storage - if self.fp8_quant: - if quantize_dtype.type_code == DataTypeCode.E4M3Float: - self.max_int_value = 448 - elif quantize_dtype.type_code == DataTypeCode.E5M2Float: - self.max_int_value = 57344 - else: - raise NotImplementedError() - else: - self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1 + self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1 self.linear_quant_axis = 0 if self.linear_weight_layout == "KN" else 1 self._quantize_func_cache = {} @@ -126,82 +111,28 @@ def visit_module(self, name: str, node: nn.Module) -> Any: ret_node: Any The new node to replace current node. """ - from . import fp8_quantization as fp8 - - if isinstance(node, nn.Linear) and ( - not is_final_fc(name) or self.config.quantize_final_fc + if ( + isinstance(node, nn.Linear) + and (not is_final_fc(name) or self.config.quantize_final_fc) + and not is_moe_gate(name, node) ): weight_name = f"{name}.weight" - self.quant_map.param_map[weight_name] = ( - [f"{name}.q_weight", f"{name}.q_scale"] - if not self.config.no_scale - else [ - f"{name}.q_weight", - ] - ) + self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] self.quant_map.map_func[weight_name] = partial( self.config.quantize_weight, output_transpose=self.config.linear_weight_layout == "KN", ) - if False and self.config.quantize_dtype == "e4m3_float8": - return fp8.GroupQuantizeLinearFP8E4M3CutlassScaleOnly.from_linear( - node, self.config - ) - else: - return GroupQuantizeLinear.from_linear(node, self.config) + return GroupQuantizeLinear.from_linear(node, self.config) if isinstance(node, nn.Embedding) and self.config.quantize_embedding: weight_name = f"{name}.weight" - self.quant_map.param_map[weight_name] = ( - [f"{name}.q_weight", f"{name}.q_scale"] - if not self.config.no_scale - else [ - f"{name}.q_weight", - ] - ) + self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] self.quant_map.map_func[weight_name] = self.config.quantize_weight return GroupQuantizeEmbedding.from_embedding(node, self.config) if isinstance(node, MixtralExperts): weight_name = f"{name}.weight" - self.quant_map.param_map[weight_name] = ( - [f"{name}.q_weight", f"{name}.q_scale"] - if not self.config.no_scale - else [ - f"{name}.q_weight", - ] - ) + self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] self.quant_map.map_func[weight_name] = self.config.quantize_weight - if self.config.name == "fp8_e4m3_e5m2_max_calibration": - op = fp8.MixtralExpertsFP8.from_mixtral_experts( - node, - self.config, - activation_dtype="e4m3_float8", - weight_dtype="e5m2_float8", - runtime="max-calibration", - ) - self.quant_map = op.add_calibration_params(self.quant_map, name) - return op - elif self.config.name == "fp8_e4m3_e5m2_max": - return fp8.MixtralExpertsFP8.from_mixtral_experts( - node, - self.config, - activation_dtype="e4m3_float8", - weight_dtype="e5m2_float8", - runtime="max", - ) - elif self.config.name == "fp8_e5m2_e5m2": - return fp8.MixtralExpertsFP8.from_mixtral_experts( - node, - self.config, - activation_dtype="e5m2_float8", - weight_dtype="e5m2_float8", - ) - elif "fp8" in self.config.name: - raise NotImplementedError( - f"Requested FP8 quantization config {self.config.name} is not implemented" - ) - else: - return GroupQuantizeMixtralExperts.from_mixtral_experts(node, self.config) - + return GroupQuantizeMixtralExperts.from_mixtral_experts(node, self.config) return self.visit(name, node) model.to(dtype=self.model_dtype) @@ -235,61 +166,13 @@ def _dequantize( fcompute=lambda *idx: tir.multiply( tir.subtract( float_weight(*idx), - tir_max_int, # TODO(jmcmahan): the max_int shift to remove negatives is not necessary for fp8 + tir_max_int, ), scale(*idx[:axis], idx[axis] // self.group_size, *idx[axis + 1 :]), ), name="dequantize", ) - def _dequantize_e4m3( - self, - weight: te.Tensor, - scale: te.Tensor, - axis: int, - out_shape: Optional[List[tir.PrimExpr]] = None, - ): - float_e4m3_weight = convert_uint_packed_fp8_to_float( - weight, - DataType(self.quantize_dtype).bits, - self.num_elem_per_storage, - self.storage_dtype, - self.model_dtype, - self.quantize_dtype, - axis=axis, - out_shape=out_shape, - ) - if out_shape is None: - out_shape = weight.shape - out_shape[axis] *= self.num_elem_per_storage - axis = axis if axis >= 0 else len(out_shape) + axis - return te.compute( - shape=out_shape, - fcompute=lambda *idx: tir.multiply( - float_e4m3_weight(*idx), - scale(*idx[:axis], idx[axis] // self.group_size, *idx[axis + 1 :]), - ), - name="dequantize", - ) - - def _dequantize_e5m2( - self, - weight: te.Tensor, - axis: int, # axis is still relevant, because it determines which axis is packed into u32 storage - out_shape: Optional[List[tir.PrimExpr]] = None, - ): - float_e5m2_weight = convert_uint_packed_fp8_to_float( - weight, - DataType(self.quantize_dtype).bits, - self.num_elem_per_storage, - self.storage_dtype, - self.model_dtype, - self.quantize_dtype, - axis=axis, - out_shape=out_shape, - ) - return float_e5m2_weight - def quantize_weight( self, weight: NDArray, axis: int = -1, output_transpose: bool = False ) -> List[NDArray]: @@ -317,28 +200,12 @@ def quantize_weight( axis = axis if axis >= 0 else len(weight.shape) + axis def _create_quantize_func() -> IRModule: - if self.fp8_quant: - if ( - DataType(self.quantize_dtype).type_code == DataTypeCode.E4M3Float - or DataType(self.quantize_dtype).type_code == DataTypeCode.E5M2Float - ): - quantize_func = self._quantize_float8 - else: - assert NotImplementedError() - else: - quantize_func = self._quantize - bb = relax.BlockBuilder() # pylint: disable=invalid-name weight_var = relax.Var("weight", relax.TensorStructInfo(weight.shape, weight.dtype)) with bb.function(name="main", params=[weight_var]): with bb.dataflow(): - lv = bb.emit_te(quantize_func, weight_var, axis, output_transpose) - if isinstance(lv.struct_info, relax.TupleStructInfo): - tuple_output = bb.emit(lv) - else: - tuple_output = bb.emit((lv,)) - gv = bb.emit_output(tuple_output) # pylint: disable=invalid-name - + lv = bb.emit_te(self._quantize, weight_var, axis, output_transpose) + gv = bb.emit_output(lv) # pylint: disable=invalid-name bb.emit_func_output(gv) return bb.finalize() @@ -349,7 +216,7 @@ def _create_quantize_func() -> IRModule: quantize_func = self._quantize_func_cache.get(key, None) if quantize_func is None: logger.info("Compiling quantize function for key: %s", key) - quantize_func = compile_quantize_func(_create_quantize_func(), device) + quantize_func = compile_quantize_func(_create_quantize_func(), device=device) self._quantize_func_cache[key] = quantize_func return quantize_func(weight) @@ -364,7 +231,6 @@ def _quantize( # pylint: disable=too-many-locals shape = weight.shape # pylint: disable=invalid-name axis = axis if axis >= 0 else len(shape) + axis k = shape[axis] - quantize_dtype = DataType(self.quantize_dtype) # compute scale per group r = te.reduce_axis((0, self.group_size), name="r") # pylint: disable=invalid-name num_group = tir.ceildiv(k, self.group_size) @@ -402,23 +268,15 @@ def _quantize( # pylint: disable=too-many-locals ).astype(self.storage_dtype), ) # compute quantized weight per storage - r = te.reduce_axis((0, self.num_elem_per_storage), name="r") # pylint: disable=invalid-name num_storage = self.num_storage_per_group * num_group quantized_weight_shape = (*shape[:axis], num_storage, *shape[axis + 1 :]) - quantized_weight = te.compute( - shape=quantized_weight_shape, - fcompute=lambda *idx: tir.sum( - tir.if_then_else( - idx[axis] * self.num_elem_per_storage + r < k, - scaled_weight( - *idx[:axis], idx[axis] * self.num_elem_per_storage + r, *idx[axis + 1 :] - ) - << (r * quantize_dtype.bits), - 0, - ), - axis=r, - ), - name="weight", + quantized_weight = pack_weight( + scaled_weight, + axis=axis, + num_elem_per_storage=self.num_elem_per_storage, + weight_dtype=self.quantize_dtype, + storage_dtype=self.storage_dtype, + out_shape=quantized_weight_shape, ) if output_transpose: if len(quantized_weight.shape) != 2 or len(scale.shape) != 2: @@ -429,109 +287,6 @@ def _quantize( # pylint: disable=too-many-locals scale = topi.transpose(scale) return quantized_weight, scale - def _quantize_float8( # pylint: disable=too-many-locals - self, - weight: te.Tensor, - axis: int = -1, - output_transpose: bool = False, - ) -> Tuple[te.Tensor, te.Tensor]: - """Group quantization for weight tensor, defined in tensor expression.""" - - shape = weight.shape # pylint: disable=invalid-name - quantize_dtype = DataType(self.quantize_dtype) - k = shape[axis] - - if quantize_dtype.type_code == DataTypeCode.E4M3Float: - # compute scale per group - num_group = tir.ceildiv(k, self.group_size) - max_int = tir.const(self.max_int_value, self.model_dtype) - axis = axis if axis >= 0 else len(shape) + axis - r = te.reduce_axis((0, self.group_size), name="r") # pylint: disable=invalid-name - scale_shape = (*shape[:axis], num_group, *shape[axis + 1 :]) - - # min_scaling_factor taken from TRT-LLM - min_scaling_factor = tir.const(1.0 / (self.max_int_value * 512.0), self.model_dtype) - max_abs = te.compute( - shape=scale_shape, - fcompute=lambda *idx: te.max( - tir.if_then_else( - idx[axis] * self.group_size + r < k, - te.abs( - weight(*idx[:axis], idx[axis] * self.group_size + r, *idx[axis + 1 :]) - ), - te.min_value(self.model_dtype), - ), - axis=r, - ), - name="max_abs_value", - ) - scale = te.compute( - scale_shape, - lambda *idx: te.max( - max_abs(*idx).astype(self.model_dtype) / max_int, min_scaling_factor - ), - name="scale", - ) - # compute scaled weight - # TODO(fp8-team): Convince ourselves that we don't need to clip the quantized weight - # Need a cast to FP8, then reinerpret cast - scaled_weight = te.compute( - shape=weight.shape, - fcompute=lambda *idx: tir.reinterpret( - # TODO(csullivan) Change this to a vector type to simplify storage and improving casting - self.storage_dtype, - tir.Cast( - self.quantize_dtype, - weight(*idx) - / scale(*idx[:axis], idx[axis] // self.group_size, *idx[axis + 1 :]), - ), - ), - ) - elif quantize_dtype.type_code == DataTypeCode.E5M2Float: - scaled_weight = te.compute( - shape=weight.shape, - fcompute=lambda *idx: tir.reinterpret( - self.storage_dtype, - tir.Cast( # TODO(jmcmahan): verify that this cast (fp16 -> e5m2) does the expected mantissa clip - self.quantize_dtype, weight(*idx) - ), - ), - ) - - # TODO(csullivan): If using vector type fp8x4 this compute op can be deleted - # compute quantized weight per storage - r = te.reduce_axis((0, self.num_elem_per_storage), name="r") # pylint: disable=invalid-name - num_storage = tir.ceildiv(k, self.num_elem_per_storage) - quantized_weight_shape = (*shape[:axis], num_storage, *shape[axis + 1 :]) - quantized_weight = te.compute( - shape=quantized_weight_shape, - fcompute=lambda *idx: tir.sum( - tir.if_then_else( - idx[axis] * self.num_elem_per_storage + r < k, - scaled_weight( - *idx[:axis], idx[axis] * self.num_elem_per_storage + r, *idx[axis + 1 :] - ) - << (r * quantize_dtype.bits), - 0, - ), - axis=r, - ), - name="weight", - ) - - if output_transpose: - if len(quantized_weight.shape) != 2 or len(scale.shape) != 2: - raise ValueError( - "Does not support transpose output quantized weight with ndim != 2" - ) - quantized_weight = topi.transpose(quantized_weight) - if quantize_dtype.type_code == DataTypeCode.E4M3Float: - scale = topi.transpose(scale) - if quantize_dtype.type_code == DataTypeCode.E4M3Float: - return quantized_weight, scale - elif quantize_dtype.type_code == DataTypeCode.E5M2Float: - return quantized_weight - class GroupQuantizeLinear(nn.Module): # pylint: disable=too-many-instance-attributes """An nn.Linear module with group quantization""" @@ -549,20 +304,27 @@ def __init__( # pylint: disable=too-many-arguments self.out_features = out_features self.out_dtype = out_dtype self.config = config - self.no_scale = self.config.no_scale num_group = tir.ceildiv(in_features, config.group_size) + num_shards = config.tensor_parallel_shards + if num_shards > 1 and (in_features * num_shards // config.group_size) % num_shards != 0: + raise ValueError( + f"The linear dimension {in_features * num_shards} has " + f"{in_features * num_shards // config.group_size} groups under group size " + f"{config.group_size}. The groups cannot be evenly distributed on " + f"{num_shards} GPUs.\n" + "Possible solutions: reduce number of GPUs, or use quantization with smaller " + "group size." + ) if config.linear_weight_layout == "KN": self.q_weight = nn.Parameter( (config.num_storage_per_group * num_group, out_features), config.storage_dtype ) - if not self.no_scale: - self.q_scale = nn.Parameter((num_group, out_features), config.model_dtype) + self.q_scale = nn.Parameter((num_group, out_features), config.model_dtype) else: self.q_weight = nn.Parameter( (out_features, config.num_storage_per_group * num_group), config.storage_dtype ) - if not self.no_scale: - self.q_scale = nn.Parameter((out_features, num_group), config.model_dtype) + self.q_scale = nn.Parameter((out_features, num_group), config.model_dtype) if bias: self.bias = nn.Parameter( (out_features,), config.model_dtype if out_dtype is None else out_dtype @@ -570,8 +332,8 @@ def __init__( # pylint: disable=too-many-arguments else: self.bias = None - @classmethod - def from_linear(cls, src: nn.Linear, config: GroupQuantize) -> "GroupQuantizeLinear": + @staticmethod + def from_linear(src: nn.Linear, config: GroupQuantize) -> "GroupQuantizeLinear": """ Converts a non-quantized nn.Linear to a group quantized GroupQuantizeLinear @@ -590,7 +352,7 @@ def from_linear(cls, src: nn.Linear, config: GroupQuantize) -> "GroupQuantizeLin """ # For dynamic shape, src.out_features is `"name"`; src.weight.shape[0] is `tir.Var("name")` out_features, in_features = src.weight.shape - quantized_linear = cls( + quantized_linear = GroupQuantizeLinear( in_features=in_features, out_features=out_features, config=config, @@ -602,10 +364,7 @@ def from_linear(cls, src: nn.Linear, config: GroupQuantize) -> "GroupQuantizeLin if "shard_strategy" in src.weight.attrs: shard = src.weight.attrs["shard_strategy"] apply_sharding(shard, f"{shard.name}_q_weight", quantized_linear.q_weight) - if ( - not DataType(config.quantize_dtype).type_code == DataTypeCode.E5M2Float - ): # no scale for e5m2 - apply_sharding(shard, f"{shard.name}_q_scale", quantized_linear.q_scale) + apply_sharding(shard, f"{shard.name}_q_scale", quantized_linear.q_scale) return quantized_linear def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name @@ -622,59 +381,34 @@ def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name ret : nn.Tensor The output tensor for the group quantized linear layer. """ - if self.config.fp8_quant: - if DataType(self.config.quantize_dtype).type_code == DataTypeCode.E4M3Float: - dequant_func = self.config._dequantize_e4m3 - elif DataType(self.config.quantize_dtype).type_code == DataTypeCode.E5M2Float: - dequant_func = self.config._dequantize_e5m2 - else: - raise NotImplementedError() - else: - dequant_func = self.confg._dequantize - - if self.config.linear_weight_layout == "NK": - out_shape = [ - ( - tir.IntImm("int64", self.out_features) - if isinstance(self.out_features, int) - else weight.shape[0] - ), - tir.IntImm("int64", self.in_features), - ] - else: - out_shape = [ - tir.IntImm("int64", self.in_features), - ( - tir.IntImm("int64", self.out_features) - if isinstance(self.out_features, int) - else weight.shape[1] - ), - ] - - if not self.no_scale: - w = nn.op.tensor_expr_op( # pylint: disable=invalid-name - lambda weight, scale: dequant_func( # pylint: disable=protected-access - weight, - scale, - axis=self.config.linear_quant_axis, - out_shape=out_shape, - ), - name_hint="dequantize", - args=[self.q_weight, self.q_scale], - ) - else: - w = nn.op.tensor_expr_op( # pylint: disable=invalid-name - lambda weight: dequant_func( # pylint: disable=protected-access - weight, - axis=self.config.linear_quant_axis, - out_shape=out_shape, + w = nn.op.tensor_expr_op( # pylint: disable=invalid-name + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + axis=self.config.linear_quant_axis, + out_shape=( + [ + ( + tir.IntImm("int64", self.out_features) + if isinstance(self.out_features, int) + else weight.shape[0] + ), # Reuse same tir.Var for symbolic shape (after Exporter) + tir.IntImm("int64", self.in_features), + ] + if self.config.linear_weight_layout == "NK" + else [ + tir.IntImm("int64", self.in_features), + ( + tir.IntImm("int64", self.out_features) + if isinstance(self.out_features, int) + else weight.shape[1] + ), # Reuse same tir.Var for symbolic shape (after Exporter) + ] ), - name_hint="dequantize", - args=[ - self.q_weight, - ], - ) - + ), + name_hint="dequantize", + args=[self.q_weight, self.q_scale], + ) if self.config.linear_weight_layout == "NK": w = nn.op.permute_dims(w) # pylint: disable=invalid-name x = nn.op.matmul(x, w, out_dtype=self.out_dtype) @@ -688,8 +422,7 @@ def to(self, dtype: Optional[str] = None) -> None: Otherwise, we might run into dtype mismatch when computing x + self.bias. """ self.q_weight.to(dtype=dtype) - if not self.no_scale: - self.q_scale.to(dtype=dtype) + self.q_scale.to(dtype=dtype) if self.bias is not None and self.out_dtype is None: self.bias.to(dtype=dtype) if dtype is not None and isinstance(getattr(self, "dtype", None), str): @@ -703,13 +436,11 @@ def __init__(self, num: Union[int, tir.Var], dim: int, config: GroupQuantize): self.num = num self.dim = dim self.config = config - self.no_scale = self.config.no_scale num_group = tir.ceildiv(dim, config.group_size) self.q_weight = nn.Parameter( (num, config.num_storage_per_group * num_group), config.storage_dtype ) - if not self.no_scale: - self.q_scale = nn.Parameter((num, num_group), config.model_dtype) + self.q_scale = nn.Parameter((num, num_group), config.model_dtype) @staticmethod def from_embedding(embedding: nn.Embedding, config: GroupQuantize) -> "GroupQuantizeEmbedding": @@ -746,47 +477,23 @@ def forward(self, x: nn.Tensor): # pylint: disable=invalid-name ret : nn.Tensor The output tensor for the embedding layer. """ - if self.config.fp8_quant: - if DataType(self.config.quantize_dtype).type_code == DataTypeCode.E4M3Float: - dequant_func = self.config._dequantize_e4m3 - elif DataType(self.config.quantize_dtype).type_code == DataTypeCode.E5M2Float: - dequant_func = self.config._dequantize_e5m2 - else: - raise NotImplementedError() - else: - dequant_func = self.confg._dequantize - - out_shape = [ - ( - tir.IntImm("int64", self.num) if isinstance(self.num, int) else weight.shape[0] - ), # Reuse same tir.Var for symbolic shape (after Exporter) - tir.IntImm("int64", self.dim), - ] - - if not self.no_scale: - w = nn.op.tensor_expr_op( # pylint: disable=invalid-name - lambda weight, scale: dequant_func( # pylint: disable=protected-access - weight, - scale, - axis=-1, - out_shape=out_shape, - ), - name_hint="dequantize", - args=[self.q_weight, self.q_scale], - ) - else: - w = nn.op.tensor_expr_op( # pylint: disable=invalid-name - lambda weight: dequant_func( # pylint: disable=protected-access - weight, - axis=-1, - out_shape=out_shape, - ), - name_hint="dequantize", - args=[ - self.q_weight, + w = nn.op.tensor_expr_op( # pylint: disable=invalid-name + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + axis=-1, + out_shape=[ + ( + tir.IntImm("int64", self.num) + if isinstance(self.num, int) + else weight.shape[0] + ), # Reuse same tir.Var for symbolic shape (after Exporter) + tir.IntImm("int64", self.dim), ], - ) - + ), + name_hint="dequantize", + args=[self.q_weight, self.q_scale], + ) if x.ndim == 1: return nn.op.take(w, x, axis=0) return nn.op.reshape( @@ -808,53 +515,23 @@ def lm_head_forward(self, x: nn.Tensor): ret : nn.Tensor The output tensor for the lm_head layer. """ - if self.config.fp8_quant: - if DataType(self.config.quantize_dtype).type_code == DataTypeCode.E4M3Float: - dequant_func = self.config._dequantize_e4m3 - elif DataType(self.config.quantize_dtype).type_code == DataTypeCode.E5M2Float: - dequant_func = self.config._dequantize_e5m2 - else: - raise NotImplementedError() - else: - dequant_func = self.confg._dequantize - - if not self.no_scale: - w = nn.op.tensor_expr_op( # pylint: disable=invalid-name - lambda weight, scale: dequant_func( # pylint: disable=protected-access - weight, - scale, - axis=-1, - out_shape=[ - ( - tir.IntImm("int64", self.num) - if isinstance(self.num, int) - else weight.shape[0] - ), - tir.IntImm("int64", self.dim), - ], - ), - name_hint="dequantize", - args=[self.q_weight, self.q_scale], - ) - else: - w = nn.op.tensor_expr_op( # pylint: disable=invalid-name - lambda weight: dequant_func( # pylint: disable=protected-access - weight, - axis=-1, - out_shape=[ - ( - tir.IntImm("int64", self.num) - if isinstance(self.num, int) - else weight.shape[0] - ), - tir.IntImm("int64", self.dim), - ], - ), - name_hint="dequantize", - args=[ - self.q_weight, + w = nn.op.tensor_expr_op( # pylint: disable=invalid-name + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + axis=-1, + out_shape=[ + ( + tir.IntImm("int64", self.num) + if isinstance(self.num, int) + else weight.shape[0] + ), + tir.IntImm("int64", self.dim), ], - ) + ), + name_hint="dequantize", + args=[self.q_weight, self.q_scale], + ) w = nn.op.permute_dims(w) return nn.op.matmul(x, w, out_dtype="float32") @@ -873,16 +550,14 @@ def __init__( self.in_features = in_features self.out_features = out_features self.config = config - self.no_scale = self.config.no_scale num_group = tir.ceildiv(in_features, config.group_size) self.q_weight = nn.Parameter( (num_local_experts, out_features, config.num_storage_per_group * num_group), config.storage_dtype, ) - if not self.no_scale: - self.q_scale = nn.Parameter( - (num_local_experts, out_features, num_group), config.model_dtype - ) + self.q_scale = nn.Parameter( + (num_local_experts, out_features, num_group), config.model_dtype + ) self.quantize_dtype = config.quantize_dtype self.group_size = config.group_size self.dtype = config.model_dtype @@ -918,8 +593,7 @@ def from_mixtral_experts( if "shard_strategy" in src.weight.attrs: shard = src.weight.attrs["shard_strategy"] apply_sharding(shard, f"{shard.name}_q_weight", quantized_mistral_experts.q_weight) - if not config.no_scale: - apply_sharding(shard, f"{shard.name}_q_scale", quantized_mistral_experts.q_scale) + apply_sharding(shard, f"{shard.name}_q_scale", quantized_mistral_experts.q_scale) return quantized_mistral_experts def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name @@ -941,45 +615,26 @@ def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disa ret : nn.Tensor The output tensor for the group quantized mistral experts layer. """ - from mlc_chat.op import moe_matmul # pylint: disable=import-outside-toplevel + from mlc_llm.op import moe_matmul # pylint: disable=import-outside-toplevel assert x.ndim == 2 if indptr.ndim == 2: # single-batch assert indptr.shape[0] == 1 - if not self.no_scale: - return moe_matmul.dequantize_gemv( - x, - self.q_weight, - self.q_scale, - indptr, - quantize_dtype=self.quantize_dtype, - group_size=self.group_size, - ) - else: - return moe_matmul.dequantize_gemv_no_scale( - x, - self.q_weight, - indptr, - quantize_dtype=self.quantize_dtype, - group_size=self.group_size, - ) - assert indptr.ndim == 1 - if not self.no_scale: - return moe_matmul.dequantize_group_gemm( + return moe_matmul.dequantize_gemv( x, self.q_weight, self.q_scale, indptr, quantize_dtype=self.quantize_dtype, - indptr_dtype=indptr.dtype, - group_size=self.group_size, - ) - else: - return moe_matmul.dequantize_group_gemm_no_scale( - x, - self.q_weight, - indptr, - quantize_dtype=self.quantize_dtype, - indptr_dtype=indptr.dtype, group_size=self.group_size, ) + assert indptr.ndim == 1 + return moe_matmul.dequantize_group_gemm( + x, + self.q_weight, + self.q_scale, + indptr, + quantize_dtype=self.quantize_dtype, + indptr_dtype=indptr.dtype, + group_size=self.group_size, + ) diff --git a/python/mlc_llm/quantization/no_quantization.py b/python/mlc_llm/quantization/no_quantization.py index b1944c17f5..bd211fd724 100644 --- a/python/mlc_llm/quantization/no_quantization.py +++ b/python/mlc_llm/quantization/no_quantization.py @@ -1,4 +1,5 @@ """The no quantization config""" + from dataclasses import dataclass diff --git a/python/mlc_llm/quantization/per_tensor_quantization.py b/python/mlc_llm/quantization/per_tensor_quantization.py index d7ea2c9f17..ff20c7e7dd 100644 --- a/python/mlc_llm/quantization/per_tensor_quantization.py +++ b/python/mlc_llm/quantization/per_tensor_quantization.py @@ -1,48 +1,52 @@ """The per-tensor quantization config""" +import functools from dataclasses import dataclass -from functools import partial -from typing import Any, Callable, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union -from tvm import DataType, DataTypeCode, IRModule -from tvm import dlight as dl -from tvm import relax, te, tir, topi +from tvm import DataType, DataTypeCode, IRModule, nd, te, tir, topi from tvm.relax.frontend import nn from tvm.runtime import NDArray -from tvm.target import Target from mlc_llm.loader import QuantizeMapping from mlc_llm.nn import MixtralExperts from mlc_llm.support import logging -from mlc_llm.support import tensor_parallel as tp from .utils import ( + apply_sharding, + compile_quantize_func, + convert_uint_packed_fp8_to_float, is_final_fc, is_moe_gate, - convert_uint_packed_fp8_to_float, - compile_quantize_func, - apply_sharding, + pack_weight, ) logger = logging.getLogger(__name__) @dataclass -class PerTensorQuantize: +class PerTensorQuantize: # pylint: disable=too-many-instance-attributes + """Configuration for per-tensor quantization""" + name: str kind: str - activation_dtype: Literal["e4m3_float8", "e5m2_float8", "float16"] - weight_dtype: Literal["e4m3_float8", "e5m2_float8", "float16"] - storage_dtype: Literal["uint32"] + activation_dtype: Literal["e4m3_float8", "e5m2_float8"] + weight_dtype: Literal["e4m3_float8", "e5m2_float8"] + storage_dtype: Literal["uint32", "e4m3_float8", "e5m2_float8"] model_dtype: Literal["float16"] quantize_embedding: bool = True - quantize_linear: bool = True quantize_final_fc: bool = True + quantize_linear: bool = True num_elem_per_storage: int = 0 - num_storage_per_group: int = 0 max_int_value: int = 0 - no_scale: bool = False + use_scale: bool = True + # The calibration mode for quantization. If set to "inference", the model is built for + # inference. This should be used after calibration is done. + # If set to "max", the model is built for calibration that computes the scale using max value of + # the activations. + calibration_mode: Literal["inference", "max"] = "inference" + tensor_parallel_shards: int = 1 def __post_init__(self): assert self.kind == "per-tensor-quant" @@ -53,7 +57,11 @@ def __post_init__(self): self._quantize_func_cache = {} def quantize_model( - self, model: nn.Module, quant_map: QuantizeMapping, name_prefix: str + self, + model: nn.Module, + quant_map: QuantizeMapping, + name_prefix: str, + tensor_parallel_shards: int, ) -> nn.Module: """ Quantize model with per-tensor quantization @@ -69,12 +77,17 @@ def quantize_model( name_prefix : str The name prefix for visited weight. + tensor_parallel_shards : int + The number of tensor parallel shards. + Returns ------- ret : nn.Module The quantized nn.Module. """ + self.tensor_parallel_shards = tensor_parallel_shards + class _Mutator(nn.Mutator): def __init__(self, config: PerTensorQuantize, quant_map: QuantizeMapping) -> None: super().__init__() @@ -101,7 +114,7 @@ def visit_module(self, name: str, node: nn.Module) -> Any: weight_name = f"{name}.weight" param_names = ( [f"{name}.q_weight", f"{name}.q_scale"] - if not self.config.no_scale + if self.config.use_scale else [ f"{name}.q_weight", ] @@ -109,33 +122,48 @@ def visit_module(self, name: str, node: nn.Module) -> Any: if ( isinstance(node, nn.Linear) and self.config.quantize_linear - and not is_moe_gate(name) and (not is_final_fc(name) or self.config.quantize_final_fc) + and not is_moe_gate(name, node) ): self.quant_map.param_map[weight_name] = param_names self.quant_map.map_func[weight_name] = self.config.quantize_weight - op = PerTensorQuantizeLinear.from_linear(node, self.config) - if hasattr(op, "add_calibration_params"): - self.quant_map = op.add_calibration_params(self.quant_map, name) - return op - if isinstance(node, nn.Embedding) and self.config.quantize_embedding: + op = PerTensorQuantizeLinear.from_linear(node, self.config, name) + elif isinstance(node, nn.Embedding) and self.config.quantize_embedding: self.quant_map.param_map[weight_name] = param_names self.quant_map.map_func[weight_name] = self.config.quantize_weight - return PerTensorQuantizeEmbedding.from_embedding(node, self.config) - if isinstance(node, MixtralExperts): + op = PerTensorQuantizeEmbedding.from_embedding(node, self.config) + elif isinstance(node, MixtralExperts): self.quant_map.param_map[weight_name] = param_names self.quant_map.map_func[weight_name] = self.config.quantize_weight - op = PerTensorQuantizeMixtralExperts.from_mixtral_experts(node, self.config) - self.quant_map = op.add_calibration_params(self.quant_map, name) - return op - return self.visit(name, node) + op = PerTensorQuantizeMixtralExperts.from_mixtral_experts( + node, self.config, name + ) + else: + return self.visit(name, node) + + if hasattr(op, "q_calibration_scale") and op.q_calibration_scale: + # update quant_map for calibration scale + param_name = f"{name}.q_calibration_scale" + old_map_func = self.quant_map.map_func[weight_name] + + def map_func(*args, **kwargs): + # placeholder for calibration scale, the actual value will be set after + # calibration. + scale = nd.empty( + shape=op.q_calibration_scale.shape, dtype=op.q_calibration_scale.dtype + ) + return [*old_map_func(*args, **kwargs), scale] + + self.quant_map.param_map[weight_name].append(param_name) + self.quant_map.map_func[weight_name] = map_func + return op model.to(dtype=self.model_dtype) mutator = _Mutator(self, quant_map) model = mutator.visit(name_prefix, model) return model - def quantize_weight(self, weight) -> Union[Tuple[NDArray, NDArray], NDArray]: + def quantize_weight(self, weight) -> List[NDArray]: """ Quantize weight with per-tensor quantization. @@ -146,8 +174,8 @@ def quantize_weight(self, weight) -> Union[Tuple[NDArray, NDArray], NDArray]: Returns ------- - ret : Union[Tuple[NDArray, NDArray], NDArray] - The quantized weight and scale if output_transpose is True, otherwise the quantized weight. + ret : List[NDArray] + The quantized weight and the scale if use_scale is True. """ device = weight.device device_type = device.MASK2STR[device.device_type] @@ -157,16 +185,22 @@ def _create_quantize_func() -> IRModule: DataTypeCode.E4M3Float, DataTypeCode.E5M2Float, ]: - quantize_func = self._quantize_float8 + quantize_func = functools.partial( + self.quantize_float8, + quantize_dtype=self.weight_dtype, + storage_dtype=self.storage_dtype, + ) else: assert NotImplementedError() class Quantizer(nn.Module): - def main(self, weight: nn.Tensor): + """Quantizer module for per-tensor quantization.""" + + def main(self, weight: nn.Tensor): # pylint: disable=missing-function-docstring return quantize_func(weight) mod = Quantizer() - mod, _ = mod.export_tvm( + mod, _ = mod.export_tvm( # pylint: disable=unbalanced-tuple-unpacking spec={"main": {"weight": nn.spec.Tensor(weight.shape, weight.dtype)}} ) return mod @@ -179,173 +213,137 @@ def main(self, weight: nn.Tensor): self._quantize_func_cache[key] = quantize_func return quantize_func(weight) - def _quantize_float8(self, weight: nn.Tensor): - shape = weight.shape - quantize_dtype = DataType(self.weight_dtype) - - if self.no_scale: - # TODO(csullivan, vinx13): Ensure scheduling is applied after lowering relax ops. - # Currently, dlight scheduling is applied before the R.cast is lowered. - # scaled_weight = weight.astype(self.weight_dtype) - scaled_weight = nn.tensor_expr_op( - lambda scaled_weight: te.compute( - shape=weight.shape, - fcompute=lambda *idx: scaled_weight(*idx).astype(self.weight_dtype), - name="cast", - ), - "cast_weight", - args=[weight], - ) + def quantize_float8( # pylint: disable=too-many-locals + self, + tensor: nn.Tensor, + quantize_dtype: str, + storage_dtype: str, + ) -> Union[Tuple[nn.Tensor], Tuple[nn.Tensor, nn.Tensor]]: + """Per-tensor quantization for weight tensor, defined in tensor expression.""" + + if self.use_scale: + # min_scaling_factor taken from TRT-LLM + def _compute_scale(x: te.Tensor) -> te.Tensor: + max_abs = topi.max(topi.abs(x)) + min_scaling_factor = tir.const(1.0 / (self.max_int_value * 512.0), self.model_dtype) + scale = topi.maximum( + max_abs.astype(self.model_dtype) / self.max_int_value, min_scaling_factor + ).astype("float32") + scale = topi.expand_dims(scale, axis=0) + return scale + + scale = nn.tensor_expr_op(_compute_scale, "compute_scale", args=[tensor]) else: - from .fp8_quantization import quantize - - scaled_weight, scale = quantize( - weight, - quantize_dtype=quantize_dtype, - kind="fp8-max", - max_int_value=self.max_int_value, - ) + scale = None - if self.weight_dtype == self.storage_dtype: - quantized_weight = scaled_weight - elif self.num_elem_per_storage == 1: - quantized_weight = nn.tensor_expr_op( - lambda scaled_weight: te.compute( - shape=scaled_weight.shape, - fcompute=lambda *idx: tir.reinterpret(self.storage_dtype, scaled_weight(*idx)), - name="quantized_weight", - ), - "quantized_weight", - args=[scaled_weight], + def _compute_quantized_tensor(weight: te.Tensor, scale: Optional[te.Tensor]) -> te.Tensor: + elem_storage_dtype = ( + f"uint{DataType(quantize_dtype).bits}" + if DataType(storage_dtype).type_code == DataTypeCode.UINT + else quantize_dtype ) - else: - axis = -1 - k = shape[axis] - r = te.reduce_axis( - (0, self.num_elem_per_storage), name="r" - ) # pylint: disable=invalid-name - quantized_weight_shape = ( - *weight.shape[:axis], - tir.ceildiv(weight.shape[axis], self.num_elem_per_storage), - ) - quantized_weight = nn.tensor_expr_op( - lambda scaled_weight: te.compute( - shape=quantized_weight_shape, - fcompute=lambda *idx: tir.sum( - tir.if_then_else( - idx[axis] * self.num_elem_per_storage + r < k, - tir.reinterpret( - "uint8", - scaled_weight( - *idx[:axis], idx[axis] * self.num_elem_per_storage + r - ), - ).astype(self.storage_dtype) - << (r * quantize_dtype.bits), - 0, + scaled_tensor = te.compute( + shape=weight.shape, + fcompute=lambda *idx: tir.Cast( + self.storage_dtype, + tir.reinterpret( + elem_storage_dtype, + tir.Cast( + quantize_dtype, + weight(*idx) / scale(0) if scale is not None else weight(*idx), ), - axis=r, ), - name="quantized_weight", ), - "quantized_weight", - args=[scaled_weight], ) - if self.no_scale: - return (quantized_weight,) - return quantized_weight, scale + if quantize_dtype == self.storage_dtype: + return scaled_tensor + + packed_weight = pack_weight( + scaled_tensor, + axis=-1, + num_elem_per_storage=self.num_elem_per_storage, + weight_dtype=self.weight_dtype, + storage_dtype=self.storage_dtype, + ) + + return packed_weight - def _quantize_float16(self, weight: nn.Tensor): - shape = weight.shape - return (weight,) + quantized_tensor = nn.tensor_expr_op( + _compute_quantized_tensor, "compute_quantized_tensor", args=[tensor, scale] + ) + + if self.use_scale: + return quantized_tensor, scale + return (quantized_tensor,) def _dequantize( self, q_weight: te.Tensor, scale: Optional[te.Tensor] = None, - out_shape: Optional[List[tir.PrimExpr]] = None, + out_shape: Optional[Sequence[tir.PrimExpr]] = None, ) -> te.Tensor: - if not self.no_scale: + if self.use_scale: assert scale is not None if DataType(self.weight_dtype).type_code in [ DataTypeCode.E4M3Float, DataTypeCode.E5M2Float, ]: - return self._dequantize_float8(q_weight, scale, out_shape) + return self.dequantize_float8(q_weight, scale, self.weight_dtype, out_shape) raise NotImplementedError() - def _dequantize_float8( + def dequantize_float8( self, - q_weight: te.Tensor, - scale: Optional[te.Tensor] = None, - out_shape: Optional[List[tir.PrimExpr]] = None, + q_tensor: te.Tensor, + scale: Optional[te.Tensor], + quantize_dtype: str, + out_shape: Optional[Sequence[tir.PrimExpr]] = None, ) -> te.Tensor: - if out_shape is None: - out_shape = (*q_weight.shape[:-1], q_weight.shape[-1] * self.num_elem_per_storage) - - if self.weight_dtype == self.storage_dtype: - weight = q_weight.astype(self.model_dtype) - elif self.num_elem_per_storage == 1: - weight = te.compute( - shape=out_shape, - fcompute=lambda *idx: tir.reinterpret(self.weight_dtype, q_weight(*idx)).astype( - self.model_dtype - ), - name="dequantize_weight", - ) - else: - weight = convert_uint_packed_fp8_to_float( - q_weight, - DataType(self.weight_dtype).bits, + """Dequantize a fp8 tensor (input or weight) to higher-precision float.""" + if quantize_dtype != self.storage_dtype: + dequantized_tensor = convert_uint_packed_fp8_to_float( + q_tensor, self.num_elem_per_storage, self.storage_dtype, self.model_dtype, - self.weight_dtype, + quantize_dtype, axis=-1, out_shape=out_shape, ) + else: + dequantized_tensor = q_tensor.astype(self.model_dtype) + if scale is not None: + dequantized_tensor = dequantized_tensor * scale.astype(dequantized_tensor.dtype) + return dequantized_tensor - if not self.no_scale: - weight = weight * scale - return weight +class PerTensorQuantizeLinear(nn.Module): # pylint: disable=too-many-instance-attributes + """An nn.Linear module with per-tensor quantization.""" -class PerTensorQuantizeLinear(nn.Module): def __init__( # pylint: disable=too-many-arguments self, in_features: int, out_features: Union[int, tir.Var], config: PerTensorQuantize, + name: str, bias: bool = True, out_dtype: Optional[str] = None, ) -> None: - """ - Converts a non-quantized nn.Linear to a per-tensor quantized PerTensorQuantizeLinear - - Parameters - ---------- - src : nn.Linear - The non-quantized nn.Linear. - - config : PerTensorQuantize - The per-tensor quantization config. - - Returns - ------- - ret: PerTensorQuantizeLinear - The per-tensor quantized linear layer. - """ super().__init__() self.in_features = in_features self.out_features = out_features - self.out_dtype = out_dtype + self.out_dtype = out_dtype or config.model_dtype self.config = config + self.name = name self.q_weight = nn.Parameter( (out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), config.storage_dtype, ) - if not config.no_scale: - self.q_scale = nn.Parameter((1,), config.model_dtype) + self.q_calibration_scale = None + if config.use_scale: + self.q_scale = nn.Parameter((1,), "float32") + if config.calibration_mode == "inference": + self.q_calibration_scale = nn.Parameter((1,), "float32") else: self.q_scale = None if bias: @@ -356,48 +354,48 @@ def __init__( # pylint: disable=too-many-arguments self.bias = None @classmethod - def from_linear(cls, src: nn.Linear, config: PerTensorQuantize) -> "PerTensorQuantizeLinear": + def from_linear( + cls, src: nn.Linear, config: PerTensorQuantize, name: str + ) -> "PerTensorQuantizeLinear": + """ + Converts a non-quantized nn.Linear to a per-tensor quantized PerTensorQuantizeLinear - if ( - DataType(config.weight_dtype).type_code - in [ - DataTypeCode.E4M3Float, - DataTypeCode.E5M2Float, - ] - # Activation calibration - and any(q_kind in config.name for q_kind in ["max", "cast"]) - ): - from .fp8_quantization import PTQLinearFP8 + Parameters + ---------- + src : nn.Linear + The non-quantized nn.Linear. - quantized_linear = PTQLinearFP8.from_linear( - src, - config, - ) - return quantized_linear + config : PerTensorQuantize + The per-tensor quantization config. + + name: str + The name of the layer. - # For dynamic shape, src.out_features is `"name"`; src.weight.shape[0] is `tir.Var("name")` + Returns + ------- + ret : PerTensorQuantizeLinear + The per-tensor quantized PerTensorQuantizeLinear layer. + """ out_features, in_features = src.weight.shape quantized_linear = cls( in_features=in_features, out_features=out_features, config=config, + name=name, bias=getattr(src, "bias", None) is not None, out_dtype=src.out_dtype, ) if quantized_linear.bias is not None: quantized_linear.bias.attrs = src.bias.attrs - if "shard_strategy" in src.weight.attrs: shard = src.weight.attrs["shard_strategy"] apply_sharding(shard, f"{shard.name}_q_weight", quantized_linear.q_weight) - apply_sharding( - tp.ShardScalar(name=shard.name), f"{shard.name}_q_scale", quantized_linear.q_scale - ) + # scale doesn't need to be sharded since it's the same for all shards return quantized_linear def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name """ - Forward method for group quantized linear layer. + Forward method for per-tensor quantized linear layer. Parameters ---------- @@ -407,26 +405,59 @@ def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name Returns ------- ret : nn.Tensor - The output tensor for the group quantized linear layer. + The output tensor for the per-tensor quantized linear layer. """ - assert DataType(self.config.weight_dtype).type_code in [ - DataTypeCode.E4M3Float, - DataTypeCode.E5M2Float, - ] - w = nn.op.tensor_expr_op( - lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access - weight, - scale, - out_shape=[ - tir.IntImm("int64", self.out_features), - tir.IntImm("int64", self.in_features), - ], - ), - "dequantize", - args=[self.q_weight, self.q_scale], - ) - w = nn.op.permute_dims(w) - x = nn.op.matmul(x, w, out_dtype=self.out_dtype) + # Note: Use calibration scale when calibration is enabled + if self.config.calibration_mode == "inference": + if self.q_calibration_scale: + x /= self.q_calibration_scale.astype(x.dtype) + x_q = x.astype(self.config.activation_dtype) + x_scale = self.q_calibration_scale + elif self.config.calibration_mode == "max": + _, x_scale = self.config.quantize_float8( # type: ignore + x, + quantize_dtype=self.config.activation_dtype, + storage_dtype=self.config.storage_dtype, + ) + if self.config.tensor_parallel_shards > 1: + x_scale = nn.ccl_allreduce(x_scale, "max") + x_scale = nn.extern( + "mlc_llm.calibration_observer", + [f"{self.name}.q_calibration_scale", "max", x_scale], + out=nn.Tensor.placeholder(x_scale.shape, x_scale.dtype), + ) + x_q = (x / x_scale.astype(x.dtype)).astype(self.config.activation_dtype) + x = x_q.astype(self.config.model_dtype) * x_scale.astype(self.config.model_dtype) + else: + raise ValueError(f"Unknown calibration mode: {self.config.calibration_mode}") + + if ( + self.config.weight_dtype == self.config.storage_dtype + and self.config.calibration_mode == "inference" + ): + x = nn.op.matmul(x_q, nn.permute_dims(self.q_weight), out_dtype="float32") + if self.config.use_scale: + scale = x_scale * self.q_scale + x = x * scale + x = x.astype(self.out_dtype) + else: + w = nn.op.tensor_expr_op( + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + out_shape=[ + ( + tir.IntImm("int64", self.out_features) + if isinstance(self.out_features, int) + else weight.shape[0] + ), + tir.IntImm("int64", self.in_features), + ], + ), + "dequantize", + args=[self.q_weight, self.q_scale], + ) + x = nn.op.matmul(x, nn.permute_dims(w), out_dtype=self.out_dtype) if self.bias is not None: x = x + self.bias return x @@ -455,8 +486,8 @@ def __init__(self, num: Union[int, tir.Var], dim: int, config: PerTensorQuantize self.q_weight = nn.Parameter( (num, tir.ceildiv(dim, config.num_elem_per_storage)), config.storage_dtype ) - if not self.config.no_scale: - self.q_scale = nn.Parameter((1,), config.model_dtype) + if self.config.use_scale: + self.q_scale = nn.Parameter((1,), "float32") else: self.q_scale = None @@ -485,7 +516,7 @@ def from_embedding( def forward(self, x: nn.Tensor): # pylint: disable=invalid-name """ - Forward method for group quantized embedding layer. + Forward method for per-tensor quantized embedding layer. Parameters ---------- @@ -549,17 +580,21 @@ def lm_head_forward(self, x: nn.Tensor): class PerTensorQuantizeMixtralExperts(nn.Module): # pylint: disable=too-many-instance-attributes """An MixtralExperts module with group quantization""" + _IMPL: Dict[str, Type["PerTensorQuantizeMixtralExperts"]] = {} + def __init__( self, num_local_experts, in_features, out_features, config: PerTensorQuantize, + name: str, ): # pylint: disable=too-many-arguments self.num_local_experts = num_local_experts self.in_features = in_features self.out_features = out_features self.config = config + self.name = name self.q_weight = nn.Parameter( ( num_local_experts, @@ -568,8 +603,11 @@ def __init__( ), config.storage_dtype, ) - if not config.no_scale: - self.q_scale = nn.Parameter((1,), config.model_dtype) + self.q_calibration_scale = None + if config.use_scale: + self.q_scale = nn.Parameter((1,), "float32") + if config.calibration_mode == "inference": + self.q_calibration_scale = nn.Parameter((1,), "float32") else: self.q_scale = None @@ -577,9 +615,11 @@ def __init__( def from_mixtral_experts( src: "MixtralExperts", config: PerTensorQuantize, + name: str, ) -> "PerTensorQuantizeMixtralExperts": """ - Converts a non-quantized MixtralExperts to a per-tensor quantized PerTensorQuantizeMixtralExperts + Converts a non-quantized MixtralExperts to a per-tensor quantized + PerTensorQuantizeMixtralExperts Parameters ---------- @@ -589,6 +629,9 @@ def from_mixtral_experts( config : PerTensorQuantize The per-tensor quantization config + name: str + The name of the layer. + Returns ------- ret : PerTensorQuantizeMixtralExperts @@ -598,20 +641,13 @@ def from_mixtral_experts( DataTypeCode.E4M3Float, DataTypeCode.E5M2Float, ]: - from .fp8_quantization import MixtralExpertsFP8 - - quantized_mixtral_experts = MixtralExpertsFP8.from_mixtral_experts( - src, - config, + return PerTensorQuantizeMixtralExperts._IMPL["fp8"].from_mixtral_experts( + src, config, name ) - # TODO(csullivan): Confirm with @vinx13 and delete this before merge - # quantized_mixtral_experts.no_scale = config.no_scale - else: - raise NotImplementedError() - return quantized_mixtral_experts + raise NotImplementedError() def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name - """Forward method for group quantized mistral experts. + """Forward method for per-tensor quantized mistral experts. Parameters ---------- @@ -624,6 +660,6 @@ def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor: # pylint: disa Returns ------- ret : nn.Tensor - The output tensor for the group quantized mistral experts layer. + The output tensor for the per-tensor quantized mistral experts layer. """ raise NotImplementedError() diff --git a/python/mlc_llm/quantization/quantization.py b/python/mlc_llm/quantization/quantization.py index c94e7c5523..1a5719a63f 100644 --- a/python/mlc_llm/quantization/quantization.py +++ b/python/mlc_llm/quantization/quantization.py @@ -1,4 +1,5 @@ """A centralized registry of all existing quantization methods and their configurations.""" + from typing import Any, Dict from .awq_quantization import AWQQuantize @@ -6,7 +7,6 @@ from .group_quantization import GroupQuantize from .no_quantization import NoQuantize from .per_tensor_quantization import PerTensorQuantize -from .smooth_quantization import SmoothQuantize Quantization = Any """Quantization is an object that represents an quantization algorithm. It is required to @@ -119,140 +119,42 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr storage_dtype="int8", model_dtype="float16", ), - "fp8_e4m3_e4m3_max_calibration": PerTensorQuantize( - name="fp8_e4m3_e4m3_max_calibration", - kind="per-tensor-quant", - activation_dtype="e4m3_float8", - weight_dtype="e4m3_float8", - storage_dtype="uint8", - model_dtype="float16", - quantize_embedding=False, - quantize_linear=False, - ), - "fp8_e4m3_e4m3_max": PerTensorQuantize( - name="fp8_e4m3_e4m3_max", + "e5m2_e5m2_f16": PerTensorQuantize( + name="e5m2_e5m2_f16", kind="per-tensor-quant", - activation_dtype="e4m3_float8", - weight_dtype="e4m3_float8", - storage_dtype="uint8", + activation_dtype="e5m2_float8", + weight_dtype="e5m2_float8", + storage_dtype="e5m2_float8", model_dtype="float16", + quantize_final_fc=False, quantize_embedding=False, - quantize_linear=False, + quantize_linear=True, + use_scale=False, ), - "ptq_e4m3_e4m3_max_calibration": PerTensorQuantize( - name="ptq_e4m3_e4m3_max_calibration", + "e4m3_e4m3_f16": PerTensorQuantize( + name="e4m3_e4m3_f16", kind="per-tensor-quant", activation_dtype="e4m3_float8", weight_dtype="e4m3_float8", storage_dtype="e4m3_float8", model_dtype="float16", + quantize_final_fc=False, quantize_embedding=False, quantize_linear=True, - # TODO(csullivan): Refactor sharding of calibration scale - # to enable lm_head quantization for TP > 1 - quantize_final_fc=False, + use_scale=True, + calibration_mode="inference", ), - "ptq_e4m3_e4m3_max": PerTensorQuantize( - name="ptq_e4m3_e4m3_max", + "e4m3_e4m3_f16_max_calibrate": PerTensorQuantize( + name="e4m3_e4m3_f16_max_calibrate", kind="per-tensor-quant", activation_dtype="e4m3_float8", weight_dtype="e4m3_float8", storage_dtype="e4m3_float8", model_dtype="float16", - quantize_embedding=False, - quantize_linear=True, quantize_final_fc=False, - ), - "smq_q8i8f16_0": SmoothQuantize( - name="smq_q8i8f16_0", - kind="smoothquant", - activation_dtype="int8", - weight_dtype="int8", - zero_point_dtype="int8", - accumulator_dtype="int32", - model_dtype="float16", - ), - "smq_q8i8f16_1": SmoothQuantize( - name="smq_q8i8f16_1", - kind="smoothquant", - activation_dtype="int8", - weight_dtype="int8", - zero_point_dtype="int8", - accumulator_dtype="int32", - model_dtype="float16", - ), - "smq_q8i8f16_2": SmoothQuantize( - name="smq_q8i8f16_2", - kind="smoothquant", - activation_dtype="int8", - weight_dtype="int8", - zero_point_dtype="int8", - accumulator_dtype="int32", - model_dtype="float16", - ), - "smq_e4m3_float8_0": SmoothQuantize( - name="smq_e4m3_float8_0", - kind="smoothquant", - activation_dtype="e4m3_float8", - weight_dtype="e4m3_float8", - zero_point_dtype="float16", - accumulator_dtype="float32", - model_dtype="float16", - ), - "smq_e4m3_float8_1": SmoothQuantize( - name="smq_e4m3_float8_1", - kind="smoothquant", - activation_dtype="e4m3_float8", - weight_dtype="e4m3_float8", - zero_point_dtype="float16", - accumulator_dtype="float32", - model_dtype="float16", - ), - "smq_e4m3_float8_2": SmoothQuantize( - name="smq_e4m3_float8_2", - kind="smoothquant", - activation_dtype="e4m3_float8", - weight_dtype="e4m3_float8", - zero_point_dtype="float16", - accumulator_dtype="float32", - model_dtype="float16", - ), - "smq_e5m2_float8_0": SmoothQuantize( - name="smq_e5m2_float8_0", - kind="smoothquant", - activation_dtype="e5m2_float8", - weight_dtype="e5m2_float8", - zero_point_dtype="float16", - accumulator_dtype="float32", - model_dtype="float16", - ), - "smq_e5m2_float8_1": SmoothQuantize( - name="smq_e5m2_float8_1", - kind="smoothquant", - activation_dtype="e5m2_float8", - weight_dtype="e5m2_float8", - zero_point_dtype="float16", - accumulator_dtype="float32", - model_dtype="float16", - ), - "smq_e5m2_float8_2": SmoothQuantize( - name="smq_e5m2_float8_2", - kind="smoothquant", - activation_dtype="e5m2_float8", - weight_dtype="e5m2_float8", - zero_point_dtype="float16", - accumulator_dtype="float32", - model_dtype="float16", - ), - "fp16_max_calibration": PerTensorQuantize( - name="fp16_max_calibration", - kind="per-tensor-quant", - activation_dtype="float16", - weight_dtype="float16", - storage_dtype="float16", - model_dtype="float16", quantize_embedding=False, - quantize_linear=False, - no_scale=True, + quantize_linear=True, + use_scale=True, + calibration_mode="max", ), } diff --git a/python/mlc_llm/quantization/utils.py b/python/mlc_llm/quantization/utils.py index 99504edb6c..d44a293d28 100644 --- a/python/mlc_llm/quantization/utils.py +++ b/python/mlc_llm/quantization/utils.py @@ -6,7 +6,7 @@ from tvm import dlight as dl from tvm import relax, te, tir from tvm.relax.frontend import nn -from tvm.runtime import DataType +from tvm.runtime import DataType, DataTypeCode from tvm.target import Target from mlc_llm.support import tensor_parallel as tp @@ -48,58 +48,21 @@ def convert_uint_to_float( # pylint: disable=too-many-arguments ) -def convert_uint_packed_fp8_to_float( # pylint: disable=too-many-arguments - weight: te.Tensor, - bits: int, - num_elem_per_storage: int, - storage_dtype: str, - model_dtype: str, - quant_dtype: str, - axis: int = -1, - out_shape: Optional[List[tir.PrimExpr]] = None, - ft_reorder: Optional[bool] = False, -) -> te.Tensor: - """Convert a quantized uint weight tensor to an unquantized e4m3_float8 weight tensor.""" - # Does *not* have FT reoder support right now, can add back in (need to verify bit-match for fp8) - if ft_reorder: - raise NotImplementedError() - assert quant_dtype in ["e4m3_float8", "e5m2_float8"] - elem_storage_dtype = f"uint{bits}" - tir_bin_mask = tir.const((1 << bits) - 1, elem_storage_dtype) - if out_shape is None: - out_shape = weight.shape - out_shape[axis] *= num_elem_per_storage - axis = axis if axis >= 0 else len(out_shape) + axis - return te.compute( - shape=out_shape, - fcompute=lambda *idx: tir.reinterpret( - DataType(quant_dtype), - tir.bitwise_and( - tir.shift_right( - weight(*idx[:axis], idx[axis] // num_elem_per_storage, *idx[axis + 1 :]), - ((idx[axis] % num_elem_per_storage) * bits).astype(storage_dtype), - ).astype(elem_storage_dtype), - tir_bin_mask, - ), - ).astype(model_dtype), - ) - - def is_final_fc(name: str) -> bool: """Determines whether the parameter is the last layer based on its name.""" # TODO: use more specious condition to determine final fc # pylint: disable=fixme return name in ["head", "lm_head", "lm_head.linear", "embed_out"] -def is_moe_gate(name: str) -> bool: +def is_moe_gate(name: str, node: nn.Linear) -> bool: """Check whether the parameter is the MoE gate layer.""" - return name.endswith("gate") + return name.endswith("gate") and isinstance(node.out_features, int) and node.out_features <= 64 def compile_quantize_func(mod: IRModule, device) -> Callable: """Compile a quantization function for a given device.""" device_type = device.MASK2STR[device.device_type] - if device_type in ["cuda", "rocm", "metal", "vulkan"]: + if device_type in ["cuda", "rocm", "metal", "vulkan", "opencl"]: target = Target.current() if target is None: target = Target.from_device(device) @@ -119,19 +82,55 @@ def compile_quantize_func(mod: IRModule, device) -> Callable: return vm["main"] -def apply_sharding(shard, name: str, weight: nn.Parameter): - if isinstance(shard, tp.ShardSingleDim): +def apply_sharding(shard_strategy, name: str, weight: nn.Parameter): + """Apply sharding strategy to a weight.""" + if isinstance(shard_strategy, tp.ShardSingleDim): weight.attrs["shard_strategy"] = tp.ShardSingleDim( name=name, - dim=shard.dim, - segs=shard.segs, - ) - elif isinstance(shard, tp.ShardScalar): - weight.attrs["shard_strategy"] = tp.ShardScalar( - name=name, + dim=shard_strategy.dim, + segs=shard_strategy.segs, ) else: - raise NotImplementedError(f"Unknowing sharding strategy: {shard}") + raise NotImplementedError(f"Unknowing sharding strategy: {shard_strategy}") + + +def convert_uint_packed_fp8_to_float( # pylint: disable=too-many-arguments + weight: te.Tensor, + num_elem_per_storage: int, + storage_dtype: str, + model_dtype: str, + quant_dtype: str, + axis: int = -1, + out_shape: Optional[Sequence[tir.PrimExpr]] = None, +) -> te.Tensor: + """Unpack a fp8 value from the storage dtype and convert to float.""" + assert quant_dtype in ["e4m3_float8", "e5m2_float8"] + assert DataType(storage_dtype).type_code == DataTypeCode.UINT + bits = DataType(quant_dtype).bits + elem_storage_dtype = DataType(f"uint{bits}") + tir_bin_mask = tir.const((1 << bits) - 1, "uint8") + if axis < 0: + axis += len(weight.shape) + if out_shape is None: + out_shape = ( + *weight.shape[:axis], + weight.shape[axis] * num_elem_per_storage, + *weight.shape[axis + 1 :], + ) + axis = axis if axis >= 0 else len(out_shape) + axis + return te.compute( + shape=out_shape, + fcompute=lambda *idx: tir.reinterpret( + quant_dtype, + tir.bitwise_and( + tir.shift_right( + weight(*idx[:axis], idx[axis] // num_elem_per_storage, *idx[axis + 1 :]), + ((idx[axis] % num_elem_per_storage) * bits).astype(storage_dtype), + ).astype(elem_storage_dtype), + tir_bin_mask, + ), + ).astype(model_dtype), + ) def pack_weight( @@ -162,10 +161,12 @@ def pack_weight( """ assert weight.dtype == storage_dtype shape = weight.shape + if axis < 0: + axis += len(shape) k = shape[axis] axis = axis if axis >= 0 else len(shape) + axis if out_shape is None: - out_shape = (*shape[axis], tir.ceildiv(k, num_elem_per_storage), *shape[axis + 1 :]) + out_shape = (*shape[:axis], tir.ceildiv(k, num_elem_per_storage), *shape[axis + 1 :]) r = te.reduce_axis((0, num_elem_per_storage), name="r") # pylint: disable=invalid-name packed_weight = te.compute( shape=out_shape, diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 59358c1646..6d362c88d0 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -1,11 +1,13 @@ """Subdirectory of serving.""" # Load MLC LLM library by importing base +""" +# NOTE(@sunggg): These are disabled because we don't use them from .. import base -from .config import EngineConfig, GenerationConfig, SpeculativeMode +from .config import EngineConfig from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData from .engine import AsyncMLCEngine, MLCEngine -from .grammar import BNFGrammar, GrammarStateMatcher from .radix_tree import PagedRadixTree from .request import Request from .server import PopenServer +""" diff --git a/python/mlc_llm/serve/_ffi_api.py b/python/mlc_llm/serve/_ffi_api.py index d755fea6d3..30de604f4d 100644 --- a/python/mlc_llm/serve/_ffi_api.py +++ b/python/mlc_llm/serve/_ffi_api.py @@ -1,4 +1,5 @@ """FFI APIs for mlc_llm.serve""" + import tvm._ffi # Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc.serve" prefix. diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 6b808ac37b..c790a22d5a 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -1,153 +1,12 @@ """Configuration dataclasses used in MLC LLM serving""" -import enum import json from dataclasses import asdict, dataclass, field -from typing import Dict, List, Literal, Optional - -import tvm - -from . import _ffi_api +from typing import List, Literal, Optional, Tuple, Union @dataclass -class ResponseFormat: - """The response format dataclass. - - Parameters - ---------- - type : Literal["text", "json_object"] - The type of response format. Default: "text". - - schema : Optional[str] - The JSON schema string for the JSON response format. If None, a legal json string without - special restrictions will be generated. - - Could be specified when the response format is "json_object". Default: None. - """ - - type: Literal["text", "json_object"] = "text" - schema: Optional[str] = None - - def __post_init__(self): - if self.schema is not None and self.type != "json_object": - raise ValueError("JSON schema is only supported in JSON response format") - - -@dataclass -class GenerationConfig: # pylint: disable=too-many-instance-attributes - """The generation configuration dataclass. - - Parameters - ---------- - n : int - How many chat completion choices to generate for each input message. - - temperature : float - The value that applies to logits and modulates the next token probabilities. - - top_p : float - In sampling, only the most probable tokens with probabilities summed up to - `top_p` are kept for sampling. - - frequency_penalty : float - Positive values penalize new tokens based on their existing frequency - in the text so far, decreasing the model's likelihood to repeat the same - line verbatim. - - presence_penalty : float - Positive values penalize new tokens based on whether they appear in the text - so far, increasing the model's likelihood to talk about new topics. - - repetition_penalty : float - The penalty term that applies to logits to control token repetition in generation. - It will be suppressed when any of frequency_penalty and presence_penalty is - non-zero. - - logprobs : bool - Whether to return log probabilities of the output tokens or not. - If true, the log probabilities of each output token will be returned. - - top_logprobs : int - An integer between 0 and 5 specifying the number of most likely - tokens to return at each token position, each with an associated - log probability. - `logprobs` must be set to True if this parameter is used. - - logit_bias : Optional[Dict[int, float]] - The bias logit value added to selected tokens prior to sampling. - - max_tokens : Optional[int] - The maximum number of generated tokens, - or None, in which case the generation will not stop - until exceeding model capability or hit any stop criteria. - - seed : Optional[int] - The random seed of the generation. - The seed will be a random value if not specified. - - stop_strs : List[str] - The list of strings that mark the end of generation. - - stop_token_ids : List[int] - The list of token ids that mark the end of generation. - - ignore_eos: bool - When it is true, ignore the eos token and generate tokens until `max_tokens`. - Default is set to False. - - response_format : ResponseFormat - The response format of the generation output. - """ - - n: int = 1 - temperature: float = 0.8 - top_p: float = 0.95 - frequency_penalty: float = 0.0 - presence_penalty: float = 0.0 - repetition_penalty: float = 1.0 - logprobs: bool = False - top_logprobs: int = 0 - logit_bias: Optional[Dict[int, float]] = field(default_factory=dict) - - max_tokens: Optional[int] = 128 - seed: Optional[int] = None - stop_strs: List[str] = field(default_factory=list) - stop_token_ids: List[int] = field(default_factory=list) - ignore_eos: bool = False - - response_format: ResponseFormat = field(default_factory=ResponseFormat) - - def asjson(self) -> str: - """Return the config in string of JSON format.""" - return json.dumps(asdict(self)) - - @staticmethod - def from_json(json_str: str) -> "GenerationConfig": - """Construct a config from JSON string.""" - return GenerationConfig(**json.loads(json_str)) - - -class KVStateKind(enum.IntEnum): # pylint: disable=too-few-public-methods - """Possible kinds of KV state.""" - - ATTENTION = 0 - RNNSTATE = 1 - - -class SpeculativeMode(enum.IntEnum): - """The speculative mode.""" - - # Disable speculative decoding. - DISABLE = 0 - # The normal speculative decoding (small draft) mode. - SMALL_DRAFT = 1 - # The eagle-style speculative decoding. - EAGLE = 2 - - -@tvm._ffi.register_object("mlc.serve.EngineConfig") # pylint: disable=protected-access -class EngineConfig(tvm.runtime.Object): +class EngineConfig: # pylint: disable=too-many-instance-attributes """The class of MLCEngine execution configuration. Parameters @@ -155,74 +14,125 @@ class EngineConfig(tvm.runtime.Object): model : str The path to the model directory. - model_lib_path : str + model_lib : str The path to the model library. - additional_models : List[str] - The path to the additional models' directories. - - additional_model_lib_paths : List[str] - The path to the additional models' libraries. + additional_models : List[Union[str, Tuple[str, str]]] + The paths to the additional models' directories (and model libraries). + Each element is a single string (denoting the model directory) + or a tuple of two strings (denoting the model directory and model lib path). + + mode : Literal["local", "interactive", "server"] + The engine mode in MLC LLM. + We provide three preset modes: "local", "interactive" and "server". + The default mode is "local". + The choice of mode decides the values of "max_num_sequence", "max_total_sequence_length" + and "prefill_chunk_size" when they are not explicitly specified. + 1. Mode "local" refers to the local server deployment which has low + request concurrency. So the max batch size will be set to 4, and max + total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 2. Mode "interactive" refers to the interactive use of server, which + has at most 1 concurrent request. So the max batch size will be set to 1, + and max total sequence length and prefill chunk size are set to the context + window size (or sliding window size) of the model. + 3. Mode "server" refers to the large server use case which may handle + many concurrent request and want to use GPU memory as much as possible. + In this mode, we will automatically infer the largest possible max batch + size and max total sequence length. + + You can manually specify arguments "max_num_sequence", "max_total_sequence_length" and + "prefill_chunk_size" to override the automatic inferred values. + + tensor_parallel_shards : Optional[int] + Number of shards to split the model into in tensor parallelism multi-gpu inference. + + gpu_memory_utilization : Optional[float] + A number in (0, 1) denoting the fraction of GPU memory used by the server in total. + It is used to infer to maximum possible KV cache capacity. + When it is unspecified, it defaults to 0.85. + Under mode "local" or "interactive", the actual memory usage may be + significantly smaller than this number. Under mode "server", the actual + memory usage may be slightly larger than this number. kv_cache_page_size : int The number of consecutive tokens handled in each page in paged KV cache. - max_num_sequence : int + max_num_sequence : Optional[int] The maximum number of sequences that are allowed to be processed by the KV cache at any time. - max_total_sequence_length : int - The maximum length allowed for a single sequence in the engine. - - max_single_sequence_length : int + max_total_sequence_length : Optional[int] The maximum total number of tokens whose KV data are allowed to exist in the KV cache at any time. - prefill_chunk_size : int + max_single_sequence_length : Optional[int] + The maximum length allowed for a single sequence in the engine. + + prefill_chunk_size : Optional[int] The maximum total sequence length in a prefill. - max_history_size: int - The maximum history size for RNN state to rool back. + sliding_window_size : Optional[int] + The sliding window size in sliding window attention (SWA). - kv_state_kind: KVStateKind + attention_sink_size : Optional[int] + The number of attention sinks when sliding window is enabled.. + + max_history_size: Optional[int] + The maximum history size for RNN state to roll back. + + kv_state_kind: Optional[Literal["kv_cache", "rnn_state"]] The kind of cache. - speculative_mode : SpeculativeMode + speculative_mode : Literal["disable", "small_draft", "eagle", "medusa"] The speculative mode. + "disable" means speculative decoding is disabled. + "small_draft" means the normal speculative decoding (small draft) mode. + "eagle" means the eagle-style speculative decoding. + "medusa" means the medusa-style speculative decoding. spec_draft_length : int The number of tokens to generate in speculative proposal (draft). + + prefix_cache_mode : Literal["disable", "radix"] + The prefix cache mode. + "disable" means no prefix cache is disabled. + "radix" means the paged radix tree based prefix cache mode. + + prefix_cache_max_num_recycling_seqs: Optional[int] + The maximum number of recycling sequences in prefix cache, default as max_num_sequence. + And set 0 to disable prefix cache, set -1 to have infinite capacity prefix cache. + + verbose : bool + A boolean indicating whether to print logging info in engine. """ - def __init__( # pylint: disable=too-many-arguments - self, - model: str, - model_lib_path: str, - additional_models: List[str], - additional_model_lib_paths: List[str], - kv_cache_page_size: int, - max_num_sequence: int, - max_total_sequence_length: int, - max_single_sequence_length: int, - prefill_chunk_size: int, - max_history_size: int, - kv_state_kind: KVStateKind, - speculative_mode: SpeculativeMode, - spec_draft_length: int, - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.EngineConfig, # type: ignore # pylint: disable=no-member - model, - model_lib_path, - additional_models, - additional_model_lib_paths, - kv_cache_page_size, - max_num_sequence, - max_total_sequence_length, - max_single_sequence_length, - prefill_chunk_size, - max_history_size, - kv_state_kind, - speculative_mode, - spec_draft_length, - ) + model: Optional[str] = None + model_lib: Optional[str] = None + additional_models: List[Union[str, Tuple[str, str]]] = field(default_factory=list) + mode: Optional[Literal["local", "interactive", "server"]] = None + tensor_parallel_shards: Optional[int] = None + gpu_memory_utilization: Optional[float] = None + kv_cache_page_size: int = 16 + max_num_sequence: Optional[int] = None + max_total_sequence_length: Optional[int] = None + max_single_sequence_length: Optional[int] = None + prefill_chunk_size: Optional[int] = None + sliding_window_size: Optional[int] = None + attention_sink_size: Optional[int] = None + max_history_size: Optional[int] = None + kv_state_kind: Optional[Literal["kv_cache", "rnn_state"]] = None + speculative_mode: Literal["disable", "small_draft", "eagle", "medusa"] = "disable" + spec_draft_length: int = 4 + prefix_cache_mode: Literal["disable", "radix"] = "radix" + prefix_cache_max_num_recycling_seqs: Optional[int] = None + verbose: bool = True + + def asjson(self) -> str: + """Return the config in string of JSON format.""" + return json.dumps(asdict(self)) + + @staticmethod + def from_json(json_str: str) -> "EngineConfig": + """Construct a config from JSON string.""" + return EngineConfig(**json.loads(json_str)) diff --git a/python/mlc_llm/serve/data.py b/python/mlc_llm/serve/data.py index 1c56178ad1..53f7b3007c 100644 --- a/python/mlc_llm/serve/data.py +++ b/python/mlc_llm/serve/data.py @@ -112,11 +112,9 @@ def from_url(url: str, config: Dict) -> "ImageData": # pylint: disable=too-many size={"shortest_edge": image_input_size}, crop_size={"height": image_input_size, "width": image_input_size}, ) - quantization = config["quantization"] - out_dtype = "float16" if "f16" in quantization else "float32" image_features = tvm.nd.array( image_processor.preprocess(image_tensor, return_tensors="np")["pixel_values"].astype( - out_dtype + "float32" ) ) image_data = ImageData(image_features, image_embed_size) @@ -159,6 +157,8 @@ class SingleRequestStreamOutput: delta_token_ids: List[int] delta_logprob_json_strs: Optional[List[str]] finish_reason: Optional[str] + request_final_usage_json_str: Optional[str] + extra_prefix_string: str @tvm._ffi.register_object("mlc.serve.RequestStreamOutput") # pylint: disable=protected-access @@ -191,9 +191,18 @@ def unpack(self) -> Tuple[str, List[SingleRequestStreamOutput]]: The output instances, one for a request. """ fields = _ffi_api.RequestStreamOutputUnpack(self) # type: ignore # pylint: disable=no-member + request_final_usage_json_str = fields[4] request_id = str(fields[0]) + if request_final_usage_json_str is not None: + return ( + request_id, + [SingleRequestStreamOutput([], None, None, request_final_usage_json_str, "")], + ) + stream_outputs = [] - for i, (delta_token_ids, finish_reason) in enumerate(zip(fields[1], fields[3])): + for i, (delta_token_ids, finish_reason, extra_prefix_string) in enumerate( + zip(fields[1], fields[3], fields[5]) + ): delta_logprob_json_strs = ( [str(logprob_json_str) for logprob_json_str in fields[2][i]] if fields[2] is not None @@ -204,6 +213,8 @@ def unpack(self) -> Tuple[str, List[SingleRequestStreamOutput]]: delta_token_ids=list(delta_token_ids), delta_logprob_json_strs=delta_logprob_json_strs, finish_reason=str(finish_reason) if finish_reason is not None else None, + request_final_usage_json_str=None, + extra_prefix_string=str(extra_prefix_string), ) ) return request_id, stream_outputs diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 413c856db1..fa67c7a81c 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -14,18 +14,19 @@ List, Literal, Optional, + Tuple, Union, overload, ) from tvm.runtime import Device -from mlc_llm.protocol import openai_api_protocol +from mlc_llm.protocol import debug_protocol, openai_api_protocol +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import GenerationConfig, SpeculativeMode -from mlc_llm.serve.request import Request -from mlc_llm.streamer import TextStreamer +from mlc_llm.serve.config import EngineConfig from mlc_llm.support import logging +from mlc_llm.tokenizers import TextStreamer from . import engine_base @@ -33,16 +34,21 @@ logger = logging.getLogger(__name__) +# Note: we define both AsyncChat and Chat for Python type analysis. +class AsyncChat: # pylint: disable=too-few-public-methods + """The proxy class to direct to async chat completions.""" + + def __init__(self, engine: weakref.ReferenceType) -> None: + assert isinstance(engine(), AsyncMLCEngine) + self.completions = AsyncChatCompletion(engine) + + class Chat: # pylint: disable=too-few-public-methods """The proxy class to direct to chat completions.""" def __init__(self, engine: weakref.ReferenceType) -> None: - assert isinstance(engine(), (AsyncMLCEngine, MLCEngine)) - self.completions = ( - AsyncChatCompletion(engine) # type: ignore - if isinstance(engine(), AsyncMLCEngine) - else ChatCompletion(engine) # type: ignore - ) + assert isinstance(engine(), MLCEngine) + self.completions = ChatCompletion(engine) class AsyncChatCompletion: # pylint: disable=too-few-public-methods @@ -63,8 +69,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals messages: List[Dict[str, Any]], stream: Literal[True], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -72,14 +78,15 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, - temperature: float = 1.0, - top_p: float = 1.0, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + extra_body: Optional[Dict[str, Any]] = None, ) -> AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any]: """Asynchronous streaming chat completion interface with OpenAI API compatibility. The method is a coroutine that streams ChatCompletionStreamResponse @@ -93,6 +100,10 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + extra_body: Optional[Dict[str, Any]] = None, + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + Yields ------ stream_response : ChatCompletionStreamResponse @@ -112,8 +123,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -122,14 +133,15 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Literal[False] = False, - temperature: float = 1.0, - top_p: float = 1.0, + stream_options: Literal[None] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + extra_body: Optional[Dict[str, Any]] = None, ) -> openai_api_protocol.ChatCompletionResponse: """Asynchronous non-streaming chat completion interface with OpenAI API compatibility. The method is a coroutine that streams ChatCompletionStreamResponse @@ -143,8 +155,12 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + extra_body: Optional[Dict[str, Any]] = None, + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + Returns - ------ + ------- response : ChatCompletionResponse The chat completion response conforming to OpenAI API. See mlc_llm/protocol/openai_api_protocol.py or @@ -161,8 +177,8 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -171,14 +187,15 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + extra_body: Optional[Dict[str, Any]] = None, ) -> Union[ AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any], openai_api_protocol.ChatCompletionResponse, @@ -193,6 +210,10 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + extra_body: Optional[Dict[str, Any]] = None, + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + Raises ------ e : BadRequestError @@ -211,14 +232,19 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals seed=seed, stop=stop, stream=stream, + stream_options=( + openai_api_protocol.StreamOptions.model_validate(stream_options) + if stream_options is not None + else None + ), temperature=temperature, top_p=top_p, tools=tools, tool_choice=tool_choice, user=user, - ignore_eos=ignore_eos, response_format=response_format, request_id=request_id, + debug_config=(extra_body.get("debug_config", None) if extra_body is not None else None), ) @@ -240,8 +266,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals messages: List[Dict[str, Any]], stream: Literal[True], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -249,14 +275,15 @@ def create( # pylint: disable=too-many-arguments,too-many-locals n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, - temperature: float = 1.0, - top_p: float = 1.0, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + extra_body: Optional[Dict[str, Any]] = None, ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]: """Synchronous streaming chat completion interface with OpenAI API compatibility. The method streams back ChatCompletionStreamResponse that conforms to @@ -270,6 +297,10 @@ def create( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + extra_body: Optional[Dict[str, Any]] = None, + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + Yields ------ stream_response : ChatCompletionStreamResponse @@ -289,8 +320,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -299,14 +330,15 @@ def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Literal[False] = False, - temperature: float = 1.0, - top_p: float = 1.0, + stream_options: Literal[None] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + extra_body: Optional[Dict[str, Any]] = None, ) -> openai_api_protocol.ChatCompletionResponse: """Synchronous non-streaming chat completion interface with OpenAI API compatibility. @@ -318,6 +350,10 @@ def create( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + extra_body: Optional[Dict[str, Any]] = None, + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + Returns ------ response : ChatCompletionResponse @@ -336,8 +372,8 @@ def create( # pylint: disable=too-many-arguments,too-many-locals *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -346,14 +382,15 @@ def create( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + extra_body: Optional[Dict[str, Any]] = None, ) -> Union[ Iterator[openai_api_protocol.ChatCompletionStreamResponse], openai_api_protocol.ChatCompletionResponse, @@ -368,6 +405,10 @@ def create( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + extra_body: Optional[Dict[str, Any]] = None, + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + Raises ------ e : BadRequestError @@ -386,14 +427,19 @@ def create( # pylint: disable=too-many-arguments,too-many-locals seed=seed, stop=stop, stream=stream, + stream_options=( + openai_api_protocol.StreamOptions.model_validate(stream_options) + if stream_options is not None + else None + ), temperature=temperature, top_p=top_p, tools=tools, tool_choice=tool_choice, user=user, - ignore_eos=ignore_eos, response_format=response_format, request_id=request_id, + debug_config=(extra_body.get("debug_config", None) if extra_body is not None else None), ) @@ -417,22 +463,23 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, - max_tokens: int = 16, + max_tokens: Optional[int] = None, n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stream_options: Optional[Dict[str, Any]] = None, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + extra_body: Optional[Dict[str, Any]] = None, ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: """Asynchronous streaming completion interface with OpenAI API compatibility. The method is a coroutine that streams CompletionResponse @@ -446,6 +493,10 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + extra_body: Optional[Dict[str, Any]] = None, + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + Yields ------ stream_response : CompletionResponse @@ -467,23 +518,24 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, - max_tokens: int = 16, + max_tokens: Optional[int] = None, n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Literal[False] = False, + stream_options: Literal[None] = None, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + extra_body: Optional[Dict[str, Any]] = None, ) -> openai_api_protocol.CompletionResponse: """Asynchronous non-streaming completion interface with OpenAI API compatibility. @@ -495,6 +547,10 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + extra_body: Optional[Dict[str, Any]] = None, + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + Returns ------ response : CompletionResponse @@ -515,23 +571,24 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, - max_tokens: int = 16, + max_tokens: Optional[int] = None, n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, + stream_options: Optional[Dict[str, Any]] = None, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + extra_body: Optional[Dict[str, Any]] = None, ) -> Union[ AsyncGenerator[openai_api_protocol.CompletionResponse, Any], openai_api_protocol.CompletionResponse, @@ -546,6 +603,10 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + extra_body: Optional[Dict[str, Any]] = None, + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + Raises ------ e : BadRequestError @@ -566,13 +627,18 @@ async def create( # pylint: disable=too-many-arguments,too-many-locals seed=seed, stop=stop, stream=stream, + stream_options=( + openai_api_protocol.StreamOptions.model_validate(stream_options) + if stream_options is not None + else None + ), suffix=suffix, temperature=temperature, top_p=top_p, user=user, - ignore_eos=ignore_eos, response_format=response_format, request_id=request_id, + debug_config=(extra_body.get("debug_config", None) if extra_body is not None else None), ) @@ -596,23 +662,24 @@ def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, - max_tokens: int = 16, + max_tokens: Optional[int] = None, n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stream_options: Optional[Dict[str, Any]] = None, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, - ) -> openai_api_protocol.CompletionResponse: + extra_body: Optional[Dict[str, Any]] = None, + ) -> Iterator[openai_api_protocol.CompletionResponse]: """Synchronous streaming completion interface with OpenAI API compatibility. The method streams back CompletionResponse that conforms to OpenAI API one at a time via yield. @@ -625,6 +692,10 @@ def create( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + extra_body: Optional[Dict[str, Any]] = None, + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + Yields ------ stream_response : CompletionResponse @@ -646,24 +717,25 @@ def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, - max_tokens: int = 16, + max_tokens: Optional[int] = None, n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Literal[False] = False, + stream_options: Literal[None] = None, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, - ) -> Iterator[openai_api_protocol.CompletionResponse]: + extra_body: Optional[Dict[str, Any]] = None, + ) -> openai_api_protocol.CompletionResponse: """Synchronous non-streaming completion interface with OpenAI API compatibility. See https://platform.openai.com/docs/api-reference/completions/create for specification. @@ -674,8 +746,12 @@ def create( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + extra_body: Optional[Dict[str, Any]] = None, + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + Returns - ------ + ------- response : CompletionResponse The completion response conforming to OpenAI API. See mlc_llm/protocol/openai_api_protocol.py or @@ -694,24 +770,28 @@ def create( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, - max_tokens: int = 16, + max_tokens: Optional[int] = None, n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, + stream_options: Optional[Dict[str, Any]] = None, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, - ) -> Iterator[openai_api_protocol.CompletionResponse]: + extra_body: Optional[Dict[str, Any]] = None, + ) -> Union[ + Iterator[openai_api_protocol.CompletionResponse], + openai_api_protocol.CompletionResponse, + ]: """Synchronous completion interface with OpenAI API compatibility. See https://platform.openai.com/docs/api-reference/completions/create for specification. @@ -722,6 +802,10 @@ def create( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + extra_body: Optional[Dict[str, Any]] = None, + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + Raises ------ e : BadRequestError @@ -742,13 +826,18 @@ def create( # pylint: disable=too-many-arguments,too-many-locals seed=seed, stop=stop, stream=stream, + stream_options=( + openai_api_protocol.StreamOptions.model_validate(stream_options) + if stream_options is not None + else None + ), suffix=suffix, temperature=temperature, top_p=top_p, user=user, - ignore_eos=ignore_eos, response_format=response_format, request_id=request_id, + debug_config=(extra_body.get("debug_config", None) if extra_body is not None else None), ) @@ -758,7 +847,7 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): Parameters ---------- - models : str + model : str A path to ``mlc-chat-config.json``, or an MLC model directory that contains `mlc-chat-config.json`. It can also be a link to a HF repository pointing to an MLC compiled model. @@ -767,16 +856,16 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): The device used to deploy the model such as "cuda" or "cuda:0". Will default to "auto" and detect from local available GPUs if not specified. - model_lib_path : Optional[str] + model_lib : Optional[str] The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use the provided ``model`` to search over possible paths. - It the model lib path is not found, it will be compiled in a JIT manner. + It the model lib is not found, it will be compiled in a JIT manner. mode : Literal["local", "interactive", "server"] The engine mode in MLC LLM. We provide three preset modes: "local", "interactive" and "server". The default mode is "local". - The choice of mode decides the values of "max_batch_size", "max_total_sequence_length" + The choice of mode decides the values of "max_num_sequence", "max_total_sequence_length" and "prefill_chunk_size" when they are not explicitly specified. 1. Mode "local" refers to the local server deployment which has low request concurrency. So the max batch size will be set to 4, and max @@ -791,106 +880,67 @@ class AsyncMLCEngine(engine_base.MLCEngineBase): In this mode, we will automatically infer the largest possible max batch size and max total sequence length. - You can manually specify arguments "max_batch_size", "max_total_sequence_length" and + You can manually specify arguments "max_num_sequence", "max_total_sequence_length" and "prefill_chunk_size" to override the automatic inferred values. - additional_models : Optional[List[str]] - The model paths and (optional) model library paths of additional models - (other than the main model). - When engine is enabled with speculative decoding, additional models are needed. - Each string in the list is either in form "model_path" or "model_path:model_lib_path". - When the model lib path of a model is not given, JIT model compilation will - be activated to compile the model automatically. - - max_batch_size : Optional[int] - The maximum allowed batch size set for the KV cache to concurrently support. - - max_total_sequence_length : Optional[int] - The KV cache total token capacity, i.e., the maximum total number of tokens that - the KV cache support. This decides the GPU memory size that the KV cache consumes. - If not specified, system will automatically estimate the maximum capacity based - on the vRAM size on GPU. - - prefill_chunk_size : Optional[int] - The maximum number of tokens the model passes for prefill each time. - It should not exceed the prefill chunk size in model config. - If not specified, this defaults to the prefill chunk size in model config. - - max_history_size : Optional[int] - The maximum history for RNN state. - - gpu_memory_utilization : Optional[float] - A number in (0, 1) denoting the fraction of GPU memory used by the server in total. - It is used to infer to maximum possible KV cache capacity. - When it is unspecified, it defaults to 0.85. - Under mode "local" or "interactive", the actual memory usage may be - significantly smaller than this number. Under mode "server", the actual - memory usage may be slightly larger than this number. - engine_config : Optional[EngineConfig] - The MLCEngine execution configuration. - Currently speculative decoding mode is specified via engine config. - For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" - to specify the eagle-style speculative decoding. - Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. + Additional configurable arguments of MLC engine. + See class "EngineConfig" for more detail. enable_tracing : bool A boolean indicating if to enable event logging for requests. """ - def __init__( # pylint: disable=too-many-arguments + def __init__( # pylint: disable=too-many-arguments,too-many-locals self, model: str, device: Union[str, Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", - additional_models: Optional[List[str]] = None, - max_batch_size: Optional[int] = None, - max_total_sequence_length: Optional[int] = None, - prefill_chunk_size: Optional[int] = None, - max_history_size: Optional[int] = None, - gpu_memory_utilization: Optional[float] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, - spec_draft_length: int = 4, + engine_config: Optional[EngineConfig] = None, enable_tracing: bool = False, ) -> None: super().__init__( "async", model=model, device=device, - model_lib_path=model_lib_path, + model_lib=model_lib, mode=mode, - additional_models=additional_models, - max_batch_size=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - prefill_chunk_size=prefill_chunk_size, - max_history_size=max_history_size, - gpu_memory_utilization=gpu_memory_utilization, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, + engine_config=engine_config, enable_tracing=enable_tracing, ) - self.chat = Chat(weakref.ref(self)) + self.chat = AsyncChat(weakref.ref(self)) self.completions = AsyncCompletion(weakref.ref(self)) async def abort(self, request_id: str) -> None: """Generation abortion interface. - Parameter + Parameters --------- request_id : str The id of the request to abort. """ self._abort(request_id) + async def metrics(self) -> engine_base.EngineMetrics: + """Get engine metrics + + Returns + ------- + metrics: EngineMetrics + The engine metrics + """ + # pylint: disable=protected-access + return await engine_base._async_query_engine_metrics(self) + async def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -899,14 +949,15 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + debug_config: Optional[Dict[str, Any]] = None, ) -> Union[ AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any], openai_api_protocol.ChatCompletionResponse, @@ -920,6 +971,11 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local request_id : Optional[str] The optional request id. A random one will be generated if it is not given. + Extra body options to pass to the request. + Can be used to pass debug config as extra_body["debug_config"] + + debug_config: Optional[Dict[str, Any]] = None, + Debug config body options to pass to the request. Raises ------ @@ -946,6 +1002,11 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local seed=seed, stop=stop, stream=stream, + stream_options=( + openai_api_protocol.StreamOptions.model_validate(stream_options) + if stream_options is not None + else None + ), temperature=temperature, top_p=top_p, tools=( @@ -955,39 +1016,52 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local ), tool_choice=tool_choice, user=user, - ignore_eos=ignore_eos, response_format=( openai_api_protocol.RequestResponseFormat.model_validate(response_format) if response_format is not None else None ), + debug_config=( + debug_protocol.DebugConfig.model_validate(debug_config) + if debug_config is not None + else None + ), ), request_id=request_id, + request_final_usage_include_extra=True, ) if stream: # Stream response. return chatcmpl_generator # Normal response. - num_prompt_tokens = 0 - num_completion_tokens = 0 output_texts = ["" for _ in range(n)] finish_reasons: List[Optional[str]] = [None for _ in range(n)] logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( [[] for _ in range(n)] if logprobs else None ) - async for response in chatcmpl_generator: - num_prompt_tokens = response.usage.prompt_tokens - num_completion_tokens = response.usage.completion_tokens - for choice in response.choices: - assert isinstance(choice.delta.content, str) - output_texts[choice.index] += choice.delta.content - if choice.finish_reason is not None and finish_reasons[choice.index] is None: - finish_reasons[choice.index] = choice.finish_reason - if choice.logprobs is not None: - assert logprob_results is not None - logprob_results[ # pylint: disable=unsupported-assignment-operation - choice.index - ] += choice.logprobs.content + request_final_usage = None + try: + async for response in chatcmpl_generator: + # when usage is not None this is the last chunk + if response.usage is not None: + request_final_usage = response.usage + continue + for choice in response.choices: + assert isinstance(choice.delta.content, str) + output_texts[choice.index] += choice.delta.content + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + assert logprob_results is not None + logprob_results[ # pylint: disable=unsupported-assignment-operation + choice.index + ] += choice.logprobs.content + except ( + Exception, + asyncio.CancelledError, + ) as err: # pylint: disable=broad-exception-caught + logger.error("Error in chat completion with request ID %s: %s", request_id, err) + raise err assert all(finish_reason is not None for finish_reason in finish_reasons) use_function_calling, tool_calls_list = engine_base.process_function_call_output( @@ -1001,8 +1075,7 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local tool_calls_list=tool_calls_list, logprob_results=logprob_results, use_function_calling=use_function_calling, - num_prompt_tokens=num_prompt_tokens, - num_completion_tokens=num_completion_tokens, + usage=request_final_usage, ) async def _completion( # pylint: disable=too-many-arguments,too-many-locals @@ -1012,23 +1085,24 @@ async def _completion( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, - max_tokens: int = 16, + max_tokens: Optional[int] = None, n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, + stream_options: Optional[Dict[str, Any]] = None, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + debug_config: Optional[Dict[str, Any]] = None, ) -> Union[ AsyncGenerator[openai_api_protocol.CompletionResponse, Any], openai_api_protocol.CompletionResponse, @@ -1043,6 +1117,9 @@ async def _completion( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + debug_config: Optional[Dict[str, Any]] = None, + Extra debug options to pass to the request. + Raises ------ e : BadRequestError @@ -1066,25 +1143,34 @@ async def _completion( # pylint: disable=too-many-arguments,too-many-locals seed=seed, stop=stop, stream=stream, + stream_options=( + openai_api_protocol.StreamOptions.model_validate(stream_options) + if stream_options is not None + else None + ), suffix=suffix, temperature=temperature, top_p=top_p, user=user, - ignore_eos=ignore_eos, response_format=( openai_api_protocol.RequestResponseFormat.model_validate(response_format) if response_format is not None else None ), + debug_config=( + debug_protocol.DebugConfig.model_validate(debug_config) + if debug_config is not None + else None + ), ), - request_id, + request_id=request_id, + request_final_usage_include_extra=True, ) if stream: # Stream response. return cmpl_generator # Normal response. - num_prompt_tokens = 0 - num_completion_tokens = 0 + request_final_usage = None output_texts = ["" for _ in range(n)] finish_reasons: List[Optional[str]] = [None for _ in range(n)] logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( @@ -1092,8 +1178,10 @@ async def _completion( # pylint: disable=too-many-arguments,too-many-locals ) async for response in cmpl_generator: - num_prompt_tokens = response.usage.prompt_tokens - num_completion_tokens = response.usage.completion_tokens + # this is the final chunk + if response.usage is not None: + request_final_usage = response.usage + continue for choice in response.choices: output_texts[choice.index] += choice.text if choice.finish_reason is not None and finish_reasons[choice.index] is None: @@ -1105,18 +1193,21 @@ async def _completion( # pylint: disable=too-many-arguments,too-many-locals ] += choice.logprobs.content assert all(finish_reason is not None for finish_reason in finish_reasons) + return engine_base.wrap_completion_response( request_id=request_id, model=model, output_texts=output_texts, finish_reasons=finish_reasons, logprob_results=logprob_results, - num_prompt_tokens=num_prompt_tokens, - num_completion_tokens=num_completion_tokens, + usage=request_final_usage, ) async def _handle_chat_completion( - self, request: openai_api_protocol.ChatCompletionRequest, request_id: str + self, + request: openai_api_protocol.ChatCompletionRequest, + request_id: str, + request_final_usage_include_extra: bool, ) -> AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any]: """The implementation fo asynchronous ChatCompletionRequest handling. @@ -1146,30 +1237,41 @@ async def _handle_chat_completion( self.max_input_sequence_length, self.conv_template.model_copy(deep=True), ) - + # prompt length is not used + _ = prompt_length finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - num_completion_tokens = 0 self.state.record_event(request_id, event="invoke generate") - async for delta_outputs in self._generate( - prompts, generation_cfg, request_id # type: ignore - ): - response, num_completion_tokens = engine_base.process_chat_completion_stream_output( - delta_outputs, - request_id, - self.state, - request.model, - generation_cfg, - use_function_calling, - prompt_length, - finish_reasons, - num_completion_tokens, - ) - if response is not None: - yield response - self.state.record_event(request_id, event="finish") + try: + async for delta_outputs in self._generate( + prompts, generation_cfg, request_id # type: ignore + ): + response = engine_base.process_chat_completion_stream_output( + delta_outputs, + request, + request_id, + self.state, + use_function_calling, + finish_reasons, + ) + + if response is not None: + if response.usage is not None: + if not request_final_usage_include_extra: + response.usage.extra = None + yield response + self.state.record_event(request_id, event="finish") + except ( + Exception, + asyncio.CancelledError, + ) as err: # pylint: disable=broad-exception-caught + logger.error("Error in _handle_chat_completion for request %s: %s", request_id, err) + raise err async def _handle_completion( - self, request: openai_api_protocol.CompletionRequest, request_id: str + self, + request: openai_api_protocol.CompletionRequest, + request_id: str, + request_final_usage_include_extra: bool, ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: """The implementation fo asynchronous CompletionRequest handling. @@ -1194,38 +1296,46 @@ async def _handle_completion( request, request_id, self.state, - self.model_config_dicts[0], self.tokenizer, self.max_input_sequence_length, + self.conv_template.model_copy(deep=True), ) + _ = prompt_length if echo_response is not None: yield echo_response - num_completion_tokens = 0 finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] self.state.record_event(request_id, event="invoke generate") - async for delta_outputs in self._generate( - prompt, generation_cfg, request_id # type: ignore - ): - response, num_completion_tokens = engine_base.process_completion_stream_output( - delta_outputs, - request_id, - self.state, - request.model, - generation_cfg, - prompt_length, - finish_reasons, - num_completion_tokens, - ) - if response is not None: - yield response + try: + async for delta_outputs in self._generate( + prompt, generation_cfg, request_id # type: ignore + ): + response = engine_base.process_completion_stream_output( + delta_outputs, + request, + request_id, + self.state, + finish_reasons, + ) - suffix_response = engine_base.create_completion_suffix_response( - request, request_id, prompt_length, finish_reasons, num_completion_tokens - ) - if suffix_response is not None: - yield suffix_response - self.state.record_event(request_id, event="finish") + if response is not None: + if response.usage is not None: + if not request_final_usage_include_extra: + response.usage.extra = None + yield response + + suffix_response = engine_base.create_completion_suffix_response( + request, request_id, finish_reasons + ) + if suffix_response is not None: + yield suffix_response + self.state.record_event(request_id, event="finish") + except ( + Exception, + asyncio.CancelledError, + ) as err: # pylint: disable=broad-exception-caught + logger.error("Error in _handle_completion for request %s: %s", request_id, err) + raise err async def _generate( self, @@ -1264,7 +1374,9 @@ async def _generate( # Create the request with the given id, input data, generation # config and the created callback. input_data = engine_utils.convert_prompts_to_data(prompt) - request = Request(request_id, input_data, generation_config) + request = self._ffi["create_request"]( + request_id, input_data, generation_config.model_dump_json(by_alias=True) + ) # Create the unique async request stream of the request. stream = engine_base.AsyncRequestStream() @@ -1282,7 +1394,6 @@ async def _generate( stream, [TextStreamer(self.tokenizer) for _ in range(generation_config.n)], ) - self.state.async_num_unfinished_generations[request_id] = generation_config.n self._ffi["add_request"](request) # Iterate the stream asynchronously and yield the output. @@ -1293,13 +1404,13 @@ async def _generate( Exception, asyncio.CancelledError, ) as exception: # pylint: disable=broad-exception-caught + logger.error("Error in _generate for request %s: %s", request_id, exception) await self.abort(request_id) raise exception def _abort(self, request_id: str): """Internal implementation of request abortion.""" self.state.async_streamers.pop(request_id, None) - self.state.async_num_unfinished_generations.pop(request_id, None) self._ffi["abort_request"](request_id) @@ -1309,7 +1420,7 @@ class MLCEngine(engine_base.MLCEngineBase): Parameters ---------- - models : str + model : str A path to ``mlc-chat-config.json``, or an MLC model directory that contains `mlc-chat-config.json`. It can also be a link to a HF repository pointing to an MLC compiled model. @@ -1318,16 +1429,16 @@ class MLCEngine(engine_base.MLCEngineBase): The device used to deploy the model such as "cuda" or "cuda:0". Will default to "auto" and detect from local available GPUs if not specified. - model_lib_path : Optional[str] + model_lib : Optional[str] The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use the provided ``model`` to search over possible paths. - It the model lib path is not found, it will be compiled in a JIT manner. + It the model lib is not found, it will be compiled in a JIT manner. mode : Literal["local", "interactive", "server"] The engine mode in MLC LLM. We provide three preset modes: "local", "interactive" and "server". The default mode is "local". - The choice of mode decides the values of "max_batch_size", "max_total_sequence_length" + The choice of mode decides the values of "max_num_sequence", "max_total_sequence_length" and "prefill_chunk_size" when they are not explicitly specified. 1. Mode "local" refers to the local server deployment which has low request concurrency. So the max batch size will be set to 4, and max @@ -1342,81 +1453,34 @@ class MLCEngine(engine_base.MLCEngineBase): In this mode, we will automatically infer the largest possible max batch size and max total sequence length. - You can manually specify arguments "max_batch_size", "max_total_sequence_length" and + You can manually specify arguments "max_num_sequence", "max_total_sequence_length" and "prefill_chunk_size" to override the automatic inferred values. - additional_models : Optional[List[str]] - The model paths and (optional) model library paths of additional models - (other than the main model). - When engine is enabled with speculative decoding, additional models are needed. - Each string in the list is either in form "model_path" or "model_path:model_lib_path". - When the model lib path of a model is not given, JIT model compilation will - be activated to compile the model automatically. - - max_batch_size : Optional[int] - The maximum allowed batch size set for the KV cache to concurrently support. - - max_total_sequence_length : Optional[int] - The KV cache total token capacity, i.e., the maximum total number of tokens that - the KV cache support. This decides the GPU memory size that the KV cache consumes. - If not specified, system will automatically estimate the maximum capacity based - on the vRAM size on GPU. - - prefill_chunk_size : Optional[int] - The maximum number of tokens the model passes for prefill each time. - It should not exceed the prefill chunk size in model config. - If not specified, this defaults to the prefill chunk size in model config. - - gpu_memory_utilization : Optional[float] - A number in (0, 1) denoting the fraction of GPU memory used by the server in total. - It is used to infer to maximum possible KV cache capacity. - When it is unspecified, it defaults to 0.85. - Under mode "local" or "interactive", the actual memory usage may be - significantly smaller than this number. Under mode "server", the actual - memory usage may be slightly larger than this number. - engine_config : Optional[EngineConfig] - The MLCEngine execution configuration. - Currently speculative decoding mode is specified via engine config. - For example, you can use "--engine-config='spec_draft_length=4;speculative_mode=EAGLE'" - to specify the eagle-style speculative decoding. - Check out class `EngineConfig` in mlc_llm/serve/config.py for detailed specification. + Additional configurable arguments of MLC engine. + See class "EngineConfig" for more detail. enable_tracing : bool A boolean indicating if to enable event logging for requests. """ - def __init__( # pylint: disable=too-many-arguments + def __init__( # pylint: disable=too-many-arguments,too-many-locals self, model: str, device: Union[str, Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", - additional_models: Optional[List[str]] = None, - max_batch_size: Optional[int] = None, - max_total_sequence_length: Optional[int] = None, - prefill_chunk_size: Optional[int] = None, - max_history_size: Optional[int] = None, - gpu_memory_utilization: Optional[float] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, - spec_draft_length: int = 4, + engine_config: Optional[EngineConfig] = None, enable_tracing: bool = False, ) -> None: super().__init__( "sync", model=model, device=device, - model_lib_path=model_lib_path, + model_lib=model_lib, mode=mode, - additional_models=additional_models, - max_batch_size=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - prefill_chunk_size=prefill_chunk_size, - max_history_size=max_history_size, - gpu_memory_utilization=gpu_memory_utilization, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, + engine_config=engine_config, enable_tracing=enable_tracing, ) self.chat = Chat(weakref.ref(self)) @@ -1425,20 +1489,31 @@ def __init__( # pylint: disable=too-many-arguments def abort(self, request_id: str) -> None: """Generation abortion interface. - Parameter + Parameters --------- request_id : str The id of the request to abort. """ self._ffi["abort_request"](request_id) + def metrics(self) -> engine_base.EngineMetrics: + """Get engine metrics + + Returns + ------- + metrics: EngineMetrics + The engine metrics + """ + # pylint: disable=protected-access + return engine_base._query_engine_metrics(self) + def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals self, *, messages: List[Dict[str, Any]], model: Optional[str] = None, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, @@ -1447,14 +1522,15 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, - temperature: float = 1.0, - top_p: float = 1.0, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, + debug_config: Optional[Dict[str, Any]] = None, ) -> Union[ Iterator[openai_api_protocol.ChatCompletionStreamResponse], openai_api_protocol.ChatCompletionResponse, @@ -1469,6 +1545,9 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + debug_config: Optional[Dict[str, Any]] = None, + Extra debug options to pass to the request. + Raises ------ e : BadRequestError @@ -1494,6 +1573,11 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals seed=seed, stop=stop, stream=stream, + stream_options=( + openai_api_protocol.StreamOptions.model_validate(stream_options) + if stream_options is not None + else None + ), temperature=temperature, top_p=top_p, tools=( @@ -1503,12 +1587,16 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals ), tool_choice=tool_choice, user=user, - ignore_eos=ignore_eos, response_format=( openai_api_protocol.RequestResponseFormat.model_validate(response_format) if response_format is not None else None ), + debug_config=( + debug_protocol.DebugConfig.model_validate(debug_config) + if debug_config is not None + else None + ), ), request_id=request_id, ) @@ -1516,16 +1604,17 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals # Stream response. return chatcmpl_generator # Normal response. - num_prompt_tokens = 0 - num_completion_tokens = 0 + request_final_usage = None output_texts = ["" for _ in range(n)] finish_reasons: List[Optional[str]] = [None for _ in range(n)] logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( [[] for _ in range(n)] if logprobs else None ) for response in chatcmpl_generator: - num_prompt_tokens = response.usage.prompt_tokens - num_completion_tokens = response.usage.completion_tokens + # if usage is not None, this is the last chunk + if response.usage is not None: + request_final_usage = response.usage + continue for choice in response.choices: assert isinstance(choice.delta.content, str) output_texts[choice.index] += choice.delta.content @@ -1549,8 +1638,7 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals tool_calls_list=tool_calls_list, logprob_results=logprob_results, use_function_calling=use_function_calling, - num_prompt_tokens=num_prompt_tokens, - num_completion_tokens=num_completion_tokens, + usage=request_final_usage, ) def _completion( # pylint: disable=too-many-arguments,too-many-locals @@ -1560,24 +1648,28 @@ def _completion( # pylint: disable=too-many-arguments,too-many-locals model: Optional[str] = None, best_of: int = 1, echo: bool = False, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, logprobs: bool = False, top_logprobs: int = 0, logit_bias: Optional[Dict[int, float]] = None, - max_tokens: int = 16, + max_tokens: Optional[int] = None, n: int = 1, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: bool = False, + stream_options: Optional[Dict[str, Any]] = None, suffix: Optional[str] = None, - temperature: float = 1.0, - top_p: float = 1.0, + temperature: Optional[float] = None, + top_p: Optional[float] = None, user: Optional[str] = None, - ignore_eos: bool = False, response_format: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, - ) -> Iterator[openai_api_protocol.CompletionResponse]: + debug_config: Optional[Dict[str, Any]] = None, + ) -> Union[ + Iterator[openai_api_protocol.CompletionResponse], + openai_api_protocol.CompletionResponse, + ]: """Synchronous completion internal interface with OpenAI API compatibility. See https://platform.openai.com/docs/api-reference/completions/create for specification. @@ -1588,6 +1680,9 @@ def _completion( # pylint: disable=too-many-arguments,too-many-locals The optional request id. A random one will be generated if it is not given. + debug_config: Optional[Dict[str, Any]] = None, + Extra debug options to pass to the request. + Raises ------ e : BadRequestError @@ -1595,6 +1690,7 @@ def _completion( # pylint: disable=too-many-arguments,too-many-locals """ if request_id is None: request_id = f"cmpl-{engine_utils.random_uuid()}" + cmpl_generator = self._handle_completion( openai_api_protocol.CompletionRequest( model=model, @@ -1611,25 +1707,33 @@ def _completion( # pylint: disable=too-many-arguments,too-many-locals seed=seed, stop=stop, stream=stream, + stream_options=( + openai_api_protocol.StreamOptions.model_validate(stream_options) + if stream_options is not None + else None + ), suffix=suffix, temperature=temperature, top_p=top_p, user=user, - ignore_eos=ignore_eos, response_format=( openai_api_protocol.RequestResponseFormat.model_validate(response_format) if response_format is not None else None ), + debug_config=( + debug_protocol.DebugConfig.model_validate(debug_config) + if debug_config is not None + else None + ), ), - request_id, + request_id=request_id, ) if stream: # Stream response. return cmpl_generator # Normal response. - num_prompt_tokens = 0 - num_completion_tokens = 0 + request_final_usage = None output_texts = ["" for _ in range(n)] finish_reasons: List[Optional[str]] = [None for _ in range(n)] logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = ( @@ -1637,8 +1741,10 @@ def _completion( # pylint: disable=too-many-arguments,too-many-locals ) for response in cmpl_generator: - num_prompt_tokens = response.usage.prompt_tokens - num_completion_tokens = response.usage.completion_tokens + # this is the final chunk + if response.usage is not None: + request_final_usage = response.usage + continue for choice in response.choices: output_texts[choice.index] += choice.text if choice.finish_reason is not None and finish_reasons[choice.index] is None: @@ -1656,8 +1762,7 @@ def _completion( # pylint: disable=too-many-arguments,too-many-locals output_texts=output_texts, finish_reasons=finish_reasons, logprob_results=logprob_results, - num_prompt_tokens=num_prompt_tokens, - num_completion_tokens=num_completion_tokens, + usage=request_final_usage, ) def _handle_chat_completion( @@ -1691,21 +1796,18 @@ def _handle_chat_completion( self.max_input_sequence_length, self.conv_template.model_copy(deep=True), ) + _ = prompt_length finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] - num_completion_tokens = 0 self.state.record_event(request_id, event="invoke generate") for delta_outputs in self._generate(prompts, generation_cfg, request_id): # type: ignore - response, num_completion_tokens = engine_base.process_chat_completion_stream_output( + response = engine_base.process_chat_completion_stream_output( delta_outputs, + request, request_id, self.state, - request.model, - generation_cfg, use_function_calling, - prompt_length, finish_reasons, - num_completion_tokens, ) if response is not None: yield response @@ -1714,7 +1816,7 @@ def _handle_chat_completion( def _handle_completion( self, request: openai_api_protocol.CompletionRequest, request_id: str ) -> Iterator[openai_api_protocol.CompletionResponse]: - """The implementation fo synchronous CompletionRequest handling. + """The implementation for synchronous CompletionRequest handling. Yields ------ @@ -1737,32 +1839,29 @@ def _handle_completion( request, request_id, self.state, - self.model_config_dicts[0], self.tokenizer, self.max_input_sequence_length, + self.conv_template.model_copy(deep=True), ) + _ = prompt_length if echo_response is not None: yield echo_response - num_completion_tokens = 0 finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)] self.state.record_event(request_id, event="invoke generate") for delta_outputs in self._generate(prompt, generation_cfg, request_id): # type: ignore - response, num_completion_tokens = engine_base.process_completion_stream_output( + response = engine_base.process_completion_stream_output( delta_outputs, + request, request_id, self.state, - request.model, - generation_cfg, - prompt_length, finish_reasons, - num_completion_tokens, ) if response is not None: yield response suffix_response = engine_base.create_completion_suffix_response( - request, request_id, prompt_length, finish_reasons, num_completion_tokens + request, request_id, finish_reasons ) if suffix_response is not None: yield suffix_response @@ -1777,7 +1876,8 @@ def _generate( # pylint: disable=too-many-locals """Internal synchronous text generation interface of AsyncMLCEngine. The method is a coroutine that streams a list of CallbackStreamOutput at a time via yield. The returned list length is the number of - parallel generations specified by `generation_config.n`. + parallel generations specified by `generation_config.n` + except for the final chunk(which is always an List of size 1 and comes with usage) Parameters ---------- @@ -1794,9 +1894,8 @@ def _generate( # pylint: disable=too-many-locals ------ request_output : List[engine_base.CallbackStreamOutput] The delta generated outputs in a list. - The number of list elements equals to `generation_config.n`, - and each element corresponds to the delta output of a parallel - generation. + Except for the final chunk, the number of list elements equals to `generation_config.n`, + and each element corresponds to the delta output of a parallel generation. """ if self._terminated: raise ValueError("The engine has terminated.") @@ -1804,39 +1903,62 @@ def _generate( # pylint: disable=too-many-locals # Create the request with the given id, input data, generation # config and the created callback. input_data = engine_utils.convert_prompts_to_data(prompt) - request = Request(request_id, input_data, generation_config) + request = self._ffi["create_request"]( + request_id, input_data, generation_config.model_dump_json(by_alias=True) + ) # Record the stream in the tracker self.state.sync_output_queue = queue.Queue() self.state.sync_text_streamers = [ TextStreamer(self.tokenizer) for _ in range(generation_config.n) ] - self.state.sync_num_unfinished_generations = generation_config.n self._ffi["add_request"](request) # Iterate the stream asynchronously and yield the token. try: - while self.state.sync_num_unfinished_generations > 0: + while True: delta_outputs = self.state.sync_output_queue.get() - request_outputs = self._request_stream_callback_impl(delta_outputs) + request_outputs, request_final_usage_json_str = self._request_stream_callback_impl( + delta_outputs + ) for request_output in request_outputs: yield request_output + + if request_final_usage_json_str is not None: + # final chunk, we can break + output = engine_base.CallbackStreamOutput( + delta_text="", + delta_logprob_json_strs=None, + finish_reason=None, + request_final_usage_json_str=request_final_usage_json_str, + ) + yield [output] + break except Exception as exception: # pylint: disable=broad-exception-caught self.abort(request_id) raise exception def _request_stream_callback_impl( self, delta_outputs: List[data.RequestStreamOutput] - ) -> List[List[engine_base.CallbackStreamOutput]]: + ) -> Tuple[List[List[engine_base.CallbackStreamOutput]], Optional[str]]: """The underlying implementation of request stream callback of MLCEngine.""" batch_outputs: List[List[engine_base.CallbackStreamOutput]] = [] for delta_output in delta_outputs: request_id, stream_outputs = delta_output.unpack() self.state.record_event(request_id, event="start callback") + + # final chunk is now always indicated by a chunk + # where usage json is present + # the backend engine always streams back this chunk + # regardless of include_usage option + is_final_chunk = stream_outputs[0].request_final_usage_json_str is not None + if is_final_chunk: + return (batch_outputs, stream_outputs[0].request_final_usage_json_str) + outputs: List[engine_base.CallbackStreamOutput] = [] for stream_output, text_streamer in zip(stream_outputs, self.state.sync_text_streamers): self.state.record_event(request_id, event="start detokenization") - delta_text = ( + delta_text = stream_output.extra_prefix_string + ( text_streamer.put(stream_output.delta_token_ids) if len(stream_output.delta_token_ids) > 0 else "" @@ -1848,13 +1970,11 @@ def _request_stream_callback_impl( outputs.append( engine_base.CallbackStreamOutput( delta_text=delta_text, - num_delta_tokens=len(stream_output.delta_token_ids), delta_logprob_json_strs=stream_output.delta_logprob_json_strs, finish_reason=stream_output.finish_reason, + request_final_usage_json_str=None, ) ) - if stream_output.finish_reason is not None: - self.state.sync_num_unfinished_generations -= 1 batch_outputs.append(outputs) self.state.record_event(request_id, event="finish callback") - return batch_outputs + return (batch_outputs, None) diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 65b41a66ac..146cf7fa50 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -5,33 +5,28 @@ import ast import asyncio import json +import numbers import queue -import subprocess import sys import threading -from dataclasses import asdict, dataclass +from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import tvm from tvm.runtime import Device -from mlc_llm.chat_module import _get_chat_config, _get_lib_module_path, _get_model_path -from mlc_llm.protocol import openai_api_protocol, protocol_utils +from mlc_llm.protocol import openai_api_protocol from mlc_llm.protocol.conversation_protocol import Conversation +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.protocol.mlc_chat_config import MLCChatConfig from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import ( - EngineConfig, - GenerationConfig, - KVStateKind, - SpeculativeMode, -) +from mlc_llm.serve.config import EngineConfig from mlc_llm.serve.event_trace_recorder import EventTraceRecorder -from mlc_llm.streamer import TextStreamer -from mlc_llm.support import logging +from mlc_llm.support import download_cache, logging from mlc_llm.support.auto_device import detect_device from mlc_llm.support.style import green -from mlc_llm.tokenizer import Tokenizer +from mlc_llm.tokenizers import TextStreamer, Tokenizer logging.enable_logging() logger = logging.getLogger(__name__) @@ -49,37 +44,76 @@ class ModelInfo: or a full path to a model directory (e.g., "dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1") - model_lib_path : Optional[str] + model_lib : Optional[str] The path to the compiled library of the model. E.g., "dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so" """ model: str - model_lib_path: Optional[str] = None + model_lib: Optional[str] = None + + +def _check_engine_config( + model: str, + model_lib: Optional[str], + mode: Literal["local", "interactive", "server"], + engine_config: EngineConfig, +) -> None: + """Check if the given engine config is valid.""" + if engine_config.model is not None and engine_config.model != model: + raise ValueError( + f'The argument "model" of engine constructor is "{model}", while the "model" ' + f'field in argument "engine_config" is "{engine_config.model}". ' + 'Please set the "engine_config.model" to None or set it to the same as the ' + 'argument "model".' + ) + if ( + engine_config.model_lib is not None + and model_lib is not None + and engine_config.model_lib != model_lib + ): + raise ValueError( + f'The argument "model_lib" of engine constructor is "{model_lib}", while the ' + f'"model_lib" field in argument "engine_config" is "{engine_config.model_lib}". ' + 'Please set the "engine_config.model_lib" to None or set it to the same as the ' + 'argument "model_lib".' + ) + if engine_config.mode is not None and engine_config.mode != mode: + raise ValueError( + f'The argument "mode" of engine constructor is "{mode}", while the ' + f'"mode" field in argument "engine_config" is "{engine_config.mode}". ' + 'Please set the "engine_config.mode" to None or set it to the same as the ' + 'argument "mode".' + ) + if engine_config.kv_cache_page_size != 16: + raise ValueError( + 'KV cache only supports page size 16, while the "kv_cache_page_size" field in ' + f'argument "engine_config" is "{engine_config.kv_cache_page_size}". ' + 'Please set "engine_config.kv_cache_page_size" to 16.' + ) def _parse_models( - model: str, model_lib_path: Optional[str], additional_models: Optional[List[str]] + model: str, + model_lib: Optional[str], + additional_models: List[Union[str, Tuple[str, str]]], ) -> List[ModelInfo]: - """Parse the specified model paths and model lib paths. + """Parse the specified model paths and model libs. Return a list of ModelInfo, which is a wrapper class of the model path + lib path. - - Each additional model is expected to follow the format of either - "{MODEL_PATH}" or "{MODEL_PATH}:{MODEL_LIB_PATH}". """ - models = [ModelInfo(model, model_lib_path)] - if additional_models is not None: - for additional_model in additional_models: - splits = additional_model.split(":", maxsplit=1) - if len(splits) == 2: - models.append(ModelInfo(splits[0], splits[1])) - else: - models.append(ModelInfo(splits[0])) + models = [ModelInfo(model, model_lib)] + for additional_model in additional_models: + if isinstance(additional_model, str): + models.append(ModelInfo(additional_model)) + else: + models.append(ModelInfo(additional_model[0], additional_model[1])) return models def _process_model_args( - models: List[ModelInfo], device: tvm.runtime.Device + models: List[ModelInfo], + device: tvm.runtime.Device, + engine_config: EngineConfig, ) -> Tuple[List[Tuple[str, str]], List[str], Conversation]: """Process the input ModelInfo to get the engine initialization arguments.""" conversation: Optional[Conversation] = None @@ -88,37 +122,50 @@ def _process_model_args( def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: nonlocal conversation - model_path, config_file_path = _get_model_path(model.model) - config_file_paths.append(config_file_path) - chat_config = _get_chat_config(config_file_path, user_chat_config=None) + model_path = download_cache.get_or_download_model(model.model) + mlc_config_path = model_path / "mlc-chat-config.json" + config_file_paths.append(str(mlc_config_path)) + + with open(mlc_config_path, mode="rt", encoding="utf-8") as file: + mlc_chat_config = MLCChatConfig.model_validate_json(file.read()) + if conversation is None: - assert isinstance(chat_config.conv_template, Conversation) - conversation = chat_config.conv_template + conversation = mlc_chat_config.conv_template - if model.model_lib_path is not None: - # do model lib search if the model lib path is provided + if model.model_lib is not None: + # do model lib search if the model lib is provided # error out if file not found - model_lib_path = _get_lib_module_path( - model=model.model, - model_path=model_path, - chat_config=chat_config, - model_lib_path=model.model_lib_path, - device_name=device.MASK2STR[device.device_type], - config_file_path=config_file_path, - ) + if model.model_lib.startswith("mock://"): + model_lib = model.model_lib + logger.info("[DEBUG] mock test: %s", model_lib) + elif Path(model.model_lib).is_file(): + model_lib = model.model_lib + logger.info("Using library model: %s", model_lib) + else: + raise FileNotFoundError( + f"The `model_lib` you passed in is not a file: {model.model_lib}.\n" + ) else: - # TODO(mlc-team) add logging information - # Run jit if model_lib_path is not provided + # Run jit if model_lib is not provided + # NOTE: we only import jit when necessary + # so the engine do not have to depend on compilation from mlc_llm.interface import jit # pylint: disable=import-outside-toplevel - model_lib_path = str( - jit.jit( - model_path=Path(model_path), - chat_config=asdict(chat_config), - device=device, - ) - ) - return model_path, model_lib_path + model_compile_overrides = { + "context_window_size": engine_config.max_single_sequence_length, + "prefill_chunk_size": engine_config.prefill_chunk_size, + "sliding_window_size": engine_config.sliding_window_size, + "attention_sink_size": engine_config.attention_sink_size, + "tensor_parallel_shards": engine_config.tensor_parallel_shards, + "max_batch_size": engine_config.max_num_sequence, + } + + model_lib = jit.jit( + model_path=model_path, + overrides=model_compile_overrides, + device=device, + ).model_lib_path + return str(model_path), model_lib model_args: List[Tuple[str, str]] = [_convert_model_info(model) for model in models] @@ -126,618 +173,126 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: return model_args, config_file_paths, conversation -def _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( # pylint: disable=too-many-locals,too-many-arguments - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_paths: List[str], - model_config_dicts: List[Dict[str, Any]], - max_num_sequence: int, - gpu_memory_utilization: Optional[float], -) -> Tuple[float, float, float, float, float, int]: - """Estimate the memory usage and the max total sequence length (capacity) - that the KV cache can support. - """ - assert len(models) != 0 - - kv_bytes_per_token = 0 - kv_aux_workspace_bytes = 0 - model_workspace_bytes = 0 - logit_processor_workspace_bytes = 0 - params_bytes = 0 - temp_func_bytes = 0 - - for model, model_config_path, model_config_dict in zip( - models, model_config_paths, model_config_dicts - ): - # Read metadata for the parameter size and the temporary memory size. - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-memory-usage-in-json", - "--mlc-chat-config", - model_config_path, - ] - usage_str = subprocess.check_output(cmd, universal_newlines=True) - usage_json = json.loads(usage_str) - params_bytes += usage_json["params_bytes"] - temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) - - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-kv-cache-metadata-in-json", - ] - kv_cache_metadata_str = subprocess.check_output(cmd, universal_newlines=True) - kv_cache_metadata = json.loads(kv_cache_metadata_str) - - # Read model config and compute the kv size per token. - model_config = model_config_dict["model_config"] - vocab_size = model_config["vocab_size"] - prefill_chunk_size = model_config["prefill_chunk_size"] - num_layers = kv_cache_metadata["num_hidden_layers"] - head_dim = kv_cache_metadata["head_dim"] - num_qo_heads = kv_cache_metadata["num_attention_heads"] - num_kv_heads = kv_cache_metadata["num_key_value_heads"] - hidden_size = head_dim * num_qo_heads - kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25 - kv_aux_workspace_bytes += ( - (max_num_sequence + 1) * 88 - + prefill_chunk_size * (num_qo_heads + 1) * 8 - + prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 - + 48 * 1024 * 1024 - ) - model_workspace_bytes += ( - prefill_chunk_size * 4 - + max_num_sequence * 4 - + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 - ) - logit_processor_workspace_bytes += ( - max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 - ) - - # Get single-card GPU size. - gpu_size_bytes = device.total_global_memory - if gpu_size_bytes is None: - raise ValueError("Cannot read total GPU global memory from device.") - if gpu_memory_utilization is None: - gpu_memory_utilization = 0.85 - - model_max_total_sequence_length = int( - ( - int(gpu_size_bytes) * gpu_memory_utilization - - params_bytes - - temp_func_bytes - - kv_aux_workspace_bytes - - model_workspace_bytes - - logit_processor_workspace_bytes - ) - / kv_bytes_per_token - ) - if model_max_total_sequence_length <= 0: - raise ValueError( - f"The model weight size {params_bytes} may be larger than available GPU memory " - f"size {gpu_size_bytes * gpu_memory_utilization} bytes." - ) - - if device.device_type == Device.kDLMetal: - # NOTE: Metal runtime has severe performance issues with large buffers. - # To work around the issue, we limit the KV cache capacity to 32768. - model_max_total_sequence_length = min(model_max_total_sequence_length, 32768) - - total_mem_usage_except_kv_cache = ( - params_bytes - + temp_func_bytes - + kv_aux_workspace_bytes - + model_workspace_bytes - + logit_processor_workspace_bytes - ) - return ( - total_mem_usage_except_kv_cache, - params_bytes, - kv_bytes_per_token, - kv_aux_workspace_bytes, - model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes, - int(model_max_total_sequence_length), - ) - - -def _estimate_mem_usage_and_max_history_size_for_rnn_state( # pylint: disable=too-many-arguments, too-many-locals, unused-argument - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_paths: List[str], - model_config_dicts: List[Dict[str, Any]], - max_num_sequence: int, - gpu_memory_utilization: Optional[float], -) -> Tuple[float, float, float, int]: - # Get single-card GPU size. - gpu_size_bytes = device.total_global_memory - if gpu_size_bytes is None: - raise ValueError("Cannot read total GPU global memory from device.") - if gpu_memory_utilization is None: - gpu_memory_utilization = 0.90 - - rnn_state_base_bytes = 0.0 # the memory usage for rnn state when history = 1 - param_bytes = 0.0 - temp_func_bytes = 0.0 - model_workspace_bytes = 0.0 - logit_processor_workspace_bytes = 0.0 - for model, model_config_path, model_config_dict in zip( - models, model_config_paths, model_config_dicts - ): - # Read metadata for the parameter size and the temporary memory size. - cmd = [ - sys.executable, - "-m", - "mlc_llm.cli.model_metadata", - model.model_lib_path, - "--print-memory-usage-in-json", - "--mlc-chat-config", - model_config_path, - ] - usage_str = subprocess.check_output(cmd, universal_newlines=True) - usage_json = json.loads(usage_str) - param_bytes += usage_json["params_bytes"] - temp_func_bytes = max(temp_func_bytes, usage_json["temp_func_bytes"]) - - model_config = model_config_dict["model_config"] - vocab_size = model_config_dict["vocab_size"] - head_size = model_config["head_size"] - num_heads = model_config["num_heads"] - num_layers = model_config["num_hidden_layers"] - hidden_size = model_config["hidden_size"] - prefill_chunk_size = model_config["prefill_chunk_size"] - logit_processor_workspace_bytes += ( - max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125 - ) - - model_workspace_bytes += ( - prefill_chunk_size * 4 - + max_num_sequence * 4 - + (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2 - ) - - rnn_state_base_bytes += ( - max_num_sequence * hidden_size * num_layers * 2 * 2 - + max_num_sequence * num_heads * head_size * head_size * num_layers * 2 +def _print_engine_mode_logging_msg(mode: Literal["local", "interactive", "server"]) -> None: + """Print the logging info for engine mode selection.""" + if mode == "local": + logger.info( + "The selected engine mode is %s. " + "We choose small max batch size and KV cache capacity to use less GPU memory.", + green(mode), ) - - max_history_size = int( - ( - gpu_size_bytes * gpu_memory_utilization - - logit_processor_workspace_bytes - - model_workspace_bytes - - param_bytes - - temp_func_bytes + elif mode == "interactive": + logger.info( + "The selected engine mode is %s. " + "We fix max batch size to 1 for interactive single sequence use.", + green(mode), ) - / rnn_state_base_bytes - ) - if max_history_size < 1: - raise ValueError( - f"Memory required by models may be larger than available GPU memory " - f"size {gpu_size_bytes * gpu_memory_utilization} bytes." + else: + logger.info( + "The selected engine mode is %s. " + "We use as much GPU memory as possible (within the limit " + "of gpu_memory_utilization).", + green(mode), ) - return ( - param_bytes, - model_workspace_bytes + logit_processor_workspace_bytes + temp_func_bytes, - rnn_state_base_bytes, - max_history_size, - ) - - -def _get_model_config_limit(model_config_dicts: List[Dict[str, Any]]) -> Tuple[int, int, int]: - """Read the model config dictionaries, and return the maximum single - sequence length the models can support, the maximum prefill chunk - size the models can support, and the max batch size the models can support. - - Returns - ------- - model_max_single_sequence_length : int - The maximum single sequence length the models can support. - model_max_prefill_chunk_size : int - The maximum prefill chunk size the models can support. - model_max_batch_size : int - The max batch size the models can support. - """ - model_max_single_sequence_length = int(1e9) - model_max_prefill_chunk_size = int(1e9) - model_max_batch_size = int(1e9) - for i, config in enumerate(model_config_dicts): - runtime_context_window_size = config["context_window_size"] - compile_time_context_window_size = config["model_config"]["context_window_size"] - if runtime_context_window_size > compile_time_context_window_size: - raise ValueError( - f"Model {i}'s runtime context window size ({runtime_context_window_size}) is " - "larger than the context window size used at compile time " - f"({compile_time_context_window_size})" - ) - if runtime_context_window_size == -1 and compile_time_context_window_size != -1: - raise ValueError( - f"Model {i}'s runtime context window size (infinite) is " - "larger than the context window size used at compile time " - f"({compile_time_context_window_size})" - ) - if runtime_context_window_size != -1: - model_max_single_sequence_length = min( - model_max_single_sequence_length, runtime_context_window_size - ) - - runtime_prefill_chunk_size = config["prefill_chunk_size"] - compile_time_prefill_chunk_size = config["model_config"]["prefill_chunk_size"] - if runtime_prefill_chunk_size > compile_time_prefill_chunk_size: - raise ValueError( - f"Model {i}'s runtime prefill chunk size ({runtime_prefill_chunk_size}) is " - "larger than the prefill chunk size used at compile time " - f"({compile_time_prefill_chunk_size})" - ) - model_max_prefill_chunk_size = min(model_max_prefill_chunk_size, runtime_prefill_chunk_size) - - model_max_batch_size = min(model_max_batch_size, config["model_config"]["max_batch_size"]) - - assert model_max_prefill_chunk_size != int(1e9) - assert model_max_batch_size != int(1e9) - return model_max_single_sequence_length, model_max_prefill_chunk_size, model_max_batch_size - - -def _infer_kv_cache_config_for_kv_cache( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - mode: Literal["local", "interactive", "server"], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - gpu_memory_utilization: Optional[float], - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_dicts: List[Dict[str, Any]], - model_config_paths: List[str], -) -> Tuple[int, int, int, KVStateKind, int]: - """Initialize the KV cache config with user input and GPU memory usage estimation. - The returned four integers are: - - max_batch_size - - max_total_sequence_length - - prefill_chunk_size - - kv_state_kind - - model_max_single_sequence_length - """ - ( - model_max_single_sequence_length, - model_max_prefill_chunk_size, - model_max_batch_size, - ) = _get_model_config_limit(model_config_dicts) - - def infer_args_under_mode( - mode: Literal["local", "interactive", "server"], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - ) -> Tuple[Tuple[int, int, int, KVStateKind], List[float]]: - logging_msg = "" - # - max_batch_size - if max_batch_size is None: - max_batch_size = ( - min(4, model_max_batch_size) - if mode == "local" - else (1 if mode == "interactive" else model_max_batch_size) - ) - logging_msg += f"max batch size is set to {max_batch_size}, " - else: - logging_msg += f"max batch size {max_batch_size} is specified by user, " - # - infer the maximum total sequence length that can fit GPU memory. - ( - total_mem_usage_except_kv_cache, - model_params_bytes, - kv_bytes_per_token, - kv_aux_workspace_bytes, - temp_workspace_bytes, - model_max_total_sequence_length, - ) = _estimate_mem_usage_and_max_total_sequence_length_for_kv_cache( - models, - device, - model_config_paths, - model_config_dicts, - max_batch_size, - gpu_memory_utilization, + if mode != "local": + logger.info( + "If you have low concurrent requests and want to use less GPU memory, " + 'please select mode "local".' ) - # - max_total_sequence_length - if max_total_sequence_length is None: - if mode == "local": - max_total_sequence_length = min( - model_max_total_sequence_length, model_max_single_sequence_length, 8192 - ) - elif mode == "interactive": - max_total_sequence_length = min( - model_max_total_sequence_length, model_max_single_sequence_length - ) - else: - max_total_sequence_length = min( - model_max_total_sequence_length, - max_batch_size * model_max_single_sequence_length, - ) - logging_msg += f"max KV cache token capacity is set to {max_total_sequence_length}, " - else: - logging_msg += ( - f"max KV cache token capacity {max_total_sequence_length} is specified by user. " - ) - # - prefill_chunk_size - if prefill_chunk_size is None: - if mode in ["local", "interactive"]: - prefill_chunk_size = min( - model_max_prefill_chunk_size, - model_max_total_sequence_length, - model_max_single_sequence_length, - ) - else: - prefill_chunk_size = model_max_prefill_chunk_size - logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " - else: - logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " - - if mode == "local": - logging_msg += ( - "We choose small max batch size and KV cache capacity to use less GPU memory." - ) - elif mode == "interactive": - logging_msg += "We fix max batch size to 1 for interactive single sequence use." - else: - logging_msg += ( - "We use as much GPU memory as possible (within the" - " limit of gpu_memory_utilization)." - ) - logger.info('Under mode "%s", %s', mode, logging_msg) - - # - Construct the KV cache config - # - Estimate total GPU memory usage on single GPU. - return ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - KVStateKind.ATTENTION, - ), [ - total_mem_usage_except_kv_cache + max_total_sequence_length * kv_bytes_per_token, - model_params_bytes, - kv_bytes_per_token * max_total_sequence_length + kv_aux_workspace_bytes, - temp_workspace_bytes, - ] - - # - Infer KV cache config and estimate memory usage for each mode. - local_kv_cache_config, local_mem_usage_list = infer_args_under_mode( - "local", max_batch_size, max_total_sequence_length, prefill_chunk_size - ) - interactive_kv_cache_config, interactive_mem_usage_list = infer_args_under_mode( - "interactive", max_batch_size, max_total_sequence_length, prefill_chunk_size - ) - server_kv_cache_config, server_mem_usage_list = infer_args_under_mode( - "server", max_batch_size, max_total_sequence_length, prefill_chunk_size - ) - - # - Select the config based on the actual mode. - if mode == "local": - kv_cache_config = local_kv_cache_config - mem_usage_list = local_mem_usage_list - elif mode == "interactive": - kv_cache_config = interactive_kv_cache_config - mem_usage_list = interactive_mem_usage_list - else: - kv_cache_config = server_kv_cache_config - mem_usage_list = server_mem_usage_list - - logger.info( - 'The actual engine mode is "%s". So max batch size is %s, ' - "max KV cache token capacity is %s, prefill chunk size is %s.", - green(mode), - green(str(kv_cache_config[0])), - green(str(kv_cache_config[1])), - green(str(kv_cache_config[2])), - ) - - logger.info( - "%s: %.2f MB (Parameters: %.2f MB. KVCache: %.2f MB. Temporary buffer: %.2f MB). " - "The actual usage might be slightly larger than the estimated number.", - green("Estimated total single GPU memory usage"), - *list(mem_usage / 1024 / 1024 for mem_usage in mem_usage_list), - ) - # - Final messages - override_msg = "Please override the arguments if you have particular values to set." - if mode in ["local", "interactive"]: + if mode != "interactive": logger.info( - 'Please switch to mode "server" if you want to use more GPU memory ' - "and support more concurrent requests. %s", - override_msg, + "If you don't have concurrent requests and only use the engine interactively, " + 'please select mode "interactive".' ) - else: + if mode != "server": logger.info( - 'Please switch to mode "local" or "interactive" if you want to use less GPU memory ' - "or do not have many concurrent requests to process. %s", - override_msg, + "If you have high concurrent requests and want to maximize the GPU memory utilization, " + 'please select mode "server".' ) - return *kv_cache_config, model_max_single_sequence_length +class EngineMetrics: + """Class to store the result returned by engine metrics""" -def _infer_kv_cache_config_for_rnn_state( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - mode: Literal["local", "interactive", "server"], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - max_history_size: Optional[int], - gpu_memory_utilization: Optional[float], - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_dicts: List[Dict[str, Any]], - model_config_paths: List[str], -) -> Tuple[int, int, int, KVStateKind, int]: - """Initialize the RNN state config with user input and GPU memory usage estimation. - The returned four integers are: - - max_batch_size - - max_total_sequence_length - - prefill_chunk_size - - kv_state_kind - - max_history_size - """ - logging_msg = "" - prefill_chunk_size = 0 + metrics: dict - if prefill_chunk_size is None: - prefill_chunk_size = min( - config["prefill_chunk_size"] if "prefill_chunk_size" in config else 4096 - for config in model_config_dicts - ) - logging_msg += f"prefill chunk size is set to {prefill_chunk_size}. " - else: - logging_msg += f"prefill chunk size {prefill_chunk_size} is specified by user. " - if max_batch_size is None: - max_batch_size = 1 if mode == "interactive" else 4 - logging_msg += f"max batch size is set to {max_batch_size}, " - else: - logging_msg += f"max batch size {max_batch_size} is specified by user, " + def __init__(self, metrics): + self.metrics = metrics - if mode == "local": - logging_msg += ( - "We choose small max batch size and RNN state capacity to use less GPU memory." - ) - elif mode == "interactive": - logging_msg += "We fix max batch size to 1 for interactive single sequence use." - else: - logging_msg += ( - "We use as much GPU memory as possible (within the" " limit of gpu_memory_utilization)." - ) - logger.info('Under mode "%s", %s', mode, logging_msg) - - ( - model_param_bytes, - model_temp_bytes, - model_rnn_state_base_bytes, - model_max_history_size, - ) = _estimate_mem_usage_and_max_history_size_for_rnn_state( - models, - device, - model_config_paths, - model_config_dicts, - max_batch_size, - gpu_memory_utilization, - ) - if max_history_size is None: - max_history_size = model_max_history_size - else: - max_history_size = min(max_history_size, model_max_history_size) - max_total_sequence_length = 32768 - prefill_chunk_size = 0 - kind = KVStateKind.RNNSTATE - - logger.info( - "%s: %.2f MB (Parameters: %.2f MB. RNNState: %.2f MB. Temporary buffer: %.2f MB). " - "The actual usage might be slightly larger than the estimated number.", - green("Estimated total single GPU memory usage"), - (model_param_bytes + model_temp_bytes + model_rnn_state_base_bytes) / 1024 / 1024, - model_param_bytes / 1024 / 1024, - max_history_size * model_rnn_state_base_bytes / 1024 / 1024, - model_temp_bytes / 1024 / 1024, - ) + def __str__(self): + return self.metrics.__str__() - return ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - kind, - max_history_size, - ) + def __repr__(self): + return self.metrics.__repr__() + def __getitem__(self, key): + return self.metrics[key] -def _infer_kv_cache_config( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - mode: Literal["local", "interactive", "server"], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - max_history_size: Optional[int], - gpu_memory_utilization: Optional[float], - models: List[ModelInfo], - device: tvm.runtime.Device, - model_config_dicts: List[Dict[str, Any]], - model_config_paths: List[str], -) -> Tuple[int, int, int, int, int, KVStateKind]: - """Initialize the cache config with user input and GPU memory usage estimation. - The returned four integers are: - - max_batch_size - - max_total_sequence_length - - prefill_chunk_size - - max_single_sequence_length - - max_history_size - - kv_state_kind - """ - if all("rwkv" not in model.model for model in models): - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - kv_state_kind, - max_single_sequence_length, - ) = _infer_kv_cache_config_for_kv_cache( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - gpu_memory_utilization, - models, - device, - model_config_dicts, - model_config_paths, - ) - max_history_size = 0 # KV cache doesn't need this - elif all("rwkv" in model.model for model in models): - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - kv_state_kind, - max_history_size, - ) = _infer_kv_cache_config_for_rnn_state( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - model_config_dicts, - model_config_paths, - ) - max_single_sequence_length = max_total_sequence_length # RNN state doesn't need this - else: - raise ValueError("The models should be either all KV cache models or all RNN state models.") - return ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) + def prometheus_text(self) -> str: + """Convert engine metrics into prometheus text format + Returns + ------- + text: str + The metrics in prometheus text format + """ + output_lines = [ + "# NOTE: these metrics count token in the unit of serving model's tokenization", + "# be careful when comparing them to client-side metrics that may use", + "# different tokenization to standardize across models.\n", + ] -def _infer_generation_config( - model_config_dicts: List[Dict[str, Any]] -) -> List[Tuple[float, float, float, float]]: - """Infer the generation config from the model config dictionaries. - The returned four floats are: - - temperature - - top_p - - frequency_penalty - - presence_penalty - """ - generation_configs = [] - - for model_config in model_config_dicts: - temperature = model_config.get("temperature", 1.0) - top_p = model_config.get("top_p", 1.0) - frequency_penalty = model_config.get("frequency_penalty", 0.0) - presence_penalty = model_config.get("presence_penalty", 0.0) - generation_configs.append((temperature, top_p, frequency_penalty, presence_penalty)) + def traverse(comment_scope, key_prefix, curr_value): + if isinstance(curr_value, dict): + if comment_scope: + output_lines.append(f"\n# {comment_scope}") + # first prioritize metrics in current scope + for key, value in curr_value.items(): + if isinstance(value, numbers.Number): + output_lines.append(f"{key_prefix}{key}\t{value}") + # then look into nested scopes if any + for key, value in curr_value.items(): + if isinstance(value, dict) and len(value) != 0: + traverse(f"{comment_scope}/{key}", f"{key_prefix}{key}_", value) + + traverse("", "", self.metrics) + return "\n".join(output_lines) + + +def _query_engine_metrics(engine): + """Query engine metrics via debug options""" + dummy_message = {"role": "user", "context": ""} + for response in engine.chat.completions.create( + messages=[dummy_message], + model="model", + stream=True, + stream_options={"include_usage": True}, + extra_body={"debug_config": {"special_request": "query_engine_metrics"}}, + ): + if response.usage is not None: + return EngineMetrics(response.usage.extra) + raise RuntimeError("query_engine metrics did not get metrics back") + + +async def _async_query_engine_metrics(engine): + """Query engine metrics via debug options""" + dummy_message = {"role": "user", "context": ""} + result = None + async for response in await engine.chat.completions.create( + messages=[dummy_message], + model="model", + stream=True, + stream_options={"include_usage": True}, + extra_body={"debug_config": {"special_request": "query_engine_metrics"}}, + ): + if response.usage is not None: + assert result is None + result = EngineMetrics(response.usage.extra) - return generation_configs + if result is not None: + return result + raise RuntimeError("query_engine metrics did not get metrics back") @dataclass @@ -749,21 +304,22 @@ class CallbackStreamOutput: delta_text : str The delta text generated since the last output. - num_delta_tokens : int - The number of delta tokens generated since the last output. - delta_logprob_json_strs : Optional[List[str]] The list of logprob JSON strings since the last output, or None if the request does not require logprobs. finish_reason : Optional[str] The finish reason of the request, or None if unfinished. + + request_final_usage_json_str: Optional[str] + The usage json which appears in last chunk, + when it appears all other fields will be empty """ delta_text: str - num_delta_tokens: int delta_logprob_json_strs: Optional[List[str]] finish_reason: Optional[str] + request_final_usage_json_str: Optional[str] class AsyncRequestStream: @@ -847,11 +403,9 @@ class EngineState: # States used for AsyncMLCEngine async_event_loop: Optional[asyncio.AbstractEventLoop] = None async_streamers: Dict[str, Tuple[AsyncRequestStream, List[TextStreamer]]] = {} - async_num_unfinished_generations: Dict[str, int] = {} # States used for MLCEngine sync_output_queue: queue.Queue = queue.Queue() sync_text_streamers: List[TextStreamer] = [] - sync_num_unfinished_generations: int = 0 def __init__(self, enable_tracing: bool) -> None: """Constructor.""" @@ -939,10 +493,29 @@ def _async_request_stream_callback_impl( self.record_event(request_id, event="start callback") stream, text_streamers = streamers + + # final chunk is now always indicated by a chunk + # where usage json is present + # the backend engine always streams back this chunk + # regardless of include_usage option + is_final_chunk = stream_outputs[0].request_final_usage_json_str is not None + if is_final_chunk: + # stream back this final usage chunk + output = CallbackStreamOutput( + delta_text="", + delta_logprob_json_strs=None, + finish_reason=None, + request_final_usage_json_str=stream_outputs[0].request_final_usage_json_str, + ) + stream.push([output]) + stream.finish() + self.async_streamers.pop(request_id, None) + continue + outputs = [] for stream_output, text_streamer in zip(stream_outputs, text_streamers): self.record_event(request_id, event="start detokenization") - delta_text = ( + delta_text = stream_output.extra_prefix_string + ( text_streamer.put(stream_output.delta_token_ids) if len(stream_output.delta_token_ids) > 0 else "" @@ -954,20 +527,14 @@ def _async_request_stream_callback_impl( outputs.append( CallbackStreamOutput( delta_text=delta_text, - num_delta_tokens=len(stream_output.delta_token_ids), delta_logprob_json_strs=stream_output.delta_logprob_json_strs, finish_reason=stream_output.finish_reason, + request_final_usage_json_str=None, ) ) - if stream_output.finish_reason is not None: - self.async_num_unfinished_generations[request_id] -= 1 # Push new delta text to the stream. stream.push(outputs) - if self.async_num_unfinished_generations[request_id] == 0: - stream.finish() - self.async_streamers.pop(request_id, None) - self.async_num_unfinished_generations.pop(request_id, None) self.record_event(request_id, event="finish callback") def _sync_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None: @@ -1000,20 +567,18 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals kind: Literal["async", "sync"], model: str, device: Union[str, tvm.runtime.Device], - model_lib_path: Optional[str], + model_lib: Optional[str], mode: Literal["local", "interactive", "server"], - additional_models: Optional[List[str]], - max_batch_size: Optional[int], - max_total_sequence_length: Optional[int], - prefill_chunk_size: Optional[int], - max_history_size: Optional[int], - gpu_memory_utilization: Optional[float], - speculative_mode: SpeculativeMode, - spec_draft_length: int, + engine_config: Optional[EngineConfig], enable_tracing: bool, ) -> None: + # - Check the fields fields of `engine_config`. + if engine_config is None: + engine_config = EngineConfig() + _check_engine_config(model, model_lib, mode, engine_config) + # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) + models = _parse_models(model, model_lib, engine_config.additional_models) if isinstance(device, str): device = detect_device(device) assert isinstance(device, Device) @@ -1021,36 +586,18 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model_args, model_config_paths, self.conv_template, - ) = _process_model_args(models, device) + ) = _process_model_args(models, device, engine_config) # - Load the raw model config into dict self.model_config_dicts = [] for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] + model_info.model_lib = model_args[i][1] with open(model_config_paths[i], "r", encoding="utf-8") as file: self.model_config_dicts.append(json.load(file)) - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + # - Print logging info for regarding the mode selection. + if engine_config.verbose: + _print_engine_mode_logging_msg(mode) # - Initialize engine state and engine. self.state = EngineState(enable_tracing) @@ -1063,61 +610,67 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "run_background_loop", "run_background_stream_back_loop", "reload", - "init_background_engine", + "init_threaded_engine", "exit_background_loop", + "create_request", + "get_complete_engine_config", + "reset", "debug_call_func_on_all_worker", ] } self.tokenizer = Tokenizer(model_args[0][0]) - self._ffi["init_background_engine"]( + self._ffi["init_threaded_engine"]( device, self.state.get_request_stream_callback(kind), self.state.trace_recorder, ) - self._ffi["reload"]( - EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - max_history_size=max_history_size, - kv_state_kind=kv_state_kind, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ) - ) - def _background_loop(): - self._ffi["run_background_loop"]() - - def _background_stream_back_loop(): - self._ffi["run_background_stream_back_loop"]() + background_loop = self._ffi["run_background_loop"] + background_stream_back_loop = self._ffi["run_background_stream_back_loop"] # - Create the background engine-driving thread and start the loop. - self._background_loop_thread: threading.Thread = threading.Thread(target=_background_loop) + self._background_loop_thread: threading.Thread = threading.Thread(target=background_loop) self._background_stream_back_loop_thread: threading.Thread = threading.Thread( - target=_background_stream_back_loop + target=background_stream_back_loop ) self._background_loop_thread.start() self._background_stream_back_loop_thread.start() self._terminated = False + engine_config.model = model_args[0][0] + engine_config.model_lib = model_args[0][1] + engine_config.additional_models = model_args[1:] # type: ignore + engine_config.mode = mode + self._ffi["reload"](engine_config.asjson()) + self.engine_config = EngineConfig.from_json(self._ffi["get_complete_engine_config"]()) + self.max_input_sequence_length = min( + self.engine_config.max_single_sequence_length, + self.engine_config.max_total_sequence_length, + ) + + def __del__(self): + """deleter, auto terminate""" + self.terminate() + def terminate(self): """Terminate the engine.""" + if hasattr(self, "_terminated") and self._terminated: + return self._terminated = True self._ffi["exit_background_loop"]() - self._background_loop_thread.join() - self._background_stream_back_loop_thread.join() + if hasattr(self, "_background_loop_thread"): + self._background_loop_thread.join() + if hasattr(self, "_background_stream_back_loop_thread"): + self._background_stream_back_loop_thread.join() def _debug_call_func_on_all_worker(self, func_name: str) -> None: """Call the given global function on all workers. Only for debug purpose.""" self._ffi["debug_call_func_on_all_worker"](func_name) + def reset(self): + """Reset the engine, clear the running data and metrics.""" + return self._ffi["reset"]() + def process_chat_completion_request( # pylint: disable=too-many-arguments request: openai_api_protocol.ChatCompletionRequest, @@ -1208,9 +761,8 @@ def process_chat_completion_request( # pylint: disable=too-many-arguments prompt_length = engine_utils.check_and_get_prompts_length(prompts, max_input_sequence_length) # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config( + generation_cfg = engine_utils.get_generation_config( request, - model_config, extra_stop_token_ids=conv_template.stop_token_ids, extra_stop_str=conv_template.stop_str, ) @@ -1219,15 +771,12 @@ def process_chat_completion_request( # pylint: disable=too-many-arguments def process_chat_completion_stream_output( # pylint: disable=too-many-arguments delta_outputs: List[CallbackStreamOutput], + request: openai_api_protocol.ChatCompletionRequest, request_id: str, engine_state: EngineState, - model: str, - generation_cfg: GenerationConfig, use_function_calling: bool, - prompt_length: int, finish_reasons: List[Optional[str]], - num_completion_tokens: int, -) -> Tuple[Optional[openai_api_protocol.ChatCompletionStreamResponse], int]: +) -> Optional[openai_api_protocol.ChatCompletionStreamResponse]: """Process the delta outputs of a single request of ChatCompletion, convert the delta output to ChatCompletionStreamResponse and return. @@ -1244,43 +793,49 @@ def process_chat_completion_stream_output( # pylint: disable=too-many-arguments engine_state : EngineState The state of the engine. - model : str - The requested model. - - generation_cfg : GenerationConfig - The generation config of the request. - use_function_calling : bool A boolean flag indicating if the request uses function call. - prompt_length : int - The total prompt length. - finish_reasons : List[Optional[str]] The list of finish reasons of each generation. The list length is the number of parallel generation specified by "n". This list is updated in place. - num_completion_tokens : int - The number of total completion tokens so far. - Returns ------- response : Optional[openai_api_protocol.ChatCompletionStreamResponse] The converted OpenAI API ChatCompletionStreamResponse instance. It can be none when there is no content. - - num_completion_tokens : int - The updated number of total completion tokens. - It is sum of the input number and the number of new completion tokens - from the given delta outputs. """ - assert len(delta_outputs) == generation_cfg.n + # we always stream back the final chunk with usage + is_final_chunk = delta_outputs[0].request_final_usage_json_str is not None + if is_final_chunk: + assert len(delta_outputs) == 1 + engine_state.record_event(request_id, event="yield final usage") + response = openai_api_protocol.ChatCompletionStreamResponse( + id=request_id, + choices=[], + model=request.model, + system_fingerprint="", + usage=openai_api_protocol.CompletionUsage.model_validate_json( + delta_outputs[0].request_final_usage_json_str + ), + ) + # non streaming mode always comes with usage + if not request.stream: + return response + # skip usage if stream option does not indicate include usage + if request.stream_options is None: + return None + if not request.stream_options.include_usage: + return None + return response + + # normal chunk + assert len(delta_outputs) == request.n choices = [] - num_new_completion_tokens = 0 for i, delta_output in enumerate(delta_outputs): finish_reason_updated = False - num_new_completion_tokens += delta_output.num_delta_tokens if delta_output.finish_reason is not None and finish_reasons[i] is None: finish_reasons[i] = ( delta_output.finish_reason if not use_function_calling else "tool_calls" @@ -1313,31 +868,23 @@ def process_chat_completion_stream_output( # pylint: disable=too-many-arguments ) ) - if len(choices) == 0 and num_new_completion_tokens == 0: + if len(choices) == 0: # Skip return when there is no delta output and no number of completion tokens. - return None, num_completion_tokens - num_completion_tokens += num_new_completion_tokens + return None response = openai_api_protocol.ChatCompletionStreamResponse( - id=request_id, - choices=choices, - model=model, - system_fingerprint="", - usage=openai_api_protocol.UsageInfo( - prompt_tokens=prompt_length, - completion_tokens=num_completion_tokens, - ), + id=request_id, choices=choices, model=request.model, system_fingerprint="" ) engine_state.record_event(request_id, event="yield delta output") - return response, num_completion_tokens + return response def process_completion_request( # pylint: disable=too-many-arguments request: openai_api_protocol.CompletionRequest, request_id: str, engine_state: EngineState, - model_config: Dict[str, Any], tokenizer: Tokenizer, max_input_sequence_length: int, + conv_template: Conversation, ) -> Tuple[List[int], GenerationConfig, int, Optional[openai_api_protocol.CompletionResponse]]: """Process the given CompletionRequest, apply request validity checks, and return the processed prompts, and other info. @@ -1359,6 +906,9 @@ def process_completion_request( # pylint: disable=too-many-arguments max_input_sequence_length : int The maximum allowed total prompt length. + conv_template : Conversation + The conversation template of the model. + Returns ------- prompt : List[int] @@ -1387,7 +937,11 @@ def process_completion_request( # pylint: disable=too-many-arguments assert isinstance(prompt, list) # Process generation config. Create request id. - generation_cfg = protocol_utils.get_generation_config(request, model_config) + generation_cfg = engine_utils.get_generation_config( + request, + extra_stop_token_ids=conv_template.stop_token_ids, + extra_stop_str=conv_template.stop_str, + ) # - Echo back the prompt. echo_response = None @@ -1400,10 +954,7 @@ def process_completion_request( # pylint: disable=too-many-arguments for i in range(generation_cfg.n) ], model=request.model, - usage=openai_api_protocol.UsageInfo( - prompt_tokens=prompt_length, - completion_tokens=0, - ), + usage=None, ) echo_response = response return prompt, generation_cfg, prompt_length, echo_response @@ -1411,14 +962,11 @@ def process_completion_request( # pylint: disable=too-many-arguments def process_completion_stream_output( # pylint: disable=too-many-arguments delta_outputs: List[CallbackStreamOutput], + request: openai_api_protocol.CompletionRequest, request_id: str, engine_state: EngineState, - model: str, - generation_cfg: GenerationConfig, - prompt_length: int, finish_reasons: List[Optional[str]], - num_completion_tokens: int, -) -> Tuple[Optional[openai_api_protocol.CompletionResponse], int]: +) -> Optional[openai_api_protocol.CompletionResponse]: """Process the delta outputs of a single request of Completion, convert the delta output to CompletionResponse and return. @@ -1429,49 +977,57 @@ def process_completion_stream_output( # pylint: disable=too-many-arguments The list length is the number of parallel generation specified by "n". Each element corresponds to a generation. + request: openai_api_protocol.CompletionRequest + Information about the request + request_id : str The id of the request. engine_state : EngineState The state of the engine. - model : str - The requested model. - - generation_cfg : GenerationConfig - The generation config of the request. - - prompt_length : int - The total prompt length. - finish_reasons : List[Optional[str]] The list of finish reasons of each generation. The list length is the number of parallel generation specified by "n". This list is updated in place. - num_completion_tokens : int - The number of total completion tokens so far. - Returns ------- response : Optional[openai_api_protocol.CompletionResponse] The converted OpenAI API CompletionResponse instance. It can be none when there is no content. - - num_completion_tokens : int - The updated number of total completion tokens. - It is sum of the input number and the number of new completion tokens - from the given delta outputs. """ - assert len(delta_outputs) == generation_cfg.n + # we always stream back the final chunk with usage + is_final_chunk = delta_outputs[0].request_final_usage_json_str is not None + if is_final_chunk: + assert len(delta_outputs) == 1 + engine_state.record_event(request_id, event="yield final usage") + response = openai_api_protocol.CompletionResponse( + id=request_id, + choices=[], + model=request.model, + system_fingerprint="", + usage=openai_api_protocol.CompletionUsage.model_validate_json( + delta_outputs[0].request_final_usage_json_str + ), + ) + # non streaming mode always comes with usage + if not request.stream: + return response + if request.stream_options is None: + return None + if not request.stream_options.include_usage: + return None + return response + + # normal chunk + assert len(delta_outputs) == request.n choices = [] - num_new_completion_tokens = 0 for i, delta_output in enumerate(delta_outputs): finish_reason_updated = False if delta_output.finish_reason is not None and finish_reasons[i] is None: finish_reasons[i] = delta_output.finish_reason finish_reason_updated = True - num_new_completion_tokens += delta_output.num_delta_tokens if not finish_reason_updated and delta_output.delta_text == "": # Ignore empty delta text when finish reason is not updated. continue @@ -1496,29 +1052,23 @@ def process_completion_stream_output( # pylint: disable=too-many-arguments ) ) - if len(choices) == 0 and num_new_completion_tokens == 0: + if len(choices) == 0: # Skip return when there is no delta output and no number of completion tokens. - return None, num_completion_tokens - num_completion_tokens += num_new_completion_tokens + return None response = openai_api_protocol.CompletionResponse( id=request_id, choices=choices, - model=model, - usage=openai_api_protocol.UsageInfo( - prompt_tokens=prompt_length, - completion_tokens=num_completion_tokens, - ), + model=request.model, + usage=None, ) engine_state.record_event(request_id, event="yield delta output") - return response, num_completion_tokens + return response def create_completion_suffix_response( request: openai_api_protocol.CompletionRequest, request_id: str, - prompt_length: int, finish_reasons: List[Optional[str]], - num_completion_tokens: int, ) -> Optional[openai_api_protocol.CompletionResponse]: """Create the suffix response of Completion request when the request requires suffix. @@ -1531,17 +1081,11 @@ def create_completion_suffix_response( request_id : str The id of the request. - prompt_length : int - The total prompt length. - finish_reasons : List[Optional[str]] The list of finish reasons of each generation. The list length is the number of parallel generation specified by "n". This list is updated in place. - num_completion_tokens : int - The number of total completion tokens so far. - Returns ------- suffix_response : Optional[openai_api_protocol.CompletionResponse] @@ -1563,10 +1107,7 @@ def create_completion_suffix_response( for i, finish_reason in enumerate(finish_reasons) ], model=request.model, - usage=openai_api_protocol.UsageInfo( - prompt_tokens=prompt_length, - completion_tokens=num_completion_tokens, - ), + usage=None, ) return response @@ -1640,8 +1181,7 @@ def wrap_chat_completion_response( # pylint: disable=too-many-arguments tool_calls_list: List[List[openai_api_protocol.ChatToolCall]], logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]], use_function_calling: bool, - num_prompt_tokens: int, - num_completion_tokens: int, + usage: Optional[Dict[str, Any]], ) -> openai_api_protocol.ChatCompletionResponse: """Wrap the non-streaming chat completion results to ChatCompletionResponse instance.""" return openai_api_protocol.ChatCompletionResponse( @@ -1669,9 +1209,7 @@ def wrap_chat_completion_response( # pylint: disable=too-many-arguments ], model=model, system_fingerprint="", - usage=openai_api_protocol.UsageInfo( - prompt_tokens=num_prompt_tokens, completion_tokens=num_completion_tokens - ), + usage=usage, ) @@ -1681,8 +1219,7 @@ def wrap_completion_response( # pylint: disable=too-many-arguments output_texts: List[str], finish_reasons: List[str], logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]], - num_prompt_tokens: int, - num_completion_tokens: int, + usage: openai_api_protocol.CompletionUsage, ) -> openai_api_protocol.CompletionResponse: """Wrap the non-streaming completion results to CompletionResponse instance.""" return openai_api_protocol.CompletionResponse( @@ -1701,7 +1238,5 @@ def wrap_completion_response( # pylint: disable=too-many-arguments for i, (output_text, finish_reason) in enumerate(zip(output_texts, finish_reasons)) ], model=model, - usage=openai_api_protocol.UsageInfo( - prompt_tokens=num_prompt_tokens, completion_tokens=num_completion_tokens - ), + usage=usage, ) diff --git a/python/mlc_llm/serve/engine_utils.py b/python/mlc_llm/serve/engine_utils.py index d1c96e37d4..6ccbc0e621 100644 --- a/python/mlc_llm/serve/engine_utils.py +++ b/python/mlc_llm/serve/engine_utils.py @@ -1,11 +1,83 @@ """Utility functions for MLC Serve engine""" import uuid -from typing import Callable, List, Union +from typing import Any, Callable, Dict, List, Optional, Union +from mlc_llm.protocol import error_protocol, openai_api_protocol +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve import data -from ..protocol import RequestProtocol, error_protocol, protocol_utils +RequestProtocol = Union[ + openai_api_protocol.CompletionRequest, openai_api_protocol.ChatCompletionRequest +] + + +def get_unsupported_fields(request: RequestProtocol) -> List[str]: + """Get the unsupported fields of the request. + Return the list of unsupported field names. + """ + if isinstance( + request, (openai_api_protocol.CompletionRequest, openai_api_protocol.ChatCompletionRequest) + ): + return openai_api_protocol.openai_api_get_unsupported_fields(request) + raise RuntimeError("Cannot reach here") + + +def openai_api_get_generation_config(request: RequestProtocol) -> Dict[str, Any]: + """Create the generation config from the given request.""" + kwargs: Dict[str, Any] = {} + arg_names = [ + "n", + "temperature", + "top_p", + "max_tokens", + "frequency_penalty", + "presence_penalty", + "logprobs", + "top_logprobs", + "logit_bias", + "seed", + "response_format", + "debug_config", + ] + for arg_name in arg_names: + kwargs[arg_name] = getattr(request, arg_name) + if kwargs["max_tokens"] is None: + # Setting to -1 means the generation will not stop until + # exceeding model capability or hit any stop criteria. + kwargs["max_tokens"] = -1 + if request.stop is not None: + kwargs["stop_strs"] = [request.stop] if isinstance(request.stop, str) else request.stop + return kwargs + + +def get_generation_config( + request: RequestProtocol, + extra_stop_token_ids: Optional[List[int]] = None, + extra_stop_str: Optional[List[str]] = None, +) -> GenerationConfig: + """Create the generation config in MLC LLM out from the input request protocol.""" + kwargs: Dict[str, Any] + if isinstance( + request, (openai_api_protocol.CompletionRequest, openai_api_protocol.ChatCompletionRequest) + ): + kwargs = openai_api_get_generation_config(request) + else: + raise RuntimeError("Cannot reach here") + + if extra_stop_token_ids is not None: + stop_token_ids = kwargs.get("stop_token_ids", []) + assert isinstance(stop_token_ids, list) + stop_token_ids += extra_stop_token_ids + kwargs["stop_token_ids"] = stop_token_ids + + if extra_stop_str is not None: + stop_strs = kwargs.get("stop_strs", []) + assert isinstance(stop_strs, list) + stop_strs += extra_stop_str + kwargs["stop_strs"] = stop_strs + + return GenerationConfig(**kwargs) def random_uuid() -> str: @@ -15,7 +87,7 @@ def random_uuid() -> str: def check_unsupported_fields(request: RequestProtocol) -> None: """Check if the request has unsupported fields. Raise BadRequestError if so.""" - unsupported_fields = protocol_utils.get_unsupported_fields(request) + unsupported_fields = get_unsupported_fields(request) if len(unsupported_fields) != 0: unsupported_fields = [f'"{field}"' for field in unsupported_fields] raise error_protocol.BadRequestError( diff --git a/python/mlc_llm/serve/entrypoints/__init__.py b/python/mlc_llm/serve/entrypoints/__init__.py index 3002bf80c7..c846cefe15 100644 --- a/python/mlc_llm/serve/entrypoints/__init__.py +++ b/python/mlc_llm/serve/entrypoints/__init__.py @@ -1,2 +1,3 @@ """The entrypoints for MLC LLM server.""" -from . import debug_entrypoints, openai_entrypoints + +from . import debug_entrypoints, metrics_entrypoints, openai_entrypoints diff --git a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py index af1613c027..1f1170b42b 100644 --- a/python/mlc_llm/serve/entrypoints/debug_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/debug_entrypoints.py @@ -34,7 +34,7 @@ async def debug_dump_event_trace(request: fastapi.Request): HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" ) - # - Check the requested model. + # Check the requested model. model = request_dict["model"] server_context: ServerContext = ServerContext.current() @@ -79,3 +79,56 @@ async def debug_cuda_profiler_stop(_request: fastapi.Request): "mlc.debug_cuda_profiler_stop" ) break + + +@app.post("/debug/dump_engine_metrics") +async def debug_dump_engine_metrics(request: fastapi.Request): + """Dump the engine metrics for the engine. Only for debug purpose.""" + # Get the raw request body as bytes + request_raw_data = await request.body() + request_json_str = request_raw_data.decode("utf-8") + try: + # Parse the JSON string + request_dict = json.loads(request_json_str) + except json.JSONDecodeError: + return error_protocol.create_error_response( + HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" + ) + if "model" not in request_dict: + return error_protocol.create_error_response( + HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" + ) + + # Check the requested model. + model = request_dict["model"] + + server_context: ServerContext = ServerContext.current() + async_engine = server_context.get_engine(model) + res = async_engine.metrics() + return res + + +@app.post("/debug/reset_engine") +async def debug_reset_engine_stats(request: fastapi.Request): + """Reset the engine, clean up all running data and metrics.""" + # Get the raw request body as bytes + request_raw_data = await request.body() + request_json_str = request_raw_data.decode("utf-8") + try: + # Parse the JSON string + request_dict = json.loads(request_json_str) + except json.JSONDecodeError: + return error_protocol.create_error_response( + HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" + ) + if "model" not in request_dict: + return error_protocol.create_error_response( + HTTPStatus.BAD_REQUEST, message=f"Invalid request {request_json_str}" + ) + + # Check the requested model. + model = request_dict["model"] + + server_context: ServerContext = ServerContext.current() + async_engine = server_context.get_engine(model) + async_engine.reset() diff --git a/python/mlc_llm/serve/entrypoints/metrics_entrypoints.py b/python/mlc_llm/serve/entrypoints/metrics_entrypoints.py new file mode 100644 index 0000000000..71ee65d65b --- /dev/null +++ b/python/mlc_llm/serve/entrypoints/metrics_entrypoints.py @@ -0,0 +1,23 @@ +"""MLC LLM server metrics entrypoints""" + +import fastapi +from fastapi.responses import PlainTextResponse + +from mlc_llm.serve.server import ServerContext + +app = fastapi.APIRouter() + +################ /metrics ################ + + +@app.get("/metrics", response_class=PlainTextResponse) +async def metrics(_request: fastapi.Request): + """Start the cuda profiler for the engine. Only for debug purpose.""" + server_context: ServerContext = ServerContext.current() + # Use the metrics from first engine for now + # TODO(mlc-team): consider refactor server context to + # single engine since multiple AsyncMLCEngine do not work well with each other + # We need to work within the internal engine instead. + for model in server_context.get_model_list(): + async_engine = server_context.get_engine(model) + return (await async_engine.metrics()).prometheus_text() diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index 23a279021f..7f62c2ad3f 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -1,6 +1,8 @@ """OpenAI API-compatible server entrypoints in MLC LLM""" # pylint: disable=too-many-locals,too-many-return-statements,too-many-statements +import json +from datetime import datetime from http import HTTPStatus from typing import AsyncGenerator, List, Optional @@ -18,12 +20,11 @@ from mlc_llm.serve.server import ServerContext app = fastapi.APIRouter() - ################ v1/models ################ @app.get("/v1/models") -async def request_models(): +async def request_models() -> ListResponse: """OpenAI-compatible served model query API. API reference: https://platform.openai.com/docs/api-reference/models """ @@ -41,6 +42,12 @@ async def request_completion(request: CompletionRequest, raw_request: fastapi.Re """ # - Check the requested model. server_context: ServerContext = ServerContext.current() + request_final_usage_include_extra = server_context.enable_debug + request_include_debug_config = server_context.enable_debug + + if not request_include_debug_config: + request.debug_config = None + async_engine = server_context.get_engine(request.model) if async_engine is None: return error_protocol.create_error_response( @@ -54,7 +61,7 @@ async def request_completion(request: CompletionRequest, raw_request: fastapi.Re # capture potential exceptions in this scope, rather then # the StreamingResponse scope. stream_generator = async_engine._handle_completion( # pylint: disable=protected-access - request, request_id + request, request_id, request_final_usage_include_extra=request_final_usage_include_extra ) first_response = await anext( # type: ignore # pylint: disable=undefined-variable stream_generator @@ -64,9 +71,9 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: if isinstance(first_response, StopAsyncIteration): yield "data: [DONE]\n\n" return - yield f"data: {first_response.model_dump_json()}\n\n" + yield f"data: {first_response.model_dump_json(by_alias=True)}\n\n" async for response in stream_generator: - yield f"data: {response.model_dump_json()}\n\n" + yield f"data: {response.model_dump_json(by_alias=True)}\n\n" yield "data: [DONE]\n\n" return fastapi.responses.StreamingResponse( @@ -74,8 +81,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) # Normal response. - num_prompt_tokens = 0 - num_completion_tokens = 0 + request_final_usage = None output_texts = ["" for _ in range(request.n)] finish_reasons: List[Optional[str]] = [None for _ in range(request.n)] logprob_results: Optional[List[List[LogProbsContent]]] = ( @@ -83,7 +89,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) async for response in async_engine._handle_completion( # pylint: disable=protected-access - request, request_id + request, request_id, request_final_usage_include_extra=request_final_usage_include_extra ): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified @@ -94,8 +100,13 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" ) - num_prompt_tokens = response.usage.prompt_tokens - num_completion_tokens = response.usage.completion_tokens + # this is the final chunk + if response.usage is not None: + request_final_usage = response.usage + # remove extra information if debug is not enabled + if not server_context.enable_debug: + request_final_usage.extra = None + continue for choice in response.choices: output_texts[choice.index] += choice.text if choice.finish_reason is not None and finish_reasons[choice.index] is None: @@ -111,8 +122,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: output_texts=output_texts, finish_reasons=finish_reasons, logprob_results=logprob_results, - num_prompt_tokens=num_prompt_tokens, - num_completion_tokens=num_completion_tokens, + usage=request_final_usage, ) @@ -128,6 +138,28 @@ async def request_chat_completion( """ # - Check the requested model. server_context: ServerContext = ServerContext.current() + request_final_usage_include_extra = server_context.enable_debug + request_include_debug_config = server_context.enable_debug + + if server_context.enable_debug: + import structlog # pylint: disable=import-outside-toplevel,import-error + + logger = structlog.stdlib.get_logger(__name__) + + request_param = await raw_request.json() + timestamp = {"timestamp": datetime.now().isoformat()} + request_param = {**timestamp, **request_param} + try: + logger.info("Received chat completion request", request=json.dumps(request_param)) + except ( # pylint: disable=broad-exception-caught + Exception, + json.JSONDecodeError, + ) as err: + logger.error("Error in dumping request parameters: %s", err) + + if not request_include_debug_config: + request.debug_config = None + async_engine = server_context.get_engine(request.model) if async_engine is None: return error_protocol.create_error_response( @@ -141,7 +173,7 @@ async def request_chat_completion( # capture potential exceptions in this scope, rather then # the StreamingResponse scope. stream_generator = async_engine._handle_chat_completion( # pylint: disable=protected-access - request, request_id + request, request_id, request_final_usage_include_extra=request_final_usage_include_extra ) first_response = await anext( # type: ignore # pylint: disable=undefined-variable stream_generator @@ -151,9 +183,9 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: if isinstance(first_response, StopAsyncIteration): yield "data: [DONE]\n\n" return - yield f"data: {first_response.model_dump_json()}\n\n" + yield f"data: {first_response.model_dump_json(by_alias=True)}\n\n" async for response in stream_generator: - yield f"data: {response.model_dump_json()}\n\n" + yield f"data: {response.model_dump_json(by_alias=True)}\n\n" yield "data: [DONE]\n\n" return fastapi.responses.StreamingResponse( @@ -161,8 +193,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) # Normal response. - num_prompt_tokens = 0 - num_completion_tokens = 0 + request_final_usage = None output_texts = ["" for _ in range(request.n)] finish_reasons: List[Optional[str]] = [None for _ in range(request.n)] logprob_results: Optional[List[List[LogProbsContent]]] = ( @@ -170,7 +201,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) async for response in async_engine._handle_chat_completion( # pylint: disable=protected-access - request, request_id + request, request_id, request_final_usage_include_extra=request_final_usage_include_extra ): if await raw_request.is_disconnected(): # In non-streaming cases, the engine will not be notified @@ -181,8 +212,13 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" ) - num_prompt_tokens = response.usage.prompt_tokens - num_completion_tokens = response.usage.completion_tokens + # usage is always the last chunk + if response.usage is not None: + request_final_usage = response.usage + # remove extra information if debug is not enabled + if not server_context.enable_debug: + request_final_usage.extra = None + for choice in response.choices: assert isinstance(choice.delta.content, str) output_texts[choice.index] += choice.delta.content @@ -196,6 +232,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: use_function_calling, tool_calls_list = engine_base.process_function_call_output( output_texts, finish_reasons ) + return engine_base.wrap_chat_completion_response( request_id=request_id, model=request.model, @@ -204,6 +241,5 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: tool_calls_list=tool_calls_list, logprob_results=logprob_results, use_function_calling=use_function_calling, - num_prompt_tokens=num_prompt_tokens, - num_completion_tokens=num_completion_tokens, + usage=request_final_usage, ) diff --git a/python/mlc_llm/serve/radix_tree.py b/python/mlc_llm/serve/radix_tree.py index 102cdac675..5561e8f846 100644 --- a/python/mlc_llm/serve/radix_tree.py +++ b/python/mlc_llm/serve/radix_tree.py @@ -13,20 +13,11 @@ class PagedRadixTree(Object): """The paged radix tree to manage prefix and sequence.""" - def __init__(self, num_pages: int, page_size: int, num_seqs: int): + def __init__(self): """ Constructor of paged radix tree. - - Parameters - ---------- - num_pages : int - The number of radix tree pages. - page_size : int - The page size of each radix tree page. - num_seqs : int - The maximum number of sequence ID. """ - self.__init_handle_by_constructor__(_ffi_api.PagedRadixTree, num_pages, page_size, num_seqs) # type: ignore # pylint: disable=no-member + self.__init_handle_by_constructor__(_ffi_api.PagedRadixTree) # type: ignore # pylint: disable=no-member def match(self, tokens: Union[ShapeTuple, List, Tuple]) -> Tuple[int, ShapeTuple]: """ @@ -53,7 +44,7 @@ def match(self, tokens: Union[ShapeTuple, List, Tuple]) -> Tuple[int, ShapeTuple def add(self, seq_id: int) -> None: """ - Get all sequences with longest common prefix with give prefix tokens. + Add an empty sequence. Parameters ---------- @@ -75,7 +66,7 @@ def remove(self, seq_id: int) -> None: def extend(self, seq_id: int, tokens: Union[ShapeTuple, List, Tuple]) -> None: """ - Get all sequences with longest common prefix with give prefix tokens. + Extend a sequence with given tokens. Parameters ---------- @@ -88,6 +79,19 @@ def extend(self, seq_id: int, tokens: Union[ShapeTuple, List, Tuple]) -> None: tokens = ShapeTuple(tokens) _ffi_api.PagedRadixTreeExtendSequence(self, seq_id, tokens) # type: ignore # pylint: disable=no-member + def rollback(self, seq_id: int, num_tokens: int) -> None: + """ + Roll back a sequence by number of tokens. + + Parameters + ---------- + seq_id : int + The sequence ID for index. + num_tokens : int + The number of tokens to be rolled back. + """ + _ffi_api.PagedRadixTreeRollBackSequence(self, seq_id, num_tokens) # type: ignore # pylint: disable=no-member + def fork(self, seq_id: int, parent_seq_id: int, forked_offset: int) -> None: """ Fork a sequence from parent sequence at given position. diff --git a/python/mlc_llm/serve/request.py b/python/mlc_llm/serve/request.py index 5c2d8ad196..85e5c5410d 100644 --- a/python/mlc_llm/serve/request.py +++ b/python/mlc_llm/serve/request.py @@ -1,12 +1,13 @@ """The request class in MLC LLM serving""" -from typing import List, Union +from typing import List import tvm._ffi from tvm.runtime import Object +from mlc_llm.protocol.generation_config import GenerationConfig + from . import _ffi_api -from .config import GenerationConfig from .data import Data @@ -16,35 +17,12 @@ class Request(Object): a unique request id, a list of multi-modal inputs, a set of generation configuration parameters. - Parameters - ---------- - request_id : str - The unique identifier of the request. - Different requests should have different ids. - - inputs : List[Data] - The user inputs of a request. Input may have multi-modality. - - generation_config : GenerationConfig - The sampling configuration which may contain temperature, - top_p, repetition_penalty, max_gen_len, etc. + Note + ---- + Do not explicitly construct this class. + Construct this object via engine.create_request functions. """ - def __init__( - self, - request_id: str, - inputs: Union[Data, List[Data]], - generation_config: GenerationConfig, - ): - if not isinstance(inputs, list): - inputs = [inputs] - self.__init_handle_by_constructor__( - _ffi_api.Request, # type: ignore # pylint: disable=no-member - request_id, - inputs, - generation_config.asjson(), - ) - @property def inputs(self) -> List[Data]: """The inputs of the request.""" @@ -53,6 +31,6 @@ def inputs(self) -> List[Data]: @property def generation_config(self) -> GenerationConfig: """The generation config of the request.""" - return GenerationConfig.from_json( + return GenerationConfig.model_validate_json( _ffi_api.RequestGetGenerationConfigJSON(self) # type: ignore # pylint: disable=no-member ) diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index 1d17f8e66a..54e0b0c7df 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -5,84 +5,113 @@ import sys import time from pathlib import Path -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union import psutil import requests from tvm.runtime import Device -from mlc_llm.serve.config import SpeculativeMode +from mlc_llm.serve.config import EngineConfig +from mlc_llm.serve.engine_base import _check_engine_config class PopenServer: # pylint: disable=too-many-instance-attributes """The wrapper of MLC LLM server, which runs the server in - a background subprocess.""" + a background subprocess. + + This server can be used for debugging purposes. + """ def __init__( # pylint: disable=too-many-arguments self, model: str, device: Union[str, Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", - additional_models: Optional[List[str]] = None, - max_batch_size: Optional[int] = None, - max_total_sequence_length: Optional[int] = None, - prefill_chunk_size: Optional[int] = None, - gpu_memory_utilization: Optional[float] = None, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, - spec_draft_length: int = 4, + engine_config: Optional[EngineConfig] = None, + enable_debug: bool = True, enable_tracing: bool = False, host: str = "127.0.0.1", - port: int = 8000, + port: int = 8082, ) -> None: """Please check out `python/mlc_llm/cli/serve.py` for the server arguments.""" + # - Check the fields fields of `engine_config`. + if engine_config is None: + engine_config = EngineConfig() + _check_engine_config(model, model_lib, mode, engine_config) + self.model = model - self.model_lib_path = model_lib_path + self.model_lib = model_lib self.device = device self.mode = mode - self.additional_models = additional_models - self.max_batch_size = max_batch_size - self.max_total_sequence_length = max_total_sequence_length - self.prefill_chunk_size = prefill_chunk_size - self.gpu_memory_utilization = gpu_memory_utilization - self.speculative_mode = speculative_mode - self.spec_draft_length = spec_draft_length + self.enable_debug = enable_debug + self.engine_config = engine_config self.enable_tracing = enable_tracing + self.enable_debug = enable_debug self.host = host self.port = port self._proc: Optional[subprocess.Popen] = None - def start(self) -> None: # pylint: disable=too-many-branches + self.base_url = "" + self.openai_v1_base_url = "" + + def start(self) -> None: # pylint: disable=too-many-branches,too-many-statements """Launch the server in a popen subprocess. Wait until the server becomes ready before return. """ cmd = [sys.executable] cmd += ["-m", "mlc_llm", "serve", self.model] - if self.model_lib_path is not None: - cmd += ["--model-lib-path", self.model_lib_path] + if self.model_lib is not None: + cmd += ["--model-lib", self.model_lib] cmd += ["--device", self.device] + + if self.enable_debug: + cmd += ["--enable-debug"] + if self.mode is not None: cmd += ["--mode", self.mode] - if self.additional_models is not None: - cmd += ["--additional-models", *self.additional_models] - if self.max_batch_size is not None: - cmd += ["--max-batch-size", str(self.max_batch_size)] - if self.max_total_sequence_length is not None: - cmd += ["--max-total-seq-length", str(self.max_total_sequence_length)] - if self.prefill_chunk_size is not None: - cmd += ["--prefill-chunk-size", str(self.prefill_chunk_size)] - if self.speculative_mode != SpeculativeMode.DISABLE: - cmd += [ - "--speculative-mode", - self.speculative_mode.name, - "--spec-draft-length", - str(self.spec_draft_length), - ] - if self.gpu_memory_utilization is not None: - cmd += ["--gpu-memory-utilization", str(self.gpu_memory_utilization)] + + if len(self.engine_config.additional_models) > 0: + args_additional_model = [] + for additional_model in self.engine_config.additional_models: + if isinstance(additional_model, str): + args_additional_model.append(additional_model) + else: + args_additional_model.append(additional_model[0] + "," + additional_model[1]) + cmd += ["--additional-models", *args_additional_model] + cmd += ["--speculative-mode", self.engine_config.speculative_mode] + cmd += ["--prefix-cache-mode", self.engine_config.prefix_cache_mode] + + args_overrides = [] + if self.engine_config.max_num_sequence is not None: + args_overrides.append(f"max_num_sequence={self.engine_config.max_num_sequence}") + if self.engine_config.max_total_sequence_length is not None: + args_overrides.append( + f"max_total_seq_length={self.engine_config.max_total_sequence_length}" + ) + if self.engine_config.prefill_chunk_size is not None: + args_overrides.append(f"prefill_chunk_size={self.engine_config.prefill_chunk_size}") + if self.engine_config.max_history_size is not None: + args_overrides.append(f"max_history_size={self.engine_config.max_history_size}") + if self.engine_config.gpu_memory_utilization is not None: + args_overrides.append( + f"gpu_memory_utilization={self.engine_config.gpu_memory_utilization}" + ) + if self.engine_config.spec_draft_length is not None: + args_overrides.append(f"spec_draft_length={self.engine_config.spec_draft_length}") + if self.engine_config.prefix_cache_max_num_recycling_seqs is not None: + args_overrides.append( + "prefix_cache_max_num_recycling_seqs=" + + str(self.engine_config.prefix_cache_max_num_recycling_seqs) + ) + if len(args_overrides) > 0: + cmd += ["--overrides", ";".join(args_overrides)] + if self.enable_tracing: cmd += ["--enable-tracing"] + if self.enable_debug: + cmd += ["--enable-debug"] cmd += ["--host", self.host] cmd += ["--port", str(self.port)] @@ -95,9 +124,12 @@ def start(self) -> None: # pylint: disable=too-many-branches # and hang forever. # Try to query the server until it is ready. - openai_v1_models_url = f"http://{self.host}:{str(self.port)}/v1/models" + self.base_url = f"http://{self.host}:{str(self.port)}" + self.openai_v1_base_url = f"http://{self.host}:{str(self.port)}/v1" + openai_v1_models_url = f"{self.base_url}/v1/models" + query_result = None - timeout = 60 + timeout = 120 attempts = 0.0 while query_result is None and attempts < timeout: try: diff --git a/python/mlc_llm/serve/server/server_context.py b/python/mlc_llm/serve/server/server_context.py index d6acd4a2be..2f4bf26626 100644 --- a/python/mlc_llm/serve/server/server_context.py +++ b/python/mlc_llm/serve/server/server_context.py @@ -11,8 +11,9 @@ class ServerContext: """ server_context: Optional["ServerContext"] = None + enable_debug: bool = False - def __init__(self): + def __init__(self) -> None: self._models: Dict[str, AsyncMLCEngine] = {} def __enter__(self): diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 1be841cb08..6c2c7b701a 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -13,19 +13,21 @@ import tvm +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve import data -from mlc_llm.serve.config import EngineConfig, GenerationConfig, SpeculativeMode +from mlc_llm.serve.config import EngineConfig from mlc_llm.serve.engine_base import ( - _infer_kv_cache_config, + EngineMetrics, + _check_engine_config, _parse_models, + _print_engine_mode_logging_msg, _process_model_args, detect_device, ) from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.serve.request import Request -from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging -from mlc_llm.tokenizer import Tokenizer +from mlc_llm.tokenizers import TextStreamer, Tokenizer logging.enable_logging() logger = logging.getLogger(__name__) @@ -58,12 +60,12 @@ class SyncMLCEngine: Parameters ---------- - models : Union[ModelInfo, List[ModelInfo]] - One or a list of model info (specifying which models to load and - which device to load to) to launch the engine. + engine_config : Optional[EngineConfig] + Additional configurable arguments of MLC engine. + See class "EngineConfig" for more detail. - kv_cache_config : KVCacheConfig - The configuration of the paged KV cache. + enable_tracing : bool + A boolean indicating if to enable event logging for requests. request_stream_callback : Optional[Callable[[str, data.TokenData, Optional[str]], None]] The provided callback function to handle the generation @@ -79,12 +81,6 @@ class SyncMLCEngine: be set before the engine executing requests. This can be done via the `set_request_stream_callback` method. Otherwise, the engine will raise exception. - - engine_config : Optional[EngineConfig] - The Engine execution configuration. - - enable_tracing : bool - A boolean indicating if to enable event logging for requests. """ def __init__( # pylint: disable=too-many-arguments,too-many-locals @@ -92,21 +88,24 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model: str, device: Union[str, tvm.runtime.Device] = "auto", *, - model_lib_path: Optional[str] = None, + model_lib: Optional[str] = None, mode: Literal["local", "interactive", "server"] = "local", - additional_models: Optional[List[str]] = None, - max_batch_size: Optional[int] = None, - max_total_sequence_length: Optional[int] = None, - prefill_chunk_size: Optional[int] = None, - max_history_size: Optional[int] = None, - gpu_memory_utilization: Optional[float] = None, + engine_config: Optional[EngineConfig] = None, enable_tracing: bool = False, - speculative_mode: SpeculativeMode = SpeculativeMode.DISABLE, - spec_draft_length: int = 4, request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None, ): + # - Check the fields fields of `engine_config`. + if engine_config is None: + engine_config = EngineConfig() + _check_engine_config( + model, + model_lib, + mode, + engine_config, + ) + # - Initialize model loading info. - models = _parse_models(model, model_lib_path, additional_models) + models = _parse_models(model, model_lib, engine_config.additional_models) if isinstance(device, str): device = detect_device(device) assert isinstance(device, tvm.runtime.Device) @@ -114,36 +113,18 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals model_args, model_config_paths, self.conv_template, - ) = _process_model_args(models, device) + ) = _process_model_args(models, device, engine_config) # - Load the raw model config into dict self.model_config_dicts = [] for i, model_info in enumerate(models): - model_info.model_lib_path = model_args[i][1] + model_info.model_lib = model_args[i][1] with open(model_config_paths[i], "r", encoding="utf-8") as file: self.model_config_dicts.append(json.load(file)) - # - Decide the KV cache config based on mode and user input. - ( - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_single_sequence_length, - max_history_size, - kv_state_kind, - ) = _infer_kv_cache_config( - mode, - max_batch_size, - max_total_sequence_length, - prefill_chunk_size, - max_history_size, - gpu_memory_utilization, - models, - device, - self.model_config_dicts, - model_config_paths, - ) - self.max_input_sequence_length = min(max_single_sequence_length, max_total_sequence_length) + # - Print logging info for regarding the mode selection. + if engine_config.verbose: + _print_engine_mode_logging_msg(mode) self._ffi = _create_tvm_module( "mlc.serve.create_engine", @@ -152,30 +133,21 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals "add_request", "abort_request", "step", - "stats", "reset", + "json_metrics", "get_request_stream_callback", "set_request_stream_callback", + "create_request", ], ) self.trace_recorder = EventTraceRecorder() if enable_tracing else None + engine_config.model = model_args[0][0] + engine_config.model_lib = model_args[0][1] + engine_config.additional_models = model_args[1:] # type: ignore + engine_config.mode = mode self._ffi["init"]( - EngineConfig( - model=model_args[0][0], - model_lib_path=model_args[0][1], - additional_models=[model_arg[0] for model_arg in model_args[1:]], - additional_model_lib_paths=[model_arg[1] for model_arg in model_args[1:]], - kv_cache_page_size=16, - max_num_sequence=max_batch_size, - max_total_sequence_length=max_total_sequence_length, - max_single_sequence_length=max_single_sequence_length, - prefill_chunk_size=prefill_chunk_size, - max_history_size=max_history_size, - kv_state_kind=kv_state_kind, - speculative_mode=speculative_mode, - spec_draft_length=spec_draft_length, - ), + engine_config.asjson(), device, request_stream_callback, self.trace_recorder, @@ -266,7 +238,7 @@ def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]): request_id, stream_outputs = delta_output.unpack() rid = int(request_id) - assert len(stream_outputs) == generation_config[rid].n + assert len(stream_outputs) == generation_config[rid].n # type:ignore for i, (stream_output, text_streamer) in enumerate( zip(stream_outputs, text_streamers[rid]) ): @@ -274,7 +246,7 @@ def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]): assert stream_output.delta_logprob_json_strs is not None output_logprobs_str[rid][i] += stream_output.delta_logprob_json_strs - delta_text = ( + delta_text = stream_output.extra_prefix_string + ( text_streamer.put(stream_output.delta_token_ids) if len(stream_output.delta_token_ids) > 0 else "" @@ -300,7 +272,7 @@ def convert_to_data(prompt: Union[str, List[int], List[data.Data]]) -> List[data for req_id, (prompt, generation_cfg) in enumerate(zip(prompts, generation_config)): input_data = convert_to_data(prompt) # type: ignore self.add_request( - Request( + self.create_request( request_id=str(req_id), inputs=input_data, generation_config=generation_cfg, @@ -314,6 +286,36 @@ def convert_to_data(prompt: Union[str, List[int], List[data.Data]]) -> List[data self._ffi["set_request_stream_callback"](original_callback) return output_texts, output_logprobs_str + def create_request( + self, + request_id: str, + inputs: Union[data.Data, List[data.Data]], + generation_config: GenerationConfig, + ): + """Create a new request that can be added to engine. + + Parameters + ---------- + request_id : str + The unique identifier of the request. + Different requests should have different ids. + + inputs : List[Data] + The user inputs of a request. Input may have multi-modality. + + generation_config : GenerationConfig + The generation configuration of the request. + + Note + ---- + engine may fill in default generation config of the model. + """ + if not isinstance(inputs, list): + inputs = [inputs] + return self._ffi["create_request"]( + request_id, inputs, generation_config.model_dump_json(by_alias=True) + ) + def add_request(self, request: Request) -> None: """Add a new request to the engine. @@ -349,18 +351,9 @@ def step(self) -> None: self._ffi["step"]() def reset(self) -> None: - """Reset the engine, clean up all running data and statistics.""" + """Reset the engine, clean up all running data and metrics.""" self._ffi["reset"]() - def stats(self) -> Dict[str, float]: - """The engine runtime statistics. - We collect the following entries: - - single token prefill latency (s/tok): avg latency of processing one token in prefill - - single token decode latency (s/tok): avg latency of processing one token in decode - - engine time for prefill (sec) - - engine time for decode (sec) - - total number of processed tokens in prefill. - - total number of processed tokens in decode. - """ - stats_json_str = self._ffi["stats"]() - return json.loads(stats_json_str) + def metrics(self) -> EngineMetrics: + """Reset the engine, clean up all running data and metrics.""" + return EngineMetrics(json.loads(self._ffi["json_metrics"]())) diff --git a/python/mlc_llm/streamer.py b/python/mlc_llm/streamer.py index 1eb88afb97..37179f17f7 100644 --- a/python/mlc_llm/streamer.py +++ b/python/mlc_llm/streamer.py @@ -7,7 +7,7 @@ from tvm.runtime import Object, ShapeTuple from . import _ffi_api -from .tokenizer import Tokenizer +from .tokenizers import Tokenizer @tvm._ffi.register_object("mlc.TextStreamer") # pylint: disable=protected-access diff --git a/python/mlc_llm/support/argparse.py b/python/mlc_llm/support/argparse.py index 81211e8e07..6d36f43c83 100644 --- a/python/mlc_llm/support/argparse.py +++ b/python/mlc_llm/support/argparse.py @@ -1,4 +1,5 @@ """An enhanced argument parser for mlc-chat.""" + import argparse import sys diff --git a/python/mlc_llm/support/auto_config.py b/python/mlc_llm/support/auto_config.py index be0ee8af98..a1c9e0bc70 100644 --- a/python/mlc_llm/support/auto_config.py +++ b/python/mlc_llm/support/auto_config.py @@ -1,4 +1,5 @@ """Help function for detecting the model configuration file `config.json`""" + import json import tempfile from pathlib import Path @@ -35,12 +36,12 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path: # pylint: disable=import-outside-toplevel from mlc_llm.model import MODEL_PRESETS - from .download import download_mlc_weights + from .download_cache import download_and_cache_mlc_weights # pylint: enable=import-outside-toplevel if mlc_chat_config.startswith("HF://") or mlc_chat_config.startswith("http"): - mlc_chat_config_path = Path(download_mlc_weights(model_url=mlc_chat_config)) + mlc_chat_config_path = Path(download_and_cache_mlc_weights(model_url=mlc_chat_config)) elif isinstance(mlc_chat_config, str) and mlc_chat_config in MODEL_PRESETS: logger.info("%s mlc preset model: %s", FOUND, mlc_chat_config) content = MODEL_PRESETS[mlc_chat_config].copy() diff --git a/python/mlc_llm/support/auto_target.py b/python/mlc_llm/support/auto_target.py index 5239756d9d..701b3c1bc8 100644 --- a/python/mlc_llm/support/auto_target.py +++ b/python/mlc_llm/support/auto_target.py @@ -220,15 +220,15 @@ def build(mod: IRModule, args: "CompileArgs", pipeline=None): # Try to locate `mlc_wasm_runtime.bc` bc_path = None bc_candidates = ["web/dist/wasm/mlc_wasm_runtime.bc"] - if os.environ.get("MLC_LLM_HOME", None): - mlc_source_home_dir = os.environ["MLC_LLM_HOME"] + if os.environ.get("MLC_LLM_SOURCE_DIR", None): + mlc_source_home_dir = os.environ["MLC_LLM_SOURCE_DIR"] bc_candidates.append( os.path.join(mlc_source_home_dir, "web", "dist", "wasm", "mlc_wasm_runtime.bc") ) error_info = ( "Cannot find library: mlc_wasm_runtime.bc\n" + "Make sure you have run `./web/prep_emcc_deps.sh` and " - + "`export MLC_LLM_HOME=/path/to/mlc-llm` so that we can locate the file. " + + "`export MLC_LLM_SOURCE_DIR=/path/to/mlc-llm` so that we can locate the file. " + "We tried to look at candidate paths:\n" ) for candidate in bc_candidates: @@ -400,6 +400,7 @@ def detect_system_lib_prefix( "target": { "kind": "opencl", "device": "adreno", + "max_threads_per_block": 512, "host": { "kind": "llvm", "mtriple": "aarch64-linux-android", @@ -411,6 +412,7 @@ def detect_system_lib_prefix( "target": { "kind": "opencl", "device": "adreno", + "max_threads_per_block": 512, "host": { "kind": "llvm", "mtriple": "aarch64-linux-android", diff --git a/python/mlc_llm/support/auto_weight.py b/python/mlc_llm/support/auto_weight.py index 84d8621026..5a561193fe 100644 --- a/python/mlc_llm/support/auto_weight.py +++ b/python/mlc_llm/support/auto_weight.py @@ -1,4 +1,5 @@ """Help functions for detecting weight paths and weight formats.""" + import json from pathlib import Path from typing import List, Optional, Tuple diff --git a/python/mlc_llm/support/config.py b/python/mlc_llm/support/config.py index e3ccfcec29..715a4b2fa4 100644 --- a/python/mlc_llm/support/config.py +++ b/python/mlc_llm/support/config.py @@ -9,6 +9,7 @@ The base class allows us to load the configuration from this JSON file, moving irrelevant fields into `kwargs`, such as `transformers_version` and `use_cache`. """ + # pylint: disable=too-few-public-methods import dataclasses import json diff --git a/python/mlc_llm/support/constants.py b/python/mlc_llm/support/constants.py index 82697ff71a..9e862a3b65 100644 --- a/python/mlc_llm/support/constants.py +++ b/python/mlc_llm/support/constants.py @@ -1,7 +1,11 @@ """Environment variables used by the MLC LLM.""" + import os import sys from pathlib import Path +from typing import List + +MLC_CHAT_CONFIG_VERSION = "0.1.0" def _check(): @@ -11,10 +15,17 @@ def _check(): f"but got {MLC_JIT_POLICY}." ) + if MLC_DOWNLOAD_CACHE_POLICY not in ["ON", "OFF", "REDO", "READONLY"]: + raise ValueError( + "Invalid MLC_AUTO_DOWNLOAD_POLICY. " + 'It has to be one of "ON", "OFF", "REDO", "READONLY"' + f"but got {MLC_DOWNLOAD_CACHE_POLICY}." + ) + def _get_cache_dir() -> Path: - if "MLC_CACHE_DIR" in os.environ: - result = Path(os.environ["MLC_CACHE_DIR"]) + if "MLC_LLM_HOME" in os.environ: + result = Path(os.environ["MLC_LLM_HOME"]) elif sys.platform == "win32": result = Path(os.environ["LOCALAPPDATA"]) result = result / "mlc_llm" @@ -28,7 +39,7 @@ def _get_cache_dir() -> Path: if not result.is_dir(): raise ValueError( f"The default cache directory is not a directory: {result}. " - "Use environment variable MLC_CACHE_DIR to specify a valid cache directory." + "Use environment variable MLC_LLM_HOME to specify a valid cache directory." ) (result / "model_weights").mkdir(parents=True, exist_ok=True) (result / "model_lib").mkdir(parents=True, exist_ok=True) @@ -45,11 +56,35 @@ def _get_dso_suffix() -> str: return "so" +def _get_test_model_path() -> List[Path]: + paths = [] + if "MLC_LLM_TEST_MODEL_PATH" in os.environ: + paths += [Path(p) for p in os.environ["MLC_LLM_TEST_MODEL_PATH"].split(os.pathsep)] + # by default, we reuse the cache dir via mlc_llm chat + # note that we do not auto download for testcase + # to avoid networking dependencies + base_list = ["hf"] + paths += [_get_cache_dir() / "model_weights" / base / "mlc-ai" for base in base_list] + [ + Path(os.path.abspath(os.path.curdir)), + Path(os.path.abspath(os.path.curdir)) / "dist", + ] + return paths + + +def _get_read_only_weight_caches() -> List[Path]: + if "MLC_LLM_READONLY_WEIGHT_CACHE" in os.environ: + return [Path(p) for p in os.environ["MLC_LLM_READONLY_WEIGHT_CACHE"].split(os.pathsep)] + return [] + + MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None) MLC_MULTI_ARCH = os.environ.get("MLC_MULTI_ARCH", None) -MLC_CACHE_DIR: Path = _get_cache_dir() MLC_JIT_POLICY = os.environ.get("MLC_JIT_POLICY", "ON") MLC_DSO_SUFFIX = _get_dso_suffix() +MLC_TEST_MODEL_PATH: List[Path] = _get_test_model_path() +MLC_DOWNLOAD_CACHE_POLICY = os.environ.get("MLC_DOWNLOAD_CACHE_POLICY", "ON") +MLC_LLM_HOME: Path = _get_cache_dir() +MLC_LLM_READONLY_WEIGHT_CACHE = _get_read_only_weight_caches() _check() diff --git a/python/mlc_llm/support/download.py b/python/mlc_llm/support/download.py index 770833e9af..c0ab9cdbc1 100644 --- a/python/mlc_llm/support/download.py +++ b/python/mlc_llm/support/download.py @@ -13,12 +13,26 @@ import requests # pylint: disable=import-error from . import logging, tqdm -from .constants import MLC_CACHE_DIR, MLC_TEMP_DIR +from .constants import ( + MLC_DOWNLOAD_CACHE_POLICY, + MLC_LLM_HOME, + MLC_LLM_READONLY_WEIGHT_CACHE, + MLC_TEMP_DIR, +) from .style import bold logger = logging.getLogger(__name__) +def log_download_cache_policy(): + """log current download policy""" + logger.info( + "%s = %s. Can be one of: ON, OFF, REDO, READONLY", + bold("MLC_DOWNLOAD_CACHE_POLICY"), + MLC_DOWNLOAD_CACHE_POLICY, + ) + + def _ensure_directory_not_exist(path: Path, force_redo: bool) -> None: if path.exists(): if force_redo: @@ -92,9 +106,9 @@ def download_file( ) -> Tuple[str, Path]: """Download a file from a URL to a destination file.""" with requests.get(url, stream=True, timeout=30) as response: - response.raise_for_status() + response.raise_for_status() # type: ignore with destination.open("wb") as file: - for chunk in response.iter_content(chunk_size=8192): + for chunk in response.iter_content(chunk_size=8192): # type: ignore file.write(chunk) if md5sum is not None: hash_md5 = hashlib.md5() @@ -110,12 +124,16 @@ def download_file( return url, destination -def download_mlc_weights( # pylint: disable=too-many-locals +def download_and_cache_mlc_weights( # pylint: disable=too-many-locals model_url: str, num_processes: int = 4, - force_redo: bool = False, + force_redo: Optional[bool] = None, ) -> Path: """Download weights for a model from the HuggingFace Git LFS repo.""" + log_download_cache_policy() + if MLC_DOWNLOAD_CACHE_POLICY == "OFF": + raise RuntimeError(f"Cannot download {model_url} as MLC_DOWNLOAD_CACHE_POLICY=OFF") + prefixes, mlc_prefix = ["HF://", "https://huggingface.co/"], "" mlc_prefix = next(p for p in prefixes if model_url.startswith(p)) assert mlc_prefix @@ -126,12 +144,36 @@ def download_mlc_weights( # pylint: disable=too-many-locals if model_url.count("/") != 1 + mlc_prefix.count("/") or not model_url.startswith(mlc_prefix): raise ValueError(f"Invalid model URL: {model_url}") user, repo = model_url[len(mlc_prefix) :].split("/") - git_dir = MLC_CACHE_DIR / "model_weights" / user / repo + domain = "hf" + + readonly_cache_dirs = [] + for base in MLC_LLM_READONLY_WEIGHT_CACHE: + cache_dir = base / domain / user / repo + readonly_cache_dirs.append(str(cache_dir)) + if (cache_dir / "mlc-chat-config.json").is_file(): + logger.info("Use cached weight: %s", bold(str(cache_dir))) + return cache_dir + + if force_redo is None: + force_redo = MLC_DOWNLOAD_CACHE_POLICY == "REDO" + + git_dir = MLC_LLM_HOME / "model_weights" / domain / user / repo + readonly_cache_dirs.append(str(git_dir)) + try: _ensure_directory_not_exist(git_dir, force_redo=force_redo) except ValueError: logger.info("Weights already downloaded: %s", bold(str(git_dir))) return git_dir + + if MLC_DOWNLOAD_CACHE_POLICY == "READONLY": + raise RuntimeError( + f"Cannot find cache for {model_url}, " + "cannot proceed to download as MLC_DOWNLOAD_CACHE_POLICY=READONLY, " + "please check settings MLC_LLM_READONLY_WEIGHT_CACHE, " + f"local path candidates: {readonly_cache_dirs}" + ) + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir_prefix: tmp_dir = Path(tmp_dir_prefix) / "tmp" git_url = git_url_template.format(user=user, repo=repo) @@ -155,3 +197,41 @@ def download_mlc_weights( # pylint: disable=too-many-locals logger.info("Moving %s to %s", tmp_dir, bold(str(git_dir))) shutil.move(str(tmp_dir), str(git_dir)) return git_dir + + +def get_or_download_model(model: str) -> Path: + """Use user-provided argument ``model`` to get model_path + + We define "valid" as having an ``mlc-chat-config.json`` right under the folder. + + Parameters + ---------- + model : str + User's input; may a path or url + + Returns + ------ + model_path : Path + A "valid" path to model folder, with + ``(model_path / "mlc-chat-config.json").is_file`` being True + + Note + ---- + This function may perform additional download and caching + + Raises + ------ + FileNotFoundError: if we cannot find a valid `model_path`. + """ + if model.startswith("HF://"): + logger.info("Downloading model from HuggingFace: %s", model) + model_path = download_and_cache_mlc_weights(model) + else: + model_path = Path(model) + + if not model_path.is_dir(): + raise FileNotFoundError(f"Cannot find model {model}, directory does not exist") + mlc_config_path = model_path / "mlc-chat-config.json" + if mlc_config_path.is_file(): + return model_path + raise FileNotFoundError(f"Cannot find {str(mlc_config_path)} in the model directory provided") diff --git a/python/mlc_llm/support/download_cache.py b/python/mlc_llm/support/download_cache.py new file mode 100644 index 0000000000..c0ab9cdbc1 --- /dev/null +++ b/python/mlc_llm/support/download_cache.py @@ -0,0 +1,237 @@ +"""Common utilities for downloading files from HuggingFace or other URLs online.""" + +import concurrent.futures as cf +import hashlib +import json +import os +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import List, Optional, Tuple + +import requests # pylint: disable=import-error + +from . import logging, tqdm +from .constants import ( + MLC_DOWNLOAD_CACHE_POLICY, + MLC_LLM_HOME, + MLC_LLM_READONLY_WEIGHT_CACHE, + MLC_TEMP_DIR, +) +from .style import bold + +logger = logging.getLogger(__name__) + + +def log_download_cache_policy(): + """log current download policy""" + logger.info( + "%s = %s. Can be one of: ON, OFF, REDO, READONLY", + bold("MLC_DOWNLOAD_CACHE_POLICY"), + MLC_DOWNLOAD_CACHE_POLICY, + ) + + +def _ensure_directory_not_exist(path: Path, force_redo: bool) -> None: + if path.exists(): + if force_redo: + logger.info("Deleting existing directory: %s", path) + shutil.rmtree(path) + else: + raise ValueError(f"Directory already exists: {path}") + else: + path.parent.mkdir(parents=True, exist_ok=True) + + +def git_clone(url: str, destination: Path, ignore_lfs: bool) -> None: + """Clone a git repository into a directory.""" + repo_name = ".tmp" + command = ["git", "clone", url, repo_name] + _ensure_directory_not_exist(destination, force_redo=False) + try: + env = os.environ.copy() + env["GIT_LFS_SKIP_SMUDGE"] = "1" + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir: + logger.info("[Git] Cloning %s to %s", bold(url), destination) + subprocess.run( + command, + env=env, + cwd=tmp_dir, + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + git_dir = os.path.join(tmp_dir, repo_name) + if not ignore_lfs: + git_lfs_pull(Path(git_dir)) + shutil.move(git_dir, str(destination)) + except subprocess.CalledProcessError as error: + raise ValueError( + f"Git clone failed with return code {error.returncode}: {error.stderr}. " + f"The command was: {command}" + ) from error + + +def git_lfs_pull(repo_dir: Path, ignore_extensions: Optional[List[str]] = None) -> None: + """Pull files with Git LFS.""" + filenames = ( + subprocess.check_output( + ["git", "-C", str(repo_dir), "lfs", "ls-files", "-n"], + stderr=subprocess.STDOUT, + ) + .decode("utf-8") + .splitlines() + ) + if ignore_extensions is not None: + filenames = [ + filename + for filename in filenames + if not any(filename.endswith(extension) for extension in ignore_extensions) + ] + logger.info("[Git LFS] Downloading %d files with Git LFS: %s", len(filenames), filenames) + with tqdm.redirect(): + for file in tqdm.tqdm(filenames): + logger.info("[Git LFS] Downloading %s", file) + subprocess.check_output( + ["git", "-C", str(repo_dir), "lfs", "pull", "--include", file], + stderr=subprocess.STDOUT, + ) + + +def download_file( + url: str, + destination: Path, + md5sum: Optional[str], +) -> Tuple[str, Path]: + """Download a file from a URL to a destination file.""" + with requests.get(url, stream=True, timeout=30) as response: + response.raise_for_status() # type: ignore + with destination.open("wb") as file: + for chunk in response.iter_content(chunk_size=8192): # type: ignore + file.write(chunk) + if md5sum is not None: + hash_md5 = hashlib.md5() + with destination.open("rb") as file: + for chunk in iter(lambda: file.read(8192), b""): + hash_md5.update(chunk) + file_md5 = hash_md5.hexdigest() + if file_md5 != md5sum: + raise ValueError( + f"MD5 checksum mismatch for downloaded file: {destination}. " + f"Expected {md5sum}, got {file_md5}" + ) + return url, destination + + +def download_and_cache_mlc_weights( # pylint: disable=too-many-locals + model_url: str, + num_processes: int = 4, + force_redo: Optional[bool] = None, +) -> Path: + """Download weights for a model from the HuggingFace Git LFS repo.""" + log_download_cache_policy() + if MLC_DOWNLOAD_CACHE_POLICY == "OFF": + raise RuntimeError(f"Cannot download {model_url} as MLC_DOWNLOAD_CACHE_POLICY=OFF") + + prefixes, mlc_prefix = ["HF://", "https://huggingface.co/"], "" + mlc_prefix = next(p for p in prefixes if model_url.startswith(p)) + assert mlc_prefix + + git_url_template = "https://huggingface.co/{user}/{repo}.git" + bin_url_template = "https://huggingface.co/{user}/{repo}/resolve/main/{record_name}" + + if model_url.count("/") != 1 + mlc_prefix.count("/") or not model_url.startswith(mlc_prefix): + raise ValueError(f"Invalid model URL: {model_url}") + user, repo = model_url[len(mlc_prefix) :].split("/") + domain = "hf" + + readonly_cache_dirs = [] + for base in MLC_LLM_READONLY_WEIGHT_CACHE: + cache_dir = base / domain / user / repo + readonly_cache_dirs.append(str(cache_dir)) + if (cache_dir / "mlc-chat-config.json").is_file(): + logger.info("Use cached weight: %s", bold(str(cache_dir))) + return cache_dir + + if force_redo is None: + force_redo = MLC_DOWNLOAD_CACHE_POLICY == "REDO" + + git_dir = MLC_LLM_HOME / "model_weights" / domain / user / repo + readonly_cache_dirs.append(str(git_dir)) + + try: + _ensure_directory_not_exist(git_dir, force_redo=force_redo) + except ValueError: + logger.info("Weights already downloaded: %s", bold(str(git_dir))) + return git_dir + + if MLC_DOWNLOAD_CACHE_POLICY == "READONLY": + raise RuntimeError( + f"Cannot find cache for {model_url}, " + "cannot proceed to download as MLC_DOWNLOAD_CACHE_POLICY=READONLY, " + "please check settings MLC_LLM_READONLY_WEIGHT_CACHE, " + f"local path candidates: {readonly_cache_dirs}" + ) + + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir_prefix: + tmp_dir = Path(tmp_dir_prefix) / "tmp" + git_url = git_url_template.format(user=user, repo=repo) + git_clone(git_url, tmp_dir, ignore_lfs=True) + git_lfs_pull(tmp_dir, ignore_extensions=[".bin"]) + shutil.rmtree(tmp_dir / ".git", ignore_errors=True) + with (tmp_dir / "ndarray-cache.json").open(encoding="utf-8") as in_file: + param_metadata = json.load(in_file)["records"] + with cf.ProcessPoolExecutor(max_workers=num_processes) as executor: + futures = [] + for record in param_metadata: + record_name = record["dataPath"] + file_url = bin_url_template.format(user=user, repo=repo, record_name=record_name) + file_dest = tmp_dir / record_name + file_md5 = record.get("md5sum", None) + futures.append(executor.submit(download_file, file_url, file_dest, file_md5)) + with tqdm.redirect(): + for future in tqdm.tqdm(cf.as_completed(futures), total=len(futures)): + file_url, file_dest = future.result() + logger.info("Downloaded %s to %s", file_url, file_dest) + logger.info("Moving %s to %s", tmp_dir, bold(str(git_dir))) + shutil.move(str(tmp_dir), str(git_dir)) + return git_dir + + +def get_or_download_model(model: str) -> Path: + """Use user-provided argument ``model`` to get model_path + + We define "valid" as having an ``mlc-chat-config.json`` right under the folder. + + Parameters + ---------- + model : str + User's input; may a path or url + + Returns + ------ + model_path : Path + A "valid" path to model folder, with + ``(model_path / "mlc-chat-config.json").is_file`` being True + + Note + ---- + This function may perform additional download and caching + + Raises + ------ + FileNotFoundError: if we cannot find a valid `model_path`. + """ + if model.startswith("HF://"): + logger.info("Downloading model from HuggingFace: %s", model) + model_path = download_and_cache_mlc_weights(model) + else: + model_path = Path(model) + + if not model_path.is_dir(): + raise FileNotFoundError(f"Cannot find model {model}, directory does not exist") + mlc_config_path = model_path / "mlc-chat-config.json" + if mlc_config_path.is_file(): + return model_path + raise FileNotFoundError(f"Cannot find {str(mlc_config_path)} in the model directory provided") diff --git a/python/mlc_llm/support/logging.py b/python/mlc_llm/support/logging.py index f2611c7f1a..023d1240f1 100644 --- a/python/mlc_llm/support/logging.py +++ b/python/mlc_llm/support/logging.py @@ -2,6 +2,7 @@ Logging support for MLC. It derives from Python's logging module, and in the future, it can be easily replaced by other logging modules such as structlog. """ + import logging diff --git a/python/mlc_llm/support/random.py b/python/mlc_llm/support/random.py index 0568276d12..9c142ed36e 100644 --- a/python/mlc_llm/support/random.py +++ b/python/mlc_llm/support/random.py @@ -1,4 +1,5 @@ """Utility functions for random number generation.""" + import sys diff --git a/python/mlc_llm/support/tensor_parallel.py b/python/mlc_llm/support/tensor_parallel.py index 6b3f930576..cea22fdb7e 100644 --- a/python/mlc_llm/support/tensor_parallel.py +++ b/python/mlc_llm/support/tensor_parallel.py @@ -1,4 +1,5 @@ """Sharding operators for tensor parallelism.""" + import dataclasses from contextlib import contextmanager from typing import Any, Dict, List, Optional @@ -75,27 +76,6 @@ def gen_shard_info(self, shards: int, weight: nn.Tensor) -> Dict[str, Any]: "out_dtype": weight.dtype, } -@dataclasses.dataclass -class ShardScalar: - """ - Shard a scalar param into multiple distinct scalars, one for each shard. - - - Parameters - ---------- - name : str - The name of the shard func - """ - - name: str - def gen_shard_info(self, shards: int, weight: nn.Tensor) -> Dict[str, Any]: - """Generate shard info for this sharding strategy.""" - return { - "func_name": self.name, - "out_shape": (shards, *weight.shape), - "out_dtype": weight.dtype, - } - @contextmanager def shard_bias(linear: nn.Linear, tensor_parallel_shards: int): diff --git a/python/mlc_llm/support/tqdm.py b/python/mlc_llm/support/tqdm.py index 9adceca480..a2d1d43f42 100644 --- a/python/mlc_llm/support/tqdm.py +++ b/python/mlc_llm/support/tqdm.py @@ -1,4 +1,5 @@ """Utils to better use tqdm""" + import contextlib import inspect import io diff --git a/python/mlc_llm/testing/__init__.py b/python/mlc_llm/testing/__init__.py index e803641043..ef1c38828b 100644 --- a/python/mlc_llm/testing/__init__.py +++ b/python/mlc_llm/testing/__init__.py @@ -1,3 +1,5 @@ """ Test and debug tools for MLC LLM """ + +from .pytest_utils import require_test_model, require_test_tokenizers diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index 4f1cfe103d..54f918d9b2 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -1,5 +1,6 @@ """Debug compiled models with TVM instrument""" +# pylint: disable=too-many-arguments import json import random from pathlib import Path @@ -12,20 +13,14 @@ from tvm.runtime import Device, Module, Object, ShapeTuple from tvm.runtime.relax_vm import VirtualMachine -from mlc_llm.chat_module import ( - ChatConfig, - GenerationConfig, - _get_chat_config, - _get_generation_config, - _get_model_path, -) from mlc_llm.conversation_template import ConvTemplateRegistry -from mlc_llm.help import HELP +from mlc_llm.interface.help import HELP +from mlc_llm.protocol.mlc_chat_config import MLCChatConfig from mlc_llm.serve import engine_utils from mlc_llm.support.argparse import ArgumentParser from mlc_llm.support.auto_device import detect_device from mlc_llm.support.style import green, red -from mlc_llm.tokenizer import Tokenizer +from mlc_llm.tokenizers import Tokenizer def _extract_metadata(mod: Module): @@ -144,7 +139,7 @@ class DebugChat: # pylint: disable=too-many-instance-attributes, too-few-public dc = DebugChat( model="./dist/Llama-2-7b-chat-hf-q4f16_1-MLC", debug_dir=Path("./debug-llama-2"), - model_lib_path="./dist/llama-2-7b-chat-q4f16_1-metal.so", + model_lib="./dist/llama-2-7b-chat-q4f16_1-metal.so", ) dc.generate("hello world", 3) """ @@ -152,10 +147,9 @@ class DebugChat: # pylint: disable=too-many-instance-attributes, too-few-public def __init__( # pylint: disable=too-many-arguments self, model: str, - model_lib_path: str, + model_lib: str, debug_dir: Path, device: Optional[str] = "auto", - chat_config: Optional[ChatConfig] = None, debug_instrument: Optional[Any] = None, ): """_summary_ @@ -169,7 +163,7 @@ def __init__( # pylint: disable=too-many-arguments folder. In the former case, we will use the provided name to search for the model folder over possible paths. - model_lib_path : str + model_lib : str The full path to the model library file to use (e.g. a ``.so`` file). debug_dir: Path @@ -213,17 +207,21 @@ def instrument( debug_instrument if debug_instrument else DefaultDebugInstrument(debug_dir / "prefill") ) self.mod, self.params, self.metadata = _get_tvm_module( - model, model_lib_path, self.device, self.instrument + model, model_lib, self.device, self.instrument ) - self.model_path, self.config_file_path = _get_model_path(model) - self.chat_config = _get_chat_config(self.config_file_path, chat_config) + self.model_path = Path(model) + self.config_file_path = self.model_path / "mlc-chat-config.json" + with open(self.config_file_path, mode="rt", encoding="utf-8") as file: + self.chat_config = MLCChatConfig.model_validate_json(file.read()) + conv_template = self.chat_config.conv_template + self.conversation = ( ConvTemplateRegistry.get_conv_template(conv_template) if isinstance(conv_template, str) else conv_template ) - self.tokenizer = Tokenizer(self.model_path) + self.tokenizer = Tokenizer(str(self.model_path)) self.add_sequence_func = tvm.get_global_func("vm.builtin.kv_state_add_sequence") self.begin_forward_func = tvm.get_global_func("vm.builtin.kv_state_begin_forward") @@ -340,18 +338,22 @@ def _apply_presence_and_freq_penalty( logits[:, :, token_id] -= freq * freq_penalty + presence_penalty def _sample_token_from_logits( - self, logits: tvm.nd.NDArray, generation_config: GenerationConfig + self, + logits: tvm.nd.NDArray, + *, + temperature=1.0, + top_p=1.0, + presence_penalty=0.0, + frequency_penalty=0.0, ): logits_np = logits.numpy() - temperature = generation_config.temperature if generation_config.temperature else 1.0 - top_p = generation_config.top_p if generation_config.top_p else 0.95 - presence_penalty = generation_config.presence_penalty - frequency_penalty = generation_config.frequency_penalty if presence_penalty != 0.0 or frequency_penalty != 0.0: self._apply_presence_and_freq_penalty(logits_np, presence_penalty, frequency_penalty) - self._softmax_with_temperature(logits_np, temperature) + logits_np = self._softmax_with_temperature(logits_np, temperature) + np.savez(self.instrument.debug_out / "logits.npz", logits_np) + logits = logits.copyfrom(logits_np) next_token = self.sample_topp_from_prob_func(logits, top_p, random.random()) return next_token @@ -360,7 +362,6 @@ def generate( self, prompt: str, generate_length: int, - generation_config: Optional[GenerationConfig] = None, ): """Generates the response from the model given a user prompt. User will need to specify the generation length for debugging purpose. For example, a generation @@ -373,9 +374,6 @@ def generate( generate_length : int How many tokens to generate. - - generation_config : Optional[GenerationConfig] - Will be used to override the GenerationConfig in ``mlc-chat-config.json``. """ out_tokens = [] @@ -383,8 +381,7 @@ def generate( print(f"{green('Input tokens')}: {input_tokens.numpy()}") embedding, input_len = self._embed(input_tokens) logits, kv_caches = self._prefill(embedding, input_len) - generation_config = _get_generation_config(self.chat_config, generation_config) - next_token = self._sample_token_from_logits(logits, generation_config) + next_token = self._sample_token_from_logits(logits) out_tokens.append(next_token) path_str = (self.debug_dir / "prefill").as_posix() print(f"Debug instrument output dumped to {green(path_str)}") @@ -393,8 +390,7 @@ def generate( for i in range(generate_length - 1): self.instrument.reset(self.debug_dir / f"decode_{i}") logits = self._decode(next_token, kv_caches) - generation_config = _get_generation_config(self.chat_config, generation_config) - next_token = self._sample_token_from_logits(logits, generation_config) + next_token = self._sample_token_from_logits(logits) out_tokens.append(next_token) path_str = (self.debug_dir / f"decode_{i}").as_posix() print(f"Debug instrument output dumped to {green(path_str)}") @@ -427,7 +423,7 @@ def main(): required=True, ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, help="The full path to the model library file to use (e.g. a ``.so`` file).", required=True, @@ -447,7 +443,7 @@ def main(): parsed = parser.parse_args() dc = DebugChat( model=parsed.model, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, debug_dir=Path(parsed.debug_dir), device=parsed.device, ) diff --git a/python/mlc_llm/testing/debug_compare.py b/python/mlc_llm/testing/debug_compare.py index b3487e3e48..2ad640920f 100644 --- a/python/mlc_llm/testing/debug_compare.py +++ b/python/mlc_llm/testing/debug_compare.py @@ -8,7 +8,7 @@ from tvm import rpc, runtime from tvm.relax.testing.lib_comparator import LibCompareVMInstrument -from mlc_llm.help import HELP +from mlc_llm.interface.help import HELP from mlc_llm.support.argparse import ArgumentParser from mlc_llm.testing.debug_chat import DebugChat @@ -67,31 +67,34 @@ def __init__( # pylint: disable=too-many-arguments, unused-argument self, mod: runtime.Module, device: runtime.Device, - debug_dir: Path, + debug_out: Path, time_eval: bool = True, rtol: float = 1e-2, atol: float = 1, skip_rounds: int = 0, ): super().__init__(mod, device, True, rtol, atol) + self.debug_out = debug_out self.time_eval = time_eval self.time_eval_results: Dict[str, Tuple[float, int]] = {} self.visited: Set[str] = set([]) self.skip_rounds = skip_rounds self.counter = 0 + debug_out.mkdir(exist_ok=True, parents=True) - def reset(self, debug_dir: Path): # pylint: disable=unused-argument + def reset(self, debug_out: Path): # pylint: disable=unused-argument """Reset the state of the Instrument class Note ---- - `debug_dir` is not used in this class. + `debug_out` is not used in this class. Parameters ---------- debug_out : Path the directory to dump the .npz files """ + self.debug_out = debug_out _print_as_table( sorted( self.time_eval_results.items(), @@ -101,6 +104,7 @@ def reset(self, debug_dir: Path): # pylint: disable=unused-argument self.time_eval_results = {} self.visited = set([]) self.counter = 0 + debug_out.mkdir(exist_ok=True, parents=True) def skip_instrument(self, func, name, before_run, ret_val, *args): if name.startswith("shape_func"): @@ -128,7 +132,12 @@ def compare( if self.time_eval and name not in self.time_eval_results: res = self.mod.time_evaluator( - name, self.device, number=20, repeat=3 # , cache_flush_bytes=256 * 10**6 + name, + self.device, + number=20, + repeat=3, + min_repeat_ms=100, + # cache_flush_bytes=256 * 10**6 )(*new_args) self.time_eval_results[name] = (res.mean, 1) print(f"Time-eval result {name} on {self.device}:\n {res}") @@ -139,7 +148,7 @@ def get_instrument(args): if args.cmp_device is None: assert args.cmp_lib_path is None, "cmp_lib_path must be None if cmp_device is None" args.cmp_device = args.device - args.cmp_lib_path = args.model_lib_path + args.cmp_lib_path = args.model_lib if args.cmp_device == "iphone": assert args.cmp_lib_path.endswith(".dylib"), "Require a dylib file for iPhone" @@ -159,19 +168,14 @@ def get_instrument(args): lib = sess.load_module(os.path.basename(args.cmp_lib_path)) cmp_device = sess.cl(0) else: - lib = tvm.runtime.load_module( - os.path.join( - args.artifact_path, - f"{args.model}-{args.quantization.name}-{args.cmp_device}.so", - ) - ) + lib = tvm.runtime.load_module(args.cmp_lib_path) cmp_device = tvm.device(args.cmp_device) return LibCompare( lib, cmp_device, time_eval=args.time_eval, - debug_dir=Path(args.debug_dir), + debug_out=Path(args.debug_dir), ) @@ -194,7 +198,7 @@ def main(): required=True, ) parser.add_argument( - "--model-lib-path", + "--model-lib", type=str, help="The full path to the model library file to use (e.g. a ``.so`` file).", required=True, @@ -230,7 +234,7 @@ def main(): instrument = get_instrument(parsed) debug_chat = DebugChat( model=parsed.model, - model_lib_path=parsed.model_lib_path, + model_lib=parsed.model_lib, debug_dir=Path(parsed.debug_dir), device=parsed.device, debug_instrument=instrument, diff --git a/python/mlc_llm/testing/pytest_utils.py b/python/mlc_llm/testing/pytest_utils.py new file mode 100644 index 0000000000..d7924e1c21 --- /dev/null +++ b/python/mlc_llm/testing/pytest_utils.py @@ -0,0 +1,86 @@ +"""Extra utilities to mark tests""" + +import functools +import inspect +from pathlib import Path +from typing import Callable + +import pytest + +from mlc_llm.support.constants import MLC_TEST_MODEL_PATH + + +def require_test_model(*models: str): + """Testcase decorator to require a model + + Examples + -------- + .. code:: + + @require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC") + def test_reload_reset_unload(model): + # model now points to the right path + # specified by MLC_TEST_MODEL_PATH + engine = mlc_llm.MLCEngine(model) + # test code follows + + Parameters + ---------- + models : List[str] + The model directories or URLs. + """ + model_paths = [] + missing_models = [] + + for model in models: + model_path = None + for base_path in MLC_TEST_MODEL_PATH: + if (base_path / model / "mlc-chat-config.json").is_file(): + model_path = base_path / model + break + if model_path is None and (Path(model) / "mlc-chat-config.json").is_file(): + model_path = Path(model) + + if model_path is None: + missing_models.append(model) + else: + model_paths.append(str(model_path)) + + message = ( + f"Model {', '.join(missing_models)} not found in candidate paths " + f"{[str(p) for p in MLC_TEST_MODEL_PATH]}," + " if you set MLC_TEST_MODEL_PATH, please ensure model paths are in the right location," + " by default we reuse cache, try to run mlc_llm chat to download right set of models." + ) + + def _decorator(func: Callable[..., None]): + wrapped = functools.partial(func, *model_paths) + wrapped.__name__ = func.__name__ # type: ignore + + if inspect.iscoroutinefunction(wrapped): + # The function is a coroutine function ("async def func(...)") + @functools.wraps(wrapped) + async def wrapper(*args, **kwargs): + if len(missing_models) > 0: + print(f"{message} skipping...") + return + await wrapped(*args, **kwargs) + + else: + # The function is a normal function ("def func(...)") + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + if len(missing_models) > 0: + print(f"{message} skipping...") + return + wrapped(*args, **kwargs) + + return pytest.mark.skipif(len(missing_models) > 0, reason=message)(wrapper) + + return _decorator + + +def require_test_tokenizers(*models: str): + """Testcase decorator to require a path to tokenizers""" + # redirect to require models for now + return require_test_model(*models) diff --git a/python/mlc_llm/tokenizers/__init__.py b/python/mlc_llm/tokenizers/__init__.py new file mode 100644 index 0000000000..88704b49cc --- /dev/null +++ b/python/mlc_llm/tokenizers/__init__.py @@ -0,0 +1,4 @@ +"""Namespace for tokenizer rleated utilities""" + +from .streamer import StopStrHandler, TextStreamer +from .tokenizers import Tokenizer diff --git a/python/mlc_llm/tokenizers/_ffi_api.py b/python/mlc_llm/tokenizers/_ffi_api.py new file mode 100644 index 0000000000..3b08d33a7f --- /dev/null +++ b/python/mlc_llm/tokenizers/_ffi_api.py @@ -0,0 +1,7 @@ +"""FFI APIs for mlc_llm""" + +import tvm._ffi + +# Exports functions registered via TVM_REGISTER_GLOBAL with the "mlc" prefix. +# e.g. TVM_REGISTER_GLOBAL("mlc.Tokenizer") +tvm._ffi._init_api("mlc.tokenizers", __name__) # pylint: disable=protected-access diff --git a/python/mlc_llm/tokenizers/streamer.py b/python/mlc_llm/tokenizers/streamer.py new file mode 100644 index 0000000000..37179f17f7 --- /dev/null +++ b/python/mlc_llm/tokenizers/streamer.py @@ -0,0 +1,84 @@ +"""Streamers in MLC LLM.""" + +from typing import List, Union + +import tvm +import tvm._ffi +from tvm.runtime import Object, ShapeTuple + +from . import _ffi_api +from .tokenizers import Tokenizer + + +@tvm._ffi.register_object("mlc.TextStreamer") # pylint: disable=protected-access +class TextStreamer(Object): + """The class that streams back validated utf-8 text strings + that generated by tokenizer. + """ + + def __init__(self, tokenizer: Tokenizer) -> None: + """Create the text streamer from tokenizer""" + self.__init_handle_by_constructor__( + _ffi_api.TextStreamer, tokenizer # type: ignore # pylint: disable=no-member + ) + + def put(self, delta_tokens: Union[List[int], ShapeTuple]) -> str: + """Put new delta tokens into the streamer, and get the UTF-8-valid + delta string. The text streamer may hold some of the input delta tokens + which cannot decode into valid UTF-8 strings. The returned string + is always guaranteed to be UTF-8 valid. + + Parameters + ---------- + delta_tokens : Union[List[int], ShapeTuple] + The new tokens to put into the streamer. + + Returns + ------- + delta_text : str + The decoded delta string after putting the input new tokens. + """ + if isinstance(delta_tokens, list): + delta_tokens = ShapeTuple(delta_tokens) + return _ffi_api.TextStreamerPut( # type: ignore # pylint: disable=no-member + self, delta_tokens + ) + + def finish(self) -> str: + """Return the string decoded by remaining tokens.""" + return _ffi_api.TextStreamerFinish(self) # type: ignore # pylint: disable=no-member + + +@tvm._ffi.register_object("mlc.StopStrHandler") # pylint: disable=protected-access +class StopStrHandler(Object): + """The stop string handler in MLC LLM, which takes input delta tokens + one at a time, and return the output delta token before stopping due to + stop strings.""" + + def __init__(self, stop_strs: List[str], tokenizer: Tokenizer) -> None: + self.__init_handle_by_constructor__( + _ffi_api.StopStrHandler, # type: ignore # pylint: disable=no-member + stop_strs, + tokenizer, + ) + + def put(self, token_id: int) -> List[int]: + """Add new input delta token to the handler, return output + delta tokens before stopping. The stop string handler may hold + some of the input delta token which may be part of a stop string. + The returned tokens are always guaranteed not to be part of stop string. + """ + return list( + _ffi_api.StopStrHandlerPut(self, token_id) # type: ignore # pylint: disable=no-member + ) + + def finish(self) -> List[int]: + """Stop string handling has finished, return remaining cached token ids.""" + return list( + _ffi_api.StopStringHandlerFinish(self) # type: ignore # pylint: disable=no-member + ) + + @property + def stop_triggered(self) -> bool: + """Check if the generation has stopped due to stop string.""" + return _ffi_api.StopStrHandlerStopTriggered(self) # type: ignore # pylint: disable=no-member diff --git a/python/mlc_llm/tokenizers/tokenizers.py b/python/mlc_llm/tokenizers/tokenizers.py new file mode 100644 index 0000000000..8540bee0c9 --- /dev/null +++ b/python/mlc_llm/tokenizers/tokenizers.py @@ -0,0 +1,129 @@ +"""The tokenizer and related tools in MLC LLM. +This tokenizer essentially wraps and binds the HuggingFace tokenizer +library and sentencepiece. +Reference: https://github.com/mlc-ai/tokenizers-cpp +""" + +import json +from dataclasses import asdict, dataclass +from typing import List, Literal + +import tvm +import tvm._ffi +from tvm.runtime import Object + +from . import _ffi_api + + +@dataclass +class TokenizerInfo: # pylint: disable=too-many-instance-attributes + """Useful information of the tokenizer during generation. + + Attributes + ---------- + token_postproc_method : Literal["byte_fallback", "byte_level"] + The method to post-process the tokens to their original strings. + Possible values (each refers to a kind of tokenizer): + - "byte_fallback": The same as the byte-fallback BPE tokenizer, including LLaMA-2, + Mixtral-7b, etc. E.g. "▁of" -> " of", "<0x1B>" -> "\x1B". + This method: + 1) Transform tokens like <0x1B> to hex char byte 1B. (so-called byte-fallback) + 2) Replace \\u2581 "▁" with space. + - "byte_level": The same as the byte-level BPE tokenizer, including LLaMA-3, GPT-2, + Phi-2, etc. E.g. "Ġin" -> " in", "ě" -> "\x1B" + This method inverses the bytes-to-unicode transformation in the encoding process in + https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59 + + prepend_space_in_encode : bool + Whether to prepend a space during encoding. + + strip_space_in_decode : bool + Whether to strip the first space during decoding. + """ + + token_postproc_method: Literal["byte_fallback", "byte_level"] = "byte_fallback" + prepend_space_in_encode: bool = False + strip_space_in_decode: bool = False + + def asjson(self) -> str: + """Return the config in string of JSON format.""" + return json.dumps(asdict(self)) + + @staticmethod + def from_json(json_str: str) -> "TokenizerInfo": + """Construct a config from JSON string.""" + return TokenizerInfo(**json.loads(json_str)) + + +@tvm._ffi.register_object("mlc.Tokenizer") # pylint: disable=protected-access +class Tokenizer(Object): + """The tokenizer class in MLC LLM.""" + + def __init__(self, tokenizer_path: str) -> None: + """Create the tokenizer from tokenizer directory path.""" + self.__init_handle_by_constructor__( + _ffi_api.Tokenizer, tokenizer_path # type: ignore # pylint: disable=no-member + ) + + def encode(self, text: str) -> List[int]: + """Encode text into ids. + + Parameters + ---------- + text : str + The text string to encode. + + Returns + ------- + token_ids : List[int] + The list of encoded token ids. + """ + return list(_ffi_api.TokenizerEncode(self, text)) # type: ignore # pylint: disable=no-member + + def encode_batch(self, texts: List[str]) -> List[List[int]]: + """Encode a batch of texts into ids. + + Parameters + ---------- + texts : List[str] + The list of text strings to encode. + + Returns + ------- + token_ids : List[List[int]] + The list of list of encoded token ids. + """ + return list(_ffi_api.TokenizerEncodeBatch(self, texts)) # type: ignore # pylint: disable=no-member + + def decode(self, token_ids: List[int]) -> str: + """Decode token ids into text. + + Parameters + ---------- + token_ids : List[int] + The token ids to decode to string. + + Returns + ------- + text : str + The decoded text string. + """ + return _ffi_api.TokenizerDecode( # type: ignore # pylint: disable=no-member + self, tvm.runtime.ShapeTuple(token_ids) + ) + + @staticmethod + def detect_tokenizer_info(tokenizer_path: str) -> TokenizerInfo: + """Detect the tokenizer info from the given path of the tokenizer. + + Parameters + ---------- + tokenizer_path : str + The tokenizer directory path. + + Returns + ------- + tokenizer_info : str + The detected tokenizer info in JSON string. + """ + return TokenizerInfo.from_json(_ffi_api.DetectTokenizerInfo(tokenizer_path)) # type: ignore # pylint: disable=no-member diff --git a/rust/.gitignore b/rust/.gitignore deleted file mode 100644 index c5e4e0d10a..0000000000 --- a/rust/.gitignore +++ /dev/null @@ -1,20 +0,0 @@ -# Generated by Cargo -# will have compiled files and executables -debug/ -target/ - -# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries -# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html -Cargo.lock - -# Generated by Rust -**/*.rs.bk -/examples/pkg - -# MSVC Windows builds of rustc generate these, which store debugging information -*.pdb - -# IDE files -.idea/ -*.iml -.vscode/ diff --git a/rust/Cargo.toml b/rust/Cargo.toml deleted file mode 100644 index d7ffe2f333..0000000000 --- a/rust/Cargo.toml +++ /dev/null @@ -1,22 +0,0 @@ -[package] -name = "mlc-llm" -version = "0.1.0" -license = "Apache-2.0" -description = "Rust API for MLC LLM." -homepage = "https://llm.mlc.ai/" -readme = "README.md" -keywords = ["rust", "mlc", "llm", "tvm", "AI"] -authors = ["MLC Contributors"] -repository = "https://github.com/mlc-ai/mlc-llm" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -tvm-rt = { path = "../3rdparty/tvm/rust/tvm-rt", version = "0.1.0-alpha", features = [ - "dynamic-linking", -] } -tracing = "0.1.32" -derive_builder = "0.12.0" -serde = { version = "1.0.160", features = ["derive"] } -serde_json = "1.0.107" diff --git a/rust/README.md b/rust/README.md deleted file mode 100644 index 971fb11200..0000000000 --- a/rust/README.md +++ /dev/null @@ -1,25 +0,0 @@ -# MLC-LLM Rust Package - -This folder contains the source code of MLC-LLM Rust package. - -# Installations -To set up the MLC-LLM Rust package, please follow these steps: - -**Step 1:** Begin by following the detailed installation [instructions](https://llm.mlc.ai/docs/deploy/rest.html#optional-build-from-source) for TVM Unity and MLC-LLM. - -**Step 2:** Define the environment variables for TVM and MLC-LLM by running the following commands in your terminal: -```bash -export TVM_HOME=/path/to/tvm -export MLC_HOME=/path/to/mlc-llm -``` - -**Step 3:** Update your `LD_LIBRARY_PATH` to include the `libtvm_runtime` and `libmlc_llm_module` libraries. These can typically be found within the build directories of your TVM and MLC-LLM installations. - -# How to run it? -To start using the package, you can refer to the example code provided in the examples directory. This code demonstrates how to create a chat_module and serve prompts effectively. - -Execute the example with Cargo using the following command: -```bash -cargo run --example mlc_llm -``` - diff --git a/rust/build.rs b/rust/build.rs deleted file mode 100644 index ce928f51af..0000000000 --- a/rust/build.rs +++ /dev/null @@ -1,6 +0,0 @@ -fn main() { - let mlc_home = env!("MLC_HOME"); - - println!("cargo:rustc-link-lib=dylib=mlc_llm_module"); - println!("cargo:rustc-link-search=native={}/build", mlc_home); -} diff --git a/rust/examples/mlc_chat.rs b/rust/examples/mlc_chat.rs deleted file mode 100644 index b3bbe97f29..0000000000 --- a/rust/examples/mlc_chat.rs +++ /dev/null @@ -1,28 +0,0 @@ -extern crate mlc_llm; - -use mlc_llm::chat_module::{ChatMessage, ChatModule}; - -fn main() { - // Single prompt example - let cm = ChatModule::new("/path/to/Llama2-13B-q8f16_1", "rocm", None).unwrap(); - let output = cm.generate("what is the meaning of life?", None).unwrap(); - println!("resp: {:?}", output); - println!("stats: {:?}", cm.stats(false)); - - // Multiple prompts example - let message1 = ChatMessage::new("user", "suppose we already have projects llama, alpaca and vicuna, what do you think would be a great name for the next project?"); - let message2 = ChatMessage::new( - "assistant", - "based on the previous projects, a possible name for the next project could be \"cervidae\" which is the scientific name for deer family. this name reflects the collaboration and teamwork involved in the development of the project, and also nods to the previous projects that have been developed by the team."); - let message3 = ChatMessage::new("user", "I like cervidae, but the name is too long!"); - let message4 = ChatMessage::new( - "assistant", - "In that case, a shorter and catchier name for the next project could be \"DeerRun\" which plays on the idea of the project being fast and efficient, just like a deer running through the woods. This name is memorable and easy to pronounce, making it a good choice for a project name."); - let message5 = ChatMessage::new("user", "Summarize our conversations."); - - let messages = vec![message1, message2, message3, message4, message5]; - - let output = cm.generate(messages, None).unwrap(); - println!("resp: {:?}", output); - println!("stats: {:?}", cm.stats(false)); -} diff --git a/rust/rustfmt.toml b/rust/rustfmt.toml deleted file mode 100644 index 8e52b87c0b..0000000000 --- a/rust/rustfmt.toml +++ /dev/null @@ -1,9 +0,0 @@ -edition = "2021" -unstable_features = true -max_width = 120 -binop_separator = "Back" -inline_attribute_width = 100 -fn_params_layout = "Compressed" -hard_tabs = false -tab_spaces = 4 -trailing_semicolon = false diff --git a/rust/src/chat_module.rs b/rust/src/chat_module.rs deleted file mode 100644 index b90549d06c..0000000000 --- a/rust/src/chat_module.rs +++ /dev/null @@ -1,601 +0,0 @@ -use std::collections::HashMap; -use std::fs; -use std::path::{Path, PathBuf}; -use std::result; -use tracing::info; -use tvm_rt::{function::Function, Module}; - -use super::config::*; - -extern "C" { - fn LLMChatDummyLinkFunc(); -} - -#[derive(Debug)] -pub enum ChatModuleError { - /// Global function in a TVM Module is not found - GlobalFuncNotFound, - /// TVM Runtime error - TvmRuntime(tvm_rt::Error), -} - -impl From for ChatModuleError { - fn from(e: tvm_rt::Error) -> Self { - Self::TvmRuntime(e) - } -} - -pub type Result = result::Result; - -#[derive(Debug, Clone)] -pub struct ChatMessage { - role: String, - content: String, -} - -impl ChatMessage { - pub fn new(role: &str, content: &str) -> Self { - ChatMessage { - role: role.to_owned(), - content: content.to_owned(), - } - } -} - -#[derive(Debug, Clone)] -pub enum Prompt { - String(String), - MessageList(Vec), -} - -impl From<&str> for Prompt { - fn from(s: &str) -> Self { - Prompt::String(s.to_owned()) - } -} - -impl From for Prompt { - fn from(s: String) -> Self { - Prompt::String(s) - } -} - -impl From> for Prompt { - fn from(messages: Vec) -> Self { - Prompt::MessageList(messages) - } -} - -#[derive(Debug, Copy, Clone)] -pub enum PlaceInPrompt { - All = 0, - Begin = 1, - Middle = 2, - End = 3, -} - -impl PlaceInPrompt { - pub fn to_value(&self) -> i32 { - *self as i32 - } -} - -macro_rules! tvm_func_invoke { - // Handle the case with return type - ($self:ident, $func_name:ident($($args:expr),*) -> $ret_type:ty) => { - { - let f = $self.chat_module.get_function(stringify!($func_name), false)?; - let res: $ret_type = f.invoke(vec![$($args.into()),*])?.try_into().expect("call should succeed"); - Ok(res) - } - }; - // Handle the case without return type - ($self:ident, $func_name:ident($($args:expr),*)) => { - { - let f = $self.chat_module.get_function(stringify!($func_name), false)?; - f.invoke(vec![$($args.into()),*])?; - Ok(()) - } - }; -} - -/// Parse the input device identifier into device name and id. -/// -/// # Arguments -/// * `device` - The device identifier to parse. It can be in the format "device_name" (e.g., "cuda") -/// or "device_name:device_id" (e.g., "cuda:1"). -/// -/// # Returns -/// * `device_name` - The name of the device. -/// * `device_id` - The id of the device, or 0 if not specified in the input. -fn parse_device_str(device: &str) -> (&str, i32) { - let device_err_msg = format!( - "Invalid device name: {}. Please enter the device in the form \ - 'device_name:device_id' or 'device_name', where 'device_name' needs to be \ - one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'.", - device - ); - let device_args: Vec<&str> = device.split(':').collect(); - match device_args.len() { - 1 => (device_args[0], 0), - 2 => (device_args[0], device_args[1].parse::().unwrap()), - _ => panic!("{}", device_err_msg), - } -} - -/// Use user-provided argument `model` to search for a valid model path. -/// We define "valid" as having an `mlc-chat-config.json` right under the folder. -/// -/// # Arguments -/// * `model`: User's input; may be a compiled model's name, or a full path. -/// -/// # Returns -/// * `model_path`: A "valid" path to model folder with `mlc-chat-config.json` existing under it. -/// * `chat_file`: The path to the `mlc-chat-config.json` file. -/// -/// # Panics -/// * If a valid model_path cannot be found. -pub fn get_model_path(model: &str) -> (PathBuf, PathBuf) { - // Note that the order of this list corresponds to our search priority - let candidate_paths = vec![ - PathBuf::from(model), // full path, or just the name - PathBuf::from(format!("{}/params", model)), // Default directory after mlc_llm.build_model() - PathBuf::from(format!("dist/prebuilt/{}", model)), // Using prebuilt workflow - PathBuf::from(format!("dist/{}/params", model)), // Default directory after mlc_llm.build_model() in the current path - PathBuf::from(format!("dist/prebuilt/mlc-chat-{}", model)), // Also prebuilt workflow, but missed prefix - ]; - - // Look for the first folder that has `mlc-chat-config.json` under it - for candidate in &candidate_paths { - let chat_file = candidate.join("mlc-chat-config.json"); - if chat_file.is_file() { - info!("Using model folder: {:?}", candidate.canonicalize().unwrap()); - info!("Using mlc chat config: {:?}", chat_file.canonicalize().unwrap()); - return (candidate.clone(), chat_file); - } - } - - let mut found_folder = false; - let mut valid_dir_str = String::new(); - for candidate in &candidate_paths { - if candidate.is_dir() { - valid_dir_str += &format!("- {:?}\n", candidate.canonicalize().unwrap()); - found_folder = true; - } - } - - if found_folder { - // Error 1: there is a folder, but not an mlc-llm model folder (E1) - let err_msg = format!( - "The model folder provided does not seem to refer to a valid mlc-llm model folder.\n\ - Specifically, we cannot find `mlc-chat-config.json`, a required file. You should \ - provide a path that contains the file.\n\ - According to your input `model`, we looked at folder(s):\n\ - {}\n\ - MLC-Chat consumes models that are processed by the MLC-LLM build process.\n\ - ", - valid_dir_str, - ); - panic!("{}", err_msg); - } else { - // Error 2: cannot find a folder (E0) - let all_paths_str = candidate_paths - .iter() - .map(|path| format!("- {}\n", path.display())) - .collect::(); - let err_msg = format!( - "Cannot find the model folder. We searched over the following possible paths:\n\ - {}\n\ - You can try to pass in `model=/path/to/your-model-path`, and confirm \ - that it contains `mlc-chat-config.json`, among other essential files.\n\ - ", - all_paths_str, - ); - panic!("{}", err_msg); - } -} - -/// Read in the config file in model path, then potentially override with user input. -/// -/// # Arguments -/// * `config_file_path`: &Path -/// `chat_file` returned by a function like `get_model_path()`. -fn get_chat_config(config_file_path: &Path) -> result::Result> { - // Read the base configuration from the file - let file_contents = fs::read_to_string(config_file_path)?; - let final_chat_config = ChatConfig::from_json(&file_contents)?; - Ok(final_chat_config) -} - -/// Look up the model library and return a corresponding `tvm` runtime Module. -/// -/// # Arguments -/// * `model` - A string representing either the name of a compiled model or a full path to it. -/// * `model_path` - The path to the model, as determined by `get_model_path`. -/// * `chat_config` - The chat configuration, possibly with overrides, returned by `get_chat_config`. -/// * `model_lib_path` - An optional string specifying the full path to the model library. This is prioritized if provided. -/// * `device_name` - A string representing the device for which the library model file name will be constructed. -/// * `config_file_path` - The path to the `mlc-chat-config.json` file, used for constructing error messages. -/// -/// # Returns -/// The path pointing to the model library we find. -fn get_lib_module_path( - model: &str, model_path: &Path, chat_config: &ChatConfig, model_lib_path: Option<&str>, device_name: &str, - config_file_path: &Path, -) -> PathBuf { - // 1. Use user's model_lib_path if provided - if let Some(lib_path) = model_lib_path { - let path = Path::new(lib_path); - if path.is_file() { - info!("Using library model: {:?}", path); - return path.to_path_buf(); - } else { - panic!("The `model_lib_path` you passed in is not a file: {:?}.", lib_path); - } - } - - // 2. Generate all possible file names according to OS - let mut candidate_paths = Vec::new(); - if let Some(model_lib) = &chat_config.model_lib { - let candidate_lib_names: Vec = if cfg!(target_os = "linux") { - vec![format!("{}-{}.so", model_lib, device_name)] - } else if cfg!(target_os = "macos") { - vec![ - format!("{}-{}.dylib", model_lib, device_name), - format!("{}-{}.so", model_lib, device_name), - ] - } else if cfg!(target_os = "windows") { - vec![format!("{}-{}.dll", model_lib, device_name)] - } else { - vec![ - format!("{}-{}.dylib", model_lib, device_name), - format!("{}-{}.so", model_lib, device_name), - format!("{}-{}.dll", model_lib, device_name), - ] - }; - - // 3. Generate possible model library paths - let pardir_model_path = model_path.parent().unwrap(); - for lib_name in &candidate_lib_names { - let paths: Vec = vec![ - lib_name.clone(), - format!("dist/prebuilt/lib/{}", lib_name), - format!("dist/{}/{}", model, lib_name), - model_path.join(lib_name).to_string_lossy().into_owned(), - pardir_model_path.join(lib_name).to_string_lossy().into_owned(), - ]; - - candidate_paths.extend(paths); - } - - // 4. Search for model library - for candidate in &candidate_paths { - let candidate_path = Path::new(candidate); - if candidate_path.is_file() { - info!("Using library model: {:?}", candidate_path); - return candidate_path.to_path_buf(); - } - } - - // 5. Error - let mut err_msg = format!( - "Cannot find the model library that corresponds to `{:?}`.\n\ - `{:?}` is either provided in the `chat_config` \ - you passed in, or specified in {:?}.\n\ - We searched over the following possible paths: \n", - model_lib, model_lib, config_file_path - ); - for candidate in &candidate_paths { - err_msg += &format!("- {}\n", candidate); - } - err_msg += &format!( - "If you would like to directly specify the model library path, you may \ - consider passing in the `ChatModule.model_lib_path` parameter." - ); - - panic!("{}", err_msg); - } else { - panic!("Cannot find the model library, you need to either pass it in, or specify in the chat_config file."); - } -} - -/// The ChatModule for MLC LLM. -/// -/// # Examples -/// -/// ``` -/// use mlc_llm::chat_module::ChatModule; -/// -/// // Create a ChatModule instance -/// let cm = ChatModule::new("Llama-2-7b-chat-hf-q4f16_1", "cuda", None, None).unwrap(); -/// -/// // Generate a response for a given prompt -/// let output = cm.generate("what is the meaning of life?", None).unwrap(); -/// -/// // Print prefill and decode performance statistics -/// println!("Statistics: {:?}\n", cm.stats(false).unwrap()); -/// -/// let output = cm.generate("what is Rust?", None).unwrap(); -/// ``` -pub struct ChatModule { - chat_module: Module, - chat_config: ChatConfig, -} - -impl ChatModule { - pub fn new(model: &str, device: &str, model_lib_path: Option<&str>) -> Result { - let device_err_msg = format!( - "Invalid device name: {}. Please enter the device in the form \ - 'device_name:device_id' or 'device_name', where 'device_name' needs to be \ - one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'.", - device - ); - - let (device_name, device_id) = parse_device_str(device); - - // 1. Get device name and id - let device_type = match device_name { - "cuda" => 2, - "opencl" => 4, - "vulkan" => 7, - "metal" => 8, - "rocm" => 10, - _ => panic!("{}", device_err_msg), - }; - - unsafe { - LLMChatDummyLinkFunc(); - } - - static GLOBAL_FUNC_NAME: &str = "mlc.llm_chat_create"; - let f = Function::get(GLOBAL_FUNC_NAME).ok_or(ChatModuleError::GlobalFuncNotFound)?; - let m: Module = f - .invoke(vec![device_type.into(), device_id.into()]) - .unwrap() - .try_into() - .expect("call should succeed"); - - // 2. Look up the model path - let (model_path, config_file_path) = get_model_path(model); - - // 3. Instantiate chat_config - let chat_config = get_chat_config(&config_file_path).unwrap(); - - // 4. Look up the model library - let model_lib_path = get_lib_module_path( - model, - &model_path, - &chat_config, - model_lib_path, - device_name, - &config_file_path, - ); - - let chat_mod = Self { - chat_module: m, - chat_config, - }; - let model_lib_str = model_lib_path.as_path().display().to_string(); - let model_path_str = model_path.as_path().display().to_string(); - chat_mod.reload(&model_lib_str, &model_path_str, "").unwrap(); - Ok(chat_mod) - } - - /// Reload the chat module from the given library and model path. - fn reload(&self, lib: &str, model_path: &str, app_config_json: &str) -> Result<()> { - tvm_func_invoke!(self, reload(lib, model_path, app_config_json)) - } - - /// Reset the chat session, clear all chat history, and potentially - /// override the original `mlc-chat-config.json`. - pub fn reset_chat(&self) -> Result<()> { - // TODO: add optional user-specified ChatConfig - tvm_func_invoke!(self, reset_chat()) - } - - /// Get the runtime stats of the encoding step, decoding step (and embedding step if exists) - /// of the chat module in text form. - pub fn stats(&self, verbose: bool) -> Result { - if verbose { - return tvm_func_invoke!(self, verbose_runtime_stats_text() -> String); - } - tvm_func_invoke!(self, runtime_stats_text() -> String) - } - - /// Check if the stop condition is met for the current round. - fn stopped(&self) -> Result { - tvm_func_invoke!(self, stopped() -> bool) - } - - /// Get the output message in the current round. - fn get_message(&self) -> Result { - tvm_func_invoke!(self, get_message() -> String) - } - - /// Decode the next token, the decoding result is stored in a buffer and - /// can be retrieved by [get_message]. - fn decode(&self, generation_config: Option<&GenerationConfig>) -> Result<()> { - let generation_config_str = match generation_config { - Some(config) => serde_json::to_string(config).unwrap(), - None => { - let config = GenerationConfig::from_chat_config(&self.chat_config); - serde_json::to_string(&config).unwrap() - } - }; - tvm_func_invoke!(self, decode(generation_config_str)) - } - - /// Load JSON config and override existing configurations for the chat module. - fn load_json_override(&self, config_str: &str, partial_update: bool) -> Result<()> { - tvm_func_invoke!(self, load_json_override(config_str, &partial_update)) - } - - /// Get the configuration of the chat module in a single json string. - fn get_config_json(&self) -> Result { - tvm_func_invoke!(self, get_config_json() -> String) - } - - /// Get the name of role 0 in the conversation. - fn get_role_0(&self) -> Result { - tvm_func_invoke!(self, get_role0() -> String) - } - - /// Get the name of role 1 in the conversation. - fn get_role_1(&self) -> Result { - tvm_func_invoke!(self, get_role1() -> String) - } - - /// A high-level method that returns the full response from the chat module given a user - /// prompt. User can optionally specify which callback method to use upon receiving the - /// response. - /// - /// # Arguments - /// * `prompt` - The user input prompt, i.e. a question to ask the chat module. - /// It can also be the whole conversation history (list of messages with role and content) - /// - /// # Examples - /// ``` - /// // Single prompt case, the `prompt` can be a &str - /// let prompt = "what is the meaning of life?"; - /// - /// // Multi-prompt case, the `prompt` can be Vec - /// let message1 = ChatMessage::new("user", "suppose we already have projects llama, alpaca and vicuna, what do you think would be a great name for the next project?"); - /// let message2 = ChatMessage::new( - /// "assistant", - /// "based on the previous projects, a possible name for the next project could be \"cervidae\" which is the scientific name for deer family. this name reflects the collaboration and teamwork involved in the development of the project, and also nods to the previous projects that have been developed by the team."); - /// let message3 = ChatMessage::new("user", "I like cervidae, but the name is too long!"); - /// let prompt = vec![message1, message2, message3]; - /// ``` - /// - /// * `generation_config` - The generation config object to override the ChatConfig generation settings. - /// - /// # Returns - /// * `output` - The generated full output from the chat module. - pub fn generate( - &self, prompt: impl Into, generation_config: Option<&GenerationConfig>, - ) -> Result> { - // TODO: add progress_callback - let mut new_msgs: Vec = vec![]; - let mut num_return_sequences: usize = 1; - - if let Some(gc) = generation_config { - if let Some(n) = gc.n { - num_return_sequences = n; - } - } - - let prompt = prompt.into(); - for _ in 0..num_return_sequences { - self.reset_chat().unwrap(); - self.prefill(&prompt, true, PlaceInPrompt::All, generation_config) - .unwrap(); - - while !self.stopped().unwrap() { - self.decode(generation_config)?; - } - let new_msg = self.get_message().unwrap(); - new_msgs.push(new_msg); - } - - Ok(new_msgs) - } - - /// Runs the prefill stage for a given input and optionally decodes the first output token. - /// The user can decide where to place the input in the prompt. - /// - /// # Arguments - /// - /// * `input` - A `String` or a `Vec`. The user input prompt, i.e., a question to ask the chat module. - /// It can also be the whole conversation history (list of messages with role and content). - /// - /// # Examples - /// ``` - /// // Single prompt case, the `prompt` can be a &str - /// "what is the meaning of life?"; - /// - /// // Multi-prompt case, the `prompt` can be Vec - /// vec![ - /// ChatMessage::new("user", "Hello, how are you?"), - /// ChatMessage::new("assistant", "I'm fine, thank you. How about you?"), - /// ChatMessage::new("user", "I'm good too."), - /// ] - /// ``` - /// * `decode_next_token` - A boolean indicating whether to decode the next token after prefilling. - /// * `place_in_prompt` - The place of the input message in the prompt, as defined by the `PlaceInPrompt` enum. - /// * `generation_config` - An optional `GenerationConfig` to override the ChatConfig generation settings. - /// - /// # Examples - /// - /// ``` - /// let input = "Hello, how are you?"; - /// let decode_next_token = true; - /// let place_in_prompt = PlaceInPrompt::All; - /// let generation_config = Some(GenerationConfig::new()); - /// - /// prefill(input, decode_next_token, place_in_prompt, generation_config); - /// ``` - fn prefill( - &self, input: &Prompt, decode_next_token: bool, place_in_promt: PlaceInPrompt, - generation_config: Option<&GenerationConfig>, - ) -> Result<()> { - let generation_config_str = match generation_config { - Some(config) => serde_json::to_string(config).unwrap(), - None => { - let config = GenerationConfig::from_chat_config(&self.chat_config); - serde_json::to_string(&config).unwrap() - } - }; - - let input_string = match input { - Prompt::String(inp) => inp.clone(), - Prompt::MessageList(chat_msgs) => { - let mut chat_msgs = chat_msgs.clone(); - if chat_msgs.len() == 1 { - chat_msgs.remove(0).content - } else { - let chat_config = ChatConfig::from_json(&(self.get_config_json()?)).unwrap(); - let mut conv_config = chat_config - .conv_config - .unwrap_or_else(|| ConvConfigBuilder::default().build().unwrap()); - - let role0 = self.get_role_0()?; - let role1 = self.get_role_1()?; - - let last_msg = chat_msgs.last().expect("No last message in the vector").clone(); - if last_msg.role != "user" { - panic!("Last message should be from user."); - } - - let mut messages = Vec::new(); - let msg_len = chat_msgs.len(); - for msg in chat_msgs.into_iter().take(msg_len - 1) { - match msg.role.as_str() { - "user" => messages.push(vec![role0.clone(), msg.content]), - "assistant" => messages.push(vec![role1.clone(), msg.content]), - _ => panic!("Only user and assistant roles are supported."), - } - } - - conv_config.messages = Some(messages); - conv_config.offset = Some(0); - - let mut map = HashMap::new(); - map.insert("conv_config", conv_config); - self.load_json_override(&serde_json::to_string(&map).unwrap(), true)?; - - last_msg.content - } - } - }; - - tvm_func_invoke!( - self, - prefill( - input_string, - &decode_next_token, - place_in_promt.to_value(), - generation_config_str - ) - ) - } -} diff --git a/rust/src/config.rs b/rust/src/config.rs deleted file mode 100644 index a6233952c4..0000000000 --- a/rust/src/config.rs +++ /dev/null @@ -1,273 +0,0 @@ -use serde::{Deserialize, Serialize}; - -/// A struct that represents user-defined partial configuration for conversation template. -/// -/// This can be passed in to the instantiation of a [ChatModule](crate::chat_module::ChatModule) -/// instance to override the default setting in `mlc-chat-config.json` under the -/// model folder. Note that we will first load the predefined template -/// with the name specified in `conv_template`. -/// -/// Since the configuration is partial, everything will be optional. -#[derive(Clone, Default, Builder, Debug, Serialize, Deserialize)] -#[builder(default)] -pub struct ConvConfig { - /// Token list prefixing the conversation. - prefix_tokens: Option>, - - /// Name of the conversation. - name: Option, - - /// The prompt encoded before starting the chat. - system: Option, - - /// An array that describes the role names of the user and the model. - roles: Option>, - - /// The chat history represented as an array of string pairs. - pub messages: Option>>, - - /// The offset used to begin the chat from the chat history. - pub offset: Option, - - /// Specifies whether we are in chat-bot mode (`0`) or pure LM prompt mode (`1`). - separator_style: Option, - - /// An array of strings indicating the separators to be used after a user message and a model message respectively. - seps: Option>, - - /// A string indicating the separator between a role and a message. - role_msg_sep: Option, - - /// A string indicating the separator to append to a role when there is no message yet. - role_empty_sep: Option, - - /// When the `stop_str` is encountered, the model will stop generating output. - stop_str: Option, - - /// A list of token IDs that act as stop tokens. - stop_tokens: Option>, - - /// Determines whether a beginning-of-string (bos) token should be added before the input tokens. - add_bos: Option, -} - -impl ConvConfig { - pub fn post_init(&mut self) { - if let Some(messages) = &self.messages { - if self.offset.is_none() { - self.offset = Some(messages.len()); - } - } - } -} - -/// A struct that represents user-defined partial configuration for the chat config file. -/// -/// An instance of [ChatConfig] can be passed in to override the default setting. -/// Since the configuration is partial, everything will be optional. -/// -/// Note: This struct is used to represent the chat config during intermediate processing. -#[derive(Builder, Debug, Default, Serialize, Deserialize)] -#[builder(default)] -pub struct ChatConfig { - /// The necessary model library to launch this model architecture. - /// Recommended to reuse model library when possible. - pub model_lib: Option, - - /// Uniquely identifying the model in application. Also used by - /// CLI to specify which model to run. - pub local_id: Option, - - /// The name of the conversation template that this chat uses. - pub conv_template: Option, - - /// Temperature applied to logits before sampling. Encourages diverse outputs if higher. - pub temperature: Option, - - /// Controls the likelihood of the model generating repeated texts. - /// See the CTRL paper for more details: - repetition_penalty: Option, - - /// Determines the set of tokens from which we sample during decoding. - /// More info on top-p sampling: - top_p: Option, - - /// Approximated average number of generated tokens in each round. - mean_gen_len: Option, - - /// Maximum number of tokens to be generated in each round. - max_gen_len: Option, - - /// Fraction of maximum window size to shift when it is exceeded. - shift_fill_factor: Option, - - /// List of tokenizer files of the model. - tokenizer_files: Option>, - - /// Partial overriding configuration for conversation template. - pub conv_config: Option, - - /// The category of the model's architecture (e.g. `llama`, `gpt_neox`, `rwkv`). - model_category: Option, - - /// Name of the model (e.g. `Llama-2-7b-chat-hf`). - model_name: Option, - - /// Tensor parallel degree. - num_shards: Option, - - /// Maximum kv cache window size. - max_window_size: Option, -} - -impl ChatConfig { - pub fn from_json(json_str: &str) -> Result { - serde_json::from_str(json_str) - } -} - -/// A struct that represents user-defined generation configuration. -/// -/// An instance of [GenerationConfig] can be passed into the -/// [ChatModule::generate](crate::chat_module::ChatModule::generate) function -/// to override the default generation settings specified in `mlc-chat-config.json` -/// and `ChatConfig` under the model folder. -/// -/// Once the generation ends, `GenerationConfig` is discarded, as the values -/// are only intended to override the `ChatConfig` generation settings during a -/// single generation, unless it is recurrently passed to the `generate` function. -/// This allows for changing generation settings over time, without permanently -/// overriding the `ChatConfig`. -/// -/// Since the configuration is partial, all fields are optional. -#[derive(Builder, Debug, Default, Serialize, Deserialize)] -#[builder(default)] -pub struct GenerationConfig { - /// The temperature applied to logits before sampling. The default value is - /// `0.7`. A higher temperature encourages more diverse outputs, while a - /// lower temperature produces more deterministic outputs. - temperature: Option, - - /// The repetition penalty controls the likelihood of the model generating - /// repeated texts. The default value is set to `1.0`, indicating that no - /// repetition penalty is applied. Increasing the value reduces the - /// likelihood of repeat text generation. However, setting a high - /// `repetition_penalty` may result in the model generating meaningless - /// texts. The ideal choice of repetition penalty may vary among models. Only - /// Active when presence_penalty and frequency_penalty are both `0.0`. - - /// For more details on how repetition penalty controls text generation, please - /// check out the CTRL paper . - repetition_penalty: Option, - - /// This parameter determines the set of tokens from which we sample during - /// decoding. The default value is set to `0.95`. At each step, we select - /// tokens from the minimal set that has a cumulative probability exceeding - /// the ``top_p` parameter. - - /// For additional information on top-p sampling, please refer to this blog - /// post: . - top_p: Option, - - /// The approximated average number of generated tokens in each round. Used - /// to determine whether the maximum window size would be exceeded. - mean_gen_len: Option, - - /// This parameter determines the maximum length of the generated text. If it is - /// not set, the model will generate text until it encounters a stop token. - max_gen_len: Option, - - /// Number between `-2.0` and `2.0`. Positive values penalize new tokens based on - /// whether they appear in the text so far, increasing the model's likelihood - /// to talk about new topics. Negative values can increase the likelihood of - /// repetition. - presence_penalty: Option, - - /// Number between `-2.0` and `2.0`. Positive values penalize new tokens based on their - /// existing frequency in the text so far, decreasing the model's likelihood to - /// repeat the same line verbatim. Negative values can increase the likelihood of - /// repetition. - frequency_penalty: Option, - - /// This parameter determines the number of text samples to generate. The default - /// value is `1`. Note that this parameter is only used when `stream` is set to - /// `false`. - pub n: Option, - - /// When `stop` is encountered, the model will stop generating output. - /// It can be a string or a list of strings. If it is a list of strings, the model - /// will stop generating output when any of the strings in the list is encountered. - /// Note that this parameter does not override the default stop string of the model. - stop: Option>, -} - -impl GenerationConfig { - pub fn from_chat_config(chat_config: &ChatConfig) -> Self { - Self { - temperature: chat_config.temperature, - repetition_penalty: chat_config.repetition_penalty, - top_p: chat_config.top_p, - mean_gen_len: chat_config.mean_gen_len, - max_gen_len: chat_config.max_gen_len, - presence_penalty: Some(0.0), - frequency_penalty: Some(0.0), - n: Some(0), - stop: None, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_conv_config() { - let mut config = ConvConfig { - messages: Some(vec![vec!["User: Hi".to_string(), "Assistant: Hello".to_string()]]), - offset: None, - ..Default::default() - }; - config.post_init(); - assert_eq!(config.offset, Some(1)); - } - - #[test] - fn test_chat_config() { - let json_data = r#" - { - "model_lib": "some_lib", - "local_id": "id123", - "temperature": 0.7 - } - "#; - - let config = ChatConfig::from_json(json_data).unwrap(); - - assert_eq!(config.model_lib, Some("some_lib".to_string())); - assert_eq!(config.local_id, Some("id123".to_string())); - assert_eq!(config.temperature, Some(0.7)); - let _pretty_json = serde_json::to_string_pretty(&config).unwrap(); - } - - #[test] - fn test_generation_config() { - let chat_config = ChatConfigBuilder::default() - .temperature(Some(0.7)) - .top_p(Some(0.8)) - .mean_gen_len(Some(50)) - .max_gen_len(Some(75)) - .build() - .unwrap(); - - let gen_config = GenerationConfig::from_chat_config(&chat_config); - - assert_eq!(gen_config.temperature, chat_config.temperature); - assert_eq!(gen_config.repetition_penalty, chat_config.repetition_penalty); - assert_eq!(gen_config.top_p, chat_config.top_p); - assert_eq!(gen_config.mean_gen_len, chat_config.mean_gen_len); - assert_eq!(gen_config.max_gen_len, chat_config.max_gen_len); - assert_eq!(gen_config.presence_penalty, Some(0.0)); - assert_eq!(gen_config.frequency_penalty, Some(0.0)); - } -} diff --git a/rust/src/lib.rs b/rust/src/lib.rs deleted file mode 100644 index a8315d7d41..0000000000 --- a/rust/src/lib.rs +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#[macro_use] extern crate derive_builder; - -pub mod chat_module; -pub mod config; diff --git a/scripts/build_mlc_for_docs.sh b/scripts/build_mlc_for_docs.sh deleted file mode 100755 index 50eee3231a..0000000000 --- a/scripts/build_mlc_for_docs.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -set -euxo pipefail - -mkdir -p build -cd build -cmake .. -make -j$(nproc) -cd - diff --git a/scripts/build_site.sh b/scripts/build_site.sh deleted file mode 100755 index 062f8094de..0000000000 --- a/scripts/build_site.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash -set -euxo pipefail - -export PYTHONPATH=$PWD/python -cd docs && make html && cd .. - -cd site && jekyll b && cd .. - -rm -rf site/_site/docs -cp -r docs/_build/html site/_site/docs diff --git a/scripts/check_url_validity.py b/scripts/check_url_validity.py deleted file mode 100644 index 3cbb29e6fb..0000000000 --- a/scripts/check_url_validity.py +++ /dev/null @@ -1,44 +0,0 @@ -import requests -import argparse -import re -from pathlib import Path - - -def find_urls_in_file(file_path): - with open(file_path, "r") as file: - content = file.read() - - # Regular expression pattern to match URLs - url_pattern = re.compile( - r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+" - ) - - # Find all matches of URLs in the content - urls = re.findall(url_pattern, content) - return [url.strip(">") for url in urls] - - -def main(): - parser = argparse.ArgumentParser( - description="Check validity of links in documentation" - ) - parser.add_argument( - "--directory", type=str, default="docs", help="Directory of documentation." - ) - args = parser.parse_args() - - # traversal the directory and find all rst files - doc_directory = Path(args.directory) - for file_path in doc_directory.glob("**/*.rst"): - print("Checking {}...".format(file_path)) - for url in find_urls_in_file(file_path): - try: - r = requests.get(url) - if r.status_code == 404: - print("404 not found: {}".format(url)) - except Exception as e: - print("Error connecting {}, error: {}".format(url, e)) - - -if __name__ == "__main__": - main() diff --git a/scripts/gh_deploy_site.sh b/scripts/gh_deploy_site.sh deleted file mode 100755 index 326c280484..0000000000 --- a/scripts/gh_deploy_site.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash -# NOTE: this script is triggered by github action automatically -# when megred into main - -set -euxo pipefail - -scripts/build_mlc_for_docs.sh -scripts/build_site.sh - -git fetch -git checkout -B gh-pages origin/gh-pages -rm -rf docs .gitignore -mkdir -p docs -cp -rf site/_site/* docs -touch docs/.nojekyll - -DATE=`date` -git add docs && git commit -am "Build at ${DATE}" -git push origin gh-pages -git checkout main && git submodule update -echo "Finish deployment at ${DATE}" diff --git a/scripts/local_deploy_site.sh b/scripts/local_deploy_site.sh deleted file mode 100755 index 52ba40b6fe..0000000000 --- a/scripts/local_deploy_site.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -# NOTE: use this script to check local site - -set -euxo pipefail - -scripts/build_site.sh - -cd site && jekyll serve --skip-initial-build --host localhost --baseurl / --port 8888 diff --git a/site/.gitignore b/site/.gitignore deleted file mode 100644 index 51b35994f6..0000000000 --- a/site/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -dist -llm-chat-config.json -_includes/stable_diffusion.html -_site diff --git a/site/CNAME b/site/CNAME deleted file mode 100644 index 0b04c4050f..0000000000 --- a/site/CNAME +++ /dev/null @@ -1 +0,0 @@ -llm.mlc.ai \ No newline at end of file diff --git a/site/_config.yml b/site/_config.yml deleted file mode 100644 index 9806232d73..0000000000 --- a/site/_config.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: "MLC LLM" -short_name: "MLC LLM" - -url: https://llm.mlc.ai/ - -exclude: [README.md, serve_local.sh] - -plugins: - - jekyll-remote-theme - -remote_theme: mlc-ai/jekyll-theme-mlc - - -# Colorize code snippets with the rogue module if we want to deploy on GH. -highlighter: rouge - -markdown: kramdown - -# The path structure for blog posts. -permalink: /blog/:year/:month/:day/:title.html - -# Number of news stories on the front page. -front_page_news: 8 - -# Base pathname for links. -base: '' - -# make pages for the _projects folder -collections: - projects: - output: true - -course_title: - -# Navigation bar links. -navigation: - - title: Home - link: / - - title: Docs - link: /docs - - title: Github - link: https://github.com/mlc-ai/mlc-llm diff --git a/site/gif/android-demo.gif b/site/gif/android-demo.gif deleted file mode 100644 index aec883f598..0000000000 Binary files a/site/gif/android-demo.gif and /dev/null differ diff --git a/site/gif/ios-demo.gif b/site/gif/ios-demo.gif deleted file mode 100644 index 7256afec90..0000000000 Binary files a/site/gif/ios-demo.gif and /dev/null differ diff --git a/site/gif/linux-demo.gif b/site/gif/linux-demo.gif deleted file mode 100644 index 15cfc9daa6..0000000000 Binary files a/site/gif/linux-demo.gif and /dev/null differ diff --git a/site/img/android/android-diagram.png b/site/img/android/android-diagram.png deleted file mode 100644 index 5f49f7cd9d..0000000000 Binary files a/site/img/android/android-diagram.png and /dev/null differ diff --git a/site/img/android/android-studio.png b/site/img/android/android-studio.png deleted file mode 100644 index 7c40215ecc..0000000000 Binary files a/site/img/android/android-studio.png and /dev/null differ diff --git a/site/img/android/android-vs-ios.png b/site/img/android/android-vs-ios.png deleted file mode 100644 index 24367975ea..0000000000 Binary files a/site/img/android/android-vs-ios.png and /dev/null differ diff --git a/site/img/android/local-advantage.png b/site/img/android/local-advantage.png deleted file mode 100644 index 854864fc72..0000000000 Binary files a/site/img/android/local-advantage.png and /dev/null differ diff --git a/site/img/diag.svg b/site/img/diag.svg deleted file mode 100644 index af9d1c7d34..0000000000 --- a/site/img/diag.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/site/img/multi-gpu/figure-1.svg b/site/img/multi-gpu/figure-1.svg deleted file mode 100644 index d3083cf775..0000000000 --- a/site/img/multi-gpu/figure-1.svg +++ /dev/null @@ -1,247 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/site/img/multi-gpu/figure-2.svg b/site/img/multi-gpu/figure-2.svg deleted file mode 100644 index 70d35f5037..0000000000 --- a/site/img/multi-gpu/figure-2.svg +++ /dev/null @@ -1,418 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/site/img/multi-gpu/figure-3.svg b/site/img/multi-gpu/figure-3.svg deleted file mode 100644 index 078231fae6..0000000000 --- a/site/img/multi-gpu/figure-3.svg +++ /dev/null @@ -1,167 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/site/index.md b/site/index.md deleted file mode 100644 index ac0367cdb2..0000000000 --- a/site/index.md +++ /dev/null @@ -1,47 +0,0 @@ ---- -layout: default -title: Home -notitle: true ---- - -# MLC LLM - -Documentation: [https://llm.mlc.ai/docs](https://llm.mlc.ai/docs) - -**M**achine **L**earning **C**ompilation for **L**arge **L**anguage **M**odels (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. - -

- -

- -## Installation - -MLC LLM is available via [pip](https://llm.mlc.ai/docs/install/mlc_llm.html#install-mlc-packages). -It is always recommended to install it in an isolated conda virtual environment. - -To verify the installation, activate your virtual environment, run - -```bash -python -c "import mlc_llm; print(mlc_llm.__path__)" -``` - -You are expected to see the installation path of MLC LLM Python package. - -## Quick Start - -Please check out our documentation for the [quick start](https://llm.mlc.ai/docs/get_started/quick_start.html). - -## Introduction - -Please check out our documentation for the [introduction](https://llm.mlc.ai/docs/get_started/introduction.html). - -## Links - -- You might want to check out our online public [Machine Learning Compilation course](https://mlc.ai) for a systematic -walkthrough of our approaches. -- [WebLLM](https://webllm.mlc.ai/) is a companion project using MLC LLM's WebGPU and WebAssembly backend. -- [WebStableDiffusion](https://websd.mlc.ai/) is a companion project for diffusion models with the WebGPU backend. - -## Disclaimer - -The pre-packaged demos are subject to the model License. diff --git a/site/privacy.md b/site/privacy.md deleted file mode 100644 index f7f2d29a06..0000000000 --- a/site/privacy.md +++ /dev/null @@ -1,10 +0,0 @@ ---- -layout: default -title: Home -notitle: true ---- - -# MLC Chat App Privacy - -MLC Chat run all generation locally. -All data stays in users' device and is not collected by the app. diff --git a/tests/cpp/conv_unittest.cc b/tests/cpp/conv_unittest.cc deleted file mode 100644 index d49c7107cd..0000000000 --- a/tests/cpp/conv_unittest.cc +++ /dev/null @@ -1,84 +0,0 @@ -#include -#include - -void _TestConversationLoadJSON() { - std::string conv_template = - "{\n" - " \"name\": \"test\",\n" - " \"system_template\": \"abc{system_message}\",\n" - " \"system_message\": \"de\",\n" - " \"roles\": {\n" - " \"user\": \"Instruct\",\n" - " \"assistant\": \"Output\",\n" - " \"tool\": \"Instruct\"\n" - " },\n" - " \"role_templates\": {\n" - " \"user\": \"{user_message}\",\n" - " \"assistant\": \"{assistant_message}\",\n" - " \"tool\": \"{tool_message}\"\n" - " },\n" - " \"messages\": [[\"Instruct\", \"Hello\"], [\"Output\", \"Hey\"]],\n" - " \"seps\": [\n" - " \"\\n\"\n" - " ],\n" - " \"role_content_sep\": \": \",\n" - " \"role_empty_sep\": \":\",\n" - " \"stop_str\": [\n" - " \"<|endoftext|>\"\n" - " ],\n" - " \"stop_token_ids\": [\n" - " 50256\n" - " ],\n" - " \"function_string\": \"\",\n" - " \"use_function_calling\": false\n" - "}"; - mlc::llm::Conversation conv; - conv.LoadJSONOverride(conv_template, true); - ASSERT_EQ(conv.name, "test"); - ASSERT_EQ(conv.system, "abcde"); - - std::vector expected_roles{"Instruct", "Output"}; - ASSERT_EQ(conv.roles, expected_roles); - - std::vector> expected_messages = {{"Instruct", "Hello"}, - {"Output", "Hey"}}; - ASSERT_EQ(conv.messages, expected_messages); - ASSERT_EQ(conv.offset, 2); - - std::vector expected_seps = {"\n"}; - ASSERT_EQ(conv.seps, expected_seps); - - ASSERT_EQ(conv.role_msg_sep, ": "); - ASSERT_EQ(conv.role_empty_sep, ":"); - ASSERT_EQ(conv.stop_str, "<|endoftext|>"); - - std::vector expected_stop_tokens = {50256}; - ASSERT_EQ(conv.stop_tokens, expected_stop_tokens); -} - -void _TestConversationJSONRoundTrip(std::string templ_name) { - mlc::llm::Conversation conv = mlc::llm::Conversation::FromTemplate(templ_name); - std::string conv_json = conv.GetConfigJSON(); - mlc::llm::Conversation conv_new; - conv_new.LoadJSONOverride(conv_json, false); - ASSERT_EQ(conv, conv_new); -} - -void _TestConversationPartialUpdate() { - mlc::llm::Conversation conv; - std::string json_str = "{\"name\": \"test\"}"; - ASSERT_ANY_THROW(conv.LoadJSONOverride(json_str, false)); - conv.LoadJSONOverride(json_str, true); - ASSERT_EQ(conv.name, "test"); -} - -TEST(ConversationTest, ConversationLoadJSONTest) { _TestConversationLoadJSON(); } - -TEST(ConversationTest, ConversationJSONRoundTripTest) { - _TestConversationJSONRoundTrip("vicuna_v1.1"); - _TestConversationJSONRoundTrip("conv_one_shot"); - _TestConversationJSONRoundTrip("redpajama_chat"); - _TestConversationJSONRoundTrip("LM"); -} - -TEST(ConversationTest, ConversationPartialUpdateTest) { _TestConversationPartialUpdate(); } diff --git a/tests/python/api/test_python.py b/tests/python/api/test_python.py deleted file mode 100644 index d4945f9503..0000000000 --- a/tests/python/api/test_python.py +++ /dev/null @@ -1,45 +0,0 @@ -# pylint: disable=missing-docstring -import pytest - -from mlc_llm import ChatModule, GenerationConfig -from mlc_llm.callback import StreamToStdout - -MODELS = ["Llama-2-7b-chat-hf-q4f16_1"] - - -@pytest.mark.parametrize("model", MODELS) -def test_chat_module_creation_and_generate(model: str): - chat_module = ChatModule(model=model) - _ = chat_module.generate( - prompt="How to make a cake?", - ) - print(f"Statistics: {chat_module.stats()}\n") - - -@pytest.mark.parametrize("model", MODELS) -def test_chat_module_creation_and_generate_with_stream(model: str): - chat_module = ChatModule(model=model) - _ = chat_module.generate( - prompt="How to make a cake?", - progress_callback=StreamToStdout(callback_interval=2), - ) - print(f"Statistics: {chat_module.stats()}\n") - - -@pytest.mark.parametrize( - "generation_config", - [ - GenerationConfig(temperature=0.7, presence_penalty=0.1, frequency_penalty=0.5, top_p=0.9), - GenerationConfig(stop=["cake", "make"], n=3), - GenerationConfig(max_gen_len=40, repetition_penalty=1.2), - ], -) -@pytest.mark.parametrize("model", MODELS) -def test_chat_module_generation_config(generation_config: GenerationConfig, model: str): - chat_module = ChatModule(model=model) - output = chat_module.generate( - prompt="How to make a cake?", - generation_config=generation_config, - ) - print(output) - print(f"Statistics: {chat_module.stats()}\n") diff --git a/tests/python/api/test_rest.py b/tests/python/api/test_rest.py deleted file mode 100644 index f617c5727d..0000000000 --- a/tests/python/api/test_rest.py +++ /dev/null @@ -1,105 +0,0 @@ -# pylint: disable=missing-docstring -import json -import os -import signal -import subprocess -import time - -import pytest -import requests - -MODELS = ["Llama-2-7b-chat-hf-q4f16_1"] - - -@pytest.fixture -def run_rest_server(model): - cmd = f"python -m mlc_llm.rest --model {model}" - print(cmd) - os.environ["PYTHONPATH"] = "./python" - with subprocess.Popen(cmd.split()) as server_proc: - # wait for server to start - while True: - try: - _ = requests.get("http://localhost:8000/stats", timeout=5) - break - except requests.exceptions.ConnectionError: - time.sleep(1) - yield - server_proc.send_signal(signal.SIGINT) - server_proc.wait() - - -@pytest.mark.usefixtures("run_rest_server") -@pytest.mark.parametrize("stream", [True, False]) -@pytest.mark.parametrize("model", MODELS) -def test_rest_chat_completions(model, stream): - payload = { - "model": model, - "messages": [ - { - "role": "user", - "content": "Hello, I am Bob", - }, - { - "role": "assistant", - "content": "Hello, I am a chatbot.", - }, - { - "role": "user", - "content": "What is my name?", - }, - ], - "stream": stream, - "frequency_penalty": 0.0, - "presence_penalty": 0.0, - "temperature": 1.0, - "top_p": 0.95, - } - if stream: - with requests.post( - "http://127.0.0.1:8000/v1/chat/completions", json=payload, stream=True, timeout=120 - ) as model_response: - print("With streaming:") - for chunk in model_response: - data = chunk[6:-2] - if data != b"[DONE]": - content = json.loads(data)["choices"][0]["delta"].get("content", "") - print(f"{content}", end="", flush=True) - print("\n") - else: - model_response = requests.post( - "http://127.0.0.1:8000/v1/chat/completions", json=payload, timeout=120 - ) - print(f"\n{model_response.json()['choices'][0]['message']['content']}\n") - - -@pytest.mark.usefixtures("run_rest_server") -@pytest.mark.parametrize("stream", [True, False]) -@pytest.mark.parametrize("model", MODELS) -def test_rest_completions(model, stream): - payload = { - "model": model, - "prompt": "What is the meaning of life?", - "stream": stream, - "frequency_penalty": 0.0, - "presence_penalty": 0.0, - "temperature": 1.0, - "n": 3, - } - if stream: - with requests.post( - "http://127.0.0.1:8000/v1/completions", json=payload, stream=True, timeout=120 - ) as model_response: - print("With streaming:") - for chunk in model_response: - data = chunk[6:-2] - if data != b"[DONE]": - content = json.loads(data)["choices"][0]["text"] - print(f"{content}", end="", flush=True) - print("\n") - else: - model_response = requests.post( - "http://127.0.0.1:8000/v1/completions", json=payload, timeout=120 - ) - assert len(model_response.json()["choices"]) == 3 - print(f"\n{model_response.json()['choices'][0]['text']}\n") diff --git a/tests/python/compiler_pass/test_fuse_ft_dequantize_matmul_epilogue.py b/tests/python/compiler_pass/test_fuse_ft_dequantize_matmul_epilogue.py deleted file mode 100644 index 1035ce96fd..0000000000 --- a/tests/python/compiler_pass/test_fuse_ft_dequantize_matmul_epilogue.py +++ /dev/null @@ -1,342 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring,too-few-public-methods -import tvm -from tvm.ir import assert_structural_equal -from tvm.script import ir as I -from tvm.script import relax as R - -from mlc_llm.compiler_pass.fuse_ft_dequantize_matmul_epilogue import ( - FuseFTDequantizeEpilogue, -) - - -def test_fuse_bias(): - @I.ir_module - class Before: - @R.function - def main( - x: R.Tensor((1, 1, 4096), "float16"), - weight: R.Tensor((4096, 512), "int8"), - scale: R.Tensor((1, 1024), "float16"), - bias: R.Tensor((1, 1, 1024), "float16"), - ): - with R.dataflow(): - lv1 = R.call_dps_packed( - "fastertransformer.gemm_fp16_int", - ( - x, - weight, - scale, - "identity", - R.prim_value(1), - R.prim_value(1024), - R.prim_value(4096), - R.prim_value(4096), - ), - out_sinfo=R.Tensor((1, 1, 1024), "float16"), - ) - lv2 = R.add(lv1, bias) - R.output(lv2) - return lv2 - - @I.ir_module - class After: - @R.function - def main( - x: R.Tensor((1, 1, 4096), "float16"), - weight: R.Tensor((4096, 512), "int8"), - scale: R.Tensor((1, 1024), "float16"), - bias: R.Tensor((1, 1, 1024), "float16"), - ) -> R.Tensor((1, 1, 1024), "float16"): - with R.dataflow(): - lv2 = R.call_dps_packed( - "fastertransformer.gemm_fp16_int_bias", - ( - x, - weight, - scale, - bias, - R.str("identity"), - R.prim_value(1), - R.prim_value(1024), - R.prim_value(4096), - R.prim_value(4096), - R.prim_value(0), - ), - out_sinfo=R.Tensor((1, 1, 1024), "float16"), - ) - R.output(lv2) - return lv2 - - seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()]) - mod = seq(Before) - assert_structural_equal(mod, After) - - -def test_fuse_activation(): - @I.ir_module - class Before: - @R.function - def main( - x: R.Tensor((1, 1, 4096), "float16"), - weight: R.Tensor((4096, 512), "int8"), - scale: R.Tensor((1, 1024), "float16"), - ): - with R.dataflow(): - lv1 = R.call_dps_packed( - "fastertransformer.gemm_fp16_int", - ( - x, - weight, - scale, - "identity", - R.prim_value(1), - R.prim_value(1024), - R.prim_value(4096), - R.prim_value(4096), - ), - out_sinfo=R.Tensor((1, 1, 1024), "float16"), - ) - lv2 = R.nn.silu(lv1) - R.output(lv2) - return lv2 - - @I.ir_module - class After: - @R.function - def main( - x: R.Tensor((1, 1, 4096), "float16"), - weight: R.Tensor((4096, 512), "int8"), - scale: R.Tensor((1, 1024), "float16"), - ) -> R.Tensor((1, 1, 1024), "float16"): - with R.dataflow(): - lv2 = R.call_dps_packed( - "fastertransformer.gemm_fp16_int", - ( - x, - weight, - scale, - R.str("silu"), - R.prim_value(1), - R.prim_value(1024), - R.prim_value(4096), - R.prim_value(4096), - ), - out_sinfo=R.Tensor((1, 1, 1024), "float16"), - ) - R.output(lv2) - return lv2 - - seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()]) - mod = seq(Before) - assert_structural_equal(mod, After) - - -def test_fuse_bias_activation(): - @I.ir_module - class Before: - @R.function - def main( - x: R.Tensor((1, 1, 4096), "float16"), - weight: R.Tensor((4096, 512), "int8"), - scale: R.Tensor((1, 1024), "float16"), - bias: R.Tensor((1, 1, 1024), "float16"), - ): - with R.dataflow(): - lv1 = R.call_dps_packed( - "fastertransformer.gemm_fp16_int", - ( - x, - weight, - scale, - "identity", - R.prim_value(1), - R.prim_value(1024), - R.prim_value(4096), - R.prim_value(4096), - ), - out_sinfo=R.Tensor((1, 1, 1024), "float16"), - ) - lv2 = R.add(lv1, bias) - lv3 = R.nn.relu(lv2) - R.output(lv3) - return lv3 - - @I.ir_module - class After: - @R.function - def main( - x: R.Tensor((1, 1, 4096), "float16"), - weight: R.Tensor((4096, 512), "int8"), - scale: R.Tensor((1, 1024), "float16"), - bias: R.Tensor((1, 1, 1024), "float16"), - ) -> R.Tensor((1, 1, 1024), "float16"): - with R.dataflow(): - lv2 = R.call_dps_packed( - "fastertransformer.gemm_fp16_int_bias", - ( - x, - weight, - scale, - bias, - R.str("relu"), - R.prim_value(1), - R.prim_value(1024), - R.prim_value(4096), - R.prim_value(4096), - R.prim_value(0), - ), - out_sinfo=R.Tensor((1, 1, 1024), "float16"), - ) - R.output(lv2) - return lv2 - - seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()]) - mod = seq(Before) - assert_structural_equal(mod, After) - - -def test_fuse_residual_binary(): - @I.ir_module - class Before: - @R.function - def main( - x: R.Tensor((1, 1, 4096), "float16"), - weight: R.Tensor((4096, 512), "int8"), - scale: R.Tensor((1, 1024), "float16"), - bias: R.Tensor((1, 1, 1024), "float16"), - residual: R.Tensor((1, 1, 1024), "float16"), - ): - with R.dataflow(): - lv1 = R.call_dps_packed( - "fastertransformer.gemm_fp16_int", - ( - x, - weight, - scale, - "identity", - R.prim_value(1), - R.prim_value(1024), - R.prim_value(4096), - R.prim_value(4096), - ), - out_sinfo=R.Tensor((1, 1, 1024), "float16"), - ) - lv2 = R.add(lv1, bias) - lv3 = R.nn.relu(lv2) - lv4 = R.multiply(lv3, residual) - R.output(lv4) - return lv4 - - @I.ir_module - class After: - @R.function - def main( - x: R.Tensor((1, 1, 4096), "float16"), - weight: R.Tensor((4096, 512), "int8"), - scale: R.Tensor((1, 1024), "float16"), - bias: R.Tensor((1, 1, 1024), "float16"), - residual: R.Tensor((1, 1, 1024), "float16"), - ) -> R.Tensor((1, 1, 1024), "float16"): - with R.dataflow(): - lv2 = R.call_dps_packed( - "fastertransformer.gemm_fp16_int_bias_residual", - ( - x, - weight, - scale, - bias, - residual, - R.str("relu"), - R.str("multiply"), - R.str("identity"), - R.prim_value(1), - R.prim_value(1024), - R.prim_value(4096), - R.prim_value(4096), - ), - out_sinfo=R.Tensor((1, 1, 1024), "float16"), - ) - R.output(lv2) - return lv2 - - seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()]) - mod = seq(Before) - assert_structural_equal(mod, After) - - -def test_fuse_residual_unary(): - @I.ir_module - class Before: - @R.function - def main( - x: R.Tensor((1, 1, 4096), "float16"), - weight: R.Tensor((4096, 512), "int8"), - scale: R.Tensor((1, 1024), "float16"), - bias: R.Tensor((1, 1, 1024), "float16"), - residual: R.Tensor((1, 1, 1024), "float16"), - ): - with R.dataflow(): - lv1 = R.call_dps_packed( - "fastertransformer.gemm_fp16_int", - ( - x, - weight, - scale, - "identity", - R.prim_value(1), - R.prim_value(1024), - R.prim_value(4096), - R.prim_value(4096), - ), - out_sinfo=R.Tensor((1, 1, 1024), "float16"), - ) - lv2 = R.add(lv1, bias) - lv3 = R.nn.relu(lv2) - lv4 = R.add(lv3, residual) - lv5 = R.nn.gelu(lv4) - R.output(lv5) - return lv5 - - @I.ir_module - class After: - @R.function - def main( - x: R.Tensor((1, 1, 4096), "float16"), - weight: R.Tensor((4096, 512), "int8"), - scale: R.Tensor((1, 1024), "float16"), - bias: R.Tensor((1, 1, 1024), "float16"), - residual: R.Tensor((1, 1, 1024), "float16"), - ) -> R.Tensor((1, 1, 1024), "float16"): - with R.dataflow(): - lv2 = R.call_dps_packed( - "fastertransformer.gemm_fp16_int_bias_residual", - ( - x, - weight, - scale, - bias, - residual, - R.str("relu"), - R.str("plus"), - R.str("gelu"), - R.prim_value(1), - R.prim_value(1024), - R.prim_value(4096), - R.prim_value(4096), - ), - out_sinfo=R.Tensor((1, 1, 1024), "float16"), - ) - R.output(lv2) - return lv2 - - seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()]) - mod = seq(Before) - assert_structural_equal(mod, After) - - -if __name__ == "__main__": - test_fuse_bias() - test_fuse_activation() - test_fuse_bias_activation() - test_fuse_residual_binary() - test_fuse_residual_unary() diff --git a/tests/python/conftest.py b/tests/python/conftest.py deleted file mode 100644 index b19fce722c..0000000000 --- a/tests/python/conftest.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-module-docstring,unused-import -import pytest -import tvm.testing - -pytest_plugins = ["tvm.testing.plugin"] diff --git a/tests/python/integration/test_model_compile.py b/tests/python/integration/test_model_compile.py deleted file mode 100644 index 3ec70b61b3..0000000000 --- a/tests/python/integration/test_model_compile.py +++ /dev/null @@ -1,167 +0,0 @@ -# pylint: disable=missing-docstring -import concurrent.futures as cf -import os -import shlex -import subprocess -import sys -import tempfile -from itertools import product - -import tvm - -from mlc_llm.model import MODEL_PRESETS -from mlc_llm.model import MODELS as SUPPORTED_MODELS -from mlc_llm.quantization import QUANTIZATION as SUPPORTED_QUANTS -from mlc_llm.support.constants import MLC_TEMP_DIR - -OPT_LEVEL = "O2" -DEVICE2TARGET = { - "cuda": { - "kind": "cuda", - "arch": "sm_86", - "max_threads_per_block": 1024, - "max_num_threads": 1024, - "max_shared_memory_per_block": 49152, - "thread_warp_size": 32, - }, - "rocm": { - "kind": "rocm", - "mtriple": "amdgcn-amd-amdhsa-hcc", - "mcpu": "gfx1100", - "thread_warp_size": 32, - "max_threads_per_block": 1024, - "max_num_threads": 256, - "max_shared_memory_per_block": 65536, - }, - "vulkan": { - "kind": "vulkan", - "max_threads_per_block": 1024, - "max_num_threads": 256, - "max_shared_memory_per_block": 32768, - "thread_warp_size": 1, - "supports_float32": 1, - "supports_float16": 1, - "supports_int64": 1, - "supports_int32": 1, - "supports_int16": 1, - "supports_int8": 1, - "supports_16bit_buffer": 1, - }, - "metal": "metal", - "wasm": "webgpu", - "android": "android", - "ios": "iphone", -} -DEVICE2SUFFIX = { - "cuda": "so", - "rocm": "so", - "vulkan": "so", - "metal": "dylib", - "wasm": "wasm", - "android": "tar", - "ios": "tar", -} -MODELS = list(MODEL_PRESETS.keys()) -QUANTS = [ # TODO(@junrushao): use `list(mlc_llm.quantization.QUANTIZATION.keys())` - "q0f16", - "q0f32", - "q3f16_1", - "q4f16_1", - "q4f32_1", - "q4f16_ft", -] -TENSOR_PARALLEL_SHARDS = [ - 1, -] - - -def run_command(log_file, cmd): - with open(log_file, "w", encoding="utf-8") as file: - subprocess.check_call( - cmd, - stdout=file, - stderr=subprocess.STDOUT, - ) - - -def test_model_compile(): # pylint: disable=too-many-locals - device = sys.argv[1] - num_workers = int(sys.argv[2]) - target = DEVICE2TARGET[device] - if not isinstance(target, str): - target = str(tvm.target.Target(target)) - suffix = DEVICE2SUFFIX[device] - - passed_cmds = [] - failed_cmds = [] - with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir: - with cf.ProcessPoolExecutor(max_workers=num_workers) as executor: - log_files = [] - cmds = [] - futures = [] - for idx, (model, quant, tp_shard) in enumerate( - product( - MODELS, - QUANTS, - TENSOR_PARALLEL_SHARDS, - ) - ): - if ( - SUPPORTED_QUANTS[quant].kind - not in SUPPORTED_MODELS[MODEL_PRESETS[model]["model_type"]].quantize - ): - continue - if not target.startswith("cuda") and quant == "q4f16_ft": - # FasterTransformer only works with cuda - continue - log_file = os.path.join(tmp_dir, f"lib{idx}.log") - cmd = [ - sys.executable, - "-m", - "mlc_llm", - "compile", - model, - "--quantization", - quant, - "--overrides", - f"tensor_parallel_shards={tp_shard}", - "--device", - target, - "--opt", - OPT_LEVEL, - "-o", - os.path.join(tmp_dir, f"lib{idx}.{suffix}"), - ] - future = executor.submit(run_command, log_file, cmd) - log_files.append(log_file) - cmds.append(cmd) - futures.append(future) - for log_file, cmd, future in zip(log_files, cmds, futures): - cmd = shlex.join(cmd) - try: - future.result() - passed_cmds.append(cmd) - print(f"[PASS] {cmd}") - except Exception: # pylint: disable=broad-except - failed_cmds.append(cmd) - print("-------------------------------") - print(f"[FAIL] {cmd}") - with open(log_file, "r", encoding="utf-8") as file: - print(file.read()) - print("-------------------------------") - print("-------------------------------") - print(f"Total {len(passed_cmds)} passed, {len(failed_cmds)} failed.") - print("-------------------------------") - print("Passed commands:") - for cmd in passed_cmds: - print(cmd) - if failed_cmds: - print("-------------------------------") - print("Failed commands:") - for cmd in failed_cmds: - print(cmd) - sys.exit(1) - - -if __name__ == "__main__": - test_model_compile() diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py deleted file mode 100644 index c52571b522..0000000000 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ /dev/null @@ -1,141 +0,0 @@ -from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union - -from mlc_llm.json_ffi import JSONFFIEngine - -chat_completion_prompts = [ - "What is the meaning of life?", - "Introduce the history of Pittsburgh to me. Please elaborate in detail.", - "Write a three-day Seattle travel plan. Please elaborate in detail.", - "What is Alaska famous of? Please elaborate in detail.", - "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", - "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", - "Why is Vitamin D important to human beings? Please elaborate in detail.", - "Where is milk tea originated from? Please elaborate in detail.", - "Where is the southernmost place in United States? Please elaborate in detail.", - "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", -] - -function_calling_prompts = [ - "What is the temperature in Pittsburgh, PA?", - "What is the temperature in Tokyo, JP?", - "What is the temperature in Pittsburgh, PA and Tokyo, JP?", -] - -tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - }, - }, - } -] - - -def run_chat_completion( - engine: JSONFFIEngine, - model: str, - prompts: List[str] = chat_completion_prompts, - tools: Optional[List[Dict]] = None, -): - num_requests = 2 - max_tokens = 64 - n = 1 - output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] - - for rid in range(num_requests): - print(f"chat completion for request {rid}") - for response in engine.chat_completion( - messages=[{"role": "user", "content": [{"type": "text", "text": prompts[rid]}]}], - model=model, - max_tokens=max_tokens, - n=n, - request_id=str(rid), - tools=tools, - ): - for choice in response.choices: - assert choice.delta.role == "assistant" - assert isinstance(choice.delta.content[0], Dict) - assert choice.delta.content[0]["type"] == "text" - output_texts[rid][choice.index] += choice.delta.content[0]["text"] - - # Print output. - print("Chat completion all finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - -def test_chat_completion(): - # Create engine. - model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - engine = JSONFFIEngine( - model, - max_total_sequence_length=1024, - ) - - run_chat_completion(engine, model) - - # Test malformed requests. - for response in engine._handle_chat_completion("malformed_string", n=1, request_id="123"): - assert len(response.choices) == 1 - assert response.choices[0].finish_reason == "error" - - engine.terminate() - - -def test_reload_reset_unload(): - # Create engine. - model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - engine = JSONFFIEngine( - model, - max_total_sequence_length=1024, - ) - - # Run chat completion before and after reload/reset. - run_chat_completion(engine, model) - engine._test_reload() - run_chat_completion(engine, model) - engine._test_reset() - run_chat_completion(engine, model) - engine._test_unload() - - engine.terminate() - - -def test_function_calling(): - model = "dist/gorilla-openfunctions-v1-q4f16_1-MLC" - model_lib_path = ( - "dist/gorilla-openfunctions-v1-q4f16_1-MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so" - ) - engine = JSONFFIEngine( - model, - model_lib_path=model_lib_path, - max_total_sequence_length=1024, - ) - - # run function calling - run_chat_completion(engine, model, function_calling_prompts, tools) - - engine.terminate() - - -if __name__ == "__main__": - test_chat_completion() - test_reload_reset_unload() - test_function_calling() diff --git a/tests/python/loader/test_awq.py b/tests/python/loader/test_awq.py deleted file mode 100644 index 3ab5bd911e..0000000000 --- a/tests/python/loader/test_awq.py +++ /dev/null @@ -1,40 +0,0 @@ -# pylint: disable=missing-docstring -from pathlib import Path -from typing import Union - -import pytest -import tvm - -from mlc_llm.loader import HuggingFaceLoader -from mlc_llm.model import MODEL_PRESETS, MODELS -from mlc_llm.quantization import QUANTIZATION -from mlc_llm.support import logging, tqdm - -logging.enable_logging() - - -@pytest.mark.parametrize( - "param_path", - [ - "./dist/models/llama-2-7b-w4-g128-awq.pt", - "./dist/models/Llama-2-7B-AWQ/model.safetensors", - ], -) -def test_load_llama(param_path: Union[str, Path]): - path_params = Path(param_path) - - model = MODELS["llama"] - quantization = QUANTIZATION["q4f16_awq"] - config = model.config.from_dict(MODEL_PRESETS["llama2_7b"]) - loader = HuggingFaceLoader( - path=path_params, - extern_param_map=model.source["awq"](config, quantization), - ) - with tqdm.redirect(): - for _name, _param in loader.load(tvm.device("cpu")): - ... - - -if __name__ == "__main__": - test_load_llama(param_path="./dist/models/llama-2-7b-w4-g128-awq.pt") - test_load_llama(param_path="./dist/models/Llama-2-7B-AWQ/model.safetensors") diff --git a/tests/python/loader/test_huggingface.py b/tests/python/loader/test_huggingface.py deleted file mode 100644 index 1b7bd3c02d..0000000000 --- a/tests/python/loader/test_huggingface.py +++ /dev/null @@ -1,69 +0,0 @@ -# pylint: disable=missing-docstring -from pathlib import Path -from typing import Union - -import pytest -import tvm - -from mlc_llm.loader import HuggingFaceLoader -from mlc_llm.model import MODELS -from mlc_llm.support import logging, tqdm - -logging.enable_logging() - - -@pytest.mark.parametrize( - "base_path", - [ - "./dist/models/Llama-2-7b-hf", - "./dist/models/Llama-2-13b-hf", - "./dist/models/Llama-2-70b-hf", - ], -) -def test_load_torch_llama(base_path: Union[str, Path]): - base_path = Path(base_path) - path_config = base_path / "config.json" - path_params = base_path / "pytorch_model.bin.index.json" - - model = MODELS["llama"] - config = model.config.from_file(path_config) - loader = HuggingFaceLoader( - path=path_params, - extern_param_map=model.source["huggingface-torch"](config, None), - ) - with tqdm.redirect(): - for _name, _param in loader.load(device=tvm.device("cpu")): - return # To reduce the time of the test - - -@pytest.mark.parametrize( - "base_path", - [ - "./dist/models/Llama-2-7b-hf", - "./dist/models/Llama-2-13b-hf", - "./dist/models/Llama-2-70b-hf", - ], -) -def test_load_safetensor_llama(base_path: Union[str, Path]): - base_path = Path(base_path) - path_config = base_path / "config.json" - path_params = base_path / "model.safetensors.index.json" - - model = MODELS["llama"] - config = model.config.from_file(path_config) - loader = HuggingFaceLoader( - path=path_params, - extern_param_map=model.source["huggingface-safetensor"](config, None), - ) - with tqdm.redirect(): - for _name, _param in loader.load(device=tvm.device("cpu")): - return # To reduce the time of the test - - -if __name__ == "__main__": - test_load_torch_llama(base_path="./dist/models/Llama-2-7b-hf") - test_load_torch_llama(base_path="./dist/models/Llama-2-13b-hf") - test_load_torch_llama(base_path="./dist/models/Llama-2-70b-hf") - test_load_safetensor_llama(base_path="./dist/models/Llama-2-7b-hf") - test_load_safetensor_llama(base_path="./dist/models/Llama-2-13b-hf") - test_load_safetensor_llama(base_path="./dist/models/Llama-2-70b-hf") diff --git a/tests/python/model/test_gpt2.py b/tests/python/model/test_gpt2.py deleted file mode 100644 index cdbe7ff222..0000000000 --- a/tests/python/model/test_gpt2.py +++ /dev/null @@ -1,21 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring -import pytest - -from mlc_llm.model import MODEL_PRESETS, MODELS - - -@pytest.mark.parametrize("model_name", ["gpt2"]) -def test_gpt2_creation(model_name: str): - model_info = MODELS["gpt2"] - config = model_info.config.from_dict(MODEL_PRESETS[model_name]) - model = model_info.model(config) - mod, named_params = model.export_tvm( - spec=model.get_default_spec(), # type: ignore - ) - mod.show(black_format=False) - for name, param in named_params: - print(name, param.shape, param.dtype) - - -if __name__ == "__main__": - test_gpt2_creation("gpt2") diff --git a/tests/python/model/test_gptNeox.py b/tests/python/model/test_gptNeox.py deleted file mode 100644 index 5983a5b491..0000000000 --- a/tests/python/model/test_gptNeox.py +++ /dev/null @@ -1,21 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring -import pytest - -from mlc_llm.model import MODEL_PRESETS, MODELS - - -@pytest.mark.parametrize("model_name", ["redpajama_3b_v1"]) -def test_mistral_creation(model_name: str): - model_info = MODELS["gpt_neox"] - config = model_info.config.from_dict(MODEL_PRESETS[model_name]) - model = model_info.model(config) - mod, named_params = model.export_tvm( - spec=model.get_default_spec(), # type: ignore - ) - mod.show(black_format=False) - for name, param in named_params: - print(name, param.shape, param.dtype) - - -if __name__ == "__main__": - test_mistral_creation("redpajama_3b_v1") diff --git a/tests/python/model/test_kv_cache.py b/tests/python/model/test_kv_cache.py deleted file mode 100644 index 3e3afb92cc..0000000000 --- a/tests/python/model/test_kv_cache.py +++ /dev/null @@ -1,114 +0,0 @@ -# pylint: disable=line-too-long,missing-docstring -import tvm -from tvm import tir -from tvm.relax.frontend.nn import core, modules, spec -from tvm.script import ir as I -from tvm.script import relax as R -from tvm.script import tir as T - -from mlc_llm.nn.kv_cache import FlashInferPagedKVCache, PagedKVCache, RopeMode - -# mypy: disable-error-code="attr-defined" -# pylint: disable=invalid-name,unused-argument,too-many-locals,too-many-statements - - -def test_nn_module_paged_kv_cache(): - # fmt: off - @I.ir_module - class Module: - @R.function - def create_paged_kv_cache( - max_batch_size: R.Shape(["max_batch_size_1"]), # type: ignore - max_total_seq_len: R.Shape(["max_total_seq_len_1"]), # type: ignore - prefill_chunk_size: R.Shape(["prefill_chunk_size_1"]), # type: ignore - page_size: R.Shape(["page_size_1"]), # type: ignore - support_sliding_window: R.Shape(["support_sliding_window_1"]), # type: ignore - ) -> R.Object: - max_batch_size_1 = T.int64() - max_total_seq_len_1 = T.int64() - prefill_chunk_size_1 = T.int64() - page_size_1 = T.int64() - support_sliding_window_1 = T.int64() - R.func_attr({"num_input": 5}) - with R.dataflow(): - paged_kv_cache: R.Object = R.call_pure_packed("mlc.create_paged_kv_cache_generic", R.shape([max_batch_size_1, max_total_seq_len_1, prefill_chunk_size_1, page_size_1, support_sliding_window_1]), R.prim_value(32), R.prim_value(32), R.prim_value(32), R.prim_value(128), R.prim_value(1), R.prim_value(1), R.prim_value(10000), R.prim_value(128), R.dtype("float16"), sinfo_args=(R.Object,)) - gv1: R.Object = paged_kv_cache - R.output(gv1) - return gv1 - - @R.function - def forward( - cache: R.Object, qkv: R.Tensor((1, 100, 96, 128), dtype="float16") # type: ignore - ) -> R.Tensor((1, 100, 32, 128), dtype="float16"): # type: ignore - R.func_attr({"num_input": 2}) - with R.dataflow(): - reshape: R.Tensor((100, 96, 128), dtype="float16") = R.reshape( # type: ignore - qkv, R.shape([100, 96, 128]) - ) - lv = R.call_dps_packed( - "vm.builtin.attention_kv_cache_attention_with_fused_qkv", - (cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape), - out_sinfo=R.Tensor((100, 32, 128), dtype="float16"), - ) - reshape1: R.Tensor((1, 100, 32, 128), dtype="float16") = R.reshape( # type: ignore - lv, R.shape([1, 100, 32, 128]) - ) - gv: R.Tensor((1, 100, 32, 128), dtype="float16") = reshape1 # type: ignore - R.output(gv) - return gv - # fmt: on - - class PagedKVCacheTest(modules.Module): - def forward( - self, - cache: PagedKVCache, - qkv: core.Tensor, - ) -> core.Tensor: - return cache.attention_with_fused_qkv(0, qkv, num_qo_heads=32) - - def create_paged_kv_cache( - self, - max_batch_size: tir.Var, - max_total_seq_len: tir.Var, - prefill_chunk_size: tir.Var, - page_size: tir.Var, - support_sliding_window: tir.Var, - ) -> PagedKVCache: - return PagedKVCache.create_generic( - max_batch_size=max_batch_size, - max_total_seq_len=max_total_seq_len, - prefill_chunk_size=prefill_chunk_size, - page_size=page_size, - support_sliding_window=support_sliding_window, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=32, - head_dim=128, - rope_mode=RopeMode.NORMAL, - rope_scale=1, - rope_theta=10000, - rotary_dim=128, - dtype="float16", - ) - - export_results = PagedKVCacheTest().export_tvm( - spec={ - "forward": { - "cache": spec.Object(object_type=PagedKVCache), - "qkv": spec.Tensor((1, 100, 96, 128), "float16"), - }, - "create_paged_kv_cache": { - "max_batch_size": int, - "max_total_seq_len": int, - "prefill_chunk_size": int, - "page_size": int, - "support_sliding_window": int, - }, - }, - ) - tvm_mod = export_results[0] - tvm.ir.assert_structural_equal(tvm_mod, Module, True) - - -if __name__ == "__main__": - test_nn_module_paged_kv_cache() diff --git a/tests/python/model/test_llama.py b/tests/python/model/test_llama.py deleted file mode 100644 index 5591dcdca2..0000000000 --- a/tests/python/model/test_llama.py +++ /dev/null @@ -1,26 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring -import pytest - -from mlc_llm.model import MODEL_PRESETS, MODELS - - -@pytest.mark.parametrize( - "model_name", ["llama2_7b", "llama2_13b", "llama2_70b", "tinyllama_1b_chat_v1.0"] -) -def test_llama2_creation(model_name: str): - model_info = MODELS["llama"] - config = model_info.config.from_dict(MODEL_PRESETS[model_name]) - model = model_info.model(config) - mod, named_params = model.export_tvm( - spec=model.get_default_spec(), # type: ignore - ) - mod.show(black_format=False) - for name, param in named_params: - print(name, param.shape, param.dtype) - - -if __name__ == "__main__": - test_llama2_creation("llama2_7b") - test_llama2_creation("llama2_13b") - test_llama2_creation("llama2_70b") - test_llama2_creation("tinyllama_1b_chat_v1") diff --git a/tests/python/model/test_llama_quantization.py b/tests/python/model/test_llama_quantization.py deleted file mode 100644 index 87d9d2b282..0000000000 --- a/tests/python/model/test_llama_quantization.py +++ /dev/null @@ -1,73 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring -import pytest - -from mlc_llm.model import MODEL_PRESETS, MODELS -from mlc_llm.quantization import QUANTIZATION -from mlc_llm.quantization.group_quantization import ( - GroupQuantizeEmbedding, - GroupQuantizeLinear, -) - - -@pytest.mark.parametrize( - "model_name", - ["llama2_7b", "llama2_13b", "llama2_70b"], -) -@pytest.mark.parametrize( - "quant_name", - ["q3f16_1", "q4f16_1", "q4f32_1"], -) -def test_llama2_group_quantization(model_name: str, quant_name: str): - model_info = MODELS["llama"] - config = model_info.config.from_dict(MODEL_PRESETS[model_name]) - model, quant_map = model_info.quantize["group-quant"](config, QUANTIZATION[quant_name]) - assert "model.embed_tokens.weight" in quant_map.param_map - assert isinstance( - model.model.embed_tokens, # type: ignore[attr-defined] - GroupQuantizeEmbedding, - ) - assert "lm_head.weight" in quant_map.param_map - assert isinstance(model.lm_head, GroupQuantizeLinear) # type: ignore[attr-defined] - for i in range(config.num_hidden_layers): - assert f"model.layers.{i}.self_attn.qkv_proj.weight" in quant_map.param_map - assert isinstance( - model.model.layers[i].self_attn.qkv_proj, # type: ignore[attr-defined] - GroupQuantizeLinear, - ) - assert f"model.layers.{i}.self_attn.o_proj.weight" in quant_map.param_map - assert isinstance( - model.model.layers[i].self_attn.o_proj, # type: ignore[attr-defined] - GroupQuantizeLinear, - ) - assert f"model.layers.{i}.mlp.gate_up_proj.weight" in quant_map.param_map - assert isinstance( - model.model.layers[i].mlp.gate_up_proj, # type: ignore[attr-defined] - GroupQuantizeLinear, - ) - assert f"model.layers.{i}.mlp.down_proj.weight" in quant_map.param_map - assert isinstance( - model.model.layers[i].mlp.down_proj, # type: ignore[attr-defined] - GroupQuantizeLinear, - ) - - -@pytest.mark.parametrize( - "model_name", - ["llama2_7b", "llama2_13b", "llama2_70b"], -) -@pytest.mark.parametrize( - "quant_name", - ["q0f16", "q0f32"], -) -def test_llama2_no_quantization(model_name: str, quant_name: str): - model_info = MODELS["llama"] - config = model_info.config.from_dict(MODEL_PRESETS[model_name]) - _, quant_map = model_info.quantize["no-quant"](config, QUANTIZATION[quant_name]) - assert len(quant_map.param_map) == 0 - assert len(quant_map.map_func) == 0 - - -if __name__ == "__main__": - test_llama2_group_quantization("llama2_7b", "q4f16_1") - test_llama2_group_quantization("llama2_13b", "q4f16_1") - test_llama2_group_quantization("llama2_70b", "q4f16_1") diff --git a/tests/python/model/test_mistral.py b/tests/python/model/test_mistral.py deleted file mode 100644 index c1d47eba77..0000000000 --- a/tests/python/model/test_mistral.py +++ /dev/null @@ -1,21 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring -import pytest - -from mlc_llm.model import MODEL_PRESETS, MODELS - - -@pytest.mark.parametrize("model_name", ["mistral_7b"]) -def test_mistral_creation(model_name: str): - model_info = MODELS["mistral"] - config = model_info.config.from_dict(MODEL_PRESETS[model_name]) - model = model_info.model(config) - mod, named_params = model.export_tvm( - spec=model.get_default_spec(), # type: ignore - ) - mod.show(black_format=False) - for name, param in named_params: - print(name, param.shape, param.dtype) - - -if __name__ == "__main__": - test_mistral_creation("mistral_7b") diff --git a/tests/python/model/test_phi.py b/tests/python/model/test_phi.py deleted file mode 100644 index e72effab35..0000000000 --- a/tests/python/model/test_phi.py +++ /dev/null @@ -1,22 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring -import pytest - -from mlc_llm.model import MODEL_PRESETS, MODELS - - -@pytest.mark.parametrize("model_name", ["phi-1_5", "phi-2"]) -def test_phi_creation(model_name: str): - model_info = MODELS["phi-msft"] - config = model_info.config.from_dict(MODEL_PRESETS[model_name]) - model = model_info.model(config) - mod, named_params = model.export_tvm( - spec=model.get_default_spec(), # type: ignore - ) - mod.show(black_format=False) - for name, param in named_params: - print(name, param.shape, param.dtype) - - -if __name__ == "__main__": - test_phi_creation("phi-1_5") - test_phi_creation("phi-2") diff --git a/tests/python/op/test_batch_spec_verify.py b/tests/python/op/test_batch_spec_verify.py deleted file mode 100644 index f35a39d71e..0000000000 --- a/tests/python/op/test_batch_spec_verify.py +++ /dev/null @@ -1,160 +0,0 @@ -import numpy as np -import pytest -import tvm -import tvm.testing - -from mlc_llm.op.batch_spec_verify import batch_spec_verify - - -@pytest.mark.parametrize("nbatch", [32, 64]) -@pytest.mark.parametrize("vocab", [3, 32, 64, 32000, 33, 65, 32001, 128000]) -@pytest.mark.parametrize("plist", [[0.5, 0.5], [1, 0], [0, 1]]) -def test_batch_spec_verify(nbatch, vocab, plist): - def numpy_reference( - draft_probs, - draft_tokens, - model_probs, - token_tree_first_child, - token_tree_next_sibling, - uniform_samples, - token_tree_parent_ptr, - ): - nbatch = token_tree_parent_ptr.shape[0] - for b in range(nbatch): - parent_ptr = token_tree_parent_ptr[b] - child_ptr = token_tree_first_child[parent_ptr] - while child_ptr != -1: - child_token = draft_tokens[child_ptr] - p_child = model_probs[parent_ptr, child_token] - q_child = draft_probs[child_ptr, child_token] - uniform_sample = uniform_samples[child_ptr] - if p_child / q_child >= uniform_sample: - parent_ptr = child_ptr - child_ptr = token_tree_first_child[child_ptr] - else: - model_probs[parent_ptr, :] = np.maximum( - model_probs[parent_ptr, :] - draft_probs[child_ptr, :], 0.0 - ) - psum = np.sum(model_probs[parent_ptr, :]) - model_probs[parent_ptr, :] /= psum - child_ptr = token_tree_next_sibling[child_ptr] - token_tree_parent_ptr[b] = parent_ptr - - np.random.seed(0) - - def gen_chain(num_nodes, base): - token_tree_first_child = list() - token_tree_next_sibling = list() - for i in range(num_nodes): - token_tree_first_child.append(base + i + 1 if i + 1 < num_nodes else -1) - token_tree_next_sibling.append(-1) - return token_tree_first_child, token_tree_next_sibling, base, base + 1 - - def gen_full_binary_tree(height, base): - token_tree_first_child = list() - token_tree_next_sibling = list() - num_nodes = 2**height - 1 - for i in range(num_nodes): - token_tree_first_child.append(base + i * 2 + 1 if i * 2 + 1 < num_nodes else -1) - token_tree_next_sibling.append(base + i * 2 + 2 if i * 2 + 2 < num_nodes else -1) - return token_tree_first_child, token_tree_next_sibling, base, base + 1 - - ### Inputs - num_nodes = 0 - token_tree_first_child = list() - token_tree_next_sibling = list() - token_tree_parent_ptr = list() - - for _ in range(nbatch): - choice = np.random.choice(2, 1, p=plist) - if choice == 0: - nodes_batch = np.random.randint(3, 32) - res = gen_chain(nodes_batch, num_nodes) - num_nodes += nodes_batch - else: - height = np.random.randint(3, 5) - res = gen_full_binary_tree(height, num_nodes) - num_nodes += 2**height - 1 - token_tree_first_child.extend(res[0]) - token_tree_next_sibling.extend(res[1]) - token_tree_parent_ptr.append(res[2]) - - token_tree_first_child = np.array(token_tree_first_child).astype("int32") - token_tree_next_sibling = np.array(token_tree_next_sibling).astype("int32") - token_tree_parent_ptr = np.array(token_tree_parent_ptr).astype("int32") - - draft_probs = np.random.rand(num_nodes, vocab).astype("float32") - draft_probs /= np.sum(draft_probs, axis=1, keepdims=True) - draft_tokens = np.random.randint(0, vocab, num_nodes).astype("int32") - model_probs = np.random.rand(num_nodes, vocab).astype("float32") - model_probs /= np.sum(model_probs, axis=1, keepdims=True) - uniform_samples = np.random.rand(num_nodes).astype("float32") - - ### TVM Inputs - dev = tvm.cuda(0) - draft_probs_tvm = tvm.nd.array(draft_probs, dev) - draft_tokens_tvm = tvm.nd.array(draft_tokens, dev) - model_probs_tvm = tvm.nd.array(model_probs, dev) - token_tree_first_child_tvm = tvm.nd.array(token_tree_first_child, dev) - token_tree_next_sibling_tvm = tvm.nd.array(token_tree_next_sibling, dev) - uniform_samples_tvm = tvm.nd.array(uniform_samples, dev) - token_tree_parent_ptr_tvm = tvm.nd.array(token_tree_parent_ptr, dev) - - # print("draft_probs", draft_probs) - # print("draft_tokens", draft_tokens) - # print("model_probs", model_probs) - # print("token_tree_first_child", token_tree_first_child) - # print("token_tree_next_sibling", token_tree_next_sibling) - # print("uniform_samples", uniform_samples) - # print("token_tree_parent_ptr", token_tree_parent_ptr) - - ### Numpy reference - numpy_reference( - draft_probs, - draft_tokens, - model_probs, - token_tree_first_child, - token_tree_next_sibling, - uniform_samples, - token_tree_parent_ptr, - ) - # print("model_probs", model_probs) - # print("token_tree_parent_ptr", token_tree_parent_ptr) - - ### TVM - kernel = batch_spec_verify(vocab) - mod = tvm.build(kernel, target="cuda") - mod( - draft_probs_tvm, - draft_tokens_tvm, - model_probs_tvm, - token_tree_first_child_tvm, - token_tree_next_sibling_tvm, - uniform_samples_tvm, - token_tree_parent_ptr_tvm, - ) - # print("model_probs", model_probs_tvm.asnumpy()) - # print("token_tree_parent_ptr", token_tree_parent_ptr_tvm.asnumpy()) - - tvm.testing.assert_allclose(model_probs, model_probs_tvm.asnumpy()) - tvm.testing.assert_allclose( - token_tree_parent_ptr, token_tree_parent_ptr_tvm.asnumpy(), rtol=0, atol=0 - ) - - time_evaluator = mod.time_evaluator(mod.entry_name, dev, number=10, repeat=3) - print(f"batch_size: {nbatch}, vocab_size: {vocab}, tree_structure: {plist}") - print( - time_evaluator( - draft_probs_tvm, - draft_tokens_tvm, - model_probs_tvm, - token_tree_first_child_tvm, - token_tree_next_sibling_tvm, - uniform_samples_tvm, - token_tree_parent_ptr_tvm, - ) - ) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/op/test_top_p_pivot.py b/tests/python/op/test_top_p_pivot.py deleted file mode 100644 index 7cfeb60e9c..0000000000 --- a/tests/python/op/test_top_p_pivot.py +++ /dev/null @@ -1,83 +0,0 @@ -import numpy as np -import pytest -import tvm -import tvm.testing - -from mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm - -# mypy: disable-error-code="var-annotated" - - -@pytest.mark.parametrize("batch_size", [32, 64]) -@pytest.mark.parametrize("vocab", [3, 32, 64, 128]) -def test_top_p_renorm(batch_size, vocab): - top_p = 0.95 - init_pivots_np = np.array([1 - top_p, 0.02, 0.01]).astype(np.float32) - top_p_np = np.array([top_p]).astype(np.float32) - - p_np = np.random.exponential(3, size=(batch_size, vocab)).astype(np.float32) - p_np /= np.sum(p_np, axis=-1, keepdims=True) - final_pivot_np = np.zeros(batch_size).astype(np.float32) - final_lsum_np = np.zeros(batch_size).astype(np.float32) - - dev = tvm.cuda(0) - var_prob = tvm.nd.array(p_np, dev) - var_init_pivots = tvm.nd.array(init_pivots_np, dev) - top_p_global = tvm.nd.array(top_p_np, dev) - var_final_pivot = tvm.nd.array(final_pivot_np, dev) - var_final_lsum = tvm.nd.array(final_lsum_np, dev) - - kernel = top_p_pivot(init_pivots_np.shape[0]) - mod = tvm.build(kernel, target="cuda") - mod(var_prob, top_p_global, var_init_pivots, var_final_pivot, var_final_lsum) - - final_pivot = var_final_pivot.asnumpy() - final_lsum = var_final_lsum.asnumpy() - - renorm_np = p_np.copy() - var_renorm = tvm.nd.array(renorm_np, dev) - - kernel_renorm = top_p_renorm() - mod_renorm = tvm.build(kernel_renorm, target="cuda") - mod_renorm(var_prob, var_final_pivot, var_final_lsum, var_renorm) - - renorm = var_renorm.asnumpy() - - def verify_pivot(probs: np.ndarray, pivot: float, lsum: float, renorm: np.ndarray): - sorted_probs = np.sort(probs, axis=-1)[::-1] - num_larger_than_pivot = np.sum(sorted_probs >= pivot) - filtered_sorted_probs = sorted_probs[:num_larger_than_pivot] - min_larger_than_pivot = min(filtered_sorted_probs) - - sum_larger_than_pivot = np.sum(np.where(sorted_probs >= pivot, sorted_probs, 0)) - sum_larger_than_pivot_exclude_min = np.sum( - np.where(filtered_sorted_probs != min_larger_than_pivot, filtered_sorted_probs, 0) - ) - - probs[probs < pivot] = 0 - renorm_prob = probs / np.sum(probs, axis=-1, keepdims=True) - try: - assert sum_larger_than_pivot >= top_p - assert sum_larger_than_pivot_exclude_min < top_p - assert abs(lsum - sum_larger_than_pivot) < 1e-6 - assert np.allclose(renorm, renorm_prob, atol=1e-6, rtol=1e-6) - except AssertionError: - print("Failed") - print("probs:", repr(probs)) - print("pivot:", pivot) - print("sorted_probs:", sorted_probs) - print("num_larger_than_pivot:", num_larger_than_pivot) - print("filtered_sorted_probs:", filtered_sorted_probs) - print("min_larger_than_pivot:", min_larger_than_pivot) - print("sum_larger_than_pivot:", sum_larger_than_pivot) - print("sum_larger_than_pivot_exclude_min:", sum_larger_than_pivot_exclude_min) - print("renom_prob:", renorm_prob) - print("renorm:", renorm) - raise - - for i in range(batch_size): - verify_pivot(p_np[i], final_pivot[i], final_lsum[i], renorm[i]) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/op/test_two_stage_softmax.py b/tests/python/op/test_two_stage_softmax.py deleted file mode 100644 index 1d3d55d8e3..0000000000 --- a/tests/python/op/test_two_stage_softmax.py +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np -import scipy.special -import tvm -from tvm import dlight - -from mlc_llm.compiler_pass.rewrite_softmax import _get_lse_and_softmax_func - - -def test_two_stage_softmax(): - chunk_size = 4096 - target = tvm.target.Target("cuda") - f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(target, chunk_size) - mod = tvm.IRModule({"chunk_lse": f_chunk_lse, "softmax_with_chunked_lse": f_softmax_with_lse}) - with target: - mod = dlight.ApplyDefaultSchedule(dlight.gpu.GeneralReduction())(mod) - - runtime_mod = tvm.build(mod, target=target) - device = tvm.cuda() - - num_runs = 5 - vocab_size = 128256 - for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: - for _ in range(num_runs): - x_np = np.random.uniform(low=-10, high=10, size=(batch_size, vocab_size)).astype( - "float32" - ) - y_np = scipy.special.softmax(x_np, axis=-1) - - x_nd = tvm.nd.array(x_np, device=device) - r_nd = tvm.nd.empty( - (batch_size, (vocab_size + chunk_size - 1) // chunk_size), - x_np.dtype, - device=device, - ) - y_nd = tvm.nd.empty(x_np.shape, x_np.dtype, device=device) - - runtime_mod["chunk_lse"](x_nd, r_nd) - runtime_mod["softmax_with_chunked_lse"](x_nd, r_nd, y_nd) - - y_nd_arr = y_nd.numpy() - np.testing.assert_allclose(y_nd_arr, y_np, atol=1e-6, rtol=1e-6) - - print(f"pass batch size {batch_size}") - - -if __name__ == "__main__": - test_two_stage_softmax() diff --git a/tests/python/protocol/test_converation_protocol.py b/tests/python/protocol/test_converation_protocol.py deleted file mode 100644 index c7732cc8e4..0000000000 --- a/tests/python/protocol/test_converation_protocol.py +++ /dev/null @@ -1,82 +0,0 @@ -import pytest - -from mlc_llm.conversation_template import ConvTemplateRegistry -from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders - - -def get_conv_templates(): - return [ - "llama-2", - "mistral_default", - "gorilla", - "gorilla-openfunctions-v2", - "chatml", - "phi-2", - "codellama_completion", - "codellama_instruct", - "rwkv-world", - ] - - -@pytest.mark.parametrize("conv_template_name", get_conv_templates()) -def test_json(conv_template_name): - template = ConvTemplateRegistry.get_conv_template(conv_template_name) - j = template.to_json_dict() - template_parsed = Conversation.from_json_dict(j) - assert template == template_parsed - - -@pytest.mark.parametrize("conv_template_name", get_conv_templates()) -def test_prompt(conv_template_name): - conversation = ConvTemplateRegistry.get_conv_template(conv_template_name) - user_msg = "test1" - assistant_msg = "test2" - prompt = "test3" - - expected_user_msg = ( - conversation.role_templates["user"] - .replace(MessagePlaceholders.USER.value, user_msg) - .replace(MessagePlaceholders.FUNCTION.value, "") - ) - - expected_prompt = ( - conversation.role_templates["user"] - .replace(MessagePlaceholders.USER.value, prompt) - .replace(MessagePlaceholders.FUNCTION.value, "") - ) - - conversation.messages.append(("user", user_msg)) - conversation.messages.append(("assistant", assistant_msg)) - conversation.messages.append(("user", prompt)) - conversation.messages.append(("assistant", None)) - res = conversation.as_prompt() - - system_msg = conversation.system_template.replace( - MessagePlaceholders.SYSTEM.value, conversation.system_message - ) - expected_final_prompt = ( - system_msg - + (conversation.seps[0] if system_msg != "" else "") - + ( - conversation.roles["user"] + conversation.role_content_sep - if conversation.add_role_after_system_message - else "" - ) - + expected_user_msg - + conversation.seps[0 % len(conversation.seps)] - + conversation.roles["assistant"] - + conversation.role_content_sep - + assistant_msg - + conversation.seps[1 % len(conversation.seps)] - + conversation.roles["user"] - + conversation.role_content_sep - + expected_prompt - + conversation.seps[0 % len(conversation.seps)] - + conversation.roles["assistant"] - + conversation.role_empty_sep - ) - assert res == expected_final_prompt - - -if __name__ == "__main__": - test_json() diff --git a/tests/python/quantization/test_awq_quantization.py b/tests/python/quantization/test_awq_quantization.py deleted file mode 100644 index 0222a29b6f..0000000000 --- a/tests/python/quantization/test_awq_quantization.py +++ /dev/null @@ -1,89 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring -from typing import List - -import numpy as np -import pytest -import torch -import tvm -import tvm.testing -from tvm import DataType -from tvm.relax.frontend import nn - -from mlc_llm.loader import QuantizeMapping -from mlc_llm.quantization import QUANTIZATION, AWQQuantize - - -def dequantize_np( - config: AWQQuantize, - weight: np.ndarray, - zeros: np.ndarray, - scale: np.ndarray, -) -> np.ndarray: - def decode_int_arr(int_arr: np.ndarray, num_elem_per_storage: int, bits: int): - bin_mask = (1 << bits) - 1 - int_arr_repeated = np.repeat(int_arr, num_elem_per_storage, axis=-1) - indice_j = np.indices(int_arr_repeated.shape)[1] - arr_bin = np.bitwise_and( - np.right_shift( - int_arr_repeated, - (indice_j % num_elem_per_storage) * bits, - ), - bin_mask, - ) - return arr_bin - - weight_bin = decode_int_arr( - weight, config.num_elem_per_storage, DataType(config.quantize_dtype).bits - ) - zero_bin = decode_int_arr( - zeros, config.num_elem_per_storage, DataType(config.quantize_dtype).bits - ) - scale_repeated = np.repeat(scale, config.group_size, axis=-1) - zero_bin_repeated = np.repeat(zero_bin, config.group_size, axis=-1) - return (weight_bin - zero_bin_repeated) * scale_repeated - - -@pytest.mark.parametrize( - "quant_name, shape, dtype", - [ - ("q4f16_awq", [2, 4096], "float16"), - ], -) -def test_dequantize_weight(quant_name: str, shape: List[int], dtype: str): - class Test(nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = nn.Linear(shape[1], shape[0], bias=False, dtype=dtype) - - def forward(self, x: nn.Tensor): - return self.linear(x) - - config = QUANTIZATION[quant_name] - assert isinstance(config, AWQQuantize) - weight_np = np.random.randint( - np.iinfo(config.storage_dtype).min, - np.iinfo(config.storage_dtype).max, - (shape[0], shape[1] // config.num_elem_per_storage), - ).astype(config.storage_dtype) - zeros_np = np.random.randint( - np.iinfo(config.storage_dtype).min, - np.iinfo(config.storage_dtype).max, - (shape[0], shape[1] // config.num_elem_per_storage // config.group_size), - ).astype(config.storage_dtype) - scale_np = np.random.random((shape[0], shape[1] // config.group_size)).astype( - config.model_dtype - ) - mod = config.quantize_model(Test(), QuantizeMapping({}, {}), "") - mod.linear.qweight.data = weight_np - mod.linear.qzeros.data = zeros_np - mod.linear.scales.data = scale_np - model = mod.jit(spec={"forward": {"x": nn.spec.Tensor((shape[1], shape[1]), dtype)}}) - out = model["forward"]( - torch.from_numpy(np.diag(np.ones(shape[1]).astype(dtype))) # pylint: disable=no-member - ) - ref = dequantize_np(config, weight_np, zeros_np, scale_np).T - tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) - - -if __name__ == "__main__": - test_dequantize_weight("q4f16_awq", [2, 4096], "float16") diff --git a/tests/python/quantization/test_group_quantization.py b/tests/python/quantization/test_group_quantization.py deleted file mode 100644 index b3f9d8034c..0000000000 --- a/tests/python/quantization/test_group_quantization.py +++ /dev/null @@ -1,189 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring -from typing import List - -import numpy as np -import pytest -import torch -import tvm -import tvm.testing -from tvm import DataType -from tvm.relax.frontend import nn - -from mlc_llm.loader import QuantizeMapping -from mlc_llm.quantization import QUANTIZATION -from mlc_llm.quantization.group_quantization import ( - GroupQuantize, - GroupQuantizeEmbedding, - GroupQuantizeLinear, -) - - -def quantize_np(config: GroupQuantize, weight: np.ndarray): - n, k = weight.shape - weight_padded = np.pad( - weight, ((0, 0), (0, (config.group_size - k % config.group_size) % config.group_size)) - ) - n, k = weight_padded.shape - weight_reshaped = np.reshape(weight_padded, (n, k // config.group_size, config.group_size)) - max_abs = np.maximum(np.max(np.abs(weight_reshaped), axis=-1), 1e-4) - scale = np.divide(max_abs, config.max_int_value) - scale_reshaped = np.reshape(scale, (*scale.shape, 1)) - weight_scaled_reshaped = np.clip( - np.add( - np.round(np.divide(weight_reshaped, scale_reshaped)), - config.max_int_value, - ), - 0, - config.max_int_value * 2, - ).astype(config.storage_dtype) - weight_filtered = np.reshape(weight_scaled_reshaped, (n, k)) - weight_filtered[..., weight.shape[1] :] = 0 - weight_scaled = np.reshape( - weight_filtered, (n, k // config.num_elem_per_storage, config.num_elem_per_storage) - ) - indice_k = np.indices(weight_scaled.shape, dtype=config.storage_dtype)[-1] - quantized_weight = np.sum( - np.left_shift(weight_scaled, indice_k * DataType(config.quantize_dtype).bits), - axis=-1, - dtype=config.storage_dtype, - ) - return quantized_weight, scale - - -def dequantize_np( - config: GroupQuantize, - weight: np.ndarray, - scale: np.ndarray, - out_shape: List[int] = None, -): - assert weight.shape[0] == scale.shape[0] - bin_mask = (1 << DataType(config.quantize_dtype).bits) - 1 - max_int = config.max_int_value - out_shape = ( - [weight.shape[0], weight.shape[1] * config.num_elem_per_storage] - if out_shape is None - else out_shape - ) - weight_repeated = np.repeat(weight, config.num_elem_per_storage, axis=-1) - scale_repeated = np.repeat(scale, config.group_size, axis=-1) - indice_j = np.indices(weight_repeated.shape)[1] - weight_bin = np.bitwise_and( - np.right_shift( - weight_repeated, - (indice_j % config.num_elem_per_storage) * DataType(config.quantize_dtype).bits, - ), - bin_mask, - ) - assert weight_bin.shape[1] <= scale_repeated.shape[1] - return ((weight_bin - max_int) * scale_repeated[..., : weight_bin.shape[1]])[ - : out_shape[0], : out_shape[1] - ] - - -@pytest.mark.parametrize( - "quant_name, shape, dtype, device", - [ - ("q3f16_1", [2, 13], "float16", "cpu"), - ("q3f16_1", [16, 120], "float16", "cpu"), - ("q4f16_1", [2, 13], "float16", "cpu"), - ("q4f16_1", [16, 128], "float16", "cpu"), - ("q4f32_1", [2, 13], "float32", "cpu"), - ("q4f32_1", [16, 128], "float32", "cpu"), - ], -) -def test_quantize_weight(quant_name: str, shape: List[int], dtype: str, device: str): - config = QUANTIZATION[quant_name] - assert isinstance(config, GroupQuantize) - weight_np = np.random.random(shape).astype(dtype) - output = config.quantize_weight(tvm.nd.array(weight_np, device=tvm.device(device))) - quantized_weight, scale = output[0].numpy(), output[1].numpy() - quantized_weight_ref, scale_ref = quantize_np(config, weight_np) - tvm.testing.assert_allclose(scale, scale_ref, rtol=1e-3, atol=1e-3) - tvm.testing.assert_allclose( - dequantize_np(config, quantized_weight, scale, shape), - dequantize_np(config, quantized_weight_ref, scale_ref, shape), - rtol=1e-2 if quant_name.startswith("q3") else 1e-3, - atol=0.4 if quant_name.startswith("q3") else 0.2, - ) - - -@pytest.mark.parametrize( - "quant_name, shape, dtype", - [ - ("q3f16_1", [2, 13], "float16"), - ("q3f16_1", [16, 120], "float16"), - ("q4f16_1", [2, 13], "float16"), - ("q4f16_1", [16, 128], "float16"), - ("q4f32_1", [2, 13], "float32"), - ("q4f32_1", [16, 128], "float32"), - ], -) -def test_dequantize_weight(quant_name: str, shape: List[int], dtype: str): - class Test(nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = nn.Linear(shape[1], shape[0], bias=False, dtype=dtype) - - def forward(self, x: nn.Tensor): - return self.linear(x) - - config = QUANTIZATION[quant_name] - assert isinstance(config, GroupQuantize) - num_group = -(shape[1] // -config.group_size) - weight_np = np.random.randint( - np.iinfo(config.storage_dtype).min, - np.iinfo(config.storage_dtype).max, - (shape[0], config.num_storage_per_group * num_group), - ).astype(config.storage_dtype) - scale_np = np.random.random((shape[0], num_group)).astype(config.model_dtype) - mod = config.quantize_model(Test(), QuantizeMapping({}, {}), "") - mod.linear.q_weight.data = weight_np - mod.linear.q_scale.data = scale_np - model = mod.jit(spec={"forward": {"x": nn.spec.Tensor((shape[1], shape[1]), dtype)}}) - out = model["forward"]( - torch.from_numpy(np.diag(np.ones(shape[1]).astype(dtype))) # pylint: disable=no-member - ) - ref = dequantize_np(config, weight_np, scale_np, shape).T - tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) - - -@pytest.mark.parametrize( - "quant_name, shape, dtype", - [ - ("q3f16_1", [16, 128], "float16"), - ("q4f16_1", [16, 128], "float16"), - ("q4f32_1", [16, 128], "float32"), - ], -) -def test_quantize_model(quant_name: str, shape: List[int], dtype: str): - class Test(nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = nn.Linear(shape[0], shape[1], dtype=dtype) - self.embedding = nn.Embedding(shape[0], shape[1], dtype=dtype) - - def forward(self, x: nn.Tensor): - return self.linear(x) - - config = QUANTIZATION[quant_name] - assert isinstance(config, GroupQuantize) - quant_map = QuantizeMapping({}, {}) - mod = config.quantize_model(Test(), quant_map, "model") - assert quant_map.param_map["model.linear.weight"] == [ - "model.linear.q_weight", - "model.linear.q_scale", - ] - assert quant_map.map_func["model.linear.weight"] == config.quantize_weight - assert isinstance(mod.linear, GroupQuantizeLinear) - assert quant_map.param_map["model.embedding.weight"] == [ - "model.embedding.q_weight", - "model.embedding.q_scale", - ] - assert quant_map.map_func["model.embedding.weight"] == config.quantize_weight - assert isinstance(mod.embedding, GroupQuantizeEmbedding) - - -if __name__ == "__main__": - test_quantize_weight("q4f16_1", [16, 128], "float16", "llvm") - test_quantize_model("q4f16_1", [16, 128], "float16") - test_dequantize_weight("q4f16_1", [16, 128], "float16") diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py deleted file mode 100644 index c89a9e2c38..0000000000 --- a/tests/python/serve/evaluate_engine.py +++ /dev/null @@ -1,72 +0,0 @@ -# pylint: disable=line-too-long,missing-docstring -import argparse -import os -import random -from typing import List, Tuple - -from mlc_llm.serve import GenerationConfig -from mlc_llm.serve.sync_engine import SyncMLCEngine - - -def _parse_args(): - args = argparse.ArgumentParser() - args.add_argument("--model-lib-path", type=str) - args.add_argument("--device", type=str, default="auto") - args.add_argument("--batch-size", type=int, default=80) - args.add_argument("--max-total-seq-length", type=int) - args.add_argument("--seed", type=int, default=0) - - parsed = args.parse_args() - parsed.model = os.path.dirname(parsed.model_lib_path) - assert parsed.batch_size % 16 == 0 - return parsed - - -def generate_requests( - num_requests: int, input_length: int, output_length: int -) -> Tuple[List[List[int]], List[GenerationConfig]]: - prompt_ids = [] - for _ in range(num_requests): - token_ids = [] - for _ in range(input_length): - token_ids.append(random.randint(0, 30000)) - prompt_ids.append(token_ids) - generation_config_list = [ - GenerationConfig(temperature=1.0, top_p=1.0, max_tokens=output_length) - ] * num_requests - return prompt_ids, generation_config_list - - -def benchmark(args: argparse.Namespace): - random.seed(args.seed) - - # Create engine - engine = SyncMLCEngine( - model=args.model, - device=args.device, - model_lib_path=args.model_lib_path, - mode="server", - max_batch_size=args.batch_size, - max_total_sequence_length=args.max_total_seq_length, - ) - - print(args) - for num_requests in [1, 2, 4, 8, 16, 32, 64]: - if num_requests > args.batch_size: - continue - for input_length in [64, 128, 256, 512, 1024]: - if num_requests * input_length >= 16384: - continue - for output_length in [4]: - print(f"nreq={num_requests}\t" f"in={input_length}\t" f"out={output_length}") - prompt_ids, generation_config = generate_requests( - num_requests, input_length, output_length - ) - engine.reset() - engine.generate(prompt_ids, generation_config) - print() - - -if __name__ == "__main__": - ARGS = _parse_args() - benchmark(ARGS) diff --git a/tests/python/serve/json.ebnf b/tests/python/serve/json.ebnf deleted file mode 100644 index fc3fb22d65..0000000000 --- a/tests/python/serve/json.ebnf +++ /dev/null @@ -1,22 +0,0 @@ -# Adopted from https://www.crockford.com/mckeeman.html -main ::= element -value ::= object | array | string | number | "true" | "false" | "null" -object ::= "{" ws "}" | "{" members "}" -members ::= member | member "," members -member ::= ws string ws ":" element -array ::= "[" ws "]" | "[" elements "]" -elements ::= element | element "," elements -element ::= ws value ws -string ::= "\"" characters "\"" -characters ::= "" | character characters -character ::= [^"\\] | "\\" escape -escape ::= "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | "u" hex hex hex hex -hex ::= [A-Fa-f0-9] -number ::= integer fraction exponent -integer ::= digit | onenine digits | "-" digit | "-" onenine digits -digits ::= digit | digit digits -digit ::= [0-9] -onenine ::= [1-9] -fraction ::= "" | "." digits -exponent ::= "" | ("e" | "E") ("" | "+" | "-") digits -ws ::= "" | "\u0020" ws | "\u000A" ws | "\u000D" ws | "\u0009" ws diff --git a/tests/python/serve/server/conftest.py b/tests/python/serve/server/conftest.py deleted file mode 100644 index e425494231..0000000000 --- a/tests/python/serve/server/conftest.py +++ /dev/null @@ -1,33 +0,0 @@ -# pylint: disable=missing-module-docstring,missing-function-docstring -import os -from typing import Tuple - -import pytest - -from mlc_llm.serve import PopenServer - - -@pytest.fixture(scope="session") -def served_model() -> Tuple[str, str]: - model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") - if model_lib_path is None: - raise ValueError( - 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' - "Please set it to model lib compiled by MLC LLM " - "(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`)." - ) - model = os.path.dirname(model_lib_path) - return model, model_lib_path - - -@pytest.fixture(scope="session") -def launch_server(served_model): # pylint: disable=redefined-outer-name - """A pytest session-level fixture which launches the server in a subprocess.""" - server = PopenServer( - model=served_model[0], - model_lib_path=served_model[1], - enable_tracing=True, - ) - - with server: - yield diff --git a/tests/python/serve/server/test_server.py b/tests/python/serve/server/test_server.py deleted file mode 100644 index e4f64d2ce4..0000000000 --- a/tests/python/serve/server/test_server.py +++ /dev/null @@ -1,1341 +0,0 @@ -"""Server tests in MLC LLM. -Before running any test, we use pytest fixtures to launch a -test-session-wide server in a subprocess, and then execute the tests. - -The recommended way to run the tests is to use the following command: - MLC_SERVE_MODEL_LIB="YOUR_MODEL_LIB" pytest -vv tests/python/serve/server/test_server.py - -Here "YOUR_MODEL_LIB" is a compiled model library like -`dist/Llama-2-7b-chat-hf-q4f16_1/Llama-2-7b-chat-hf-q4f16_1-cuda.so`, -as long as the model is built with batching and embedding separation enabled. - -To directly run the Python file (a.k.a., not using pytest), you need to -launch the server in ahead before running this file. This can be done in -two steps: -- start a new shell session, run - python -m mlc_llm.serve.server --model "YOUR_MODEL_LIB" -- start another shell session, run this file - MLC_SERVE_MODEL_LIB="YOUR_MODEL_LIB" python tests/python/serve/server/test_server.py -""" - -# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches -import json -import os -from http import HTTPStatus -from typing import Dict, List, Optional, Tuple - -import pytest -import regex -import requests -from openai import OpenAI -from pydantic import BaseModel - -OPENAI_BASE_URL = "http://127.0.0.1:8000/v1" -OPENAI_V1_MODELS_URL = "http://127.0.0.1:8000/v1/models" -OPENAI_V1_COMPLETION_URL = "http://127.0.0.1:8000/v1/completions" -OPENAI_V1_CHAT_COMPLETION_URL = "http://127.0.0.1:8000/v1/chat/completions" -DEBUG_DUMP_EVENT_TRACE_URL = "http://127.0.0.1:8000/debug/dump_event_trace" - - -JSON_TOKEN_PATTERN = ( - r"((-?(?:0|[1-9]\d*))(\.\d+)?([eE][-+]?\d+)?)|null|true|false|" - r'("((\\["\\\/bfnrt])|(\\u[0-9a-fA-F]{4})|[^"\\\x00-\x1f])*")' -) -JSON_TOKEN_RE = regex.compile(JSON_TOKEN_PATTERN) - - -def is_json(s: str) -> bool: - try: - json.loads(s) - return True - except json.JSONDecodeError: - return False - - -def is_json_prefix(s: str) -> bool: - try: - json.loads(s) - return True - except json.JSONDecodeError as e: - # If the JSON decoder reaches the end of s, it is a prefix of a JSON string. - if e.pos == len(s): - return True - # Since json.loads is token-based instead of char-based, there may remain half a token after - # the matching position. - # If the left part is a prefix of a valid JSON token, the output is also valid - regex_match = JSON_TOKEN_RE.fullmatch(s[e.pos :], partial=True) - return regex_match is not None - - -def check_openai_nonstream_response( - response: Dict, - *, - is_chat_completion: bool, - model: str, - object_str: str, - num_choices: int, - finish_reasons: List[str], - completion_tokens: Optional[int] = None, - echo_prompt: Optional[str] = None, - suffix: Optional[str] = None, - stop: Optional[List[str]] = None, - require_substr: Optional[List[str]] = None, - check_json_output: bool = False, -): - assert response["model"] == model - assert response["object"] == object_str - - choices = response["choices"] - assert isinstance(choices, list) - assert len(choices) <= num_choices - texts: List[str] = ["" for _ in range(num_choices)] - for choice in choices: - idx = choice["index"] - assert choice["finish_reason"] in finish_reasons - - if not is_chat_completion: - assert isinstance(choice["text"], str) - texts[idx] = choice["text"] - if echo_prompt is not None: - assert texts[idx] - if suffix is not None: - assert texts[idx] - else: - message = choice["message"] - assert message["role"] == "assistant" - assert isinstance(message["content"], str) - texts[idx] = message["content"] - - if stop is not None: - for stop_str in stop: - assert stop_str not in texts[idx] - if require_substr is not None: - for substr in require_substr: - assert substr in texts[idx] - if check_json_output: - # the output should be json or a prefix of a json string - # if the output is a prefix of a json string, the output must exceed the max output - # length - output_is_json = is_json(texts[idx]) - output_is_json_prefix = is_json_prefix(texts[idx]) - assert output_is_json or output_is_json_prefix - if not output_is_json and output_is_json_prefix: - assert choice["finish_reason"] == "length" - - usage = response["usage"] - assert isinstance(usage, dict) - assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] - assert usage["prompt_tokens"] > 0 - if completion_tokens is not None: - assert usage["completion_tokens"] == completion_tokens - - -def check_openai_stream_response( - responses: List[Dict], - *, - is_chat_completion: bool, - model: str, - object_str: str, - num_choices: int, - finish_reasons: List[str], - completion_tokens: Optional[int] = None, - echo_prompt: Optional[str] = None, - suffix: Optional[str] = None, - stop: Optional[List[str]] = None, - require_substr: Optional[List[str]] = None, - check_json_output: bool = False, -): - assert len(responses) > 0 - - finished = [False for _ in range(num_choices)] - outputs = ["" for _ in range(num_choices)] - finish_reason_list = ["" for _ in range(num_choices)] - for response in responses: - assert response["model"] == model - assert response["object"] == object_str - - choices = response["choices"] - assert isinstance(choices, list) - assert len(choices) <= num_choices - for choice in choices: - idx = choice["index"] - - if not is_chat_completion: - assert isinstance(choice["text"], str) - outputs[idx] += choice["text"] - else: - delta = choice["delta"] - assert delta["role"] == "assistant" - assert isinstance(delta["content"], str) - outputs[idx] += delta["content"] - - if finished[idx]: - assert choice["finish_reason"] in finish_reasons - finish_reason_list[idx] = choice["finish_reason"] - elif choice["finish_reason"] is not None: - assert choice["finish_reason"] in finish_reasons - finish_reason_list[idx] = choice["finish_reason"] - finished[idx] = True - - if not is_chat_completion: - usage = response["usage"] - assert isinstance(usage, dict) - assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] - assert usage["prompt_tokens"] >= 0 - if completion_tokens is not None: - assert usage["completion_tokens"] <= completion_tokens - - if not is_chat_completion: - if completion_tokens is not None: - assert responses[-1]["usage"]["completion_tokens"] == completion_tokens - - for i, (output, finish_reason) in enumerate(zip(outputs, finish_reason_list)): - if echo_prompt is not None: - assert output.startswith(echo_prompt) - if suffix is not None: - assert output.endswith(suffix) - if stop is not None: - for stop_str in stop: - assert stop_str not in output - if require_substr is not None: - for substr in require_substr: - assert substr in output - if check_json_output: - # the output should be json or a prefix of a json string - # if the output is a prefix of a json string, the output must exceed the max output - # length - output_is_json = is_json(output) - output_is_json_prefix = is_json_prefix(output) - assert output_is_json or output_is_json_prefix - if not output_is_json and output_is_json_prefix: - assert finish_reason == "length" - - -def expect_error(response_str: str, msg_prefix: Optional[str] = None): - response = json.loads(response_str) - assert response["object"] == "error" - assert isinstance(response["message"], str) - if msg_prefix is not None: - assert response["message"].startswith(msg_prefix) - - -def test_openai_v1_models( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - response = requests.get(OPENAI_V1_MODELS_URL, timeout=180).json() - assert response["object"] == "list" - models = response["data"] - assert isinstance(models, list) - assert len(models) == 1 - - model_card = models[0] - assert isinstance(model_card, dict) - assert model_card["id"] == served_model[0], f"{model_card['id']} {served_model[0]}" - assert model_card["object"] == "model" - assert model_card["owned_by"] == "MLC-LLM" - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - prompt = "What is the meaning of life?" - max_tokens = 256 - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": max_tokens, - "stream": stream, - "ignore_eos": True, - } - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - completion_tokens=max_tokens, - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - completion_tokens=max_tokens, - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_openai_package( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - client = OpenAI(base_url=OPENAI_BASE_URL, api_key="None") - prompt = "What is the meaning of life?" - max_tokens = 256 - response = client.completions.create( - model=served_model[0], - prompt=prompt, - max_tokens=max_tokens, - stream=stream, - ) - if not stream: - check_openai_nonstream_response( - response.model_dump(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length", "stop"], - completion_tokens=max_tokens, - ) - else: - responses = [] - for chunk in response: # pylint: disable=not-an-iterable - responses.append(chunk.model_dump()) - check_openai_stream_response( - responses, - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length", "stop"], - completion_tokens=max_tokens, - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_echo( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - prompt = "What is the meaning of life?" - max_tokens = 256 - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": max_tokens, - "echo": True, - "stream": stream, - "ignore_eos": True, - } - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - completion_tokens=max_tokens, - echo_prompt=prompt, - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - completion_tokens=max_tokens, - echo_prompt=prompt, - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_suffix( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - prompt = "What is the meaning of life?" - suffix = "Hello, world!" - max_tokens = 256 - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": max_tokens, - "suffix": suffix, - "stream": stream, - "ignore_eos": True, - } - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - completion_tokens=max_tokens, - suffix=suffix, - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - completion_tokens=max_tokens, - suffix=suffix, - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_stop_str( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - # Choose "in" as the stop string since it is very unlikely that - # "in" does not appear in the generated output. - prompt = "What is the meaning of life?" - stop = ["in"] - max_tokens = 256 - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": max_tokens, - "stop": stop, - "stream": stream, - } - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["stop", "length"], - stop=stop, - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["stop", "length"], - stop=stop, - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_temperature( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - prompt = "What's the meaning of life?" - max_tokens = 128 - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": max_tokens, - "stream": stream, - "temperature": 0.0, - "ignore_eos": True, - } - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_json( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - prompt = "Response with a json object:" - max_tokens = 128 - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": max_tokens, - "stream": stream, - "response_format": {"type": "json_object"}, - } - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length", "stop"], - check_json_output=True, - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length", "stop"], - check_json_output=True, - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_json_schema( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - prompt = ( - "Generate a json containing three fields: an integer field named size, a " - "boolean field named is_accepted, and a float field named num:" - ) - max_tokens = 128 - - class Schema(BaseModel): - size: int - is_accepted: bool - num: float - - schema_str = json.dumps(Schema.model_json_schema()) - - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": max_tokens, - "stream": stream, - "response_format": {"type": "json_object", "schema": schema_str}, - } - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length", "stop"], - check_json_output=True, - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length", "stop"], - check_json_output=True, - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_logit_bias( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - # NOTE: This test only tests that the system does not break on logit bias. - # The test does not promise the correctness of logit bias handling. - - prompt = "What's the meaning of life?" - max_tokens = 128 - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": max_tokens, - "stream": stream, - "logit_bias": {338: -100}, # 338 is " is" in Llama tokenizer. - "ignore_eos": True, - } - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_presence_frequency_penalty( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - prompt = "What's the meaning of life?" - max_tokens = 128 - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": max_tokens, - "stream": stream, - "frequency_penalty": 2.0, - "presence_penalty": 2.0, - "ignore_eos": True, - } - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - ) - - -def test_openai_v1_completions_seed( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - prompt = "What's the meaning of life?" - max_tokens = 128 - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": max_tokens, - "stream": False, - "seed": 233, - "ignore_eos": True, - } - - response1 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - response2 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - for response in [response1, response2]: - check_openai_nonstream_response( - response.json(), - is_chat_completion=False, - model=served_model[0], - object_str="text_completion", - num_choices=1, - finish_reasons=["length"], - ) - - text1 = response1.json()["choices"][0]["text"] - text2 = response2.json()["choices"][0]["text"] - assert text1 == text2 - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_prompt_overlong( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - num_tokens = 1000000 - prompt = [128] * num_tokens - payload = { - "model": served_model[0], - "prompt": prompt, - "max_tokens": 256, - "stream": stream, - } - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - error_msg_prefix = ( - f"Request prompt has {num_tokens} tokens in total, larger than the model input length limit" - ) - if not stream: - expect_error(response.json(), msg_prefix=error_msg_prefix) - else: - num_chunks = 0 - for chunk in response.iter_lines(chunk_size=512): - if not chunk: - continue - num_chunks += 1 - expect_error(json.loads(chunk.decode("utf-8")), msg_prefix=error_msg_prefix) - assert num_chunks == 1 - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_completions_invalid_logprobs( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - payload = { - "model": served_model[0], - "prompt": "What is the meaning of life?", - "max_tokens": 256, - "stream": stream, - "logprobs": False, - "top_logprobs": 4, - } - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY - assert response.json()["detail"][0]["msg"].endswith( - '"logprobs" must be True to support "top_logprobs"' - ) - - payload["logprobs"] = True - payload["top_logprobs"] = 6 - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY - assert response.json()["detail"][0]["msg"].endswith('"top_logprobs" must be in range [0, 5]') - - -def test_openai_v1_completions_unsupported_args( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - # Right now "best_of" is unsupported. - best_of = 2 - payload = { - "model": served_model[0], - "prompt": "What is the meaning of life?", - "max_tokens": 256, - "best_of": best_of, - } - - response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180) - error_msg_prefix = 'Request fields "best_of" are not supported right now.' - expect_error(response.json(), msg_prefix=error_msg_prefix) - - -def test_openai_v1_completions_request_cancellation( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - # Use a large max_tokens and small timeout to force timeouts. - payload = { - "model": served_model[0], - "prompt": "What is the meaning of life?", - "max_tokens": 2048, - "stream": False, - } - with pytest.raises(requests.exceptions.Timeout): - requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=1) - - # The server should still be alive after a request cancelled. - # We query `v1/models` to validate the server liveness. - response = requests.get(OPENAI_V1_MODELS_URL, timeout=180).json() - - assert response["object"] == "list" - models = response["data"] - assert isinstance(models, list) - assert len(models) == 1 - - model_card = models[0] - assert isinstance(model_card, dict) - assert model_card["id"] == served_model[0] - assert model_card["object"] == "model" - assert model_card["owned_by"] == "MLC-LLM" - - -CHAT_COMPLETION_MESSAGES = [ - # messages #0 - [{"role": "user", "content": "Hello! Our project is MLC LLM."}], - # messages #1 - [ - {"role": "user", "content": "Hello! Our project is MLC LLM."}, - { - "role": "assistant", - "content": "Hello! It's great to hear about your project, MLC LLM.", - }, - {"role": "user", "content": "What is the name of our project?"}, - ], - # messages #2 - [ - { - "role": "system", - "content": "You are a helpful, respectful and honest assistant. " - "You always ends your response with an emoji.", - }, - {"role": "user", "content": "Hello! Our project is MLC LLM."}, - ], -] - - -@pytest.mark.parametrize("stream", [False, True]) -@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) -def test_openai_v1_chat_completions( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, - messages: List[Dict[str, str]], -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - payload = { - "model": served_model[0], - "messages": messages, - "stream": stream, - } - - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion", - num_choices=1, - finish_reasons=["stop"], - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion.chunk", - num_choices=1, - finish_reasons=["stop"], - ) - - -@pytest.mark.parametrize("stream", [False, True]) -@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) -def test_openai_v1_chat_completions_n( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, - messages: List[Dict[str, str]], -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - n = 3 - payload = { - "model": served_model[0], - "messages": messages, - "stream": stream, - "n": n, - "max_tokens": 300, - } - - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion", - num_choices=n, - finish_reasons=["stop", "length"], - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion.chunk", - num_choices=n, - finish_reasons=["stop", "length"], - ) - - -@pytest.mark.parametrize("stream", [False, True]) -@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) -def test_openai_v1_chat_completions_openai_package( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, - messages: List[Dict[str, str]], -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - client = OpenAI(base_url=OPENAI_BASE_URL, api_key="None") - response = client.chat.completions.create( - model=served_model[0], - messages=messages, - stream=stream, - logprobs=True, - top_logprobs=2, - ) - if not stream: - check_openai_nonstream_response( - response.model_dump(), - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion", - num_choices=1, - finish_reasons=["stop"], - ) - else: - responses = [] - for chunk in response: # pylint: disable=not-an-iterable - responses.append(chunk.model_dump()) - check_openai_stream_response( - responses, - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion.chunk", - num_choices=1, - finish_reasons=["stop"], - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_chat_completions_max_tokens( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - messages = [{"role": "user", "content": "Write a novel with at least 500 words."}] - max_tokens = 16 - payload = { - "model": served_model[0], - "messages": messages, - "stream": stream, - "max_tokens": max_tokens, - } - - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion", - num_choices=1, - finish_reasons=["length"], - completion_tokens=max_tokens, - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion.chunk", - num_choices=1, - finish_reasons=["length"], - completion_tokens=max_tokens, - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_chat_completions_json( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - messages = [{"role": "user", "content": "Response with a json object:"}] - max_tokens = 128 - payload = { - "model": served_model[0], - "messages": messages, - "stream": stream, - "max_tokens": max_tokens, - "response_format": {"type": "json_object"}, - } - - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion", - num_choices=1, - finish_reasons=["length", "stop"], - check_json_output=True, - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion.chunk", - num_choices=1, - finish_reasons=["length", "stop"], - check_json_output=True, - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_chat_completions_json_schema( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - prompt = ( - "Generate a json containing three fields: an integer field named size, a " - "boolean field named is_accepted, and a float field named num:" - ) - messages = [{"role": "user", "content": prompt}] - max_tokens = 128 - - class Schema(BaseModel): - size: int - is_accepted: bool - num: float - - schema_str = json.dumps(Schema.model_json_schema()) - - payload = { - "model": served_model[0], - "messages": messages, - "stream": stream, - "max_tokens": max_tokens, - "response_format": {"type": "json_object", "schema": schema_str}, - } - - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion", - num_choices=1, - finish_reasons=["length", "stop"], - check_json_output=True, - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion.chunk", - num_choices=1, - finish_reasons=["length", "stop"], - check_json_output=True, - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_chat_completions_ignore_eos( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - messages = [{"role": "user", "content": "Write a sentence with less than 20 words."}] - max_tokens = 128 - payload = { - "model": served_model[0], - "messages": messages, - "stream": stream, - "max_tokens": max_tokens, - "ignore_eos": True, - } - - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion", - num_choices=1, - finish_reasons=["length"], - completion_tokens=max_tokens, - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion.chunk", - num_choices=1, - finish_reasons=["length"], - completion_tokens=max_tokens, - ) - - -@pytest.mark.parametrize("stream", [False, True]) -def test_openai_v1_chat_completions_system_prompt_wrong_pos( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - messages = [ - {"role": "user", "content": "Hello! Our project is MLC LLM."}, - { - "role": "system", - "content": "You are a helpful, respectful and honest assistant. " - "You always ends your response with an emoji.", - }, - ] - payload = { - "model": served_model[0], - "messages": messages, - "stream": stream, - } - - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) - error_msg = "System prompt at position 1 in the message list is invalid." - if not stream: - expect_error(response.json(), msg_prefix=error_msg) - else: - num_chunks = 0 - for chunk in response.iter_lines(chunk_size=512): - if not chunk: - continue - num_chunks += 1 - expect_error(json.loads(chunk.decode("utf-8")), msg_prefix=error_msg) - assert num_chunks == 1 - - -def test_debug_dump_event_trace( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - # We only check that the request does not fail. - payload = {"model": served_model[0]} - response = requests.post(DEBUG_DUMP_EVENT_TRACE_URL, json=payload, timeout=180) - assert response.status_code == HTTPStatus.OK - - -if __name__ == "__main__": - model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") - if model_lib_path is None: - raise ValueError( - 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' - "Please set it to model lib compiled by MLC LLM " - "(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`)." - ) - MODEL = (os.path.dirname(model_lib_path), model_lib_path) - - test_openai_v1_models(MODEL, None) - - test_openai_v1_completions(MODEL, None, stream=False) - test_openai_v1_completions(MODEL, None, stream=True) - test_openai_v1_completions_openai_package(MODEL, None, stream=False) - test_openai_v1_completions_openai_package(MODEL, None, stream=True) - test_openai_v1_completions_echo(MODEL, None, stream=False) - test_openai_v1_completions_echo(MODEL, None, stream=True) - test_openai_v1_completions_suffix(MODEL, None, stream=False) - test_openai_v1_completions_suffix(MODEL, None, stream=True) - test_openai_v1_completions_stop_str(MODEL, None, stream=False) - test_openai_v1_completions_stop_str(MODEL, None, stream=True) - test_openai_v1_completions_temperature(MODEL, None, stream=False) - test_openai_v1_completions_temperature(MODEL, None, stream=True) - test_openai_v1_completions_logit_bias(MODEL, None, stream=False) - test_openai_v1_completions_logit_bias(MODEL, None, stream=True) - test_openai_v1_completions_presence_frequency_penalty(MODEL, None, stream=False) - test_openai_v1_completions_presence_frequency_penalty(MODEL, None, stream=True) - test_openai_v1_completions_seed(MODEL, None) - test_openai_v1_completions_prompt_overlong(MODEL, None, stream=False) - test_openai_v1_completions_prompt_overlong(MODEL, None, stream=True) - test_openai_v1_completions_invalid_logprobs(MODEL, None, stream=False) - test_openai_v1_completions_invalid_logprobs(MODEL, None, stream=True) - test_openai_v1_completions_unsupported_args(MODEL, None) - test_openai_v1_completions_request_cancellation(MODEL, None) - - for msg in CHAT_COMPLETION_MESSAGES: - test_openai_v1_chat_completions(MODEL, None, stream=False, messages=msg) - test_openai_v1_chat_completions(MODEL, None, stream=True, messages=msg) - test_openai_v1_chat_completions_n(MODEL, None, stream=False, messages=msg) - test_openai_v1_chat_completions_n(MODEL, None, stream=True, messages=msg) - test_openai_v1_chat_completions_openai_package(MODEL, None, stream=False, messages=msg) - test_openai_v1_chat_completions_openai_package(MODEL, None, stream=True, messages=msg) - test_openai_v1_chat_completions_max_tokens(MODEL, None, stream=False) - test_openai_v1_chat_completions_max_tokens(MODEL, None, stream=True) - test_openai_v1_chat_completions_json(MODEL, None, stream=False) - test_openai_v1_chat_completions_json(MODEL, None, stream=True) - test_openai_v1_chat_completions_ignore_eos(MODEL, None, stream=False) - test_openai_v1_chat_completions_ignore_eos(MODEL, None, stream=True) - test_openai_v1_chat_completions_system_prompt_wrong_pos(MODEL, None, stream=False) - test_openai_v1_chat_completions_system_prompt_wrong_pos(MODEL, None, stream=True) - - test_debug_dump_event_trace(MODEL, None) diff --git a/tests/python/serve/server/test_server_function_call.py b/tests/python/serve/server/test_server_function_call.py deleted file mode 100644 index 3fff27b938..0000000000 --- a/tests/python/serve/server/test_server_function_call.py +++ /dev/null @@ -1,210 +0,0 @@ -# pylint: disable=line-too-long -""" -Test script for function call in chat completion. To run this script, use the following command: -MLC_SERVE_MODEL_LIB=dist/gorilla-openfunctions-v1-q4f16_1_MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so -MLC_SERVE_MODEL_LIB=${MLC_SERVE_MODEL_LIB} python -m pytest -x tests/python/serve/server/test_server_function_call.py -""" - -# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches -import json -import os -from typing import Dict, List, Optional, Tuple - -import pytest -import requests - -OPENAI_V1_CHAT_COMPLETION_URL = "http://127.0.0.1:8000/v1/chat/completions" - - -def check_openai_nonstream_response( - response: Dict, - *, - model: str, - object_str: str, - num_choices: int, - finish_reason: List[str], - completion_tokens: Optional[int] = None, -): - print(response) - assert response["model"] == model - assert response["object"] == object_str - - choices = response["choices"] - assert isinstance(choices, list) - assert len(choices) == num_choices - for idx, choice in enumerate(choices): - assert choice["index"] == idx - assert choice["finish_reason"] in finish_reason - - # text: str - message = choice["message"] - assert message["role"] == "assistant" - if choice["finish_reason"] == "tool_calls": - assert message["content"] is None - assert isinstance(message["tool_calls"], list) - else: - assert message["tool_calls"] is None - assert message["content"] is not None - - usage = response["usage"] - assert isinstance(usage, dict) - assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] - assert usage["prompt_tokens"] > 0 - - if completion_tokens is not None: - assert usage["completion_tokens"] == completion_tokens - - -def check_openai_stream_response( - responses: List[Dict], - *, - model: str, - object_str: str, - num_choices: int, - finish_reason: str, - echo_prompt: Optional[str] = None, - suffix: Optional[str] = None, - stop: Optional[List[str]] = None, - require_substr: Optional[List[str]] = None, -): - assert len(responses) > 0 - - finished = [False for _ in range(num_choices)] - outputs = ["" for _ in range(num_choices)] - for response in responses: - assert response["model"] == model - assert response["object"] == object_str - - choices = response["choices"] - assert isinstance(choices, list) - assert len(choices) == num_choices - for idx, choice in enumerate(choices): - assert choice["index"] == idx - - delta = choice["delta"] - assert delta["role"] == "assistant" - assert isinstance(delta["content"], str) - outputs[idx] += delta["content"] - - if finished[idx]: - assert choice["finish_reason"] == finish_reason - elif choice["finish_reason"] is not None: - assert choice["finish_reason"] == finish_reason - finished[idx] = True - - for output in outputs: - if echo_prompt is not None: - assert output.startswith(echo_prompt) - if suffix is not None: - assert output.endswith(suffix) - if stop is not None: - for stop_str in stop: - assert stop_str not in output - if require_substr is not None: - for substr in require_substr: - assert substr in output - - -tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - }, - }, - } -] - - -CHAT_COMPLETION_MESSAGES = [ - # messages #0 - [ - { - "role": "user", - "content": "What is the current weather in Pittsburgh, PA?", - } - ], - # messages #1 - [ - { - "role": "user", - "content": "What is the current weather in Pittsburgh, PA and Tokyo, JP?", - } - ], - # messages #2 - [ - { - "role": "user", - "content": "What is the current weather in Pittsburgh, PA in fahrenheit?", - } - ], -] - - -@pytest.mark.parametrize("stream", [False, True]) -@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) -def test_openai_v1_chat_completion_function_call( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, - messages: List[Dict[str, str]], -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - payload = { - "model": served_model[0], - "messages": messages, - "stream": stream, - "tools": tools, - } - - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) - if not stream: - check_openai_nonstream_response( - response.json(), - model=served_model[0], - object_str="chat.completion", - num_choices=1, - finish_reason=["tool_calls", "error"], - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - model=served_model[0], - object_str="chat.completion.chunk", - num_choices=1, - finish_reason="tool_calls", - ) - - -if __name__ == "__main__": - model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") - if model_lib_path is None: - raise ValueError( - 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' - "Please set it to model lib compiled by MLC LLM " - "(e.g., `./dist/gorilla-openfunctions-v1-q4f16_1_MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so`) " - "which supports function calls." - ) - MODEL = (os.path.dirname(model_lib_path), model_lib_path) - - for msg in CHAT_COMPLETION_MESSAGES: - test_openai_v1_chat_completion_function_call(MODEL, None, stream=False, messages=msg) - test_openai_v1_chat_completion_function_call(MODEL, None, stream=True, messages=msg) diff --git a/tests/python/serve/server/test_server_image.py b/tests/python/serve/server/test_server_image.py deleted file mode 100644 index 9b016224e4..0000000000 --- a/tests/python/serve/server/test_server_image.py +++ /dev/null @@ -1,258 +0,0 @@ -# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches -import json -import os -from typing import Dict, List, Optional, Tuple - -import pytest -import regex -import requests - -OPENAI_V1_CHAT_COMPLETION_URL = "http://127.0.0.1:8001/v1/chat/completions" - -JSON_TOKEN_PATTERN = ( - r"((-?(?:0|[1-9]\d*))(\.\d+)?([eE][-+]?\d+)?)|null|true|false|" - r'("((\\["\\\/bfnrt])|(\\u[0-9a-fA-F]{4})|[^"\\\x00-\x1f])*")' -) -JSON_TOKEN_RE = regex.compile(JSON_TOKEN_PATTERN) - - -def is_json_or_json_prefix(s: str) -> bool: - try: - json.loads(s) - return True - except json.JSONDecodeError as e: - # If the JSON decoder reaches the end of s, it is a prefix of a JSON string. - if e.pos == len(s): - return True - # Since json.loads is token-based instead of char-based, there may remain half a token after - # the matching position. - # If the left part is a prefix of a valid JSON token, the output is also valid - regex_match = JSON_TOKEN_RE.fullmatch(s[e.pos :], partial=True) - return regex_match is not None - - -def check_openai_nonstream_response( - response: Dict, - *, - is_chat_completion: bool, - model: str, - object_str: str, - num_choices: int, - finish_reasons: List[str], - completion_tokens: Optional[int] = None, - echo_prompt: Optional[str] = None, - suffix: Optional[str] = None, - stop: Optional[List[str]] = None, - require_substr: Optional[List[str]] = None, - json_mode: bool = False, -): - assert response["model"] == model - assert response["object"] == object_str - - choices = response["choices"] - assert isinstance(choices, list) - assert len(choices) <= num_choices - texts: List[str] = ["" for _ in range(num_choices)] - for choice in choices: - idx = choice["index"] - assert choice["finish_reason"] in finish_reasons - - if not is_chat_completion: - assert isinstance(choice["text"], str) - texts[idx] = choice["text"] - if echo_prompt is not None: - assert texts[idx] - if suffix is not None: - assert texts[idx] - else: - message = choice["message"] - assert message["role"] == "assistant" - assert isinstance(message["content"], str) - texts[idx] = message["content"] - - if stop is not None: - for stop_str in stop: - assert stop_str not in texts[idx] - if require_substr is not None: - for substr in require_substr: - assert substr in texts[idx] - if json_mode: - assert is_json_or_json_prefix(texts[idx]) - - usage = response["usage"] - assert isinstance(usage, dict) - assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] - assert usage["prompt_tokens"] > 0 - if completion_tokens is not None: - assert usage["completion_tokens"] == completion_tokens - - -def check_openai_stream_response( - responses: List[Dict], - *, - is_chat_completion: bool, - model: str, - object_str: str, - num_choices: int, - finish_reasons: List[str], - completion_tokens: Optional[int] = None, - echo_prompt: Optional[str] = None, - suffix: Optional[str] = None, - stop: Optional[List[str]] = None, - require_substr: Optional[List[str]] = None, - json_mode: bool = False, -): - assert len(responses) > 0 - - finished = [False for _ in range(num_choices)] - outputs = ["" for _ in range(num_choices)] - for response in responses: - assert response["model"] == model - assert response["object"] == object_str - - choices = response["choices"] - assert isinstance(choices, list) - assert len(choices) <= num_choices - for choice in choices: - idx = choice["index"] - - if not is_chat_completion: - assert isinstance(choice["text"], str) - outputs[idx] += choice["text"] - else: - delta = choice["delta"] - assert delta["role"] == "assistant" - assert isinstance(delta["content"], str) - outputs[idx] += delta["content"] - - if finished[idx]: - assert choice["finish_reason"] in finish_reasons - elif choice["finish_reason"] is not None: - assert choice["finish_reason"] in finish_reasons - finished[idx] = True - - if not is_chat_completion: - usage = response["usage"] - assert isinstance(usage, dict) - assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] - assert usage["prompt_tokens"] > 0 - if completion_tokens is not None: - assert usage["completion_tokens"] <= completion_tokens - - if not is_chat_completion: - if completion_tokens is not None: - assert responses[-1]["usage"]["completion_tokens"] == completion_tokens - - for i, output in enumerate(outputs): - if echo_prompt is not None: - assert output.startswith(echo_prompt) - if suffix is not None: - assert output.endswith(suffix) - if stop is not None: - for stop_str in stop: - assert stop_str not in output - if require_substr is not None: - for substr in require_substr: - assert substr in output - if json_mode: - assert is_json_or_json_prefix(output) - - -CHAT_COMPLETION_MESSAGES = [ - # messages #0 - [ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": "https://llava-vl.github.io/static/images/view.jpg", - }, - {"type": "text", "text": "What does this image represent?"}, - ], - }, - ], - # messages #1 - [ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": "https://llava-vl.github.io/static/images/view.jpg", - }, - {"type": "text", "text": "What does this image represent?"}, - ], - }, - { - "role": "assistant", - "content": "The image represents a serene and peaceful scene of a pier extending over a body of water, such as a lake or a river.er. The pier is made of wood and has a bench on it, providing a place for people to sit and enjoy the view. The pier is situated in a natural environment, surrounded by trees and mountains in the background. This setting creates a tranquil atmosphere, inviting visitors to relax and appreciate the beauty of the landscape.", - }, - { - "role": "user", - "content": "What country is the image set in? Give me 10 ranked guesses and reasons why.", - }, - ], -] - - -@pytest.mark.parametrize("stream", [False, True]) -@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) -def test_openai_v1_chat_completions( - served_model: Tuple[str, str], - launch_server, # pylint: disable=unused-argument - stream: bool, - messages: List[Dict[str, str]], -): - # `served_model` and `launch_server` are pytest fixtures - # defined in conftest.py. - - payload = { - "model": served_model[0], - "messages": messages, - "stream": stream, - } - response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180) - if not stream: - check_openai_nonstream_response( - response.json(), - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion", - num_choices=1, - finish_reasons=["stop"], - ) - else: - responses = [] - for chunk in response.iter_lines(chunk_size=512): - if not chunk or chunk == b"data: [DONE]": - continue - responses.append(json.loads(chunk.decode("utf-8")[6:])) - check_openai_stream_response( - responses, - is_chat_completion=True, - model=served_model[0], - object_str="chat.completion.chunk", - num_choices=1, - finish_reasons=["stop"], - ) - - -if __name__ == "__main__": - model_lib_path = os.environ.get("MLC_SERVE_MODEL_LIB") - if model_lib_path is None: - raise ValueError( - 'Environment variable "MLC_SERVE_MODEL_LIB" not found. ' - "Please set it to model lib compiled by MLC LLM " - "(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`)." - ) - - model = os.environ.get("MLC_SERVE_MODEL") - if model is None: - MODEL = (os.path.dirname(model_lib_path), model_lib_path) - else: - MODEL = (model, model_lib_path) - - for msg in CHAT_COMPLETION_MESSAGES: - test_openai_v1_chat_completions(MODEL, None, stream=False, messages=msg) - test_openai_v1_chat_completions(MODEL, None, stream=True, messages=msg) diff --git a/tests/python/serve/test_event_trace_recorder.py b/tests/python/serve/test_event_trace_recorder.py deleted file mode 100644 index b22dfeddad..0000000000 --- a/tests/python/serve/test_event_trace_recorder.py +++ /dev/null @@ -1,44 +0,0 @@ -# pylint: disable=missing-module-docstring,missing-function-docstring -import json - -from mlc_llm.serve.event_trace_recorder import EventTraceRecorder - - -def test_event_trace_recorder(): - trace_recorder = EventTraceRecorder() - request_ids = ["x", "y"] - num_decode = 5 - - for request_id in request_ids: - trace_recorder.add_event(request_id, event="start tokenization") - trace_recorder.add_event(request_id, event="finish tokenization") - trace_recorder.add_event(request_id, event="add request") - trace_recorder.add_event(request_id, event="start embed") - trace_recorder.add_event(request_id, event="finish embed") - trace_recorder.add_event(request_id, event="start prefill") - trace_recorder.add_event(request_id, event="finish prefill") - - for _ in range(num_decode): - for request_id in request_ids: - trace_recorder.add_event(request_id, event="start decode") - trace_recorder.add_event(request_id, event="finish decode") - for request_id in request_ids: - trace_recorder.add_event(request_id, event="start detokenization") - trace_recorder.add_event(request_id, event="finish detokenization") - - events = json.loads(trace_recorder.dump_json()) - decode_count = {} - for event in events: - request_id = event["tid"] - if event["name"].startswith("decode"): - if request_id not in decode_count: - decode_count[request_id] = 1 - else: - decode_count[request_id] += 1 - - for _, decode_cnt in decode_count.items(): - assert decode_cnt == num_decode * 2, decode_cnt - - -if __name__ == "__main__": - test_event_trace_recorder() diff --git a/tests/python/serve/test_grammar_parser.py b/tests/python/serve/test_grammar_parser.py deleted file mode 100644 index 10eacdf9b9..0000000000 --- a/tests/python/serve/test_grammar_parser.py +++ /dev/null @@ -1,285 +0,0 @@ -# pylint: disable=missing-module-docstring,missing-function-docstring -import os - -import pytest -import tvm.testing -from tvm import TVMError - -from mlc_llm.serve import BNFGrammar - - -def test_bnf_simple(): - before = """main ::= b c -b ::= "b" -c ::= "c" -""" - expected = """main ::= ((b c)) -b ::= (([b])) -c ::= (([c])) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) - after = bnf_grammar.to_string() - assert after == expected - - -def test_ebnf(): - before = """main ::= b c | b main -b ::= "ab"* -c ::= [acep-z]+ -d ::= "d"? -""" - expected = """main ::= ((b c) | (b main)) -b ::= ((b_1)) -c ::= ((c_1)) -d ::= ((d_1)) -b_1 ::= ("" | ([a] [b] b_1)) -c_1 ::= (([acep-z] c_1) | ([acep-z])) -d_1 ::= ("" | ([d])) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) - after = bnf_grammar.to_string() - assert after == expected - - -def test_star_quantifier(): - before = """main ::= b c d -b ::= [b]* -c ::= "b"* -d ::= ([b] [c] [d] | ([p] [q]))* -e ::= [e]* [f]* | [g]* -""" - expected = """main ::= ((b c d)) -b ::= [b]* -c ::= ((c_1)) -d ::= ((d_1)) -e ::= ((e_star e_star_1) | (e_star_2)) -c_1 ::= ("" | ([b] c_1)) -d_1 ::= ("" | (d_1_choice d_1)) -e_star ::= [e]* -e_star_1 ::= [f]* -e_star_2 ::= [g]* -d_1_choice ::= (([b] [c] [d]) | ([p] [q])) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) - after = bnf_grammar.to_string() - assert after == expected - - -def test_char(): - before = r"""main ::= [a-z] [A-z] "\u0234" "\U00000345\xff" [-A-Z] [--] [^a] rest -rest ::= [a-zA-Z0-9-] [\u0234-\U00000345] [测-试] [\--\]] rest1 -rest1 ::= "\?\"\'测试あc" "👀" "" -""" - expected = r"""main ::= (([a-z] [A-z] ([\u0234]) ([\u0345] [\u00ff]) [\-A-Z] [\-\-] [^a] rest)) -rest ::= (([a-zA-Z0-9\-] [\u0234-\u0345] [\u6d4b-\u8bd5] [\--\]] rest1)) -rest1 ::= ((([\?] [\"] [\'] [\u6d4b] [\u8bd5] [\u3042] [c]) ([\U0001f440]) "")) -""" - # Disable unwrap_nesting_rules to expose the result before unwrapping. - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", False, False) - after = bnf_grammar.to_string() - assert after == expected - - -def test_space(): - before = """ - -main::="a" "b" ("c""d" -"e") | - -"f" | "g" -""" - expected = """main ::= (([a] [b] [c] [d] [e]) | ([f]) | ([g])) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) - after = bnf_grammar.to_string() - assert after == expected - - -def test_nest(): - before = """main::= "a" ("b" | "c" "d") | (("e" "f")) -""" - expected = """main ::= (([a] main_choice) | ([e] [f])) -main_choice ::= (([b]) | ([c] [d])) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) - after = bnf_grammar.to_string() - assert after == expected - - -def test_flatten(): - before = """main ::= or_test sequence_test nested_test empty_test -or_test ::= ([a] | "b") | "de" | "" | or_test | [^a-z] -sequence_test ::= [a] "a" ("b" ("c" | "d")) ("d" "e") sequence_test "" -nested_test ::= ("a" ("b" ("c" "d"))) | ("a" | ("b" | "c")) | nested_rest -nested_rest ::= ("a" | ("b" "c" | ("d" | "e" "f"))) | ((("g"))) -empty_test ::= "d" | (("" | "" "") "" | "a" "") | ("" ("" | "")) "" "" -""" - expected = """main ::= ((or_test sequence_test nested_test empty_test)) -or_test ::= ("" | ([a]) | ([b]) | ([d] [e]) | (or_test) | ([^a-z])) -sequence_test ::= (([a] [a] [b] sequence_test_choice [d] [e] sequence_test)) -nested_test ::= (([a] [b] [c] [d]) | ([a]) | ([b]) | ([c]) | (nested_rest)) -nested_rest ::= (([a]) | ([b] [c]) | ([d]) | ([e] [f]) | ([g])) -empty_test ::= ("" | ([d]) | ([a])) -sequence_test_choice ::= (([c]) | ([d])) -""" - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) - after = bnf_grammar.to_string() - assert after == expected - - -def test_json(): - current_file_path = os.path.abspath(__file__) - json_ebnf_path = os.path.join(os.path.dirname(current_file_path), "json.ebnf") - - with open(json_ebnf_path, "r", encoding="utf-8") as file: - before = file.read() - - expected = r"""main ::= ((element)) -value ::= ((object) | (array) | (string) | (number) | ([t] [r] [u] [e]) | ([f] [a] [l] [s] [e]) | ([n] [u] [l] [l])) -object ::= (([{] ws [}]) | ([{] members [}])) -members ::= ((member) | (member [,] members)) -member ::= ((ws string ws [:] element)) -array ::= (([[] ws [\]]) | ([[] elements [\]])) -elements ::= ((element) | (element [,] elements)) -element ::= ((ws value ws)) -string ::= (([\"] characters [\"])) -characters ::= ("" | (character characters)) -character ::= (([^\"\\]) | ([\\] escape)) -escape ::= (([\"]) | ([\\]) | ([/]) | ([b]) | ([f]) | ([n]) | ([r]) | ([t]) | ([u] hex hex hex hex)) -hex ::= (([A-Fa-f0-9])) -number ::= ((integer fraction exponent)) -integer ::= ((digit) | (onenine digits) | ([\-] digit) | ([\-] onenine digits)) -digits ::= ((digit) | (digit digits)) -digit ::= (([0-9])) -onenine ::= (([1-9])) -fraction ::= ("" | ([.] digits)) -exponent ::= ("" | (exponent_choice exponent_choice_1 digits)) -ws ::= ("" | ([ ] ws) | ([\n] ws) | ([\r] ws) | ([\t] ws)) -exponent_choice ::= (([e]) | ([E])) -exponent_choice_1 ::= ("" | ([+]) | ([\-])) -""" - - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) - after = bnf_grammar.to_string() - assert after == expected - - -def test_to_string_roundtrip(): - """Checks the printed result can be parsed, and the parsing-printing process is idempotent.""" - - before = r"""main ::= (b c) | (b main) -b ::= b_1 d -c ::= c_1 -d ::= d_1 -b_1 ::= ([b] b_1) | "" -c_1 ::= (c_2 c_1) | c_2 -c_2 ::= [acep-z] -d_1 ::= [d] | "" -""" - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main", True, False) - output_string_1 = bnf_grammar_1.to_string() - bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, "main", True, False) - output_string_2 = bnf_grammar_2.to_string() - assert output_string_1 == output_string_2 - - -def test_error(): - with pytest.raises( - TVMError, match='TVMError: EBNF parse error at line 1, column 11: Rule "a" is not defined' - ): - BNFGrammar.from_ebnf_string("main ::= a b") - - with pytest.raises( - TVMError, match="TVMError: EBNF parse error at line 1, column 15: Expect element" - ): - BNFGrammar.from_ebnf_string('main ::= "a" |') - - with pytest.raises(TVMError, match='TVMError: EBNF parse error at line 1, column 15: Expect "'): - BNFGrammar.from_ebnf_string('main ::= "a" "') - - with pytest.raises( - TVMError, match="TVMError: EBNF parse error at line 1, column 1: Expect rule name" - ): - BNFGrammar.from_ebnf_string('::= "a"') - - with pytest.raises( - TVMError, - match="TVMError: EBNF parse error at line 1, column 12: Character class should not contain " - "newline", - ): - BNFGrammar.from_ebnf_string("main ::= [a\n]") - - with pytest.raises( - TVMError, match="TVMError: EBNF parse error at line 1, column 11: Invalid escape sequence" - ): - BNFGrammar.from_ebnf_string(r'main ::= "\@"') - - with pytest.raises( - TVMError, match="TVMError: EBNF parse error at line 1, column 11: Invalid escape sequence" - ): - BNFGrammar.from_ebnf_string(r'main ::= "\uFF"') - - with pytest.raises( - TVMError, - match="TVMError: EBNF parse error at line 1, column 14: Invalid character class: " - "lower bound is larger than upper bound", - ): - BNFGrammar.from_ebnf_string(r"main ::= [Z-A]") - - with pytest.raises( - TVMError, match="TVMError: EBNF parse error at line 1, column 6: Expect ::=" - ): - BNFGrammar.from_ebnf_string(r'main := "a"') - - with pytest.raises( - TVMError, - match='TVMError: EBNF parse error at line 2, column 9: Rule "main" is defined multiple ' - "times", - ): - BNFGrammar.from_ebnf_string('main ::= "a"\nmain ::= "b"') - - with pytest.raises( - TVMError, - match="TVMError: EBNF parse error at line 1, column 10: " - 'The main rule with name "main" is not found.', - ): - BNFGrammar.from_ebnf_string('a ::= "a"') - - -def test_to_json(): - before = """main ::= b c | b main -b ::= "bcd" -c ::= [a-z] -""" - expected = ( - '{"rule_expr_indptr":[0,3,6,10,13,16,20,24,28,32,36,41,44,48,51],"rule_expr_data"' - ":[3,1,1,3,1,2,4,2,0,1,3,1,1,3,1,0,4,2,3,4,5,2,2,5,0,2,98,98,0,2,99,99,0,2,100,100," - '4,3,7,8,9,5,1,10,0,2,97,122,4,1,12,5,1,13],"rules":[{"body_expr_id":6,"name":"main"},' - '{"body_expr_id":11,"name":"b"},{"body_expr_id":14,"name":"c"}]}' - ) - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) - after = bnf_grammar.to_json(False) - assert after == expected - - -def test_to_json_roundtrip(): - before = r"""main ::= ((b c) | (b main)) -b ::= ((b_1 d)) -c ::= ((c_1)) -d ::= ((d_1)) -b_1 ::= ("" | ([b] b_1)) -c_1 ::= ((c_2 c_1) | (c_2)) -c_2 ::= (([acep-z])) -d_1 ::= ("" | ([d])) -""" - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main", True, False) - output_json_1 = bnf_grammar_1.to_json(False) - bnf_grammar_2 = BNFGrammar.from_json(output_json_1) - output_json_2 = bnf_grammar_2.to_json(False) - output_str = bnf_grammar_2.to_string() - assert output_json_1 == output_json_2 - assert output_str == before - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/serve/test_grammar_state_matcher_custom.py b/tests/python/serve/test_grammar_state_matcher_custom.py deleted file mode 100644 index 6fc48705d1..0000000000 --- a/tests/python/serve/test_grammar_state_matcher_custom.py +++ /dev/null @@ -1,399 +0,0 @@ -# pylint: disable=missing-module-docstring,missing-function-docstring -# pylint: disable=redefined-outer-name,unbalanced-tuple-unpacking -"""This test is adopted from test_grammar_state_matcher_json.py, but the grammar is parsed from -a unoptimized, non-simplified EBNF string. This is to test the robustness of the grammar state -matcher.""" -import json -import sys -from typing import Dict, List, Optional, Tuple - -import pytest -import tvm -import tvm.testing -from pydantic import BaseModel - -from mlc_llm.serve import BNFGrammar, GrammarStateMatcher -from mlc_llm.tokenizer import Tokenizer - - -def get_json_grammar(): - json_grammar_ebnf = r""" -main ::= basic_array | basic_object -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= (([\"] basic_string_1 [\"])) -basic_string_1 ::= "" | [^"\\\r\n] basic_string_1 | "\\" escape basic_string_1 -escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]" -basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" -ws ::= [ \n\t]* -""" - grammar = BNFGrammar.from_ebnf_string(json_grammar_ebnf) - return grammar - - -@pytest.fixture(scope="function") -def json_grammar(): - return get_json_grammar() - - -(json_input_accepted,) = tvm.testing.parameters( - ('{"name": "John"}',), - ('{ "name" : "John" }',), - ("{}",), - ("[]",), - ('{"name": "Alice", "age": 30, "city": "New York"}',), - ('{"name": "Mike", "hobbies": ["reading", "cycling", "hiking"]}',), - ('{"name": "Emma", "address": {"street": "Maple Street", "city": "Boston"}}',), - ('[{"name": "David"}, {"name": "Sophia"}]',), - ( - '{"name": "William", "age": null, "married": true, "children": ["Liam", "Olivia"],' - ' "hasPets": false}', - ), - ( - '{"name": "Olivia", "contact": {"email": "olivia@example.com", "address": ' - '{"city": "Chicago", "zipcode": "60601"}}}', - ), - ( - '{"name": "Liam", "skills": ["Java", "Python"], "experience": ' - '[{"company": "CompanyA", "years": 5}, {"company": "CompanyB", "years": 3}]}', - ), - ( - '{"person": {"name": "Ethan", "age": 40}, "education": {"degree": "Masters", ' - '"university": "XYZ University"}, "work": [{"company": "ABC Corp", "position": ' - '"Manager"}, {"company": "DEF Corp", "position": "Senior Manager"}]}', - ), - ( - '{"name": "Charlotte", "details": {"personal": {"age": 35, "hobbies": ["gardening", ' - '"painting"]}, "professional": {"occupation": "Engineer", "skills": ' - '["CAD", "Project Management"], "projects": [{"name": "Project A", ' - '"status": "Completed"}, {"name": "Project B", "status": "In Progress"}]}}}', - ), -) - - -def test_json_accept(json_grammar: BNFGrammar, json_input_accepted: str): - assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_accepted) - - -(json_input_refused,) = tvm.testing.parameters( - (r'{ name: "John" }',), - (r'{ "name": "John" } ',), # trailing space is not accepted - (r'{ "name": "John", "age": 30, }',), - (r'{ "name": "John", "address": { "street": "123 Main St", "city": "New York" }',), - (r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling",], }',), - (r'{ "name": "John", "age": 30.5.7 }',), - (r'{ "name": "John, "age": 30, "hobbies": ["reading", "traveling"] }',), - ( - r'{ "name": "John", "age": 30, "hobbies": ["reading", { "type": "outdoor", "list": ' - r'["hiking", "swimming",]}] }', - ), - (r'{ "name": "John", "age": 30, "status": "\P\J" }',), - ( - r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling"], "address": ' - r'{ "street": "123 Main St", "city": "New York", "coordinates": { "latitude": 40.7128, ' - r'"longitude": -74.0060 }}}, "work": { "company": "Acme", "position": "developer" }}', - ), -) - - -def test_json_refuse(json_grammar: BNFGrammar, json_input_refused): - assert not GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_refused) - - -(json_input_pressure,) = tvm.testing.parameters( - # Extra long string: 1k chars - ( - '["Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer nec odio. Praesent ' - "libero. Sed cursus ante dapibus diam. Sed nisi. Nulla quis sem at nibh elementum " - "imperdiet. Duis sagittis ipsum. Praesent mauris. Fusce nec tellus sed augue semper " - "porta. Mauris massa. Vestibulum lacinia arcu eget nulla. Class aptent taciti sociosqu " - "ad litora torquent per conubia nostra, per inceptos himenaeos. Curabitur sodales ligula " - "in libero. Sed dignissim lacinia nunc. Curabitur tortor. Pellentesque nibh. Aenean quam. " - "In scelerisque sem at dolor. Maecenas mattis. Sed convallis tristique sem. Proin ut " - "ligula vel nunc egestas porttitor. Morbi lectus risus, iaculis vel, suscipit quis, " - "luctus non, massa. Fusce ac turpis quis ligula lacinia aliquet. Mauris ipsum. Nulla " - "metus metus, ullamcorper vel, tincidunt sed, euismod in, nibh. Quisque volutpat " - "condimentum velit. Class aptent taciti sociosqu ad litora torquent per conubia nostra, " - "per inceptos himenaeos. Nam nec ante. Sed lacinia, urna non tincidunt mattis, tortor " - "neque adipiscing diam, a cursus ipsum ante quis turpis. Nulla facilisi. Ut fringilla. " - "Suspendisse potenti. Nunc feugiat mi a tellus consequat imperdiet. Vestibulum sapien. " - "Proin quam. Etiam ultrices. Suspendisse in justo eu magna luctus suscipit. Sed lectus. " - "Integer euismod lacus luctus magna. Quisque cursus, metus vitae pharetra auctor, sem " - 'massa mattis sem, at interdum magna augue eget diam."]', - ), - # long and complex json: 3k chars - ( - r"""{ - "web-app": { - "servlet": [ - { - "servlet-name": "cofaxCDS", - "servlet-class": "org.cofax.cds.CDSServlet", - "init-param": { - "configGlossary:installationAt": "Philadelphia, PA", - "configGlossary:adminEmail": "ksm@pobox.com", - "configGlossary:poweredBy": "Cofax", - "configGlossary:poweredByIcon": "/images/cofax.gif", - "configGlossary:staticPath": "/content/static", - "templateProcessorClass": "org.cofax.WysiwygTemplate", - "templateLoaderClass": "org.cofax.FilesTemplateLoader", - "templatePath": "templates", - "templateOverridePath": "", - "defaultListTemplate": "listTemplate.htm", - "defaultFileTemplate": "articleTemplate.htm", - "useJSP": false, - "jspListTemplate": "listTemplate.jsp", - "jspFileTemplate": "articleTemplate.jsp", - "cachePackageTagsTrack": 200, - "cachePackageTagsStore": 200, - "cachePackageTagsRefresh": 60, - "cacheTemplatesTrack": 100, - "cacheTemplatesStore": 50, - "cacheTemplatesRefresh": 15, - "cachePagesTrack": 200, - "cachePagesStore": 100, - "cachePagesRefresh": 10, - "cachePagesDirtyRead": 10, - "searchEngineListTemplate": "forSearchEnginesList.htm", - "searchEngineFileTemplate": "forSearchEngines.htm", - "searchEngineRobotsDb": "WEB-INF/robots.db", - "useDataStore": true, - "dataStoreClass": "org.cofax.SqlDataStore", - "redirectionClass": "org.cofax.SqlRedirection", - "dataStoreName": "cofax", - "dataStoreDriver": "com.microsoft.jdbc.sqlserver.SQLServerDriver", - "dataStoreUrl": "jdbc:microsoft:sqlserver://LOCALHOST:1433;DatabaseName=goon", - "dataStoreUser": "sa", - "dataStorePassword": "dataStoreTestQuery", - "dataStoreTestQuery": "SET NOCOUNT ON;select test='test';", - "dataStoreLogFile": "/usr/local/tomcat/logs/datastore.log", - "dataStoreInitConns": 10, - "dataStoreMaxConns": 100, - "dataStoreConnUsageLimit": 100, - "dataStoreLogLevel": "debug", - "maxUrlLength": 500 - } - }, - { - "servlet-name": "cofaxEmail", - "servlet-class": "org.cofax.cds.EmailServlet", - "init-param": { - "mailHost": "mail1", - "mailHostOverride": "mail2" - } - }, - { - "servlet-name": "cofaxAdmin", - "servlet-class": "org.cofax.cds.AdminServlet" - }, - { - "servlet-name": "fileServlet", - "servlet-class": "org.cofax.cds.FileServlet" - }, - { - "servlet-name": "cofaxTools", - "servlet-class": "org.cofax.cms.CofaxToolsServlet", - "init-param": { - "templatePath": "toolstemplates/", - "log": 1, - "logLocation": "/usr/local/tomcat/logs/CofaxTools.log", - "logMaxSize": "", - "dataLog": 1, - "dataLogLocation": "/usr/local/tomcat/logs/dataLog.log", - "dataLogMaxSize": "", - "removePageCache": "/content/admin/remove?cache=pages&id=", - "removeTemplateCache": "/content/admin/remove?cache=templates&id=", - "fileTransferFolder": "/usr/local/tomcat/webapps/content/fileTransferFolder", - "lookInContext": 1, - "adminGroupID": 4, - "betaServer": true - } - } - ], - "servlet-mapping": { - "cofaxCDS": "/", - "cofaxEmail": "/cofaxutil/aemail/*", - "cofaxAdmin": "/admin/*", - "fileServlet": "/static/*", - "cofaxTools": "/tools/*" - }, - "taglib": { - "taglib-uri": "cofax.tld", - "taglib-location": "/WEB-INF/tlds/cofax.tld" - } - } -}""", - ), -) - - -def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): - assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_pressure) - - -(input_find_rejected_tokens, expected_rejected_sizes) = tvm.testing.parameters( - ( - # short test - '{"id": 1,"name": "Example"}', - [ - # fmt: off - 31989, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 299, 299, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 292, 292, 292, 292, 292, 31974, 31999 - # fmt: on - ], - ), - ( - # long test - """{ -"id": 1, -"na": "ex", -"ac": true, -"t": ["t1", "t2"], -"ne": {"lv2": {"val": "dp"}, "arr": [1, 2, 3]}, -"res": "res" -}""", - [ - # fmt: off - 31989, 31912, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 31915, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 31974, 31915, 31915, 299, 299, 299, 31973, - 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 299, 299, 31973, 31846, 31846, - 31840, 291, 291, 291, 31969, 31846, 31846, 291, 291, 291, 31969, 31974, 31915, 31915, - 299, 299, 299, 31973, 31846, 31846, 31908, 299, 299, 299, 299, 31973, 31846, 31846, - 31906, 299, 299, 299, 299, 31973, 31846, 31846, 291, 291, 291, 31968, 31970, 31915, - 31915, 299, 299, 299, 299, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, - 31846, 31846, 31943, 31970, 31974, 31915, 31915, 299, 299, 299, 299, 31973, 31846, - 31846, 292, 292, 292, 292, 31974, 31974, 31999 - # fmt: on - ], - ), -) - - -def test_find_next_rejected_tokens( - json_grammar: BNFGrammar, - input_find_rejected_tokens: str, - expected_rejected_sizes: Optional[List[int]] = None, -): - tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - tokenizer = Tokenizer(tokenizer_path) - grammar_state_matcher = GrammarStateMatcher(json_grammar, tokenizer) - - real_sizes = [] - for c in input_find_rejected_tokens: - rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - real_sizes.append(len(rejected_token_ids)) - print("Accepting char:", c, file=sys.stderr) - assert grammar_state_matcher.debug_accept_char(ord(c)) - rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - real_sizes.append(len(rejected_token_ids)) - - if expected_rejected_sizes is not None: - assert real_sizes == expected_rejected_sizes - - -def test_token_based_operations(json_grammar: BNFGrammar): - """Test accepting token and finding the next token mask.""" - token_table = [ - # fmt: off - "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', - # fmt: on - ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}"] - input_ids = [token_table.index(t) for t in input_splitted] - - grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) - - expected = [ - ["{"], - ['"', "}", "\n", " ", '"a":true'], - ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", " "], - ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", " "], - [":", "\n", " ", ':"'], - ['"', "{", "6", "\n", " "], - ["}", ", ", "6", "\n", " "], - [" ", "\n", '"', '"a":true'], - [" ", "\n", '"', '"a":true'], - ["}", ", ", "\n", " "], - [""], - ] - - result = [] - - for id in input_ids: - rejected = grammar_state_matcher.find_next_rejected_tokens() - accepted = list(set(range(len(token_table))) - set(rejected)) - accepted_tokens = [token_table[i] for i in accepted] - result.append(accepted_tokens) - assert id in accepted - assert grammar_state_matcher.accept_token(id) - - rejected = grammar_state_matcher.find_next_rejected_tokens() - accepted = list(set(range(len(token_table))) - set(rejected)) - accepted_tokens = [token_table[i] for i in accepted] - result.append(accepted_tokens) - - assert result == expected - - -def test_custom_main_rule(): - json_grammar_ebnf = r""" -main ::= basic_object -basic_any ::= basic_string | basic_object -basic_string ::= (([\"] basic_string_1 [\"])) -basic_string_1 ::= "" | [^"\\\r\n] basic_string_1 | "\\" escape basic_string_1 -escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}" -ws ::= [ \n\t]* -""" - grammar = BNFGrammar.from_ebnf_string(json_grammar_ebnf, "basic_string") - assert GrammarStateMatcher(grammar).debug_match_complete_string(r'"abc\r\n"') - assert not GrammarStateMatcher(grammar).debug_match_complete_string(r'{"name": "John" }') - - -def test_find_next_rejected_tokens_schema(): - class MainModel(BaseModel): - integer_field: int - number_field: float - boolean_field: bool - any_array_field: List - array_field: List[str] - tuple_field: Tuple[str, int, List[str]] - object_field: Dict[str, int] - nested_object_field: Dict[str, Dict[str, int]] - - schema = MainModel.model_json_schema() - schema_str = json.dumps(schema) - ebnf_grammar = BNFGrammar.from_schema(schema_str, indent=2) - - instance = MainModel( - integer_field=42, - number_field=3.14e5, - boolean_field=True, - any_array_field=[3.14, "foo", None, True], - array_field=["foo", "bar"], - tuple_field=("foo", 42, ["bar", "baz"]), - object_field={"foo": 42, "bar": 43}, - nested_object_field={"foo": {"bar": 42}}, - ) - instance_str = instance.model_dump_json(indent=2, round_trip=True) - - tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - tokenizer = Tokenizer(tokenizer_path) - matcher = GrammarStateMatcher(ebnf_grammar, tokenizer) - - for c in instance_str: - matcher.find_next_rejected_tokens(True) - print("Accepting char:", c, file=sys.stderr) - assert matcher.debug_accept_char(ord(c)) - matcher.find_next_rejected_tokens(True) - - -if __name__ == "__main__": - # Run a benchmark to show the performance before running tests - test_find_next_rejected_tokens(get_json_grammar(), '{"id": 1,"name": "Example"}') - - tvm.testing.main() diff --git a/tests/python/serve/test_grammar_state_matcher_json.py b/tests/python/serve/test_grammar_state_matcher_json.py deleted file mode 100644 index fc0f79a041..0000000000 --- a/tests/python/serve/test_grammar_state_matcher_json.py +++ /dev/null @@ -1,412 +0,0 @@ -# pylint: disable=missing-module-docstring,missing-function-docstring -# pylint: disable=redefined-outer-name,unbalanced-tuple-unpacking -"""This test uses the optimized JSON grammar provided by the grammar library.""" -import sys -from typing import List, Optional - -import pytest -import tvm -import tvm.testing -from tvm import TVMError - -from mlc_llm.serve import BNFGrammar, GrammarStateMatcher -from mlc_llm.tokenizer import Tokenizer - - -@pytest.fixture(scope="function") -def json_grammar(): - return BNFGrammar.get_grammar_of_json() - - -(json_input_accepted,) = tvm.testing.parameters( - ('{"name": "John"}',), - ('{ "name" : "John" }',), - ("{}",), - ("[]",), - ('{"name": "Alice", "age": 30, "city": "New York"}',), - ('{"name": "Mike", "hobbies": ["reading", "cycling", "hiking"]}',), - ('{"name": "Emma", "address": {"street": "Maple Street", "city": "Boston"}}',), - ('[{"name": "David"}, {"name": "Sophia"}]',), - ( - '{"name": "William", "age": null, "married": true, "children": ["Liam", "Olivia"],' - ' "hasPets": false}', - ), - ( - '{"name": "Olivia", "contact": {"email": "olivia@example.com", "address": ' - '{"city": "Chicago", "zipcode": "60601"}}}', - ), - ( - '{"name": "Liam", "skills": ["Java", "Python"], "experience": ' - '[{"company": "CompanyA", "years": 5}, {"company": "CompanyB", "years": 3}]}', - ), - ( - '{"person": {"name": "Ethan", "age": 40}, "education": {"degree": "Masters", ' - '"university": "XYZ University"}, "work": [{"company": "ABC Corp", "position": ' - '"Manager"}, {"company": "DEF Corp", "position": "Senior Manager"}]}', - ), - ( - '{"name": "Charlotte", "details": {"personal": {"age": 35, "hobbies": ["gardening", ' - '"painting"]}, "professional": {"occupation": "Engineer", "skills": ' - '["CAD", "Project Management"], "projects": [{"name": "Project A", ' - '"status": "Completed"}, {"name": "Project B", "status": "In Progress"}]}}}', - ), -) - - -def test_json_accept(json_grammar: BNFGrammar, json_input_accepted: str): - assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_accepted) - - -(json_input_refused,) = tvm.testing.parameters( - (r'{ name: "John" }',), - (r'{ "name": "John" } ',), # trailing space is not accepted - (r'{ "name": "John", "age": 30, }',), - (r'{ "name": "John", "address": { "street": "123 Main St", "city": "New York" }',), - (r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling",], }',), - (r'{ "name": "John", "age": 30.5.7 }',), - (r'{ "name": "John, "age": 30, "hobbies": ["reading", "traveling"] }',), - ( - r'{ "name": "John", "age": 30, "hobbies": ["reading", { "type": "outdoor", "list": ' - r'["hiking", "swimming",]}] }', - ), - (r'{ "name": "John", "age": 30, "status": "\P\J" }',), - ( - r'{ "name": "John", "age": 30, "hobbies": ["reading", "traveling"], "address": ' - r'{ "street": "123 Main St", "city": "New York", "coordinates": { "latitude": 40.7128, ' - r'"longitude": -74.0060 }}}, "work": { "company": "Acme", "position": "developer" }}', - ), -) - - -def test_json_refuse(json_grammar: BNFGrammar, json_input_refused): - assert not GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_refused) - - -(json_input_pressure,) = tvm.testing.parameters( - # Extra long string: 1k chars - ( - '["Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer nec odio. Praesent ' - "libero. Sed cursus ante dapibus diam. Sed nisi. Nulla quis sem at nibh elementum " - "imperdiet. Duis sagittis ipsum. Praesent mauris. Fusce nec tellus sed augue semper " - "porta. Mauris massa. Vestibulum lacinia arcu eget nulla. Class aptent taciti sociosqu " - "ad litora torquent per conubia nostra, per inceptos himenaeos. Curabitur sodales ligula " - "in libero. Sed dignissim lacinia nunc. Curabitur tortor. Pellentesque nibh. Aenean quam. " - "In scelerisque sem at dolor. Maecenas mattis. Sed convallis tristique sem. Proin ut " - "ligula vel nunc egestas porttitor. Morbi lectus risus, iaculis vel, suscipit quis, " - "luctus non, massa. Fusce ac turpis quis ligula lacinia aliquet. Mauris ipsum. Nulla " - "metus metus, ullamcorper vel, tincidunt sed, euismod in, nibh. Quisque volutpat " - "condimentum velit. Class aptent taciti sociosqu ad litora torquent per conubia nostra, " - "per inceptos himenaeos. Nam nec ante. Sed lacinia, urna non tincidunt mattis, tortor " - "neque adipiscing diam, a cursus ipsum ante quis turpis. Nulla facilisi. Ut fringilla. " - "Suspendisse potenti. Nunc feugiat mi a tellus consequat imperdiet. Vestibulum sapien. " - "Proin quam. Etiam ultrices. Suspendisse in justo eu magna luctus suscipit. Sed lectus. " - "Integer euismod lacus luctus magna. Quisque cursus, metus vitae pharetra auctor, sem " - 'massa mattis sem, at interdum magna augue eget diam."]', - ), - # long and complex json: 3k chars - ( - r"""{ - "web-app": { - "servlet": [ - { - "servlet-name": "cofaxCDS", - "servlet-class": "org.cofax.cds.CDSServlet", - "init-param": { - "configGlossary:installationAt": "Philadelphia, PA", - "configGlossary:adminEmail": "ksm@pobox.com", - "configGlossary:poweredBy": "Cofax", - "configGlossary:poweredByIcon": "/images/cofax.gif", - "configGlossary:staticPath": "/content/static", - "templateProcessorClass": "org.cofax.WysiwygTemplate", - "templateLoaderClass": "org.cofax.FilesTemplateLoader", - "templatePath": "templates", - "templateOverridePath": "", - "defaultListTemplate": "listTemplate.htm", - "defaultFileTemplate": "articleTemplate.htm", - "useJSP": false, - "jspListTemplate": "listTemplate.jsp", - "jspFileTemplate": "articleTemplate.jsp", - "cachePackageTagsTrack": 200, - "cachePackageTagsStore": 200, - "cachePackageTagsRefresh": 60, - "cacheTemplatesTrack": 100, - "cacheTemplatesStore": 50, - "cacheTemplatesRefresh": 15, - "cachePagesTrack": 200, - "cachePagesStore": 100, - "cachePagesRefresh": 10, - "cachePagesDirtyRead": 10, - "searchEngineListTemplate": "forSearchEnginesList.htm", - "searchEngineFileTemplate": "forSearchEngines.htm", - "searchEngineRobotsDb": "WEB-INF/robots.db", - "useDataStore": true, - "dataStoreClass": "org.cofax.SqlDataStore", - "redirectionClass": "org.cofax.SqlRedirection", - "dataStoreName": "cofax", - "dataStoreDriver": "com.microsoft.jdbc.sqlserver.SQLServerDriver", - "dataStoreUrl": "jdbc:microsoft:sqlserver://LOCALHOST:1433;DatabaseName=goon", - "dataStoreUser": "sa", - "dataStorePassword": "dataStoreTestQuery", - "dataStoreTestQuery": "SET NOCOUNT ON;select test='test';", - "dataStoreLogFile": "/usr/local/tomcat/logs/datastore.log", - "dataStoreInitConns": 10, - "dataStoreMaxConns": 100, - "dataStoreConnUsageLimit": 100, - "dataStoreLogLevel": "debug", - "maxUrlLength": 500 - } - }, - { - "servlet-name": "cofaxEmail", - "servlet-class": "org.cofax.cds.EmailServlet", - "init-param": { - "mailHost": "mail1", - "mailHostOverride": "mail2" - } - }, - { - "servlet-name": "cofaxAdmin", - "servlet-class": "org.cofax.cds.AdminServlet" - }, - { - "servlet-name": "fileServlet", - "servlet-class": "org.cofax.cds.FileServlet" - }, - { - "servlet-name": "cofaxTools", - "servlet-class": "org.cofax.cms.CofaxToolsServlet", - "init-param": { - "templatePath": "toolstemplates/", - "log": 1, - "logLocation": "/usr/local/tomcat/logs/CofaxTools.log", - "logMaxSize": "", - "dataLog": 1, - "dataLogLocation": "/usr/local/tomcat/logs/dataLog.log", - "dataLogMaxSize": "", - "removePageCache": "/content/admin/remove?cache=pages&id=", - "removeTemplateCache": "/content/admin/remove?cache=templates&id=", - "fileTransferFolder": "/usr/local/tomcat/webapps/content/fileTransferFolder", - "lookInContext": 1, - "adminGroupID": 4, - "betaServer": true - } - } - ], - "servlet-mapping": { - "cofaxCDS": "/", - "cofaxEmail": "/cofaxutil/aemail/*", - "cofaxAdmin": "/admin/*", - "fileServlet": "/static/*", - "cofaxTools": "/tools/*" - }, - "taglib": { - "taglib-uri": "cofax.tld", - "taglib-location": "/WEB-INF/tlds/cofax.tld" - } - } -}""", - ), -) - - -def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): - assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_pressure) - - -(input_find_rejected_tokens, expected_rejected_sizes) = tvm.testing.parameters( - ( - # short test - '{"id": 1,"name": "Example"}', - [ - # fmt: off - 31989, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 299, 299, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 292, 292, 292, 292, 292, 31974, 31999 - # fmt: on - ], - ), - ( - # long test - """{ -"id": 1, -"na": "ex", -"ac": true, -"t": ["t1", "t2"], -"ne": {"lv2": {"val": "dp"}, "arr": [1, 2, 3]}, -"res": "res" -}""", - [ - # fmt: off - 31989, 31912, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 31915, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 31974, 31915, 31915, 299, 299, 299, 31973, - 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 299, 299, 31973, 31846, 31846, - 31840, 291, 291, 291, 31969, 31846, 31846, 291, 291, 291, 31969, 31974, 31915, 31915, - 299, 299, 299, 31973, 31846, 31846, 31908, 299, 299, 299, 299, 31973, 31846, 31846, - 31906, 299, 299, 299, 299, 31973, 31846, 31846, 291, 291, 291, 31968, 31970, 31915, - 31915, 299, 299, 299, 299, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, - 31846, 31846, 31943, 31970, 31974, 31915, 31915, 299, 299, 299, 299, 31973, 31846, - 31846, 292, 292, 292, 292, 31974, 31974, 31999 - # fmt: on - ], - ), -) - - -def test_find_next_rejected_tokens( - json_grammar: BNFGrammar, - input_find_rejected_tokens: str, - expected_rejected_sizes: Optional[List[int]] = None, -): - tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - tokenizer = Tokenizer(tokenizer_path) - grammar_state_matcher = GrammarStateMatcher(json_grammar, tokenizer) - - real_sizes = [] - for c in input_find_rejected_tokens: - rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - real_sizes.append(len(rejected_token_ids)) - print("Accepting char:", c, file=sys.stderr) - assert grammar_state_matcher.debug_accept_char(ord(c)) - rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - real_sizes.append(len(rejected_token_ids)) - if expected_rejected_sizes is not None: - assert real_sizes == expected_rejected_sizes - - -def test_token_based_operations(json_grammar: BNFGrammar): - """Test accepting token and finding the next token mask.""" - token_table = [ - # fmt: off - "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', - # fmt: on - ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}"] - input_ids = [token_table.index(t) for t in input_splitted] - - grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) - - expected = [ - ["{"], - ['"', "}", "\n", " ", '"a":true'], - ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", " "], - ["a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", " "], - [":", "\n", " ", ':"'], - ['"', "{", "6", "\n", " "], - ["}", ", ", "6", "\n", " "], - [" ", "\n", '"', '"a":true'], - [" ", "\n", '"', '"a":true'], - ["}", ", ", "\n", " "], - [""], - ] - - result = [] - - for id in input_ids: - rejected = grammar_state_matcher.find_next_rejected_tokens() - accepted = list(set(range(len(token_table))) - set(rejected)) - accepted_tokens = [token_table[i] for i in accepted] - result.append(accepted_tokens) - assert id in accepted - assert grammar_state_matcher.accept_token(id) - - rejected = grammar_state_matcher.find_next_rejected_tokens() - accepted = list(set(range(len(token_table))) - set(rejected)) - accepted_tokens = [token_table[i] for i in accepted] - result.append(accepted_tokens) - - assert result == expected - - -def test_rollback(json_grammar: BNFGrammar): - token_table = [ - # fmt: off - "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', - # fmt: on - ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}"] - input_ids = [token_table.index(t) for t in input_splitted] - - grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table, 5) - - assert grammar_state_matcher.max_rollback_steps() == 5 - - input_ids_splitted = [input_ids[i : i + 2] for i in range(0, len(input_ids), 2)] - - for i_1, i_2 in input_ids_splitted: - orig_result = [] - orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i_1) - orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i_2) - grammar_state_matcher.rollback(2) - result_after_rollback = [] - result_after_rollback.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i_1) - result_after_rollback.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i_2) - assert orig_result == result_after_rollback - - -def test_reset(json_grammar: BNFGrammar): - token_table = [ - # fmt: off - "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', - # fmt: on - ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}"] - input_ids = [token_table.index(t) for t in input_splitted] - - grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table) - - orig_result = [] - - for i in input_ids: - orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i) - - grammar_state_matcher.reset_state() - - result_after_reset = [] - - for i in input_ids: - result_after_reset.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i) - - assert orig_result == result_after_reset - - -def test_termination(json_grammar: BNFGrammar): - token_table = [ - # fmt: off - "", "", "a", "abc", 'b"', '"', ':"', "{", "}", ", ", "6", ":", "\n", " ", '"a":true', - # fmt: on - ] - input_splitted = ["{", '"', "abc", 'b"', ":", "6", ", ", " ", '"a":true', "}", ""] - input_ids = [token_table.index(t) for t in input_splitted] - - grammar_state_matcher = GrammarStateMatcher(json_grammar, token_table, 5) - - orig_result = [] - - for i in input_ids: - orig_result.append(grammar_state_matcher.find_next_rejected_tokens()) - assert grammar_state_matcher.accept_token(i) - - assert grammar_state_matcher.is_terminated() - - with pytest.raises(TVMError): - grammar_state_matcher.accept_token(0) - - with pytest.raises(TVMError): - grammar_state_matcher.find_next_rejected_tokens() - - grammar_state_matcher.rollback(2) - - assert not grammar_state_matcher.is_terminated() - assert grammar_state_matcher.accept_token(input_ids[-2]) - - -if __name__ == "__main__": - # Run a benchmark to show the performance before running tests - test_find_next_rejected_tokens(BNFGrammar.get_grammar_of_json(), '{"id": 1,"name": "Example"}') - - tvm.testing.main() diff --git a/tests/python/serve/test_json_schema_converter.py b/tests/python/serve/test_json_schema_converter.py deleted file mode 100644 index 84dbd2cb7b..0000000000 --- a/tests/python/serve/test_json_schema_converter.py +++ /dev/null @@ -1,478 +0,0 @@ -import json -from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Tuple, Union - -import tvm.testing -from pydantic import BaseModel, Field, TypeAdapter - -from mlc_llm.serve import BNFGrammar, GrammarStateMatcher - - -def check_schema_with_grammar( - schema: Dict[str, Any], - expected_grammar: str, - indent: Optional[int] = None, - separators: Optional[Tuple[str, str]] = None, - strict_mode: bool = True, -): - schema_str = json.dumps(schema, indent=2) - grammar = BNFGrammar.debug_json_schema_to_ebnf( - schema_str, indent=indent, separators=separators, strict_mode=strict_mode - ) - assert grammar == expected_grammar - - -def check_schema_with_json( - schema: Dict[str, Any], - json_str: str, - check_accepted: bool = True, - indent: Optional[int] = None, - separators: Optional[Tuple[str, str]] = None, - strict_mode: bool = True, -): - ebnf_grammar = BNFGrammar.from_schema( - json.dumps(schema, indent=2), indent=indent, separators=separators, strict_mode=strict_mode - ) - matcher = GrammarStateMatcher(ebnf_grammar) - - if check_accepted: - assert matcher.debug_match_complete_string(json_str) - else: - assert not matcher.debug_match_complete_string(json_str) - - -def check_schema_with_instance( - schema: Dict[str, Any], - instance: BaseModel, - check_accepted: bool = True, - indent: Optional[int] = None, - separators: Optional[Tuple[str, str]] = None, - strict_mode: bool = True, -): - instance_obj = instance.model_dump(mode="json", round_trip=True) - instance_str = json.dumps(instance_obj, indent=indent, separators=separators) - check_schema_with_json(schema, instance_str, check_accepted, indent, separators, strict_mode) - - -def test_basic(): - class MainModel(BaseModel): - integer_field: int - number_field: float - boolean_field: bool - any_array_field: List - array_field: List[str] - tuple_field: Tuple[str, int, List[str]] - object_field: Dict[str, int] - nested_object_field: Dict[str, Dict[str, int]] - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub ["] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_prop_3 ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -main_prop_4 ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" -main_prop_5_item_2 ::= ("[" "" basic_string (", " basic_string)* "" "]") | "[]" -main_prop_5 ::= "[" "" basic_string ", " basic_integer ", " main_prop_5_item_2 "" "]" -main_prop_6 ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" -main_prop_7_addl ::= ("{" "" basic_string ": " basic_integer (", " basic_string ": " basic_integer)* "" "}") | "{}" -main_prop_7 ::= ("{" "" basic_string ": " main_prop_7_addl (", " basic_string ": " main_prop_7_addl)* "" "}") | "{}" -main ::= "{" "" "\"integer_field\"" ": " basic_integer ", " "\"number_field\"" ": " basic_number ", " "\"boolean_field\"" ": " basic_boolean ", " "\"any_array_field\"" ": " main_prop_3 ", " "\"array_field\"" ": " main_prop_4 ", " "\"tuple_field\"" ": " main_prop_5 ", " "\"object_field\"" ": " main_prop_6 ", " "\"nested_object_field\"" ": " main_prop_7 "" "}" -""" - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar) - - instance = MainModel( - integer_field=42, - number_field=3.14e5, - boolean_field=True, - any_array_field=[3.14, "foo", None, True], - array_field=["foo", "bar"], - tuple_field=("foo", 42, ["bar", "baz"]), - object_field={"foo": 42, "bar": 43}, - nested_object_field={"foo": {"bar": 42}}, - ) - check_schema_with_instance(schema, instance) - - instance_empty = MainModel( - integer_field=42, - number_field=3.14e5, - boolean_field=True, - any_array_field=[], - array_field=[], - tuple_field=("foo", 42, []), - object_field={}, - nested_object_field={}, - ) - - schema = MainModel.model_json_schema() - check_schema_with_instance(schema, instance_empty) - - -def test_indent(): - class MainModel(BaseModel): - array_field: List[str] - tuple_field: Tuple[str, int, List[str]] - object_field: Dict[str, int] - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub ["] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any ("," basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" "}") | "{}" -main_prop_0 ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" -main_prop_1_item_2 ::= ("[" "\n " basic_string (",\n " basic_string)* "\n " "]") | "[]" -main_prop_1 ::= "[" "\n " basic_string ",\n " basic_integer ",\n " main_prop_1_item_2 "\n " "]" -main_prop_2 ::= ("{" "\n " basic_string ": " basic_integer (",\n " basic_string ": " basic_integer)* "\n " "}") | "{}" -main ::= "{" "\n " "\"array_field\"" ": " main_prop_0 ",\n " "\"tuple_field\"" ": " main_prop_1 ",\n " "\"object_field\"" ": " main_prop_2 "\n" "}" -""" - - instance = MainModel( - array_field=["foo", "bar"], - tuple_field=("foo", 42, ["bar", "baz"]), - object_field={"foo": 42, "bar": 43}, - ) - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar, indent=2) - check_schema_with_instance(schema, instance, indent=2) - check_schema_with_instance(schema, instance, indent=None, separators=(",", ":")) - - -def test_non_strict(): - class Foo(BaseModel): - pass - - class MainModel(BaseModel): - tuple_field: Tuple[str, Tuple[int, int]] - foo_field: Foo - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub ["] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any ("," basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any ("," basic_string ": " basic_any)* "" "}") | "{}" -main_prop_0_item_1 ::= "[" "\n " basic_integer ",\n " basic_integer (",\n " basic_any)* "\n " "]" -main_prop_0 ::= "[" "\n " basic_string ",\n " main_prop_0_item_1 (",\n " basic_any)* "\n " "]" -main_prop_1 ::= ("{" "\n " basic_string ": " basic_any (",\n " basic_string ": " basic_any)* "\n " "}") | "{}" -main ::= "{" "\n " "\"tuple_field\"" ": " main_prop_0 ",\n " "\"foo_field\"" ": " main_prop_1 (",\n " basic_string ": " basic_any)* "\n" "}" -""" - - instance_json = """{ - "tuple_field": [ - "foo", - [ - 12, - 13, - "ext" - ], - "extra" - ], - "foo_field": { - "tmp": "str" - }, - "extra": "field" -}""" - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar, indent=2, strict_mode=False) - check_schema_with_json(schema, instance_json, indent=2, strict_mode=False) - - -def test_enum_const(): - class Field(Enum): - FOO = "foo" - BAR = "bar" - - class MainModel(BaseModel): - bars: Literal["a"] - str_values: Literal['a\n\r"'] - foo: Literal["a", "b", "c"] - values: Literal[1, "a", True] - field: Field - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub ["] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_prop_0 ::= "\"a\"" -main_prop_1 ::= "\"a\\n\\r\\\"\"" -main_prop_2 ::= ("\"a\"") | ("\"b\"") | ("\"c\"") -main_prop_3 ::= ("1") | ("\"a\"") | ("true") -main_prop_4 ::= ("\"foo\"") | ("\"bar\"") -main ::= "{" "" "\"bars\"" ": " main_prop_0 ", " "\"str_values\"" ": " main_prop_1 ", " "\"foo\"" ": " main_prop_2 ", " "\"values\"" ": " main_prop_3 ", " "\"field\"" ": " main_prop_4 "" "}" -""" - - schema = MainModel.model_json_schema() - instance = MainModel(foo="a", values=1, bars="a", str_values='a\n\r"', field=Field.FOO) - check_schema_with_grammar(schema, ebnf_grammar) - check_schema_with_instance(schema, instance) - - -def test_optional(): - class MainModel(BaseModel): - num: int = 0 - opt_bool: Optional[bool] = None - size: Optional[float] - name: str = "" - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub ["] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_prop_1 ::= basic_boolean | basic_null -main_prop_2 ::= basic_number | basic_null -main ::= "{" "" ("\"num\"" ": " basic_integer ", ")? ("\"opt_bool\"" ": " main_prop_1 ", ")? "\"size\"" ": " main_prop_2 (", " "\"name\"" ": " basic_string)? "" "}" -""" - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar) - - instance = MainModel(num=42, opt_bool=True, size=3.14, name="foo") - check_schema_with_instance(schema, instance) - - instance = MainModel(size=None) - check_schema_with_instance(schema, instance) - - check_schema_with_json(schema, '{"size": null}') - check_schema_with_json(schema, '{"size": null, "name": "foo"}') - check_schema_with_json(schema, '{"num": 1, "size": null, "name": "foo"}') - - -def test_all_optional(): - class MainModel(BaseModel): - size: int = 0 - state: bool = False - num: float = 0 - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub ["] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_part_1 ::= "" | ", " "\"num\"" ": " basic_number "" -main_part_0 ::= main_part_1 | ", " "\"state\"" ": " basic_boolean main_part_1 -main ::= ("{" "" (("\"size\"" ": " basic_integer main_part_0) | ("\"state\"" ": " basic_boolean main_part_1) | ("\"num\"" ": " basic_number "")) "" "}") | "{}" -""" - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar) - - instance = MainModel(size=42, state=True, num=3.14) - check_schema_with_instance(schema, instance) - - check_schema_with_json(schema, '{"state": false}') - check_schema_with_json(schema, '{"size": 1, "num": 1.5}') - - ebnf_grammar_non_strict = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub ["] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_part_2 ::= (", " basic_string ": " basic_any)* -main_part_1 ::= main_part_2 | ", " "\"num\"" ": " basic_number main_part_2 -main_part_0 ::= main_part_1 | ", " "\"state\"" ": " basic_boolean main_part_1 -main ::= ("{" "" (("\"size\"" ": " basic_integer main_part_0) | ("\"state\"" ": " basic_boolean main_part_1) | ("\"num\"" ": " basic_number main_part_2) | basic_string ": " basic_any main_part_2) "" "}") | "{}" -""" - - check_schema_with_grammar(schema, ebnf_grammar_non_strict, strict_mode=False) - - check_schema_with_json(schema, '{"size": 1, "num": 1.5, "other": false}', strict_mode=False) - check_schema_with_json(schema, '{"other": false}', strict_mode=False) - - -def test_empty(): - class MainModel(BaseModel): - pass - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub ["] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main ::= "{" "}" -""" - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar) - - instance = MainModel() - check_schema_with_instance(schema, instance) - - check_schema_with_json(schema, '{"tmp": 123}', strict_mode=False) - - -def test_reference(): - class Foo(BaseModel): - count: int - size: Optional[float] = None - - class Bar(BaseModel): - apple: str = "x" - banana: str = "y" - - class MainModel(BaseModel): - foo: Foo - bars: List[Bar] - - instance = MainModel( - foo=Foo(count=42, size=3.14), - bars=[Bar(apple="a", banana="b"), Bar(apple="c", banana="d")], - ) - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub ["] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_prop_0_prop_1 ::= basic_number | basic_null -main_prop_0 ::= "{" "" "\"count\"" ": " basic_integer (", " "\"size\"" ": " main_prop_0_prop_1)? "" "}" -main_prop_1_items_part_0 ::= "" | ", " "\"banana\"" ": " basic_string "" -main_prop_1_items ::= ("{" "" (("\"apple\"" ": " basic_string main_prop_1_items_part_0) | ("\"banana\"" ": " basic_string "")) "" "}") | "{}" -main_prop_1 ::= ("[" "" main_prop_1_items (", " main_prop_1_items)* "" "]") | "[]" -main ::= "{" "" "\"foo\"" ": " main_prop_0 ", " "\"bars\"" ": " main_prop_1 "" "}" -""" - - schema = MainModel.model_json_schema() - check_schema_with_grammar(schema, ebnf_grammar) - check_schema_with_instance(schema, instance) - - -def test_union(): - class Cat(BaseModel): - name: str - color: str - - class Dog(BaseModel): - name: str - breed: str - - ta = TypeAdapter(Union[Cat, Dog]) - - model_schema = ta.json_schema() - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub ["] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_case_0 ::= "{" "" "\"name\"" ": " basic_string ", " "\"color\"" ": " basic_string "" "}" -main_case_1 ::= "{" "" "\"name\"" ": " basic_string ", " "\"breed\"" ": " basic_string "" "}" -main ::= main_case_0 | main_case_1 -""" - - check_schema_with_grammar(model_schema, ebnf_grammar) - - check_schema_with_instance(model_schema, Cat(name="kitty", color="black")) - check_schema_with_instance(model_schema, Dog(name="doggy", breed="bulldog")) - check_schema_with_json(model_schema, '{"name": "kitty", "test": "black"}', False) - - -def test_alias(): - class MainModel(BaseModel): - test: str = Field(..., alias="name") - - ebnf_grammar = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub ["] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main ::= "{" "" "\"name\"" ": " basic_string "" "}" -""" - - check_schema_with_grammar(MainModel.model_json_schema(), ebnf_grammar) - - instance = MainModel(name="kitty") - instance_str = json.dumps(instance.model_dump(mode="json", round_trip=True, by_alias=False)) - check_schema_with_json(MainModel.model_json_schema(by_alias=False), instance_str) - - instance_str = json.dumps(instance.model_dump(mode="json", round_trip=True, by_alias=True)) - check_schema_with_json(MainModel.model_json_schema(by_alias=True), instance_str) - - # property name contains space - class MainModelSpace(BaseModel): - test: Literal["abc"] = Field(..., alias="name 1") - - ebnf_grammar_space = r"""basic_escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -basic_string_sub ::= "" | [^"\\\r\n] basic_string_sub | "\\" basic_escape basic_string_sub -basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object -basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"? -basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)? -basic_string ::= ["] basic_string_sub ["] -basic_boolean ::= "true" | "false" -basic_null ::= "null" -basic_array ::= ("[" "" basic_any (", " basic_any)* "" "]") | "[]" -basic_object ::= ("{" "" basic_string ": " basic_any (", " basic_string ": " basic_any)* "" "}") | "{}" -main_prop_0 ::= "\"abc\"" -main ::= "{" "" "\"name 1\"" ": " main_prop_0 "" "}" -""" - - check_schema_with_grammar(MainModelSpace.model_json_schema(), ebnf_grammar_space) - - instance_space = MainModelSpace(**{"name 1": "abc"}) - instance_space_str = json.dumps( - instance_space.model_dump(mode="json", round_trip=True, by_alias=True) - ) - check_schema_with_json(MainModelSpace.model_json_schema(by_alias=True), instance_space_str) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/serve/test_radix_tree.py b/tests/python/serve/test_radix_tree.py deleted file mode 100644 index cea421cd95..0000000000 --- a/tests/python/serve/test_radix_tree.py +++ /dev/null @@ -1,79 +0,0 @@ -from tvm import TVMError -from tvm.runtime import ShapeTuple - -from mlc_llm.serve import PagedRadixTree - - -def test_add(): - prt = PagedRadixTree(16, 128, 16) - prt.add(0) - assert prt.get(0) == [] - - -def test_remove(): - prt = PagedRadixTree(32, 128, 16) - capacity = prt.free_capacity() - prt.add(0) - prt.remove(0) - prt.add(0) - prt.extend(0, [1 for _ in range(200)]) - prt.remove(0) - assert prt.free_capacity() == capacity - - prt.add(1) - prt.extend(1, [1 for _ in range(200)]) - capacity = prt.free_capacity() - prt.add(2) - prt.extend(2, [1 for _ in range(100)] + [2 for _ in range(100)]) - prt.remove(2) - assert prt.free_capacity() == capacity - - prt.add(3) - prt.extend(3, [1 for _ in range(200)]) - prt.remove(3) - assert prt.free_capacity() == capacity - - -def test_extend(): - prt = PagedRadixTree(1024, 256, 256) - L = prt.free_capacity() // 1024 - H = L // 2 - Q = L // 4 - seq_id = 0 - for start_pos in [0, H, L, L + H]: - for length in [Q, L - H, L, 2 * L - H, 2 * L]: - prt.add(seq_id) - if start_pos: - tokens_1 = [seq_id for _ in range(start_pos)] - prt.extend(seq_id, tokens_1) - assert prt.get(seq_id) == tokens_1 - else: - tokens_1 = [] - tokens_2 = [seq_id for _ in range(length)] - prt.extend(seq_id, tokens_2) - assert prt.get(seq_id) == tokens_1 + tokens_2 - seq_id += 1 - - -def test_fork(): - prt = PagedRadixTree(1024, 256, 256) - L = prt.free_capacity() // 1024 - H = L // 2 - Q = L // 4 - seq_id = 0 - length_list = [Q, H, L, L + Q, L + H, L * 2] - for p_idx in range(1, len(length_list)): - for c_idx in range(0, p_idx + 1): - prt.add(seq_id) - tokens = [seq_id for _ in range(length_list[p_idx])] - prt.extend(seq_id, tokens) - prt.fork(seq_id + 1, seq_id, length_list[c_idx]) - assert prt.get(seq_id + 1) == tokens[: length_list[c_idx]] - seq_id += 2 - - -if __name__ == "__main__": - test_add() - test_remove() - test_extend() - test_fork() diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py deleted file mode 100644 index 6e3835238a..0000000000 --- a/tests/python/serve/test_serve_async_engine.py +++ /dev/null @@ -1,290 +0,0 @@ -# pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -import asyncio -from typing import List - -from mlc_llm.serve import AsyncMLCEngine, GenerationConfig - -prompts = [ - "What is the meaning of life?", - "Introduce the history of Pittsburgh to me. Please elaborate in detail.", - "Write a three-day Seattle travel plan. Please elaborate in detail.", - "What is Alaska famous of? Please elaborate in detail.", - "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", - "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", - "Why is Vitamin D important to human beings? Please elaborate in detail.", - "Where is milk tea originated from? Please elaborate in detail.", - "Where is the southernmost place in United States? Please elaborate in detail.", - "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", -] - - -async def test_engine_generate(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) - - num_requests = 10 - max_tokens = 256 - generation_cfg = GenerationConfig(max_tokens=max_tokens, n=7) - - output_texts: List[List[str]] = [ - ["" for _ in range(generation_cfg.n)] for _ in range(num_requests) - ] - - async def generate_task( - async_engine: AsyncMLCEngine, - prompt: str, - generation_cfg: GenerationConfig, - request_id: str, - ): - print(f"generate task for request {request_id}") - rid = int(request_id) - async for delta_outputs in async_engine._generate( - prompt, generation_cfg, request_id=request_id - ): - assert len(delta_outputs) == generation_cfg.n - for i, delta_output in enumerate(delta_outputs): - output_texts[rid][i] += delta_output.delta_text - - tasks = [ - asyncio.create_task( - generate_task(async_engine, prompts[i], generation_cfg, request_id=str(i)) - ) - for i in range(num_requests) - ] - - await asyncio.gather(*tasks) - - # Print output. - print("All finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - async_engine.terminate() - del async_engine - - -async def test_chat_completion(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) - - num_requests = 2 - max_tokens = 32 - n = 1 - output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] - - async def generate_task(prompt: str, request_id: str): - print(f"generate chat completion task for request {request_id}") - rid = int(request_id) - async for response in await async_engine.chat.completions.create( - messages=[{"role": "user", "content": prompt}], - model=model, - max_tokens=max_tokens, - n=n, - request_id=request_id, - stream=True, - ): - for choice in response.choices: - assert choice.delta.role == "assistant" - output_texts[rid][choice.index] += choice.delta.content - - tasks = [ - asyncio.create_task(generate_task(prompts[i], request_id=str(i))) - for i in range(num_requests) - ] - - await asyncio.gather(*tasks) - - # Print output. - print("Chat completion all finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - async_engine.terminate() - del async_engine - - -async def test_chat_completion_non_stream(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) - - num_requests = 2 - max_tokens = 32 - n = 1 - output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] - - async def generate_task(prompt: str, request_id: str): - print(f"generate chat completion task for request {request_id}") - rid = int(request_id) - response = await async_engine.chat.completions.create( - messages=[{"role": "user", "content": prompt}], - model=model, - max_tokens=max_tokens, - n=n, - request_id=request_id, - ) - for choice in response.choices: - assert choice.message.role == "assistant" - output_texts[rid][choice.index] += choice.message.content - - tasks = [ - asyncio.create_task(generate_task(prompts[i], request_id=str(i))) - for i in range(num_requests) - ] - - await asyncio.gather(*tasks) - - # Print output. - print("Chat completion all finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - async_engine.terminate() - del async_engine - - -async def test_completion(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) - - num_requests = 2 - max_tokens = 128 - n = 1 - output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] - - async def generate_task(prompt: str, request_id: str): - print(f"generate completion task for request {request_id}") - rid = int(request_id) - async for response in await async_engine.completions.create( - prompt=prompt, - model=model, - max_tokens=max_tokens, - n=n, - ignore_eos=True, - request_id=request_id, - stream=True, - ): - for choice in response.choices: - output_texts[rid][choice.index] += choice.text - - tasks = [ - asyncio.create_task(generate_task(prompts[i], request_id=str(i))) - for i in range(num_requests) - ] - - await asyncio.gather(*tasks) - - # Print output. - print("Completion all finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - async_engine.terminate() - del async_engine - - -async def test_completion_non_stream(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - async_engine = AsyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) - - num_requests = 2 - max_tokens = 128 - n = 1 - output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] - - async def generate_task(prompt: str, request_id: str): - print(f"generate completion task for request {request_id}") - rid = int(request_id) - response = await async_engine.completions.create( - prompt=prompt, - model=model, - max_tokens=max_tokens, - n=n, - ignore_eos=True, - request_id=request_id, - ) - for choice in response.choices: - output_texts[rid][choice.index] += choice.text - - tasks = [ - asyncio.create_task(generate_task(prompts[i], request_id=str(i))) - for i in range(num_requests) - ] - - await asyncio.gather(*tasks) - - # Print output. - print("Completion all finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - async_engine.terminate() - del async_engine - - -if __name__ == "__main__": - asyncio.run(test_engine_generate()) - asyncio.run(test_chat_completion()) - asyncio.run(test_chat_completion_non_stream()) - asyncio.run(test_completion()) - asyncio.run(test_completion_non_stream()) diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py deleted file mode 100644 index c3963af613..0000000000 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ /dev/null @@ -1,85 +0,0 @@ -# pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals -import asyncio -from typing import List - -from mlc_llm.serve import AsyncMLCEngine, GenerationConfig, SpeculativeMode - -prompts = [ - "What is the meaning of life?", - "Introduce the history of Pittsburgh to me. Please elaborate in detail.", - "Write a three-day Seattle travel plan. Please elaborate in detail.", - "What is Alaska famous of? Please elaborate in detail.", - "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", - "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", - "Why is Vitamin D important to human beings? Please elaborate in detail.", - "Where is milk tea originated from? Please elaborate in detail.", - "Where is the southernmost place in United States? Please elaborate in detail.", - "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", -] - - -async def test_engine_generate(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) - async_engine = AsyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.SMALL_DRAFT, - ) - - num_requests = 10 - max_tokens = 256 - generation_cfg = GenerationConfig(max_tokens=max_tokens) - - output_texts: List[List[str]] = [ - ["" for _ in range(generation_cfg.n)] for _ in range(num_requests) - ] - - async def generate_task( - async_engine: AsyncMLCEngine, - prompt: str, - generation_cfg: GenerationConfig, - request_id: str, - ): - print(f"generate task for request {request_id}") - rid = int(request_id) - async for delta_outputs in async_engine._generate( - prompt, generation_cfg, request_id=request_id - ): - assert len(delta_outputs) == generation_cfg.n - for i, delta_output in enumerate(delta_outputs): - output_texts[rid][i] += delta_output.delta_text - - tasks = [ - asyncio.create_task( - generate_task(async_engine, prompts[i], generation_cfg, request_id=str(i)) - ) - for i in range(num_requests) - ] - - await asyncio.gather(*tasks) - - # Print output. - print("All finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - async_engine.terminate() - del async_engine - - -if __name__ == "__main__": - asyncio.run(test_engine_generate()) diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py deleted file mode 100644 index 37d1833b14..0000000000 --- a/tests/python/serve/test_serve_engine.py +++ /dev/null @@ -1,237 +0,0 @@ -# pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -from typing import List - -import pytest - -from mlc_llm.serve import GenerationConfig, MLCEngine - -prompts = [ - "What is the meaning of life?", - "Introduce the history of Pittsburgh to me. Please elaborate in detail.", - "Write a three-day Seattle travel plan. Please elaborate in detail.", - "What is Alaska famous of? Please elaborate in detail.", - "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", - "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", - "Why is Vitamin D important to human beings? Please elaborate in detail.", - "Where is milk tea originated from? Please elaborate in detail.", - "Where is the southernmost place in United States? Please elaborate in detail.", - "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", -] - -test_models = [ - ( - "dist/Llama-2-7b-chat-hf-q0f16-MLC", - "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so", - ), - ( - "dist/rwkv-6-world-1b6-q0f16-MLC", - "dist/rwkv-6-world-1b6-q0f16-MLC/rwkv-6-world-1b6-q0f16-MLC-cuda.so", - ), -] - - -def create_engine(model: str, model_lib_path: str): - if "rwkv" in model: - return MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_batch_size=8, - max_history_size=1, - ) - else: - return MLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) - - -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_engine_generate(model: str, model_lib_path: str): - engine = create_engine(model, model_lib_path) - - num_requests = 10 - max_tokens = 256 - generation_cfg = GenerationConfig(max_tokens=max_tokens, n=7) - - output_texts: List[List[str]] = [ - ["" for _ in range(generation_cfg.n)] for _ in range(num_requests) - ] - for rid in range(num_requests): - print(f"generating for request {rid}") - for delta_outputs in engine._generate(prompts[rid], generation_cfg, request_id=str(rid)): - assert len(delta_outputs) == generation_cfg.n - for i, delta_output in enumerate(delta_outputs): - output_texts[rid][i] += delta_output.delta_text - - # Print output. - print("All finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - engine.terminate() - del engine - - -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_chat_completion(model: str, model_lib_path: str): - # Create engine - engine = create_engine(model, model_lib_path) - - num_requests = 2 - max_tokens = 64 - n = 2 - output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] - - for rid in range(num_requests): - print(f"chat completion for request {rid}") - for response in engine.chat.completions.create( - messages=[{"role": "user", "content": prompts[rid]}], - model=model, - max_tokens=max_tokens, - n=n, - request_id=str(rid), - stream=True, - ): - for choice in response.choices: - assert choice.delta.role == "assistant" - output_texts[rid][choice.index] += choice.delta.content - - # Print output. - print("Chat completion all finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - engine.terminate() - del engine - - -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_chat_completion_non_stream(model: str, model_lib_path: str): - engine = create_engine(model, model_lib_path) - - num_requests = 2 - max_tokens = 64 - n = 2 - output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] - - for rid in range(num_requests): - print(f"chat completion for request {rid}") - response = engine.chat.completions.create( - messages=[{"role": "user", "content": prompts[rid]}], - model=model, - max_tokens=max_tokens, - n=n, - request_id=str(rid), - ) - for choice in response.choices: - assert choice.message.role == "assistant" - output_texts[rid][choice.index] += choice.message.content - - # Print output. - print("Chat completion all finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - engine.terminate() - del engine - - -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_completion(model: str, model_lib_path: str): - engine = create_engine(model, model_lib_path) - - num_requests = 2 - max_tokens = 128 - n = 1 - output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] - - for rid in range(num_requests): - print(f"completion for request {rid}") - for response in engine.completions.create( - prompt=prompts[rid], - model=model, - max_tokens=max_tokens, - n=n, - ignore_eos=True, - request_id=str(rid), - stream=True, - ): - for choice in response.choices: - output_texts[rid][choice.index] += choice.text - - # Print output. - print("Completion all finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - engine.terminate() - del engine - - -@pytest.mark.parametrize("model,model_lib_path", test_models) -def test_completion_non_stream(model: str, model_lib_path: str): - engine = create_engine(model, model_lib_path) - - num_requests = 2 - max_tokens = 128 - n = 1 - output_texts: List[List[str]] = [["" for _ in range(n)] for _ in range(num_requests)] - - for rid in range(num_requests): - print(f"completion for request {rid}") - response = engine.completions.create( - prompt=prompts[rid], - model=model, - max_tokens=max_tokens, - n=n, - ignore_eos=True, - request_id=str(rid), - ) - for choice in response.choices: - output_texts[rid][choice.index] += choice.text - - # Print output. - print("Completion all finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - engine.terminate() - del engine - - -if __name__ == "__main__": - for model, model_lib_path in test_models: - test_engine_generate(model, model_lib_path) - test_chat_completion(model, model_lib_path) - test_chat_completion_non_stream(model, model_lib_path) - test_completion(model, model_lib_path) - test_completion_non_stream(model, model_lib_path) diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py deleted file mode 100644 index b764c62cd2..0000000000 --- a/tests/python/serve/test_serve_engine_grammar.py +++ /dev/null @@ -1,199 +0,0 @@ -# pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -import asyncio -import json -from typing import List - -import pytest -from pydantic import BaseModel - -from mlc_llm.serve import AsyncMLCEngine, GenerationConfig -from mlc_llm.serve.config import ResponseFormat -from mlc_llm.serve.sync_engine import SyncMLCEngine - -prompts_list = [ - "Generate a JSON string containing 20 objects:", - "Generate a JSON containing a list:", - "Generate a JSON with 5 elements:", -] -model_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" -model_lib_path = "dist/libs/Llama-2-7b-chat-hf-q4f16_1-cuda.so" - - -def test_batch_generation_with_grammar(): - # Create engine - engine = SyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") - - prompt_len = len(prompts_list) - prompts = prompts_list * 3 - - temperature = 1 - repetition_penalty = 1 - max_tokens = 512 - generation_config_no_json = GenerationConfig( - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens=max_tokens, - stop_token_ids=[2], - response_format=ResponseFormat(type="text"), - ) - generation_config_json = GenerationConfig( - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens=max_tokens, - stop_token_ids=[2], - response_format=ResponseFormat(type="json_object"), - ) - generation_config_json_no_stop_token = GenerationConfig( - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens=max_tokens, - response_format=ResponseFormat(type="json_object"), - ) - all_generation_configs = ( - [generation_config_no_json] * prompt_len - + [generation_config_json] * prompt_len - + [generation_config_json_no_stop_token] * prompt_len - ) - - # Generate output. - output_texts, _ = engine.generate(prompts, all_generation_configs) - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - -def test_batch_generation_with_schema(): - # Create engine - engine = SyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") - - prompt = ( - "Generate a json containing three fields: an integer field named size, a " - "boolean field named is_accepted, and a float field named num:" - ) - repeat_cnt = 3 - prompts = [prompt] * repeat_cnt * 2 - - temperature = 1 - repetition_penalty = 1 - max_tokens = 512 - generation_config_no_json = GenerationConfig( - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens=max_tokens, - stop_token_ids=[2], - response_format=ResponseFormat(type="text"), - ) - - class Schema(BaseModel): - size: int - is_accepted: bool - num: float - - schema_str = json.dumps(Schema.model_json_schema()) - - generation_config_json = GenerationConfig( - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens=max_tokens, - stop_token_ids=[2], - response_format=ResponseFormat(type="json_object", schema=schema_str), - ) - - all_generation_configs = [generation_config_no_json] * repeat_cnt + [ - generation_config_json - ] * repeat_cnt - - # Generate output. - output_texts, _ = engine.generate(prompts, all_generation_configs) - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}: {outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}): {output}\n") - - -async def run_async_engine(): - # Create engine - async_engine = AsyncMLCEngine(model=model_path, model_lib_path=model_lib_path, mode="server") - - prompts = prompts_list * 20 - - max_tokens = 256 - temperature = 1 - repetition_penalty = 1 - max_tokens = 512 - generation_config = GenerationConfig( - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens=max_tokens, - stop_token_ids=[2], - response_format=ResponseFormat(type="json_object"), - ) - - output_texts: List[List[str]] = [ - ["" for _ in range(generation_config.n)] for _ in range(len(prompts)) - ] - - async def generate_task( - async_engine: AsyncMLCEngine, - prompt: str, - generation_cfg: GenerationConfig, - request_id: str, - ): - print(f"Start generation task for request {request_id}") - rid = int(request_id) - async for delta_outputs in async_engine._generate( - prompt, generation_cfg, request_id=request_id - ): - assert len(delta_outputs) == generation_cfg.n - for i, delta_output in enumerate(delta_outputs): - output_texts[rid][i] += delta_output.delta_text - - tasks = [ - asyncio.create_task( - generate_task(async_engine, prompts[i], generation_config, request_id=str(i)) - ) - for i in range(len(prompts)) - ] - - await asyncio.gather(*tasks) - - # Print output. - print("All finished") - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - async_engine.terminate() - - -def test_async_engine(): - asyncio.run(run_async_engine()) - - -def test_generation_config_error(): - with pytest.raises(ValueError): - GenerationConfig( - temperature=1.0, - repetition_penalty=1.0, - max_tokens=128, - stop_token_ids=[2], - response_format=ResponseFormat(type="text", schema="{}"), - ) - - -if __name__ == "__main__": - test_batch_generation_with_grammar() - test_async_engine() - test_generation_config_error() diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py deleted file mode 100644 index 59e8c97196..0000000000 --- a/tests/python/serve/test_serve_engine_image.py +++ /dev/null @@ -1,55 +0,0 @@ -import json -from pathlib import Path - -from mlc_llm.serve import GenerationConfig, data -from mlc_llm.serve.sync_engine import SyncMLCEngine - - -def get_test_image(config) -> data.ImageData: - return data.ImageData.from_url("https://llava-vl.github.io/static/images/view.jpg", config) - - -def test_engine_generate(): - # Create engine - model = "dist/llava-1.5-7b-hf-q4f16_1-MLC/params" - model_lib_path = "dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so" - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) - max_tokens = 256 - - with open(Path(model) / "mlc-chat-config.json", "r", encoding="utf-8") as file: - model_config = json.load(file) - - prompts = [ - [ - data.TextData("USER: "), - get_test_image(model_config), - data.TextData("\nWhat does this image represent? ASSISTANT:"), - ], - [ - data.TextData("USER: "), - get_test_image(model_config), - data.TextData("\nIs there a dog in this image? ASSISTANT:"), - ], - [data.TextData("USER: What is the meaning of life? ASSISTANT:")], - ] - - output_texts, _ = engine.generate( - prompts, GenerationConfig(max_tokens=max_tokens, stop_token_ids=[2]) - ) - - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - -if __name__ == "__main__": - test_engine_generate() diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py deleted file mode 100644 index 33c06b1c5e..0000000000 --- a/tests/python/serve/test_serve_engine_spec.py +++ /dev/null @@ -1,691 +0,0 @@ -# pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals -from typing import Callable, List, Optional - -import numpy as np - -from mlc_llm.serve import ( - GenerationConfig, - Request, - RequestStreamOutput, - SpeculativeMode, - data, -) -from mlc_llm.serve.sync_engine import SyncMLCEngine - -prompts = [ - "What is the meaning of life?", - "Introduce the history of Pittsburgh to me. Please elaborate in detail.", - "Write a three-day Seattle travel plan. Please elaborate in detail.", - "What is Alaska famous of? Please elaborate in detail.", - "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", - "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", - "Why is Vitamin D important to human beings? Please elaborate in detail.", - "Where is milk tea originated from? Please elaborate in detail.", - "Where is the southernmost place in United States? Please elaborate in detail.", - "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", -] - - -def create_requests( - num_requests: int, - stop_token_id: Optional[int] = None, - temperature: float = 0.8, - repetition_penalty: float = 1.0, - max_tokens_low: int = 256, - max_tokens_high: int = 257, -) -> List[Request]: - assert num_requests >= 0 and num_requests <= len(prompts) - - stop_token_ids = [stop_token_id] if stop_token_id is not None else [] - requests = [] - for req_id, prompt in zip(range(num_requests), prompts): - max_tokens = np.random.randint(max_tokens_low, max_tokens_high) - requests.append( - Request( - request_id=str(req_id), - inputs=data.TextData(prompt), - generation_config=GenerationConfig( - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - ), - ) - ) - return requests - - -def test_engine_basic(): - """Test engine **without continuous batching**. - - - Add all requests to the engine altogether in the beginning. - - All requests have the same max_tokens. This means all requests - will end together. - - Engine keeps running `step` for estimated number of steps (number of - requests + max_tokens - 1). Then check the output of each request. - """ - - # Hyperparameters for tests (you can try different combinations). - num_requests = len(prompts) # [4, 8, 10] - temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.0 # [1.0, 1.01] - max_tokens: int = 256 # [32, 128, 256] - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - - # Define the callback function for request generation results - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.SMALL_DRAFT, - request_stream_callback=fcallback, - ) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens, - max_tokens_high=max_tokens + 1, - ) - - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max_tokens - 1 - # Run steps - for step in range(num_steps): - engine.step() - - for req_id, output in enumerate(outputs): - print(f"Prompt {req_id}: {requests[req_id].inputs[0]}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - - -def test_engine_eagle_basic(): - """Test engine **without continuous batching**. - - - Add all requests to the engine altogether in the beginning. - - All requests have the same max_tokens. This means all requests - will end together. - - Engine keeps running `step` for estimated number of steps (number of - requests + max_tokens - 1). Then check the output of each request. - - Use Eagle model as speculative model - """ - - # Hyperparameters for tests (you can try different combinations). - num_requests = len(prompts) # [4, 8, 10] - temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.0 # [1.0, 1.01] - max_tokens: int = 256 # [32, 128, 256] - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - - # Define the callback function for request generation results - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - small_model = "dist/Eagle-llama2-7b-chat-q0f16-MLC" - small_model_lib_path = ( - "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" - ) - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.EAGLE, - spec_draft_length=2, - request_stream_callback=fcallback, - ) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens, - max_tokens_high=max_tokens + 1, - ) - - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max_tokens - 1 - # Run steps - for step in range(num_steps): - engine.step() - - for req_id, output in enumerate(outputs): - print(f"Prompt {req_id}: {requests[req_id].inputs[0]}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - - -def test_engine_continuous_batching_1(): - """Test engine **with continuous batching**. - - - Add all requests to the engine altogether in the beginning. - - All requests have a random maximum generation length. So each - request keeps generating until reaching the maximum length. - - Engine keeps running `step` for estimated number of steps (number of - requests + the maximum max_tokens - 1). Then check the output - of each request. - """ - - # Hyperparameters for tests (you can try different combinations) - num_requests = len(prompts) # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.00 # [1.0, 1.01] - max_tokens_low = 128 - max_tokens_high = 384 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - finish_time = [None] * num_requests - - # Define the callback class for request generation results - class CallbackTimer: - timer: int = -1 - - def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - if stream_outputs[0].finish_reason is not None: - print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - finish_time[int(request_id)] = self.timer - - return fcallback - - def step(self) -> None: - self.timer += 1 - - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) - timer = CallbackTimer() - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.SMALL_DRAFT, - request_stream_callback=timer.callback_getter(), - ) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens_low, - max_tokens_high=max_tokens_high, - ) - - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1 - # Run steps - for step in range(num_steps): - timer.step() - assert timer.timer == step - engine.step() - - for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): - print(f"Prompt {req_id}: {request.inputs[0]}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - # assert fin_time == request.generation_config.max_tokens - 1 - - -def test_engine_eagle_continuous_batching_1(): - """Test engine **with continuous batching**. - - - Add all requests to the engine altogether in the beginning. - - All requests have a random maximum generation length. So each - request keeps generating until reaching the maximum length. - - Engine keeps running `step` for estimated number of steps (number of - requests + the maximum max_tokens - 1). Then check the output - of each request. - """ - - # Hyperparameters for tests (you can try different combinations) - num_requests = len(prompts) # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.00 # [1.0, 1.01] - max_tokens_low = 128 - max_tokens_high = 384 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - finish_time = [None] * num_requests - - # Define the callback class for request generation results - class CallbackTimer: - timer: int = -1 - - def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - if stream_outputs[0].finish_reason is not None: - print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - finish_time[int(request_id)] = self.timer - - return fcallback - - def step(self) -> None: - self.timer += 1 - - # Create engine - model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - small_model = "dist/Eagle-llama2-7b-chat-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" - ) - timer = CallbackTimer() - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.EAGLE, - request_stream_callback=timer.callback_getter(), - ) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens_low, - max_tokens_high=max_tokens_high, - ) - - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1 - # Run steps - for step in range(num_steps): - timer.step() - assert timer.timer == step - engine.step() - - for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): - print(f"Prompt {req_id}: {request.inputs[0]}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - # assert fin_time == request.generation_config.max_tokens - 1 - - -def compare_output_text(output_text1, output_text2): - if isinstance(output_text1, list) and isinstance(output_text2, list): - for item1, item2 in zip(output_text1, output_text2): - if not compare_output_text(item1, item2): - return False - elif output_text1 != output_text2: - print(output_text1) - print(output_text2) - return False - return True - - -def test_engine_generate(compare_precision=False): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) - - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.SMALL_DRAFT, - ) - - num_requests = 10 - max_tokens = 256 - - # Generate output. - if compare_precision: - print("compare precision") - generation_config = GenerationConfig( - temperature=0.0, top_p=0, max_tokens=1024, stop_token_ids=[2], n=1 - ) - engine_single_model = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) - output_texts_single_model, _ = engine_single_model.generate( - prompts[:num_requests], generation_config - ) - for req_id, outputs in enumerate(output_texts_single_model): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - # TODO: Add pytorch precision - else: - generation_config = GenerationConfig(max_tokens=max_tokens, n=3) - output_texts, _ = engine.generate(prompts[:num_requests], generation_config) - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - if compare_precision: - precision_flag = compare_output_text(output_texts, output_texts_single_model) - if precision_flag: - print(f"Accuracy verification succeed\n") - else: - print(f"Accuracy verification failed\n") - - -def test_engine_eagle_generate(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - small_model = "dist/Eagle-llama2-7b-chat-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so" - ) - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - speculative_mode=SpeculativeMode.EAGLE, - ) - - num_requests = 10 - max_tokens = 256 - - # Generate output. - output_texts, _ = engine.generate( - prompts[:num_requests], GenerationConfig(max_tokens=max_tokens, n=3) - ) - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - -def test_engine_efficiency(): - """Test engine speculative decoding efficiency.""" - - # Hyperparameters for tests (you can try different combinations). - num_requests = 1 # [4, 8, 10] - temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.0 # [1.0, 1.01] - max_tokens: int = 512 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - - # Define the callback function for request generation results - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - - # Create engine - model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - request_stream_callback=fcallback, - ) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens, - max_tokens_high=max_tokens + 1, - ) - - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max_tokens - 1 - # Run steps - for step in range(num_steps): - engine.step() - - for eg, name in zip([engine], ["Normal Deconding"]): - stats = eg.stats() - print("engine name:", name) - if name == "Speculative Decoding": - print("total draft tokens:", stats["total_draft_tokens"]) - print("total accepted tokens:", stats["total_accepted_tokens"]) - print( - "Accept rate:", - stats["total_accepted_tokens"] / (1e-10 + stats["total_draft_tokens"]), - ) - print("engine total decode time:", stats["engine_total_decode_time"]) - print() - - -def test_engine_spec_efficiency(): - """Test engine speculative decoding efficiency.""" - - # Hyperparameters for tests (you can try different combinations). - num_requests = 1 # [4, 8, 10] - temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.0 # [1.0, 1.01] - max_tokens: int = 512 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - - # Define the callback function for request generation results - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - - # Create engine - model = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-13b-chat-hf-q4f16_1-MLC/Llama-2-13b-chat-hf-q4f16_1-MLC-cuda.so" - small_model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - small_model_lib_path = ( - "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - ) - # If Flashinfer allows head_dim < 128, we can test this model - # small_model = "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC" - # small_model_lib_path = ( - # "dist/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC/TinyLlama-1.1B-Chat-v1.0-q0f16-MLC-cuda.so" - # ) - spec_engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - spec_draft_length=6, - speculative_mode=SpeculativeMode.SMALL_DRAFT, - request_stream_callback=fcallback, - ) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens, - max_tokens_high=max_tokens + 1, - ) - - # Add all requests to engine - for request in requests: - spec_engine.add_request(request) - - num_steps = num_requests + max_tokens - 1 - # Run steps - for step in range(num_steps): - spec_engine.step() - - for eg, name in zip([spec_engine], ["Speculative Decoding"]): - stats = eg.stats() - print("engine name:", name) - if name == "Speculative Decoding": - print("total draft tokens:", stats["total_draft_tokens"]) - print("total accepted tokens:", stats["total_accepted_tokens"]) - print( - "Accept rate:", - stats["total_accepted_tokens"] / (1e-10 + stats["total_draft_tokens"]), - ) - print("engine total decode time:", stats["engine_total_decode_time"]) - print() - - -def test_engine_eagle_spec_efficiency(): - """Test engine speculative decoding efficiency.""" - - # Hyperparameters for tests (you can try different combinations). - num_requests = 1 # [4, 8, 10] - temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.0 # [1.0, 1.01] - max_tokens: int = 512 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - - # Define the callback function for request generation results - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - - # Create engine - model = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC/Llama-2-7b-chat-hf-q4f16_1-MLC-cuda.so" - small_model = "dist/Eagle-llama2-7b-chat-q0f16-MLC" - small_model_lib_path = ( - "dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so" - ) - spec_engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - additional_models=[small_model + ":" + small_model_lib_path], - spec_draft_length=6, - speculative_mode=SpeculativeMode.EAGLE, - request_stream_callback=fcallback, - ) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens, - max_tokens_high=max_tokens + 1, - ) - - # Add all requests to engine - for request in requests: - spec_engine.add_request(request) - - num_steps = num_requests + max_tokens - 1 - # Run steps - for step in range(num_steps): - spec_engine.step() - - for eg, name in zip([spec_engine], ["Speculative Decoding"]): - stats = eg.stats() - print("engine name:", name) - if name == "Speculative Decoding": - print("total draft tokens:", stats["total_draft_tokens"]) - print("total accepted tokens:", stats["total_accepted_tokens"]) - print( - "Accept rate:", - stats["total_accepted_tokens"] / (1e-10 + stats["total_draft_tokens"]), - ) - print("engine total decode time:", stats["engine_total_decode_time"]) - print() - - -if __name__ == "__main__": - test_engine_basic() - test_engine_eagle_basic() - test_engine_continuous_batching_1() - test_engine_eagle_continuous_batching_1() - test_engine_generate(compare_precision=True) - test_engine_eagle_generate() - test_engine_efficiency() - test_engine_spec_efficiency() - test_engine_eagle_spec_efficiency() diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py deleted file mode 100644 index f68f48b7c5..0000000000 --- a/tests/python/serve/test_serve_sync_engine.py +++ /dev/null @@ -1,396 +0,0 @@ -# pylint: disable=chained-comparison,line-too-long,missing-docstring, -# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable -from typing import Callable, List, Optional - -import numpy as np - -from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data -from mlc_llm.serve.sync_engine import SyncMLCEngine - -prompts = [ - "What is the meaning of life?", - "Introduce the history of Pittsburgh to me. Please elaborate in detail.", - "Write a three-day Seattle travel plan. Please elaborate in detail.", - "What is Alaska famous of? Please elaborate in detail.", - "What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.", - "What are the necessary components to assemble a desktop computer? Please elaborate in detail.", - "Why is Vitamin D important to human beings? Please elaborate in detail.", - "Where is milk tea originated from? Please elaborate in detail.", - "Where is the southernmost place in United States? Please elaborate in detail.", - "Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.", -] - - -def create_requests( - num_requests: int, - stop_token_id: Optional[int] = None, - temperature: float = 0.8, - repetition_penalty: float = 1.0, - max_tokens_low: int = 256, - max_tokens_high: int = 257, -) -> List[Request]: - assert num_requests >= 0 and num_requests <= len(prompts) - - stop_token_ids = [stop_token_id] if stop_token_id is not None else [] - requests = [] - for req_id, prompt in zip(range(num_requests), prompts): - max_tokens = np.random.randint(max_tokens_low, max_tokens_high) - requests.append( - Request( - request_id=str(req_id), - inputs=data.TextData(prompt), - generation_config=GenerationConfig( - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - ), - ) - ) - return requests - - -def test_engine_basic(): - """Test engine **without continuous batching**. - - - Add all requests to the engine altogether in the beginning. - - All requests have the same max_tokens. This means all requests - will end together. - - Engine keeps running `step` for estimated number of steps (number of - requests + max_tokens - 1). Then check the output of each request. - """ - - # Hyperparameters for tests (you can try different combinations). - num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0, 0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.0 # [1.0, 1.01] - max_tokens: int = 256 # [32, 128, 256] - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - - # Define the callback function for request generation results - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - request_stream_callback=fcallback, - ) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens, - max_tokens_high=max_tokens + 1, - ) - - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max_tokens - 1 - # Run steps - for step in range(num_steps): - engine.step() - - for req_id, output in enumerate(outputs): - print(f"Prompt {req_id}: {requests[req_id].inputs[0]}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - - -def test_engine_continuous_batching_1(): - """Test engine **with continuous batching**. - - - Add all requests to the engine altogether in the beginning. - - All requests have a random maximum generation length. So each - request keeps generating until reaching the maximum length. - - Engine keeps running `step` for estimated number of steps (number of - requests + the maximum max_tokens - 1). Then check the output - of each request. - """ - - # Hyperparameters for tests (you can try different combinations) - num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.00 # [1.0, 1.01] - max_tokens_low = 128 - max_tokens_high = 384 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - finish_time = [None] * num_requests - - # Define the callback class for request generation results - class CallbackTimer: - timer: int = -1 - - def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - if stream_outputs[0].finish_reason is not None: - print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - finish_time[int(request_id)] = self.timer - - return fcallback - - def step(self) -> None: - self.timer += 1 - - # Create engine - timer = CallbackTimer() - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - request_stream_callback=timer.callback_getter(), - ) - - # Create requests - requests = create_requests( - num_requests, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens_low, - max_tokens_high=max_tokens_high, - ) - - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1 - # Run steps - for step in range(num_steps): - timer.step() - assert timer.timer == step - engine.step() - - for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): - print(f"Prompt {req_id}: {request.inputs[0]}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - assert ( - fin_time == request.generation_config.max_tokens - 1 - ), f"finish time = {fin_time}, max tokens = {request.generation_config.max_tokens - 1}" - - -def test_engine_continuous_batching_2(): - """Test engine **with continuous batching**. - - - Add all requests to the engine altogether in the beginning. - - All requests have the stop token. So each request keeps generating - until having the stop token or reaching the maximum length. - - Engine keeps running `step` for estimated number of steps (number of - requests + the maximum max_tokens - 1). Then check the output - of each request. - """ - - # Hyperparameters for tests (you can try different combinations) - num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.00 # [1.0, 1.01] - stop_token_id = 2 - max_tokens = 512 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - finish_time = [None] * num_requests - - # Define the callback class for request generation results - class CallbackTimer: - timer: int = -1 - - def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - if stream_outputs[0].finish_reason is not None: - print(f"Request {request_id} finished at step {self.timer}.") - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - finish_time[int(request_id)] = self.timer - - return fcallback - - def step(self) -> None: - self.timer += 1 - - # Create engine - timer = CallbackTimer() - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - request_stream_callback=timer.callback_getter(), - ) - - # Create requests - requests = create_requests( - num_requests, - stop_token_id=stop_token_id, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens, - max_tokens_high=max_tokens + 1, - ) - - # Add all requests to engine - for request in requests: - engine.add_request(request) - - num_steps = num_requests + max_tokens - 1 - # Run steps - for step in range(num_steps): - timer.step() - assert timer.timer == step - engine.step() - - for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): - print(f"Prompt {req_id}: {request.inputs[0]}") - if fin_time < num_requests + max_tokens - 2: - print(f"Request {req_id} ends early on the stop token") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - - -def test_engine_continuous_batching_3(): - """Test engine **with continuous batching**. - - - Add requests randomly between time [0, 200). - - All requests have a random maximum generation length. So each - request keeps generating until reaching the maximum length. - - Engine keeps running `step` until all requests finish. - Then check the output of each request. - """ - - # Hyperparameters for tests (you can try different combinations) - num_requests = 10 # [4, 8, 10] - temperature = 0.9 # [0.8, 0.9, 1.0, 1.1] - repetition_penalty = 1.00 # [1.0, 1.01] - stop_token_id = 2 - max_tokens_low = 64 - max_tokens_high = 192 - np.random.seed(0) - - # Output list - outputs = [[] for _ in range(num_requests)] - finish_time = [None] * num_requests - - # Define the callback class for request generation results - class CallbackTimer: - timer: int = -1 - finished_requests: int = 0 - - def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]: - def fcallback(delta_outputs: List[RequestStreamOutput]): - for delta_output in delta_outputs: - request_id, stream_outputs = delta_output.unpack() - assert len(stream_outputs) == 1 - if stream_outputs[0].finish_reason is not None: - print(f"Request {request_id} finished at step {self.timer}.") - self.finished_requests += 1 - outputs[int(request_id)] += stream_outputs[0].delta_token_ids - finish_time[int(request_id)] = self.timer - - return fcallback - - def step(self) -> None: - self.timer += 1 - - def all_finished(self) -> bool: - return self.finished_requests == num_requests - - # Create engine - timer = CallbackTimer() - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - request_stream_callback=timer.callback_getter(), - ) - - # Create requests - requests = create_requests( - num_requests, - stop_token_id=stop_token_id, - temperature=temperature, - repetition_penalty=repetition_penalty, - max_tokens_low=max_tokens_low, - max_tokens_high=max_tokens_high, - ) - - # Assign the time to add requests to engine - request_add_time = [np.random.randint(0, 200) for _ in range(num_requests)] - - # Run steps - while not timer.all_finished(): - timer.step() - - # Add requests to engine - for req_id, add_time in enumerate(request_add_time): - if add_time == timer.timer: - print(f"add request {req_id} at step {timer.timer}") - engine.add_request(requests[req_id]) - - engine.step() - - for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)): - print(f"Prompt {req_id}: {request.inputs[0]}") - print(f"Finish time: {fin_time}") - print(f"Output {req_id}:{engine.tokenizer.decode(output)}\n") - - -def test_engine_generate(): - # Create engine - model = "dist/Llama-2-7b-chat-hf-q0f16-MLC" - model_lib_path = "dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so" - engine = SyncMLCEngine( - model=model, - model_lib_path=model_lib_path, - mode="server", - max_total_sequence_length=4096, - ) - - num_requests = 10 - max_tokens = 256 - - # Generate output. - output_texts, _ = engine.generate( - prompts[:num_requests], GenerationConfig(max_tokens=max_tokens, n=7) - ) - for req_id, outputs in enumerate(output_texts): - print(f"Prompt {req_id}: {prompts[req_id]}") - if len(outputs) == 1: - print(f"Output {req_id}:{outputs[0]}\n") - else: - for i, output in enumerate(outputs): - print(f"Output {req_id}({i}):{output}\n") - - -if __name__ == "__main__": - test_engine_basic() - test_engine_continuous_batching_1() - test_engine_continuous_batching_2() - test_engine_continuous_batching_3() - test_engine_generate() diff --git a/tests/python/support/test_auto_config.py b/tests/python/support/test_auto_config.py deleted file mode 100644 index 90e797b14e..0000000000 --- a/tests/python/support/test_auto_config.py +++ /dev/null @@ -1,41 +0,0 @@ -# pylint: disable=missing-docstring -import json -import tempfile -from pathlib import Path - -import pytest - -from mlc_llm.support import logging -from mlc_llm.support.auto_config import detect_config - -logging.enable_logging() - - -def _create_json_file(json_path, data): - with open(json_path, "w", encoding="utf-8") as i_f: - json.dump(data, i_f) - - -def test_detect_config(): - with tempfile.TemporaryDirectory() as tmpdir: - base_path = Path(tmpdir) - config_json_path = base_path / "config.json" - _create_json_file(config_json_path, {}) - - assert detect_config(base_path) == config_json_path - assert detect_config(config_json_path) == config_json_path - - -def test_detect_config_fail(): - with pytest.raises(ValueError): - detect_config(Path("do/not/exist")) - - with tempfile.TemporaryDirectory() as tmpdir: - base_path = Path(tmpdir) - with pytest.raises(ValueError): - assert detect_config(base_path) - - -if __name__ == "__main__": - test_detect_config() - test_detect_config_fail() diff --git a/tests/python/support/test_auto_weight.py b/tests/python/support/test_auto_weight.py deleted file mode 100644 index 2b3ad48393..0000000000 --- a/tests/python/support/test_auto_weight.py +++ /dev/null @@ -1,118 +0,0 @@ -# pylint: disable=missing-docstring -import json -import os -import tempfile -from pathlib import Path - -import pytest - -from mlc_llm.support import logging -from mlc_llm.support.auto_weight import detect_weight - -logging.enable_logging() - - -def _create_json_file(json_path, data): - with open(json_path, "w", encoding="utf-8") as i_f: - json.dump(data, i_f) - - -@pytest.mark.parametrize( - "weight_format, index_filename, result", - [ - ("huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch"), - ("huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor"), - ("auto", "pytorch_model.bin.index.json", "huggingface-torch"), - ("auto", "model.safetensors.index.json", "huggingface-safetensor"), - ], -) -def test_detect_weight(weight_format, index_filename, result): - with tempfile.TemporaryDirectory() as tmpdir: - base_path = Path(tmpdir) - if index_filename is not None: - weight_index_file = base_path / index_filename - _create_json_file(weight_index_file, {}) - assert detect_weight(base_path, None, weight_format) == (weight_index_file, result) - - -@pytest.mark.parametrize( - "weight_format, index_filename, result", - [ - ("huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch"), - ("huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor"), - ("auto", "pytorch_model.bin.index.json", "huggingface-torch"), - ("auto", "model.safetensors.index.json", "huggingface-safetensor"), - ], -) -def test_detect_weight_in_config_json(weight_format, index_filename, result): - with tempfile.TemporaryDirectory() as config_dir, tempfile.TemporaryDirectory() as weight_dir: - config_path = Path(config_dir) - weight_path = Path(weight_dir) - config_json_path = config_path / "config.json" - _create_json_file(config_json_path, {"weight_path": weight_dir}) - if index_filename is not None: - weight_index_file = weight_path / index_filename - _create_json_file(weight_index_file, {}) - - assert detect_weight(None, config_json_path, weight_format) == (weight_index_file, result) - - -@pytest.mark.parametrize( - "weight_format, index_filename, result", - [ - ("huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch"), - ("huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor"), - ("auto", "pytorch_model.bin.index.json", "huggingface-torch"), - ("auto", "model.safetensors.index.json", "huggingface-safetensor"), - ], -) -def test_detect_weight_same_dir_config_json(weight_format, index_filename, result): - with tempfile.TemporaryDirectory() as tmpdir: - base_path = Path(tmpdir) - config_json_path = base_path / "config.json" - _create_json_file(config_json_path, {}) - if index_filename is not None: - weight_index_file = Path(os.path.join(tmpdir, index_filename)) - _create_json_file(weight_index_file, {}) - assert detect_weight(None, config_json_path, weight_format) == (weight_index_file, result) - - -def test_find_weight_fail(): - with tempfile.TemporaryDirectory() as tmpdir: - base_path = Path(tmpdir) - with pytest.raises(ValueError): - detect_weight(Path("do/not/exist"), base_path, "awq") - with pytest.raises(AssertionError): - detect_weight(None, Path("do/not/exist"), "awq") - - -if __name__ == "__main__": - test_detect_weight("huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch") - test_detect_weight( - "huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor" - ) - test_detect_weight("auto", "pytorch_model.bin.index.json", "huggingface-torch") - test_detect_weight("auto", "model.safetensors.index.json", "huggingface-safetensor") - test_detect_weight_in_config_json( - "huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch" - ) - test_detect_weight_in_config_json( - "huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor" - ) - test_detect_weight_in_config_json("auto", "pytorch_model.bin.index.json", "huggingface-torch") - test_detect_weight_in_config_json( - "auto", "model.safetensors.index.json", "huggingface-safetensor" - ) - test_detect_weight_same_dir_config_json( - "huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch" - ) - test_detect_weight_same_dir_config_json( - "huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor" - ) - test_detect_weight_same_dir_config_json( - "auto", "pytorch_model.bin.index.json", "huggingface-torch" - ) - test_detect_weight_same_dir_config_json( - "auto", "model.safetensors.index.json", "huggingface-safetensor" - ) - test_find_weight_fail() diff --git a/tests/python/support/test_streamer.py b/tests/python/support/test_streamer.py deleted file mode 100644 index 4ea4573c08..0000000000 --- a/tests/python/support/test_streamer.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Streamer tests in MLC LLM. - -Please specify the local path to llama2 tokenizer via environment -variable before running this test. -The recommended way to run the tests is to use the following command: - MLC_LLAMA_TOKENIZER_PATH="path/to/llama/tokenizer" \ - pytest -vv tests/python/support/test_text_streamer_stop_handler.py - -Here "MLC_LLAMA_TOKENIZER_PATH" can be chosen from -- a llama2 weight directory (e.g., "path/to/Llama-2-7b-chat-hf"), -- a sentencepiece llama2 tokenizer path - (e.g., "path/to/Llama-2-7b-chat-hf/tokenizer.model"). - -To directly run the Python file (a.k.a., not using pytest), you also need to -specify the tokenizer path via environment variable. -""" - -# pylint: disable=missing-function-docstring -import os -import time -from typing import List, Tuple - -import pytest - -from mlc_llm.streamer import StopStrHandler, TextStreamer -from mlc_llm.tokenizer import Tokenizer - -# fmt: off -para_input_tokens = [18585, 29892, 1244, 29915, 29879, 263, 3273, 14880, 1048, 953, 29877, 2397, - 29892, 988, 1269, 1734, 338, 5643, 491, 385, 953, 29877, 2397, 29901, 13, 13, - 29950, 1032, 727, 29991, 29871, 243, 162, 148, 142, 306, 29915, 29885, 1244, 304, - 1371, 1234, 738, 5155, 366, 505, 1048, 953, 29877, 2397, 29871, 243, 162, 167, 151, - 29889, 7440, 366, 1073, 393, 953, 29877, 2397, 508, 367, 1304, 304, 27769, 23023, - 1080, 322, 21737, 297, 263, 2090, 322, 1708, 1319, 982, 29973, 29871, 243, 162, 155, - 135, 2688, 508, 884, 367, 1304, 304, 788, 263, 6023, 310, 2022, 2877, 304, 596, 7191, - 322, 11803, 29889, 29871, 243, 162, 149, 152, 1126, 29892, 1258, 366, 1073, 393, 727, - 526, 1584, 953, 29877, 2397, 8090, 322, 14188, 366, 508, 1708, 29973, 29871, 243, 162, - 145, 177, 243, 162, 148, 131, 1105, 29892, 748, 14432, 322, 679, 907, 1230, 411, 953, - 29877, 2397, 29991, 29871, 243, 162, 149, 168, 243, 162, 145, 171] - -DECODED_PARAGRAPH = ( - "Sure, here's a short paragraph about emoji, " - "where each word is followed by an emoji:\n\n" - "Hey there! 👋 I'm here to help answer any questions you have about emoji 🤔. " - "Did you know that emoji can be used to convey emotions and feelings in a " - "fun and playful way? 😄 " - "They can also be used to add a touch of personality to your messages and posts. 💕 " - "And, did you know that there are even emoji games and activities you can play? 🎮👀 " - "So, go ahead and get creative with emoji! 💥🎨" -) -# fmt: on - - -def _get_tokenizer_path() -> str: - path = os.environ.get("MLC_LLAMA_TOKENIZER_PATH") - if path is None: - raise ValueError( - 'Environment variable "MLC_LLAMA_TOKENIZER_PATH" not found. ' - "Please set it to the a valid llama tokenizer path." - ) - return path - - -@pytest.fixture -def llama_tokenizer_path() -> str: - return _get_tokenizer_path() - - -def test_text_streamer(llama_tokenizer_path: str): # pylint: disable=redefined-outer-name - text_streamer = TextStreamer(Tokenizer(llama_tokenizer_path)) - total_text = "" - for token in para_input_tokens: - total_text += text_streamer.put([token]) - total_text += text_streamer.finish() - - assert total_text == DECODED_PARAGRAPH - - -def stop_handler_process_tokens( - stop_handler: StopStrHandler, tokens: List[int], tokenizer: Tokenizer -) -> str: - returned_tokens = [] - for token in tokens: - returned_tokens += stop_handler.put(token) - if stop_handler.stop_triggered: - break - - if not stop_handler.stop_triggered: - returned_tokens += stop_handler.finish() - - return tokenizer.decode(returned_tokens) - - -def test_stop_str_handler_stop(llama_tokenizer_path: str): # pylint: disable=redefined-outer-name - stop_strs = [" 🤔"] - tokenizer = Tokenizer(llama_tokenizer_path) - stop_handler = StopStrHandler(stop_strs, tokenizer) - - total_text = stop_handler_process_tokens(stop_handler, para_input_tokens, tokenizer) - expected_text = ( - "Sure, here's a short paragraph about emoji, " - "where each word is followed by an emoji:\n\n" - "Hey there! 👋 I'm here to help answer any questions you have about emoji" - ) - - assert total_text == expected_text - - -def test_stop_str_handler_not_stop( - llama_tokenizer_path: str, # pylint: disable=redefined-outer-name -): - stop_strs = ["^^"] - tokenizer = Tokenizer(llama_tokenizer_path) - stop_handler = StopStrHandler(stop_strs, tokenizer) - - total_text = stop_handler_process_tokens(stop_handler, para_input_tokens, tokenizer) - assert total_text == DECODED_PARAGRAPH - - -def test_stop_str_handler_return_cached_tokens( - llama_tokenizer_path: str, # pylint: disable=redefined-outer-name -): - tokens = para_input_tokens[:26] # until "\n\n" - stop_strs = ["\n\n\n"] - tokenizer = Tokenizer(llama_tokenizer_path) - stop_handler = StopStrHandler(stop_strs, tokenizer) - - total_text = stop_handler_process_tokens(stop_handler, tokens, tokenizer) - expected_text = ( - "Sure, here's a short paragraph about emoji, " - "where each word is followed by an emoji:\n\n" - ) - - assert total_text == expected_text - - -def test_stop_str_handler_throughput( - llama_tokenizer_path: str, # pylint: disable=redefined-outer-name -): - stop_strs = ["[INST]"] - tokenizer = Tokenizer(llama_tokenizer_path) - stop_handler = StopStrHandler(stop_strs, tokenizer) - - tokens = para_input_tokens * 20 - returned_tokens = [] - - tbegin = time.perf_counter() - for token in tokens: - returned_tokens += stop_handler.put(token) - assert not stop_handler.stop_triggered - tend = time.perf_counter() - - throughput = len(tokens) / (tend - tbegin) - print( - f"num tokens = {len(tokens)}, " - f"time elapsed = {tend - tbegin:.5f} sec, " - f"throughput = {throughput}" - ) - assert throughput >= 100000 - - -emoji_tokens_expected_result = [ - # HF: "�����", SentencePiece: "�👀" - ([177, 243, 162, 148, 131], ("�����", "�👀")), - # Both: "👀👀" - ([243, 162, 148, 131, 243, 162, 148, 131], ("👀👀",)), - # Both: "👀👀👀" - ([243, 162, 148, 131, 243, 162, 148, 131, 243, 162, 148, 131], ("👀👀👀",)), - # HF: "👀�������", SentencePiece: "👀���👀" - ([243, 162, 148, 131, 162, 148, 131, 243, 162, 148, 131], ("👀�������", "👀���👀")), - # Both: "👀��� have👀" - ([243, 162, 148, 131, 162, 148, 131, 505, 243, 162, 148, 131], ("👀��� have👀",)), -] - - -@pytest.mark.parametrize("tokens_and_results", emoji_tokens_expected_result) -def test_text_streamer_emojis( - llama_tokenizer_path: str, tokens_and_results: Tuple[List[int], Tuple[str]] -): # pylint: disable=redefined-outer-name - text_streamer = TextStreamer(Tokenizer(llama_tokenizer_path)) - total_text = "" - tokens, expected_results = tokens_and_results - for token in tokens: - total_text += text_streamer.put([token]) - total_text += text_streamer.finish() - assert total_text in expected_results - - -if __name__ == "__main__": - tokenizer_path = _get_tokenizer_path() - test_text_streamer(tokenizer_path) - test_stop_str_handler_stop(tokenizer_path) - test_stop_str_handler_not_stop(tokenizer_path) - test_stop_str_handler_return_cached_tokens(tokenizer_path) - test_stop_str_handler_throughput(tokenizer_path) - - for tokens_and_res in emoji_tokens_expected_result: - test_text_streamer_emojis(tokenizer_path, tokens_and_res) diff --git a/web/Makefile b/web/Makefile deleted file mode 100644 index 48f98b5e81..0000000000 --- a/web/Makefile +++ /dev/null @@ -1,51 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -TVM_ROOT=$(TVM_HOME) -MLC_LLM_ROOT=$(shell cd ..; pwd) - -INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ - -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include\ - -I$(TVM_ROOT)/3rdparty/compiler-rt -I$(TVM_ROOT)/3rdparty/picojson\ - -I$(MLC_LLM_ROOT)/3rdparty/tokenizers-cpp\ - -I$(MLC_LLM_ROOT)/3rdparty/tokenizers-cpp/include -I$(MLC_LLM_ROOT)/cpp - -.PHONY: clean all rmtypedep preparetest - -all: dist/wasm/mlc_wasm_runtime.wasm - -EMCC = emcc - -EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes - -EMCC_LDFLAGS = --no-entry -s WASM_BIGINT=1 -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1\ - -s ERROR_ON_UNDEFINED_SYMBOLS=0 - -dist/wasm/mlc_wasm_runtime.bc: emcc/mlc_wasm_runtime.cc - @mkdir -p $(@D) - $(EMCC) $(EMCC_CFLAGS) -c -MM -MT dist/wasm/mlc_wasm_runtime.bc emcc/mlc_wasm_runtime.cc >dist/wasm/mlc_wasm_runtime.d - $(EMCC) $(EMCC_CFLAGS) -emit-llvm -c -o dist/wasm/mlc_wasm_runtime.bc emcc/mlc_wasm_runtime.cc - -# Compile to wasm here so that errors can be caught earlier (rather than during export_library) -dist/wasm/mlc_wasm_runtime.wasm: dist/wasm/mlc_wasm_runtime.bc - @mkdir -p $(@D) - $(EMCC) $(EMCC_CFLAGS) -o dist/wasm/mlc_wasm_runtime.wasm $+ $(EMCC_LDFLAGS) - -clean: - @rm -rf dist/wasm lib - --include dist/wasm/*.d diff --git a/web/README.md b/web/README.md deleted file mode 100644 index f4fc808b1f..0000000000 --- a/web/README.md +++ /dev/null @@ -1,28 +0,0 @@ - - - - - - - - - - - - - - - - - -# MLC-LLM WebAssembly Runtime - -This folder contains MLC-LLM WebAssembly Runtime. - -Please refer to https://llm.mlc.ai/docs/install/emcc.html. - -The main step is running `make` under this folder, a step included in `web/prep_emcc_deps.sh`. - -`make` creates `web/dist/wasm/mlc_wasm_runtime.bc`, which will be included in the model library wasm -when we compile the model. Thus during runtime, runtimes like WebLLM can directly reuse source -code from MLC-LLM. \ No newline at end of file diff --git a/web/emcc/mlc_wasm_runtime.cc b/web/emcc/mlc_wasm_runtime.cc deleted file mode 100644 index b9a7f55bfa..0000000000 --- a/web/emcc/mlc_wasm_runtime.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * \file mlc_wasm_runtime.cc - * \brief MLC wasm runtime library pack. - */ - -// configurations for tvm logging -#define TVM_LOG_STACK_TRACE 0 -#define TVM_LOG_DEBUG 0 -#define TVM_LOG_CUSTOMIZE 1 - -// Pass in COMPILE_MLC_WASM_RUNTIME so unsupported code would not be compiled in to the .bc file -#define COMPILE_MLC_WASM_RUNTIME 1 -#define __STDC_FORMAT_MACROS 1 -#define PICOJSON_USE_INT64 - -#define DMLC_USE_LOGGING_LIBRARY - -// Grammar related -#include "serve/grammar/grammar.cc" -#include "serve/grammar/grammar_parser.cc" -#include "serve/grammar/grammar_serializer.cc" -#include "serve/grammar/grammar_simplifier.cc" -#include "serve/grammar/grammar_state_matcher.cc" -#include "serve/grammar/json_schema_converter.cc" -#include "support/encoding.cc" diff --git a/web/prep_emcc_deps.sh b/web/prep_emcc_deps.sh deleted file mode 100755 index 0ccf98698b..0000000000 --- a/web/prep_emcc_deps.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash -# This file prepares all the necessary dependencies for the web build. -set -euxo pipefail - -emcc --version -npm --version - -TVM_HOME_SET="${TVM_HOME:-}" - -git submodule update --init --recursive - -# Build mlc_wasm_runtime -cd web && make -cd - - -# Build tvm's web runtime -if [[ -z ${TVM_HOME_SET} ]]; then - echo "Do not find TVM_HOME env variable, use 3rdparty/tvm". - echo "Make sure you set TVM_HOME in your env variable to use emcc build correctly" - export TVM_HOME="${TVM_HOME:-3rdparty/tvm}" -fi - -cd ${TVM_HOME}/web && make -cd -