Skip to content

Commit

Permalink
feat: [vertexai] add FunctionDeclarationMaker.fromFunc to create Func…
Browse files Browse the repository at this point in the history
…tionDeclaration from a Java static method (#10915)

PiperOrigin-RevId: 639154403

Co-authored-by: Jaycee Li <jayceeli@google.com>
  • Loading branch information
copybara-service[bot] and jaycee-li authored Jun 11, 2024
1 parent 5ebfc33 commit 5a10656
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 178 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.ToolConfig;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -41,8 +40,8 @@ public final class ChatSession {
private final GenerativeModel model;
private final Optional<ChatSession> rootChatSession;
private final Optional<AutomaticFunctionCallingResponder> automaticFunctionCallingResponder;
private List<Content> history;
private int previousHistorySize;
private List<Content> history = new ArrayList<>();
private int previousHistorySize = 0;
private Optional<ResponseStream<GenerateContentResponse>> currentResponseStream;
private Optional<GenerateContentResponse> currentResponse;

Expand All @@ -51,17 +50,14 @@ public final class ChatSession {
* GenerationConfig) inherits from the model.
*/
public ChatSession(GenerativeModel model) {
this(model, new ArrayList<>(), 0, Optional.empty(), Optional.empty());
this(model, Optional.empty(), Optional.empty());
}

/**
* Creates a new chat session given a GenerativeModel instance and a root chat session.
* Configurations of the chat (e.g., GenerationConfig) inherits from the model.
*
* @param model a {@link GenerativeModel} instance that generates contents in the chat.
* @param history a list of {@link Content} containing interleaving conversation between "user"
* and "model".
* @param previousHistorySize the size of the previous history.
* @param rootChatSession a root {@link ChatSession} instance. All the chat history in the current
* chat session will be merged to the root chat session.
* @param automaticFunctionCallingResponder an {@link AutomaticFunctionCallingResponder} instance
Expand All @@ -70,14 +66,10 @@ public ChatSession(GenerativeModel model) {
*/
private ChatSession(
GenerativeModel model,
List<Content> history,
int previousHistorySize,
Optional<ChatSession> rootChatSession,
Optional<AutomaticFunctionCallingResponder> automaticFunctionCallingResponder) {
checkNotNull(model, "model should not be null");
this.model = model;
this.history = history;
this.previousHistorySize = previousHistorySize;
this.rootChatSession = rootChatSession;
this.automaticFunctionCallingResponder = automaticFunctionCallingResponder;
currentResponseStream = Optional.empty();
Expand All @@ -92,12 +84,15 @@ private ChatSession(
* @return a new {@link ChatSession} instance with the specified GenerationConfig.
*/
public ChatSession withGenerationConfig(GenerationConfig generationConfig) {
return new ChatSession(
model.withGenerationConfig(generationConfig),
history,
previousHistorySize,
Optional.of(rootChatSession.orElse(this)),
automaticFunctionCallingResponder);
ChatSession rootChat = rootChatSession.orElse(this);
ChatSession newChatSession =
new ChatSession(
model.withGenerationConfig(generationConfig),
Optional.of(rootChat),
automaticFunctionCallingResponder);
newChatSession.history = history;
newChatSession.previousHistorySize = previousHistorySize;
return newChatSession;
}

/**
Expand All @@ -108,12 +103,15 @@ public ChatSession withGenerationConfig(GenerationConfig generationConfig) {
* @return a new {@link ChatSession} instance with the specified SafetySettings.
*/
public ChatSession withSafetySettings(List<SafetySetting> safetySettings) {
return new ChatSession(
model.withSafetySettings(safetySettings),
history,
previousHistorySize,
Optional.of(rootChatSession.orElse(this)),
automaticFunctionCallingResponder);
ChatSession rootChat = rootChatSession.orElse(this);
ChatSession newChatSession =
new ChatSession(
model.withSafetySettings(safetySettings),
Optional.of(rootChat),
automaticFunctionCallingResponder);
newChatSession.history = history;
newChatSession.previousHistorySize = previousHistorySize;
return newChatSession;
}

/**
Expand All @@ -124,44 +122,13 @@ public ChatSession withSafetySettings(List<SafetySetting> safetySettings) {
* @return a new {@link ChatSession} instance with the specified Tools.
*/
public ChatSession withTools(List<Tool> tools) {
return new ChatSession(
model.withTools(tools),
history,
previousHistorySize,
Optional.of(rootChatSession.orElse(this)),
automaticFunctionCallingResponder);
}

/**
* Creates a copy of the current ChatSession with updated ToolConfig.
*
* @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used in the
* new ChatSession.
* @return a new {@link ChatSession} instance with the specified ToolConfigs.
*/
public ChatSession withToolConfig(ToolConfig toolConfig) {
return new ChatSession(
model.withToolConfig(toolConfig),
history,
previousHistorySize,
Optional.of(rootChatSession.orElse(this)),
automaticFunctionCallingResponder);
}

/**
* Creates a copy of the current ChatSession with updated SystemInstruction.
*
* @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} containing system
* instructions.
* @return a new {@link ChatSession} instance with the specified ToolConfigs.
*/
public ChatSession withSystemInstruction(Content systemInstruction) {
return new ChatSession(
model.withSystemInstruction(systemInstruction),
history,
previousHistorySize,
Optional.of(rootChatSession.orElse(this)),
automaticFunctionCallingResponder);
ChatSession rootChat = rootChatSession.orElse(this);
ChatSession newChatSession =
new ChatSession(
model.withTools(tools), Optional.of(rootChat), automaticFunctionCallingResponder);
newChatSession.history = history;
newChatSession.previousHistorySize = previousHistorySize;
return newChatSession;
}

/**
Expand All @@ -174,12 +141,13 @@ public ChatSession withSystemInstruction(Content systemInstruction) {
*/
public ChatSession withAutomaticFunctionCallingResponder(
AutomaticFunctionCallingResponder automaticFunctionCallingResponder) {
return new ChatSession(
model,
history,
previousHistorySize,
Optional.of(rootChatSession.orElse(this)),
Optional.of(automaticFunctionCallingResponder));
ChatSession rootChat = rootChatSession.orElse(this);
ChatSession newChatSession =
new ChatSession(
model, Optional.of(rootChat), Optional.of(automaticFunctionCallingResponder));
newChatSession.history = history;
newChatSession.previousHistorySize = previousHistorySize;
return newChatSession;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.cloud.vertexai.api.FunctionDeclaration;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Type;
import com.google.common.base.Strings;
import com.google.gson.JsonObject;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Parameter;

/** Helper class to create {@link com.google.cloud.vertexai.api.FunctionDeclaration} */
public final class FunctionDeclarationMaker {
Expand Down Expand Up @@ -60,4 +65,92 @@ public static FunctionDeclaration fromJsonObject(JsonObject jsonObject)
checkNotNull(jsonObject, "JsonObject can't be null.");
return fromJsonString(jsonObject.toString());
}

/**
* Creates a FunctionDeclaration from a Java static method
*
* <p><b>Note:</b>: If you don't want to manually provide parameter names, you can ignore
* `orderedParameterNames` and compile your code with the "-parameters" flag. In this case, the
* parameter names can be auto retrieved from reflection.
*
* @param functionDescription A description of the method.
* @param function A Java static method.
* @param orderedParameterNames A list of parameter names in the order they are passed to the
* method.
* @return a {@link com.google.cloud.vertexai.api.FunctionDeclaration} instance.
* @throws IllegalArgumentException if the method is not a static method or the number of provided
* parameter names doesn't match the number of parameters in the callable function or
* parameter types in this method are not String, boolean, int, double, or float.
* @throws IllegalStateException if the parameter names are not provided and cannot be retrieved
* from reflection
*/
public static FunctionDeclaration fromFunc(
String functionDescription, Method function, String... orderedParameterNames) {
if (!Modifier.isStatic(function.getModifiers())) {
throw new IllegalArgumentException(
"Instance methods are not supported. Please use static methods.");
}
Schema.Builder parametersBuilder = Schema.newBuilder().setType(Type.OBJECT);

Parameter[] parameters = function.getParameters();
// If parameter names are provided, the number of parameter names should match the number of
// parameters in the method.
if (orderedParameterNames.length > 0 && orderedParameterNames.length != parameters.length) {
throw new IllegalArgumentException(
"The number of parameter names does not match the number of parameters in the method.");
}

for (int i = 0; i < parameters.length; i++) {
if (orderedParameterNames.length == 0) {
// If parameter names are not provided, try to retrieve them from reflection.
if (!parameters[i].isNamePresent()) {
throw new IllegalStateException(
"Failed to retrieve the parameter name from reflection. Please compile your"
+ " code with \"-parameters\" flag or use `fromFunc(String, Method,"
+ " String...)` to manually enter parameter names");
}
addParameterToParametersBuilder(
parametersBuilder, parameters[i].getName(), parameters[i].getType());
} else {
addParameterToParametersBuilder(
parametersBuilder, orderedParameterNames[i], parameters[i].getType());
}
}

return FunctionDeclaration.newBuilder()
.setName(function.getName())
.setDescription(functionDescription)
.setParameters(parametersBuilder)
.build();
}

/** Adds a parameter to the parameters builder. */
private static void addParameterToParametersBuilder(
Schema.Builder parametersBuilder, String parameterName, Class<?> parameterType) {
Schema.Builder parameterBuilder = Schema.newBuilder().setDescription(parameterName);
switch (parameterType.getName()) {
case "java.lang.String":
parameterBuilder.setType(Type.STRING);
break;
case "boolean":
parameterBuilder.setType(Type.BOOLEAN);
break;
case "int":
parameterBuilder.setType(Type.INTEGER);
break;
case "double":
case "float":
parameterBuilder.setType(Type.NUMBER);
break;
default:
throw new IllegalArgumentException(
"Unsupported parameter type "
+ parameterType.getName()
+ " for parameter "
+ parameterName);
}
parametersBuilder
.addRequired(parameterName)
.putProperties(parameterName, parameterBuilder.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import com.google.cloud.vertexai.api.Candidate.FinishReason;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.FunctionCall;
import com.google.cloud.vertexai.api.FunctionCallingConfig;
import com.google.cloud.vertexai.api.FunctionDeclaration;
import com.google.cloud.vertexai.api.GenerateContentRequest;
import com.google.cloud.vertexai.api.GenerateContentResponse;
Expand All @@ -41,7 +40,6 @@
import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.ToolConfig;
import com.google.cloud.vertexai.api.Type;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
Expand Down Expand Up @@ -176,16 +174,6 @@ public final class ChatSessionTest {
.build())
.addRequired("location")))
.build();
private static final ToolConfig TOOL_CONFIG =
ToolConfig.newBuilder()
.setFunctionCallingConfig(
FunctionCallingConfig.newBuilder()
.setMode(FunctionCallingConfig.Mode.ANY)
.addAllowedFunctionNames("getCurrentWeather"))
.build();
private static final Content SYSTEM_INSTRUCTION =
ContentMaker.fromString(
"You're a helpful assistant that starts all its answers with: \"COOL\"");

@Rule public final MockitoRule mocksRule = MockitoJUnit.rule();

Expand Down Expand Up @@ -530,9 +518,7 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception {
rootChat
.withGenerationConfig(GENERATION_CONFIG)
.withSafetySettings(Arrays.asList(SAFETY_SETTING))
.withTools(Arrays.asList(TOOL))
.withToolConfig(TOOL_CONFIG)
.withSystemInstruction(SYSTEM_INSTRUCTION);
.withTools(Arrays.asList(TOOL));
response = childChat.sendMessage(SAMPLE_MESSAGE_2);

// (Assert) root chat history should contain all 4 contents
Expand All @@ -546,12 +532,8 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception {
ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable, times(2)).call(request.capture());
Content expectedSystemInstruction = SYSTEM_INSTRUCTION.toBuilder().clearRole().build();
assertThat(request.getAllValues().get(1).getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getAllValues().get(1).getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getAllValues().get(1).getTools(0)).isEqualTo(TOOL);
assertThat(request.getAllValues().get(1).getToolConfig()).isEqualTo(TOOL_CONFIG);
assertThat(request.getAllValues().get(1).getSystemInstruction())
.isEqualTo(expectedSystemInstruction);
}
}
Loading

0 comments on commit 5a10656

Please sign in to comment.