Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Android] Switch MLC Chat to use MLCEngine #2410

Merged
merged 2 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 36 additions & 25 deletions android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.mlc.mlcchat

import ai.mlc.mlcllm.ChatModule
import ai.mlc.mlcllm.MLCEngine
import ai.mlc.mlcllm.OpenAIProtocol
import android.app.Application
import android.content.ClipData
import android.content.ClipboardManager
Expand All @@ -21,6 +22,8 @@ import java.nio.channels.Channels
import java.util.UUID
import java.util.concurrent.Executors
import kotlin.concurrent.thread
import ai.mlc.mlcllm.OpenAIProtocol.ChatCompletionMessage
import kotlinx.coroutines.*

class AppViewModel(application: Application) : AndroidViewModel(application) {
val modelList = emptyList<ModelState>().toMutableStateList()
Expand Down Expand Up @@ -502,14 +505,14 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {
private var modelChatState = mutableStateOf(ModelChatState.Ready)
@Synchronized get
@Synchronized set
private val backend = ChatModule()
private val engine = MLCEngine()
private var modelLib = ""
private var modelPath = ""
private val executorService = Executors.newSingleThreadExecutor()

private val viewModelScope = CoroutineScope(Dispatchers.Main + Job())
private fun mainResetChat() {
executorService.submit {
callBackend { backend.resetChat() }
callBackend { engine.reset() }
viewModelScope.launch {
clearHistory()
switchToReady()
Expand Down Expand Up @@ -551,7 +554,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {
val stackTrace = e.stackTraceToString()
val errorMessage = e.localizedMessage
appendMessage(
MessageRole.Bot,
MessageRole.Assistant,
"MLCChat failed\n\nStack trace:\n$stackTrace\n\nError message:\n$errorMessage"
)
switchToFailed()
Expand Down Expand Up @@ -604,7 +607,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {

private fun mainTerminateChat(callback: () -> Unit) {
executorService.submit {
callBackend { backend.unload() }
callBackend { engine.unload() }
viewModelScope.launch {
clearHistory()
switchToReady()
Expand Down Expand Up @@ -644,11 +647,8 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {
Toast.makeText(application, "Initialize...", Toast.LENGTH_SHORT).show()
}
if (!callBackend {
backend.unload()
backend.reload(
modelConfig.modelLib,
modelPath
)
engine.unload()
engine.reload(modelPath, modelConfig.modelLib)
}) return@submit
viewModelScope.launch {
Toast.makeText(application, "Ready to chat", Toast.LENGTH_SHORT).show()
Expand All @@ -662,19 +662,30 @@ class AppViewModel(application: Application) : AndroidViewModel(application) {
switchToGenerating()
executorService.submit {
appendMessage(MessageRole.User, prompt)
appendMessage(MessageRole.Bot, "")
if (!callBackend { backend.prefill(prompt) }) return@submit
while (!backend.stopped()) {
if (!callBackend {
backend.decode()
val newText = backend.message
viewModelScope.launch { updateMessage(MessageRole.Bot, newText) }
}) return@submit
if (modelChatState.value != ModelChatState.Generating) return@submit
}
val runtimeStats = backend.runtimeStatsText()
appendMessage(MessageRole.Assistant, "")
viewModelScope.launch {
report.value = runtimeStats
val channel = engine.chat.completions.create(
messages = listOf(
ChatCompletionMessage(
role = OpenAIProtocol.ChatCompletionRole.user,
content = prompt
)
)
)
var texts = ""
for (response in channel) {
if (!callBackend {
val finalsage = response.usage
if (finalsage != null) {
report.value = (finalsage.extra?.asTextLabel()?:"")
} else {
if (response.choices.size > 0) {
texts += response.choices[0].delta.content?.asText().orEmpty()
}
}
updateMessage(MessageRole.Assistant, texts)
});
}
if (modelChatState.value == ModelChatState.Generating) switchToReady()
}
}
Expand Down Expand Up @@ -722,7 +733,7 @@ enum class ModelChatState {
}

enum class MessageRole {
Bot,
Assistant,
User
}

Expand Down Expand Up @@ -757,4 +768,4 @@ data class ParamsRecord(

data class ParamsConfig(
@SerializedName("records") val paramsRecords: List<ParamsRecord>
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ fun ChatView(
@Composable
fun MessageView(messageData: MessageData) {
SelectionContainer {
if (messageData.role == MessageRole.Bot) {
if (messageData.role == MessageRole.Assistant) {
Row(
horizontalArrangement = Arrangement.Start,
modifier = Modifier.fillMaxWidth()
Expand Down
4 changes: 4 additions & 0 deletions android/mlc4j/src/main/java/ai/mlc/mlcllm/JSONFFIEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,8 @@ public interface KotlinFunction {
void invoke(String arg);
}

public void reset() {
resetFunc.invoke();
}

}
189 changes: 126 additions & 63 deletions android/mlc4j/src/main/java/ai/mlc/mlcllm/MLCEngine.kt
Original file line number Diff line number Diff line change
@@ -1,83 +1,152 @@
package ai.mlc.mlcllm

import ai.mlc.mlcllm.JSONFFIEngine
import ai.mlc.mlcllm.OpenAIProtocol.*
import kotlinx.coroutines.GlobalScope
import kotlinx.serialization.json.Json
import kotlinx.serialization.encodeToString
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.launch
import java.lang.Exception
import kotlinx.serialization.json.Json
import kotlinx.serialization.encodeToString
import kotlinx.serialization.decodeFromString
import kotlin.concurrent.thread
import java.util.UUID
import java.util.logging.Logger

class BackgroundWorker(private val task: () -> Unit) {

fun start() {
thread(start = true) {
task()
}
}
}

class MLCEngine () {
private val jsonFFIEngine = JSONFFIEngine()
private val channelMap = mutableMapOf<String, Channel<ChatCompletionStreamResponse>>()
class MLCEngine {

private val state: EngineState
private val jsonFFIEngine: JSONFFIEngine
val chat: Chat
private val threads = mutableListOf<BackgroundWorker>()

init {
jsonFFIEngine.initBackgroundEngine(this::streamCallback)
GlobalScope.launch {
state = EngineState()
jsonFFIEngine = JSONFFIEngine()
chat = Chat(jsonFFIEngine, state)

jsonFFIEngine.initBackgroundEngine { result ->
state.streamCallback(result)
}

val backgroundWorker = BackgroundWorker {
Thread.currentThread().priority = Thread.MAX_PRIORITY
jsonFFIEngine.runBackgroundLoop()
}
GlobalScope.launch {

val backgroundStreamBackWorker = BackgroundWorker {
jsonFFIEngine.runBackgroundStreamBackLoop()
}

threads.add(backgroundWorker)
threads.add(backgroundStreamBackWorker)

backgroundWorker.start()
backgroundStreamBackWorker.start()
}

private fun streamCallback(result: String?) {
val responses = mutableListOf<ChatCompletionStreamResponse>()
fun reload(modelPath: String, modelLib: String) {
val engineConfig = """
{
"model": "$modelPath",
"model_lib": "system://$modelLib",
"mode": "interactive"
}
"""
jsonFFIEngine.reload(engineConfig)
}

fun reset() {
jsonFFIEngine.reset()
}

fun unload() {
jsonFFIEngine.unload()
}
}

data class RequestState(
val request: ChatCompletionRequest,
val continuation: Channel<ChatCompletionStreamResponse>
)

class EngineState {

private val logger = Logger.getLogger(EngineState::class.java.name)
private val requestStateMap = mutableMapOf<String, RequestState>()

suspend fun chatCompletion(
jsonFFIEngine: JSONFFIEngine,
request: ChatCompletionRequest
): ReceiveChannel<ChatCompletionStreamResponse> {
val json = Json { encodeDefaults = true }
val jsonRequest = json.encodeToString(request)
val requestID = UUID.randomUUID().toString()
val channel = Channel<ChatCompletionStreamResponse>(Channel.UNLIMITED)

requestStateMap[requestID] = RequestState(request, channel)

jsonFFIEngine.chatCompletion(jsonRequest, requestID)

return channel
}

fun streamCallback(result: String?) {
val json = Json { ignoreUnknownKeys = true }
try {
val msg = json.decodeFromString<ChatCompletionStreamResponse>(result!!)
responses.add(msg)
} catch (lastError: Exception) {
println("Kotlin json parsing error: error=$lastError, jsonsrc=$result")
}
val responses: List<ChatCompletionStreamResponse> = json.decodeFromString(result ?: return)

// dispatch to right request ID
for (res in responses) {
val channel = channelMap[res.id]
if (channel != null) {
responses.forEach { res ->
val requestState = requestStateMap[res.id] ?: return@forEach
GlobalScope.launch {
channel.send(res)
// detect finished from result
var finished = false
for (choice in res.choices) {
if (choice.finish_reason != "" && choice.finish_reason != null) {
finished = true
}
val sendResult = requestState.continuation.trySend(res)
if (sendResult.isFailure) {
// Handle the failure case if needed
logger.severe("Failed to send response: ${sendResult.exceptionOrNull()}")
}
if (finished) {
channel.close()
channelMap.remove(res.id)

res.usage?.let { finalUsage ->
requestState.request.stream_options?.include_usage?.let { includeUsage ->
if (includeUsage) {
requestState.continuation.send(res)
}
}
requestState.continuation.close()
requestStateMap.remove(res.id)
}
}

}
} catch (e: Exception) {
logger.severe("Kotlin JSON parsing error: $e, jsonsrc=$result")
}
}
}

private fun deinit() {
jsonFFIEngine.exitBackgroundLoop()
}
class Chat(
private val jsonFFIEngine: JSONFFIEngine,
private val state: EngineState
) {
val completions = Completions(jsonFFIEngine, state)
}

fun reload(modelPath: String, modelLib: String) {
val engineConfigJSONStr = """
{
"model": "$modelPath",
"model_lib": "system://$modelLib",
"mode": "interactive"
}
""".trimIndent()
jsonFFIEngine.reload(engineConfigJSONStr)
}
class Completions(
private val jsonFFIEngine: JSONFFIEngine,
private val state: EngineState
) {

private fun unload() {
jsonFFIEngine.unload()
suspend fun create(request: ChatCompletionRequest): ReceiveChannel<ChatCompletionStreamResponse> {
return state.chatCompletion(jsonFFIEngine, request)
}

fun chatCompletion(
suspend fun create(
messages: List<ChatCompletionMessage>,
model: String? = null,
frequency_penalty: Float? = null,
Expand All @@ -89,13 +158,18 @@ class MLCEngine () {
n: Int = 1,
seed: Int? = null,
stop: List<String>? = null,
stream: Boolean = false,
stream: Boolean = true,
stream_options: StreamOptions? = null,
temperature: Float? = null,
top_p: Float? = null,
tools: List<ChatTool>? = null,
user: String? = null,
response_format: ResponseFormat? = null
): ReceiveChannel<ChatCompletionStreamResponse> {
if (!stream) {
throw IllegalArgumentException("Only stream=true is supported in MLCKotlin")
}

val request = ChatCompletionRequest(
messages = messages,
model = model,
Expand All @@ -109,25 +183,14 @@ class MLCEngine () {
seed = seed,
stop = stop,
stream = stream,
stream_options = stream_options,
temperature = temperature,
top_p = top_p,
tools = tools,
user = user,
response_format = response_format
)
return chatCompletion(request)
}

private fun chatCompletion(request: ChatCompletionRequest): ReceiveChannel<ChatCompletionStreamResponse> {
val channel = Channel<ChatCompletionStreamResponse>()
val jsonRequest = Json.encodeToString(request)
val requestId = UUID.randomUUID().toString()

// Store the channel in the map for further callbacks
channelMap[requestId] = channel

jsonFFIEngine.chatCompletion(jsonRequest, requestId)

return channel
return create(request)
}
}

Loading
Loading