Skip to content

Commit

Permalink
Introduce initial support for Google's Gemini chat model
Browse files Browse the repository at this point in the history
  • Loading branch information
geoand committed May 1, 2024
1 parent fa0d38e commit b0d0472
Show file tree
Hide file tree
Showing 25 changed files with 1,398 additions and 0 deletions.
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
<module>watsonx</module>
<module>websockets-next</module>
<module>vertex-ai</module>
<module>vertex-ai-gemini</module>
</modules>
<scm>
<connection>scm:git:git@github.com:quarkiverse/quarkus-langchain4j.git</connection>
Expand Down
64 changes: 64 additions & 0 deletions vertex-ai-gemini/deployment/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-vertex-ai-gemini-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>
<artifactId>quarkus-langchain4j-vertex-ai-gemini-deployment</artifactId>
<name>Quarkus LangChain4j - Vertex AI Gemini - Deployment</name>
<dependencies>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-vertex-ai-gemini</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core-deployment</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-rest-client-reactive-jackson-deployment</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-testing-internal</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>${assertj.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<annotationProcessorPaths>
<path>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-extension-processor</artifactId>
<version>${quarkus.version}</version>
</path>
</annotationProcessorPaths>
</configuration>
</plugin>
</plugins>
</build>


</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkiverse.langchain4j.vertexai.gemini.deployment;

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigGroup;

@ConfigGroup
public interface ChatModelBuildConfig {

/**
* Whether the model should be enabled
*/
@ConfigDocDefault("true")
Optional<Boolean> enabled();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.quarkiverse.langchain4j.vertexai.gemini.deployment;

import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;

import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;

@ConfigRoot(phase = BUILD_TIME)
@ConfigMapping(prefix = "quarkus.langchain4j.vertexai.gemini")
public interface LangChain4jVertexAiBuildConfig {
/**
* Chat model related settings
*/
ChatModelBuildConfig chatModel();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package io.quarkiverse.langchain4j.vertexai.gemini.deployment;

import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL;

import java.util.List;

import jakarta.enterprise.context.ApplicationScoped;

import org.jboss.jandex.AnnotationInstance;

import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.runtime.NamedModelUtil;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertexAiGeminiRecorder;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.config.LangChain4jVertexAiGeminiConfig;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.FeatureBuildItem;

public class VertexAiGeminiProcessor {

private static final String FEATURE = "langchain4j-vertexai-gemini";
private static final String PROVIDER = "vertexai-gemini";

@BuildStep
FeatureBuildItem feature() {
return new FeatureBuildItem(FEATURE);
}

@BuildStep
public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem> chatProducer,
LangChain4jVertexAiBuildConfig config) {
if (config.chatModel().enabled().isEmpty()) {
chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER));
}
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
void generateBeans(VertexAiGeminiRecorder recorder, List<SelectedChatModelProviderBuildItem> selectedChatItem,
LangChain4jVertexAiGeminiConfig config, BuildProducer<SyntheticBeanBuildItem> beanProducer) {
for (var selected : selectedChatItem) {
if (PROVIDER.equals(selected.getProvider())) {
var modelName = selected.getModelName();
var builder = SyntheticBeanBuildItem
.configure(CHAT_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.chatModel(config, modelName));

addQualifierIfNecessary(builder, modelName);
beanProducer.produce(builder.done());
}
}
}

private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) {
if (!NamedModelUtil.isDefault(modelName)) {
builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package io.quarkiverse.langchain4j.vertexai.gemini.deployment;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static org.assertj.core.api.Assertions.assertThat;

import jakarta.inject.Inject;
import jakarta.inject.Singleton;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import com.github.tomakehurst.wiremock.verification.LoggedRequest;

import dev.langchain4j.model.chat.ChatLanguageModel;
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertexAiGeminiChatLanguageModel;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertxAiGeminiRestApi;
import io.quarkus.arc.ClientProxy;
import io.quarkus.test.QuarkusUnitTest;

public class VertexAiGeminiChatLanguageModelSmokeTest extends WiremockAware {

private static final String API_KEY = "somekey";
private static final String CHAT_MODEL_ID = "gemini-pro";

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
.overrideRuntimeConfigKey("quarkus.langchain4j.vertexai.gemini.base-url", WiremockAware.wiremockUrlForConfig())
.overrideRuntimeConfigKey("quarkus.langchain4j.vertexai.gemini.log-requests", "true");

@Inject
ChatLanguageModel chatLanguageModel;

@Test
void test() {
assertThat(ClientProxy.unwrap(chatLanguageModel)).isInstanceOf(VertexAiGeminiChatLanguageModel.class);

wiremock().register(
post(urlEqualTo(
String.format("/v1/projects/dummy/locations/dummy/publishers/google/models/%s:generateContent",
CHAT_MODEL_ID)))
.withHeader("Authorization", equalTo("Bearer " + API_KEY))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("""
{
"candidates": [
{
"content": {
"role": "model",
"parts": [
{
"text": "Nice to meet you"
}
]
},
"finishReason": "STOP",
"safetyRatings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.044847902,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.05592617
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.18877223,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.027324531
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.15278918,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.045437217
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.15869519,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.036838707
}
]
}
],
"usageMetadata": {
"promptTokenCount": 11,
"candidatesTokenCount": 37,
"totalTokenCount": 48
}
}
""")));

String response = chatLanguageModel.generate("hello");
assertThat(response).isEqualTo("Nice to meet you");

LoggedRequest loggedRequest = singleLoggedRequest();
assertThat(loggedRequest.getHeader("User-Agent")).isEqualTo("Resteasy Reactive Client");
String requestBody = new String(loggedRequest.getBody());
assertThat(requestBody).contains("hello");
}

@Singleton
public static class DummyAuthProvider implements VertxAiGeminiRestApi.AuthProvider {
@Override
public String getBearerToken() {
return API_KEY;
}
}

}
20 changes: 20 additions & 0 deletions vertex-ai-gemini/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>
<artifactId>quarkus-langchain4j-vertex-ai-gemini-parent</artifactId>
<name>Quarkus LangChain4j - Vertex AI Gemini - Parent</name>
<packaging>pom</packaging>

<modules>
<module>deployment</module>
<module>runtime</module>
</modules>

</project>
Loading

0 comments on commit b0d0472

Please sign in to comment.