Skip to content

Commit

Permalink
feat: add x-goog-request-params to header (#940)
Browse files Browse the repository at this point in the history
  • Loading branch information
kolea2 authored Dec 19, 2022
1 parent fab1ece commit dee8cb4
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ private DatastoreException invalidResponseException(String method, IOException e
}

public AllocateIdsResponse allocateIds(AllocateIdsRequest request) throws DatastoreException {
try (InputStream is = remoteRpc.call("allocateIds", request)) {
try (InputStream is =
remoteRpc.call("allocateIds", request, request.getProjectId(), request.getDatabaseId())) {
return AllocateIdsResponse.parseFrom(is);
} catch (IOException exception) {
throw invalidResponseException("allocateIds", exception);
Expand All @@ -76,47 +77,54 @@ public AllocateIdsResponse allocateIds(AllocateIdsRequest request) throws Datast

public BeginTransactionResponse beginTransaction(BeginTransactionRequest request)
throws DatastoreException {
try (InputStream is = remoteRpc.call("beginTransaction", request)) {
try (InputStream is =
remoteRpc.call(
"beginTransaction", request, request.getProjectId(), request.getDatabaseId())) {
return BeginTransactionResponse.parseFrom(is);
} catch (IOException exception) {
throw invalidResponseException("beginTransaction", exception);
}
}

public CommitResponse commit(CommitRequest request) throws DatastoreException {
try (InputStream is = remoteRpc.call("commit", request)) {
try (InputStream is =
remoteRpc.call("commit", request, request.getProjectId(), request.getDatabaseId())) {
return CommitResponse.parseFrom(is);
} catch (IOException exception) {
throw invalidResponseException("commit", exception);
}
}

public LookupResponse lookup(LookupRequest request) throws DatastoreException {
try (InputStream is = remoteRpc.call("lookup", request)) {
try (InputStream is =
remoteRpc.call("lookup", request, request.getProjectId(), request.getDatabaseId())) {
return LookupResponse.parseFrom(is);
} catch (IOException exception) {
throw invalidResponseException("lookup", exception);
}
}

public ReserveIdsResponse reserveIds(ReserveIdsRequest request) throws DatastoreException {
try (InputStream is = remoteRpc.call("reserveIds", request)) {
try (InputStream is =
remoteRpc.call("reserveIds", request, request.getProjectId(), request.getDatabaseId())) {
return ReserveIdsResponse.parseFrom(is);
} catch (IOException exception) {
throw invalidResponseException("reserveIds", exception);
}
}

public RollbackResponse rollback(RollbackRequest request) throws DatastoreException {
try (InputStream is = remoteRpc.call("rollback", request)) {
try (InputStream is =
remoteRpc.call("rollback", request, request.getProjectId(), request.getDatabaseId())) {
return RollbackResponse.parseFrom(is);
} catch (IOException exception) {
throw invalidResponseException("rollback", exception);
}
}

public RunQueryResponse runQuery(RunQueryRequest request) throws DatastoreException {
try (InputStream is = remoteRpc.call("runQuery", request)) {
try (InputStream is =
remoteRpc.call("runQuery", request, request.getProjectId(), request.getDatabaseId())) {
return RunQueryResponse.parseFrom(is);
} catch (IOException exception) {
throw invalidResponseException("runQuery", exception);
Expand All @@ -125,7 +133,9 @@ public RunQueryResponse runQuery(RunQueryRequest request) throws DatastoreExcept

public RunAggregationQueryResponse runAggregationQuery(RunAggregationQueryRequest request)
throws DatastoreException {
try (InputStream is = remoteRpc.call("runAggregationQuery", request)) {
try (InputStream is =
remoteRpc.call(
"runAggregationQuery", request, request.getProjectId(), request.getDatabaseId())) {
return RunAggregationQueryResponse.parseFrom(is);
} catch (IOException exception) {
throw invalidResponseException("runAggregationQuery", exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.api.client.http.protobuf.ProtoHttpContent;
import com.google.api.client.util.IOUtils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.protobuf.MessageLite;
import com.google.rpc.Code;
import com.google.rpc.Status;
Expand All @@ -46,6 +47,8 @@ class RemoteRpc {
@VisibleForTesting static final String API_FORMAT_VERSION_HEADER = "X-Goog-Api-Format-Version";
private static final String API_FORMAT_VERSION = "2";

@VisibleForTesting static final String X_GOOG_REQUEST_PARAMS_HEADER = "x-goog-request-params";

private final HttpRequestFactory client;
private final HttpRequestInitializer initializer;
private final String url;
Expand Down Expand Up @@ -74,7 +77,9 @@ class RemoteRpc {
*
* @throws DatastoreException if the RPC fails.
*/
public InputStream call(String methodName, MessageLite request) throws DatastoreException {
public InputStream call(
String methodName, MessageLite request, String projectId, String databaseId)
throws DatastoreException {
logger.fine("remote datastore call " + methodName);

long startTime = System.currentTimeMillis();
Expand All @@ -84,7 +89,7 @@ public InputStream call(String methodName, MessageLite request) throws Datastore
rpcCount.incrementAndGet();
ProtoHttpContent payload = new ProtoHttpContent(request);
HttpRequest httpRequest = client.buildPostRequest(resolveURL(methodName), payload);
setHeaders(request, httpRequest);
setHeaders(request, httpRequest, projectId, databaseId);
// Don't throw an HTTPResponseException on error. It converts the response to a String and
// throws away the original, whereas we need the raw bytes to parse it as a proto.
httpRequest.setThrowExceptionOnExecuteError(false);
Expand Down Expand Up @@ -123,8 +128,16 @@ public InputStream call(String methodName, MessageLite request) throws Datastore
}

@VisibleForTesting
void setHeaders(MessageLite request, HttpRequest httpRequest) {
void setHeaders(
MessageLite request, HttpRequest httpRequest, String projectId, String databaseId) {
httpRequest.getHeaders().put(API_FORMAT_VERSION_HEADER, API_FORMAT_VERSION);
StringBuilder builder = new StringBuilder("project_id=");
builder.append(projectId);
if (!Strings.isNullOrEmpty(databaseId)) {
builder.append("&database_id=");
builder.append(databaseId);
}
httpRequest.getHeaders().put(X_GOOG_REQUEST_PARAMS_HEADER, builder.toString());
if (enableE2EChecksum && request != null) {
String checksum = EndToEndChecksumHandler.computeChecksum(request.toByteArray());
if (checksum != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ public void testGzip() throws IOException, DatastoreException {
new InjectedTestValues(gzip(response), new byte[1], true);
RemoteRpc rpc = newRemoteRpc(injectedTestValues);

InputStream is = rpc.call("beginTransaction", BeginTransactionResponse.getDefaultInstance());
InputStream is =
rpc.call("beginTransaction", BeginTransactionResponse.getDefaultInstance(), "", "");
BeginTransactionResponse parsedResponse = BeginTransactionResponse.parseFrom(is);
is.close();

Expand All @@ -159,14 +160,15 @@ public void testGzip() throws IOException, DatastoreException {
public void testHttpHeaders_expectE2eChecksumHeader() throws IOException {
// Enable E2E-Checksum system env variable
RemoteRpc.setSystemEnvE2EChecksum(true);
String projectId = "project-id";
MessageLite request =
RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8("project-id")).build();
RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8(projectId)).build();
RemoteRpc rpc =
newRemoteRpc(
new InjectedTestValues(gzip(newBeginTransactionResponse()), new byte[1], true));
HttpRequest httpRequest =
rpc.getClient().buildPostRequest(rpc.resolveURL("blah"), new ProtoHttpContent(request));
rpc.setHeaders(request, httpRequest);
rpc.setHeaders(request, httpRequest, projectId, "");
assertNotNull(
httpRequest.getHeaders().getFirstHeaderStringValue(RemoteRpc.API_FORMAT_VERSION_HEADER));
// Expect to find e2e-checksum header
Expand All @@ -181,14 +183,15 @@ public void testHttpHeaders_expectE2eChecksumHeader() throws IOException {
public void testHttpHeaders_doNotExpectE2eChecksumHeader() throws IOException {
// disable E2E-Checksum system env variable
RemoteRpc.setSystemEnvE2EChecksum(false);
String projectId = "project-id";
MessageLite request =
RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8("project-id")).build();
RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8(projectId)).build();
RemoteRpc rpc =
newRemoteRpc(
new InjectedTestValues(gzip(newBeginTransactionResponse()), new byte[1], true));
HttpRequest httpRequest =
rpc.getClient().buildPostRequest(rpc.resolveURL("blah"), new ProtoHttpContent(request));
rpc.setHeaders(request, httpRequest);
rpc.setHeaders(request, httpRequest, projectId, "");
assertNotNull(
httpRequest.getHeaders().getFirstHeaderStringValue(RemoteRpc.API_FORMAT_VERSION_HEADER));
// Do not expect to find e2e-checksum header
Expand All @@ -198,6 +201,38 @@ public void testHttpHeaders_doNotExpectE2eChecksumHeader() throws IOException {
.getFirstHeaderStringValue(EndToEndChecksumHandler.HTTP_REQUEST_CHECKSUM_HEADER));
}

@Test
public void testHttpHeaders_prefixHeader() throws IOException {
String projectId = "my-project";
String databaseId = "my-db";
MessageLite request =
RollbackRequest.newBuilder()
.setTransaction(ByteString.copyFromUtf8(projectId))
.setDatabaseId(databaseId)
.build();
RemoteRpc rpc =
newRemoteRpc(
new InjectedTestValues(gzip(newBeginTransactionResponse()), new byte[1], true));
HttpRequest httpRequest =
rpc.getClient().buildPostRequest(rpc.resolveURL("blah"), new ProtoHttpContent(request));
rpc.setHeaders(request, httpRequest, projectId, databaseId);
assertEquals(
"project_id=my-project&database_id=my-db",
httpRequest.getHeaders().get(RemoteRpc.X_GOOG_REQUEST_PARAMS_HEADER));

MessageLite request2 =
RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8(projectId)).build();
RemoteRpc rpc2 =
newRemoteRpc(
new InjectedTestValues(gzip(newBeginTransactionResponse()), new byte[1], true));
HttpRequest httpRequest2 =
rpc2.getClient().buildPostRequest(rpc2.resolveURL("blah"), new ProtoHttpContent(request2));
rpc2.setHeaders(request, httpRequest2, projectId, "");
assertEquals(
"project_id=my-project",
httpRequest2.getHeaders().get(RemoteRpc.X_GOOG_REQUEST_PARAMS_HEADER));
}

private static BeginTransactionResponse newBeginTransactionResponse() {
return BeginTransactionResponse.newBuilder()
.setTransaction(ByteString.copyFromUtf8("blah-blah-blah"))
Expand Down

0 comments on commit dee8cb4

Please sign in to comment.