diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt index cd8b23ce08..956b507ee5 100644 --- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt @@ -1,6 +1,7 @@ package ai.mlc.mlcchat -import ai.mlc.mlcllm.ChatModule +import ai.mlc.mlcllm.MLCEngine +import ai.mlc.mlcllm.OpenAIProtocol import android.app.Application import android.content.ClipData import android.content.ClipboardManager @@ -21,6 +22,8 @@ import java.nio.channels.Channels import java.util.UUID import java.util.concurrent.Executors import kotlin.concurrent.thread +import ai.mlc.mlcllm.OpenAIProtocol.ChatCompletionMessage +import kotlinx.coroutines.* class AppViewModel(application: Application) : AndroidViewModel(application) { val modelList = emptyList().toMutableStateList() @@ -502,14 +505,14 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { private var modelChatState = mutableStateOf(ModelChatState.Ready) @Synchronized get @Synchronized set - private val backend = ChatModule() + private val engine = MLCEngine() private var modelLib = "" private var modelPath = "" private val executorService = Executors.newSingleThreadExecutor() - + private val viewModelScope = CoroutineScope(Dispatchers.Main + Job()) private fun mainResetChat() { executorService.submit { - callBackend { backend.resetChat() } + callBackend { engine.reset() } viewModelScope.launch { clearHistory() switchToReady() @@ -551,7 +554,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { val stackTrace = e.stackTraceToString() val errorMessage = e.localizedMessage appendMessage( - MessageRole.Bot, + MessageRole.Assistant, "MLCChat failed\n\nStack trace:\n$stackTrace\n\nError message:\n$errorMessage" ) switchToFailed() @@ -604,7 +607,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { private fun mainTerminateChat(callback: () -> Unit) { executorService.submit { - callBackend { backend.unload() } + callBackend { engine.unload() } viewModelScope.launch { clearHistory() switchToReady() @@ -644,11 +647,8 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { Toast.makeText(application, "Initialize...", Toast.LENGTH_SHORT).show() } if (!callBackend { - backend.unload() - backend.reload( - modelConfig.modelLib, - modelPath - ) + engine.unload() + engine.reload(modelPath, modelConfig.modelLib) }) return@submit viewModelScope.launch { Toast.makeText(application, "Ready to chat", Toast.LENGTH_SHORT).show() @@ -662,19 +662,30 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { switchToGenerating() executorService.submit { appendMessage(MessageRole.User, prompt) - appendMessage(MessageRole.Bot, "") - if (!callBackend { backend.prefill(prompt) }) return@submit - while (!backend.stopped()) { - if (!callBackend { - backend.decode() - val newText = backend.message - viewModelScope.launch { updateMessage(MessageRole.Bot, newText) } - }) return@submit - if (modelChatState.value != ModelChatState.Generating) return@submit - } - val runtimeStats = backend.runtimeStatsText() + appendMessage(MessageRole.Assistant, "") viewModelScope.launch { - report.value = runtimeStats + val channel = engine.chat.completions.create( + messages = listOf( + ChatCompletionMessage( + role = OpenAIProtocol.ChatCompletionRole.user, + content = prompt + ) + ) + ) + var texts = "" + for (response in channel) { + if (!callBackend { + val finalsage = response.usage + if (finalsage != null) { + report.value = (finalsage.extra?.asTextLabel()?:"") + } else { + if (response.choices.size > 0) { + texts += response.choices[0].delta.content?.asText().orEmpty() + } + } + updateMessage(MessageRole.Assistant, texts) + }); + } if (modelChatState.value == ModelChatState.Generating) switchToReady() } } @@ -722,7 +733,7 @@ enum class ModelChatState { } enum class MessageRole { - Bot, + Assistant, User } @@ -757,4 +768,4 @@ data class ParamsRecord( data class ParamsConfig( @SerializedName("records") val paramsRecords: List -) \ No newline at end of file +) diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt index 9f581ab313..d92342b1d4 100644 --- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt @@ -136,7 +136,7 @@ fun ChatView( @Composable fun MessageView(messageData: MessageData) { SelectionContainer { - if (messageData.role == MessageRole.Bot) { + if (messageData.role == MessageRole.Assistant) { Row( horizontalArrangement = Arrangement.Start, modifier = Modifier.fillMaxWidth() diff --git a/android/mlc4j/src/main/java/ai/mlc/mlcllm/JSONFFIEngine.java b/android/mlc4j/src/main/java/ai/mlc/mlcllm/JSONFFIEngine.java index 59d8585426..ad95dd2a02 100644 --- a/android/mlc4j/src/main/java/ai/mlc/mlcllm/JSONFFIEngine.java +++ b/android/mlc4j/src/main/java/ai/mlc/mlcllm/JSONFFIEngine.java @@ -80,4 +80,8 @@ public interface KotlinFunction { void invoke(String arg); } + public void reset() { + resetFunc.invoke(); + } + } diff --git a/android/mlc4j/src/main/java/ai/mlc/mlcllm/MLCEngine.kt b/android/mlc4j/src/main/java/ai/mlc/mlcllm/MLCEngine.kt index a2b0a3de37..58760d045b 100644 --- a/android/mlc4j/src/main/java/ai/mlc/mlcllm/MLCEngine.kt +++ b/android/mlc4j/src/main/java/ai/mlc/mlcllm/MLCEngine.kt @@ -1,83 +1,152 @@ package ai.mlc.mlcllm -import ai.mlc.mlcllm.JSONFFIEngine import ai.mlc.mlcllm.OpenAIProtocol.* import kotlinx.coroutines.GlobalScope -import kotlinx.serialization.json.Json -import kotlinx.serialization.encodeToString import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.launch -import java.lang.Exception +import kotlinx.serialization.json.Json +import kotlinx.serialization.encodeToString +import kotlinx.serialization.decodeFromString +import kotlin.concurrent.thread import java.util.UUID +import java.util.logging.Logger + +class BackgroundWorker(private val task: () -> Unit) { + + fun start() { + thread(start = true) { + task() + } + } +} -class MLCEngine () { - private val jsonFFIEngine = JSONFFIEngine() - private val channelMap = mutableMapOf>() +class MLCEngine { + + private val state: EngineState + private val jsonFFIEngine: JSONFFIEngine + val chat: Chat + private val threads = mutableListOf() init { - jsonFFIEngine.initBackgroundEngine(this::streamCallback) - GlobalScope.launch { + state = EngineState() + jsonFFIEngine = JSONFFIEngine() + chat = Chat(jsonFFIEngine, state) + + jsonFFIEngine.initBackgroundEngine { result -> + state.streamCallback(result) + } + + val backgroundWorker = BackgroundWorker { + Thread.currentThread().priority = Thread.MAX_PRIORITY jsonFFIEngine.runBackgroundLoop() } - GlobalScope.launch { + + val backgroundStreamBackWorker = BackgroundWorker { jsonFFIEngine.runBackgroundStreamBackLoop() } + + threads.add(backgroundWorker) + threads.add(backgroundStreamBackWorker) + + backgroundWorker.start() + backgroundStreamBackWorker.start() } - private fun streamCallback(result: String?) { - val responses = mutableListOf() + fun reload(modelPath: String, modelLib: String) { + val engineConfig = """ + { + "model": "$modelPath", + "model_lib": "system://$modelLib", + "mode": "interactive" + } + """ + jsonFFIEngine.reload(engineConfig) + } + + fun reset() { + jsonFFIEngine.reset() + } + + fun unload() { + jsonFFIEngine.unload() + } +} + +data class RequestState( + val request: ChatCompletionRequest, + val continuation: Channel +) + +class EngineState { + + private val logger = Logger.getLogger(EngineState::class.java.name) + private val requestStateMap = mutableMapOf() + + suspend fun chatCompletion( + jsonFFIEngine: JSONFFIEngine, + request: ChatCompletionRequest + ): ReceiveChannel { + val json = Json { encodeDefaults = true } + val jsonRequest = json.encodeToString(request) + val requestID = UUID.randomUUID().toString() + val channel = Channel(Channel.UNLIMITED) + + requestStateMap[requestID] = RequestState(request, channel) + + jsonFFIEngine.chatCompletion(jsonRequest, requestID) + + return channel + } + + fun streamCallback(result: String?) { val json = Json { ignoreUnknownKeys = true } try { - val msg = json.decodeFromString(result!!) - responses.add(msg) - } catch (lastError: Exception) { - println("Kotlin json parsing error: error=$lastError, jsonsrc=$result") - } + val responses: List = json.decodeFromString(result ?: return) - // dispatch to right request ID - for (res in responses) { - val channel = channelMap[res.id] - if (channel != null) { + responses.forEach { res -> + val requestState = requestStateMap[res.id] ?: return@forEach GlobalScope.launch { - channel.send(res) - // detect finished from result - var finished = false - for (choice in res.choices) { - if (choice.finish_reason != "" && choice.finish_reason != null) { - finished = true - } + val sendResult = requestState.continuation.trySend(res) + if (sendResult.isFailure) { + // Handle the failure case if needed + logger.severe("Failed to send response: ${sendResult.exceptionOrNull()}") } - if (finished) { - channel.close() - channelMap.remove(res.id) + + res.usage?.let { finalUsage -> + requestState.request.stream_options?.include_usage?.let { includeUsage -> + if (includeUsage) { + requestState.continuation.send(res) + } + } + requestState.continuation.close() + requestStateMap.remove(res.id) } } - } + } catch (e: Exception) { + logger.severe("Kotlin JSON parsing error: $e, jsonsrc=$result") } } +} - private fun deinit() { - jsonFFIEngine.exitBackgroundLoop() - } +class Chat( + private val jsonFFIEngine: JSONFFIEngine, + private val state: EngineState +) { + val completions = Completions(jsonFFIEngine, state) +} - fun reload(modelPath: String, modelLib: String) { - val engineConfigJSONStr = """ - { - "model": "$modelPath", - "model_lib": "system://$modelLib", - "mode": "interactive" - } - """.trimIndent() - jsonFFIEngine.reload(engineConfigJSONStr) - } +class Completions( + private val jsonFFIEngine: JSONFFIEngine, + private val state: EngineState +) { - private fun unload() { - jsonFFIEngine.unload() + suspend fun create(request: ChatCompletionRequest): ReceiveChannel { + return state.chatCompletion(jsonFFIEngine, request) } - fun chatCompletion( + suspend fun create( messages: List, model: String? = null, frequency_penalty: Float? = null, @@ -89,13 +158,18 @@ class MLCEngine () { n: Int = 1, seed: Int? = null, stop: List? = null, - stream: Boolean = false, + stream: Boolean = true, + stream_options: StreamOptions? = null, temperature: Float? = null, top_p: Float? = null, tools: List? = null, user: String? = null, response_format: ResponseFormat? = null ): ReceiveChannel { + if (!stream) { + throw IllegalArgumentException("Only stream=true is supported in MLCKotlin") + } + val request = ChatCompletionRequest( messages = messages, model = model, @@ -109,25 +183,14 @@ class MLCEngine () { seed = seed, stop = stop, stream = stream, + stream_options = stream_options, temperature = temperature, top_p = top_p, tools = tools, user = user, response_format = response_format ) - return chatCompletion(request) - } - - private fun chatCompletion(request: ChatCompletionRequest): ReceiveChannel { - val channel = Channel() - val jsonRequest = Json.encodeToString(request) - val requestId = UUID.randomUUID().toString() - - // Store the channel in the map for further callbacks - channelMap[requestId] = channel - - jsonFFIEngine.chatCompletion(jsonRequest, requestId) - - return channel + return create(request) } } + diff --git a/android/mlc4j/src/main/java/ai/mlc/mlcllm/OpenAIProtocol.kt b/android/mlc4j/src/main/java/ai/mlc/mlcllm/OpenAIProtocol.kt index f381ebc4e8..7bc4bc4bc1 100644 --- a/android/mlc4j/src/main/java/ai/mlc/mlcllm/OpenAIProtocol.kt +++ b/android/mlc4j/src/main/java/ai/mlc/mlcllm/OpenAIProtocol.kt @@ -1,6 +1,5 @@ package ai.mlc.mlcllm -import android.util.Log import kotlinx.serialization.KSerializer import kotlinx.serialization.Serializable import kotlinx.serialization.builtins.ListSerializer @@ -107,9 +106,8 @@ class OpenAIProtocol { override fun serialize(encoder: Encoder, value: ChatCompletionMessageContent) { if (value.isText()) { encoder.encodeString(value.text!!) - } - else { - encoder.encodeSerializableValue(ListSerializer(MapSerializer(String.serializer(), String.serializer())), value.parts?: listOf()) + } else { + encoder.encodeSerializableValue(ListSerializer(MapSerializer(String.serializer(), String.serializer())), value.parts ?: listOf()) } } @@ -144,6 +142,40 @@ class OpenAIProtocol { ) : this(role, ChatCompletionMessageContent(content), name, tool_calls, tool_call_id) } + @Serializable + data class CompletionUsageExtra( + val prefill_tokens_per_s: Float? = null, + val decode_tokens_per_s: Float? = null, + val num_prefill_tokens: Int? = null + ) { + fun asTextLabel(): String { + var outputText = "" + if (prefill_tokens_per_s != null) { + outputText += "prefill: ${String.format("%.1f", prefill_tokens_per_s)} tok/s" + } + if (decode_tokens_per_s != null) { + if (outputText.isNotEmpty()) { + outputText += ", " + } + outputText += "decode: ${String.format("%.1f", decode_tokens_per_s)} tok/s" + } + return outputText + } + } + + @Serializable + data class CompletionUsage( + val prompt_tokens: Int, + val completion_tokens: Int, + val total_tokens: Int, + val extra: CompletionUsageExtra? = null + ) + + @Serializable + data class StreamOptions( + val include_usage: Boolean = false + ) + @Serializable data class ChatCompletionStreamResponseChoice( var finish_reason: String? = null, @@ -159,7 +191,8 @@ class OpenAIProtocol { var created: Int? = null, var model: String? = null, val system_fingerprint: String, - var `object`: String? = null + var `object`: String? = null, + val usage: CompletionUsage? = null ) @Serializable @@ -175,7 +208,8 @@ class OpenAIProtocol { val n: Int = 1, val seed: Int? = null, val stop: List? = null, - val stream: Boolean = false, + val stream: Boolean = true, + val stream_options: StreamOptions? = null, val temperature: Float? = null, val top_p: Float? = null, val tools: List? = null, @@ -188,4 +222,5 @@ class OpenAIProtocol { val type: String, val schema: String? = null ) -} \ No newline at end of file +} + diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 6d8ec9a1c2..e17d3ee77d 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -19,6 +19,21 @@ namespace mlc { namespace llm { namespace serve { +uint64_t TotalDetectGlobalMemory(DLDevice device) { + // Get single-card GPU size. + TVMRetValue rv; + DeviceAPI::Get(device)->GetAttr(device, DeviceAttrKind::kTotalGlobalMemory, &rv); + int64_t gpu_size_bytes = rv; + // Since the memory size returned by the OpenCL runtime is smaller than the actual available + // memory space, we set a best available space so that MLC LLM can run 7B or 8B models on Android + // with OpenCL. + if (device.device_type == kDLOpenCL) { + int64_t min_size_bytes = 5LL * 1024 * 1024 * 1024; // Minimum size is 5 GB + gpu_size_bytes = std::max(gpu_size_bytes, min_size_bytes); + } + return gpu_size_bytes; +} + /****************** DebugConfig ******************/ Result DebugConfig::FromJSON(const picojson::object& config) { @@ -522,10 +537,7 @@ Result EstimateMemoryUsageOnMode( logit_processor_workspace_bytes += max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125; } - // Get single-card GPU size. - TVMRetValue rv; - DeviceAPI::Get(device)->GetAttr(device, DeviceAttrKind::kTotalGlobalMemory, &rv); - int64_t gpu_size_bytes = rv; + int64_t gpu_size_bytes = TotalDetectGlobalMemory(device); // Compute the maximum total sequence length under the GPU memory budget. int64_t model_max_total_sequence_length = static_cast((gpu_size_bytes * gpu_memory_utilization // @@ -817,10 +829,7 @@ Result InferrableEngineConfig::InferForRNNState( logit_processor_workspace_bytes += max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125; } - // Get single-card GPU size. - TVMRetValue rv; - DeviceAPI::Get(device)->GetAttr(device, DeviceAttrKind::kTotalGlobalMemory, &rv); - int64_t gpu_size_bytes = rv; + int64_t gpu_size_bytes = TotalDetectGlobalMemory(device); // Compute the maximum history size length under the GPU memory budget. int64_t model_max_history_size = static_cast((gpu_size_bytes * gpu_memory_utilization // - params_bytes //