diff --git a/README.md b/README.md index fa0ee84..4f68aba 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,12 @@ Add the following dependencies to your `pom.xml`: [LATEST_VERSION] - + + + dev.langchain4j + langchain4j + 0.36.0 + dev.langchain4j langchain4j-open-ai diff --git a/langchain4j-kotlin/pom.xml b/langchain4j-kotlin/pom.xml index bef22db..2182440 100644 --- a/langchain4j-kotlin/pom.xml +++ b/langchain4j-kotlin/pom.xml @@ -16,6 +16,11 @@ dev.langchain4j langchain4j-core + + dev.langchain4j + langchain4j + true + org.jetbrains.kotlinx kotlinx-coroutines-core-jvm @@ -30,11 +35,7 @@ junit-jupiter-api test - - dev.langchain4j - langchain4j - test - + dev.langchain4j langchain4j-open-ai @@ -45,6 +46,16 @@ finchly test + + org.mockito.kotlin + mockito-kotlin + test + + + org.mockito + mockito-junit-jupiter + test + diff --git a/langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/TypeAliases.kt b/langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/TypeAliases.kt new file mode 100644 index 0000000..4059ddf --- /dev/null +++ b/langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/TypeAliases.kt @@ -0,0 +1,3 @@ +package me.kpavlov.langchain4j.kotlin + +typealias ChatMemoryId = Any diff --git a/langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/service/AiServicesExtensions.kt b/langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/service/AiServicesExtensions.kt new file mode 100644 index 0000000..473fd1f --- /dev/null +++ b/langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/service/AiServicesExtensions.kt @@ -0,0 +1,13 @@ +package me.kpavlov.langchain4j.kotlin.service + +import dev.langchain4j.service.AiServices + +/** + * Sets the system message provider for the AI services. + * + * @param provider The SystemMessageProvider that supplies system messages based on chat memory identifiers. + * @return The updated AiServices instance with the specified system message provider. + */ +public fun AiServices.systemMessageProvider( + provider: SystemMessageProvider, +): AiServices = this.systemMessageProvider(provider::getSystemMessage) diff --git a/langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/service/SystemMessageProvider.kt b/langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/service/SystemMessageProvider.kt new file mode 100644 index 0000000..b12f011 --- /dev/null +++ b/langchain4j-kotlin/src/main/kotlin/me/kpavlov/langchain4j/kotlin/service/SystemMessageProvider.kt @@ -0,0 +1,16 @@ +package me.kpavlov.langchain4j.kotlin.service + +import me.kpavlov.langchain4j.kotlin.ChatMemoryId + +/** + * Interface for providing LLM system messages based on a given chat memory identifier. + */ +public interface SystemMessageProvider { + /** + * Provides a system message based on the given chat memory identifier. + * + * @param chatMemoryID Identifier for the chat memory used to generate the system message. + * @return A system prompt string associated with the provided chat memory identifier, maybe `null` + */ + public fun getSystemMessage(chatMemoryID: ChatMemoryId): String? +} diff --git a/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/service/ServiceWithSystemMessageProviderTest.kt b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/service/ServiceWithSystemMessageProviderTest.kt new file mode 100644 index 0000000..78382c2 --- /dev/null +++ b/langchain4j-kotlin/src/test/kotlin/me/kpavlov/langchain4j/kotlin/service/ServiceWithSystemMessageProviderTest.kt @@ -0,0 +1,47 @@ +package me.kpavlov.langchain4j.kotlin.service + +import assertk.assertThat +import assertk.assertions.isEqualTo +import dev.langchain4j.data.message.AiMessage +import dev.langchain4j.data.message.SystemMessage +import dev.langchain4j.model.chat.ChatLanguageModel +import dev.langchain4j.model.output.Response +import dev.langchain4j.service.AiServices +import dev.langchain4j.service.UserMessage +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith +import org.mockito.junit.jupiter.MockitoExtension + +@ExtendWith(MockitoExtension::class) +class ServiceWithSystemMessageProviderTest { + lateinit var model: ChatLanguageModel + + @Test + fun `Should use SystemMessageProvider`() { + model = + ChatLanguageModel { + assertThat(it.first()).isEqualTo(SystemMessage.from("You are helpful assistant")) + Response.from(AiMessage.from("I'm fine, thanks")) + } + + val assistant = + AiServices + .builder(Assistant::class.java) + .systemMessageProvider( + object : SystemMessageProvider { + override fun getSystemMessage(chatMemoryID: Any): String = + "You are helpful assistant" + }, + ).chatLanguageModel(model) + .build() + + val response = assistant.askQuestion("How are you") + assertThat(response).isEqualTo("I'm fine, thanks") + } + + private interface Assistant { + fun askQuestion( + @UserMessage question: String, + ): String + } +} diff --git a/pom.xml b/pom.xml index f53a007..9bb67b0 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ root 0.1.2-SNAPSHOT pom - LangChain4j-Kotlin :: Aggregator + LangChain4j-Kotlin :: Root Kotlin enhancements for LangChain4j https://github.com/kpavlov/langchain4j-kotlin @@ -41,18 +41,22 @@ + UTF-8 official 17 2.0.21 ${java.version} ${java.version} + ${java.version} 4.2.2 0.1.1 5.11.3 1.9.0 0.36.0 + 5.4.0 + 5.14.2 2.0.16 @@ -97,6 +101,19 @@ pom import + + org.mockito + mockito-bom + ${mockito.version} + pom + import + + + org.mockito.kotlin + mockito-kotlin + ${mockito-kotlin.version} + test + org.slf4j slf4j-simple @@ -157,6 +174,9 @@ org.apache.maven.plugins maven-surefire-plugin 3.5.2 + + @{argLine} -javaagent:${org.mockito:mockito-core:jar} + org.apache.maven.plugins @@ -222,8 +242,9 @@ -Xjsr305=strict - + + true @@ -231,6 +252,12 @@ org.apache.maven.plugins maven-dependency-plugin + + properties + + properties + + analyze-only @@ -238,6 +265,13 @@ + + org.apache.maven.plugins + maven-surefire-plugin + + @{argLine} -javaagent:${org.mockito:mockito-core:jar} + + org.apache.maven.plugins maven-failsafe-plugin