diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java index 8caad29c0171..589272d20ce8 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java @@ -16,6 +16,7 @@ package com.google.cloud.vertexai; +import com.google.api.core.InternalApi; import com.google.api.gax.core.CredentialsProvider; import com.google.api.gax.core.FixedCredentialsProvider; import com.google.api.gax.core.GaxProperties; @@ -220,11 +221,30 @@ public void setApiEndpoint(String apiEndpoint) { } } + /** + * Returns the {@link PredictionServiceClient} with GRPC or REST, based on the Transport type. The + * client will be instantiated when the first prediction API call is made. + * + * @return {@link PredictionServiceClient} that send requests to the backing service through + * method calls that map to the API methods. + */ + @InternalApi + public PredictionServiceClient getPredictionServiceClient() throws IOException { + if (this.transport == Transport.GRPC) { + return getPredictionServiceGrpcClient(); + } else { + return getPredictionServiceRestClient(); + } + } + /** * Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the * first prediction API call is made. + * + * @return {@link PredictionServiceClient} that send GRPC requests to the backing service through + * method calls that map to the API methods. */ - public PredictionServiceClient getPredictionServiceClient() throws IOException { + private PredictionServiceClient getPredictionServiceGrpcClient() throws IOException { if (predictionServiceClient == null) { PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder(); settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint)); @@ -257,7 +277,7 @@ public PredictionServiceClient getPredictionServiceClient() throws IOException { * @return {@link PredictionServiceClient} that send REST requests to the backing service through * method calls that map to the API methods. */ - public PredictionServiceClient getPredictionServiceRestClient() throws IOException { + private PredictionServiceClient getPredictionServiceRestClient() throws IOException { if (predictionServiceRestClient == null) { PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newHttpJsonBuilder(); @@ -284,14 +304,30 @@ public PredictionServiceClient getPredictionServiceRestClient() throws IOExcepti return predictionServiceRestClient; } + /** + * Returns the {@link LlmUtilityServiceClient} with GRPC or REST, based on the Transport type. The + * client will be instantiated when the first API call is made. + * + * @return {@link LlmUtilityServiceClient} that makes calls to the backing service through method + * calls that map to the API methods. + */ + @InternalApi + public LlmUtilityServiceClient getLlmUtilityClient() throws IOException { + if (this.transport == Transport.GRPC) { + return getLlmUtilityGrpcClient(); + } else { + return getLlmUtilityRestClient(); + } + } + /** * Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the - * first prediction API call is made. + * first API call is made. * * @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through * method calls that map to the API methods. */ - public LlmUtilityServiceClient getLlmUtilityClient() throws IOException { + private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException { if (llmUtilityClient == null) { LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder(); settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint)); @@ -319,12 +355,12 @@ public LlmUtilityServiceClient getLlmUtilityClient() throws IOException { /** * Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the - * first prediction API call is made. + * first API call is made. * * @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through * method calls that map to the API methods. */ - public LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException { + private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException { if (llmUtilityRestClient == null) { LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newHttpJsonBuilder(); diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java index fd34bd39d5d8..fc232ae40fb4 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java @@ -289,11 +289,7 @@ public CountTokensResponse countTokens(List contents) throws IOExceptio @BetaApi private CountTokensResponse countTokensFromRequest(CountTokensRequest request) throws IOException { - if (vertexAi.getTransport() == Transport.REST) { - return vertexAi.getLlmUtilityRestClient().countTokens(request); - } else { - return vertexAi.getLlmUtilityClient().countTokens(request); - } + return vertexAi.getLlmUtilityClient().countTokens(request); } /** @@ -520,11 +516,7 @@ public GenerateContentResponse generateContent( */ private GenerateContentResponse generateContent(GenerateContentRequest request) throws IOException { - if (vertexAi.getTransport() == Transport.REST) { - return vertexAi.getPredictionServiceRestClient().generateContentCallable().call(request); - } else { - return vertexAi.getPredictionServiceClient().generateContentCallable().call(request); - } + return vertexAi.getPredictionServiceClient().generateContentCallable().call(request); } /** @@ -932,23 +924,13 @@ public ResponseStream generateContentStream( */ private ResponseStream generateContentStream( GenerateContentRequest request) throws IOException { - if (vertexAi.getTransport() == Transport.REST) { - return new ResponseStream( - new ResponseStreamIteratorWithHistory( - vertexAi - .getPredictionServiceRestClient() - .streamGenerateContentCallable() - .call(request) - .iterator())); - } else { - return new ResponseStream( - new ResponseStreamIteratorWithHistory( - vertexAi - .getPredictionServiceClient() - .streamGenerateContentCallable() - .call(request) - .iterator())); - } + return new ResponseStream( + new ResponseStreamIteratorWithHistory( + vertexAi + .getPredictionServiceClient() + .streamGenerateContentCallable() + .call(request) + .iterator())); } /**