Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [vertexai] Support Function calling #10242

Merged
merged 1 commit into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion java-vertexai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ If you are using Maven with [BOM][libraries-bom], add this to your pom.xml file:
<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>libraries-bom</artifactId>
<version>26.30.0</version>
<version>26.29.0</version>
<type>pom</type>
<scope>import</scope>
</dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import com.google.cloud.vertexai.api.CountTokensRequest;
import com.google.cloud.vertexai.api.CountTokensResponse;
import com.google.cloud.vertexai.api.GenerateContentRequest;
import com.google.cloud.vertexai.api.GenerateContentRequest.Builder;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.Tool;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -40,8 +40,131 @@ public class GenerativeModel {
private final VertexAI vertexAi;
private GenerationConfig generationConfig = null;
private List<SafetySetting> safetySettings = null;
private List<Tool> tools = null;
private Transport transport;

public static Builder newBuilder() {
return new Builder();
}

private GenerativeModel(Builder builder) {
this.modelName = builder.modelName;

this.vertexAi = builder.vertexAi;

this.resourceName =
String.format(
"projects/%s/locations/%s/publishers/google/models/%s",
this.vertexAi.getProjectId(), this.vertexAi.getLocation(), this.modelName);

if (builder.generationConfig != null) {
this.generationConfig = builder.generationConfig;
}
if (builder.safetySettings != null) {
this.safetySettings = builder.safetySettings;
}
if (builder.tools != null) {
this.tools = builder.tools;
}

if (builder.transport != null) {
this.transport = builder.transport;
} else {
this.transport = this.vertexAi.getTransport();
}
}

/** Builder class for {@link GenerativeModel}. */
public static class Builder {
private String modelName;
private VertexAI vertexAi;
private GenerationConfig generationConfig;
private List<SafetySetting> safetySettings;
private List<Tool> tools;
private Transport transport;

private Builder() {}

public GenerativeModel build() {
if (this.modelName == null) {
throw new IllegalArgumentException(
"modelName is required. Please call setModelName() before building.");
}
if (this.vertexAi == null) {
throw new IllegalArgumentException(
"vertexAi is required. Please call setVertexAi() before building.");
}
return new GenerativeModel(this);
}

/**
* Set the name of the generative model. This is required for building a GenerativeModel
* instance. Supported format: "gemini-pro", "models/gemini-pro",
* "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model
* names can be found at
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
*/
public Builder setModelName(String modelName) {
this.modelName = validateModelName(modelName);
return this;
}

/**
* Set {@link com.google.cloud.vertexai.VertexAI} that contains the default configs for the
* generative model. This is required for building a GenerativeModel instance.
*/
public Builder setVertexAi(VertexAI vertexAi) {
this.vertexAi = vertexAi;
return this;
}

/**
* Set {@link com.google.cloud.vertexai.api.GenerationConfig} that will be used by default to
* interact with the generative model.
*/
public Builder setGenerationConfig(GenerationConfig generationConfig) {
this.generationConfig = generationConfig;
return this;
}

/**
* Set a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will be used by
* default to interact with the generative model.
*/
public Builder setSafetySettings(List<SafetySetting> safetySettings) {
this.safetySettings = new ArrayList<>();
for (SafetySetting safetySetting : safetySettings) {
if (safetySetting != null) {
this.safetySettings.add(safetySetting);
}
}
return this;
}

/**
* Set a list of {@link com.google.cloud.vertexai.api.Tool} that will be used by default to
* interact with the generative model.
*/
public Builder setTools(List<Tool> tools) {
this.tools = new ArrayList<>();
for (Tool tool : tools) {
if (tool != null) {
this.tools.add(tool);
}
}
return this;
}

/**
* Set the {@link Transport} layer for API calls in the generative model. It overrides the
* transport setting in {@link com.google.cloud.vertexai.VertexAI}
*/
public Builder setTransport(Transport transport) {
this.transport = transport;
return this;
}
}

/**
* Construct a GenerativeModel instance.
*
Expand Down Expand Up @@ -384,7 +507,8 @@ public GenerateContentResponse generateContent(
public GenerateContentResponse generateContent(
List<Content> contents, GenerationConfig generationConfig, List<SafetySetting> safetySettings)
throws IOException {
Builder requestBuilder = GenerateContentRequest.newBuilder().addAllContents(contents);
GenerateContentRequest.Builder requestBuilder =
GenerateContentRequest.newBuilder().addAllContents(contents);
if (generationConfig != null) {
requestBuilder.setGenerationConfig(generationConfig);
} else if (this.generationConfig != null) {
Expand All @@ -395,6 +519,9 @@ public GenerateContentResponse generateContent(
} else if (this.safetySettings != null) {
requestBuilder.addAllSafetySettings(this.safetySettings);
}
if (this.tools != null) {
requestBuilder.addAllTools(this.tools);
}
return ResponseHandler.aggregateStreamIntoResponse(generateContentStream(requestBuilder));
}

Expand Down Expand Up @@ -655,7 +782,8 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
public ResponseStream<GenerateContentResponse> generateContentStream(
List<Content> contents, GenerationConfig generationConfig, List<SafetySetting> safetySettings)
throws IOException {
Builder requestBuilder = GenerateContentRequest.newBuilder().addAllContents(contents);
GenerateContentRequest.Builder requestBuilder =
GenerateContentRequest.newBuilder().addAllContents(contents);
if (generationConfig != null) {
requestBuilder.setGenerationConfig(generationConfig);
} else if (this.generationConfig != null) {
Expand All @@ -666,6 +794,9 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
} else if (this.safetySettings != null) {
requestBuilder.addAllSafetySettings(this.safetySettings);
}
if (this.tools != null) {
requestBuilder.addAllTools(this.tools);
}
return generateContentStream(requestBuilder);
}

Expand All @@ -678,8 +809,8 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
* com.google.cloud.vertexai.api.GenerateContentResponse}
* @throws IOException if an I/O error occurs while making the API call
*/
private ResponseStream<GenerateContentResponse> generateContentStream(Builder requestBuilder)
throws IOException {
private ResponseStream<GenerateContentResponse> generateContentStream(
GenerateContentRequest.Builder requestBuilder) throws IOException {
GenerateContentRequest request = requestBuilder.setModel(this.resourceName).build();
ResponseStream<GenerateContentResponse> responseStream = null;
if (this.transport == Transport.REST) {
Expand Down Expand Up @@ -723,6 +854,16 @@ public void setSafetySettings(List<SafetySetting> safetySettings) {
}
}

/**
* Sets the value for {@link #getTools}, which will be used by default for generating response.
*/
public void setTools(List<Tool> tools) {
this.tools = new ArrayList<>();
for (Tool tool : tools) {
this.tools.add(tool);
}
}

/**
* Sets the value for {@link #getTransport}, which defines the layer for API calls in this
* generative model.
Expand Down Expand Up @@ -760,6 +901,15 @@ public List<SafetySetting> getSafetySettings() {
}
}

/** Returns a list of {@link com.google.cloud.vertexai.api.Tool} of this generative model. */
public List<Tool> getTools() {
if (this.tools != null) {
return Collections.unmodifiableList(this.tools);
} else {
return null;
}
}

public ChatSession startChat() {
return new ChatSession(this);
}
Expand Down
Loading