Skip to content

Commit

Permalink
feat(custom): add custom SSE processor and JSON response callback #10
Browse files Browse the repository at this point in the history
Add `CustomSSEProcessor` for SSE processing and `JSONBodyResponseCallback` for JSON response handling.
  • Loading branch information
phodal committed Jun 24, 2024
1 parent ee02a76 commit 6f75068
Show file tree
Hide file tree
Showing 5 changed files with 386 additions and 0 deletions.
4 changes: 4 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ project(":") {
dependencies {
implementation(project(":core"))

// custom agent deps
implementation(libs.json.pathkt)
implementation(libs.okhttp)
implementation(libs.okhttp.sse)
// open ai deps
implementation("com.theokanning.openai-gpt3-java:service:0.18.2")
implementation("com.squareup.retrofit2:converter-jackson:2.11.0")
Expand Down
4 changes: 4 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ kover = "0.8.1"
[libraries]
kotlinx-serialization-json = "org.jetbrains.kotlinx:kotlinx-serialization-json:1.7.0"
kotlinx-coroutines-core = "org.jetbrains.kotlinx:kotlinx-coroutines-core:1.8.1"
json-pathkt = "com.nfeld.jsonpathkt:jsonpathkt:2.0.1"

okhttp = "com.squareup.okhttp3:okhttp:4.4.1"
okhttp-sse = "com.squareup.okhttp3:okhttp-sse:4.4.1"

[plugins]
changelog = { id = "org.jetbrains.changelog", version.ref = "changelog" }
Expand Down
201 changes: 201 additions & 0 deletions src/main/kotlin/com/phodal/shire/custom/CustomSSEProcessor.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package com.phodal.shire.custom

import com.fasterxml.jackson.databind.ObjectMapper
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.project.Project
import com.nfeld.jsonpathkt.JsonPath
import com.nfeld.jsonpathkt.extension.read
import com.phodal.shire.llm.LlmProvider.Companion.ChatRole
import com.theokanning.openai.completion.chat.ChatCompletionResult
import com.theokanning.openai.service.SSE
import io.reactivex.BackpressureStrategy
import io.reactivex.Flowable
import io.reactivex.FlowableEmitter
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.withContext
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.*
import okhttp3.Call
import okhttp3.Request
import org.jetbrains.annotations.VisibleForTesting

/**
* The `CustomSSEProcessor` class is responsible for processing server-sent events (SSE) in a custom manner.
* It provides functions to stream JSON and SSE data from a given `Call` instance, and exposes properties for request and response formats.
*
* @property hasSuccessRequest A boolean flag indicating whether the request was successful.
* @property requestFormat A string representing the format of the request.
* @property responseFormat A string representing the format of the response.
* @property logger An instance of the logger for logging purposes.
*
* @constructor Creates an instance of `CustomSSEProcessor`.
*/
open class CustomSSEProcessor(private val project: Project) {
open var hasSuccessRequest: Boolean = false
private var parseFailedResponses: MutableList<String> = mutableListOf()
open val requestFormat: String = ""
open val responseFormat: String = ""
private val logger = logger<CustomSSEProcessor>()


fun streamJson(call: Call, promptText: String, messages: MutableList<Message>): Flow<String> = callbackFlow {
call.enqueue(JSONBodyResponseCallback(responseFormat) {
withContext(Dispatchers.IO) {
send(it)
}

messages += Message(ChatRole.Assistant.roleName(), it)
close()
})
awaitClose()
}

@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class)
fun streamSSE(call: Call, promptText: String, keepHistory: Boolean = false, messages: MutableList<Message>): Flow<String> {
val sseFlowable = Flowable
.create({ emitter: FlowableEmitter<SSE> ->
call.enqueue(ResponseBodyCallback(emitter, true))
}, BackpressureStrategy.BUFFER)

try {
var output = ""
return callbackFlow {
withContext(Dispatchers.IO) {
sseFlowable
.doOnError {
it.printStackTrace()
trySend(it.message ?: "Error occurs")
close()
}
.blockingForEach { sse ->
if (responseFormat.isNotEmpty()) {
// {"id":"cmpl-a22a0d78fcf845be98660628fe5d995b","object":"chat.completion.chunk","created":822330,"model":"moonshot-v1-8k","choices":[{"index":0,"delta":{},"finish_reason":"stop","usage":{"prompt_tokens":434,"completion_tokens":68,"total_tokens":502}}]}
// in some case, the response maybe not equal to our response format, so we need to ignore it
// {"id":"cmpl-ac26a17e","object":"chat.completion.chunk","created":1858403,"model":"yi-34b-chat","choices":[{"delta":{"role":"assistant"},"index":0}],"content":"","lastOne":false}
val chunk: String? = JsonPath.parse(sse!!.data)?.read(responseFormat)

// new JsonPath lib caught the exception, so we need to handle when it is null
if (chunk == null) {
parseFailedResponses.add(sse.data)
logger.warn("Failed to parse response.origin response is: ${sse.data}, response format: $responseFormat")
} else {
hasSuccessRequest = true
output += chunk
trySend(chunk)
}
} else {
val result: ChatCompletionResult =
ObjectMapper().readValue(sse!!.data, ChatCompletionResult::class.java)

val completion = result.choices[0].message
if (completion != null && completion.content != null) {
output += completion.content
trySend(completion.content)
}
}
}

// when stream finished, check if any response parsed succeeded
// if not, notice user check response format
if (!hasSuccessRequest) {
val errorMsg = """
|**Failed** to parse response.please check your response format:
|**$responseFormat** origin responses is:
|- ${parseFailedResponses.joinToString("\n- ")}
|""".trimMargin()

// TODO add refresh feature
// don't use trySend, it may be ignored by 'close()` op
send(errorMsg)
}

messages += Message(ChatRole.Assistant.roleName(), output)
close()
}
awaitClose()
}
} catch (e: Exception) {
if (hasSuccessRequest) {
logger.info("Failed to stream", e)
} else {
logger.error("Failed to stream", e)
}

return callbackFlow {
close()
}
} finally {
parseFailedResponses.clear()
}
}
}

@Serializable
data class Message(val role: String, val content: String)

@Serializable
data class CustomRequest(val messages: List<Message>)

@VisibleForTesting
fun Request.Builder.appendCustomHeaders(customRequestHeader: String): Request.Builder = apply {
runCatching {
Json.parseToJsonElement(customRequestHeader)
.jsonObject["customHeaders"].let { customFields ->
customFields?.jsonObject?.forEach { (key, value) ->
header(key, value.jsonPrimitive.content)
}
}
}.onFailure {
logger<CustomRequest>().warn("Failed to parse custom request header", it)
}
}

@VisibleForTesting
fun JsonObject.updateCustomBody(customRequest: String): JsonObject {
return runCatching {
buildJsonObject {
// copy origin object
this@updateCustomBody.forEach { u, v -> put(u, v) }

val customRequestJson = Json.parseToJsonElement(customRequest).jsonObject
customRequestJson["customFields"]?.let { customFields ->
customFields.jsonObject.forEach { (key, value) ->
put(key, value.jsonPrimitive)
}
}

// TODO clean code with magic literals
var roleKey = "role"
var contentKey = "content"
customRequestJson.jsonObject["messageKeys"]?.let {
roleKey = it.jsonObject["role"]?.jsonPrimitive?.content ?: "role"
contentKey = it.jsonObject["content"]?.jsonPrimitive?.content ?: "content"
}

val messages: JsonArray = this@updateCustomBody["messages"]?.jsonArray ?: buildJsonArray { }
this.put("messages", buildJsonArray {
messages.forEach { message ->
val role: String = message.jsonObject["role"]?.jsonPrimitive?.content ?: "user"
val content: String = message.jsonObject["content"]?.jsonPrimitive?.content ?: ""
add(buildJsonObject {
put(roleKey, role)
put(contentKey, content)
})
}
})
}
}.getOrElse {
logger<CustomRequest>().error("Failed to parse custom request body", it)
this
}
}

fun CustomRequest.updateCustomFormat(format: String): String {
val requestContentOri = Json.encodeToString<CustomRequest>(this)
return Json.parseToJsonElement(requestContentOri)
.jsonObject.updateCustomBody(format).toString()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.phodal.shire.custom

import com.nfeld.jsonpathkt.JsonPath
import com.nfeld.jsonpathkt.extension.read
import kotlinx.coroutines.runBlocking
import okhttp3.Call
import okhttp3.Callback
import okhttp3.Response
import java.io.IOException

class JSONBodyResponseCallback(private val responseFormat: String,private val callback: suspend (String)->Unit): Callback {
override fun onFailure(call: Call, e: IOException) {
runBlocking {
callback("error. ${e.message}")
}
}

override fun onResponse(call: Call, response: Response) {
val responseBody: String? = response.body?.string()
if (responseFormat.isEmpty()) {
runBlocking {
callback(responseBody ?: "")
}

return
}

val responseContent: String = JsonPath.parse(responseBody)?.read(responseFormat) ?: ""

runBlocking() {
callback(responseContent)
}
}
}
Loading

0 comments on commit 6f75068

Please sign in to comment.