Skip to content

Commit

Permalink
chore: [vertexai] sync Github PR #1077
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631145038
  • Loading branch information
jaycee-li authored and copybara-github committed May 7, 2024
1 parent 51e7ad6 commit 4bc6a5f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public GenerativeModel build() {
* Sets the name of the generative model. This is required for building a GenerativeModel
* instance. Supported format: "gemini-pro", "models/gemini-pro",
* "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model
* names can be found in the Gemini models documentation
* names can be found in the Gemini models documentation:
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
*/
@CanIgnoreReturnValue
Expand Down Expand Up @@ -217,13 +217,9 @@ public Builder setSystemInstruction(Content systemInstruction) {
* @return a new {@link GenerativeModel} instance with the specified GenerationConfig.
*/
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
checkNotNull(generationConfig, "GenerationConfig can't be null.");
return new GenerativeModel(
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
systemInstruction,
vertexAi);
modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi);
}

/**
Expand All @@ -234,11 +230,14 @@ public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
* @return a new {@link GenerativeModel} instance with the specified safetySettings.
*/
public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
checkNotNull(
safetySettings,
"safetySettings can't be null. Use an empty list if no safety settings is intended.");
return new GenerativeModel(
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
tools,
systemInstruction,
vertexAi);
}
Expand All @@ -251,10 +250,11 @@ public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
* @return a new {@link GenerativeModel} instance with the specified tools.
*/
public GenerativeModel withTools(List<Tool> tools) {
checkNotNull(tools, "tools can't be null. Use an empty list if no tool is to be used.");
return new GenerativeModel(
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
safetySettings,
ImmutableList.copyOf(tools),
systemInstruction,
vertexAi);
Expand All @@ -268,11 +268,15 @@ public GenerativeModel withTools(List<Tool> tools) {
* @return a new {@link GenerativeModel} instance with the specified tools.
*/
public GenerativeModel withSystemInstruction(Content systemInstruction) {
checkNotNull(
systemInstruction,
"system instruction can't be null. "
+ "Use Optional.empty() if no system instruction should be provided.");
return new GenerativeModel(
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
safetySettings,
tools,
Optional.of(systemInstruction),
vertexAi);
}
Expand Down Expand Up @@ -506,7 +510,6 @@ private ApiFuture<GenerateContentResponse> generateContentAsync(GenerateContentR
*/
private GenerateContentRequest buildGenerateContentRequest(List<Content> contents) {
checkArgument(contents != null && !contents.isEmpty(), "contents can't be null or empty.");

GenerateContentRequest.Builder requestBuilder =
GenerateContentRequest.newBuilder()
.setModel(resourceName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,17 +325,17 @@ public void testGenerateContentwithContents() throws Exception {
}

@Test
public void testGenerateContentwithSystemInstructions() throws Exception {
String systemInstructionText =
"You're a helpful assistant that starts all its answers with: \"COOL\"";
Content systemInstructions = ContentMaker.fromString(systemInstructionText);

model = new GenerativeModel(MODEL_NAME, vertexAi).withSystemInstruction(systemInstructions);

public void testGenerateContentwithSystemInstruction() throws Exception {
when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockGenerateContentResponse);

String systemInstructionText =
"You're a helpful assistant that starts all its answers with: \"COOL\"";
Content systemInstruction = ContentMaker.fromString(systemInstructionText);

model = new GenerativeModel(MODEL_NAME, vertexAi).withSystemInstruction(systemInstruction);

Content content = ContentMaker.fromString(TEXT);
GenerateContentResponse unused = model.generateContent(Arrays.asList(content));

Expand Down

0 comments on commit 4bc6a5f

Please sign in to comment.