generated from JetBrains/intellij-platform-plugin-template
-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(custom): add custom SSE processor and JSON response callback #10
Add `CustomSSEProcessor` for SSE processing and `JSONBodyResponseCallback` for JSON response handling.
- Loading branch information
Showing
5 changed files
with
386 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
201 changes: 201 additions & 0 deletions
201
src/main/kotlin/com/phodal/shire/custom/CustomSSEProcessor.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
package com.phodal.shire.custom | ||
|
||
import com.fasterxml.jackson.databind.ObjectMapper | ||
import com.intellij.openapi.diagnostic.logger | ||
import com.intellij.openapi.project.Project | ||
import com.nfeld.jsonpathkt.JsonPath | ||
import com.nfeld.jsonpathkt.extension.read | ||
import com.phodal.shire.llm.LlmProvider.Companion.ChatRole | ||
import com.theokanning.openai.completion.chat.ChatCompletionResult | ||
import com.theokanning.openai.service.SSE | ||
import io.reactivex.BackpressureStrategy | ||
import io.reactivex.Flowable | ||
import io.reactivex.FlowableEmitter | ||
import kotlinx.coroutines.Dispatchers | ||
import kotlinx.coroutines.channels.awaitClose | ||
import kotlinx.coroutines.flow.Flow | ||
import kotlinx.coroutines.flow.callbackFlow | ||
import kotlinx.coroutines.withContext | ||
import kotlinx.serialization.Serializable | ||
import kotlinx.serialization.encodeToString | ||
import kotlinx.serialization.json.* | ||
import okhttp3.Call | ||
import okhttp3.Request | ||
import org.jetbrains.annotations.VisibleForTesting | ||
|
||
/** | ||
* The `CustomSSEProcessor` class is responsible for processing server-sent events (SSE) in a custom manner. | ||
* It provides functions to stream JSON and SSE data from a given `Call` instance, and exposes properties for request and response formats. | ||
* | ||
* @property hasSuccessRequest A boolean flag indicating whether the request was successful. | ||
* @property requestFormat A string representing the format of the request. | ||
* @property responseFormat A string representing the format of the response. | ||
* @property logger An instance of the logger for logging purposes. | ||
* | ||
* @constructor Creates an instance of `CustomSSEProcessor`. | ||
*/ | ||
open class CustomSSEProcessor(private val project: Project) { | ||
open var hasSuccessRequest: Boolean = false | ||
private var parseFailedResponses: MutableList<String> = mutableListOf() | ||
open val requestFormat: String = "" | ||
open val responseFormat: String = "" | ||
private val logger = logger<CustomSSEProcessor>() | ||
|
||
|
||
fun streamJson(call: Call, promptText: String, messages: MutableList<Message>): Flow<String> = callbackFlow { | ||
call.enqueue(JSONBodyResponseCallback(responseFormat) { | ||
withContext(Dispatchers.IO) { | ||
send(it) | ||
} | ||
|
||
messages += Message(ChatRole.Assistant.roleName(), it) | ||
close() | ||
}) | ||
awaitClose() | ||
} | ||
|
||
@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) | ||
fun streamSSE(call: Call, promptText: String, keepHistory: Boolean = false, messages: MutableList<Message>): Flow<String> { | ||
val sseFlowable = Flowable | ||
.create({ emitter: FlowableEmitter<SSE> -> | ||
call.enqueue(ResponseBodyCallback(emitter, true)) | ||
}, BackpressureStrategy.BUFFER) | ||
|
||
try { | ||
var output = "" | ||
return callbackFlow { | ||
withContext(Dispatchers.IO) { | ||
sseFlowable | ||
.doOnError { | ||
it.printStackTrace() | ||
trySend(it.message ?: "Error occurs") | ||
close() | ||
} | ||
.blockingForEach { sse -> | ||
if (responseFormat.isNotEmpty()) { | ||
// {"id":"cmpl-a22a0d78fcf845be98660628fe5d995b","object":"chat.completion.chunk","created":822330,"model":"moonshot-v1-8k","choices":[{"index":0,"delta":{},"finish_reason":"stop","usage":{"prompt_tokens":434,"completion_tokens":68,"total_tokens":502}}]} | ||
// in some case, the response maybe not equal to our response format, so we need to ignore it | ||
// {"id":"cmpl-ac26a17e","object":"chat.completion.chunk","created":1858403,"model":"yi-34b-chat","choices":[{"delta":{"role":"assistant"},"index":0}],"content":"","lastOne":false} | ||
val chunk: String? = JsonPath.parse(sse!!.data)?.read(responseFormat) | ||
|
||
// new JsonPath lib caught the exception, so we need to handle when it is null | ||
if (chunk == null) { | ||
parseFailedResponses.add(sse.data) | ||
logger.warn("Failed to parse response.origin response is: ${sse.data}, response format: $responseFormat") | ||
} else { | ||
hasSuccessRequest = true | ||
output += chunk | ||
trySend(chunk) | ||
} | ||
} else { | ||
val result: ChatCompletionResult = | ||
ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java) | ||
|
||
val completion = result.choices[0].message | ||
if (completion != null && completion.content != null) { | ||
output += completion.content | ||
trySend(completion.content) | ||
} | ||
} | ||
} | ||
|
||
// when stream finished, check if any response parsed succeeded | ||
// if not, notice user check response format | ||
if (!hasSuccessRequest) { | ||
val errorMsg = """ | ||
|**Failed** to parse response.please check your response format: | ||
|**$responseFormat** origin responses is: | ||
|- ${parseFailedResponses.joinToString("\n- ")} | ||
|""".trimMargin() | ||
|
||
// TODO add refresh feature | ||
// don't use trySend, it may be ignored by 'close()` op | ||
send(errorMsg) | ||
} | ||
|
||
messages += Message(ChatRole.Assistant.roleName(), output) | ||
close() | ||
} | ||
awaitClose() | ||
} | ||
} catch (e: Exception) { | ||
if (hasSuccessRequest) { | ||
logger.info("Failed to stream", e) | ||
} else { | ||
logger.error("Failed to stream", e) | ||
} | ||
|
||
return callbackFlow { | ||
close() | ||
} | ||
} finally { | ||
parseFailedResponses.clear() | ||
} | ||
} | ||
} | ||
|
||
@Serializable | ||
data class Message(val role: String, val content: String) | ||
|
||
@Serializable | ||
data class CustomRequest(val messages: List<Message>) | ||
|
||
@VisibleForTesting | ||
fun Request.Builder.appendCustomHeaders(customRequestHeader: String): Request.Builder = apply { | ||
runCatching { | ||
Json.parseToJsonElement(customRequestHeader) | ||
.jsonObject["customHeaders"].let { customFields -> | ||
customFields?.jsonObject?.forEach { (key, value) -> | ||
header(key, value.jsonPrimitive.content) | ||
} | ||
} | ||
}.onFailure { | ||
logger<CustomRequest>().warn("Failed to parse custom request header", it) | ||
} | ||
} | ||
|
||
@VisibleForTesting | ||
fun JsonObject.updateCustomBody(customRequest: String): JsonObject { | ||
return runCatching { | ||
buildJsonObject { | ||
// copy origin object | ||
this@updateCustomBody.forEach { u, v -> put(u, v) } | ||
|
||
val customRequestJson = Json.parseToJsonElement(customRequest).jsonObject | ||
customRequestJson["customFields"]?.let { customFields -> | ||
customFields.jsonObject.forEach { (key, value) -> | ||
put(key, value.jsonPrimitive) | ||
} | ||
} | ||
|
||
// TODO clean code with magic literals | ||
var roleKey = "role" | ||
var contentKey = "content" | ||
customRequestJson.jsonObject["messageKeys"]?.let { | ||
roleKey = it.jsonObject["role"]?.jsonPrimitive?.content ?: "role" | ||
contentKey = it.jsonObject["content"]?.jsonPrimitive?.content ?: "content" | ||
} | ||
|
||
val messages: JsonArray = this@updateCustomBody["messages"]?.jsonArray ?: buildJsonArray { } | ||
this.put("messages", buildJsonArray { | ||
messages.forEach { message -> | ||
val role: String = message.jsonObject["role"]?.jsonPrimitive?.content ?: "user" | ||
val content: String = message.jsonObject["content"]?.jsonPrimitive?.content ?: "" | ||
add(buildJsonObject { | ||
put(roleKey, role) | ||
put(contentKey, content) | ||
}) | ||
} | ||
}) | ||
} | ||
}.getOrElse { | ||
logger<CustomRequest>().error("Failed to parse custom request body", it) | ||
this | ||
} | ||
} | ||
|
||
fun CustomRequest.updateCustomFormat(format: String): String { | ||
val requestContentOri = Json.encodeToString<CustomRequest>(this) | ||
return Json.parseToJsonElement(requestContentOri) | ||
.jsonObject.updateCustomBody(format).toString() | ||
} |
34 changes: 34 additions & 0 deletions
34
src/main/kotlin/com/phodal/shire/custom/JSONBodyResponseCallback.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package com.phodal.shire.custom | ||
|
||
import com.nfeld.jsonpathkt.JsonPath | ||
import com.nfeld.jsonpathkt.extension.read | ||
import kotlinx.coroutines.runBlocking | ||
import okhttp3.Call | ||
import okhttp3.Callback | ||
import okhttp3.Response | ||
import java.io.IOException | ||
|
||
class JSONBodyResponseCallback(private val responseFormat: String,private val callback: suspend (String)->Unit): Callback { | ||
override fun onFailure(call: Call, e: IOException) { | ||
runBlocking { | ||
callback("error. ${e.message}") | ||
} | ||
} | ||
|
||
override fun onResponse(call: Call, response: Response) { | ||
val responseBody: String? = response.body?.string() | ||
if (responseFormat.isEmpty()) { | ||
runBlocking { | ||
callback(responseBody ?: "") | ||
} | ||
|
||
return | ||
} | ||
|
||
val responseContent: String = JsonPath.parse(responseBody)?.read(responseFormat) ?: "" | ||
|
||
runBlocking() { | ||
callback(responseContent) | ||
} | ||
} | ||
} |
Oops, something went wrong.