Skip to content

Commit

Permalink
feat: [vertexai] allow setting ToolConfig and SystemInstruction in Ch…
Browse files Browse the repository at this point in the history
…atSession (#10953)

PiperOrigin-RevId: 641987014

Co-authored-by: Jaycee Li <jayceeli@google.com>
  • Loading branch information
copybara-service[bot] and jaycee-li authored Jun 11, 2024
1 parent 0801812 commit 5ebfc33
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.ToolConfig;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -40,8 +41,8 @@ public final class ChatSession {
private final GenerativeModel model;
private final Optional<ChatSession> rootChatSession;
private final Optional<AutomaticFunctionCallingResponder> automaticFunctionCallingResponder;
private List<Content> history = new ArrayList<>();
private int previousHistorySize = 0;
private List<Content> history;
private int previousHistorySize;
private Optional<ResponseStream<GenerateContentResponse>> currentResponseStream;
private Optional<GenerateContentResponse> currentResponse;

Expand All @@ -50,14 +51,17 @@ public final class ChatSession {
* GenerationConfig) inherits from the model.
*/
public ChatSession(GenerativeModel model) {
this(model, Optional.empty(), Optional.empty());
this(model, new ArrayList<>(), 0, Optional.empty(), Optional.empty());
}

/**
* Creates a new chat session given a GenerativeModel instance and a root chat session.
* Configurations of the chat (e.g., GenerationConfig) inherits from the model.
*
* @param model a {@link GenerativeModel} instance that generates contents in the chat.
* @param history a list of {@link Content} containing interleaving conversation between "user"
* and "model".
* @param previousHistorySize the size of the previous history.
* @param rootChatSession a root {@link ChatSession} instance. All the chat history in the current
* chat session will be merged to the root chat session.
* @param automaticFunctionCallingResponder an {@link AutomaticFunctionCallingResponder} instance
Expand All @@ -66,10 +70,14 @@ public ChatSession(GenerativeModel model) {
*/
private ChatSession(
GenerativeModel model,
List<Content> history,
int previousHistorySize,
Optional<ChatSession> rootChatSession,
Optional<AutomaticFunctionCallingResponder> automaticFunctionCallingResponder) {
checkNotNull(model, "model should not be null");
this.model = model;
this.history = history;
this.previousHistorySize = previousHistorySize;
this.rootChatSession = rootChatSession;
this.automaticFunctionCallingResponder = automaticFunctionCallingResponder;
currentResponseStream = Optional.empty();
Expand All @@ -84,15 +92,12 @@ private ChatSession(
* @return a new {@link ChatSession} instance with the specified GenerationConfig.
*/
public ChatSession withGenerationConfig(GenerationConfig generationConfig) {
ChatSession rootChat = rootChatSession.orElse(this);
ChatSession newChatSession =
new ChatSession(
model.withGenerationConfig(generationConfig),
Optional.of(rootChat),
automaticFunctionCallingResponder);
newChatSession.history = history;
newChatSession.previousHistorySize = previousHistorySize;
return newChatSession;
return new ChatSession(
model.withGenerationConfig(generationConfig),
history,
previousHistorySize,
Optional.of(rootChatSession.orElse(this)),
automaticFunctionCallingResponder);
}

/**
Expand All @@ -103,15 +108,12 @@ public ChatSession withGenerationConfig(GenerationConfig generationConfig) {
* @return a new {@link ChatSession} instance with the specified SafetySettings.
*/
public ChatSession withSafetySettings(List<SafetySetting> safetySettings) {
ChatSession rootChat = rootChatSession.orElse(this);
ChatSession newChatSession =
new ChatSession(
model.withSafetySettings(safetySettings),
Optional.of(rootChat),
automaticFunctionCallingResponder);
newChatSession.history = history;
newChatSession.previousHistorySize = previousHistorySize;
return newChatSession;
return new ChatSession(
model.withSafetySettings(safetySettings),
history,
previousHistorySize,
Optional.of(rootChatSession.orElse(this)),
automaticFunctionCallingResponder);
}

/**
Expand All @@ -122,13 +124,44 @@ public ChatSession withSafetySettings(List<SafetySetting> safetySettings) {
* @return a new {@link ChatSession} instance with the specified Tools.
*/
public ChatSession withTools(List<Tool> tools) {
ChatSession rootChat = rootChatSession.orElse(this);
ChatSession newChatSession =
new ChatSession(
model.withTools(tools), Optional.of(rootChat), automaticFunctionCallingResponder);
newChatSession.history = history;
newChatSession.previousHistorySize = previousHistorySize;
return newChatSession;
return new ChatSession(
model.withTools(tools),
history,
previousHistorySize,
Optional.of(rootChatSession.orElse(this)),
automaticFunctionCallingResponder);
}

/**
* Creates a copy of the current ChatSession with updated ToolConfig.
*
* @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used in the
* new ChatSession.
* @return a new {@link ChatSession} instance with the specified ToolConfigs.
*/
public ChatSession withToolConfig(ToolConfig toolConfig) {
return new ChatSession(
model.withToolConfig(toolConfig),
history,
previousHistorySize,
Optional.of(rootChatSession.orElse(this)),
automaticFunctionCallingResponder);
}

/**
* Creates a copy of the current ChatSession with updated SystemInstruction.
*
* @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} containing system
* instructions.
* @return a new {@link ChatSession} instance with the specified ToolConfigs.
*/
public ChatSession withSystemInstruction(Content systemInstruction) {
return new ChatSession(
model.withSystemInstruction(systemInstruction),
history,
previousHistorySize,
Optional.of(rootChatSession.orElse(this)),
automaticFunctionCallingResponder);
}

/**
Expand All @@ -141,13 +174,12 @@ public ChatSession withTools(List<Tool> tools) {
*/
public ChatSession withAutomaticFunctionCallingResponder(
AutomaticFunctionCallingResponder automaticFunctionCallingResponder) {
ChatSession rootChat = rootChatSession.orElse(this);
ChatSession newChatSession =
new ChatSession(
model, Optional.of(rootChat), Optional.of(automaticFunctionCallingResponder));
newChatSession.history = history;
newChatSession.previousHistorySize = previousHistorySize;
return newChatSession;
return new ChatSession(
model,
history,
previousHistorySize,
Optional.of(rootChatSession.orElse(this)),
Optional.of(automaticFunctionCallingResponder));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.google.cloud.vertexai.api.Candidate.FinishReason;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.FunctionCall;
import com.google.cloud.vertexai.api.FunctionCallingConfig;
import com.google.cloud.vertexai.api.FunctionDeclaration;
import com.google.cloud.vertexai.api.GenerateContentRequest;
import com.google.cloud.vertexai.api.GenerateContentResponse;
Expand All @@ -40,6 +41,7 @@
import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.ToolConfig;
import com.google.cloud.vertexai.api.Type;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
Expand Down Expand Up @@ -174,6 +176,16 @@ public final class ChatSessionTest {
.build())
.addRequired("location")))
.build();
private static final ToolConfig TOOL_CONFIG =
ToolConfig.newBuilder()
.setFunctionCallingConfig(
FunctionCallingConfig.newBuilder()
.setMode(FunctionCallingConfig.Mode.ANY)
.addAllowedFunctionNames("getCurrentWeather"))
.build();
private static final Content SYSTEM_INSTRUCTION =
ContentMaker.fromString(
"You're a helpful assistant that starts all its answers with: \"COOL\"");

@Rule public final MockitoRule mocksRule = MockitoJUnit.rule();

Expand Down Expand Up @@ -518,7 +530,9 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception {
rootChat
.withGenerationConfig(GENERATION_CONFIG)
.withSafetySettings(Arrays.asList(SAFETY_SETTING))
.withTools(Arrays.asList(TOOL));
.withTools(Arrays.asList(TOOL))
.withToolConfig(TOOL_CONFIG)
.withSystemInstruction(SYSTEM_INSTRUCTION);
response = childChat.sendMessage(SAMPLE_MESSAGE_2);

// (Assert) root chat history should contain all 4 contents
Expand All @@ -532,8 +546,12 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception {
ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable, times(2)).call(request.capture());
Content expectedSystemInstruction = SYSTEM_INSTRUCTION.toBuilder().clearRole().build();
assertThat(request.getAllValues().get(1).getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getAllValues().get(1).getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getAllValues().get(1).getTools(0)).isEqualTo(TOOL);
assertThat(request.getAllValues().get(1).getToolConfig()).isEqualTo(TOOL_CONFIG);
assertThat(request.getAllValues().get(1).getSystemInstruction())
.isEqualTo(expectedSystemInstruction);
}
}

0 comments on commit 5ebfc33

Please sign in to comment.