diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index c5cff0a41..1ca556221 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -39,10 +39,14 @@ jackson = "2.15.2" jsonschema = "4.31.1" jakarta = "3.0.2" suspend-transform = "0.3.1" +suspendApp = "0.4.0" [libraries] arrow-core = { module = "io.arrow-kt:arrow-core", version.ref = "arrow" } +arrow-continuations = { module = "io.arrow-kt:arrow-continuations", version.ref = "arrow" } arrow-fx-coroutines = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" } +suspendApp-core = { module = "io.arrow-kt:suspendapp", version.ref = "suspendApp" } +suspendApp-ktor = { module = "io.arrow-kt:suspendapp-ktor", version.ref = "suspendApp" } kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-json" } kotlinx-coroutines = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref="kotlinx-coroutines" } kotlinx-coroutines-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-reactive", version.ref="kotlinx-coroutines-reactive" } @@ -54,6 +58,13 @@ ktor-client-serialization = { module = "io.ktor:ktor-serialization-kotlinx-json" ktor-client-cio = { module = "io.ktor:ktor-client-cio", version.ref = "ktor" } ktor-client-js = { module = "io.ktor:ktor-client-js", version.ref = "ktor" } ktor-client-winhttp = { module = "io.ktor:ktor-client-winhttp", version.ref = "ktor" } +ktor-server-auth = { module = "io.ktor:ktor-server-auth", version.ref = "ktor" } +ktor-server-core = { module = "io.ktor:ktor-server-core", version.ref = "ktor" } +ktor-server-netty = { module = "io.ktor:ktor-server-netty", version.ref = "ktor" } +ktor-server-contentNegotiation = { module = "io.ktor:ktor-server-content-negotiation", version.ref = "ktor" } +ktor-server-resources = { module = "io.ktor:ktor-server-resources", version.ref = "ktor" } +ktor-server-cors = { module = "io.ktor:ktor-server-cors", version.ref = "ktor" } +ktor-server-request-validation = { module = "io.ktor:ktor-server-request-validation", version.ref = "ktor" } okio = { module = "com.squareup.okio:okio", version.ref = "okio" } okio-fakefilesystem = { module = "com.squareup.okio:okio-fakefilesystem", version.ref = "okio" } okio-nodefilesystem = { module = "com.squareup.okio:okio-nodefilesystem", version.ref = "okio" } @@ -63,6 +74,7 @@ kotest-property = { module = "io.kotest:kotest-property", version.ref = "kotest" kotest-junit5 = { module = "io.kotest:kotest-runner-junit5", version.ref = "kotest" } kotest-testcontainers = { module = "io.kotest.extensions:kotest-extensions-testcontainers", version.ref = "kotest-testcontainers" } kotest-assertions-arrow = { module = "io.kotest.extensions:kotest-assertions-arrow", version.ref = "kotest-arrow" } +ktor-serialization-json = { module = "io.ktor:ktor-serialization-kotlinx-json", version.ref = "ktor" } junit-jupiter-api = { module = "org.junit.jupiter:junit-jupiter-api", version.ref = "junit" } uuid = { module = "app.softwork:kotlinx-uuid-core", version.ref = "uuid" } klogging = { module = "io.github.oshai:kotlin-logging", version.ref = "klogging" } diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAI.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAI.kt index 0c1f1c10c..de58aa4d3 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAI.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/auto/llm/openai/OpenAI.kt @@ -72,4 +72,27 @@ class OpenAI(internal val token: String) : AutoCloseable, AutoClose by autoClose @JvmField val DEFAULT_IMAGES = DEFAULT.DALLE_2 } + + fun supportedModels(): List { + return listOf( + GPT_4, + GPT_4_0314, + GPT_4_32K, + GPT_3_5_TURBO, + GPT_3_5_TURBO_16K, + GPT_3_5_TURBO_FUNCTIONS, + GPT_3_5_TURBO_0301, + TEXT_DAVINCI_003, + TEXT_DAVINCI_002, + TEXT_CURIE_001, + TEXT_BABBAGE_001, + TEXT_ADA_001, + TEXT_EMBEDDING_ADA_002, + DALLE_2 + ) + } +} + +fun String.toOpenAIModel(): OpenAIModel? { + return OpenAI.DEFAULT.supportedModels().find { it.name == this } } diff --git a/server/build.gradle.kts b/server/build.gradle.kts new file mode 100644 index 000000000..1ae876901 --- /dev/null +++ b/server/build.gradle.kts @@ -0,0 +1,43 @@ +plugins { + id(libs.plugins.kotlin.jvm.get().pluginId) + id(libs.plugins.kotlinx.serialization.get().pluginId) +} + +repositories { + mavenCentral() +} + +java { + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_11 + toolchain { + languageVersion = JavaLanguageVersion.of(11) + } +} + +dependencies { + implementation(projects.xefCore) + implementation(projects.xefKotlin) + implementation(libs.kotlinx.serialization.json) + implementation(libs.logback) + implementation(libs.klogging) + implementation(libs.ktor.server.auth) + implementation(libs.ktor.server.netty) + implementation(libs.ktor.server.core) + implementation(libs.ktor.server.contentNegotiation) + implementation(libs.ktor.server.resources) + implementation(libs.ktor.server.cors) + implementation(libs.ktor.serialization.json) + implementation(libs.suspendApp.core) + implementation(libs.suspendApp.ktor) + implementation(libs.ktor.server.request.validation) + implementation(libs.openai.client) +} + +tasks.getByName("processResources") { + dependsOn(projects.xefGpt4all.dependencyProject.tasks.getByName("jvmProcessResources")) + from("${projects.xefGpt4all.dependencyProject.buildDir}/processedResources/jvm/main") + into("$buildDir/resources/main") +} + + diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/Main.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/Main.kt new file mode 100644 index 000000000..33b9d53bd --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/Main.kt @@ -0,0 +1,41 @@ +package com.xebia.functional.xef.server + + +import arrow.continuations.SuspendApp +import arrow.fx.coroutines.resourceScope +import arrow.continuations.ktor.server +import com.xebia.functional.xef.server.http.routes.routes +import io.ktor.serialization.kotlinx.json.* +import io.ktor.server.application.* +import io.ktor.server.auth.* +import io.ktor.server.netty.* +import io.ktor.server.plugins.contentnegotiation.* +import io.ktor.server.plugins.cors.routing.* +import io.ktor.server.resources.* +import io.ktor.server.routing.* +import kotlinx.coroutines.awaitCancellation + +object Main { + @JvmStatic + fun main(args: Array) = SuspendApp { + resourceScope { + server(factory = Netty, port = 8080, host = "0.0.0.0") { + install(CORS) { + allowNonSimpleContentTypes = true + anyHost() + } + install(ContentNegotiation) { json() } + install(Resources) + install(Authentication) { + bearer("auth-bearer") { + authenticate { tokenCredential -> + UserIdPrincipal(tokenCredential.token) + } + } + } + routing { routes() } + } + awaitCancellation() + } + } +} diff --git a/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/Routes.kt b/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/Routes.kt new file mode 100644 index 000000000..04e27c2cc --- /dev/null +++ b/server/src/main/kotlin/com/xebia/functional/xef/server/http/routes/Routes.kt @@ -0,0 +1,81 @@ +package com.xebia.functional.xef.server.http.routes + +import com.aallam.openai.api.BetaOpenAI +import com.aallam.openai.api.chat.ChatCompletionRequest +import com.aallam.openai.api.chat.ChatRole +import com.xebia.functional.xef.auto.CoreAIScope +import com.xebia.functional.xef.auto.PromptConfiguration +import com.xebia.functional.xef.auto.llm.openai.* +import com.xebia.functional.xef.auto.llm.openai.OpenAI.Companion.DEFAULT_CHAT +import com.xebia.functional.xef.llm.Chat +import com.xebia.functional.xef.llm.models.chat.Message +import com.xebia.functional.xef.llm.models.chat.Role +import io.ktor.http.* +import io.ktor.server.application.* +import io.ktor.server.auth.* +import io.ktor.server.request.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import io.ktor.util.pipeline.* +import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequest as XefChatCompletionRequest + +@OptIn(BetaOpenAI::class) +fun Routing.routes() { + authenticate("auth-bearer") { + post("/chat/completions") { + val model: Chat = call.request.headers["xef-model"]?.toOpenAIModel() ?: DEFAULT_CHAT + val token = call.principal()?.name ?: throw IllegalArgumentException("No token found") + val scope = CoreAIScope(OpenAIEmbeddings(OpenAI(token).GPT_3_5_TURBO_16K)) + val data = call.receive().toCore() + response { + model.promptMessage( + question = data.messages.joinToString("\n") { "${it.role}: ${it.content}" }, + context = scope.context, + promptConfiguration = PromptConfiguration( + temperature = data.temperature, + numberOfPredictions = data.n, + user = data.user ?: "" + ) + ) + } + } + } +} + + +/** + * Responds with the data and converts any potential Throwable into a 404. + */ +private suspend inline fun PipelineContext<*, ApplicationCall>.response( + block: () -> T +) = arrow.core.raise.recover({ + call.respond(block()) +}) { + call.respondText(it.message ?: "Response not found", status = HttpStatusCode.NotFound) +} + +@OptIn(BetaOpenAI::class) +private fun ChatCompletionRequest.toCore(): XefChatCompletionRequest = XefChatCompletionRequest( + model = model.id, + messages = messages.map { Message(it.role.toCore(), it.content ?: "", it.name ?: "") }, + temperature = temperature ?: 0.0, + topP = topP ?: 1.0, + n = n ?: 1, + stream = false, + stop = stop, + maxTokens = maxTokens, + presencePenalty = presencePenalty ?: 0.0, + frequencyPenalty = frequencyPenalty ?: 0.0, + logitBias = logitBias ?: emptyMap(), + user = user, + streamToStandardOut = false +) + +@OptIn(BetaOpenAI::class) +private fun ChatRole.toCore(): Role = + when (this) { + ChatRole.System -> Role.SYSTEM + ChatRole.User -> Role.USER + ChatRole.Assistant -> Role.ASSISTANT + else -> Role.ASSISTANT + } diff --git a/settings.gradle.kts b/settings.gradle.kts index 08ff7e6c0..c59fdb9c0 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -87,3 +87,7 @@ project(":xef-reasoning").projectDir = file("reasoning") include("xef-java-examples") project(":xef-java-examples").projectDir = file("examples/java") // + +// +include("xef-server") +project(":xef-server").projectDir = file("server")