From 06a4c267002958459c6793be18749e348f2753f9 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Tue, 30 Apr 2024 17:36:20 +0300 Subject: [PATCH] Introduce initial support for Google's Gemini chat model --- pom.xml | 1 + vertex-ai-gemini/deployment/pom.xml | 64 +++++ .../deployment/ChatModelBuildConfig.java | 16 ++ .../LangChain4jVertexAiBuildConfig.java | 15 ++ .../deployment/VertexAiGeminiProcessor.java | 68 +++++ ...texAiGeminiChatLanguageModelSmokeTest.java | 121 +++++++++ vertex-ai-gemini/pom.xml | 20 ++ vertex-ai-gemini/runtime/pom.xml | 103 ++++++++ .../runtime/gemini/ContentMapper.java | 56 ++++ .../runtime/gemini/FinishReasonMapper.java | 18 ++ .../vertexai/runtime/gemini/FunctionCall.java | 7 + .../runtime/gemini/FunctionDeclaration.java | 11 + .../runtime/gemini/FunctionResponse.java | 8 + .../gemini/GenerateContentRequest.java | 33 +++ .../gemini/GenerateContentResponse.java | 33 +++ .../GenerateContentResponseHandler.java | 45 ++++ .../runtime/gemini/GenerationConfig.java | 68 +++++ .../vertexai/runtime/gemini/RoleMapper.java | 18 ++ .../VertexAiGeminiChatLanguageModel.java | 155 +++++++++++ .../gemini/VertexAiGeminiRecorder.java | 89 +++++++ .../runtime/gemini/VertxAiGeminiRestApi.java | 246 ++++++++++++++++++ .../gemini/config/ChatModelConfig.java | 92 +++++++ .../LangChain4jVertexAiGeminiConfig.java | 87 +++++++ .../src/main/resources/META-INF/beans.xml | 0 .../resources/META-INF/quarkus-extension.yaml | 17 ++ 25 files changed, 1391 insertions(+) create mode 100644 vertex-ai-gemini/deployment/pom.xml create mode 100644 vertex-ai-gemini/deployment/src/main/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/ChatModelBuildConfig.java create mode 100644 vertex-ai-gemini/deployment/src/main/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/LangChain4jVertexAiBuildConfig.java create mode 100644 vertex-ai-gemini/deployment/src/main/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/VertexAiGeminiProcessor.java create mode 100644 vertex-ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/VertexAiGeminiChatLanguageModelSmokeTest.java create mode 100644 vertex-ai-gemini/pom.xml create mode 100644 vertex-ai-gemini/runtime/pom.xml create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/ContentMapper.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FinishReasonMapper.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionCall.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionDeclaration.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionResponse.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentRequest.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponse.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponseHandler.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerationConfig.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/RoleMapper.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiRecorder.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertxAiGeminiRestApi.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/config/ChatModelConfig.java create mode 100644 vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/config/LangChain4jVertexAiGeminiConfig.java create mode 100644 vertex-ai-gemini/runtime/src/main/resources/META-INF/beans.xml create mode 100644 vertex-ai-gemini/runtime/src/main/resources/META-INF/quarkus-extension.yaml diff --git a/pom.xml b/pom.xml index bc14ec306..bbbd4142d 100644 --- a/pom.xml +++ b/pom.xml @@ -38,6 +38,7 @@ watsonx websockets-next vertex-ai + vertex-ai-gemini scm:git:git@github.com:quarkiverse/quarkus-langchain4j.git diff --git a/vertex-ai-gemini/deployment/pom.xml b/vertex-ai-gemini/deployment/pom.xml new file mode 100644 index 000000000..d71bb6926 --- /dev/null +++ b/vertex-ai-gemini/deployment/pom.xml @@ -0,0 +1,64 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-vertex-ai-gemini-parent + 999-SNAPSHOT + + quarkus-langchain4j-vertex-ai-gemini-deployment + Quarkus LangChain4j - Vertex AI Gemini - Deployment + + + io.quarkiverse.langchain4j + quarkus-langchain4j-vertex-ai-gemini + ${project.version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core-deployment + ${project.version} + + + io.quarkus + quarkus-rest-client-reactive-jackson-deployment + + + io.quarkus + quarkus-junit5-internal + test + + + io.quarkiverse.langchain4j + quarkus-langchain4j-testing-internal + ${project.version} + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + + + + diff --git a/vertex-ai-gemini/deployment/src/main/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/ChatModelBuildConfig.java b/vertex-ai-gemini/deployment/src/main/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/ChatModelBuildConfig.java new file mode 100644 index 000000000..faa29341d --- /dev/null +++ b/vertex-ai-gemini/deployment/src/main/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/ChatModelBuildConfig.java @@ -0,0 +1,16 @@ +package io.quarkiverse.langchain4j.vertexai.gemini.deployment; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigGroup; + +@ConfigGroup +public interface ChatModelBuildConfig { + + /** + * Whether the model should be enabled + */ + @ConfigDocDefault("true") + Optional enabled(); +} diff --git a/vertex-ai-gemini/deployment/src/main/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/LangChain4jVertexAiBuildConfig.java b/vertex-ai-gemini/deployment/src/main/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/LangChain4jVertexAiBuildConfig.java new file mode 100644 index 000000000..9efa5650e --- /dev/null +++ b/vertex-ai-gemini/deployment/src/main/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/LangChain4jVertexAiBuildConfig.java @@ -0,0 +1,15 @@ +package io.quarkiverse.langchain4j.vertexai.gemini.deployment; + +import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME; + +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; + +@ConfigRoot(phase = BUILD_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.vertexai.gemini") +public interface LangChain4jVertexAiBuildConfig { + /** + * Chat model related settings + */ + ChatModelBuildConfig chatModel(); +} diff --git a/vertex-ai-gemini/deployment/src/main/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/VertexAiGeminiProcessor.java b/vertex-ai-gemini/deployment/src/main/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/VertexAiGeminiProcessor.java new file mode 100644 index 000000000..2ef4474b5 --- /dev/null +++ b/vertex-ai-gemini/deployment/src/main/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/VertexAiGeminiProcessor.java @@ -0,0 +1,68 @@ +package io.quarkiverse.langchain4j.vertexai.gemini.deployment; + +import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL; + +import java.util.List; + +import jakarta.enterprise.context.ApplicationScoped; + +import org.jboss.jandex.AnnotationInstance; + +import io.quarkiverse.langchain4j.ModelName; +import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; +import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; +import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertexAiGeminiRecorder; +import io.quarkiverse.langchain4j.vertexai.runtime.gemini.config.LangChain4jVertexAiGeminiConfig; +import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.deployment.annotations.BuildProducer; +import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.ExecutionTime; +import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.FeatureBuildItem; + +public class VertexAiGeminiProcessor { + + private static final String FEATURE = "langchain4j-vertexai-gemini"; + private static final String PROVIDER = "vertexai-gemini"; + + @BuildStep + FeatureBuildItem feature() { + return new FeatureBuildItem(FEATURE); + } + + @BuildStep + public void providerCandidates(BuildProducer chatProducer, + LangChain4jVertexAiBuildConfig config) { + if (config.chatModel().enabled().isEmpty()) { + chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER)); + } + } + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + @BuildStep + @Record(ExecutionTime.RUNTIME_INIT) + void generateBeans(VertexAiGeminiRecorder recorder, List selectedChatItem, + LangChain4jVertexAiGeminiConfig config, BuildProducer beanProducer) { + for (var selected : selectedChatItem) { + if (PROVIDER.equals(selected.getProvider())) { + var modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.chatModel(config, modelName)); + + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } + } + } + + private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) { + if (!NamedModelUtil.isDefault(modelName)) { + builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build()); + } + } +} diff --git a/vertex-ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/VertexAiGeminiChatLanguageModelSmokeTest.java b/vertex-ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/VertexAiGeminiChatLanguageModelSmokeTest.java new file mode 100644 index 000000000..502b66802 --- /dev/null +++ b/vertex-ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/vertexai/gemini/deployment/VertexAiGeminiChatLanguageModelSmokeTest.java @@ -0,0 +1,121 @@ +package io.quarkiverse.langchain4j.vertexai.gemini.deployment; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.assertj.core.api.Assertions.assertThat; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.github.tomakehurst.wiremock.verification.LoggedRequest; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertexAiGeminiChatLanguageModel; +import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertxAiGeminiRestApi; +import io.quarkus.arc.ClientProxy; +import io.quarkus.test.QuarkusUnitTest; + +public class VertexAiGeminiChatLanguageModelSmokeTest extends WiremockAware { + + private static final String API_KEY = "somekey"; + private static final String CHAT_MODEL_ID = "gemini-pro"; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.vertexai.gemini.base-url", WiremockAware.wiremockUrlForConfig()) + .overrideRuntimeConfigKey("quarkus.langchain4j.vertexai.gemini.log-requests", "true"); + + @Inject + ChatLanguageModel chatLanguageModel; + + @Test + void test() { + assertThat(ClientProxy.unwrap(chatLanguageModel)).isInstanceOf(VertexAiGeminiChatLanguageModel.class); + + wiremock().register( + post(urlEqualTo( + String.format("/v1/projects/dummy/locations/dummy/publishers/google/models/%s:generateContent", + CHAT_MODEL_ID))) + .withHeader("Authorization", equalTo("Bearer " + API_KEY)) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(""" + { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "Nice to meet you" + } + ] + }, + "finishReason": "STOP", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.044847902, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.05592617 + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.18877223, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.027324531 + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.15278918, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.045437217 + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.15869519, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.036838707 + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 11, + "candidatesTokenCount": 37, + "totalTokenCount": 48 + } + } + """))); + + String response = chatLanguageModel.generate("hello"); + assertThat(response).isEqualTo("Nice to meet you"); + + LoggedRequest loggedRequest = singleLoggedRequest(); + assertThat(loggedRequest.getHeader("User-Agent")).isEqualTo("Resteasy Reactive Client"); + String requestBody = new String(loggedRequest.getBody()); + assertThat(requestBody).contains("hello"); + } + + @Singleton + public static class DummyAuthProvider implements VertxAiGeminiRestApi.AuthProvider { + @Override + public String getBearerToken() { + return API_KEY; + } + } + +} diff --git a/vertex-ai-gemini/pom.xml b/vertex-ai-gemini/pom.xml new file mode 100644 index 000000000..9eb5ba728 --- /dev/null +++ b/vertex-ai-gemini/pom.xml @@ -0,0 +1,20 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-parent + 999-SNAPSHOT + + quarkus-langchain4j-vertex-ai-gemini-parent + Quarkus LangChain4j - Vertex AI Gemini - Parent + pom + + + deployment + runtime + + + diff --git a/vertex-ai-gemini/runtime/pom.xml b/vertex-ai-gemini/runtime/pom.xml new file mode 100644 index 000000000..13806c785 --- /dev/null +++ b/vertex-ai-gemini/runtime/pom.xml @@ -0,0 +1,103 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-vertex-ai-gemini-parent + 999-SNAPSHOT + + quarkus-langchain4j-vertex-ai-gemini + Quarkus LangChain4j - Vertex AI Gemini - Runtime + + + io.quarkus + quarkus-arc + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core + ${project.version} + + + io.quarkus + quarkus-rest-client-reactive-jackson + + + + com.google.auth + google-auth-library-oauth2-http + 1.23.0 + + + + io.quarkus + quarkus-junit5-internal + test + + + org.mockito + mockito-core + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + + + + io.quarkus + quarkus-extension-maven-plugin + ${quarkus.version} + + + compile + + extension-descriptor + + + ${project.groupId}:${project.artifactId}-deployment:${project.version} + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + maven-jar-plugin + + + generate-codestart-jar + generate-resources + + jar + + + ${project.basedir}/src/main + + codestarts/** + + codestarts + true + + + + + + + diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/ContentMapper.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/ContentMapper.java new file mode 100644 index 000000000..54df26ec9 --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/ContentMapper.java @@ -0,0 +1,56 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +import java.util.ArrayList; +import java.util.List; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.Content; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.TextContent; +import dev.langchain4j.data.message.UserMessage; + +final class ContentMapper { + + private ContentMapper() { + } + + static GenerateContentRequest map(List messages, List toolSpecifications, + GenerationConfig generationConfig) { + List systemPrompts = new ArrayList<>(); + List contents = new ArrayList<>(messages.size()); + + for (ChatMessage message : messages) { + if (message instanceof SystemMessage sm) { + systemPrompts.add(sm.text()); + } else { + String role = RoleMapper.map(message.type()); + if (message instanceof UserMessage um) { + List parts = new ArrayList<>(um.contents().size()); + for (Content userMessageContent : um.contents()) { + if (userMessageContent instanceof TextContent tc) { + parts.add(GenerateContentRequest.Content.Part.ofText(tc.text())); + } else { + throw new IllegalArgumentException("The Gemini integration currently only supports text content"); + } + } + contents.add(new GenerateContentRequest.Content(role, parts)); + } + } + } + + List tools; + if (toolSpecifications == null || toolSpecifications.isEmpty()) { + tools = null; + } else { + tools = new ArrayList<>(toolSpecifications.size()); + for (GenerateContentRequest.Tool tool : tools) { + // TODO: implement + } + } + + return new GenerateContentRequest(contents, + !systemPrompts.isEmpty() ? GenerateContentRequest.SystemInstruction.ofContent(systemPrompts) : null, tools, + generationConfig); + } +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FinishReasonMapper.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FinishReasonMapper.java new file mode 100644 index 000000000..818c454b0 --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FinishReasonMapper.java @@ -0,0 +1,18 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +import dev.langchain4j.model.output.FinishReason; + +final class FinishReasonMapper { + + private FinishReasonMapper() { + } + + static FinishReason map(GenerateContentResponse.FinishReason finishReason) { + return switch (finishReason) { + case STOP -> FinishReason.STOP; + case MAX_TOKENS -> FinishReason.LENGTH; + case SAFETY -> FinishReason.CONTENT_FILTER; + default -> FinishReason.OTHER; + }; + } +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionCall.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionCall.java new file mode 100644 index 000000000..a22f3fc90 --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionCall.java @@ -0,0 +1,7 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +import java.util.Map; + +public record FunctionCall(String name, Map args) { + +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionDeclaration.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionDeclaration.java new file mode 100644 index 000000000..2c96daa83 --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionDeclaration.java @@ -0,0 +1,11 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +import java.util.List; +import java.util.Map; + +public record FunctionDeclaration(String name, String description, Parameters parameters) { + + public record Parameters(String type, Map> properties, List required) { + + } +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionResponse.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionResponse.java new file mode 100644 index 000000000..18c682af5 --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/FunctionResponse.java @@ -0,0 +1,8 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +public record FunctionResponse(String name, Response response) { + + public record Response(String name, Object content) { + + } +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentRequest.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentRequest.java new file mode 100644 index 000000000..7809c666a --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentRequest.java @@ -0,0 +1,33 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +import java.util.List; + +public record GenerateContentRequest(List contents, SystemInstruction systemInstruction, List tools, + GenerationConfig generationConfig) { + + public record Content(String role, List parts) { + + public record Part(String text, FunctionCall functionCall) { + + public static Part ofText(String text) { + return new Part(text, null); + } + } + } + + public record SystemInstruction(List parts) { + + public static SystemInstruction ofContent(List contents) { + return new SystemInstruction(contents.stream().map(Part::new).toList()); + } + + public record Part(String text) { + + } + } + + public record Tool(List functionDeclarations) { + + } + +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponse.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponse.java new file mode 100644 index 000000000..461554696 --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponse.java @@ -0,0 +1,33 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +import java.util.List; + +public record GenerateContentResponse(List candidates, UsageMetadata usageMetadata) { + + public record Candidate(Content content, FinishReason finishReason) { + + public record Content(List parts) { + + } + + public record Part(String text, FunctionCall functionCall) { + + } + + } + + public record UsageMetadata(Integer promptTokenCount, Integer candidatesTokenCount, Integer totalTokenCount) { + + } + + public enum FinishReason { + + FINISH_REASON_UNSPECIFIED, + STOP, + MAX_TOKENS, + SAFETY, + RECITATION, + OTHER, + UNRECOGNIZED + } +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponseHandler.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponseHandler.java new file mode 100644 index 000000000..32b5fc503 --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerateContentResponseHandler.java @@ -0,0 +1,45 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +import java.util.List; + +import dev.langchain4j.model.output.TokenUsage; + +final class GenerateContentResponseHandler { + + private GenerateContentResponseHandler() { + } + + static String getText(GenerateContentResponse response) { + GenerateContentResponse.FinishReason finishReason = getFinishReason(response); + if (finishReason == GenerateContentResponse.FinishReason.SAFETY) { + throw new IllegalArgumentException("The response is blocked due to safety reason."); + } else if (finishReason == GenerateContentResponse.FinishReason.RECITATION) { + throw new IllegalArgumentException("The response is blocked due to unauthorized citations."); + } + + StringBuilder text = new StringBuilder(); + List parts = response.candidates().get(0).content().parts(); + for (GenerateContentResponse.Candidate.Part part : parts) { + text.append(part.text()); + } + + return text.toString(); + } + + static GenerateContentResponse.FinishReason getFinishReason(GenerateContentResponse response) { + if (response.candidates().size() != 1) { + throw new IllegalArgumentException( + String.format( + "This response should have exactly 1 candidate, but it has %s.", + response.candidates().size())); + } + return response.candidates().get(0).finishReason(); + } + + static TokenUsage getTokenUsage(GenerateContentResponse.UsageMetadata usageMetadata) { + return new TokenUsage( + usageMetadata.promptTokenCount(), + usageMetadata.candidatesTokenCount(), + usageMetadata.totalTokenCount()); + } +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerationConfig.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerationConfig.java new file mode 100644 index 000000000..e6c149764 --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/GenerationConfig.java @@ -0,0 +1,68 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +public class GenerationConfig { + + private final Double temperature; + private final Integer maxOutputTokens; + private final Integer topK; + private final Double topP; + + public GenerationConfig(Builder builder) { + this.temperature = builder.temperature; + this.maxOutputTokens = builder.maxOutputTokens; + this.topK = builder.topK; + this.topP = builder.topP; + } + + public Double getTemperature() { + return temperature; + } + + public Integer getMaxOutputTokens() { + return maxOutputTokens; + } + + public Integer getTopK() { + return topK; + } + + public Double getTopP() { + return topP; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private Double temperature; + private Integer maxOutputTokens; + private Integer topK; + private Double topP; + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder maxOutputTokens(Integer maxOutputTokens) { + this.maxOutputTokens = maxOutputTokens; + return this; + } + + public Builder topK(Integer topK) { + this.topK = topK; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public GenerationConfig build() { + return new GenerationConfig(this); + } + } +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/RoleMapper.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/RoleMapper.java new file mode 100644 index 000000000..cc623d296 --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/RoleMapper.java @@ -0,0 +1,18 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +import dev.langchain4j.data.message.ChatMessageType; + +final class RoleMapper { + + private RoleMapper() { + } + + static String map(ChatMessageType type) { + return switch (type) { + case USER -> "user"; + case AI -> "model"; + case TOOL_EXECUTION_RESULT -> null; + default -> throw new IllegalArgumentException(type + " is not allowed."); + }; + } +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java new file mode 100644 index 000000000..ca08513ee --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java @@ -0,0 +1,155 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +import java.net.URI; +import java.net.URISyntaxException; +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import org.jboss.resteasy.reactive.client.api.LoggingScope; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; + +public class VertexAiGeminiChatLanguageModel implements ChatLanguageModel { + + private final GenerationConfig generationConfig; + private final VertxAiGeminiRestApi.ApiMetadata apiMetadata; + private final VertxAiGeminiRestApi restApi; + + private VertexAiGeminiChatLanguageModel(Builder builder) { + this.generationConfig = GenerationConfig.builder() + .maxOutputTokens(builder.maxOutputTokens) + .temperature(builder.temperature) + .topK(builder.topK) + .topP(builder.topP) + .build(); + + this.apiMetadata = VertxAiGeminiRestApi.ApiMetadata + .builder() + .modelId(builder.modelId) + .location(builder.location) + .projectId(builder.projectId) + .publisher(builder.publisher) + .build(); + + try { + String baseUrl = builder.baseUrl.orElse(String.format("https://%s-aiplatform.googleapis.com", builder.location)); + var restApiBuilder = QuarkusRestClientBuilder.newBuilder() + .baseUri(new URI(baseUrl)) + .connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS) + .readTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS); + + if (builder.logRequests || builder.logResponses) { + restApiBuilder.loggingScope(LoggingScope.REQUEST_RESPONSE); + restApiBuilder.clientLogger(new VertxAiGeminiRestApi.VertxAiClientLogger(builder.logRequests, + builder.logResponses)); + } + restApi = restApiBuilder.build(VertxAiGeminiRestApi.class); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @Override + public Response generate(List messages) { + GenerateContentRequest request = ContentMapper.map(messages, Collections.emptyList(), generationConfig); + + GenerateContentResponse response = restApi.predict(request, apiMetadata); + + return Response.from( + AiMessage.from(GenerateContentResponseHandler.getText(response)), + GenerateContentResponseHandler.getTokenUsage(response.usageMetadata()), + FinishReasonMapper.map(GenerateContentResponseHandler.getFinishReason(response))); + } + + public static Builder builder() { + return new Builder(); + } + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + public static final class Builder { + + private Optional baseUrl = Optional.empty(); + private String projectId; + private String location; + private String modelId; + private String publisher; + private Double temperature; + private Integer maxOutputTokens; + private Integer topK; + private Double topP; + private Duration timeout = Duration.ofSeconds(10); + private Boolean logRequests = false; + private Boolean logResponses = false; + + public Builder baseUrl(Optional baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder projectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder location(String location) { + this.location = location; + return this; + } + + public Builder modelId(String modelId) { + this.modelId = modelId; + return this; + } + + public Builder publisher(String publisher) { + this.publisher = publisher; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder maxOutputTokens(Integer maxOutputTokens) { + this.maxOutputTokens = maxOutputTokens; + return this; + } + + public Builder topK(Integer topK) { + this.topK = topK; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder logRequests(boolean logRequests) { + this.logRequests = logRequests; + return this; + } + + public Builder logResponses(boolean logResponses) { + this.logResponses = logResponses; + return this; + } + + public VertexAiGeminiChatLanguageModel build() { + return new VertexAiGeminiChatLanguageModel(this); + } + } +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiRecorder.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiRecorder.java new file mode 100644 index 000000000..7ed6eabb5 --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiRecorder.java @@ -0,0 +1,89 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault; + +import java.util.Optional; +import java.util.function.Supplier; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.DisabledChatLanguageModel; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; +import io.quarkiverse.langchain4j.vertexai.runtime.gemini.config.LangChain4jVertexAiGeminiConfig; +import io.quarkus.runtime.annotations.Recorder; +import io.smallrye.config.ConfigValidationException; + +@Recorder +public class VertexAiGeminiRecorder { + private static final String DUMMY_KEY = "dummy"; + + public Supplier chatModel(LangChain4jVertexAiGeminiConfig config, String modelName) { + var vertexAiConfig = correspondingVertexAiConfig(config, modelName); + + if (vertexAiConfig.enableIntegration()) { + var chatModelConfig = vertexAiConfig.chatModel(); + Optional baseUrl = vertexAiConfig.baseUrl(); + + String location = vertexAiConfig.location(); + if (baseUrl.isEmpty() && DUMMY_KEY.equals(location)) { + throw new ConfigValidationException(createConfigProblems("location", modelName)); + } + String projectId = vertexAiConfig.projectId(); + if (baseUrl.isEmpty() && DUMMY_KEY.equals(projectId)) { + throw new ConfigValidationException(createConfigProblems("project-id", modelName)); + } + var builder = VertexAiGeminiChatLanguageModel.builder() + .baseUrl(baseUrl) + .location(location) + .projectId(projectId) + .publisher(vertexAiConfig.publisher()) + .modelId(chatModelConfig.modelId()) + .maxOutputTokens(chatModelConfig.maxOutputTokens()) + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), vertexAiConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), vertexAiConfig.logResponses())); + + if (chatModelConfig.temperature().isEmpty()) { + builder.temperature(chatModelConfig.temperature().getAsDouble()); + } + if (chatModelConfig.topK().isPresent()) { + builder.topK(chatModelConfig.topK().getAsInt()); + } + if (chatModelConfig.topP().isPresent()) { + builder.topP(chatModelConfig.topP().getAsDouble()); + } + + // TODO: add the rest of the properties + + return new Supplier<>() { + @Override + public ChatLanguageModel get() { + return builder.build(); + } + }; + } else { + return new Supplier<>() { + @Override + public ChatLanguageModel get() { + return new DisabledChatLanguageModel(); + } + }; + } + + } + + private LangChain4jVertexAiGeminiConfig.VertexAiGeminiConfig correspondingVertexAiConfig( + LangChain4jVertexAiGeminiConfig runtimeConfig, String modelName) { + + return NamedModelUtil.isDefault(modelName) ? runtimeConfig.defaultConfig() : runtimeConfig.namedConfig().get(modelName); + } + + private static ConfigValidationException.Problem[] createConfigProblems(String key, String modelName) { + return new ConfigValidationException.Problem[] { createConfigProblem(key, modelName) }; + } + + private static ConfigValidationException.Problem createConfigProblem(String key, String modelName) { + return new ConfigValidationException.Problem( + "SRCFG00014: The config property quarkus.langchain4j.vertexai%s%s is required but it could not be found in any config source" + .formatted( + NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), key)); + } +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertxAiGeminiRestApi.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertxAiGeminiRestApi.java new file mode 100644 index 000000000..8c947b806 --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertxAiGeminiRestApi.java @@ -0,0 +1,246 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini; + +import static java.util.stream.Collectors.joining; +import static java.util.stream.StreamSupport.stream; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.concurrent.ExecutorService; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.ws.rs.BeanParam; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; + +import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; +import org.jboss.logging.Logger; +import org.jboss.resteasy.reactive.RestPath; +import org.jboss.resteasy.reactive.client.api.ClientLogger; +import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestContext; +import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestFilter; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.auth.oauth2.GoogleCredentials; + +import io.quarkus.arc.DefaultBean; +import io.quarkus.rest.client.reactive.jackson.ClientObjectMapper; +import io.vertx.core.Handler; +import io.vertx.core.MultiMap; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.HttpClientRequest; +import io.vertx.core.http.HttpClientResponse; + +@Path("v1/projects/{projectId}/locations/{location}/publishers/{publisher}/models") +@RegisterProvider(VertxAiGeminiRestApi.TokenFilter.class) +public interface VertxAiGeminiRestApi { + + @Path("{modelId}:generateContent") + @POST + GenerateContentResponse predict(GenerateContentRequest request, @BeanParam ApiMetadata apiMetadata); + + @ClientObjectMapper + static ObjectMapper mapper(ObjectMapper defaultObjectMapper) { + return defaultObjectMapper.copy().setSerializationInclusion(JsonInclude.Include.NON_NULL); + } + + class ApiMetadata { + @RestPath + public final String projectId; + + @RestPath + public final String location; + + @RestPath + public final String modelId; + + @RestPath + public final String publisher; + + private ApiMetadata(Builder builder) { + this.projectId = builder.projectId; + this.location = builder.location; + this.modelId = builder.modelId; + this.publisher = builder.publisher; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String projectId; + private String location; + private String modelId; + private String publisher; + + public Builder projectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder location(String location) { + this.location = location; + return this; + } + + public Builder modelId(String modelId) { + this.modelId = modelId; + return this; + } + + public Builder publisher(String publisherId) { + this.publisher = publisherId; + return this; + } + + public ApiMetadata build() { + return new ApiMetadata(this); + } + } + } + + interface AuthProvider { + + String getBearerToken(); + } + + @ApplicationScoped + @DefaultBean + class ApplicationDefaultAuthProvider implements AuthProvider { + + @Override + public String getBearerToken() { + try { + var credentials = GoogleCredentials.getApplicationDefault(); + credentials.refreshIfExpired(); + return credentials.getAccessToken().getTokenValue(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } + + class TokenFilter implements ResteasyReactiveClientRequestFilter { + + private final ExecutorService executorService; + private final AuthProvider authProvider; + + public TokenFilter(ExecutorService executorService, AuthProvider authProvider) { + this.executorService = executorService; + this.authProvider = authProvider; + } + + @Override + public void filter(ResteasyReactiveClientRequestContext context) { + context.suspend(); + executorService.submit(new Runnable() { + @Override + public void run() { + try { + context.getHeaders().add("Authorization", "Bearer " + authProvider.getBearerToken()); + context.resume(); + } catch (Exception e) { + context.resume(e); + } + } + }); + } + } + + class VertxAiClientLogger implements ClientLogger { + private static final Logger log = Logger.getLogger(VertxAiClientLogger.class); + + private static final Pattern BEARER_PATTERN = Pattern.compile("(Bearer\\s*)(\\w{2})(\\w|\\.|-|_)+(\\w{2})"); + + private final boolean logRequests; + private final boolean logResponses; + + public VertxAiClientLogger(boolean logRequests, boolean logResponses) { + this.logRequests = logRequests; + this.logResponses = logResponses; + } + + @Override + public void setBodySize(int bodySize) { + // ignore + } + + @Override + public void logRequest(HttpClientRequest request, Buffer body, boolean omitBody) { + if (!logRequests || !log.isInfoEnabled()) { + return; + } + try { + log.infof("Request:\n- method: %s\n- url: %s\n- headers: %s\n- body: %s", + request.getMethod(), + request.absoluteURI(), + inOneLine(request.headers()), + bodyToString(body)); + } catch (Exception e) { + log.warn("Failed to log request", e); + } + } + + @Override + public void logResponse(HttpClientResponse response, boolean redirect) { + if (!logResponses || !log.isInfoEnabled()) { + return; + } + response.bodyHandler(new Handler<>() { + @Override + public void handle(Buffer body) { + try { + log.infof( + "Response:\n- status code: %s\n- headers: %s\n- body: %s", + response.statusCode(), + inOneLine(response.headers()), + bodyToString(body)); + } catch (Exception e) { + log.warn("Failed to log response", e); + } + } + }); + } + + private String bodyToString(Buffer body) { + if (body == null) { + return ""; + } + return body.toString(); + } + + private String inOneLine(MultiMap headers) { + + return stream(headers.spliterator(), false) + .map(header -> { + String headerKey = header.getKey(); + String headerValue = header.getValue(); + if (headerKey.equals("Authorization")) { + headerValue = maskAuthorizationHeaderValue(headerValue); + } + return String.format("[%s: %s]", headerKey, headerValue); + }) + .collect(joining(", ")); + } + + private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) { + try { + + Matcher matcher = BEARER_PATTERN.matcher(authorizationHeaderValue); + + StringBuilder sb = new StringBuilder(); + while (matcher.find()) { + matcher.appendReplacement(sb, matcher.group(1) + matcher.group(2) + "..." + matcher.group(4)); + } + matcher.appendTail(sb); + + return sb.toString(); + } catch (Exception e) { + return "Failed to mask the API key."; + } + } + } +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/config/ChatModelConfig.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/config/ChatModelConfig.java new file mode 100644 index 000000000..554c43f1b --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/config/ChatModelConfig.java @@ -0,0 +1,92 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini.config; + +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.OptionalInt; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigGroup; +import io.smallrye.config.WithDefault; + +@ConfigGroup +public interface ChatModelConfig { + + /** + * The id of the model to use + */ + @WithDefault("gemini-pro") + String modelId(); + + /** + * The temperature is used for sampling during response generation, which occurs when topP and topK are applied. + * Temperature controls the degree of randomness in token selection. Lower temperatures are good for prompts that require a + * less open-ended or creative response, while higher temperatures can lead to more diverse or creative results. + * A temperature of 0 means that the highest probability tokens are always selected. In this case, responses for a given + * prompt are mostly deterministic, but a small amount of variation is still possible. + *

+ * If the model returns a response that's too generic, too short, or the model gives a fallback response, try increasing the + * temperature. + *

+ * Range for gemini-1.0-pro-001 + *

+ * Range for gemini-1.0-pro-002, gemini-1.5-pro: 0.0 - 2.0 + *

+ * Default for gemini-1.5-pro and gemini-1.0-pro-002: 1.0 + *

+ * Default for gemini-1.0-pro-001: 0.9 + */ + @WithDefault("0.0") + OptionalDouble temperature(); + + /** + * Maximum number of tokens that can be generated in the response. A token is approximately four characters. 100 tokens + * correspond to roughly 60-80 words. + * Specify a lower value for shorter responses and a higher value for potentially longer responses. + */ + @WithDefault("8192") + Integer maxOutputTokens(); + + /** + * Top-P changes how the model selects tokens for output. Tokens are selected from the most (see top-K) to least probable + * until the sum of their probabilities equals the top-P value. + * For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-P value is 0.5, then the model + * will select either A or B as the next token by using temperature and excludes C as a candidate. + *

+ * Specify a lower value for less random responses and a higher value for more random responses. + *

+ * Range: 0.0 - 1.0 + *

+ * gemini-1.0-pro and gemini-1.5-pro don't support topK + */ + OptionalDouble topP(); + + /** + * Top-K changes how the model selects tokens for output. A top-K of 1 means the next selected token is the most probable + * among all tokens in the model's vocabulary (also called greedy decoding), + * while a top-K of 3 means that the next token is selected from among the three most probable tokens by using temperature. + *

+ * For each token selection step, the top-K tokens with the highest probabilities are sampled. Then tokens are further + * filtered based on top-P with the final token selected using temperature sampling. + *

+ * Specify a lower value for less random responses and a higher value for more random responses. + *

+ * Range: 1-40 + *

+ * Default for gemini-1.5-pro: 0.94 + *

+ * Default for gemini-1.0-pro: 1 + */ + OptionalInt topK(); + + /** + * Whether chat model requests should be logged + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether chat model responses should be logged + */ + @ConfigDocDefault("false") + Optional logResponses(); +} diff --git a/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/config/LangChain4jVertexAiGeminiConfig.java b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/config/LangChain4jVertexAiGeminiConfig.java new file mode 100644 index 000000000..e6854b3b9 --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/config/LangChain4jVertexAiGeminiConfig.java @@ -0,0 +1,87 @@ +package io.quarkiverse.langchain4j.vertexai.runtime.gemini.config; + +import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; + +import java.util.Map; +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; +import io.quarkus.runtime.annotations.ConfigGroup; +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; + +@ConfigRoot(phase = RUN_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.vertexai.gemini") +public interface LangChain4jVertexAiGeminiConfig { + /** + * Default model config + */ + @WithParentName + VertexAiGeminiConfig defaultConfig(); + + /** + * Named model config + */ + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); + + @ConfigGroup + interface VertexAiGeminiConfig { + + /** + * The unique identifier of the project + */ + @WithDefault("dummy") // TODO: this should be optional but Smallrye Config doesn't like it + String projectId(); + + /** + * GCP location + */ + @WithDefault("dummy") // TODO: this should be optional but Smallrye Config doesn't like it + String location(); + + /** + * Publisher of model + */ + @WithDefault("google") + String publisher(); + + /** + * Meant to be used for testing only in order to override the base URL used by the client + */ + Optional baseUrl(); + + /** + * Whether to enable the integration. Defaults to {@code true}, which means requests are made to the Anthropic + * provider. + * Set to {@code false} to disable all requests. + */ + @WithDefault("true") + Boolean enableIntegration(); + + /** + * Whether the Vertex AI client should log requests + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether the Vertex AI client should log responses + */ + @ConfigDocDefault("false") + Optional logResponses(); + + /** + * Chat model related settings + */ + ChatModelConfig chatModel(); + } +} diff --git a/vertex-ai-gemini/runtime/src/main/resources/META-INF/beans.xml b/vertex-ai-gemini/runtime/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb diff --git a/vertex-ai-gemini/runtime/src/main/resources/META-INF/quarkus-extension.yaml b/vertex-ai-gemini/runtime/src/main/resources/META-INF/quarkus-extension.yaml new file mode 100644 index 000000000..aa0495c7f --- /dev/null +++ b/vertex-ai-gemini/runtime/src/main/resources/META-INF/quarkus-extension.yaml @@ -0,0 +1,17 @@ +name: LangChain4j Vertex AI Gemini +artifact: ${project.groupId}:${project.artifactId}:${project.version} +description: Provides integration of Quarkus LangChain4j with Vertex AI Gemini +metadata: + keywords: + - ai + - langchain4j + - vertex + # guide: https://quarkiverse.github.io/quarkiverse-docs/langchain4j/dev/ # To create and publish this guide, see https://github.com/quarkiverse/quarkiverse/wiki#documenting-your-extension + categories: + - "miscellaneous" + status: "experimental" + codestart: + name: langchain4j-vertex-ai-gemini + languages: + - "java" + artifact: "io.quarkiverse.langchain4j:quarkus-langchain4j-vertex-ai:codestarts:jar:${project.version}"