Skip to content

Commit

Permalink
feat: [vertexai] add custom headers support in VertexAI
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665423699
  • Loading branch information
jaycee-li authored and copybara-github committed Aug 20, 2024
1 parent 2ba5930 commit 43ce9d4
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand All @@ -63,6 +66,7 @@ public class VertexAI implements AutoCloseable {
private final String location;
private final String apiEndpoint;
private final Transport transport;
private final HeaderProvider headerProvider;
private final CredentialsProvider credentialsProvider;

private final transient Supplier<PredictionServiceClient> predictionClientSupplier;
Expand All @@ -85,6 +89,7 @@ public VertexAI(String projectId, String location) {
location,
Transport.GRPC,
ImmutableList.of(),
/* customHeaders= */ ImmutableMap.of(),
/* credentials= */ Optional.empty(),
/* apiEndpoint= */ Optional.empty(),
/* predictionClientSupplierOpt= */ Optional.empty(),
Expand All @@ -108,6 +113,7 @@ public VertexAI() {
null,
Transport.GRPC,
ImmutableList.of(),
/* customHeaders= */ ImmutableMap.of(),
/* credentials= */ Optional.empty(),
/* apiEndpoint= */ Optional.empty(),
/* predictionClientSupplierOpt= */ Optional.empty(),
Expand All @@ -119,6 +125,7 @@ private VertexAI(
String location,
Transport transport,
List<String> scopes,
Map<String, String> customHeaders,
Optional<Credentials> credentials,
Optional<String> apiEndpoint,
Optional<Supplier<PredictionServiceClient>> predictionClientSupplierOpt,
Expand All @@ -131,6 +138,15 @@ private VertexAI(
this.location = Strings.isNullOrEmpty(location) ? inferLocation() : location;
this.transport = transport;

String sdkHeader =
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class));
Map<String, String> headers = new HashMap<>(customHeaders);
headers.compute("user-agent", (k, v) -> v == null ? sdkHeader : sdkHeader + " " + v);
this.headerProvider = FixedHeaderProvider.create(headers);

if (credentials.isPresent()) {
this.credentialsProvider = FixedCredentialsProvider.create(credentials.get());
} else {
Expand Down Expand Up @@ -160,6 +176,7 @@ public static class Builder {
private String location;
private Transport transport = Transport.GRPC;
private ImmutableList<String> scopes = ImmutableList.of();
private ImmutableMap<String, String> customHeaders = ImmutableMap.of();
private Optional<Credentials> credentials = Optional.empty();
private Optional<String> apiEndpoint = Optional.empty();

Expand All @@ -174,6 +191,7 @@ public VertexAI build() {
location,
transport,
scopes,
customHeaders,
credentials,
apiEndpoint,
Optional.ofNullable(predictionClientSupplier),
Expand Down Expand Up @@ -240,6 +258,14 @@ public Builder setScopes(List<String> scopes) {
this.scopes = ImmutableList.copyOf(scopes);
return this;
}

@CanIgnoreReturnValue
public Builder setCustomHeaders(Map<String, String> customHeaders) {
checkNotNull(customHeaders, "customHeaders can't be null");

this.customHeaders = ImmutableMap.copyOf(customHeaders);
return this;
}
}

/**
Expand Down Expand Up @@ -278,6 +304,15 @@ public String getApiEndpoint() {
return apiEndpoint;
}

/**
* Returns the headers to use when making API calls.
*
* @return a map of headers to use when making API calls.
*/
public Map<String, String> getHeaders() {
return headerProvider.getHeaders();
}

/**
* Returns the default credentials to use when making API calls.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;

import com.google.api.gax.core.GaxProperties;
import com.google.api.gax.core.GoogleCredentialsProvider;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.cloud.vertexai.api.PredictionServiceClient;
import com.google.cloud.vertexai.api.PredictionServiceSettings;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -397,4 +400,59 @@ public void testInstantiateVertexAI_builderWithTransport_shouldContainRightField
assertThat(vertexAi.getTransport()).isEqualTo(Transport.REST);
assertThat(vertexAi.getApiEndpoint()).isEqualTo(TEST_DEFAULT_ENDPOINT);
}

@Test
public void testInstantiateVertexAI_builderWithCustomHeaders_shouldContainRightFields()
throws IOException {
Map<String, String> customHeaders = new HashMap<>();
customHeaders.put("test_key", "test_value");

vertexAi =
new VertexAI.Builder()
.setProjectId(TEST_PROJECT)
.setLocation(TEST_LOCATION)
.setCustomHeaders(customHeaders)
.build();

assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT);
assertThat(vertexAi.getLocation()).isEqualTo(TEST_LOCATION);
// headers should include both the sdk header and the custom headers
Map<String, String> expectedHeaders = new HashMap<>(customHeaders);
expectedHeaders.put(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
assertThat(vertexAi.getHeaders()).isEqualTo(expectedHeaders);
}

@Test
public void
testInstantiateVertexAI_builderWithCustomHeadersWithSdkReservedKey_shouldContainRightFields()
throws IOException {
Map<String, String> customHeadersWithSdkReservedKey = new HashMap<>();
customHeadersWithSdkReservedKey.put("user-agent", "test_value");

vertexAi =
new VertexAI.Builder()
.setProjectId(TEST_PROJECT)
.setLocation(TEST_LOCATION)
.setCustomHeaders(customHeadersWithSdkReservedKey)
.build();

assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT);
assertThat(vertexAi.getLocation()).isEqualTo(TEST_LOCATION);
// headers should include sdk reserved key with value of both the sdk header and the custom
// headers
Map<String, String> expectedHeaders = new HashMap<>();
expectedHeaders.put(
"user-agent",
String.format(
"%s/%s %s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class),
"test_value"));
assertThat(vertexAi.getHeaders()).isEqualTo(expectedHeaders);
}
}

0 comments on commit 43ce9d4

Please sign in to comment.