diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/Constants.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/Constants.java index 9ac034a89d5b..4e12bcdcbdc1 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/Constants.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/Constants.java @@ -20,8 +20,12 @@ /** A class that holds all constants for vertexai/generativeai. */ public final class Constants { + public static final String MODEL_NAME_PREFIX_PROJECTS = "projects/"; + public static final String MODEL_NAME_PREFIX_PUBLISHERS = "publishers/"; + public static final String MODEL_NAME_PREFIX_MODELS = "models/"; public static final ImmutableSet MODEL_NAME_PREFIXES = - ImmutableSet.of("publishers/google/models/", "models/"); + ImmutableSet.of( + MODEL_NAME_PREFIX_PROJECTS, MODEL_NAME_PREFIX_PUBLISHERS, MODEL_NAME_PREFIX_MODELS); private Constants() {} } diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java index 5a1d3d9a56f3..db5a6ab7cb6d 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java @@ -99,12 +99,9 @@ private GenerativeModel( checkNotNull(safetySettings, "ImmutableList can't be null."); checkNotNull(tools, "ImmutableList can't be null."); - modelName = reconcileModelName(modelName); - this.modelName = modelName; - this.resourceName = - String.format( - "projects/%s/locations/%s/publishers/google/models/%s", - vertexAi.getProjectId(), vertexAi.getLocation(), modelName); + this.resourceName = getResourceName(modelName, vertexAi); + // reconcileModelName should be called after getResourceName. + this.modelName = reconcileModelName(modelName); this.vertexAi = vertexAi; this.generationConfig = generationConfig; this.safetySettings = safetySettings; @@ -157,7 +154,7 @@ public Builder setModelName(String modelName) { + " https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models" + " to find the right model name."); - this.modelName = reconcileModelName(modelName); + this.modelName = modelName; return this; } @@ -584,10 +581,28 @@ public ChatSession startChat() { private static String reconcileModelName(String modelName) { for (String prefix : Constants.MODEL_NAME_PREFIXES) { if (modelName.startsWith(prefix)) { - modelName = modelName.substring(prefix.length()); + modelName = modelName.substring(modelName.lastIndexOf('/') + 1); break; } } return modelName; } + + /** + * Computes resourceName based on original modelName. Note: this should happen before the + * modelName is reconciled. + */ + private static String getResourceName(String modelName, VertexAI vertexAi) { + if (modelName.startsWith(Constants.MODEL_NAME_PREFIX_PROJECTS)) { + return modelName; + } else if (modelName.startsWith(Constants.MODEL_NAME_PREFIX_PUBLISHERS)) { + return String.format( + "projects/%s/locations/%s/%s", + vertexAi.getProjectId(), vertexAi.getLocation(), modelName); + } else { + return String.format( + "projects/%s/locations/%s/publishers/google/models/%s", + vertexAi.getProjectId(), vertexAi.getLocation(), reconcileModelName(modelName)); + } + } } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java index c52f3be380ba..1df2354b6d96 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java @@ -108,9 +108,7 @@ public final class GenerativeModelTest { .setVertexAiSearch( VertexAISearch.newBuilder() .setDatastore( - String.format( - "projects/%s/locations/%s/collections/%s/dataStores/%s", - PROJECT, "global", "default_collection", "test_123"))) + "projects/test_project/locations/global/collections/default_collection/dataStores/test_123")) .setDisableAttribution(false)) .build(); @@ -164,6 +162,19 @@ public void testInstantiateGenerativeModel() { assertThat(model.getTools()).isEmpty(); } + @Test + public void + testInstantiateGenerativeModel_withModelNameStartingFromProjects_modelNameIsCorrect() { + model = + new GenerativeModel( + "projects/test_project/locations/test_location/publishers/google/models/gemini-pro", + vertexAi); + assertThat(model.getModelName()).isEqualTo(MODEL_NAME); + assertThat(model.getGenerationConfig()).isEqualTo(GenerationConfig.getDefaultInstance()); + assertThat(model.getSafetySettings()).isEmpty(); + assertThat(model.getTools()).isEmpty(); + } + @Test public void testInstantiateGenerativeModelwithBuilder() { model = new GenerativeModel.Builder().setModelName(MODEL_NAME).setVertexAi(vertexAi).build(); @@ -286,6 +297,32 @@ public void testGenerateContentwithText() throws Exception { ArgumentCaptor.forClass(GenerateContentRequest.class); verify(mockUnaryCallable).call(request.capture()); assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT); + assertThat(request.getValue().getModel()) + .isEqualTo( + "projects/test_project/locations/test_location/publishers/google/models/gemini-pro"); + } + + @Test + public void testGenerateContentwithText_withFullModelName_requestHasCorrectResourceName() + throws Exception { + model = + new GenerativeModel( + "projects/another_project/locations/europe-west4/publishers/google/models/another_model", + vertexAi); + + when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); + when(mockUnaryCallable.call(any(GenerateContentRequest.class))) + .thenReturn(mockGenerateContentResponse); + + GenerateContentResponse unused = model.generateContent(TEXT); + + ArgumentCaptor request = + ArgumentCaptor.forClass(GenerateContentRequest.class); + verify(mockUnaryCallable).call(request.capture()); + assertThat(request.getValue().getModel()) + .isEqualTo( + "projects/another_project/locations/europe-west4/publishers/google/models/another_model"); + assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT); } @Test