Skip to content

Commit

Permalink
Expression Language for LLM driven template replacements (#298)
Browse files Browse the repository at this point in the history
* Expression Language for LLM driven template replacements

* Prompt adjustments

* Prompt adjustments and better structure for building final prompt based on messages
  • Loading branch information
raulraja authored Aug 7, 2023
1 parent 2c69451 commit dbf500b
Show file tree
Hide file tree
Showing 15 changed files with 289 additions and 138 deletions.
33 changes: 27 additions & 6 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ kotlin {
api(libs.kotlinx.serialization.json)
api(libs.ktor.utils)
api(projects.xefTokenizer)

implementation(libs.bundles.ktor.client)
implementation(libs.klogging)
implementation(libs.uuid)
}
Expand All @@ -87,21 +87,42 @@ kotlin {
implementation(libs.logback)
implementation(libs.skrape)
implementation(libs.rss.reader)
api(libs.ktor.client.cio)
}
}

val jsMain by getting
val jsMain by getting {
dependencies {
api(libs.ktor.client.js)
}
}

val jvmTest by getting {
dependencies {
implementation(libs.kotest.junit5)
}
}

val linuxX64Main by getting
val macosX64Main by getting
val macosArm64Main by getting
val mingwX64Main by getting
val linuxX64Main by getting {
dependencies {
implementation(libs.ktor.client.cio)
}
}
val macosX64Main by getting {
dependencies {
implementation(libs.ktor.client.cio)
}
}
val macosArm64Main by getting {
dependencies {
implementation(libs.ktor.client.cio)
}
}
val mingwX64Main by getting {
dependencies {
implementation(libs.ktor.client.winhttp)
}
}
val linuxX64Test by getting
val macosX64Test by getting
val macosArm64Test by getting
Expand Down
70 changes: 38 additions & 32 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ interface Chat : LLM {
): Flow<String> = flow {
val memories: List<Memory> = memories(conversationId, context, promptConfiguration)

val promptWithContext: String =
val promptWithContext: List<Message> =
createPromptWithContextAwareOfTokens(
memories = memories,
ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext),
Expand All @@ -55,7 +55,7 @@ interface Chat : LLM {
minResponseTokens = promptConfiguration.minResponseTokens
)

val messages: List<Message> = messages(memories, promptWithContext)
val messages: List<Message> = messagesFromMemory(memories) + promptWithContext

fun checkTotalLeftChatTokens(): Int {
val maxContextLength: Int = modelType.maxContextLength
Expand Down Expand Up @@ -138,32 +138,22 @@ interface Chat : LLM {

@AiDsl
suspend fun promptMessages(
prompt: Prompt,
messages: List<Message>,
context: VectorStore,
conversationId: ConversationId? = null,
functions: List<CFunction> = emptyList(),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): List<String> {

val memories: List<Memory> = memories(conversationId, context, promptConfiguration)

val promptWithContext: String =
createPromptWithContextAwareOfTokens(
memories = memories,
ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext),
modelType = modelType,
prompt = prompt.message,
minResponseTokens = promptConfiguration.minResponseTokens
)

val messages: List<Message> = messages(memories, promptWithContext)
val allMessages = messagesFromMemory(memories) + messages

fun checkTotalLeftChatTokens(): Int {
val maxContextLength: Int = modelType.maxContextLength
val messagesTokens: Int = tokensFromMessages(messages)
val messagesTokens: Int = tokensFromMessages(allMessages)
val totalLeftTokens: Int = maxContextLength - messagesTokens
if (totalLeftTokens < 0) {
throw AIError.MessagesExceedMaxTokenLength(messages, messagesTokens, maxContextLength)
throw AIError.MessagesExceedMaxTokenLength(allMessages, messagesTokens, maxContextLength)
}
return totalLeftTokens
}
Expand Down Expand Up @@ -217,6 +207,29 @@ interface Chat : LLM {
}
}

@AiDsl
suspend fun promptMessages(
prompt: Prompt,
context: VectorStore,
conversationId: ConversationId? = null,
functions: List<CFunction> = emptyList(),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): List<String> {

val memories: List<Memory> = memories(conversationId, context, promptConfiguration)

val promptWithContext: List<Message> =
createPromptWithContextAwareOfTokens(
memories = memories,
ctxInfo = context.similaritySearch(prompt.message, promptConfiguration.docsInContext),
modelType = modelType,
prompt = prompt.message,
minResponseTokens = promptConfiguration.minResponseTokens
)

return promptMessages(promptWithContext, context, conversationId, functions, promptConfiguration)
}

private suspend fun List<ChoiceWithFunctions>.addChoiceWithFunctionsToMemory(
request: ChatCompletionRequestWithFunctions,
context: VectorStore,
Expand Down Expand Up @@ -274,8 +287,8 @@ interface Chat : LLM {
}
}

private fun messages(memories: List<Memory>, promptWithContext: String): List<Message> =
memories.map { it.content } + listOf(Message(Role.USER, promptWithContext, Role.USER.name))
private fun messagesFromMemory(memories: List<Memory>): List<Message> =
memories.map { it.content }

private suspend fun memories(
conversationId: ConversationId?,
Expand All @@ -288,13 +301,13 @@ interface Chat : LLM {
emptyList()
}

private fun createPromptWithContextAwareOfTokens(
private suspend fun createPromptWithContextAwareOfTokens(
memories: List<Memory>,
ctxInfo: List<String>,
modelType: ModelType,
prompt: String,
minResponseTokens: Int,
): String {
): List<Message> {
val maxContextLength: Int = modelType.maxContextLength
val promptTokens: Int = modelType.encoding.countTokens(prompt)
val memoryTokens = tokensFromMessages(memories.map { it.content })
Expand All @@ -311,17 +324,10 @@ interface Chat : LLM {
// alternatively we could summarize the context, but that's not implemented yet
val ctxTruncated: String = modelType.encoding.truncateText(ctx, remainingTokens)

"""|```Context
|${ctxTruncated}
|```
|The context is related to the question try to answer the `goal` as best as you can
|or provide information about the found content
|```goal
|${prompt}
|```
|ANSWER:
|"""
.trimMargin()
} else prompt
listOf(
Message.assistantMessage { "Context: $ctxTruncated" },
Message.userMessage { prompt }
)
} else listOf(Message.userMessage { prompt })
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.xebia.functional.xef.auto.AiDsl
import com.xebia.functional.xef.auto.PromptConfiguration
import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequestWithFunctions
import com.xebia.functional.xef.llm.models.chat.ChatCompletionResponseWithFunctions
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.functions.CFunction
import com.xebia.functional.xef.llm.models.functions.encodeJsonSchema
import com.xebia.functional.xef.prompt.Prompt
Expand Down Expand Up @@ -45,6 +46,29 @@ interface ChatWithFunctions : Chat {
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): A = prompt(prompt, context, conversationId, functions, serializer, promptConfiguration)

@AiDsl
suspend fun <A> prompt(
messages: List<Message>,
context: VectorStore,
serializer: KSerializer<A>,
conversationId: ConversationId? = null,
functions: List<CFunction> = generateCFunction(serializer.descriptor),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): A {
return tryDeserialize(
{ json -> Json.decodeFromString(serializer, json) },
promptConfiguration.maxDeserializationAttempts
) {
promptMessages(
messages = messages,
context = context,
conversationId = conversationId,
functions = functions,
promptConfiguration
)
}
}

@AiDsl
suspend fun <A> prompt(
prompt: Prompt,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
package com.xebia.functional.xef.llm.models.chat

data class Message(val role: Role, val content: String, val name: String)
data class Message(val role: Role, val content: String, val name: String) {
companion object {
suspend fun systemMessage(message: suspend () -> String) =
Message(role = Role.SYSTEM, content = message(), name = Role.SYSTEM.name)

suspend fun userMessage(message: suspend () -> String) =
Message(role = Role.USER, content = message(), name = Role.USER.name)

suspend fun assistantMessage(message: suspend () -> String) =
Message(role = Role.ASSISTANT, content = message(), name = Role.ASSISTANT.name)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package com.xebia.functional.xef.prompt.expressions

import com.xebia.functional.xef.auto.CoreAIScope
import com.xebia.functional.xef.auto.PromptConfiguration
import com.xebia.functional.xef.llm.ChatWithFunctions
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.prompt.experts.ExpertSystem
import io.github.oshai.kotlinlogging.KLogger
import io.github.oshai.kotlinlogging.KotlinLogging

class Expression(
private val scope: CoreAIScope,
private val model: ChatWithFunctions,
val block: suspend Expression.() -> Unit
) {

private val logger: KLogger = KotlinLogging.logger {}

private val messages: MutableList<Message> = mutableListOf()

private val generationKeys: MutableList<String> = mutableListOf()

suspend fun system(message: suspend () -> String) {
messages.add(Message.systemMessage(message))
}

suspend fun user(message: suspend () -> String) {
messages.add(Message.userMessage(message))
}

suspend fun assistant(message: suspend () -> String) {
messages.add(Message.assistantMessage(message))
}

fun prompt(key: String): String {
generationKeys.add(key)
return "{{$key}}"
}

suspend fun run(
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): ExpressionResult {
block()
val instructionMessage =
Message(
role = Role.USER,
content =
ExpertSystem(
system = "You are an expert in replacing variables in templates",
query =
"""
|I want to replace the following variables in the following template:
|<template>
|${messages.joinToString("\n") { it.content }}
|</template>
|The variables are:
|${generationKeys.joinToString("\n") { it }}
"""
.trimMargin(),
instructions =
listOf(
"Create a `ReplacedValues` object with the `replacements` where the keys are the variable names and the values are the values to replace them with.",
)
)
.message,
name = Role.USER.name
)
val values: ReplacedValues =
model.prompt(
messages = messages + instructionMessage,
context = scope.context,
serializer = ReplacedValues.serializer(),
conversationId = scope.conversationId,
promptConfiguration = promptConfiguration
)
logger.info { "replaced: ${values.replacements.joinToString { it.key }}" }
val replacedTemplate =
messages.fold("") { acc, message ->
val replacedMessage =
generationKeys.fold(message.content) { acc, key ->
acc.replace(
"{{$key}}",
values.replacements.firstOrNull { it.key == key }?.value ?: "{{$key}}"
)
}
acc + replacedMessage + "\n"
}
return ExpressionResult(messages = messages, result = replacedTemplate, values = values)
}

companion object {
suspend fun run(
scope: CoreAIScope,
model: ChatWithFunctions,
block: suspend Expression.() -> Unit
): ExpressionResult = Expression(scope, model, block).run()

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.xebia.functional.xef.prompt.expressions

import com.xebia.functional.xef.llm.models.chat.Message

data class ExpressionResult(
val messages: List<Message>,
val result: String,
val values: ReplacedValues,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.xebia.functional.xef.prompt.expressions

import com.xebia.functional.xef.auto.Description
import kotlinx.serialization.Serializable

@Serializable
data class ReplacedValues(
@Description(["The values that are generated for the template"])
val replacements: List<Replacement>
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.xebia.functional.xef.prompt.expressions

import com.xebia.functional.xef.auto.Description
import kotlinx.serialization.Serializable

@Serializable
data class Replacement(
@Description(["The key originally in {{key}} format that was going to get replaced"])
val key: String,
@Description(["The Assistant generated value that the `key` should be replaced with"])
val value: String
)
Loading

0 comments on commit dbf500b

Please sign in to comment.