Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce initial support for Google's Gemini chat model #534

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
<module>watsonx</module>
<module>websockets-next</module>
<module>vertex-ai</module>
<module>vertex-ai-gemini</module>
</modules>
<scm>
<connection>scm:git:git@github.com:quarkiverse/quarkus-langchain4j.git</connection>
Expand Down
64 changes: 64 additions & 0 deletions vertex-ai-gemini/deployment/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-vertex-ai-gemini-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>
<artifactId>quarkus-langchain4j-vertex-ai-gemini-deployment</artifactId>
<name>Quarkus LangChain4j - Vertex AI Gemini - Deployment</name>
<dependencies>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-vertex-ai-gemini</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core-deployment</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-rest-client-reactive-jackson-deployment</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-testing-internal</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>${assertj.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<annotationProcessorPaths>
<path>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-extension-processor</artifactId>
<version>${quarkus.version}</version>
</path>
</annotationProcessorPaths>
</configuration>
</plugin>
</plugins>
</build>


</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkiverse.langchain4j.vertexai.gemini.deployment;

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigGroup;

@ConfigGroup
public interface ChatModelBuildConfig {

/**
* Whether the model should be enabled
*/
@ConfigDocDefault("true")
Optional<Boolean> enabled();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.quarkiverse.langchain4j.vertexai.gemini.deployment;

import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;

import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;

@ConfigRoot(phase = BUILD_TIME)
@ConfigMapping(prefix = "quarkus.langchain4j.vertexai.gemini")
public interface LangChain4jVertexAiBuildConfig {
/**
* Chat model related settings
*/
ChatModelBuildConfig chatModel();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package io.quarkiverse.langchain4j.vertexai.gemini.deployment;

import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL;

import java.util.List;

import jakarta.enterprise.context.ApplicationScoped;

import org.jboss.jandex.AnnotationInstance;

import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.runtime.NamedModelUtil;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertexAiGeminiRecorder;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.config.LangChain4jVertexAiGeminiConfig;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.FeatureBuildItem;

public class VertexAiGeminiProcessor {

private static final String FEATURE = "langchain4j-vertexai-gemini";
private static final String PROVIDER = "vertexai-gemini";

@BuildStep
FeatureBuildItem feature() {
return new FeatureBuildItem(FEATURE);
}

@BuildStep
public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem> chatProducer,
LangChain4jVertexAiBuildConfig config) {
if (config.chatModel().enabled().isEmpty()) {
chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER));
}
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
void generateBeans(VertexAiGeminiRecorder recorder, List<SelectedChatModelProviderBuildItem> selectedChatItem,
LangChain4jVertexAiGeminiConfig config, BuildProducer<SyntheticBeanBuildItem> beanProducer) {
for (var selected : selectedChatItem) {
if (PROVIDER.equals(selected.getProvider())) {
var modelName = selected.getModelName();
var builder = SyntheticBeanBuildItem
.configure(CHAT_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.chatModel(config, modelName));

addQualifierIfNecessary(builder, modelName);
beanProducer.produce(builder.done());
}
}
}

private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) {
if (!NamedModelUtil.isDefault(modelName)) {
builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package io.quarkiverse.langchain4j.vertexai.gemini.deployment;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static org.assertj.core.api.Assertions.assertThat;

import jakarta.inject.Inject;
import jakarta.inject.Singleton;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import com.github.tomakehurst.wiremock.verification.LoggedRequest;

import dev.langchain4j.model.chat.ChatLanguageModel;
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertexAiGeminiChatLanguageModel;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertxAiGeminiRestApi;
import io.quarkus.arc.ClientProxy;
import io.quarkus.test.QuarkusUnitTest;

public class VertexAiGeminiChatLanguageModelSmokeTest extends WiremockAware {

private static final String API_KEY = "somekey";
private static final String CHAT_MODEL_ID = "gemini-pro";

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
.overrideRuntimeConfigKey("quarkus.langchain4j.vertexai.gemini.base-url", WiremockAware.wiremockUrlForConfig())
.overrideRuntimeConfigKey("quarkus.langchain4j.vertexai.gemini.log-requests", "true");

@Inject
ChatLanguageModel chatLanguageModel;

@Test
void test() {
assertThat(ClientProxy.unwrap(chatLanguageModel)).isInstanceOf(VertexAiGeminiChatLanguageModel.class);

wiremock().register(
post(urlEqualTo(
String.format("/v1/projects/dummy/locations/dummy/publishers/google/models/%s:generateContent",
CHAT_MODEL_ID)))
.withHeader("Authorization", equalTo("Bearer " + API_KEY))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("""
{
"candidates": [
{
"content": {
"role": "model",
"parts": [
{
"text": "Nice to meet you"
}
]
},
"finishReason": "STOP",
"safetyRatings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.044847902,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.05592617
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.18877223,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.027324531
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.15278918,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.045437217
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.15869519,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.036838707
}
]
}
],
"usageMetadata": {
"promptTokenCount": 11,
"candidatesTokenCount": 37,
"totalTokenCount": 48
}
}
""")));

String response = chatLanguageModel.generate("hello");
assertThat(response).isEqualTo("Nice to meet you");

LoggedRequest loggedRequest = singleLoggedRequest();
assertThat(loggedRequest.getHeader("User-Agent")).isEqualTo("Resteasy Reactive Client");
String requestBody = new String(loggedRequest.getBody());
assertThat(requestBody).contains("hello");
}

@Singleton
public static class DummyAuthProvider implements VertxAiGeminiRestApi.AuthProvider {
@Override
public String getBearerToken() {
return API_KEY;
}
}

}
20 changes: 20 additions & 0 deletions vertex-ai-gemini/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>
<artifactId>quarkus-langchain4j-vertex-ai-gemini-parent</artifactId>
<name>Quarkus LangChain4j - Vertex AI Gemini - Parent</name>
<packaging>pom</packaging>

<modules>
<module>deployment</module>
<module>runtime</module>
</modules>

</project>
Loading
Loading