Skip to content

Commit

Permalink
Merge branch 'main' into update/gpt3.5-models
Browse files Browse the repository at this point in the history
  • Loading branch information
Montagon authored Oct 2, 2024
2 parents 8f9e8e4 + c55ae53 commit 4c5eec5
Show file tree
Hide file tree
Showing 92 changed files with 219,468 additions and 14,494 deletions.
242 changes: 111 additions & 131 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt
Original file line number Diff line number Diff line change
@@ -1,155 +1,135 @@
package com.xebia.functional.xef

import com.xebia.functional.openai.generated.api.Chat
import com.xebia.functional.openai.generated.api.Images
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequest
import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.models.modelType
import com.xebia.functional.xef.llm.prompt
import com.xebia.functional.xef.llm.promptStreaming
import com.xebia.functional.xef.prompt.Prompt
import kotlin.coroutines.cancellation.CancellationException
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.typeOf
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.serializer
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow

sealed interface AI {
class AI<out A>(private val config: AIConfig, val serializer: Tool<A>) {

interface PromptClassifier {
fun template(input: String, output: String, context: String): String
}

companion object {
private fun runStreamingWithStringSerializer(prompt: Prompt): Flow<String> =
config.api.promptStreaming(prompt, config.conversation, config.tools)

fun <A : Any> chat(
target: KType,
model: CreateChatCompletionRequestModel,
api: Chat,
conversation: Conversation,
enumSerializer: ((case: String) -> A)?,
caseSerializers: List<KSerializer<A>>,
serializer: () -> KSerializer<A>,
): DefaultAI<A> =
DefaultAI(
target = target,
model = model,
api = api,
serializer = serializer,
conversation = conversation,
enumSerializer = enumSerializer,
caseSerializers = caseSerializers
)

fun images(
config: Config = Config(),
): Images = OpenAI(config).images
@PublishedApi
internal suspend operator fun invoke(prompt: Prompt): A =
when (val serializer = serializer) {
is Tool.Callable -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.Contextual -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.Enumeration<A> -> runWithEnumSingleTokenSerializer(serializer, prompt)
is Tool.FlowOfStreamedFunctions<*> -> {
config.api.promptStreaming(prompt, config.conversation, serializer, config.tools) as A
}
is Tool.FlowOfStrings -> runStreamingWithStringSerializer(prompt) as A
is Tool.Primitive -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.Sealed -> config.api.prompt(prompt, config.conversation, serializer, config.tools)
is Tool.FlowOfAIEventsSealed ->
channelFlow {
send(AIEvent.Start)
config.api.prompt(
prompt = prompt,
scope = config.conversation,
serializer = serializer.sealedSerializer,
tools = config.tools,
collector = this
)
}
as A
is Tool.FlowOfAIEvents ->
channelFlow {
send(AIEvent.Start)
config.api.prompt(
prompt = prompt,
scope = config.conversation,
serializer = serializer.serializer,
tools = config.tools,
collector = this
)
}
as A
}

@PublishedApi
internal suspend inline fun <reified A : Any> invokeEnum(
prompt: Prompt,
target: KType = typeOf<A>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A =
chat(
target = target,
model = prompt.model,
api = api,
conversation = conversation,
enumSerializer = { @Suppress("UPPER_BOUND_VIOLATED") enumValueOf<A>(it) },
caseSerializers = emptyList()
) {
serializer<A>()
private suspend fun runWithEnumSingleTokenSerializer(
serializer: Tool.Enumeration<A>,
prompt: Prompt
): A {
val encoding = prompt.model.modelType(forFunctions = false).encoding
val cases = serializer.cases
val logitBias =
cases
.flatMap {
val result = encoding.encode(it.function.name)
if (result.size > 1) {
error("Cannot encode enum case $it into one token")
}
result
}
.invoke(prompt)
.associate { "$it" to 100 }
val result =
config.api.createChatCompletion(
CreateChatCompletionRequest(
messages = prompt.messages,
model = prompt.model,
logitBias = logitBias,
maxTokens = 1,
temperature = 0.0
)
)
val choice = result.choices[0].message.content
val enumSerializer = serializer.enumSerializer
return if (choice != null) {
enumSerializer(choice)
} else {
error("Cannot decode enum case from $choice")
}
}

/**
* Classify a prompt using a given enum.
*
* @param input The input to the model.
* @param output The output to the model.
* @param context The context to the model.
* @param model The model to use.
* @param target The target type to return.
* @param api The chat API to use.
* @param conversation The conversation to use.
* @return The classified enum.
* @throws IllegalArgumentException If no enum values are found.
*/
companion object {
@AiDsl
@Throws(IllegalArgumentException::class, CancellationException::class)
suspend inline fun <reified E> classify(
input: String,
output: String,
context: String,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_1106_preview,
target: KType = typeOf<E>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): E where E : PromptClassifier, E : Enum<E> {
val value = enumValues<E>().firstOrNull() ?: error("No enum values found")
return invoke(
config: AIConfig = AIConfig(),
): E where E : Enum<E>, E : PromptClassifier {
val value = enumValues<E>().firstOrNull() ?: error("No values to classify")
return AI<E>(
prompt = value.template(input, output, context),
model = model,
target = target,
config = config,
api = api,
conversation = conversation
)
}

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: String,
target: KType = typeOf<A>(),
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_3_5_turbo_0125,
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A = chat(Prompt(model, prompt), target, config, api, conversation)

@AiDsl
suspend inline operator fun <reified A : Any> invoke(
prompt: Prompt,
target: KType = typeOf<A>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A = chat(prompt, target, config, api, conversation)

@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
@AiDsl
suspend inline fun <reified A : Any> chat(
prompt: Prompt,
target: KType = typeOf<A>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
): A {
val kind =
(target.classifier as? KClass<*>)?.serializer()?.descriptor?.kind
?: error("Cannot find SerialKind for $target")
return when (kind) {
SerialKind.ENUM -> invokeEnum<A>(prompt, target, config, api, conversation)
else -> {
chat(
target = target,
model = prompt.model,
api = api,
conversation = conversation,
enumSerializer = null,
caseSerializers = emptyList()
) {
serializer<A>()
}
.invoke(prompt)
}
}
suspend inline fun <reified E> multipleClassify(
input: String,
config: AIConfig = AIConfig(),
): List<E> where E : Enum<E>, E : PromptMultipleClassifier {
val values = enumValues<E>()
val value = values.firstOrNull() ?: error("No values to classify")
val selected: SelectedItems =
AI(
prompt = value.template(input),
serializer = Tool.fromKotlin<SelectedItems>(),
config = config
)
return selected.selectedItems.mapNotNull { values.elementAtOrNull(it) }
}
}
}

@AiDsl
suspend inline fun <reified A> AI(
prompt: String,
serializer: Tool<A> = Tool.fromKotlin<A>(),
config: AIConfig = AIConfig()
): A = AI(Prompt(config.model, prompt), serializer, config)

@AiDsl
suspend inline fun <reified A> AI(
prompt: Prompt,
serializer: Tool<A> = Tool.fromKotlin<A>(),
config: AIConfig = AIConfig(),
): A = AI(config, serializer).invoke(prompt)
15 changes: 15 additions & 0 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AIConfig.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.xebia.functional.xef

import com.xebia.functional.openai.generated.api.Chat
import com.xebia.functional.openai.generated.api.OpenAI
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.xef.conversation.Conversation

data class AIConfig(
val tools: List<Tool<*>> = emptyList(),
val model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4o,
val config: Config = Config(),
val openAI: OpenAI = OpenAI(config, logRequests = false),
val api: Chat = openAI.chat,
val conversation: Conversation = Conversation()
)
36 changes: 36 additions & 0 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AIEvent.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.xebia.functional.xef

sealed class AIEvent<out A> {
data object Start : AIEvent<Nothing>()

data class Result<out A>(val value: A) : AIEvent<A>()

data class ToolExecutionRequest(val tool: Tool<*>, val input: Any?) : AIEvent<Nothing>()

data class ToolExecutionResponse(val tool: Tool<*>, val output: Any?) : AIEvent<Nothing>()

data class Stop(val usage: Usage) : AIEvent<Nothing>() {
data class Usage(
val llmCalls: Int,
val toolCalls: Int,
val inputTokens: Int,
val outputTokens: Int,
val totalTokens: Int,
)
}

fun debugPrint(): Unit =
when (this) {
// emoji for start is: 🚀
Start -> println("🚀 Starting...")
is Result -> println("🎉 $value")
is ToolExecutionRequest ->
println("🔧 Executing tool: ${tool.function.name} with input: $input")
is ToolExecutionResponse ->
println("🔨 Tool response: ${tool.function.name} resulted in: $output")
is Stop -> {
println("🛑 Stopping...")
println("📊 Usage: $usage")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.xebia.functional.xef

data class Classification(
val name: String,
val description: String,
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ data class Config(
prettyPrint = false
isLenient = true
explicitNulls = false
classDiscriminator = "_type_"
classDiscriminator = TYPE_DISCRIMINATOR
},
val streamingPrefix: String = "data:",
val streamingDelimiter: String = "data: [DONE]"
) {
companion object {
val DEFAULT = Config()
const val TYPE_DISCRIMINATOR = "_type_"
}
}

Expand Down
Loading

0 comments on commit 4c5eec5

Please sign in to comment.