Skip to content

Commit

Permalink
Extend AiServices to support SystemMessageProvider (#25)
Browse files Browse the repository at this point in the history
This pull request includes several updates to the `langchain4j-kotlin`
project, focusing on dependency management, new features, and testing
improvements. The most important changes include adding new
dependencies, introducing a new interface for system message providers,
and updating the `pom.xml` files to reflect these changes.

### Dependency Management:
* Added `langchain4j` as an optional dependency in
`langchain4j-kotlin/pom.xml` and updated the main `pom.xml` to include
versions for `mockito` and `mockito-kotlin` dependencies.
[[1]](diffhunk://#diff-988dd67f6a0de416687ee39291f2713a8229de37918339f305c4ec8ae58e539bR19-R23)
[[2]](diffhunk://#diff-9c5fb3d1b7e3b0f54bc5c4182965c4fe1f9023d449017cece3005d3f90e8e4d8R104-R116)
* Removed `langchain4j` from the test scope in
`langchain4j-kotlin/pom.xml`.
* Added `mockito-kotlin` and `mockito-junit-jupiter` as test
dependencies in `langchain4j-kotlin/pom.xml`.

### New Features:
* Introduced a new type alias `ChatMemoryId` in `TypeAliases.kt`.
* Added `SystemMessageProvider` interface to provide system messages
based on chat memory identifiers.
* Created an extension function `systemMessageProvider` for `AiServices`
to set the system message provider.

### Testing Improvements:
* Added a new test class `ServiceWithSystemMessageProviderTest` to
verify the functionality of the `SystemMessageProvider`.

### Documentation:
* Updated the `README.md` to include the new `langchain4j` dependency.

### Miscellaneous:
* Updated the project name in the main `pom.xml` from "Aggregator" to
"Root".
* Added additional configuration for Maven plugins in `pom.xml` to
support new dependencies and improve build processes.
[[1]](diffhunk://#diff-9c5fb3d1b7e3b0f54bc5c4182965c4fe1f9023d449017cece3005d3f90e8e4d8R177-R179)
[[2]](diffhunk://#diff-9c5fb3d1b7e3b0f54bc5c4182965c4fe1f9023d449017cece3005d3f90e8e4d8L225-R274)

---------

Co-authored-by: Konstantin Pavlov <{ID}+{username}@users.noreply.github.com>
  • Loading branch information
kpavlov and Konstantin Pavlov authored Nov 14, 2024
1 parent 1b73fad commit 47060e8
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 8 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ Add the following dependencies to your `pom.xml`:
<version>[LATEST_VERSION]</version>
</dependency>

<!-- Required Dependencies -->
<!-- Extra Dependencies -->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>0.36.0</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
Expand Down
21 changes: 16 additions & 5 deletions langchain4j-kotlin/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlinx</groupId>
<artifactId>kotlinx-coroutines-core-jvm</artifactId>
Expand All @@ -30,11 +35,7 @@
<artifactId>junit-jupiter-api</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
Expand All @@ -45,6 +46,16 @@
<artifactId>finchly</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito.kotlin</groupId>
<artifactId>mockito-kotlin</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package me.kpavlov.langchain4j.kotlin

typealias ChatMemoryId = Any
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package me.kpavlov.langchain4j.kotlin.service

import dev.langchain4j.service.AiServices

/**
* Sets the system message provider for the AI services.
*
* @param provider The SystemMessageProvider that supplies system messages based on chat memory identifiers.
* @return The updated AiServices instance with the specified system message provider.
*/
public fun <T> AiServices<T>.systemMessageProvider(
provider: SystemMessageProvider,
): AiServices<T> = this.systemMessageProvider(provider::getSystemMessage)
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package me.kpavlov.langchain4j.kotlin.service

import me.kpavlov.langchain4j.kotlin.ChatMemoryId

/**
* Interface for providing LLM system messages based on a given chat memory identifier.
*/
public interface SystemMessageProvider {
/**
* Provides a system message based on the given chat memory identifier.
*
* @param chatMemoryID Identifier for the chat memory used to generate the system message.
* @return A system prompt string associated with the provided chat memory identifier, maybe `null`
*/
public fun getSystemMessage(chatMemoryID: ChatMemoryId): String?
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package me.kpavlov.langchain4j.kotlin.service

import assertk.assertThat
import assertk.assertions.isEqualTo
import dev.langchain4j.data.message.AiMessage
import dev.langchain4j.data.message.SystemMessage
import dev.langchain4j.model.chat.ChatLanguageModel
import dev.langchain4j.model.output.Response
import dev.langchain4j.service.AiServices
import dev.langchain4j.service.UserMessage
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import org.mockito.junit.jupiter.MockitoExtension

@ExtendWith(MockitoExtension::class)
class ServiceWithSystemMessageProviderTest {
lateinit var model: ChatLanguageModel

@Test
fun `Should use SystemMessageProvider`() {
model =
ChatLanguageModel {
assertThat(it.first()).isEqualTo(SystemMessage.from("You are helpful assistant"))
Response.from(AiMessage.from("I'm fine, thanks"))
}

val assistant =
AiServices
.builder(Assistant::class.java)
.systemMessageProvider(
object : SystemMessageProvider {
override fun getSystemMessage(chatMemoryID: Any): String =
"You are helpful assistant"
},
).chatLanguageModel(model)
.build()

val response = assistant.askQuestion("How are you")
assertThat(response).isEqualTo("I'm fine, thanks")
}

private interface Assistant {
fun askQuestion(
@UserMessage question: String,
): String
}
}
38 changes: 36 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<artifactId>root</artifactId>
<version>0.1.2-SNAPSHOT</version>
<packaging>pom</packaging>
<name>LangChain4j-Kotlin :: Aggregator</name>
<name>LangChain4j-Kotlin :: Root</name>
<description>Kotlin enhancements for LangChain4j</description>
<url>https://github.com/kpavlov/langchain4j-kotlin</url>

Expand Down Expand Up @@ -41,18 +41,22 @@
</scm>

<properties>
<argLine/>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<kotlin.code.style>official</kotlin.code.style>
<java.version>17</java.version>
<kotlin.version>2.0.21</kotlin.version>
<kotlin.compiler.jvmTarget>${java.version}</kotlin.compiler.jvmTarget>
<java.compiler.release>${java.version}</java.compiler.release>
<java.compiler.source>${java.version}</java.compiler.source>
<!-- Dependencies -->
<awaitility.version>4.2.2</awaitility.version>
<finchly.version>0.1.1</finchly.version>
<junit.version>5.11.3</junit.version>
<kotlinx.version>1.9.0</kotlinx.version>
<langchain4j.version>0.36.0</langchain4j.version>
<mockito-kotlin.version>5.4.0</mockito-kotlin.version>
<mockito.version>5.14.2</mockito.version>
<slf4j.version>2.0.16</slf4j.version>
</properties>

Expand Down Expand Up @@ -97,6 +101,19 @@
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-bom</artifactId>
<version>${mockito.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>org.mockito.kotlin</groupId>
<artifactId>mockito-kotlin</artifactId>
<version>${mockito-kotlin.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
Expand Down Expand Up @@ -157,6 +174,9 @@
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.5.2</version>
<configuration>
<argLine>@{argLine} -javaagent:${org.mockito:mockito-core:jar}</argLine>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
Expand Down Expand Up @@ -222,22 +242,36 @@
<configuration>
<args>
<arg>-Xjsr305=strict</arg>
<!--arg>-Werror</arg-->
<!--<arg>-Werror</arg>-->
</args>
<javaParameters>true</javaParameters>
</configuration>
</plugin>

<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<executions>
<execution>
<id>properties</id>
<goals>
<goal>properties</goal>
</goals>
</execution>
<execution>
<goals>
<goal>analyze-only</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<argLine>@{argLine} -javaagent:${org.mockito:mockito-core:jar}</argLine>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-failsafe-plugin</artifactId>
Expand Down

0 comments on commit 47060e8

Please sign in to comment.